diff --git a/frontend/src/NewApp/Components/Pulse.tsx b/frontend/src/NewApp/Components/Pulse.tsx new file mode 100644 index 0000000..db01dcb --- /dev/null +++ b/frontend/src/NewApp/Components/Pulse.tsx @@ -0,0 +1,180 @@ +import React, { useEffect, useState, useRef } from 'react'; +import { SxProps } from '@mui/material'; +import Box from '@mui/material/Box'; + +interface PulseProps { + timestamp: number | string; + sx?: SxProps; +} + +const Pulse: React.FC = ({ timestamp, sx }) => { + const [isAnimating, setIsAnimating] = useState(false); + const [animationKey, setAnimationKey] = useState(0); + const previousTimestamp = useRef(null); + + useEffect(() => { + if (timestamp && timestamp !== previousTimestamp.current) { + previousTimestamp.current = timestamp; + setAnimationKey(prev => prev + 1); + setIsAnimating(true); + + // Reset animation state after animation completes + const timer = setTimeout(() => { + setIsAnimating(false); + }, 1000); + + return () => clearTimeout(timer); + } + }, [timestamp]); + + const containerStyle: React.CSSProperties = { + position: 'relative', + width: 80, + height: 80, + display: 'flex', + alignItems: 'center', + justifyContent: 'center', + }; + + const baseCoreStyle: React.CSSProperties = { + width: 24, + height: 24, + borderRadius: '50%', + backgroundColor: '#2196f3', + position: 'relative', + zIndex: 3, + }; + + const coreStyle: React.CSSProperties = { + ...baseCoreStyle, + animation: isAnimating ? 'pulse-glow 1s ease-out' : 'none', + }; + + const pulseRing1Style: React.CSSProperties = { + position: 'absolute', + width: 24, + height: 24, + borderRadius: '50%', + backgroundColor: '#2196f3', + zIndex: 2, + animation: 'pulse-expand 1s ease-out forwards', + }; + + const pulseRing2Style: React.CSSProperties = { + position: 'absolute', + width: 24, + height: 24, + borderRadius: '50%', + backgroundColor: '#64b5f6', + zIndex: 1, + animation: 'pulse-expand 1s ease-out 0.2s forwards', + }; + + const rippleStyle: React.CSSProperties = { + position: 'absolute', + width: 32, + height: 32, + borderRadius: '50%', + border: '2px solid #2196f3', + backgroundColor: 'transparent', + zIndex: 0, + animation: 'ripple-expand 1s ease-out forwards', + }; + + const outerRippleStyle: React.CSSProperties = { + position: 'absolute', + width: 40, + height: 40, + borderRadius: '50%', + border: '1px solid #90caf9', + backgroundColor: 'transparent', + zIndex: 0, + animation: 'ripple-expand 1s ease-out 0.3s forwards', + }; + + + + return ( + <> + + + + {/* Base circle */} +
+ + {/* Pulse rings */} + {isAnimating && ( + <> + {/* Primary pulse ring */} +
+ + {/* Secondary pulse ring with delay */} +
+ + {/* Ripple effect */} +
+ + {/* Outer ripple */} +
+ + )} + + + + + ); +}; + +export { Pulse } ; \ No newline at end of file diff --git a/frontend/src/NewApp/Pages/GenerateCandidate.tsx b/frontend/src/NewApp/Pages/GenerateCandidate.tsx new file mode 100644 index 0000000..7f8f138 --- /dev/null +++ b/frontend/src/NewApp/Pages/GenerateCandidate.tsx @@ -0,0 +1,330 @@ +import React, { useEffect, useState, useRef } from 'react'; +import Avatar from '@mui/material/Avatar'; +import Box from '@mui/material/Box'; +import Tooltip from '@mui/material/Tooltip'; +import Button from '@mui/material/Button'; +import Paper from '@mui/material/Paper'; +import IconButton from '@mui/material/IconButton'; +import CancelIcon from '@mui/icons-material/Cancel'; +import SendIcon from '@mui/icons-material/Send'; +import PropagateLoader from 'react-spinners/PropagateLoader'; +import { CandidateInfo } from '../Components/CandidateInfo'; +import { Query } from '../../Components/ChatQuery' +import { streamQueryResponse, StreamQueryController } from '../Components/streamQueryResponse'; +import { connectionBase } from 'Global'; +import { UserInfo } from '../Components/UserContext'; +import { BackstoryElementProps } from 'Components/BackstoryTab'; +import { BackstoryTextField, BackstoryTextFieldRef } from 'Components/BackstoryTextField'; +import { jsonrepair } from 'jsonrepair'; +import { StyledMarkdown } from 'Components/StyledMarkdown'; +import { Scrollable } from 'Components/Scrollable'; +import { useForkRef } from '@mui/material'; +import { BackstoryMessage } from 'Components/Message'; +import { Pulse } from 'NewApp/Components/Pulse'; + +const emptyUser : UserInfo = { + type: 'candidate', + description: "[blank]", + rag_content_size: 0, + username: "[blank]", + first_name: "[blank]", + last_name: "[blank]", + full_name: "[blank] [blank]", + contact_info: {}, + questions: [], + isAuthenticated: false, + has_profile: false +}; + +const GenerateCandidate = (props: BackstoryElementProps) => { + const {sessionId, setSnack, submitQuery} = props; + const [streaming, setStreaming] = useState(''); + const [processing, setProcessing] = useState(false); + const [user, setUser] = useState(emptyUser); + const controllerRef = useRef(null); + const backstoryTextRef = useRef(null); + const promptRef = useRef(null); + const stateRef = useRef(0); /* Generating persona */ + const userRef = useRef(user); + const [prompt, setPrompt] = useState(''); + const [resume, setResume] = useState(''); + const [canGenImage, setCanGenImage] = useState(false); + const [hasProfile, setHasProfile] = useState(false); + const [status, setStatus] = useState(''); + const [timestamp, setTimestamp] = useState(0); + + const generateProfile = () => { + if (controllerRef.current) { + return; + } + setProcessing(true); + setCanGenImage(false); + stateRef.current = 3; + const start = Date.now(); + + controllerRef.current = streamQueryResponse({ + query: { + prompt: `A photorealistic profile picture of a ${user.age} year old ${user.gender} ${user.ethnicity} person. ${prompt}`, + agent_options: { + username: userRef.current.username, + filename: "profile.png" + } + }, + type: "image", + sessionId, + connectionBase, + onComplete: (msg) => { + console.log({ msg, state: stateRef.current, prompt: promptRef.current || '' }); + switch (msg.status) { + case "partial": + case "done": + if (msg.status === "done") { + setProcessing(false); + controllerRef.current = null; + stateRef.current = 0; + setCanGenImage(true); + setHasProfile(true); + } + break; + case "error": + console.log(`Error generating persona: ${msg.response} after ${Date.now() - start}`); + setSnack(msg.response || "", "error"); + setProcessing(false); + setUser({...userRef.current}); + controllerRef.current = null; + stateRef.current = 0; + setCanGenImage(true); + setHasProfile(true); /* Hack for now */ + break; + default: + const data = JSON.parse(msg.response || ''); + if (data.timestamp) { + setTimestamp(data.timestamp); + } else { + setTimestamp(Date.now()) + } + if (data.message) { + setStatus(data.message); + } + break; + } + } + }); + + }; + + const generatePersona = (query: Query) => { + if (controllerRef.current) { + return; + } + setPrompt(query.prompt); + promptRef.current = query.prompt; + stateRef.current = 0; + setUser(emptyUser); + setStreaming(''); + setResume(''); + setProcessing(true); + setCanGenImage(false); + + controllerRef.current = streamQueryResponse({ + query, + type: "persona", + sessionId, + connectionBase, + onComplete: (msg) => { + switch (msg.status) { + case "partial": + case "done": + switch (stateRef.current) { + case 0: /* Generating persona */ + let partialUser = JSON.parse(jsonrepair((msg.response || '').trim())); + if (!partialUser.full_name) { + partialUser.full_name = `${partialUser.first_name} ${partialUser.last_name}`; + } + console.log(partialUser); + setUser({...partialUser}); + stateRef.current++; /* Generating resume */ + break; + case 1: /* Generating resume */ + stateRef.current++; /* RAG generation */ + setResume(msg.response || ''); + break; + case 2: /* RAG generation */ + stateRef.current++; /* Image generation */ + break; + } + if (msg.status === "done") { + setProcessing(false); + setCanGenImage(true); + controllerRef.current = null; + stateRef.current = 0; + setTimeout(() => { + generateProfile(); + }, 0); + } + break; + case "thinking": + setStatus(msg.response || ''); + break; + + case "error": + console.log(`Error generating persona: ${msg.response}`); + setSnack(msg.response || "", "error"); + setProcessing(false); + setUser({...userRef.current}); + controllerRef.current = null; + stateRef.current = 0; + break; + } + }, + onStreaming: (chunk) => { + setStreaming(chunk); + } + }); + }; + + const cancelQuery = () => { + if (controllerRef.current) { + controllerRef.current.abort(); + controllerRef.current = null; + stateRef.current = 0; + setProcessing(false); + } + } + + useEffect(() => { + promptRef.current = prompt; + }, [prompt]); + + useEffect(() => { + userRef.current = user; + }, [user]); + + useEffect(() => { + if (streaming.trim().length === 0) { + return; + } + + try { + switch (stateRef.current) { + case 0: /* Generating persona */ + const partialUser = {...emptyUser, ...JSON.parse(jsonrepair(`${streaming.trim()}...`))}; + if (!partialUser.full_name) { + partialUser.full_name = `${partialUser.first_name} ${partialUser.last_name}`; + } + setUser(partialUser); + break; + case 1: /* Generating resume */ + setResume(streaming); + break; + case 3: /* RAG streaming */ + break; + case 4: /* Image streaming */ + break; + } + } catch { + } + }, [streaming]); + + if (!sessionId) { + return <>; + } + + const onEnter = (value: string) => { + if (processing) { + return; + } + const query: Query = { + prompt: value, + } + generatePersona(query); + }; + + return ( + + { user && } + {processing && + + {status} + + + } + + + + { processing && } + + + + + + + + { resume !== '' && } + + + + + + + + + { /* This span is used to wrap the IconButton to ensure Tooltip works even when disabled */} + { cancelQuery(); }} + sx={{ display: "flex", margin: 'auto 0px' }} + size="large" + edge="start" + disabled={controllerRef.current === null || !sessionId || processing === false} + > + + + + + + ); +}; + +export { + GenerateCandidate +}; \ No newline at end of file diff --git a/src/utils/agents/image_generator.py b/src/utils/agents/image_generator.py new file mode 100644 index 0000000..fe8d493 --- /dev/null +++ b/src/utils/agents/image_generator.py @@ -0,0 +1,218 @@ +from __future__ import annotations +from pydantic import model_validator, Field, BaseModel # type: ignore +from typing import ( + Dict, + Literal, + ClassVar, + cast, + Any, + AsyncGenerator, + List, + Optional +# override +) # NOTE: You must import Optional for late binding to work +import inspect +import random +import re +import json +import traceback +import asyncio +import time +import asyncio +import time +import os + +from . base import Agent, agent_registry, LLMMessage +from .. message import Message +from .. rag import ChromaDBGetResponse +from .. setup_logging import setup_logging +from .. profile_image import generate_image, ImageRequest +from .. import defines +from .. user import User + +logger = setup_logging() + +seed = int(time.time()) +random.seed(seed) + +class ImageGenerator(Agent): + agent_type: Literal["image"] = "image" # type: ignore + _agent_type: ClassVar[str] = agent_type # Add this for registration + agent_persist: bool = False + + system_prompt: str = "" # No system prompt is used + username: str + filename: str + + llm: Any = Field(default=None, exclude=True) + model: str = Field(default=None, exclude=True) + user: Dict[str, Any] = Field(default={}, exclude=True) + + async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]: + logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") + message.status = "done" + yield message + return + + async def process_message( + self, llm: Any, model: str, message: Message + ) -> AsyncGenerator[Message, None]: + logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") + if not self.context: + raise ValueError("Context is not set for this agent.") + + self.llm = llm + self.model = model + + spinner: List[str] = ["\\", "|", "/", "-"] + tick: int = 0 + while self.context.processing: + logger.info( + "TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time" + ) + message.status = "waiting" + message.response = ( + f"Busy processing another request. Please wait. {spinner[tick]}" + ) + tick = (tick + 1) % len(spinner) + yield message + await asyncio.sleep(1) # Allow the event loop to process the write + + self.context.processing = True + + try: + # + # Generate the profile picture + # + prompt = message.prompt + message.status = "thinking" + message.response = f"Generating: {prompt}" + yield message + + user_info = os.path.join(defines.user_dir, self.username, defines.user_info_file) + with open(user_info, "r") as f: + self.user = json.loads(f.read()) + + logger.info("Beginning image generation...", self.user) + logger.info("TODO: Add safety checks for filename... actually figure out an entirely different way to figure out where to store them.") + self.filename = "profile.png" + request = ImageRequest(filepath=os.path.join(defines.user_dir, self.user["username"], self.filename), prompt=prompt) + async for message in generate_image( + message=message, + request=request + ): + if message.status != "done": + yield message + logger.info("Image generation done...") + images = self.user.get("images", []) + if self.filename not in images: + images.append(self.filename) + if self.filename == "profile.png": + self.user["has_profile"] = True + + # + # Write out the completed user information + # + with open(user_info, "w") as f: + f.write(json.dumps(self.user)) + + # Image generated + message.status = "done" + message.response = json.dumps(self.user) + + except Exception as e: + message.status = "error" + logger.error(traceback.format_exc()) + logger.error(message.response) + message.response = f"Error in image generation: {str(e)}" + logger.error(message.response) + yield message + return + + # Done processing, add message to conversation + self.context.processing = False + # Return the final message + yield message + return + + async def call_llm(self, message: Message, system_prompt, prompt, temperature=0.7): + logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") + + messages: List[LLMMessage] = [ + LLMMessage(role="system", content=system_prompt), + LLMMessage(role="user", content=prompt), + ] + message.metadata.options = { + "seed": 8911, + "num_ctx": self.context_size, + "temperature": temperature, # Higher temperature to encourage tool usage + } + + message.status = "streaming" + yield message + + last_chunk_time = 0 + message.chunk = "" + message.response = "" + for response in self.llm.chat( + model=self.model, + messages=messages, + options={ + **message.metadata.options, + }, + stream=True, + ): + if not response: + message.status = "error" + message.response = "No response from LLM." + yield message + return + + message.status = "streaming" + message.chunk += response.message.content + message.response += response.message.content + + if not response.done: + now = time.perf_counter() + if now - last_chunk_time > 0.25: + yield message + last_chunk_time = now + message.chunk = "" + + if response.done: + self.collect_metrics(response) + message.metadata.eval_count += response.eval_count + message.metadata.eval_duration += response.eval_duration + message.metadata.prompt_eval_count += response.prompt_eval_count + message.metadata.prompt_eval_duration += response.prompt_eval_duration + self.context_tokens = response.prompt_eval_count + response.eval_count + message.chunk = "" + message.status = "done" + yield message + + def extract_json_from_text(self, text: str) -> str: + """Extract JSON string from text that may contain other content.""" + json_pattern = r"```json\s*([\s\S]*?)\s*```" + match = re.search(json_pattern, text) + if match: + return match.group(1).strip() + + # Try to find JSON without the markdown code block + json_pattern = r"({[\s\S]*})" + match = re.search(json_pattern, text) + if match: + return match.group(1).strip() + + raise ValueError("No JSON found in the response") + + def extract_markdown_from_text(self, text: str) -> str: + """Extract Markdown string from text that may contain other content.""" + markdown_pattern = r"```(md|markdown)\s*([\s\S]*?)\s*```" + match = re.search(markdown_pattern, text) + if match: + return match.group(2).strip() + + raise ValueError("No Markdown found in the response") + +# Register the base agent +agent_registry.register(ImageGenerator._agent_type, ImageGenerator) diff --git a/src/utils/agents/persona_generator.py b/src/utils/agents/persona_generator.py new file mode 100644 index 0000000..d2a9151 --- /dev/null +++ b/src/utils/agents/persona_generator.py @@ -0,0 +1,422 @@ +from __future__ import annotations +from pydantic import model_validator, Field, BaseModel # type: ignore +from typing import ( + Dict, + Literal, + ClassVar, + cast, + Any, + AsyncGenerator, + List, + Optional +# override +) # NOTE: You must import Optional for late binding to work +import inspect +import random +import re +import json +import traceback +import asyncio +import time +import asyncio +import time +import os + +from .base import Agent, agent_registry, LLMMessage +from ..message import Message +from ..rag import ChromaDBGetResponse +from ..setup_logging import setup_logging +from .. import defines +from ..user import User + +logger = setup_logging() + +seed = int(time.time()) +random.seed(seed) + +emptyUser = { + "profile_url": "", + "description": "", + "rag_content_size": 0, + "username": "", + "first_name": "", + "last_name": "", + "full_name": "", + "contact_info": {}, + "questions": [], +} + +generate_persona_system_prompt = """\ +You are a casing director for a movie. Your job is to provide information on ficticious personas for use in a screen play. + +All response field MUST BE IN ENGLISH, regardless of ethnicity. + +You will be provided with defaults to use if not specified by the user: + +```json +{ +"age": number, +"gender": "male" | "female", +"ethnicity": string, +} +``` + +Additional information provided in the user message can override those defaults. + +You need to randomly assign an English username (can include numbers), a first name, last name, and a two English sentence description of that individual's work given the demographics provided. + +Your response must be in JSON. +Provide only the JSON response, and match the field names EXACTLY. +Provide all information in English ONLY, with no other commentary: + +```json +{ +"username": string, # A likely-to-be unique username, no more than 15 characters (can include numbers and letters but no special characters) +"first_name": string, +"last_name": string, +"description": string, # One to two sentence description of their job +"location": string, # In the location, provide ALL of: City, State/Region, and Country +} +``` + +Make sure to provide a username and that the field name for the job description is "description". +""" + +generate_resume_system_prompt = """ +You are a creative writing casting director. As part of the casting, you are building backstories about individuals. The first part +of that is to create an in-depth resume for the person. You will be provided with the following information: + +```json +"full_name": string, # Person full name +"location": string, # Location of residence +"age": number, # Age of candidate +"description": string # A brief description of the person +``` + +Use that information to invent a full career resume. Include sections such as: + +* Contact information +* Job goal +* Top skills +* Detailed work history. If they are under the age of 25, you might include skills, hobbies, or volunteering they may have done while an adolescent +* In the work history, provide company names, years of employment, and their role +* Education + +Provide the resume in Markdown format. DO NOT provide any commentary before or after the resume. +""" + +class PersonaGenerator(Agent): + agent_type: Literal["persona"] = "persona" # type: ignore + _agent_type: ClassVar[str] = agent_type # Add this for registration + agent_persist: bool = False + + system_prompt: str = generate_persona_system_prompt + age: int = Field(default_factory=lambda: random.randint(22, 67)) + gender: str = Field(default_factory=lambda: random.choice(["male", "female"])) + ethnicity: Literal[ + "Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial" + ] = Field( + default_factory=lambda: random.choices( + ["Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"], + weights=[57.69, 15.38, 19.23, 5.77, 1.92], + k=1 + )[0] + ) + username: str = "" + + llm: Any = Field(default=None, exclude=True) + model: str = Field(default=None, exclude=True) + + def randomize(self): + self.age = random.randint(22, 67) + self.gender = random.choice(["male", "female"]) + # Use random.choices with explicit type casting to satisfy Literal type + self.ethnicity = cast( + Literal["Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"], + random.choices( + ["Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"], + weights=[57.69, 15.38, 19.23, 5.77, 1.92], + k=1 + )[0] + ) + + async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]: + logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") + + if not self.context: + raise ValueError("Context is not set for this agent.") + + message.tunables.enable_tools = False + message.tunables.enable_rag = False + message.tunables.enable_context = False + + message.prompt = f"""\ +```json +{json.dumps({ + "age": self.age, + "gender": self.gender, + "ethnicity": self.ethnicity +})} +``` +{message.prompt} +""" + message.status = "done" + yield message + return + + async def process_message( + self, llm: Any, model: str, message: Message + ) -> AsyncGenerator[Message, None]: + logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") + if not self.context: + raise ValueError("Context is not set for this agent.") + + self.llm = llm + self.model = model + original_prompt = message.prompt + + spinner: List[str] = ["\\", "|", "/", "-"] + tick: int = 0 + while self.context.processing: + logger.info( + "TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time" + ) + message.status = "waiting" + message.response = ( + f"Busy processing another request. Please wait. {spinner[tick]}" + ) + tick = (tick + 1) % len(spinner) + yield message + await asyncio.sleep(1) # Allow the event loop to process the write + + self.context.processing = True + + try: + + # + # Generate the persona + # + async for message in self.call_llm( + message=message, system_prompt=self.system_prompt, prompt=original_prompt + ): + if message.status != "done": + yield message + if message.status == "error": + raise Exception(message.response) + + json_str = self.extract_json_from_text(message.response) + try: + persona = json.loads(json_str) | { + "age": self.age, + "gender": self.gender, + "ethnicity": self.ethnicity + } + if not persona.get("full_name", None): + persona["full_name"] = f"{persona['first_name']} {persona['last_name']}" + self.username = persona.get("username", None) + if not self.username: + raise ValueError("LLM did not generate a username") + user_dir = os.path.join(defines.user_dir, persona["username"]) + while os.path.exists(user_dir): + match = re.match(r"^(.*?)(\d*)$", persona["username"]) + if match: + base = match.group(1) + num = match.group(2) + iteration = int(num) + 1 if num else 1 + persona["username"] = f"{base}{iteration}" + user_dir = os.path.join(defines.user_dir, persona["username"]) + + for key in persona: + if isinstance(persona[key], str): + persona[key] = persona[key].strip() + # Mark this persona as AI generated + persona["is_ai"] = True + except Exception as e: + message.response = f"Unable to parse LLM returned content: {json_str} {str(e)}" + message.status = "error" + logger.error(traceback.format_exc()) + logger.error(message.response) + yield message + return + + # Persona generated + message.response = json.dumps(persona) + message.status = "partial" + yield message + + # + # Generate the resume + # + message.status = "thinking" + message.response = f"Generating resume for {persona['full_name']}..." + yield message + + prompt = f""" +```json +{{ + "full_name": "{persona["full_name"]}", + "location": "{persona["location"]}", + "age": {persona["age"]}, + "description": {persona["description"]} +}} +``` +""" + try: + async for message in self.call_llm( + message=message, system_prompt=generate_resume_system_prompt, prompt=prompt + ): + if message.status != "done": + yield message + if message.status == "error": + raise Exception(message.response) + + except Exception as e: + message.response = f"Unable to parse LLM returned content: {json_str} {str(e)}" + message.status = "error" + logger.error(traceback.format_exc()) + logger.error(message.response) + yield message + return + + resume = self.extract_markdown_from_text(message.response) + if resume: + user_resume_dir = os.path.join(defines.user_dir, persona["username"], defines.resume_doc_dir) + os.makedirs(user_resume_dir, exist_ok=True) + user_resume_file = os.path.join(user_resume_dir, defines.resume_doc) + with open(user_resume_file, "w") as f: + f.write(resume) + + # Resume generated + message.response = resume + message.status = "partial" + yield message + + # + # Generate RAG database + # + message.status = "thinking" + message.response = f"Generating RAG content from resume..." + yield message + + # Prior to instancing a new User, the json data has to be created + # so the system can process it + user_dir = os.path.join(defines.user_dir, persona["username"]) + os.makedirs(user_dir, exist_ok=True) + user_info = os.path.join(user_dir, "info.json") + with open(user_info, "w") as f: + f.write(json.dumps(persona, indent=2)) + + user = User(llm=self.llm, username=self.username) + await user.initialize() + await user.file_watcher.initialize_collection() + # RAG content generated + message.response = f"{user.file_watcher.collection.count()} entries created in RAG vector store." + + # + # Write out the completed user information + # + with open(user_info, "w") as f: + f.write(json.dumps(persona, indent=2)) + + # Image generated + message.status = "done" + message.response = json.dumps(persona) + + except Exception as e: + message.status = "error" + logger.error(traceback.format_exc()) + logger.error(message.response) + message.response = f"Error in persona generation: {str(e)}" + logger.error(message.response) + self.randomize() # Randomize for next generation + yield message + return + + # Done processing, add message to conversation + self.context.processing = False + self.randomize() # Randomize for next generation + # Return the final message + yield message + return + + async def call_llm(self, message: Message, system_prompt, prompt, temperature=0.7): + logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") + + messages: List[LLMMessage] = [ + LLMMessage(role="system", content=system_prompt), + LLMMessage(role="user", content=prompt), + ] + message.metadata.options = { + "seed": 8911, + "num_ctx": self.context_size, + "temperature": temperature, # Higher temperature to encourage tool usage + } + + message.status = "streaming" + yield message + + last_chunk_time = 0 + message.chunk = "" + message.response = "" + for response in self.llm.chat( + model=self.model, + messages=messages, + options={ + **message.metadata.options, + }, + stream=True, + ): + if not response: + message.status = "error" + message.response = "No response from LLM." + yield message + return + + message.status = "streaming" + message.chunk += response.message.content + message.response += response.message.content + + if not response.done: + now = time.perf_counter() + if now - last_chunk_time > 0.25: + yield message + last_chunk_time = now + message.chunk = "" + + if response.done: + self.collect_metrics(response) + message.metadata.eval_count += response.eval_count + message.metadata.eval_duration += response.eval_duration + message.metadata.prompt_eval_count += response.prompt_eval_count + message.metadata.prompt_eval_duration += response.prompt_eval_duration + self.context_tokens = response.prompt_eval_count + response.eval_count + message.chunk = "" + message.status = "done" + yield message + + def extract_json_from_text(self, text: str) -> str: + """Extract JSON string from text that may contain other content.""" + json_pattern = r"```json\s*([\s\S]*?)\s*```" + match = re.search(json_pattern, text) + if match: + return match.group(1).strip() + + # Try to find JSON without the markdown code block + json_pattern = r"({[\s\S]*})" + match = re.search(json_pattern, text) + if match: + return match.group(1).strip() + + raise ValueError("No JSON found in the response") + + def extract_markdown_from_text(self, text: str) -> str: + """Extract Markdown string from text that may contain other content.""" + markdown_pattern = r"```(md|markdown)\s*([\s\S]*?)\s*```" + match = re.search(markdown_pattern, text) + if match: + return match.group(2).strip() + + raise ValueError("No Markdown found in the response") + +# Register the base agent +agent_registry.register(PersonaGenerator._agent_type, PersonaGenerator)