Gen image works

This commit is contained in:
James Ketr 2025-05-22 15:25:52 -07:00
parent 840ad9159b
commit 74dea019fa
7 changed files with 191 additions and 547 deletions

View File

@ -58,7 +58,7 @@ type BackstoryMessage = {
full_content?: string;
response?: string; // Set when status === 'done', 'partial', or 'error'
chunk?: string; // Used when status === 'streaming'
timestamp?: string;
timestamp?: number;
disableCopy?: boolean,
user?: string,
title?: string,

View File

@ -115,6 +115,7 @@ const BackstoryPageContainer = (props : BackstoryPageContainerProps) => {
borderRadius: 2,
minHeight: '80vh',
width: "100%",
flexDirection: "column",
}}>
{children}
</Paper>

View File

@ -17,7 +17,7 @@ from .setup_logging import setup_logging
from .agents import class_registry, AnyAgent, Agent, __all__ as agents_all
from .metrics import Metrics
from .check_serializable import check_serializable
from .profile_image import generate_image_status, ImageRequest
from .profile_image import generate_image, ImageRequest
__all__ = [
"Agent",
@ -34,7 +34,7 @@ __all__ = [
"check_serializable",
"logger",
"User",
"generate_image_status", "ImageRequest"
"generate_image", "ImageRequest"
]
__all__.extend(agents_all) # type: ignore

View File

@ -50,6 +50,7 @@ class Agent(BaseModel, ABC):
# Agent management with pydantic
agent_type: Literal["base"] = "base"
_agent_type: ClassVar[str] = agent_type # Add this for registration
agent_persist: bool = True # Whether this agent will persist in the context
# Tunables (sets default for new Messages attached to this agent)
tunables: Tunables = Field(default_factory=Tunables)

View File

@ -1,448 +0,0 @@
from __future__ import annotations
from pydantic import model_validator, Field, BaseModel # type: ignore
from typing import (
Dict,
Literal,
ClassVar,
cast,
Any,
AsyncGenerator,
List,
Optional
# override
) # NOTE: You must import Optional for late binding to work
import inspect
import random
import re
import json
import traceback
import asyncio
import time
import asyncio
import time
import os
from . base import Agent, agent_registry, LLMMessage
from .. message import Message
from .. rag import ChromaDBGetResponse
from .. setup_logging import setup_logging
from .. profile_image import generate_image_status, ImageRequest
from .. import defines
from .. user import User
logger = setup_logging()
seed = int(time.time())
random.seed(seed)
emptyUser = {
"profile_url": "",
"description": "",
"rag_content_size": 0,
"username": "",
"first_name": "",
"last_name": "",
"full_name": "",
"contact_info": {},
"questions": [],
}
generate_persona_system_prompt = """\
You are a casing director for a movie. Your job is to provide information on ficticious personas for use in a screen play.
All response field MUST BE IN ENGLISH, regardless of ethnicity.
You will be provided with defaults to use if not specified by the user:
```json
{
"age": number,
"gender": "male" | "female",
"ethnicity": string,
}
```
Additional information provided in the user message can override those defaults.
You need to randomly assign an English username (can include numbers), a first name, last name, and a two English sentence description of that individual's work given the demographics provided.
Your response must be in JSON.
Provide only the JSON response, and match the field names EXACTLY.
Provide all information in English ONLY, with no other commentary:
```json
{
"username": string, # A likely-to-be unique username, no more than 15 characters (can include numbers and letters but no special characters)
"first_name": string,
"last_name": string,
"description": string, # One to two sentence description of their job
"location": string, # In the location, provide ALL of: City, State/Region, and Country
}
```
Make sure to provide a username and that the field name for the job description is "description".
"""
generate_resume_system_prompt = """
You are a creative writing casting director. As part of the casting, you are building backstories about individuals. The first part
of that is to create an in-depth resume for the person. You will be provided with the following information:
```json
"full_name": string, # Person full name
"location": string, # Location of residence
"age": number, # Age of candidate
"description": string # A brief description of the person
```
Use that information to invent a full career resume. Include sections such as:
* Contact information
* Job goal
* Top skills
* Detailed work history. If they are under the age of 25, you might include skills, hobbies, or volunteering they may have done while an adolescent
* In the work history, provide company names, years of employment, and their role
* Education
Provide the resume in Markdown format. DO NOT provide any commentary before or after the resume.
"""
class PersonaGenerator(Agent):
agent_type: Literal["persona"] = "persona" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
system_prompt: str = generate_persona_system_prompt
age: int = Field(default_factory=lambda: random.randint(22, 67))
gender: str = Field(default_factory=lambda: random.choice(["male", "female"]))
ethnicity: Literal[
"Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"
] = Field(
default_factory=lambda: random.choices(
["Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"],
weights=[57.69, 15.38, 19.23, 5.77, 1.92],
k=1
)[0]
)
username: str = ""
llm: Any = Field(default=None, exclude=True)
model: str = Field(default=None, exclude=True)
def randomize(self):
self.age = random.randint(22, 67)
self.gender = random.choice(["male", "female"])
# Use random.choices with explicit type casting to satisfy Literal type
self.ethnicity = cast(
Literal["Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"],
random.choices(
["Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"],
weights=[57.69, 15.38, 19.23, 5.77, 1.92],
k=1
)[0]
)
async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
if not self.context:
raise ValueError("Context is not set for this agent.")
message.tunables.enable_tools = False
message.tunables.enable_rag = False
message.tunables.enable_context = False
message.prompt = f"""\
```json
{json.dumps({
"age": self.age,
"gender": self.gender,
"ethnicity": self.ethnicity
})}
```
{message.prompt}
"""
message.status = "done"
yield message
return
async def process_message(
self, llm: Any, model: str, message: Message
) -> AsyncGenerator[Message, None]:
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
if not self.context:
raise ValueError("Context is not set for this agent.")
self.llm = llm
self.model = model
original_prompt = message.prompt
spinner: List[str] = ["\\", "|", "/", "-"]
tick: int = 0
while self.context.processing:
logger.info(
"TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time"
)
message.status = "waiting"
message.response = (
f"Busy processing another request. Please wait. {spinner[tick]}"
)
tick = (tick + 1) % len(spinner)
yield message
await asyncio.sleep(1) # Allow the event loop to process the write
self.context.processing = True
try:
#
# Generate the persona
#
async for message in self.call_llm(
message=message, system_prompt=self.system_prompt, prompt=original_prompt
):
if message.status != "done":
yield message
if message.status == "error":
raise Exception(message.response)
json_str = self.extract_json_from_text(message.response)
try:
persona = json.loads(json_str) | {
"age": self.age,
"gender": self.gender,
"ethnicity": self.ethnicity
}
if not persona.get("full_name", None):
persona["full_name"] = f"{persona['first_name']} {persona['last_name']}"
self.username = persona.get("username", None)
if not self.username:
raise ValueError("LLM did not generate a username")
user_dir = os.path.join(defines.user_dir, persona["username"])
while os.path.exists(user_dir):
match = re.match(r"^(.*?)(\d*)$", persona["username"])
if match:
base = match.group(1)
num = match.group(2)
iteration = int(num) + 1 if num else 1
persona["username"] = f"{base}{iteration}"
user_dir = os.path.join(defines.user_dir, persona["username"])
for key in persona:
if isinstance(persona[key], str):
persona[key] = persona[key].strip()
# Mark this persona as AI generated
persona["is_ai"] = True
except Exception as e:
message.response = f"Unable to parse LLM returned content: {json_str} {str(e)}"
message.status = "error"
logger.error(traceback.format_exc())
logger.error(message.response)
yield message
return
# Persona generated
message.response = json.dumps(persona)
message.status = "partial"
yield message
#
# Generate the resume
#
message.status = "thinking"
message.response = f"Generating resume for {persona['full_name']}..."
yield message
prompt = f"""
```json
{{
"full_name": "{persona["full_name"]}",
"location": "{persona["location"]}",
"age": {persona["age"]},
"description": {persona["description"]}
}}
```
"""
try:
async for message in self.call_llm(
message=message, system_prompt=generate_resume_system_prompt, prompt=prompt
):
if message.status != "done":
yield message
if message.status == "error":
raise Exception(message.response)
except Exception as e:
message.response = f"Unable to parse LLM returned content: {json_str} {str(e)}"
message.status = "error"
logger.error(traceback.format_exc())
logger.error(message.response)
yield message
return
resume = self.extract_markdown_from_text(message.response)
if resume:
user_resume_dir = os.path.join(defines.user_dir, persona["username"], defines.resume_doc_dir)
os.makedirs(user_resume_dir, exist_ok=True)
user_resume_file = os.path.join(user_resume_dir, defines.resume_doc)
with open(user_resume_file, "w") as f:
f.write(resume)
# Resume generated
message.response = resume
message.status = "partial"
yield message
#
# Generate RAG database
#
message.status = "thinking"
message.response = f"Generating RAG content from resume..."
yield message
# Prior to instancing a new User, the json data has to be created
# so the system can process it
user_dir = os.path.join(defines.user_dir, persona["username"])
os.makedirs(user_dir, exist_ok=True)
user_info = os.path.join(user_dir, "info.json")
with open(user_info, "w") as f:
f.write(json.dumps(persona, indent=2))
user = User(llm=self.llm, username=self.username)
await user.initialize()
await user.file_watcher.initialize_collection()
# RAG content generated
message.status = "partial"
message.response = f"{user.file_watcher.collection.count()} entries created in RAG vector store."
yield message
#
# Generate the profile picture
#
prompt = f"A photorealistic profile picture of a {persona["age"]} year old {persona["gender"]} {persona["ethnicity"]} person."
if original_prompt:
prompt = f"{prompt} {original_prompt}"
message.status = "thinking"
message.response = prompt
yield message
logger.info("Beginning image generation...")
request = ImageRequest(filepath=os.path.join(defines.user_dir, persona["username"], f"profile.png"), prompt=prompt)
placeholder = Message(prompt=prompt)
async for placeholder in generate_image_status(
message=placeholder,
**request.model_dump()
):
logger.info("Image generation continue...")
if placeholder.status != "done":
placeholder.response = placeholder.response
yield placeholder
logger.info("Image generation done...")
persona["has_profile"] = True
#
# Write out the completed user information
#
with open(user_info, "w") as f:
f.write(json.dumps(persona, indent=2))
# Image generated
message.status = "done"
message.response = json.dumps(persona)
except Exception as e:
message.status = "error"
logger.error(traceback.format_exc())
logger.error(message.response)
message.response = f"Error in persona generation: {str(e)}"
logger.error(message.response)
self.randomize() # Randomize for next generation
yield message
return
# Done processing, add message to conversation
self.context.processing = False
self.randomize() # Randomize for next generation
# Return the final message
yield message
return
async def call_llm(self, message: Message, system_prompt, prompt, temperature=0.7):
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
messages: List[LLMMessage] = [
LLMMessage(role="system", content=system_prompt),
LLMMessage(role="user", content=prompt),
]
message.metadata.options = {
"seed": 8911,
"num_ctx": self.context_size,
"temperature": temperature, # Higher temperature to encourage tool usage
}
message.status = "streaming"
yield message
last_chunk_time = 0
message.chunk = ""
message.response = ""
for response in self.llm.chat(
model=self.model,
messages=messages,
options={
**message.metadata.options,
},
stream=True,
):
if not response:
message.status = "error"
message.response = "No response from LLM."
yield message
return
message.status = "streaming"
message.chunk += response.message.content
message.response += response.message.content
if not response.done:
now = time.perf_counter()
if now - last_chunk_time > 0.25:
yield message
last_chunk_time = now
message.chunk = ""
if response.done:
self.collect_metrics(response)
message.metadata.eval_count += response.eval_count
message.metadata.eval_duration += response.eval_duration
message.metadata.prompt_eval_count += response.prompt_eval_count
message.metadata.prompt_eval_duration += response.prompt_eval_duration
self.context_tokens = response.prompt_eval_count + response.eval_count
message.chunk = ""
message.status = "done"
yield message
def extract_json_from_text(self, text: str) -> str:
"""Extract JSON string from text that may contain other content."""
json_pattern = r"```json\s*([\s\S]*?)\s*```"
match = re.search(json_pattern, text)
if match:
return match.group(1).strip()
# Try to find JSON without the markdown code block
json_pattern = r"({[\s\S]*})"
match = re.search(json_pattern, text)
if match:
return match.group(1).strip()
raise ValueError("No JSON found in the response")
def extract_markdown_from_text(self, text: str) -> str:
"""Extract Markdown string from text that may contain other content."""
markdown_pattern = r"```(md|markdown)\s*([\s\S]*?)\s*```"
match = re.search(markdown_pattern, text)
if match:
return match.group(2).strip()
raise ValueError("No Markdown found in the response")
# Register the base agent
agent_registry.register(PersonaGenerator._agent_type, PersonaGenerator)

View File

@ -129,7 +129,8 @@ class Context(BaseModel):
agent = agent_cls(agent_type=agent_type, **kwargs)
# set_context after constructor to initialize any non-serialized data
agent.set_context(self)
self.agents.append(agent)
if agent.agent_persist: # If an agent is not set to persist, do not add it to the context
self.agents.append(agent)
return agent
raise ValueError(f"No agent class found for agent_type: {agent_type}")

View File

@ -13,6 +13,13 @@ import gc
import tempfile
import uuid
import torch # type: ignore
import asyncio
import time
import json
from typing import AsyncGenerator
from threading import Thread
import queue
import uuid
from .agents.base import Agent, agent_registry, LLMMessage
from .message import Message
@ -24,11 +31,6 @@ logger = setup_logging()
# Heuristic time estimates (in seconds) for different models and devices at 512x512
TIME_ESTIMATES = {
"stable-diffusion": {
"cuda": {"load": 5, "per_step": 0.5},
"xpu": {"load": 7, "per_step": 0.7},
"cpu": {"load": 20, "per_step": 5.0},
},
"flux": {
"cuda": {"load": 10, "per_step": 0.8},
"xpu": {"load": 15, "per_step": 1.0},
@ -43,24 +45,176 @@ class ImageRequest(BaseModel):
iterations: int = 4
height: int = 256
width: int = 256
guidance_scale: float = 7.5
# Global model cache instance
model_cache = ImageModelCache(timeout_seconds=60 * 15) # 15 minutes
def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str):
"""Background worker for Flux image generation"""
try:
# Your existing estimates calculation
estimates = {"per_step": 0.5} # Replace with your actual estimates
resolution_scale = (params.height * params.width) / (512 * 512)
# Flux: Run generation in the background and yield progress updates
estimated_gen_time = estimates["per_step"] * params.iterations * resolution_scale
status_queue.put({
"status": "running",
"message": f"Starting Flux image generation with {params.iterations} inference steps",
"estimated_time_remaining": estimated_gen_time,
"progress": 0
})
# Start the generation task
start_gen_time = time.time()
# Simulate your pipe call with progress updates
def status_callback(pipeline, step, timestep, callback_kwargs):
# Send progress updates
elapsed = time.time() - start_gen_time
progress = int((step + 1) / params.iterations * 100)
status_queue.put({
"status": "running",
"message": f"Processing step {step}/{params.iterations}",
"progress": progress
})
# Replace this block with your actual Flux pipe call:
image = pipe(
params.prompt,
num_inference_steps=params.iterations,
guidance_scale=7.5,
height=params.height,
width=params.width,
callback_on_step_end=status_callback,
).images[0]
# Simulate image generation completion
gen_time = time.time() - start_gen_time
per_step_time = gen_time / params.iterations if params.iterations > 0 else gen_time
image.save(params.filepath)
# Final completion status
status_queue.put({
"status": "completed",
"message": f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.",
"progress": 100,
"generation_time": gen_time,
"per_step_time": per_step_time,
"image_path": f"generated_{task_id}.png" # Replace with actual image path/URL
})
except Exception as e:
status_queue.put({
"status": "error",
"message": f"Generation failed: {str(e)}",
"error": str(e),
"progress": 0
})
async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[Dict[str, Any], None]:
"""
Single async function that handles background Flux generation with status streaming
"""
task_id = str(uuid.uuid4())
status_queue = queue.Queue()
worker_thread = None
try:
# Start background worker thread
worker_thread = Thread(
target=flux_worker,
args=(pipe, params, status_queue, task_id),
daemon=True
)
worker_thread.start()
# Initial status
yield {'status': 'starting', 'task_id': task_id, 'message': 'Initializing image generation'}
# Stream status updates
completed = False
last_heartbeat = time.time()
while not completed and worker_thread.is_alive():
try:
# Try to get status update (non-blocking)
status_update = status_queue.get_nowait()
# Add task_id to status update
status_update['task_id'] = task_id
# Send status update
yield status_update
# Check if completed
if status_update.get('status') in ['completed', 'error']:
completed = True
last_heartbeat = time.time()
except queue.Empty:
# No new status, send heartbeat if needed
current_time = time.time()
if current_time - last_heartbeat > 2: # Heartbeat every 2 seconds
heartbeat = {
'status': 'heartbeat',
'task_id': task_id,
'timestamp': current_time
}
yield heartbeat
last_heartbeat = current_time
# Brief sleep to prevent busy waiting
await asyncio.sleep(0.1)
# Handle thread completion or timeout
if not completed:
if worker_thread.is_alive():
# Thread still running but we might have missed the completion signal
timeout_status = {
'status': 'timeout',
'task_id': task_id,
'message': 'Generation timed out or connection lost'
}
yield timeout_status
else:
# Thread completed but we might have missed the final status
final_status = {
'status': 'completed',
'task_id': task_id,
'message': 'Generation completed'
}
yield final_status
except Exception as e:
error_status = {
'status': 'error',
'task_id': task_id,
'message': f'Server error: {str(e)}',
'error': str(e)
}
yield error_status
finally:
# Cleanup: ensure thread completion
if worker_thread and 'worker_thread' in locals() and worker_thread.is_alive():
worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup
def status(message: Message, status: str, progress: float = 0, estimated_time_remaining="...") -> Message:
message.status = "thinking"
message.response = json.dumps({
"status": status,
"progress": progress,
"estimated_time_remaining": estimated_time_remaining
})
message.response = status
return message
async def generate_image_status(message: Message, model: str, prompt: str, iterations: int, filepath: str, height: int = 512, width: int = 512) -> AsyncGenerator[Message, None]:
async def generate_image(message: Message, request: ImageRequest) -> AsyncGenerator[Message, None]:
"""Generate an image with specified dimensions and yield status updates with time estimates."""
try:
# Validate prompt
prompt = prompt.strip()
prompt = request.prompt.strip()
if not prompt:
message.status = "error"
message.response = "Prompt cannot be empty"
@ -68,111 +222,46 @@ async def generate_image_status(message: Message, model: str, prompt: str, itera
return
# Validate dimensions
if height <= 0 or width <= 0:
if request.height <= 0 or request.width <= 0:
message.status = "error"
message.response = "Height and width must be positive"
yield message
return
if re.match(r".*stable-diffusion.*", model):
if height % 8 != 0 or width % 8 != 0:
message.status = "error"
message.response = "Stable Diffusion requires height and width to be multiples of 8"
yield message
return
filedir = os.path.dirname(filepath)
filename = os.path.basename(filepath)
filedir = os.path.dirname(request.filepath)
filename = os.path.basename(request.filepath)
os.makedirs(filedir, exist_ok=True)
model_type = "stable-diffusion" if re.match(r".*stable-diffusion.*", model) else "flux"
model_type = "flux"
device = "cpu"
if model_type == "flux":
device = "cpu"
else:
device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu"
yield status(message, f"Starting image generation for prompt: {prompt} {width}x{height} as {filename} using {device}")
yield status(message, f"Starting image generation for prompt: {prompt} {request.width}x{request.height} as {filename} using {device}")
# Get initial time estimate, scaled by resolution
estimates = TIME_ESTIMATES[model_type][device]
resolution_scale = (height * width) / (512 * 512)
estimated_total = estimates["load"] + estimates["per_step"] * iterations * resolution_scale
yield status(message, f"Estimated generation time: ~{estimated_total:.1f} seconds for {width}x{height}")
resolution_scale = (request.height * request.width) / (512 * 512)
estimated_total = estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale
yield status(message, f"Estimated generation time: ~{estimated_total:.1f} seconds for {request.width}x{request.height}")
# Initialize or get cached pipeline
start_time = time.time()
yield status(message, f"Loading {model_type} model: {model}")
pipe = await model_cache.get_pipeline(model, device)
yield status(message, f"Loading {model_type} model: {request.model}")
pipe = await model_cache.get_pipeline(request.model, device)
load_time = time.time() - start_time
yield status(message, f"Model loaded in {load_time:.1f} seconds. Generating image with {iterations} inference steps", progress=10)
# Generate image with progress tracking
start_gen_time = time.time()
if model_type == "stable-diffusion":
steps_completed = 0
last_step_time = start_gen_time
def progress_callback(step: int, timestep: int, latents: torch.Tensor):
nonlocal steps_completed, last_step_time
steps_completed += 1
current_time = time.time()
step_time = current_time - last_step_time
last_step_time = current_time
progress = (steps_completed / iterations) * 80 + 10 # Scale from 10% to 90%
remaining_steps = iterations - steps_completed
estimated_remaining = step_time * remaining_steps * resolution_scale
yield status(message=message, status=f"Step {steps_completed}/{iterations} completed", progress=progress, estimated_time_remaining=str(estimated_remaining))
async def capture_progress(step: int, timestep: int, latents: torch.Tensor):
for msg in progress_callback(step, timestep, latents):
yield msg
yield status(message, f"Generating image with {iterations} inference steps")
image = pipe(
prompt,
num_inference_steps=iterations,
guidance_scale=7.5,
height=height,
width=width,
callback=capture_progress,
callback_steps=1
).images[0]
else:
# Flux: Run generation in the background and yield progress updates
estimated_gen_time = estimates["per_step"] * iterations * resolution_scale
yield status(message, f"Starting Flux image generation with {iterations} inference steps", estimated_time_remaining=estimated_gen_time)
image = pipe(
prompt,
num_inference_steps=iterations,
guidance_scale=7.5,
height=height,
width=width
).images[0]
# Start the generation task
start_gen_time = time.time()
gen_time = time.time() - start_gen_time
per_step_time = gen_time / iterations if iterations > 0 else gen_time
yield status(message, f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.", 90)
gen_time = time.time() - start_gen_time
per_step_time = gen_time / iterations if iterations > 0 else gen_time
yield status(message, f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.", 90)
# Save image
yield status(message, f"Saving image to {filepath}", 95)
image.save(filepath)
yield status(message, f"Model loaded in {load_time:.1f} seconds. Generating image with {request.iterations} inference steps", progress=10)
async for update in async_generate_image(pipe, request):
message.status = update.get("status", "thinking")
message.response = json.dumps(update) # Merge properties from async_generate_image over the message...
yield message
# Final result
total_time = time.time() - start_time
message.status = "done"
message.response = json.dumps({
"status": f"Image generation complete in {total_time:.1f} seconds",
"progress": 100,
"filename": filepath
"filename": request.filepath
})
yield message