Image gen improved
This commit is contained in:
parent
74dea019fa
commit
38b5185cfd
180
frontend/src/NewApp/Components/Pulse.tsx
Normal file
180
frontend/src/NewApp/Components/Pulse.tsx
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
import React, { useEffect, useState, useRef } from 'react';
|
||||||
|
import { SxProps } from '@mui/material';
|
||||||
|
import Box from '@mui/material/Box';
|
||||||
|
|
||||||
|
interface PulseProps {
|
||||||
|
timestamp: number | string;
|
||||||
|
sx?: SxProps;
|
||||||
|
}
|
||||||
|
|
||||||
|
const Pulse: React.FC<PulseProps> = ({ timestamp, sx }) => {
|
||||||
|
const [isAnimating, setIsAnimating] = useState(false);
|
||||||
|
const [animationKey, setAnimationKey] = useState(0);
|
||||||
|
const previousTimestamp = useRef<number | string | null>(null);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (timestamp && timestamp !== previousTimestamp.current) {
|
||||||
|
previousTimestamp.current = timestamp;
|
||||||
|
setAnimationKey(prev => prev + 1);
|
||||||
|
setIsAnimating(true);
|
||||||
|
|
||||||
|
// Reset animation state after animation completes
|
||||||
|
const timer = setTimeout(() => {
|
||||||
|
setIsAnimating(false);
|
||||||
|
}, 1000);
|
||||||
|
|
||||||
|
return () => clearTimeout(timer);
|
||||||
|
}
|
||||||
|
}, [timestamp]);
|
||||||
|
|
||||||
|
const containerStyle: React.CSSProperties = {
|
||||||
|
position: 'relative',
|
||||||
|
width: 80,
|
||||||
|
height: 80,
|
||||||
|
display: 'flex',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
};
|
||||||
|
|
||||||
|
const baseCoreStyle: React.CSSProperties = {
|
||||||
|
width: 24,
|
||||||
|
height: 24,
|
||||||
|
borderRadius: '50%',
|
||||||
|
backgroundColor: '#2196f3',
|
||||||
|
position: 'relative',
|
||||||
|
zIndex: 3,
|
||||||
|
};
|
||||||
|
|
||||||
|
const coreStyle: React.CSSProperties = {
|
||||||
|
...baseCoreStyle,
|
||||||
|
animation: isAnimating ? 'pulse-glow 1s ease-out' : 'none',
|
||||||
|
};
|
||||||
|
|
||||||
|
const pulseRing1Style: React.CSSProperties = {
|
||||||
|
position: 'absolute',
|
||||||
|
width: 24,
|
||||||
|
height: 24,
|
||||||
|
borderRadius: '50%',
|
||||||
|
backgroundColor: '#2196f3',
|
||||||
|
zIndex: 2,
|
||||||
|
animation: 'pulse-expand 1s ease-out forwards',
|
||||||
|
};
|
||||||
|
|
||||||
|
const pulseRing2Style: React.CSSProperties = {
|
||||||
|
position: 'absolute',
|
||||||
|
width: 24,
|
||||||
|
height: 24,
|
||||||
|
borderRadius: '50%',
|
||||||
|
backgroundColor: '#64b5f6',
|
||||||
|
zIndex: 1,
|
||||||
|
animation: 'pulse-expand 1s ease-out 0.2s forwards',
|
||||||
|
};
|
||||||
|
|
||||||
|
const rippleStyle: React.CSSProperties = {
|
||||||
|
position: 'absolute',
|
||||||
|
width: 32,
|
||||||
|
height: 32,
|
||||||
|
borderRadius: '50%',
|
||||||
|
border: '2px solid #2196f3',
|
||||||
|
backgroundColor: 'transparent',
|
||||||
|
zIndex: 0,
|
||||||
|
animation: 'ripple-expand 1s ease-out forwards',
|
||||||
|
};
|
||||||
|
|
||||||
|
const outerRippleStyle: React.CSSProperties = {
|
||||||
|
position: 'absolute',
|
||||||
|
width: 40,
|
||||||
|
height: 40,
|
||||||
|
borderRadius: '50%',
|
||||||
|
border: '1px solid #90caf9',
|
||||||
|
backgroundColor: 'transparent',
|
||||||
|
zIndex: 0,
|
||||||
|
animation: 'ripple-expand 1s ease-out 0.3s forwards',
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<style>
|
||||||
|
{`
|
||||||
|
@keyframes pulse-expand {
|
||||||
|
0% {
|
||||||
|
transform: scale(1);
|
||||||
|
opacity: 1;
|
||||||
|
}
|
||||||
|
50% {
|
||||||
|
transform: scale(1.3);
|
||||||
|
opacity: 0.7;
|
||||||
|
}
|
||||||
|
100% {
|
||||||
|
transform: scale(1.6);
|
||||||
|
opacity: 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes ripple-expand {
|
||||||
|
0% {
|
||||||
|
transform: scale(0.8);
|
||||||
|
opacity: 0.8;
|
||||||
|
}
|
||||||
|
100% {
|
||||||
|
transform: scale(2);
|
||||||
|
opacity: 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes pulse-glow {
|
||||||
|
0% {
|
||||||
|
box-shadow: 0 0 5px #2196f3, 0 0 10px #2196f3, 0 0 15px #2196f3;
|
||||||
|
}
|
||||||
|
50% {
|
||||||
|
box-shadow: 0 0 10px #2196f3, 0 0 20px #2196f3, 0 0 30px #2196f3;
|
||||||
|
}
|
||||||
|
100% {
|
||||||
|
box-shadow: 0 0 5px #2196f3, 0 0 10px #2196f3, 0 0 15px #2196f3;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`}
|
||||||
|
</style>
|
||||||
|
|
||||||
|
<Box sx={{...containerStyle, ...sx}}>
|
||||||
|
{/* Base circle */}
|
||||||
|
<div style={coreStyle} />
|
||||||
|
|
||||||
|
{/* Pulse rings */}
|
||||||
|
{isAnimating && (
|
||||||
|
<>
|
||||||
|
{/* Primary pulse ring */}
|
||||||
|
<div
|
||||||
|
key={`pulse-1-${animationKey}`}
|
||||||
|
style={pulseRing1Style}
|
||||||
|
/>
|
||||||
|
|
||||||
|
{/* Secondary pulse ring with delay */}
|
||||||
|
<div
|
||||||
|
key={`pulse-2-${animationKey}`}
|
||||||
|
style={pulseRing2Style}
|
||||||
|
/>
|
||||||
|
|
||||||
|
{/* Ripple effect */}
|
||||||
|
<div
|
||||||
|
key={`ripple-${animationKey}`}
|
||||||
|
style={rippleStyle}
|
||||||
|
/>
|
||||||
|
|
||||||
|
{/* Outer ripple */}
|
||||||
|
<div
|
||||||
|
key={`ripple-outer-${animationKey}`}
|
||||||
|
style={outerRippleStyle}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
|
||||||
|
|
||||||
|
</Box>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export { Pulse } ;
|
330
frontend/src/NewApp/Pages/GenerateCandidate.tsx
Normal file
330
frontend/src/NewApp/Pages/GenerateCandidate.tsx
Normal file
@ -0,0 +1,330 @@
|
|||||||
|
import React, { useEffect, useState, useRef } from 'react';
|
||||||
|
import Avatar from '@mui/material/Avatar';
|
||||||
|
import Box from '@mui/material/Box';
|
||||||
|
import Tooltip from '@mui/material/Tooltip';
|
||||||
|
import Button from '@mui/material/Button';
|
||||||
|
import Paper from '@mui/material/Paper';
|
||||||
|
import IconButton from '@mui/material/IconButton';
|
||||||
|
import CancelIcon from '@mui/icons-material/Cancel';
|
||||||
|
import SendIcon from '@mui/icons-material/Send';
|
||||||
|
import PropagateLoader from 'react-spinners/PropagateLoader';
|
||||||
|
import { CandidateInfo } from '../Components/CandidateInfo';
|
||||||
|
import { Query } from '../../Components/ChatQuery'
|
||||||
|
import { streamQueryResponse, StreamQueryController } from '../Components/streamQueryResponse';
|
||||||
|
import { connectionBase } from 'Global';
|
||||||
|
import { UserInfo } from '../Components/UserContext';
|
||||||
|
import { BackstoryElementProps } from 'Components/BackstoryTab';
|
||||||
|
import { BackstoryTextField, BackstoryTextFieldRef } from 'Components/BackstoryTextField';
|
||||||
|
import { jsonrepair } from 'jsonrepair';
|
||||||
|
import { StyledMarkdown } from 'Components/StyledMarkdown';
|
||||||
|
import { Scrollable } from 'Components/Scrollable';
|
||||||
|
import { useForkRef } from '@mui/material';
|
||||||
|
import { BackstoryMessage } from 'Components/Message';
|
||||||
|
import { Pulse } from 'NewApp/Components/Pulse';
|
||||||
|
|
||||||
|
const emptyUser : UserInfo = {
|
||||||
|
type: 'candidate',
|
||||||
|
description: "[blank]",
|
||||||
|
rag_content_size: 0,
|
||||||
|
username: "[blank]",
|
||||||
|
first_name: "[blank]",
|
||||||
|
last_name: "[blank]",
|
||||||
|
full_name: "[blank] [blank]",
|
||||||
|
contact_info: {},
|
||||||
|
questions: [],
|
||||||
|
isAuthenticated: false,
|
||||||
|
has_profile: false
|
||||||
|
};
|
||||||
|
|
||||||
|
const GenerateCandidate = (props: BackstoryElementProps) => {
|
||||||
|
const {sessionId, setSnack, submitQuery} = props;
|
||||||
|
const [streaming, setStreaming] = useState<string>('');
|
||||||
|
const [processing, setProcessing] = useState<boolean>(false);
|
||||||
|
const [user, setUser] = useState<UserInfo>(emptyUser);
|
||||||
|
const controllerRef = useRef<StreamQueryController>(null);
|
||||||
|
const backstoryTextRef = useRef<BackstoryTextFieldRef>(null);
|
||||||
|
const promptRef = useRef<string>(null);
|
||||||
|
const stateRef = useRef<number>(0); /* Generating persona */
|
||||||
|
const userRef = useRef<UserInfo>(user);
|
||||||
|
const [prompt, setPrompt] = useState<string>('');
|
||||||
|
const [resume, setResume] = useState<string>('');
|
||||||
|
const [canGenImage, setCanGenImage] = useState<boolean>(false);
|
||||||
|
const [hasProfile, setHasProfile] = useState<boolean>(false);
|
||||||
|
const [status, setStatus] = useState<string>('');
|
||||||
|
const [timestamp, setTimestamp] = useState<number>(0);
|
||||||
|
|
||||||
|
const generateProfile = () => {
|
||||||
|
if (controllerRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setProcessing(true);
|
||||||
|
setCanGenImage(false);
|
||||||
|
stateRef.current = 3;
|
||||||
|
const start = Date.now();
|
||||||
|
|
||||||
|
controllerRef.current = streamQueryResponse({
|
||||||
|
query: {
|
||||||
|
prompt: `A photorealistic profile picture of a ${user.age} year old ${user.gender} ${user.ethnicity} person. ${prompt}`,
|
||||||
|
agent_options: {
|
||||||
|
username: userRef.current.username,
|
||||||
|
filename: "profile.png"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
type: "image",
|
||||||
|
sessionId,
|
||||||
|
connectionBase,
|
||||||
|
onComplete: (msg) => {
|
||||||
|
console.log({ msg, state: stateRef.current, prompt: promptRef.current || '' });
|
||||||
|
switch (msg.status) {
|
||||||
|
case "partial":
|
||||||
|
case "done":
|
||||||
|
if (msg.status === "done") {
|
||||||
|
setProcessing(false);
|
||||||
|
controllerRef.current = null;
|
||||||
|
stateRef.current = 0;
|
||||||
|
setCanGenImage(true);
|
||||||
|
setHasProfile(true);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case "error":
|
||||||
|
console.log(`Error generating persona: ${msg.response} after ${Date.now() - start}`);
|
||||||
|
setSnack(msg.response || "", "error");
|
||||||
|
setProcessing(false);
|
||||||
|
setUser({...userRef.current});
|
||||||
|
controllerRef.current = null;
|
||||||
|
stateRef.current = 0;
|
||||||
|
setCanGenImage(true);
|
||||||
|
setHasProfile(true); /* Hack for now */
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
const data = JSON.parse(msg.response || '');
|
||||||
|
if (data.timestamp) {
|
||||||
|
setTimestamp(data.timestamp);
|
||||||
|
} else {
|
||||||
|
setTimestamp(Date.now())
|
||||||
|
}
|
||||||
|
if (data.message) {
|
||||||
|
setStatus(data.message);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
const generatePersona = (query: Query) => {
|
||||||
|
if (controllerRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setPrompt(query.prompt);
|
||||||
|
promptRef.current = query.prompt;
|
||||||
|
stateRef.current = 0;
|
||||||
|
setUser(emptyUser);
|
||||||
|
setStreaming('');
|
||||||
|
setResume('');
|
||||||
|
setProcessing(true);
|
||||||
|
setCanGenImage(false);
|
||||||
|
|
||||||
|
controllerRef.current = streamQueryResponse({
|
||||||
|
query,
|
||||||
|
type: "persona",
|
||||||
|
sessionId,
|
||||||
|
connectionBase,
|
||||||
|
onComplete: (msg) => {
|
||||||
|
switch (msg.status) {
|
||||||
|
case "partial":
|
||||||
|
case "done":
|
||||||
|
switch (stateRef.current) {
|
||||||
|
case 0: /* Generating persona */
|
||||||
|
let partialUser = JSON.parse(jsonrepair((msg.response || '').trim()));
|
||||||
|
if (!partialUser.full_name) {
|
||||||
|
partialUser.full_name = `${partialUser.first_name} ${partialUser.last_name}`;
|
||||||
|
}
|
||||||
|
console.log(partialUser);
|
||||||
|
setUser({...partialUser});
|
||||||
|
stateRef.current++; /* Generating resume */
|
||||||
|
break;
|
||||||
|
case 1: /* Generating resume */
|
||||||
|
stateRef.current++; /* RAG generation */
|
||||||
|
setResume(msg.response || '');
|
||||||
|
break;
|
||||||
|
case 2: /* RAG generation */
|
||||||
|
stateRef.current++; /* Image generation */
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (msg.status === "done") {
|
||||||
|
setProcessing(false);
|
||||||
|
setCanGenImage(true);
|
||||||
|
controllerRef.current = null;
|
||||||
|
stateRef.current = 0;
|
||||||
|
setTimeout(() => {
|
||||||
|
generateProfile();
|
||||||
|
}, 0);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case "thinking":
|
||||||
|
setStatus(msg.response || '');
|
||||||
|
break;
|
||||||
|
|
||||||
|
case "error":
|
||||||
|
console.log(`Error generating persona: ${msg.response}`);
|
||||||
|
setSnack(msg.response || "", "error");
|
||||||
|
setProcessing(false);
|
||||||
|
setUser({...userRef.current});
|
||||||
|
controllerRef.current = null;
|
||||||
|
stateRef.current = 0;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
onStreaming: (chunk) => {
|
||||||
|
setStreaming(chunk);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
const cancelQuery = () => {
|
||||||
|
if (controllerRef.current) {
|
||||||
|
controllerRef.current.abort();
|
||||||
|
controllerRef.current = null;
|
||||||
|
stateRef.current = 0;
|
||||||
|
setProcessing(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
promptRef.current = prompt;
|
||||||
|
}, [prompt]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
userRef.current = user;
|
||||||
|
}, [user]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (streaming.trim().length === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
switch (stateRef.current) {
|
||||||
|
case 0: /* Generating persona */
|
||||||
|
const partialUser = {...emptyUser, ...JSON.parse(jsonrepair(`${streaming.trim()}...`))};
|
||||||
|
if (!partialUser.full_name) {
|
||||||
|
partialUser.full_name = `${partialUser.first_name} ${partialUser.last_name}`;
|
||||||
|
}
|
||||||
|
setUser(partialUser);
|
||||||
|
break;
|
||||||
|
case 1: /* Generating resume */
|
||||||
|
setResume(streaming);
|
||||||
|
break;
|
||||||
|
case 3: /* RAG streaming */
|
||||||
|
break;
|
||||||
|
case 4: /* Image streaming */
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
}
|
||||||
|
}, [streaming]);
|
||||||
|
|
||||||
|
if (!sessionId) {
|
||||||
|
return <></>;
|
||||||
|
}
|
||||||
|
|
||||||
|
const onEnter = (value: string) => {
|
||||||
|
if (processing) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const query: Query = {
|
||||||
|
prompt: value,
|
||||||
|
}
|
||||||
|
generatePersona(query);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box className="GenerateCandidate" sx={{
|
||||||
|
display: "flex", flexDirection: "column", flexGrow: 1, gap: 1, width: { xs: '100%', md: '700px', lg: '1024px' }
|
||||||
|
}}>
|
||||||
|
{ user && <CandidateInfo sessionId={sessionId} user={user}/> }
|
||||||
|
{processing &&
|
||||||
|
<Box sx={{
|
||||||
|
display: "flex",
|
||||||
|
flexDirection: "column",
|
||||||
|
alignItems: "center",
|
||||||
|
justifyContent: "center",
|
||||||
|
m: 2,
|
||||||
|
}}>
|
||||||
|
<Box sx={{flexDirection: "row", fontWeight: "bold"}}>{status}</Box>
|
||||||
|
<PropagateLoader
|
||||||
|
size="10px"
|
||||||
|
loading={processing}
|
||||||
|
aria-label="Loading Spinner"
|
||||||
|
data-testid="loader"
|
||||||
|
/>
|
||||||
|
</Box>
|
||||||
|
}
|
||||||
|
<Box sx={{display: "flex", flexDirection: "column"}}>
|
||||||
|
<Box sx={{ display: "flex", flexDirection: "row"}}>
|
||||||
|
<Avatar
|
||||||
|
src={hasProfile ? `/api/u/${user.username}/profile/${sessionId}` : ''}
|
||||||
|
alt={`${user.full_name}'s profile`}
|
||||||
|
sx={{
|
||||||
|
width: 80,
|
||||||
|
height: 80,
|
||||||
|
border: '2px solid #e0e0e0',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
{ processing && <Pulse sx={{position: "absolute", left: "-80px" }} timestamp={timestamp}/> }
|
||||||
|
<Tooltip title={"Generate Profile Picture"}>
|
||||||
|
<span style={{ display: "flex", flexGrow: 1 }}>
|
||||||
|
<Button
|
||||||
|
sx={{ m: 1, gap: 1, flexGrow: 1 }}
|
||||||
|
variant="contained"
|
||||||
|
disabled={sessionId === undefined || processing || !canGenImage}
|
||||||
|
onClick={() => { generateProfile(); }}>
|
||||||
|
Generate Profile Picture<SendIcon />
|
||||||
|
</Button>
|
||||||
|
</span>
|
||||||
|
</Tooltip>
|
||||||
|
</Box>
|
||||||
|
</Box>
|
||||||
|
{ resume !== '' && <Paper sx={{pt: 1, pb: 1, pl: 2, pr: 2}}><Scrollable sx={{flexGrow: 1}}><StyledMarkdown {...{content: resume, setSnack, sessionId, submitQuery}}/></Scrollable></Paper> }
|
||||||
|
<BackstoryTextField
|
||||||
|
style={{ flexGrow: 0, flexShrink: 1 }}
|
||||||
|
ref={backstoryTextRef}
|
||||||
|
disabled={processing}
|
||||||
|
onEnter={onEnter}
|
||||||
|
placeholder='Specify any characteristics you would like the persona to have. For example, "This person likes yo-yos."'
|
||||||
|
/>
|
||||||
|
<Box sx={{ display: "flex", justifyContent: "center", flexDirection: "row" }}>
|
||||||
|
<Tooltip title={"Send"}>
|
||||||
|
<span style={{ display: "flex", flexGrow: 1 }}>
|
||||||
|
<Button
|
||||||
|
sx={{ m: 1, gap: 1, flexGrow: 1 }}
|
||||||
|
variant="contained"
|
||||||
|
disabled={sessionId === undefined || processing}
|
||||||
|
onClick={() => { generatePersona({ prompt: (backstoryTextRef.current && backstoryTextRef.current.getAndResetValue()) || "" }); }}>
|
||||||
|
Send<SendIcon />
|
||||||
|
</Button>
|
||||||
|
</span>
|
||||||
|
</Tooltip>
|
||||||
|
<Tooltip title="Cancel">
|
||||||
|
<span style={{ display: "flex" }}> { /* This span is used to wrap the IconButton to ensure Tooltip works even when disabled */}
|
||||||
|
<IconButton
|
||||||
|
aria-label="cancel"
|
||||||
|
onClick={() => { cancelQuery(); }}
|
||||||
|
sx={{ display: "flex", margin: 'auto 0px' }}
|
||||||
|
size="large"
|
||||||
|
edge="start"
|
||||||
|
disabled={controllerRef.current === null || !sessionId || processing === false}
|
||||||
|
>
|
||||||
|
<CancelIcon />
|
||||||
|
</IconButton>
|
||||||
|
</span>
|
||||||
|
</Tooltip>
|
||||||
|
</Box>
|
||||||
|
</Box>);
|
||||||
|
};
|
||||||
|
|
||||||
|
export {
|
||||||
|
GenerateCandidate
|
||||||
|
};
|
218
src/utils/agents/image_generator.py
Normal file
218
src/utils/agents/image_generator.py
Normal file
@ -0,0 +1,218 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from pydantic import model_validator, Field, BaseModel # type: ignore
|
||||||
|
from typing import (
|
||||||
|
Dict,
|
||||||
|
Literal,
|
||||||
|
ClassVar,
|
||||||
|
cast,
|
||||||
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
|
List,
|
||||||
|
Optional
|
||||||
|
# override
|
||||||
|
) # NOTE: You must import Optional for late binding to work
|
||||||
|
import inspect
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
import traceback
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
|
||||||
|
from . base import Agent, agent_registry, LLMMessage
|
||||||
|
from .. message import Message
|
||||||
|
from .. rag import ChromaDBGetResponse
|
||||||
|
from .. setup_logging import setup_logging
|
||||||
|
from .. profile_image import generate_image, ImageRequest
|
||||||
|
from .. import defines
|
||||||
|
from .. user import User
|
||||||
|
|
||||||
|
logger = setup_logging()
|
||||||
|
|
||||||
|
seed = int(time.time())
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
class ImageGenerator(Agent):
|
||||||
|
agent_type: Literal["image"] = "image" # type: ignore
|
||||||
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
agent_persist: bool = False
|
||||||
|
|
||||||
|
system_prompt: str = "" # No system prompt is used
|
||||||
|
username: str
|
||||||
|
filename: str
|
||||||
|
|
||||||
|
llm: Any = Field(default=None, exclude=True)
|
||||||
|
model: str = Field(default=None, exclude=True)
|
||||||
|
user: Dict[str, Any] = Field(default={}, exclude=True)
|
||||||
|
|
||||||
|
async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
|
||||||
|
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
|
message.status = "done"
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
async def process_message(
|
||||||
|
self, llm: Any, model: str, message: Message
|
||||||
|
) -> AsyncGenerator[Message, None]:
|
||||||
|
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
|
if not self.context:
|
||||||
|
raise ValueError("Context is not set for this agent.")
|
||||||
|
|
||||||
|
self.llm = llm
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
spinner: List[str] = ["\\", "|", "/", "-"]
|
||||||
|
tick: int = 0
|
||||||
|
while self.context.processing:
|
||||||
|
logger.info(
|
||||||
|
"TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time"
|
||||||
|
)
|
||||||
|
message.status = "waiting"
|
||||||
|
message.response = (
|
||||||
|
f"Busy processing another request. Please wait. {spinner[tick]}"
|
||||||
|
)
|
||||||
|
tick = (tick + 1) % len(spinner)
|
||||||
|
yield message
|
||||||
|
await asyncio.sleep(1) # Allow the event loop to process the write
|
||||||
|
|
||||||
|
self.context.processing = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
#
|
||||||
|
# Generate the profile picture
|
||||||
|
#
|
||||||
|
prompt = message.prompt
|
||||||
|
message.status = "thinking"
|
||||||
|
message.response = f"Generating: {prompt}"
|
||||||
|
yield message
|
||||||
|
|
||||||
|
user_info = os.path.join(defines.user_dir, self.username, defines.user_info_file)
|
||||||
|
with open(user_info, "r") as f:
|
||||||
|
self.user = json.loads(f.read())
|
||||||
|
|
||||||
|
logger.info("Beginning image generation...", self.user)
|
||||||
|
logger.info("TODO: Add safety checks for filename... actually figure out an entirely different way to figure out where to store them.")
|
||||||
|
self.filename = "profile.png"
|
||||||
|
request = ImageRequest(filepath=os.path.join(defines.user_dir, self.user["username"], self.filename), prompt=prompt)
|
||||||
|
async for message in generate_image(
|
||||||
|
message=message,
|
||||||
|
request=request
|
||||||
|
):
|
||||||
|
if message.status != "done":
|
||||||
|
yield message
|
||||||
|
logger.info("Image generation done...")
|
||||||
|
images = self.user.get("images", [])
|
||||||
|
if self.filename not in images:
|
||||||
|
images.append(self.filename)
|
||||||
|
if self.filename == "profile.png":
|
||||||
|
self.user["has_profile"] = True
|
||||||
|
|
||||||
|
#
|
||||||
|
# Write out the completed user information
|
||||||
|
#
|
||||||
|
with open(user_info, "w") as f:
|
||||||
|
f.write(json.dumps(self.user))
|
||||||
|
|
||||||
|
# Image generated
|
||||||
|
message.status = "done"
|
||||||
|
message.response = json.dumps(self.user)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
message.status = "error"
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
logger.error(message.response)
|
||||||
|
message.response = f"Error in image generation: {str(e)}"
|
||||||
|
logger.error(message.response)
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
# Done processing, add message to conversation
|
||||||
|
self.context.processing = False
|
||||||
|
# Return the final message
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
async def call_llm(self, message: Message, system_prompt, prompt, temperature=0.7):
|
||||||
|
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
|
|
||||||
|
messages: List[LLMMessage] = [
|
||||||
|
LLMMessage(role="system", content=system_prompt),
|
||||||
|
LLMMessage(role="user", content=prompt),
|
||||||
|
]
|
||||||
|
message.metadata.options = {
|
||||||
|
"seed": 8911,
|
||||||
|
"num_ctx": self.context_size,
|
||||||
|
"temperature": temperature, # Higher temperature to encourage tool usage
|
||||||
|
}
|
||||||
|
|
||||||
|
message.status = "streaming"
|
||||||
|
yield message
|
||||||
|
|
||||||
|
last_chunk_time = 0
|
||||||
|
message.chunk = ""
|
||||||
|
message.response = ""
|
||||||
|
for response in self.llm.chat(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
options={
|
||||||
|
**message.metadata.options,
|
||||||
|
},
|
||||||
|
stream=True,
|
||||||
|
):
|
||||||
|
if not response:
|
||||||
|
message.status = "error"
|
||||||
|
message.response = "No response from LLM."
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
message.status = "streaming"
|
||||||
|
message.chunk += response.message.content
|
||||||
|
message.response += response.message.content
|
||||||
|
|
||||||
|
if not response.done:
|
||||||
|
now = time.perf_counter()
|
||||||
|
if now - last_chunk_time > 0.25:
|
||||||
|
yield message
|
||||||
|
last_chunk_time = now
|
||||||
|
message.chunk = ""
|
||||||
|
|
||||||
|
if response.done:
|
||||||
|
self.collect_metrics(response)
|
||||||
|
message.metadata.eval_count += response.eval_count
|
||||||
|
message.metadata.eval_duration += response.eval_duration
|
||||||
|
message.metadata.prompt_eval_count += response.prompt_eval_count
|
||||||
|
message.metadata.prompt_eval_duration += response.prompt_eval_duration
|
||||||
|
self.context_tokens = response.prompt_eval_count + response.eval_count
|
||||||
|
message.chunk = ""
|
||||||
|
message.status = "done"
|
||||||
|
yield message
|
||||||
|
|
||||||
|
def extract_json_from_text(self, text: str) -> str:
|
||||||
|
"""Extract JSON string from text that may contain other content."""
|
||||||
|
json_pattern = r"```json\s*([\s\S]*?)\s*```"
|
||||||
|
match = re.search(json_pattern, text)
|
||||||
|
if match:
|
||||||
|
return match.group(1).strip()
|
||||||
|
|
||||||
|
# Try to find JSON without the markdown code block
|
||||||
|
json_pattern = r"({[\s\S]*})"
|
||||||
|
match = re.search(json_pattern, text)
|
||||||
|
if match:
|
||||||
|
return match.group(1).strip()
|
||||||
|
|
||||||
|
raise ValueError("No JSON found in the response")
|
||||||
|
|
||||||
|
def extract_markdown_from_text(self, text: str) -> str:
|
||||||
|
"""Extract Markdown string from text that may contain other content."""
|
||||||
|
markdown_pattern = r"```(md|markdown)\s*([\s\S]*?)\s*```"
|
||||||
|
match = re.search(markdown_pattern, text)
|
||||||
|
if match:
|
||||||
|
return match.group(2).strip()
|
||||||
|
|
||||||
|
raise ValueError("No Markdown found in the response")
|
||||||
|
|
||||||
|
# Register the base agent
|
||||||
|
agent_registry.register(ImageGenerator._agent_type, ImageGenerator)
|
422
src/utils/agents/persona_generator.py
Normal file
422
src/utils/agents/persona_generator.py
Normal file
@ -0,0 +1,422 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from pydantic import model_validator, Field, BaseModel # type: ignore
|
||||||
|
from typing import (
|
||||||
|
Dict,
|
||||||
|
Literal,
|
||||||
|
ClassVar,
|
||||||
|
cast,
|
||||||
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
|
List,
|
||||||
|
Optional
|
||||||
|
# override
|
||||||
|
) # NOTE: You must import Optional for late binding to work
|
||||||
|
import inspect
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
import traceback
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
|
||||||
|
from .base import Agent, agent_registry, LLMMessage
|
||||||
|
from ..message import Message
|
||||||
|
from ..rag import ChromaDBGetResponse
|
||||||
|
from ..setup_logging import setup_logging
|
||||||
|
from .. import defines
|
||||||
|
from ..user import User
|
||||||
|
|
||||||
|
logger = setup_logging()
|
||||||
|
|
||||||
|
seed = int(time.time())
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
emptyUser = {
|
||||||
|
"profile_url": "",
|
||||||
|
"description": "",
|
||||||
|
"rag_content_size": 0,
|
||||||
|
"username": "",
|
||||||
|
"first_name": "",
|
||||||
|
"last_name": "",
|
||||||
|
"full_name": "",
|
||||||
|
"contact_info": {},
|
||||||
|
"questions": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
generate_persona_system_prompt = """\
|
||||||
|
You are a casing director for a movie. Your job is to provide information on ficticious personas for use in a screen play.
|
||||||
|
|
||||||
|
All response field MUST BE IN ENGLISH, regardless of ethnicity.
|
||||||
|
|
||||||
|
You will be provided with defaults to use if not specified by the user:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"age": number,
|
||||||
|
"gender": "male" | "female",
|
||||||
|
"ethnicity": string,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Additional information provided in the user message can override those defaults.
|
||||||
|
|
||||||
|
You need to randomly assign an English username (can include numbers), a first name, last name, and a two English sentence description of that individual's work given the demographics provided.
|
||||||
|
|
||||||
|
Your response must be in JSON.
|
||||||
|
Provide only the JSON response, and match the field names EXACTLY.
|
||||||
|
Provide all information in English ONLY, with no other commentary:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"username": string, # A likely-to-be unique username, no more than 15 characters (can include numbers and letters but no special characters)
|
||||||
|
"first_name": string,
|
||||||
|
"last_name": string,
|
||||||
|
"description": string, # One to two sentence description of their job
|
||||||
|
"location": string, # In the location, provide ALL of: City, State/Region, and Country
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Make sure to provide a username and that the field name for the job description is "description".
|
||||||
|
"""
|
||||||
|
|
||||||
|
generate_resume_system_prompt = """
|
||||||
|
You are a creative writing casting director. As part of the casting, you are building backstories about individuals. The first part
|
||||||
|
of that is to create an in-depth resume for the person. You will be provided with the following information:
|
||||||
|
|
||||||
|
```json
|
||||||
|
"full_name": string, # Person full name
|
||||||
|
"location": string, # Location of residence
|
||||||
|
"age": number, # Age of candidate
|
||||||
|
"description": string # A brief description of the person
|
||||||
|
```
|
||||||
|
|
||||||
|
Use that information to invent a full career resume. Include sections such as:
|
||||||
|
|
||||||
|
* Contact information
|
||||||
|
* Job goal
|
||||||
|
* Top skills
|
||||||
|
* Detailed work history. If they are under the age of 25, you might include skills, hobbies, or volunteering they may have done while an adolescent
|
||||||
|
* In the work history, provide company names, years of employment, and their role
|
||||||
|
* Education
|
||||||
|
|
||||||
|
Provide the resume in Markdown format. DO NOT provide any commentary before or after the resume.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class PersonaGenerator(Agent):
|
||||||
|
agent_type: Literal["persona"] = "persona" # type: ignore
|
||||||
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
agent_persist: bool = False
|
||||||
|
|
||||||
|
system_prompt: str = generate_persona_system_prompt
|
||||||
|
age: int = Field(default_factory=lambda: random.randint(22, 67))
|
||||||
|
gender: str = Field(default_factory=lambda: random.choice(["male", "female"]))
|
||||||
|
ethnicity: Literal[
|
||||||
|
"Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"
|
||||||
|
] = Field(
|
||||||
|
default_factory=lambda: random.choices(
|
||||||
|
["Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"],
|
||||||
|
weights=[57.69, 15.38, 19.23, 5.77, 1.92],
|
||||||
|
k=1
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
username: str = ""
|
||||||
|
|
||||||
|
llm: Any = Field(default=None, exclude=True)
|
||||||
|
model: str = Field(default=None, exclude=True)
|
||||||
|
|
||||||
|
def randomize(self):
|
||||||
|
self.age = random.randint(22, 67)
|
||||||
|
self.gender = random.choice(["male", "female"])
|
||||||
|
# Use random.choices with explicit type casting to satisfy Literal type
|
||||||
|
self.ethnicity = cast(
|
||||||
|
Literal["Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"],
|
||||||
|
random.choices(
|
||||||
|
["Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"],
|
||||||
|
weights=[57.69, 15.38, 19.23, 5.77, 1.92],
|
||||||
|
k=1
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
|
||||||
|
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
|
|
||||||
|
if not self.context:
|
||||||
|
raise ValueError("Context is not set for this agent.")
|
||||||
|
|
||||||
|
message.tunables.enable_tools = False
|
||||||
|
message.tunables.enable_rag = False
|
||||||
|
message.tunables.enable_context = False
|
||||||
|
|
||||||
|
message.prompt = f"""\
|
||||||
|
```json
|
||||||
|
{json.dumps({
|
||||||
|
"age": self.age,
|
||||||
|
"gender": self.gender,
|
||||||
|
"ethnicity": self.ethnicity
|
||||||
|
})}
|
||||||
|
```
|
||||||
|
{message.prompt}
|
||||||
|
"""
|
||||||
|
message.status = "done"
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
async def process_message(
|
||||||
|
self, llm: Any, model: str, message: Message
|
||||||
|
) -> AsyncGenerator[Message, None]:
|
||||||
|
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
|
if not self.context:
|
||||||
|
raise ValueError("Context is not set for this agent.")
|
||||||
|
|
||||||
|
self.llm = llm
|
||||||
|
self.model = model
|
||||||
|
original_prompt = message.prompt
|
||||||
|
|
||||||
|
spinner: List[str] = ["\\", "|", "/", "-"]
|
||||||
|
tick: int = 0
|
||||||
|
while self.context.processing:
|
||||||
|
logger.info(
|
||||||
|
"TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time"
|
||||||
|
)
|
||||||
|
message.status = "waiting"
|
||||||
|
message.response = (
|
||||||
|
f"Busy processing another request. Please wait. {spinner[tick]}"
|
||||||
|
)
|
||||||
|
tick = (tick + 1) % len(spinner)
|
||||||
|
yield message
|
||||||
|
await asyncio.sleep(1) # Allow the event loop to process the write
|
||||||
|
|
||||||
|
self.context.processing = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
#
|
||||||
|
# Generate the persona
|
||||||
|
#
|
||||||
|
async for message in self.call_llm(
|
||||||
|
message=message, system_prompt=self.system_prompt, prompt=original_prompt
|
||||||
|
):
|
||||||
|
if message.status != "done":
|
||||||
|
yield message
|
||||||
|
if message.status == "error":
|
||||||
|
raise Exception(message.response)
|
||||||
|
|
||||||
|
json_str = self.extract_json_from_text(message.response)
|
||||||
|
try:
|
||||||
|
persona = json.loads(json_str) | {
|
||||||
|
"age": self.age,
|
||||||
|
"gender": self.gender,
|
||||||
|
"ethnicity": self.ethnicity
|
||||||
|
}
|
||||||
|
if not persona.get("full_name", None):
|
||||||
|
persona["full_name"] = f"{persona['first_name']} {persona['last_name']}"
|
||||||
|
self.username = persona.get("username", None)
|
||||||
|
if not self.username:
|
||||||
|
raise ValueError("LLM did not generate a username")
|
||||||
|
user_dir = os.path.join(defines.user_dir, persona["username"])
|
||||||
|
while os.path.exists(user_dir):
|
||||||
|
match = re.match(r"^(.*?)(\d*)$", persona["username"])
|
||||||
|
if match:
|
||||||
|
base = match.group(1)
|
||||||
|
num = match.group(2)
|
||||||
|
iteration = int(num) + 1 if num else 1
|
||||||
|
persona["username"] = f"{base}{iteration}"
|
||||||
|
user_dir = os.path.join(defines.user_dir, persona["username"])
|
||||||
|
|
||||||
|
for key in persona:
|
||||||
|
if isinstance(persona[key], str):
|
||||||
|
persona[key] = persona[key].strip()
|
||||||
|
# Mark this persona as AI generated
|
||||||
|
persona["is_ai"] = True
|
||||||
|
except Exception as e:
|
||||||
|
message.response = f"Unable to parse LLM returned content: {json_str} {str(e)}"
|
||||||
|
message.status = "error"
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
logger.error(message.response)
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
# Persona generated
|
||||||
|
message.response = json.dumps(persona)
|
||||||
|
message.status = "partial"
|
||||||
|
yield message
|
||||||
|
|
||||||
|
#
|
||||||
|
# Generate the resume
|
||||||
|
#
|
||||||
|
message.status = "thinking"
|
||||||
|
message.response = f"Generating resume for {persona['full_name']}..."
|
||||||
|
yield message
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"full_name": "{persona["full_name"]}",
|
||||||
|
"location": "{persona["location"]}",
|
||||||
|
"age": {persona["age"]},
|
||||||
|
"description": {persona["description"]}
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async for message in self.call_llm(
|
||||||
|
message=message, system_prompt=generate_resume_system_prompt, prompt=prompt
|
||||||
|
):
|
||||||
|
if message.status != "done":
|
||||||
|
yield message
|
||||||
|
if message.status == "error":
|
||||||
|
raise Exception(message.response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
message.response = f"Unable to parse LLM returned content: {json_str} {str(e)}"
|
||||||
|
message.status = "error"
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
logger.error(message.response)
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
resume = self.extract_markdown_from_text(message.response)
|
||||||
|
if resume:
|
||||||
|
user_resume_dir = os.path.join(defines.user_dir, persona["username"], defines.resume_doc_dir)
|
||||||
|
os.makedirs(user_resume_dir, exist_ok=True)
|
||||||
|
user_resume_file = os.path.join(user_resume_dir, defines.resume_doc)
|
||||||
|
with open(user_resume_file, "w") as f:
|
||||||
|
f.write(resume)
|
||||||
|
|
||||||
|
# Resume generated
|
||||||
|
message.response = resume
|
||||||
|
message.status = "partial"
|
||||||
|
yield message
|
||||||
|
|
||||||
|
#
|
||||||
|
# Generate RAG database
|
||||||
|
#
|
||||||
|
message.status = "thinking"
|
||||||
|
message.response = f"Generating RAG content from resume..."
|
||||||
|
yield message
|
||||||
|
|
||||||
|
# Prior to instancing a new User, the json data has to be created
|
||||||
|
# so the system can process it
|
||||||
|
user_dir = os.path.join(defines.user_dir, persona["username"])
|
||||||
|
os.makedirs(user_dir, exist_ok=True)
|
||||||
|
user_info = os.path.join(user_dir, "info.json")
|
||||||
|
with open(user_info, "w") as f:
|
||||||
|
f.write(json.dumps(persona, indent=2))
|
||||||
|
|
||||||
|
user = User(llm=self.llm, username=self.username)
|
||||||
|
await user.initialize()
|
||||||
|
await user.file_watcher.initialize_collection()
|
||||||
|
# RAG content generated
|
||||||
|
message.response = f"{user.file_watcher.collection.count()} entries created in RAG vector store."
|
||||||
|
|
||||||
|
#
|
||||||
|
# Write out the completed user information
|
||||||
|
#
|
||||||
|
with open(user_info, "w") as f:
|
||||||
|
f.write(json.dumps(persona, indent=2))
|
||||||
|
|
||||||
|
# Image generated
|
||||||
|
message.status = "done"
|
||||||
|
message.response = json.dumps(persona)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
message.status = "error"
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
logger.error(message.response)
|
||||||
|
message.response = f"Error in persona generation: {str(e)}"
|
||||||
|
logger.error(message.response)
|
||||||
|
self.randomize() # Randomize for next generation
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
# Done processing, add message to conversation
|
||||||
|
self.context.processing = False
|
||||||
|
self.randomize() # Randomize for next generation
|
||||||
|
# Return the final message
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
async def call_llm(self, message: Message, system_prompt, prompt, temperature=0.7):
|
||||||
|
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
|
|
||||||
|
messages: List[LLMMessage] = [
|
||||||
|
LLMMessage(role="system", content=system_prompt),
|
||||||
|
LLMMessage(role="user", content=prompt),
|
||||||
|
]
|
||||||
|
message.metadata.options = {
|
||||||
|
"seed": 8911,
|
||||||
|
"num_ctx": self.context_size,
|
||||||
|
"temperature": temperature, # Higher temperature to encourage tool usage
|
||||||
|
}
|
||||||
|
|
||||||
|
message.status = "streaming"
|
||||||
|
yield message
|
||||||
|
|
||||||
|
last_chunk_time = 0
|
||||||
|
message.chunk = ""
|
||||||
|
message.response = ""
|
||||||
|
for response in self.llm.chat(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
options={
|
||||||
|
**message.metadata.options,
|
||||||
|
},
|
||||||
|
stream=True,
|
||||||
|
):
|
||||||
|
if not response:
|
||||||
|
message.status = "error"
|
||||||
|
message.response = "No response from LLM."
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
message.status = "streaming"
|
||||||
|
message.chunk += response.message.content
|
||||||
|
message.response += response.message.content
|
||||||
|
|
||||||
|
if not response.done:
|
||||||
|
now = time.perf_counter()
|
||||||
|
if now - last_chunk_time > 0.25:
|
||||||
|
yield message
|
||||||
|
last_chunk_time = now
|
||||||
|
message.chunk = ""
|
||||||
|
|
||||||
|
if response.done:
|
||||||
|
self.collect_metrics(response)
|
||||||
|
message.metadata.eval_count += response.eval_count
|
||||||
|
message.metadata.eval_duration += response.eval_duration
|
||||||
|
message.metadata.prompt_eval_count += response.prompt_eval_count
|
||||||
|
message.metadata.prompt_eval_duration += response.prompt_eval_duration
|
||||||
|
self.context_tokens = response.prompt_eval_count + response.eval_count
|
||||||
|
message.chunk = ""
|
||||||
|
message.status = "done"
|
||||||
|
yield message
|
||||||
|
|
||||||
|
def extract_json_from_text(self, text: str) -> str:
|
||||||
|
"""Extract JSON string from text that may contain other content."""
|
||||||
|
json_pattern = r"```json\s*([\s\S]*?)\s*```"
|
||||||
|
match = re.search(json_pattern, text)
|
||||||
|
if match:
|
||||||
|
return match.group(1).strip()
|
||||||
|
|
||||||
|
# Try to find JSON without the markdown code block
|
||||||
|
json_pattern = r"({[\s\S]*})"
|
||||||
|
match = re.search(json_pattern, text)
|
||||||
|
if match:
|
||||||
|
return match.group(1).strip()
|
||||||
|
|
||||||
|
raise ValueError("No JSON found in the response")
|
||||||
|
|
||||||
|
def extract_markdown_from_text(self, text: str) -> str:
|
||||||
|
"""Extract Markdown string from text that may contain other content."""
|
||||||
|
markdown_pattern = r"```(md|markdown)\s*([\s\S]*?)\s*```"
|
||||||
|
match = re.search(markdown_pattern, text)
|
||||||
|
if match:
|
||||||
|
return match.group(2).strip()
|
||||||
|
|
||||||
|
raise ValueError("No Markdown found in the response")
|
||||||
|
|
||||||
|
# Register the base agent
|
||||||
|
agent_registry.register(PersonaGenerator._agent_type, PersonaGenerator)
|
Loading…
x
Reference in New Issue
Block a user