from __future__ import annotations from datetime import UTC, datetime from pydantic import BaseModel, Field # type: ignore from typing import Dict, Literal, Any, AsyncGenerator, Optional import inspect import random import re import json import traceback import asyncio import time import os import gc import tempfile import uuid import torch # type: ignore import asyncio import time import json from typing import AsyncGenerator from threading import Thread import queue import uuid from .image_model_cache import ImageModelCache from models import Candidate, ChatMessage, ChatMessageBase, ChatMessageMetaData, ChatMessageType, ChatMessageUser, ChatOptions, ChatSenderType, ChatStatusType from logger import logger from image_generator.image_model_cache import ImageModelCache # Heuristic time estimates (in seconds) for different models and devices at 512x512 TIME_ESTIMATES = { "flux": { "cuda": {"load": 10, "per_step": 0.8}, "xpu": {"load": 15, "per_step": 1.0}, "cpu": {"load": 30, "per_step": 10.0}, } } class ImageRequest(BaseModel): filepath: str prompt: str model: str = "black-forest-labs/FLUX.1-schnell" iterations: int = 4 height: int = 256 width: int = 256 guidance_scale: float = 7.5 # Global model cache instance model_cache = ImageModelCache() def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str): """Background worker for Flux image generation""" try: # Your existing estimates calculation estimates = {"per_step": 0.5} # Replace with your actual estimates resolution_scale = (params.height * params.width) / (512 * 512) # Flux: Run generation in the background and yield progress updates estimated_gen_time = estimates["per_step"] * params.iterations * resolution_scale status_queue.put({ "status": "running", "message": f"Initializing image generation...", "estimated_time_remaining": estimated_gen_time, "progress": 0 }) # Start the generation task start_gen_time = time.time() # Simulate your pipe call with progress updates def status_callback(pipeline, step, timestep, callback_kwargs): # Send progress updates progress = int((step+1) / params.iterations * 100) status_queue.put({ "status": "running", "message": f"Processing step {step+1}/{params.iterations} ({progress}%) complete.", "progress": progress }) return callback_kwargs # Replace this block with your actual Flux pipe call: image = pipe( params.prompt, num_inference_steps=params.iterations, guidance_scale=7.5, height=params.height, width=params.width, callback_on_step_end=status_callback, ).images[0] gen_time = time.time() - start_gen_time per_step_time = gen_time / params.iterations if params.iterations > 0 else gen_time logger.info(f"Saving to {params.filepath}") image.save(params.filepath) # Final completion status status_queue.put({ "status": "completed", "message": f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.", "progress": 100, "generation_time": gen_time, "per_step_time": per_step_time, "image_path": params.filepath }) except Exception as e: logger.error(traceback.format_exc()) logger.error(e) status_queue.put({ "status": "error", "message": f"Generation failed: {str(e)}", "error": str(e), "progress": 0 }) async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[Dict[str, Any], None]: """ Single async function that handles background Flux generation with status streaming """ task_id = str(uuid.uuid4()) status_queue = queue.Queue() worker_thread = None try: # Start background worker thread worker_thread = Thread( target=flux_worker, args=(pipe, params, status_queue, task_id), daemon=True ) worker_thread.start() # Initial status yield {'status': 'starting', 'task_id': task_id, 'message': 'Initializing image generation'} # Stream status updates completed = False last_heartbeat = time.time() while not completed and worker_thread.is_alive(): try: # Try to get status update (non-blocking) status_update = status_queue.get_nowait() # Add task_id to status update status_update['task_id'] = task_id # Send status update yield status_update # Check if completed if status_update.get('status') in ['completed', 'error']: completed = True last_heartbeat = time.time() except queue.Empty: # No new status, send heartbeat if needed current_time = time.time() if current_time - last_heartbeat > 2: # Heartbeat every 2 seconds heartbeat = { 'status': 'heartbeat', 'task_id': task_id, 'timestamp': current_time } yield heartbeat last_heartbeat = current_time # Brief sleep to prevent busy waiting await asyncio.sleep(0.1) # Handle thread completion or timeout if not completed: if worker_thread.is_alive(): # Thread still running but we might have missed the completion signal timeout_status = { 'status': 'timeout', 'task_id': task_id, 'message': 'Generation timed out or connection lost' } yield timeout_status else: # Thread completed but we might have missed the final status final_status = { 'status': 'completed', 'task_id': task_id, 'message': 'Generation completed' } yield final_status except Exception as e: error_status = { 'status': 'error', 'task_id': task_id, 'message': f'Server error: {str(e)}', 'error': str(e) } logger.error(error_status) yield error_status finally: # Cleanup: ensure thread completion if worker_thread and 'worker_thread' in locals() and worker_thread.is_alive(): worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup def status(chat_message: ChatMessage, status: str, progress: float = 0, estimated_time_remaining="...") -> ChatMessage: """Update chat message status and return it.""" message = chat_message.copy(deep=True) message.id = str(uuid.uuid4()) message.timestamp = datetime.now(UTC) message.type = ChatMessageType.THINKING message.status = ChatStatusType.STREAMING message.content = status return message async def generate_image(user_message: ChatMessage, request: ImageRequest) -> AsyncGenerator[ChatMessage, None]: """Generate an image with specified dimensions and yield status updates with time estimates.""" chat_message = ChatMessage( session_id=user_message.session_id, tunables=user_message.tunables, status=ChatStatusType.INITIALIZING, type=ChatMessageType.PREPARING, sender=ChatSenderType.ASSISTANT, content="", timestamp=datetime.now(UTC) ) try: # Validate prompt prompt = user_message.content.strip() if not prompt: chat_message.status = ChatStatusType.ERROR chat_message.content = "Prompt cannot be empty" yield chat_message return # Validate dimensions if request.height <= 0 or request.width <= 0: chat_message.status = ChatStatusType.ERROR chat_message.content = "Height and width must be positive" yield chat_message return filedir = os.path.dirname(request.filepath) filename = os.path.basename(request.filepath) os.makedirs(filedir, exist_ok=True) model_type = "flux" device = "cpu" yield status(chat_message, f"Starting image generation...") # Get initial time estimate, scaled by resolution estimates = TIME_ESTIMATES[model_type][device] resolution_scale = (request.height * request.width) / (512 * 512) estimated_total = estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale yield status(chat_message, f"Estimated generation time: ~{estimated_total:.1f} seconds for {request.width}x{request.height}") # Initialize or get cached pipeline start_time = time.time() yield status(chat_message, f"Loading generative image model...") pipe = await model_cache.get_pipeline(request.model, device) load_time = time.time() - start_time yield status(chat_message, f"Model loaded in {load_time:.1f} seconds.", progress=10) async for status_message in async_generate_image(pipe, request): chat_message.content = json.dumps(status_message) # Merge properties from async_generate_image over the message... chat_message.type = ChatMessageType.HEARTBEAT if status_message.get("status") == "heartbeat" else ChatMessageType.THINKING if chat_message.type != ChatMessageType.HEARTBEAT: logger.info(chat_message.content) yield chat_message # Final result total_time = time.time() - start_time chat_message.status = ChatStatusType.DONE chat_message.type = ChatMessageType.RESPONSE chat_message.content = json.dumps({ "status": f"Image generation complete in {total_time:.1f} seconds", "progress": 100, "filename": request.filepath }) yield chat_message except Exception as e: chat_message.status = ChatStatusType.ERROR chat_message.content = str(e) yield chat_message logger.error(traceback.format_exc()) logger.error(chat_message.content) return