294 lines
11 KiB
Python

from __future__ import annotations
from datetime import UTC, datetime
from pydantic import BaseModel, Field # type: ignore
from typing import Dict, Literal, Any, AsyncGenerator, Optional
import inspect
import random
import re
import json
import traceback
import asyncio
import time
import os
import gc
import tempfile
import uuid
import torch # type: ignore
import asyncio
import time
import json
from typing import AsyncGenerator
from threading import Thread
import queue
import uuid
from .image_model_cache import ImageModelCache
from models import Candidate, ChatMessage, ChatMessageBase, ChatMessageMetaData, ChatMessageType, ChatMessageUser, ChatOptions, ChatSenderType, ChatStatusType
from logger import logger
from image_generator.image_model_cache import ImageModelCache
# Heuristic time estimates (in seconds) for different models and devices at 512x512
TIME_ESTIMATES = {
"flux": {
"cuda": {"load": 10, "per_step": 0.8},
"xpu": {"load": 15, "per_step": 1.0},
"cpu": {"load": 30, "per_step": 10.0},
}
}
class ImageRequest(BaseModel):
filepath: str
prompt: str
model: str = "black-forest-labs/FLUX.1-schnell"
iterations: int = 4
height: int = 256
width: int = 256
guidance_scale: float = 7.5
# Global model cache instance
model_cache = ImageModelCache()
def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str):
"""Background worker for Flux image generation"""
try:
# Your existing estimates calculation
estimates = {"per_step": 0.5} # Replace with your actual estimates
resolution_scale = (params.height * params.width) / (512 * 512)
# Flux: Run generation in the background and yield progress updates
estimated_gen_time = estimates["per_step"] * params.iterations * resolution_scale
status_queue.put({
"status": "running",
"message": f"Initializing image generation...",
"estimated_time_remaining": estimated_gen_time,
"progress": 0
})
# Start the generation task
start_gen_time = time.time()
# Simulate your pipe call with progress updates
def status_callback(pipeline, step, timestep, callback_kwargs):
# Send progress updates
progress = int((step+1) / params.iterations * 100)
status_queue.put({
"status": "running",
"message": f"Processing step {step+1}/{params.iterations} ({progress}%) complete.",
"progress": progress
})
return callback_kwargs
# Replace this block with your actual Flux pipe call:
image = pipe(
params.prompt,
num_inference_steps=params.iterations,
guidance_scale=7.5,
height=params.height,
width=params.width,
callback_on_step_end=status_callback,
).images[0]
gen_time = time.time() - start_gen_time
per_step_time = gen_time / params.iterations if params.iterations > 0 else gen_time
logger.info(f"Saving to {params.filepath}")
image.save(params.filepath)
# Final completion status
status_queue.put({
"status": "completed",
"message": f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.",
"progress": 100,
"generation_time": gen_time,
"per_step_time": per_step_time,
"image_path": params.filepath
})
except Exception as e:
logger.error(traceback.format_exc())
logger.error(e)
status_queue.put({
"status": "error",
"message": f"Generation failed: {str(e)}",
"error": str(e),
"progress": 0
})
async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[Dict[str, Any], None]:
"""
Single async function that handles background Flux generation with status streaming
"""
task_id = str(uuid.uuid4())
status_queue = queue.Queue()
worker_thread = None
try:
# Start background worker thread
worker_thread = Thread(
target=flux_worker,
args=(pipe, params, status_queue, task_id),
daemon=True
)
worker_thread.start()
# Initial status
yield {'status': 'starting', 'task_id': task_id, 'message': 'Initializing image generation'}
# Stream status updates
completed = False
last_heartbeat = time.time()
while not completed and worker_thread.is_alive():
try:
# Try to get status update (non-blocking)
status_update = status_queue.get_nowait()
# Add task_id to status update
status_update['task_id'] = task_id
# Send status update
yield status_update
# Check if completed
if status_update.get('status') in ['completed', 'error']:
completed = True
last_heartbeat = time.time()
except queue.Empty:
# No new status, send heartbeat if needed
current_time = time.time()
if current_time - last_heartbeat > 2: # Heartbeat every 2 seconds
heartbeat = {
'status': 'heartbeat',
'task_id': task_id,
'timestamp': current_time
}
yield heartbeat
last_heartbeat = current_time
# Brief sleep to prevent busy waiting
await asyncio.sleep(0.1)
# Handle thread completion or timeout
if not completed:
if worker_thread.is_alive():
# Thread still running but we might have missed the completion signal
timeout_status = {
'status': 'timeout',
'task_id': task_id,
'message': 'Generation timed out or connection lost'
}
yield timeout_status
else:
# Thread completed but we might have missed the final status
final_status = {
'status': 'completed',
'task_id': task_id,
'message': 'Generation completed'
}
yield final_status
except Exception as e:
error_status = {
'status': 'error',
'task_id': task_id,
'message': f'Server error: {str(e)}',
'error': str(e)
}
logger.error(error_status)
yield error_status
finally:
# Cleanup: ensure thread completion
if worker_thread and 'worker_thread' in locals() and worker_thread.is_alive():
worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup
def status(chat_message: ChatMessage, status: str, progress: float = 0, estimated_time_remaining="...") -> ChatMessage:
"""Update chat message status and return it."""
message = chat_message.copy(deep=True)
message.id = str(uuid.uuid4())
message.timestamp = datetime.now(UTC)
message.type = ChatMessageType.THINKING
message.status = ChatStatusType.STREAMING
message.content = status
return message
async def generate_image(user_message: ChatMessage, request: ImageRequest) -> AsyncGenerator[ChatMessage, None]:
"""Generate an image with specified dimensions and yield status updates with time estimates."""
chat_message = ChatMessage(
session_id=user_message.session_id,
tunables=user_message.tunables,
status=ChatStatusType.INITIALIZING,
type=ChatMessageType.PREPARING,
sender=ChatSenderType.ASSISTANT,
content="",
timestamp=datetime.now(UTC)
)
try:
# Validate prompt
prompt = user_message.content.strip()
if not prompt:
chat_message.status = ChatStatusType.ERROR
chat_message.content = "Prompt cannot be empty"
yield chat_message
return
# Validate dimensions
if request.height <= 0 or request.width <= 0:
chat_message.status = ChatStatusType.ERROR
chat_message.content = "Height and width must be positive"
yield chat_message
return
filedir = os.path.dirname(request.filepath)
filename = os.path.basename(request.filepath)
os.makedirs(filedir, exist_ok=True)
model_type = "flux"
device = "cpu"
yield status(chat_message, f"Starting image generation...")
# Get initial time estimate, scaled by resolution
estimates = TIME_ESTIMATES[model_type][device]
resolution_scale = (request.height * request.width) / (512 * 512)
estimated_total = estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale
yield status(chat_message, f"Estimated generation time: ~{estimated_total:.1f} seconds for {request.width}x{request.height}")
# Initialize or get cached pipeline
start_time = time.time()
yield status(chat_message, f"Loading generative image model...")
pipe = await model_cache.get_pipeline(request.model, device)
load_time = time.time() - start_time
yield status(chat_message, f"Model loaded in {load_time:.1f} seconds.", progress=10)
async for status_message in async_generate_image(pipe, request):
chat_message.content = json.dumps(status_message) # Merge properties from async_generate_image over the message...
chat_message.type = ChatMessageType.HEARTBEAT if status_message.get("status") == "heartbeat" else ChatMessageType.THINKING
if chat_message.type != ChatMessageType.HEARTBEAT:
logger.info(chat_message.content)
yield chat_message
# Final result
total_time = time.time() - start_time
chat_message.status = ChatStatusType.DONE
chat_message.type = ChatMessageType.RESPONSE
chat_message.content = json.dumps({
"status": f"Image generation complete in {total_time:.1f} seconds",
"progress": 100,
"filename": request.filepath
})
yield chat_message
except Exception as e:
chat_message.status = ChatStatusType.ERROR
chat_message.content = str(e)
yield chat_message
logger.error(traceback.format_exc())
logger.error(chat_message.content)
return