279 lines
9.8 KiB
Python

from __future__ import annotations
from pydantic import BaseModel
from typing import Any, AsyncGenerator
import traceback
import asyncio
import time
import os
import uuid
import asyncio
import time
from typing import AsyncGenerator
from threading import Thread
import queue
import uuid
from .image_model_cache import ImageModelCache
from models import ApiActivityType, ApiStatusType, ChatMessage, ChatMessageError, ChatMessageStatus
from logger import logger
from image_generator.image_model_cache import ImageModelCache
# Heuristic time estimates (in seconds) for different models and devices at 512x512
TIME_ESTIMATES = {
"flux": {
"cuda": {"load": 10, "per_step": 0.8},
"xpu": {"load": 15, "per_step": 1.0},
"cpu": {"load": 30, "per_step": 10.0},
}
}
class ImageRequest(BaseModel):
session_id: str
filepath: str
prompt: str
model: str = "black-forest-labs/FLUX.1-schnell"
iterations: int = 4
height: int = 256
width: int = 256
guidance_scale: float = 7.5
# Global model cache instance
model_cache = ImageModelCache()
def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str):
"""Background worker for Flux image generation"""
try:
# Flux: Run generation in the background and yield progress updates
status_queue.put(ChatMessageStatus(
session_id=params.session_id,
content=f"Initializing image generation.",
activity=ApiActivityType.GENERATING_IMAGE,
))
# Start the generation task
start_gen_time = time.time()
# Simulate your pipe call with progress updates
def status_callback(pipeline, step, timestep, callback_kwargs):
# Send progress updates
progress = int((step+1) / params.iterations * 100)
status_queue.put(ChatMessageStatus(
session_id=params.session_id,
content=f"Processing step {step+1}/{params.iterations} ({progress}%)",
activity=ApiActivityType.GENERATING_IMAGE,
))
return callback_kwargs
# Replace this block with your actual Flux pipe call:
image = pipe(
params.prompt,
num_inference_steps=params.iterations,
guidance_scale=7.5,
height=params.height,
width=params.width,
callback_on_step_end=status_callback,
).images[0]
gen_time = time.time() - start_gen_time
per_step_time = gen_time / params.iterations if params.iterations > 0 else gen_time
logger.info(f"Saving to {params.filepath}")
image.save(params.filepath)
# Final completion status
status_queue.put(ChatMessage(
session_id=params.session_id,
status=ApiStatusType.DONE,
content=f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.",
))
except Exception as e:
logger.error(traceback.format_exc())
logger.error(e)
status_queue.put(ChatMessageError(
session_id=params.session_id,
content=f"Error during image generation: {str(e)}",
))
async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError, None]:
"""
Single async function that handles background Flux generation with status streaming
"""
task_id = str(uuid.uuid4())
status_queue = queue.Queue()
worker_thread = None
try:
# Start background worker thread
worker_thread = Thread(
target=flux_worker,
args=(pipe, params, status_queue, task_id),
daemon=True
)
worker_thread.start()
# Initial status
status_message = ChatMessageStatus(
session_id=params.session_id,
content=f"Starting image generation with task ID {task_id}",
activity=ApiActivityType.THINKING
)
yield status_message
# Stream status updates
completed = False
last_heartbeat = time.time()
while not completed and worker_thread.is_alive():
try:
# Try to get status update (non-blocking)
status_update = status_queue.get_nowait()
# Send status update
yield status_update
# Check if completed
if status_update.status == ApiStatusType.DONE:
logger.info(f"Image generation completed for task {task_id}")
completed = True
last_heartbeat = time.time()
except queue.Empty:
# No new status, send heartbeat if needed
current_time = time.time()
if current_time - last_heartbeat > 2: # Heartbeat every 2 seconds
heartbeat = ChatMessageStatus(
session_id=params.session_id,
content=f"Heartbeat for task {task_id}",
activity=ApiActivityType.HEARTBEAT,
)
yield heartbeat
last_heartbeat = current_time
# Brief sleep to prevent busy waiting
await asyncio.sleep(0.1)
# Handle thread completion or timeout
if not completed:
if worker_thread.is_alive():
# Thread still running but we might have missed the completion signal
timeout_status = ChatMessageError(
session_id=params.session_id,
content=f"Generation timeout for task {task_id}. The process may still be running.",
)
yield timeout_status
else:
# Thread completed but we might have missed the final status
final_status = ChatMessage(
session_id=params.session_id,
status=ApiStatusType.DONE,
content=f"Generation completed for task {task_id}.",
)
yield final_status
except Exception as e:
error_status = ChatMessageError(
session_id=params.session_id,
content=f'Server error: {str(e)}'
)
logger.error(error_status)
yield error_status
finally:
# Cleanup: ensure thread completion
if worker_thread and 'worker_thread' in locals() and worker_thread.is_alive():
worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup
def status(session_id: str, status: str) -> ChatMessageStatus:
"""Update chat message status and return it."""
chat_message = ChatMessageStatus(
session_id=session_id,
activity=ApiActivityType.GENERATING_IMAGE,
content=status,
)
return chat_message
async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, None]:
"""Generate an image with specified dimensions and yield status updates with time estimates."""
session_id = request.session_id
prompt = request.prompt.strip()
try:
# Validate prompt
if not prompt:
error_message = ChatMessageError(
session_id=session_id,
content="Prompt cannot be empty."
)
logger.error(error_message.content)
yield error_message
return
# Validate dimensions
if request.height <= 0 or request.width <= 0:
error_message = ChatMessageError(
session_id=session_id,
content="Height and width must be positive integers."
)
logger.error(error_message.content)
yield error_message
return
filedir = os.path.dirname(request.filepath)
filename = os.path.basename(request.filepath)
os.makedirs(filedir, exist_ok=True)
model_type = "flux"
device = "cpu"
# Get initial time estimate, scaled by resolution
estimates = TIME_ESTIMATES[model_type][device]
resolution_scale = (request.height * request.width) / (512 * 512)
estimated_total = estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale
# Initialize or get cached pipeline
start_time = time.time()
yield status(session_id, f"Loading generative image model...")
pipe = await model_cache.get_pipeline(request.model, device)
load_time = time.time() - start_time
yield status(session_id, f"Model loaded in {load_time:.1f} seconds.",)
progress = None
async for progress in async_generate_image(pipe, request):
if progress.status == ApiStatusType.ERROR:
yield progress
return
if progress.status != ApiStatusType.DONE:
yield progress
if not progress:
error_message = ChatMessageError(
session_id=session_id,
content="Image generation failed to produce a valid response."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
return
# Final result
total_time = time.time() - start_time
chat_message = ChatMessage(
session_id=session_id,
status=ApiStatusType.DONE,
content=f"Image generated successfully in {total_time:.1f} seconds.",
)
yield chat_message
except Exception as e:
error_message = ChatMessageError(
session_id=session_id,
content=f"Error during image generation: {str(e)}"
)
logger.error(traceback.format_exc())
logger.error(error_message.content)
yield error_message
return