From 74dea019fa79bbfe4cf0551033091ec2c7c2a63e Mon Sep 17 00:00:00 2001 From: James Ketrenos Date: Thu, 22 May 2025 15:25:52 -0700 Subject: [PATCH] Gen image works --- frontend/src/Components/Message.tsx | 2 +- .../src/NewApp/Components/BackstoryLayout.tsx | 1 + src/utils/__init__.py | 4 +- src/utils/agents/base.py | 1 + src/utils/agents/persona.py | 448 ------------------ src/utils/context.py | 3 +- src/utils/profile_image.py | 279 +++++++---- 7 files changed, 191 insertions(+), 547 deletions(-) delete mode 100644 src/utils/agents/persona.py diff --git a/frontend/src/Components/Message.tsx b/frontend/src/Components/Message.tsx index fdf8eb9..115d75c 100644 --- a/frontend/src/Components/Message.tsx +++ b/frontend/src/Components/Message.tsx @@ -58,7 +58,7 @@ type BackstoryMessage = { full_content?: string; response?: string; // Set when status === 'done', 'partial', or 'error' chunk?: string; // Used when status === 'streaming' - timestamp?: string; + timestamp?: number; disableCopy?: boolean, user?: string, title?: string, diff --git a/frontend/src/NewApp/Components/BackstoryLayout.tsx b/frontend/src/NewApp/Components/BackstoryLayout.tsx index 9d8fcfb..64eb702 100644 --- a/frontend/src/NewApp/Components/BackstoryLayout.tsx +++ b/frontend/src/NewApp/Components/BackstoryLayout.tsx @@ -115,6 +115,7 @@ const BackstoryPageContainer = (props : BackstoryPageContainerProps) => { borderRadius: 2, minHeight: '80vh', width: "100%", + flexDirection: "column", }}> {children} diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 5239f5a..c87fc5f 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -17,7 +17,7 @@ from .setup_logging import setup_logging from .agents import class_registry, AnyAgent, Agent, __all__ as agents_all from .metrics import Metrics from .check_serializable import check_serializable -from .profile_image import generate_image_status, ImageRequest +from .profile_image import generate_image, ImageRequest __all__ = [ "Agent", @@ -34,7 +34,7 @@ __all__ = [ "check_serializable", "logger", "User", - "generate_image_status", "ImageRequest" + "generate_image", "ImageRequest" ] __all__.extend(agents_all) # type: ignore diff --git a/src/utils/agents/base.py b/src/utils/agents/base.py index 2e6045b..91a18fc 100644 --- a/src/utils/agents/base.py +++ b/src/utils/agents/base.py @@ -50,6 +50,7 @@ class Agent(BaseModel, ABC): # Agent management with pydantic agent_type: Literal["base"] = "base" _agent_type: ClassVar[str] = agent_type # Add this for registration + agent_persist: bool = True # Whether this agent will persist in the context # Tunables (sets default for new Messages attached to this agent) tunables: Tunables = Field(default_factory=Tunables) diff --git a/src/utils/agents/persona.py b/src/utils/agents/persona.py deleted file mode 100644 index b14c311..0000000 --- a/src/utils/agents/persona.py +++ /dev/null @@ -1,448 +0,0 @@ -from __future__ import annotations -from pydantic import model_validator, Field, BaseModel # type: ignore -from typing import ( - Dict, - Literal, - ClassVar, - cast, - Any, - AsyncGenerator, - List, - Optional -# override -) # NOTE: You must import Optional for late binding to work -import inspect -import random -import re -import json -import traceback -import asyncio -import time -import asyncio -import time -import os - -from . base import Agent, agent_registry, LLMMessage -from .. message import Message -from .. rag import ChromaDBGetResponse -from .. setup_logging import setup_logging -from .. profile_image import generate_image_status, ImageRequest -from .. import defines -from .. user import User - -logger = setup_logging() - -seed = int(time.time()) -random.seed(seed) - -emptyUser = { - "profile_url": "", - "description": "", - "rag_content_size": 0, - "username": "", - "first_name": "", - "last_name": "", - "full_name": "", - "contact_info": {}, - "questions": [], -} - -generate_persona_system_prompt = """\ -You are a casing director for a movie. Your job is to provide information on ficticious personas for use in a screen play. - -All response field MUST BE IN ENGLISH, regardless of ethnicity. - -You will be provided with defaults to use if not specified by the user: - -```json -{ -"age": number, -"gender": "male" | "female", -"ethnicity": string, -} -``` - -Additional information provided in the user message can override those defaults. - -You need to randomly assign an English username (can include numbers), a first name, last name, and a two English sentence description of that individual's work given the demographics provided. - -Your response must be in JSON. -Provide only the JSON response, and match the field names EXACTLY. -Provide all information in English ONLY, with no other commentary: - -```json -{ -"username": string, # A likely-to-be unique username, no more than 15 characters (can include numbers and letters but no special characters) -"first_name": string, -"last_name": string, -"description": string, # One to two sentence description of their job -"location": string, # In the location, provide ALL of: City, State/Region, and Country -} -``` - -Make sure to provide a username and that the field name for the job description is "description". -""" - -generate_resume_system_prompt = """ -You are a creative writing casting director. As part of the casting, you are building backstories about individuals. The first part -of that is to create an in-depth resume for the person. You will be provided with the following information: - -```json -"full_name": string, # Person full name -"location": string, # Location of residence -"age": number, # Age of candidate -"description": string # A brief description of the person -``` - -Use that information to invent a full career resume. Include sections such as: - -* Contact information -* Job goal -* Top skills -* Detailed work history. If they are under the age of 25, you might include skills, hobbies, or volunteering they may have done while an adolescent -* In the work history, provide company names, years of employment, and their role -* Education - -Provide the resume in Markdown format. DO NOT provide any commentary before or after the resume. -""" - -class PersonaGenerator(Agent): - agent_type: Literal["persona"] = "persona" # type: ignore - _agent_type: ClassVar[str] = agent_type # Add this for registration - - system_prompt: str = generate_persona_system_prompt - age: int = Field(default_factory=lambda: random.randint(22, 67)) - gender: str = Field(default_factory=lambda: random.choice(["male", "female"])) - ethnicity: Literal[ - "Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial" - ] = Field( - default_factory=lambda: random.choices( - ["Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"], - weights=[57.69, 15.38, 19.23, 5.77, 1.92], - k=1 - )[0] - ) - username: str = "" - - llm: Any = Field(default=None, exclude=True) - model: str = Field(default=None, exclude=True) - - def randomize(self): - self.age = random.randint(22, 67) - self.gender = random.choice(["male", "female"]) - # Use random.choices with explicit type casting to satisfy Literal type - self.ethnicity = cast( - Literal["Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"], - random.choices( - ["Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"], - weights=[57.69, 15.38, 19.23, 5.77, 1.92], - k=1 - )[0] - ) - - async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]: - logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") - - if not self.context: - raise ValueError("Context is not set for this agent.") - - message.tunables.enable_tools = False - message.tunables.enable_rag = False - message.tunables.enable_context = False - - message.prompt = f"""\ -```json -{json.dumps({ - "age": self.age, - "gender": self.gender, - "ethnicity": self.ethnicity -})} -``` -{message.prompt} -""" - message.status = "done" - yield message - return - - async def process_message( - self, llm: Any, model: str, message: Message - ) -> AsyncGenerator[Message, None]: - logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") - if not self.context: - raise ValueError("Context is not set for this agent.") - - self.llm = llm - self.model = model - original_prompt = message.prompt - - spinner: List[str] = ["\\", "|", "/", "-"] - tick: int = 0 - while self.context.processing: - logger.info( - "TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time" - ) - message.status = "waiting" - message.response = ( - f"Busy processing another request. Please wait. {spinner[tick]}" - ) - tick = (tick + 1) % len(spinner) - yield message - await asyncio.sleep(1) # Allow the event loop to process the write - - self.context.processing = True - - try: - - # - # Generate the persona - # - async for message in self.call_llm( - message=message, system_prompt=self.system_prompt, prompt=original_prompt - ): - if message.status != "done": - yield message - if message.status == "error": - raise Exception(message.response) - - json_str = self.extract_json_from_text(message.response) - try: - persona = json.loads(json_str) | { - "age": self.age, - "gender": self.gender, - "ethnicity": self.ethnicity - } - if not persona.get("full_name", None): - persona["full_name"] = f"{persona['first_name']} {persona['last_name']}" - self.username = persona.get("username", None) - if not self.username: - raise ValueError("LLM did not generate a username") - user_dir = os.path.join(defines.user_dir, persona["username"]) - while os.path.exists(user_dir): - match = re.match(r"^(.*?)(\d*)$", persona["username"]) - if match: - base = match.group(1) - num = match.group(2) - iteration = int(num) + 1 if num else 1 - persona["username"] = f"{base}{iteration}" - user_dir = os.path.join(defines.user_dir, persona["username"]) - - for key in persona: - if isinstance(persona[key], str): - persona[key] = persona[key].strip() - # Mark this persona as AI generated - persona["is_ai"] = True - except Exception as e: - message.response = f"Unable to parse LLM returned content: {json_str} {str(e)}" - message.status = "error" - logger.error(traceback.format_exc()) - logger.error(message.response) - yield message - return - - # Persona generated - message.response = json.dumps(persona) - message.status = "partial" - yield message - - # - # Generate the resume - # - message.status = "thinking" - message.response = f"Generating resume for {persona['full_name']}..." - yield message - - prompt = f""" -```json -{{ - "full_name": "{persona["full_name"]}", - "location": "{persona["location"]}", - "age": {persona["age"]}, - "description": {persona["description"]} -}} -``` -""" - try: - async for message in self.call_llm( - message=message, system_prompt=generate_resume_system_prompt, prompt=prompt - ): - if message.status != "done": - yield message - if message.status == "error": - raise Exception(message.response) - - except Exception as e: - message.response = f"Unable to parse LLM returned content: {json_str} {str(e)}" - message.status = "error" - logger.error(traceback.format_exc()) - logger.error(message.response) - yield message - return - - resume = self.extract_markdown_from_text(message.response) - if resume: - user_resume_dir = os.path.join(defines.user_dir, persona["username"], defines.resume_doc_dir) - os.makedirs(user_resume_dir, exist_ok=True) - user_resume_file = os.path.join(user_resume_dir, defines.resume_doc) - with open(user_resume_file, "w") as f: - f.write(resume) - - # Resume generated - message.response = resume - message.status = "partial" - yield message - - # - # Generate RAG database - # - message.status = "thinking" - message.response = f"Generating RAG content from resume..." - yield message - - # Prior to instancing a new User, the json data has to be created - # so the system can process it - user_dir = os.path.join(defines.user_dir, persona["username"]) - os.makedirs(user_dir, exist_ok=True) - user_info = os.path.join(user_dir, "info.json") - with open(user_info, "w") as f: - f.write(json.dumps(persona, indent=2)) - - user = User(llm=self.llm, username=self.username) - await user.initialize() - await user.file_watcher.initialize_collection() - # RAG content generated - message.status = "partial" - message.response = f"{user.file_watcher.collection.count()} entries created in RAG vector store." - yield message - - # - # Generate the profile picture - # - prompt = f"A photorealistic profile picture of a {persona["age"]} year old {persona["gender"]} {persona["ethnicity"]} person." - if original_prompt: - prompt = f"{prompt} {original_prompt}" - message.status = "thinking" - message.response = prompt - yield message - - logger.info("Beginning image generation...") - request = ImageRequest(filepath=os.path.join(defines.user_dir, persona["username"], f"profile.png"), prompt=prompt) - placeholder = Message(prompt=prompt) - async for placeholder in generate_image_status( - message=placeholder, - **request.model_dump() - ): - logger.info("Image generation continue...") - if placeholder.status != "done": - placeholder.response = placeholder.response - yield placeholder - logger.info("Image generation done...") - persona["has_profile"] = True - - # - # Write out the completed user information - # - with open(user_info, "w") as f: - f.write(json.dumps(persona, indent=2)) - - # Image generated - message.status = "done" - message.response = json.dumps(persona) - - except Exception as e: - message.status = "error" - logger.error(traceback.format_exc()) - logger.error(message.response) - message.response = f"Error in persona generation: {str(e)}" - logger.error(message.response) - self.randomize() # Randomize for next generation - yield message - return - - # Done processing, add message to conversation - self.context.processing = False - self.randomize() # Randomize for next generation - # Return the final message - yield message - return - - async def call_llm(self, message: Message, system_prompt, prompt, temperature=0.7): - logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") - - messages: List[LLMMessage] = [ - LLMMessage(role="system", content=system_prompt), - LLMMessage(role="user", content=prompt), - ] - message.metadata.options = { - "seed": 8911, - "num_ctx": self.context_size, - "temperature": temperature, # Higher temperature to encourage tool usage - } - - message.status = "streaming" - yield message - - last_chunk_time = 0 - message.chunk = "" - message.response = "" - for response in self.llm.chat( - model=self.model, - messages=messages, - options={ - **message.metadata.options, - }, - stream=True, - ): - if not response: - message.status = "error" - message.response = "No response from LLM." - yield message - return - - message.status = "streaming" - message.chunk += response.message.content - message.response += response.message.content - - if not response.done: - now = time.perf_counter() - if now - last_chunk_time > 0.25: - yield message - last_chunk_time = now - message.chunk = "" - - if response.done: - self.collect_metrics(response) - message.metadata.eval_count += response.eval_count - message.metadata.eval_duration += response.eval_duration - message.metadata.prompt_eval_count += response.prompt_eval_count - message.metadata.prompt_eval_duration += response.prompt_eval_duration - self.context_tokens = response.prompt_eval_count + response.eval_count - message.chunk = "" - message.status = "done" - yield message - - def extract_json_from_text(self, text: str) -> str: - """Extract JSON string from text that may contain other content.""" - json_pattern = r"```json\s*([\s\S]*?)\s*```" - match = re.search(json_pattern, text) - if match: - return match.group(1).strip() - - # Try to find JSON without the markdown code block - json_pattern = r"({[\s\S]*})" - match = re.search(json_pattern, text) - if match: - return match.group(1).strip() - - raise ValueError("No JSON found in the response") - - def extract_markdown_from_text(self, text: str) -> str: - """Extract Markdown string from text that may contain other content.""" - markdown_pattern = r"```(md|markdown)\s*([\s\S]*?)\s*```" - match = re.search(markdown_pattern, text) - if match: - return match.group(2).strip() - - raise ValueError("No Markdown found in the response") - -# Register the base agent -agent_registry.register(PersonaGenerator._agent_type, PersonaGenerator) diff --git a/src/utils/context.py b/src/utils/context.py index 013c8e2..0ff1434 100644 --- a/src/utils/context.py +++ b/src/utils/context.py @@ -129,7 +129,8 @@ class Context(BaseModel): agent = agent_cls(agent_type=agent_type, **kwargs) # set_context after constructor to initialize any non-serialized data agent.set_context(self) - self.agents.append(agent) + if agent.agent_persist: # If an agent is not set to persist, do not add it to the context + self.agents.append(agent) return agent raise ValueError(f"No agent class found for agent_type: {agent_type}") diff --git a/src/utils/profile_image.py b/src/utils/profile_image.py index 44e0a59..a22747e 100644 --- a/src/utils/profile_image.py +++ b/src/utils/profile_image.py @@ -13,6 +13,13 @@ import gc import tempfile import uuid import torch # type: ignore +import asyncio +import time +import json +from typing import AsyncGenerator +from threading import Thread +import queue +import uuid from .agents.base import Agent, agent_registry, LLMMessage from .message import Message @@ -24,11 +31,6 @@ logger = setup_logging() # Heuristic time estimates (in seconds) for different models and devices at 512x512 TIME_ESTIMATES = { - "stable-diffusion": { - "cuda": {"load": 5, "per_step": 0.5}, - "xpu": {"load": 7, "per_step": 0.7}, - "cpu": {"load": 20, "per_step": 5.0}, - }, "flux": { "cuda": {"load": 10, "per_step": 0.8}, "xpu": {"load": 15, "per_step": 1.0}, @@ -43,24 +45,176 @@ class ImageRequest(BaseModel): iterations: int = 4 height: int = 256 width: int = 256 + guidance_scale: float = 7.5 # Global model cache instance model_cache = ImageModelCache(timeout_seconds=60 * 15) # 15 minutes +def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str): + """Background worker for Flux image generation""" + try: + # Your existing estimates calculation + estimates = {"per_step": 0.5} # Replace with your actual estimates + resolution_scale = (params.height * params.width) / (512 * 512) + + # Flux: Run generation in the background and yield progress updates + estimated_gen_time = estimates["per_step"] * params.iterations * resolution_scale + status_queue.put({ + "status": "running", + "message": f"Starting Flux image generation with {params.iterations} inference steps", + "estimated_time_remaining": estimated_gen_time, + "progress": 0 + }) + + # Start the generation task + start_gen_time = time.time() + + # Simulate your pipe call with progress updates + def status_callback(pipeline, step, timestep, callback_kwargs): + # Send progress updates + elapsed = time.time() - start_gen_time + progress = int((step + 1) / params.iterations * 100) + + status_queue.put({ + "status": "running", + "message": f"Processing step {step}/{params.iterations}", + "progress": progress + }) + + # Replace this block with your actual Flux pipe call: + image = pipe( + params.prompt, + num_inference_steps=params.iterations, + guidance_scale=7.5, + height=params.height, + width=params.width, + callback_on_step_end=status_callback, + ).images[0] + + # Simulate image generation completion + gen_time = time.time() - start_gen_time + per_step_time = gen_time / params.iterations if params.iterations > 0 else gen_time + + image.save(params.filepath) + + # Final completion status + status_queue.put({ + "status": "completed", + "message": f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.", + "progress": 100, + "generation_time": gen_time, + "per_step_time": per_step_time, + "image_path": f"generated_{task_id}.png" # Replace with actual image path/URL + }) + + except Exception as e: + status_queue.put({ + "status": "error", + "message": f"Generation failed: {str(e)}", + "error": str(e), + "progress": 0 + }) + + +async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[Dict[str, Any], None]: + """ + Single async function that handles background Flux generation with status streaming + """ + task_id = str(uuid.uuid4()) + status_queue = queue.Queue() + worker_thread = None + + try: + # Start background worker thread + worker_thread = Thread( + target=flux_worker, + args=(pipe, params, status_queue, task_id), + daemon=True + ) + worker_thread.start() + + # Initial status + yield {'status': 'starting', 'task_id': task_id, 'message': 'Initializing image generation'} + + # Stream status updates + completed = False + last_heartbeat = time.time() + + while not completed and worker_thread.is_alive(): + try: + # Try to get status update (non-blocking) + status_update = status_queue.get_nowait() + + # Add task_id to status update + status_update['task_id'] = task_id + + # Send status update + yield status_update + + # Check if completed + if status_update.get('status') in ['completed', 'error']: + completed = True + + last_heartbeat = time.time() + + except queue.Empty: + # No new status, send heartbeat if needed + current_time = time.time() + if current_time - last_heartbeat > 2: # Heartbeat every 2 seconds + heartbeat = { + 'status': 'heartbeat', + 'task_id': task_id, + 'timestamp': current_time + } + yield heartbeat + last_heartbeat = current_time + + # Brief sleep to prevent busy waiting + await asyncio.sleep(0.1) + + # Handle thread completion or timeout + if not completed: + if worker_thread.is_alive(): + # Thread still running but we might have missed the completion signal + timeout_status = { + 'status': 'timeout', + 'task_id': task_id, + 'message': 'Generation timed out or connection lost' + } + yield timeout_status + else: + # Thread completed but we might have missed the final status + final_status = { + 'status': 'completed', + 'task_id': task_id, + 'message': 'Generation completed' + } + yield final_status + + except Exception as e: + error_status = { + 'status': 'error', + 'task_id': task_id, + 'message': f'Server error: {str(e)}', + 'error': str(e) + } + yield error_status + + finally: + # Cleanup: ensure thread completion + if worker_thread and 'worker_thread' in locals() and worker_thread.is_alive(): + worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup + def status(message: Message, status: str, progress: float = 0, estimated_time_remaining="...") -> Message: message.status = "thinking" - message.response = json.dumps({ - "status": status, - "progress": progress, - "estimated_time_remaining": estimated_time_remaining - }) + message.response = status return message -async def generate_image_status(message: Message, model: str, prompt: str, iterations: int, filepath: str, height: int = 512, width: int = 512) -> AsyncGenerator[Message, None]: +async def generate_image(message: Message, request: ImageRequest) -> AsyncGenerator[Message, None]: """Generate an image with specified dimensions and yield status updates with time estimates.""" try: # Validate prompt - prompt = prompt.strip() + prompt = request.prompt.strip() if not prompt: message.status = "error" message.response = "Prompt cannot be empty" @@ -68,111 +222,46 @@ async def generate_image_status(message: Message, model: str, prompt: str, itera return # Validate dimensions - if height <= 0 or width <= 0: + if request.height <= 0 or request.width <= 0: message.status = "error" message.response = "Height and width must be positive" yield message return - if re.match(r".*stable-diffusion.*", model): - if height % 8 != 0 or width % 8 != 0: - message.status = "error" - message.response = "Stable Diffusion requires height and width to be multiples of 8" - yield message - return - filedir = os.path.dirname(filepath) - filename = os.path.basename(filepath) + filedir = os.path.dirname(request.filepath) + filename = os.path.basename(request.filepath) os.makedirs(filedir, exist_ok=True) - model_type = "stable-diffusion" if re.match(r".*stable-diffusion.*", model) else "flux" + model_type = "flux" + device = "cpu" - if model_type == "flux": - device = "cpu" - else: - device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu" - - yield status(message, f"Starting image generation for prompt: {prompt} {width}x{height} as {filename} using {device}") + yield status(message, f"Starting image generation for prompt: {prompt} {request.width}x{request.height} as {filename} using {device}") # Get initial time estimate, scaled by resolution estimates = TIME_ESTIMATES[model_type][device] - resolution_scale = (height * width) / (512 * 512) - estimated_total = estimates["load"] + estimates["per_step"] * iterations * resolution_scale - yield status(message, f"Estimated generation time: ~{estimated_total:.1f} seconds for {width}x{height}") + resolution_scale = (request.height * request.width) / (512 * 512) + estimated_total = estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale + yield status(message, f"Estimated generation time: ~{estimated_total:.1f} seconds for {request.width}x{request.height}") # Initialize or get cached pipeline start_time = time.time() - yield status(message, f"Loading {model_type} model: {model}") - pipe = await model_cache.get_pipeline(model, device) + yield status(message, f"Loading {model_type} model: {request.model}") + pipe = await model_cache.get_pipeline(request.model, device) load_time = time.time() - start_time - yield status(message, f"Model loaded in {load_time:.1f} seconds. Generating image with {iterations} inference steps", progress=10) - - # Generate image with progress tracking - start_gen_time = time.time() - - if model_type == "stable-diffusion": - steps_completed = 0 - last_step_time = start_gen_time - - def progress_callback(step: int, timestep: int, latents: torch.Tensor): - nonlocal steps_completed, last_step_time - steps_completed += 1 - current_time = time.time() - step_time = current_time - last_step_time - last_step_time = current_time - progress = (steps_completed / iterations) * 80 + 10 # Scale from 10% to 90% - remaining_steps = iterations - steps_completed - estimated_remaining = step_time * remaining_steps * resolution_scale - yield status(message=message, status=f"Step {steps_completed}/{iterations} completed", progress=progress, estimated_time_remaining=str(estimated_remaining)) - - async def capture_progress(step: int, timestep: int, latents: torch.Tensor): - for msg in progress_callback(step, timestep, latents): - yield msg - - yield status(message, f"Generating image with {iterations} inference steps") - image = pipe( - prompt, - num_inference_steps=iterations, - guidance_scale=7.5, - height=height, - width=width, - callback=capture_progress, - callback_steps=1 - ).images[0] - else: - # Flux: Run generation in the background and yield progress updates - estimated_gen_time = estimates["per_step"] * iterations * resolution_scale - yield status(message, f"Starting Flux image generation with {iterations} inference steps", estimated_time_remaining=estimated_gen_time) - - image = pipe( - prompt, - num_inference_steps=iterations, - guidance_scale=7.5, - height=height, - width=width - ).images[0] - - # Start the generation task - start_gen_time = time.time() - - gen_time = time.time() - start_gen_time - per_step_time = gen_time / iterations if iterations > 0 else gen_time - yield status(message, f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.", 90) - - gen_time = time.time() - start_gen_time - per_step_time = gen_time / iterations if iterations > 0 else gen_time - yield status(message, f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.", 90) - - # Save image - yield status(message, f"Saving image to {filepath}", 95) - image.save(filepath) + yield status(message, f"Model loaded in {load_time:.1f} seconds. Generating image with {request.iterations} inference steps", progress=10) + async for update in async_generate_image(pipe, request): + message.status = update.get("status", "thinking") + message.response = json.dumps(update) # Merge properties from async_generate_image over the message... + yield message + # Final result total_time = time.time() - start_time message.status = "done" message.response = json.dumps({ "status": f"Image generation complete in {total_time:.1f} seconds", "progress": 100, - "filename": filepath + "filename": request.filepath }) yield message