from __future__ import annotations from pydantic import BaseModel, Field # type: ignore from typing import Dict, Literal, Any, AsyncGenerator, Optional import inspect import random import re import json import traceback import asyncio import time import os import gc import tempfile import uuid import torch # type: ignore import asyncio import time import json from typing import AsyncGenerator from threading import Thread import queue import uuid from .agents.base import Agent, agent_registry, LLMMessage from .message import Message from .rag import ChromaDBGetResponse from .setup_logging import setup_logging from .image_model_cache import ImageModelCache logger = setup_logging() # Heuristic time estimates (in seconds) for different models and devices at 512x512 TIME_ESTIMATES = { "flux": { "cuda": {"load": 10, "per_step": 0.8}, "xpu": {"load": 15, "per_step": 1.0}, "cpu": {"load": 30, "per_step": 10.0}, } } class ImageRequest(BaseModel): filepath: str prompt: str model: str = "black-forest-labs/FLUX.1-schnell" iterations: int = 4 height: int = 256 width: int = 256 guidance_scale: float = 7.5 # Global model cache instance model_cache = ImageModelCache(timeout_seconds=60 * 15) # 15 minutes def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str): """Background worker for Flux image generation""" try: # Your existing estimates calculation estimates = {"per_step": 0.5} # Replace with your actual estimates resolution_scale = (params.height * params.width) / (512 * 512) # Flux: Run generation in the background and yield progress updates estimated_gen_time = estimates["per_step"] * params.iterations * resolution_scale status_queue.put({ "status": "running", "message": f"Starting Flux image generation with {params.iterations} inference steps", "estimated_time_remaining": estimated_gen_time, "progress": 0 }) # Start the generation task start_gen_time = time.time() # Simulate your pipe call with progress updates def status_callback(pipeline, step, timestep, callback_kwargs): # Send progress updates elapsed = time.time() - start_gen_time progress = int((step + 1) / params.iterations * 100) status_queue.put({ "status": "running", "message": f"Processing step {step}/{params.iterations}", "progress": progress }) # Replace this block with your actual Flux pipe call: image = pipe( params.prompt, num_inference_steps=params.iterations, guidance_scale=7.5, height=params.height, width=params.width, callback_on_step_end=status_callback, ).images[0] # Simulate image generation completion gen_time = time.time() - start_gen_time per_step_time = gen_time / params.iterations if params.iterations > 0 else gen_time image.save(params.filepath) # Final completion status status_queue.put({ "status": "completed", "message": f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.", "progress": 100, "generation_time": gen_time, "per_step_time": per_step_time, "image_path": f"generated_{task_id}.png" # Replace with actual image path/URL }) except Exception as e: status_queue.put({ "status": "error", "message": f"Generation failed: {str(e)}", "error": str(e), "progress": 0 }) async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[Dict[str, Any], None]: """ Single async function that handles background Flux generation with status streaming """ task_id = str(uuid.uuid4()) status_queue = queue.Queue() worker_thread = None try: # Start background worker thread worker_thread = Thread( target=flux_worker, args=(pipe, params, status_queue, task_id), daemon=True ) worker_thread.start() # Initial status yield {'status': 'starting', 'task_id': task_id, 'message': 'Initializing image generation'} # Stream status updates completed = False last_heartbeat = time.time() while not completed and worker_thread.is_alive(): try: # Try to get status update (non-blocking) status_update = status_queue.get_nowait() # Add task_id to status update status_update['task_id'] = task_id # Send status update yield status_update # Check if completed if status_update.get('status') in ['completed', 'error']: completed = True last_heartbeat = time.time() except queue.Empty: # No new status, send heartbeat if needed current_time = time.time() if current_time - last_heartbeat > 2: # Heartbeat every 2 seconds heartbeat = { 'status': 'heartbeat', 'task_id': task_id, 'timestamp': current_time } yield heartbeat last_heartbeat = current_time # Brief sleep to prevent busy waiting await asyncio.sleep(0.1) # Handle thread completion or timeout if not completed: if worker_thread.is_alive(): # Thread still running but we might have missed the completion signal timeout_status = { 'status': 'timeout', 'task_id': task_id, 'message': 'Generation timed out or connection lost' } yield timeout_status else: # Thread completed but we might have missed the final status final_status = { 'status': 'completed', 'task_id': task_id, 'message': 'Generation completed' } yield final_status except Exception as e: error_status = { 'status': 'error', 'task_id': task_id, 'message': f'Server error: {str(e)}', 'error': str(e) } yield error_status finally: # Cleanup: ensure thread completion if worker_thread and 'worker_thread' in locals() and worker_thread.is_alive(): worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup def status(message: Message, status: str, progress: float = 0, estimated_time_remaining="...") -> Message: message.status = "thinking" message.response = status return message async def generate_image(message: Message, request: ImageRequest) -> AsyncGenerator[Message, None]: """Generate an image with specified dimensions and yield status updates with time estimates.""" try: # Validate prompt prompt = request.prompt.strip() if not prompt: message.status = "error" message.response = "Prompt cannot be empty" yield message return # Validate dimensions if request.height <= 0 or request.width <= 0: message.status = "error" message.response = "Height and width must be positive" yield message return filedir = os.path.dirname(request.filepath) filename = os.path.basename(request.filepath) os.makedirs(filedir, exist_ok=True) model_type = "flux" device = "cpu" yield status(message, f"Starting image generation for prompt: {prompt} {request.width}x{request.height} as {filename} using {device}") # Get initial time estimate, scaled by resolution estimates = TIME_ESTIMATES[model_type][device] resolution_scale = (request.height * request.width) / (512 * 512) estimated_total = estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale yield status(message, f"Estimated generation time: ~{estimated_total:.1f} seconds for {request.width}x{request.height}") # Initialize or get cached pipeline start_time = time.time() yield status(message, f"Loading {model_type} model: {request.model}") pipe = await model_cache.get_pipeline(request.model, device) load_time = time.time() - start_time yield status(message, f"Model loaded in {load_time:.1f} seconds. Generating image with {request.iterations} inference steps", progress=10) async for update in async_generate_image(pipe, request): message.status = update.get("status", "thinking") message.response = json.dumps(update) # Merge properties from async_generate_image over the message... yield message # Final result total_time = time.time() - start_time message.status = "done" message.response = json.dumps({ "status": f"Image generation complete in {total_time:.1f} seconds", "progress": 100, "filename": request.filepath }) yield message except Exception as e: message.status = "error" message.response = str(e) yield message logger.error(traceback.format_exc()) logger.error(message.response) return