Image gen improved
This commit is contained in:
parent
74dea019fa
commit
38b5185cfd
180
frontend/src/NewApp/Components/Pulse.tsx
Normal file
180
frontend/src/NewApp/Components/Pulse.tsx
Normal file
@ -0,0 +1,180 @@
|
||||
import React, { useEffect, useState, useRef } from 'react';
|
||||
import { SxProps } from '@mui/material';
|
||||
import Box from '@mui/material/Box';
|
||||
|
||||
interface PulseProps {
|
||||
timestamp: number | string;
|
||||
sx?: SxProps;
|
||||
}
|
||||
|
||||
const Pulse: React.FC<PulseProps> = ({ timestamp, sx }) => {
|
||||
const [isAnimating, setIsAnimating] = useState(false);
|
||||
const [animationKey, setAnimationKey] = useState(0);
|
||||
const previousTimestamp = useRef<number | string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (timestamp && timestamp !== previousTimestamp.current) {
|
||||
previousTimestamp.current = timestamp;
|
||||
setAnimationKey(prev => prev + 1);
|
||||
setIsAnimating(true);
|
||||
|
||||
// Reset animation state after animation completes
|
||||
const timer = setTimeout(() => {
|
||||
setIsAnimating(false);
|
||||
}, 1000);
|
||||
|
||||
return () => clearTimeout(timer);
|
||||
}
|
||||
}, [timestamp]);
|
||||
|
||||
const containerStyle: React.CSSProperties = {
|
||||
position: 'relative',
|
||||
width: 80,
|
||||
height: 80,
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
};
|
||||
|
||||
const baseCoreStyle: React.CSSProperties = {
|
||||
width: 24,
|
||||
height: 24,
|
||||
borderRadius: '50%',
|
||||
backgroundColor: '#2196f3',
|
||||
position: 'relative',
|
||||
zIndex: 3,
|
||||
};
|
||||
|
||||
const coreStyle: React.CSSProperties = {
|
||||
...baseCoreStyle,
|
||||
animation: isAnimating ? 'pulse-glow 1s ease-out' : 'none',
|
||||
};
|
||||
|
||||
const pulseRing1Style: React.CSSProperties = {
|
||||
position: 'absolute',
|
||||
width: 24,
|
||||
height: 24,
|
||||
borderRadius: '50%',
|
||||
backgroundColor: '#2196f3',
|
||||
zIndex: 2,
|
||||
animation: 'pulse-expand 1s ease-out forwards',
|
||||
};
|
||||
|
||||
const pulseRing2Style: React.CSSProperties = {
|
||||
position: 'absolute',
|
||||
width: 24,
|
||||
height: 24,
|
||||
borderRadius: '50%',
|
||||
backgroundColor: '#64b5f6',
|
||||
zIndex: 1,
|
||||
animation: 'pulse-expand 1s ease-out 0.2s forwards',
|
||||
};
|
||||
|
||||
const rippleStyle: React.CSSProperties = {
|
||||
position: 'absolute',
|
||||
width: 32,
|
||||
height: 32,
|
||||
borderRadius: '50%',
|
||||
border: '2px solid #2196f3',
|
||||
backgroundColor: 'transparent',
|
||||
zIndex: 0,
|
||||
animation: 'ripple-expand 1s ease-out forwards',
|
||||
};
|
||||
|
||||
const outerRippleStyle: React.CSSProperties = {
|
||||
position: 'absolute',
|
||||
width: 40,
|
||||
height: 40,
|
||||
borderRadius: '50%',
|
||||
border: '1px solid #90caf9',
|
||||
backgroundColor: 'transparent',
|
||||
zIndex: 0,
|
||||
animation: 'ripple-expand 1s ease-out 0.3s forwards',
|
||||
};
|
||||
|
||||
|
||||
|
||||
return (
|
||||
<>
|
||||
<style>
|
||||
{`
|
||||
@keyframes pulse-expand {
|
||||
0% {
|
||||
transform: scale(1);
|
||||
opacity: 1;
|
||||
}
|
||||
50% {
|
||||
transform: scale(1.3);
|
||||
opacity: 0.7;
|
||||
}
|
||||
100% {
|
||||
transform: scale(1.6);
|
||||
opacity: 0;
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes ripple-expand {
|
||||
0% {
|
||||
transform: scale(0.8);
|
||||
opacity: 0.8;
|
||||
}
|
||||
100% {
|
||||
transform: scale(2);
|
||||
opacity: 0;
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes pulse-glow {
|
||||
0% {
|
||||
box-shadow: 0 0 5px #2196f3, 0 0 10px #2196f3, 0 0 15px #2196f3;
|
||||
}
|
||||
50% {
|
||||
box-shadow: 0 0 10px #2196f3, 0 0 20px #2196f3, 0 0 30px #2196f3;
|
||||
}
|
||||
100% {
|
||||
box-shadow: 0 0 5px #2196f3, 0 0 10px #2196f3, 0 0 15px #2196f3;
|
||||
}
|
||||
}
|
||||
`}
|
||||
</style>
|
||||
|
||||
<Box sx={{...containerStyle, ...sx}}>
|
||||
{/* Base circle */}
|
||||
<div style={coreStyle} />
|
||||
|
||||
{/* Pulse rings */}
|
||||
{isAnimating && (
|
||||
<>
|
||||
{/* Primary pulse ring */}
|
||||
<div
|
||||
key={`pulse-1-${animationKey}`}
|
||||
style={pulseRing1Style}
|
||||
/>
|
||||
|
||||
{/* Secondary pulse ring with delay */}
|
||||
<div
|
||||
key={`pulse-2-${animationKey}`}
|
||||
style={pulseRing2Style}
|
||||
/>
|
||||
|
||||
{/* Ripple effect */}
|
||||
<div
|
||||
key={`ripple-${animationKey}`}
|
||||
style={rippleStyle}
|
||||
/>
|
||||
|
||||
{/* Outer ripple */}
|
||||
<div
|
||||
key={`ripple-outer-${animationKey}`}
|
||||
style={outerRippleStyle}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
|
||||
</Box>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export { Pulse } ;
|
330
frontend/src/NewApp/Pages/GenerateCandidate.tsx
Normal file
330
frontend/src/NewApp/Pages/GenerateCandidate.tsx
Normal file
@ -0,0 +1,330 @@
|
||||
import React, { useEffect, useState, useRef } from 'react';
|
||||
import Avatar from '@mui/material/Avatar';
|
||||
import Box from '@mui/material/Box';
|
||||
import Tooltip from '@mui/material/Tooltip';
|
||||
import Button from '@mui/material/Button';
|
||||
import Paper from '@mui/material/Paper';
|
||||
import IconButton from '@mui/material/IconButton';
|
||||
import CancelIcon from '@mui/icons-material/Cancel';
|
||||
import SendIcon from '@mui/icons-material/Send';
|
||||
import PropagateLoader from 'react-spinners/PropagateLoader';
|
||||
import { CandidateInfo } from '../Components/CandidateInfo';
|
||||
import { Query } from '../../Components/ChatQuery'
|
||||
import { streamQueryResponse, StreamQueryController } from '../Components/streamQueryResponse';
|
||||
import { connectionBase } from 'Global';
|
||||
import { UserInfo } from '../Components/UserContext';
|
||||
import { BackstoryElementProps } from 'Components/BackstoryTab';
|
||||
import { BackstoryTextField, BackstoryTextFieldRef } from 'Components/BackstoryTextField';
|
||||
import { jsonrepair } from 'jsonrepair';
|
||||
import { StyledMarkdown } from 'Components/StyledMarkdown';
|
||||
import { Scrollable } from 'Components/Scrollable';
|
||||
import { useForkRef } from '@mui/material';
|
||||
import { BackstoryMessage } from 'Components/Message';
|
||||
import { Pulse } from 'NewApp/Components/Pulse';
|
||||
|
||||
const emptyUser : UserInfo = {
|
||||
type: 'candidate',
|
||||
description: "[blank]",
|
||||
rag_content_size: 0,
|
||||
username: "[blank]",
|
||||
first_name: "[blank]",
|
||||
last_name: "[blank]",
|
||||
full_name: "[blank] [blank]",
|
||||
contact_info: {},
|
||||
questions: [],
|
||||
isAuthenticated: false,
|
||||
has_profile: false
|
||||
};
|
||||
|
||||
const GenerateCandidate = (props: BackstoryElementProps) => {
|
||||
const {sessionId, setSnack, submitQuery} = props;
|
||||
const [streaming, setStreaming] = useState<string>('');
|
||||
const [processing, setProcessing] = useState<boolean>(false);
|
||||
const [user, setUser] = useState<UserInfo>(emptyUser);
|
||||
const controllerRef = useRef<StreamQueryController>(null);
|
||||
const backstoryTextRef = useRef<BackstoryTextFieldRef>(null);
|
||||
const promptRef = useRef<string>(null);
|
||||
const stateRef = useRef<number>(0); /* Generating persona */
|
||||
const userRef = useRef<UserInfo>(user);
|
||||
const [prompt, setPrompt] = useState<string>('');
|
||||
const [resume, setResume] = useState<string>('');
|
||||
const [canGenImage, setCanGenImage] = useState<boolean>(false);
|
||||
const [hasProfile, setHasProfile] = useState<boolean>(false);
|
||||
const [status, setStatus] = useState<string>('');
|
||||
const [timestamp, setTimestamp] = useState<number>(0);
|
||||
|
||||
const generateProfile = () => {
|
||||
if (controllerRef.current) {
|
||||
return;
|
||||
}
|
||||
setProcessing(true);
|
||||
setCanGenImage(false);
|
||||
stateRef.current = 3;
|
||||
const start = Date.now();
|
||||
|
||||
controllerRef.current = streamQueryResponse({
|
||||
query: {
|
||||
prompt: `A photorealistic profile picture of a ${user.age} year old ${user.gender} ${user.ethnicity} person. ${prompt}`,
|
||||
agent_options: {
|
||||
username: userRef.current.username,
|
||||
filename: "profile.png"
|
||||
}
|
||||
},
|
||||
type: "image",
|
||||
sessionId,
|
||||
connectionBase,
|
||||
onComplete: (msg) => {
|
||||
console.log({ msg, state: stateRef.current, prompt: promptRef.current || '' });
|
||||
switch (msg.status) {
|
||||
case "partial":
|
||||
case "done":
|
||||
if (msg.status === "done") {
|
||||
setProcessing(false);
|
||||
controllerRef.current = null;
|
||||
stateRef.current = 0;
|
||||
setCanGenImage(true);
|
||||
setHasProfile(true);
|
||||
}
|
||||
break;
|
||||
case "error":
|
||||
console.log(`Error generating persona: ${msg.response} after ${Date.now() - start}`);
|
||||
setSnack(msg.response || "", "error");
|
||||
setProcessing(false);
|
||||
setUser({...userRef.current});
|
||||
controllerRef.current = null;
|
||||
stateRef.current = 0;
|
||||
setCanGenImage(true);
|
||||
setHasProfile(true); /* Hack for now */
|
||||
break;
|
||||
default:
|
||||
const data = JSON.parse(msg.response || '');
|
||||
if (data.timestamp) {
|
||||
setTimestamp(data.timestamp);
|
||||
} else {
|
||||
setTimestamp(Date.now())
|
||||
}
|
||||
if (data.message) {
|
||||
setStatus(data.message);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
};
|
||||
|
||||
const generatePersona = (query: Query) => {
|
||||
if (controllerRef.current) {
|
||||
return;
|
||||
}
|
||||
setPrompt(query.prompt);
|
||||
promptRef.current = query.prompt;
|
||||
stateRef.current = 0;
|
||||
setUser(emptyUser);
|
||||
setStreaming('');
|
||||
setResume('');
|
||||
setProcessing(true);
|
||||
setCanGenImage(false);
|
||||
|
||||
controllerRef.current = streamQueryResponse({
|
||||
query,
|
||||
type: "persona",
|
||||
sessionId,
|
||||
connectionBase,
|
||||
onComplete: (msg) => {
|
||||
switch (msg.status) {
|
||||
case "partial":
|
||||
case "done":
|
||||
switch (stateRef.current) {
|
||||
case 0: /* Generating persona */
|
||||
let partialUser = JSON.parse(jsonrepair((msg.response || '').trim()));
|
||||
if (!partialUser.full_name) {
|
||||
partialUser.full_name = `${partialUser.first_name} ${partialUser.last_name}`;
|
||||
}
|
||||
console.log(partialUser);
|
||||
setUser({...partialUser});
|
||||
stateRef.current++; /* Generating resume */
|
||||
break;
|
||||
case 1: /* Generating resume */
|
||||
stateRef.current++; /* RAG generation */
|
||||
setResume(msg.response || '');
|
||||
break;
|
||||
case 2: /* RAG generation */
|
||||
stateRef.current++; /* Image generation */
|
||||
break;
|
||||
}
|
||||
if (msg.status === "done") {
|
||||
setProcessing(false);
|
||||
setCanGenImage(true);
|
||||
controllerRef.current = null;
|
||||
stateRef.current = 0;
|
||||
setTimeout(() => {
|
||||
generateProfile();
|
||||
}, 0);
|
||||
}
|
||||
break;
|
||||
case "thinking":
|
||||
setStatus(msg.response || '');
|
||||
break;
|
||||
|
||||
case "error":
|
||||
console.log(`Error generating persona: ${msg.response}`);
|
||||
setSnack(msg.response || "", "error");
|
||||
setProcessing(false);
|
||||
setUser({...userRef.current});
|
||||
controllerRef.current = null;
|
||||
stateRef.current = 0;
|
||||
break;
|
||||
}
|
||||
},
|
||||
onStreaming: (chunk) => {
|
||||
setStreaming(chunk);
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
const cancelQuery = () => {
|
||||
if (controllerRef.current) {
|
||||
controllerRef.current.abort();
|
||||
controllerRef.current = null;
|
||||
stateRef.current = 0;
|
||||
setProcessing(false);
|
||||
}
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
promptRef.current = prompt;
|
||||
}, [prompt]);
|
||||
|
||||
useEffect(() => {
|
||||
userRef.current = user;
|
||||
}, [user]);
|
||||
|
||||
useEffect(() => {
|
||||
if (streaming.trim().length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
switch (stateRef.current) {
|
||||
case 0: /* Generating persona */
|
||||
const partialUser = {...emptyUser, ...JSON.parse(jsonrepair(`${streaming.trim()}...`))};
|
||||
if (!partialUser.full_name) {
|
||||
partialUser.full_name = `${partialUser.first_name} ${partialUser.last_name}`;
|
||||
}
|
||||
setUser(partialUser);
|
||||
break;
|
||||
case 1: /* Generating resume */
|
||||
setResume(streaming);
|
||||
break;
|
||||
case 3: /* RAG streaming */
|
||||
break;
|
||||
case 4: /* Image streaming */
|
||||
break;
|
||||
}
|
||||
} catch {
|
||||
}
|
||||
}, [streaming]);
|
||||
|
||||
if (!sessionId) {
|
||||
return <></>;
|
||||
}
|
||||
|
||||
const onEnter = (value: string) => {
|
||||
if (processing) {
|
||||
return;
|
||||
}
|
||||
const query: Query = {
|
||||
prompt: value,
|
||||
}
|
||||
generatePersona(query);
|
||||
};
|
||||
|
||||
return (
|
||||
<Box className="GenerateCandidate" sx={{
|
||||
display: "flex", flexDirection: "column", flexGrow: 1, gap: 1, width: { xs: '100%', md: '700px', lg: '1024px' }
|
||||
}}>
|
||||
{ user && <CandidateInfo sessionId={sessionId} user={user}/> }
|
||||
{processing &&
|
||||
<Box sx={{
|
||||
display: "flex",
|
||||
flexDirection: "column",
|
||||
alignItems: "center",
|
||||
justifyContent: "center",
|
||||
m: 2,
|
||||
}}>
|
||||
<Box sx={{flexDirection: "row", fontWeight: "bold"}}>{status}</Box>
|
||||
<PropagateLoader
|
||||
size="10px"
|
||||
loading={processing}
|
||||
aria-label="Loading Spinner"
|
||||
data-testid="loader"
|
||||
/>
|
||||
</Box>
|
||||
}
|
||||
<Box sx={{display: "flex", flexDirection: "column"}}>
|
||||
<Box sx={{ display: "flex", flexDirection: "row"}}>
|
||||
<Avatar
|
||||
src={hasProfile ? `/api/u/${user.username}/profile/${sessionId}` : ''}
|
||||
alt={`${user.full_name}'s profile`}
|
||||
sx={{
|
||||
width: 80,
|
||||
height: 80,
|
||||
border: '2px solid #e0e0e0',
|
||||
}}
|
||||
/>
|
||||
{ processing && <Pulse sx={{position: "absolute", left: "-80px" }} timestamp={timestamp}/> }
|
||||
<Tooltip title={"Generate Profile Picture"}>
|
||||
<span style={{ display: "flex", flexGrow: 1 }}>
|
||||
<Button
|
||||
sx={{ m: 1, gap: 1, flexGrow: 1 }}
|
||||
variant="contained"
|
||||
disabled={sessionId === undefined || processing || !canGenImage}
|
||||
onClick={() => { generateProfile(); }}>
|
||||
Generate Profile Picture<SendIcon />
|
||||
</Button>
|
||||
</span>
|
||||
</Tooltip>
|
||||
</Box>
|
||||
</Box>
|
||||
{ resume !== '' && <Paper sx={{pt: 1, pb: 1, pl: 2, pr: 2}}><Scrollable sx={{flexGrow: 1}}><StyledMarkdown {...{content: resume, setSnack, sessionId, submitQuery}}/></Scrollable></Paper> }
|
||||
<BackstoryTextField
|
||||
style={{ flexGrow: 0, flexShrink: 1 }}
|
||||
ref={backstoryTextRef}
|
||||
disabled={processing}
|
||||
onEnter={onEnter}
|
||||
placeholder='Specify any characteristics you would like the persona to have. For example, "This person likes yo-yos."'
|
||||
/>
|
||||
<Box sx={{ display: "flex", justifyContent: "center", flexDirection: "row" }}>
|
||||
<Tooltip title={"Send"}>
|
||||
<span style={{ display: "flex", flexGrow: 1 }}>
|
||||
<Button
|
||||
sx={{ m: 1, gap: 1, flexGrow: 1 }}
|
||||
variant="contained"
|
||||
disabled={sessionId === undefined || processing}
|
||||
onClick={() => { generatePersona({ prompt: (backstoryTextRef.current && backstoryTextRef.current.getAndResetValue()) || "" }); }}>
|
||||
Send<SendIcon />
|
||||
</Button>
|
||||
</span>
|
||||
</Tooltip>
|
||||
<Tooltip title="Cancel">
|
||||
<span style={{ display: "flex" }}> { /* This span is used to wrap the IconButton to ensure Tooltip works even when disabled */}
|
||||
<IconButton
|
||||
aria-label="cancel"
|
||||
onClick={() => { cancelQuery(); }}
|
||||
sx={{ display: "flex", margin: 'auto 0px' }}
|
||||
size="large"
|
||||
edge="start"
|
||||
disabled={controllerRef.current === null || !sessionId || processing === false}
|
||||
>
|
||||
<CancelIcon />
|
||||
</IconButton>
|
||||
</span>
|
||||
</Tooltip>
|
||||
</Box>
|
||||
</Box>);
|
||||
};
|
||||
|
||||
export {
|
||||
GenerateCandidate
|
||||
};
|
218
src/utils/agents/image_generator.py
Normal file
218
src/utils/agents/image_generator.py
Normal file
@ -0,0 +1,218 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import model_validator, Field, BaseModel # type: ignore
|
||||
from typing import (
|
||||
Dict,
|
||||
Literal,
|
||||
ClassVar,
|
||||
cast,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
List,
|
||||
Optional
|
||||
# override
|
||||
) # NOTE: You must import Optional for late binding to work
|
||||
import inspect
|
||||
import random
|
||||
import re
|
||||
import json
|
||||
import traceback
|
||||
import asyncio
|
||||
import time
|
||||
import asyncio
|
||||
import time
|
||||
import os
|
||||
|
||||
from . base import Agent, agent_registry, LLMMessage
|
||||
from .. message import Message
|
||||
from .. rag import ChromaDBGetResponse
|
||||
from .. setup_logging import setup_logging
|
||||
from .. profile_image import generate_image, ImageRequest
|
||||
from .. import defines
|
||||
from .. user import User
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
seed = int(time.time())
|
||||
random.seed(seed)
|
||||
|
||||
class ImageGenerator(Agent):
|
||||
agent_type: Literal["image"] = "image" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
agent_persist: bool = False
|
||||
|
||||
system_prompt: str = "" # No system prompt is used
|
||||
username: str
|
||||
filename: str
|
||||
|
||||
llm: Any = Field(default=None, exclude=True)
|
||||
model: str = Field(default=None, exclude=True)
|
||||
user: Dict[str, Any] = Field(default={}, exclude=True)
|
||||
|
||||
async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
|
||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||
message.status = "done"
|
||||
yield message
|
||||
return
|
||||
|
||||
async def process_message(
|
||||
self, llm: Any, model: str, message: Message
|
||||
) -> AsyncGenerator[Message, None]:
|
||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||
if not self.context:
|
||||
raise ValueError("Context is not set for this agent.")
|
||||
|
||||
self.llm = llm
|
||||
self.model = model
|
||||
|
||||
spinner: List[str] = ["\\", "|", "/", "-"]
|
||||
tick: int = 0
|
||||
while self.context.processing:
|
||||
logger.info(
|
||||
"TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time"
|
||||
)
|
||||
message.status = "waiting"
|
||||
message.response = (
|
||||
f"Busy processing another request. Please wait. {spinner[tick]}"
|
||||
)
|
||||
tick = (tick + 1) % len(spinner)
|
||||
yield message
|
||||
await asyncio.sleep(1) # Allow the event loop to process the write
|
||||
|
||||
self.context.processing = True
|
||||
|
||||
try:
|
||||
#
|
||||
# Generate the profile picture
|
||||
#
|
||||
prompt = message.prompt
|
||||
message.status = "thinking"
|
||||
message.response = f"Generating: {prompt}"
|
||||
yield message
|
||||
|
||||
user_info = os.path.join(defines.user_dir, self.username, defines.user_info_file)
|
||||
with open(user_info, "r") as f:
|
||||
self.user = json.loads(f.read())
|
||||
|
||||
logger.info("Beginning image generation...", self.user)
|
||||
logger.info("TODO: Add safety checks for filename... actually figure out an entirely different way to figure out where to store them.")
|
||||
self.filename = "profile.png"
|
||||
request = ImageRequest(filepath=os.path.join(defines.user_dir, self.user["username"], self.filename), prompt=prompt)
|
||||
async for message in generate_image(
|
||||
message=message,
|
||||
request=request
|
||||
):
|
||||
if message.status != "done":
|
||||
yield message
|
||||
logger.info("Image generation done...")
|
||||
images = self.user.get("images", [])
|
||||
if self.filename not in images:
|
||||
images.append(self.filename)
|
||||
if self.filename == "profile.png":
|
||||
self.user["has_profile"] = True
|
||||
|
||||
#
|
||||
# Write out the completed user information
|
||||
#
|
||||
with open(user_info, "w") as f:
|
||||
f.write(json.dumps(self.user))
|
||||
|
||||
# Image generated
|
||||
message.status = "done"
|
||||
message.response = json.dumps(self.user)
|
||||
|
||||
except Exception as e:
|
||||
message.status = "error"
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(message.response)
|
||||
message.response = f"Error in image generation: {str(e)}"
|
||||
logger.error(message.response)
|
||||
yield message
|
||||
return
|
||||
|
||||
# Done processing, add message to conversation
|
||||
self.context.processing = False
|
||||
# Return the final message
|
||||
yield message
|
||||
return
|
||||
|
||||
async def call_llm(self, message: Message, system_prompt, prompt, temperature=0.7):
|
||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||
|
||||
messages: List[LLMMessage] = [
|
||||
LLMMessage(role="system", content=system_prompt),
|
||||
LLMMessage(role="user", content=prompt),
|
||||
]
|
||||
message.metadata.options = {
|
||||
"seed": 8911,
|
||||
"num_ctx": self.context_size,
|
||||
"temperature": temperature, # Higher temperature to encourage tool usage
|
||||
}
|
||||
|
||||
message.status = "streaming"
|
||||
yield message
|
||||
|
||||
last_chunk_time = 0
|
||||
message.chunk = ""
|
||||
message.response = ""
|
||||
for response in self.llm.chat(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
options={
|
||||
**message.metadata.options,
|
||||
},
|
||||
stream=True,
|
||||
):
|
||||
if not response:
|
||||
message.status = "error"
|
||||
message.response = "No response from LLM."
|
||||
yield message
|
||||
return
|
||||
|
||||
message.status = "streaming"
|
||||
message.chunk += response.message.content
|
||||
message.response += response.message.content
|
||||
|
||||
if not response.done:
|
||||
now = time.perf_counter()
|
||||
if now - last_chunk_time > 0.25:
|
||||
yield message
|
||||
last_chunk_time = now
|
||||
message.chunk = ""
|
||||
|
||||
if response.done:
|
||||
self.collect_metrics(response)
|
||||
message.metadata.eval_count += response.eval_count
|
||||
message.metadata.eval_duration += response.eval_duration
|
||||
message.metadata.prompt_eval_count += response.prompt_eval_count
|
||||
message.metadata.prompt_eval_duration += response.prompt_eval_duration
|
||||
self.context_tokens = response.prompt_eval_count + response.eval_count
|
||||
message.chunk = ""
|
||||
message.status = "done"
|
||||
yield message
|
||||
|
||||
def extract_json_from_text(self, text: str) -> str:
|
||||
"""Extract JSON string from text that may contain other content."""
|
||||
json_pattern = r"```json\s*([\s\S]*?)\s*```"
|
||||
match = re.search(json_pattern, text)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
# Try to find JSON without the markdown code block
|
||||
json_pattern = r"({[\s\S]*})"
|
||||
match = re.search(json_pattern, text)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
raise ValueError("No JSON found in the response")
|
||||
|
||||
def extract_markdown_from_text(self, text: str) -> str:
|
||||
"""Extract Markdown string from text that may contain other content."""
|
||||
markdown_pattern = r"```(md|markdown)\s*([\s\S]*?)\s*```"
|
||||
match = re.search(markdown_pattern, text)
|
||||
if match:
|
||||
return match.group(2).strip()
|
||||
|
||||
raise ValueError("No Markdown found in the response")
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(ImageGenerator._agent_type, ImageGenerator)
|
422
src/utils/agents/persona_generator.py
Normal file
422
src/utils/agents/persona_generator.py
Normal file
@ -0,0 +1,422 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import model_validator, Field, BaseModel # type: ignore
|
||||
from typing import (
|
||||
Dict,
|
||||
Literal,
|
||||
ClassVar,
|
||||
cast,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
List,
|
||||
Optional
|
||||
# override
|
||||
) # NOTE: You must import Optional for late binding to work
|
||||
import inspect
|
||||
import random
|
||||
import re
|
||||
import json
|
||||
import traceback
|
||||
import asyncio
|
||||
import time
|
||||
import asyncio
|
||||
import time
|
||||
import os
|
||||
|
||||
from .base import Agent, agent_registry, LLMMessage
|
||||
from ..message import Message
|
||||
from ..rag import ChromaDBGetResponse
|
||||
from ..setup_logging import setup_logging
|
||||
from .. import defines
|
||||
from ..user import User
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
seed = int(time.time())
|
||||
random.seed(seed)
|
||||
|
||||
emptyUser = {
|
||||
"profile_url": "",
|
||||
"description": "",
|
||||
"rag_content_size": 0,
|
||||
"username": "",
|
||||
"first_name": "",
|
||||
"last_name": "",
|
||||
"full_name": "",
|
||||
"contact_info": {},
|
||||
"questions": [],
|
||||
}
|
||||
|
||||
generate_persona_system_prompt = """\
|
||||
You are a casing director for a movie. Your job is to provide information on ficticious personas for use in a screen play.
|
||||
|
||||
All response field MUST BE IN ENGLISH, regardless of ethnicity.
|
||||
|
||||
You will be provided with defaults to use if not specified by the user:
|
||||
|
||||
```json
|
||||
{
|
||||
"age": number,
|
||||
"gender": "male" | "female",
|
||||
"ethnicity": string,
|
||||
}
|
||||
```
|
||||
|
||||
Additional information provided in the user message can override those defaults.
|
||||
|
||||
You need to randomly assign an English username (can include numbers), a first name, last name, and a two English sentence description of that individual's work given the demographics provided.
|
||||
|
||||
Your response must be in JSON.
|
||||
Provide only the JSON response, and match the field names EXACTLY.
|
||||
Provide all information in English ONLY, with no other commentary:
|
||||
|
||||
```json
|
||||
{
|
||||
"username": string, # A likely-to-be unique username, no more than 15 characters (can include numbers and letters but no special characters)
|
||||
"first_name": string,
|
||||
"last_name": string,
|
||||
"description": string, # One to two sentence description of their job
|
||||
"location": string, # In the location, provide ALL of: City, State/Region, and Country
|
||||
}
|
||||
```
|
||||
|
||||
Make sure to provide a username and that the field name for the job description is "description".
|
||||
"""
|
||||
|
||||
generate_resume_system_prompt = """
|
||||
You are a creative writing casting director. As part of the casting, you are building backstories about individuals. The first part
|
||||
of that is to create an in-depth resume for the person. You will be provided with the following information:
|
||||
|
||||
```json
|
||||
"full_name": string, # Person full name
|
||||
"location": string, # Location of residence
|
||||
"age": number, # Age of candidate
|
||||
"description": string # A brief description of the person
|
||||
```
|
||||
|
||||
Use that information to invent a full career resume. Include sections such as:
|
||||
|
||||
* Contact information
|
||||
* Job goal
|
||||
* Top skills
|
||||
* Detailed work history. If they are under the age of 25, you might include skills, hobbies, or volunteering they may have done while an adolescent
|
||||
* In the work history, provide company names, years of employment, and their role
|
||||
* Education
|
||||
|
||||
Provide the resume in Markdown format. DO NOT provide any commentary before or after the resume.
|
||||
"""
|
||||
|
||||
class PersonaGenerator(Agent):
|
||||
agent_type: Literal["persona"] = "persona" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
agent_persist: bool = False
|
||||
|
||||
system_prompt: str = generate_persona_system_prompt
|
||||
age: int = Field(default_factory=lambda: random.randint(22, 67))
|
||||
gender: str = Field(default_factory=lambda: random.choice(["male", "female"]))
|
||||
ethnicity: Literal[
|
||||
"Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"
|
||||
] = Field(
|
||||
default_factory=lambda: random.choices(
|
||||
["Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"],
|
||||
weights=[57.69, 15.38, 19.23, 5.77, 1.92],
|
||||
k=1
|
||||
)[0]
|
||||
)
|
||||
username: str = ""
|
||||
|
||||
llm: Any = Field(default=None, exclude=True)
|
||||
model: str = Field(default=None, exclude=True)
|
||||
|
||||
def randomize(self):
|
||||
self.age = random.randint(22, 67)
|
||||
self.gender = random.choice(["male", "female"])
|
||||
# Use random.choices with explicit type casting to satisfy Literal type
|
||||
self.ethnicity = cast(
|
||||
Literal["Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"],
|
||||
random.choices(
|
||||
["Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"],
|
||||
weights=[57.69, 15.38, 19.23, 5.77, 1.92],
|
||||
k=1
|
||||
)[0]
|
||||
)
|
||||
|
||||
async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
|
||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||
|
||||
if not self.context:
|
||||
raise ValueError("Context is not set for this agent.")
|
||||
|
||||
message.tunables.enable_tools = False
|
||||
message.tunables.enable_rag = False
|
||||
message.tunables.enable_context = False
|
||||
|
||||
message.prompt = f"""\
|
||||
```json
|
||||
{json.dumps({
|
||||
"age": self.age,
|
||||
"gender": self.gender,
|
||||
"ethnicity": self.ethnicity
|
||||
})}
|
||||
```
|
||||
{message.prompt}
|
||||
"""
|
||||
message.status = "done"
|
||||
yield message
|
||||
return
|
||||
|
||||
async def process_message(
|
||||
self, llm: Any, model: str, message: Message
|
||||
) -> AsyncGenerator[Message, None]:
|
||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||
if not self.context:
|
||||
raise ValueError("Context is not set for this agent.")
|
||||
|
||||
self.llm = llm
|
||||
self.model = model
|
||||
original_prompt = message.prompt
|
||||
|
||||
spinner: List[str] = ["\\", "|", "/", "-"]
|
||||
tick: int = 0
|
||||
while self.context.processing:
|
||||
logger.info(
|
||||
"TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time"
|
||||
)
|
||||
message.status = "waiting"
|
||||
message.response = (
|
||||
f"Busy processing another request. Please wait. {spinner[tick]}"
|
||||
)
|
||||
tick = (tick + 1) % len(spinner)
|
||||
yield message
|
||||
await asyncio.sleep(1) # Allow the event loop to process the write
|
||||
|
||||
self.context.processing = True
|
||||
|
||||
try:
|
||||
|
||||
#
|
||||
# Generate the persona
|
||||
#
|
||||
async for message in self.call_llm(
|
||||
message=message, system_prompt=self.system_prompt, prompt=original_prompt
|
||||
):
|
||||
if message.status != "done":
|
||||
yield message
|
||||
if message.status == "error":
|
||||
raise Exception(message.response)
|
||||
|
||||
json_str = self.extract_json_from_text(message.response)
|
||||
try:
|
||||
persona = json.loads(json_str) | {
|
||||
"age": self.age,
|
||||
"gender": self.gender,
|
||||
"ethnicity": self.ethnicity
|
||||
}
|
||||
if not persona.get("full_name", None):
|
||||
persona["full_name"] = f"{persona['first_name']} {persona['last_name']}"
|
||||
self.username = persona.get("username", None)
|
||||
if not self.username:
|
||||
raise ValueError("LLM did not generate a username")
|
||||
user_dir = os.path.join(defines.user_dir, persona["username"])
|
||||
while os.path.exists(user_dir):
|
||||
match = re.match(r"^(.*?)(\d*)$", persona["username"])
|
||||
if match:
|
||||
base = match.group(1)
|
||||
num = match.group(2)
|
||||
iteration = int(num) + 1 if num else 1
|
||||
persona["username"] = f"{base}{iteration}"
|
||||
user_dir = os.path.join(defines.user_dir, persona["username"])
|
||||
|
||||
for key in persona:
|
||||
if isinstance(persona[key], str):
|
||||
persona[key] = persona[key].strip()
|
||||
# Mark this persona as AI generated
|
||||
persona["is_ai"] = True
|
||||
except Exception as e:
|
||||
message.response = f"Unable to parse LLM returned content: {json_str} {str(e)}"
|
||||
message.status = "error"
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(message.response)
|
||||
yield message
|
||||
return
|
||||
|
||||
# Persona generated
|
||||
message.response = json.dumps(persona)
|
||||
message.status = "partial"
|
||||
yield message
|
||||
|
||||
#
|
||||
# Generate the resume
|
||||
#
|
||||
message.status = "thinking"
|
||||
message.response = f"Generating resume for {persona['full_name']}..."
|
||||
yield message
|
||||
|
||||
prompt = f"""
|
||||
```json
|
||||
{{
|
||||
"full_name": "{persona["full_name"]}",
|
||||
"location": "{persona["location"]}",
|
||||
"age": {persona["age"]},
|
||||
"description": {persona["description"]}
|
||||
}}
|
||||
```
|
||||
"""
|
||||
try:
|
||||
async for message in self.call_llm(
|
||||
message=message, system_prompt=generate_resume_system_prompt, prompt=prompt
|
||||
):
|
||||
if message.status != "done":
|
||||
yield message
|
||||
if message.status == "error":
|
||||
raise Exception(message.response)
|
||||
|
||||
except Exception as e:
|
||||
message.response = f"Unable to parse LLM returned content: {json_str} {str(e)}"
|
||||
message.status = "error"
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(message.response)
|
||||
yield message
|
||||
return
|
||||
|
||||
resume = self.extract_markdown_from_text(message.response)
|
||||
if resume:
|
||||
user_resume_dir = os.path.join(defines.user_dir, persona["username"], defines.resume_doc_dir)
|
||||
os.makedirs(user_resume_dir, exist_ok=True)
|
||||
user_resume_file = os.path.join(user_resume_dir, defines.resume_doc)
|
||||
with open(user_resume_file, "w") as f:
|
||||
f.write(resume)
|
||||
|
||||
# Resume generated
|
||||
message.response = resume
|
||||
message.status = "partial"
|
||||
yield message
|
||||
|
||||
#
|
||||
# Generate RAG database
|
||||
#
|
||||
message.status = "thinking"
|
||||
message.response = f"Generating RAG content from resume..."
|
||||
yield message
|
||||
|
||||
# Prior to instancing a new User, the json data has to be created
|
||||
# so the system can process it
|
||||
user_dir = os.path.join(defines.user_dir, persona["username"])
|
||||
os.makedirs(user_dir, exist_ok=True)
|
||||
user_info = os.path.join(user_dir, "info.json")
|
||||
with open(user_info, "w") as f:
|
||||
f.write(json.dumps(persona, indent=2))
|
||||
|
||||
user = User(llm=self.llm, username=self.username)
|
||||
await user.initialize()
|
||||
await user.file_watcher.initialize_collection()
|
||||
# RAG content generated
|
||||
message.response = f"{user.file_watcher.collection.count()} entries created in RAG vector store."
|
||||
|
||||
#
|
||||
# Write out the completed user information
|
||||
#
|
||||
with open(user_info, "w") as f:
|
||||
f.write(json.dumps(persona, indent=2))
|
||||
|
||||
# Image generated
|
||||
message.status = "done"
|
||||
message.response = json.dumps(persona)
|
||||
|
||||
except Exception as e:
|
||||
message.status = "error"
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(message.response)
|
||||
message.response = f"Error in persona generation: {str(e)}"
|
||||
logger.error(message.response)
|
||||
self.randomize() # Randomize for next generation
|
||||
yield message
|
||||
return
|
||||
|
||||
# Done processing, add message to conversation
|
||||
self.context.processing = False
|
||||
self.randomize() # Randomize for next generation
|
||||
# Return the final message
|
||||
yield message
|
||||
return
|
||||
|
||||
async def call_llm(self, message: Message, system_prompt, prompt, temperature=0.7):
|
||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||
|
||||
messages: List[LLMMessage] = [
|
||||
LLMMessage(role="system", content=system_prompt),
|
||||
LLMMessage(role="user", content=prompt),
|
||||
]
|
||||
message.metadata.options = {
|
||||
"seed": 8911,
|
||||
"num_ctx": self.context_size,
|
||||
"temperature": temperature, # Higher temperature to encourage tool usage
|
||||
}
|
||||
|
||||
message.status = "streaming"
|
||||
yield message
|
||||
|
||||
last_chunk_time = 0
|
||||
message.chunk = ""
|
||||
message.response = ""
|
||||
for response in self.llm.chat(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
options={
|
||||
**message.metadata.options,
|
||||
},
|
||||
stream=True,
|
||||
):
|
||||
if not response:
|
||||
message.status = "error"
|
||||
message.response = "No response from LLM."
|
||||
yield message
|
||||
return
|
||||
|
||||
message.status = "streaming"
|
||||
message.chunk += response.message.content
|
||||
message.response += response.message.content
|
||||
|
||||
if not response.done:
|
||||
now = time.perf_counter()
|
||||
if now - last_chunk_time > 0.25:
|
||||
yield message
|
||||
last_chunk_time = now
|
||||
message.chunk = ""
|
||||
|
||||
if response.done:
|
||||
self.collect_metrics(response)
|
||||
message.metadata.eval_count += response.eval_count
|
||||
message.metadata.eval_duration += response.eval_duration
|
||||
message.metadata.prompt_eval_count += response.prompt_eval_count
|
||||
message.metadata.prompt_eval_duration += response.prompt_eval_duration
|
||||
self.context_tokens = response.prompt_eval_count + response.eval_count
|
||||
message.chunk = ""
|
||||
message.status = "done"
|
||||
yield message
|
||||
|
||||
def extract_json_from_text(self, text: str) -> str:
|
||||
"""Extract JSON string from text that may contain other content."""
|
||||
json_pattern = r"```json\s*([\s\S]*?)\s*```"
|
||||
match = re.search(json_pattern, text)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
# Try to find JSON without the markdown code block
|
||||
json_pattern = r"({[\s\S]*})"
|
||||
match = re.search(json_pattern, text)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
raise ValueError("No JSON found in the response")
|
||||
|
||||
def extract_markdown_from_text(self, text: str) -> str:
|
||||
"""Extract Markdown string from text that may contain other content."""
|
||||
markdown_pattern = r"```(md|markdown)\s*([\s\S]*?)\s*```"
|
||||
match = re.search(markdown_pattern, text)
|
||||
if match:
|
||||
return match.group(2).strip()
|
||||
|
||||
raise ValueError("No Markdown found in the response")
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(PersonaGenerator._agent_type, PersonaGenerator)
|
Loading…
x
Reference in New Issue
Block a user