from __future__ import annotations from pydantic import BaseModel from typing import Any, AsyncGenerator import traceback import asyncio import time import os import uuid import asyncio import time from typing import AsyncGenerator from threading import Thread import queue import uuid from .image_model_cache import ImageModelCache from models import ApiActivityType, ApiStatusType, ChatMessage, ChatMessageError, ChatMessageStatus from logger import logger from image_generator.image_model_cache import ImageModelCache # Heuristic time estimates (in seconds) for different models and devices at 512x512 TIME_ESTIMATES = { "flux": { "cuda": {"load": 10, "per_step": 0.8}, "xpu": {"load": 15, "per_step": 1.0}, "cpu": {"load": 30, "per_step": 10.0}, } } class ImageRequest(BaseModel): session_id: str filepath: str prompt: str model: str = "black-forest-labs/FLUX.1-schnell" iterations: int = 4 height: int = 256 width: int = 256 guidance_scale: float = 7.5 # Global model cache instance model_cache = ImageModelCache() def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str): """Background worker for Flux image generation""" try: # Flux: Run generation in the background and yield progress updates status_queue.put(ChatMessageStatus( session_id=params.session_id, content=f"Initializing image generation.", activity=ApiActivityType.GENERATING_IMAGE, )) # Start the generation task start_gen_time = time.time() # Simulate your pipe call with progress updates def status_callback(pipeline, step, timestep, callback_kwargs): # Send progress updates progress = int((step+1) / params.iterations * 100) status_queue.put(ChatMessageStatus( session_id=params.session_id, content=f"Processing step {step+1}/{params.iterations} ({progress}%)", activity=ApiActivityType.GENERATING_IMAGE, )) return callback_kwargs # Replace this block with your actual Flux pipe call: image = pipe( params.prompt, num_inference_steps=params.iterations, guidance_scale=7.5, height=params.height, width=params.width, callback_on_step_end=status_callback, ).images[0] gen_time = time.time() - start_gen_time per_step_time = gen_time / params.iterations if params.iterations > 0 else gen_time logger.info(f"Saving to {params.filepath}") image.save(params.filepath) # Final completion status status_queue.put(ChatMessage( session_id=params.session_id, status=ApiStatusType.DONE, content=f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.", )) except Exception as e: logger.error(traceback.format_exc()) logger.error(e) status_queue.put(ChatMessageError( session_id=params.session_id, content=f"Error during image generation: {str(e)}", )) async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError, None]: """ Single async function that handles background Flux generation with status streaming """ task_id = str(uuid.uuid4()) status_queue = queue.Queue() worker_thread = None try: # Start background worker thread worker_thread = Thread( target=flux_worker, args=(pipe, params, status_queue, task_id), daemon=True ) worker_thread.start() # Initial status status_message = ChatMessageStatus( session_id=params.session_id, content=f"Starting image generation with task ID {task_id}", activity=ApiActivityType.THINKING ) yield status_message # Stream status updates completed = False last_heartbeat = time.time() while not completed and worker_thread.is_alive(): try: # Try to get status update (non-blocking) status_update = status_queue.get_nowait() # Send status update yield status_update # Check if completed if status_update.status == ApiStatusType.DONE: logger.info(f"Image generation completed for task {task_id}") completed = True last_heartbeat = time.time() except queue.Empty: # No new status, send heartbeat if needed current_time = time.time() if current_time - last_heartbeat > 2: # Heartbeat every 2 seconds heartbeat = ChatMessageStatus( session_id=params.session_id, content=f"Heartbeat for task {task_id}", activity=ApiActivityType.HEARTBEAT, ) yield heartbeat last_heartbeat = current_time # Brief sleep to prevent busy waiting await asyncio.sleep(0.1) # Handle thread completion or timeout if not completed: if worker_thread.is_alive(): # Thread still running but we might have missed the completion signal timeout_status = ChatMessageError( session_id=params.session_id, content=f"Generation timeout for task {task_id}. The process may still be running.", ) yield timeout_status else: # Thread completed but we might have missed the final status final_status = ChatMessage( session_id=params.session_id, status=ApiStatusType.DONE, content=f"Generation completed for task {task_id}.", ) yield final_status except Exception as e: error_status = ChatMessageError( session_id=params.session_id, content=f'Server error: {str(e)}' ) logger.error(error_status) yield error_status finally: # Cleanup: ensure thread completion if worker_thread and 'worker_thread' in locals() and worker_thread.is_alive(): worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup def status(session_id: str, status: str) -> ChatMessageStatus: """Update chat message status and return it.""" chat_message = ChatMessageStatus( session_id=session_id, activity=ApiActivityType.GENERATING_IMAGE, content=status, ) return chat_message async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, None]: """Generate an image with specified dimensions and yield status updates with time estimates.""" session_id = request.session_id prompt = request.prompt.strip() try: # Validate prompt if not prompt: error_message = ChatMessageError( session_id=session_id, content="Prompt cannot be empty." ) logger.error(error_message.content) yield error_message return # Validate dimensions if request.height <= 0 or request.width <= 0: error_message = ChatMessageError( session_id=session_id, content="Height and width must be positive integers." ) logger.error(error_message.content) yield error_message return filedir = os.path.dirname(request.filepath) filename = os.path.basename(request.filepath) os.makedirs(filedir, exist_ok=True) model_type = "flux" device = "cpu" # Get initial time estimate, scaled by resolution estimates = TIME_ESTIMATES[model_type][device] resolution_scale = (request.height * request.width) / (512 * 512) estimated_total = estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale # Initialize or get cached pipeline start_time = time.time() yield status(session_id, f"Loading generative image model...") pipe = await model_cache.get_pipeline(request.model, device) load_time = time.time() - start_time yield status(session_id, f"Model loaded in {load_time:.1f} seconds.",) progress = None async for progress in async_generate_image(pipe, request): if progress.status == ApiStatusType.ERROR: yield progress return if progress.status != ApiStatusType.DONE: yield progress if not progress: error_message = ChatMessageError( session_id=session_id, content="Image generation failed to produce a valid response." ) logger.error(f"⚠️ {error_message.content}") yield error_message return # Final result total_time = time.time() - start_time chat_message = ChatMessage( session_id=session_id, status=ApiStatusType.DONE, content=f"Image generated successfully in {total_time:.1f} seconds.", ) yield chat_message except Exception as e: error_message = ChatMessageError( session_id=session_id, content=f"Error during image generation: {str(e)}" ) logger.error(traceback.format_exc()) logger.error(error_message.content) yield error_message return