279 lines
9.8 KiB
Python
279 lines
9.8 KiB
Python
from __future__ import annotations
|
|
from pydantic import BaseModel
|
|
from typing import Any, AsyncGenerator
|
|
import traceback
|
|
import asyncio
|
|
import time
|
|
import os
|
|
import uuid
|
|
import asyncio
|
|
import time
|
|
from typing import AsyncGenerator
|
|
from threading import Thread
|
|
import queue
|
|
import uuid
|
|
|
|
from .image_model_cache import ImageModelCache
|
|
|
|
from models import ApiActivityType, ApiStatusType, ChatMessage, ChatMessageError, ChatMessageStatus
|
|
from logger import logger
|
|
|
|
from image_generator.image_model_cache import ImageModelCache
|
|
|
|
# Heuristic time estimates (in seconds) for different models and devices at 512x512
|
|
TIME_ESTIMATES = {
|
|
"flux": {
|
|
"cuda": {"load": 10, "per_step": 0.8},
|
|
"xpu": {"load": 15, "per_step": 1.0},
|
|
"cpu": {"load": 30, "per_step": 10.0},
|
|
}
|
|
}
|
|
|
|
class ImageRequest(BaseModel):
|
|
session_id: str
|
|
filepath: str
|
|
prompt: str
|
|
model: str = "black-forest-labs/FLUX.1-schnell"
|
|
iterations: int = 4
|
|
height: int = 256
|
|
width: int = 256
|
|
guidance_scale: float = 7.5
|
|
|
|
# Global model cache instance
|
|
model_cache = ImageModelCache()
|
|
|
|
def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str):
|
|
"""Background worker for Flux image generation"""
|
|
try:
|
|
# Flux: Run generation in the background and yield progress updates
|
|
status_queue.put(ChatMessageStatus(
|
|
session_id=params.session_id,
|
|
content=f"Initializing image generation.",
|
|
activity=ApiActivityType.GENERATING_IMAGE,
|
|
))
|
|
|
|
# Start the generation task
|
|
start_gen_time = time.time()
|
|
|
|
# Simulate your pipe call with progress updates
|
|
def status_callback(pipeline, step, timestep, callback_kwargs):
|
|
# Send progress updates
|
|
progress = int((step+1) / params.iterations * 100)
|
|
|
|
status_queue.put(ChatMessageStatus(
|
|
session_id=params.session_id,
|
|
content=f"Processing step {step+1}/{params.iterations} ({progress}%)",
|
|
activity=ApiActivityType.GENERATING_IMAGE,
|
|
))
|
|
return callback_kwargs
|
|
|
|
# Replace this block with your actual Flux pipe call:
|
|
image = pipe(
|
|
params.prompt,
|
|
num_inference_steps=params.iterations,
|
|
guidance_scale=7.5,
|
|
height=params.height,
|
|
width=params.width,
|
|
callback_on_step_end=status_callback,
|
|
).images[0]
|
|
|
|
gen_time = time.time() - start_gen_time
|
|
per_step_time = gen_time / params.iterations if params.iterations > 0 else gen_time
|
|
|
|
logger.info(f"Saving to {params.filepath}")
|
|
image.save(params.filepath)
|
|
|
|
# Final completion status
|
|
status_queue.put(ChatMessage(
|
|
session_id=params.session_id,
|
|
status=ApiStatusType.DONE,
|
|
content=f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.",
|
|
))
|
|
|
|
except Exception as e:
|
|
logger.error(traceback.format_exc())
|
|
logger.error(e)
|
|
status_queue.put(ChatMessageError(
|
|
session_id=params.session_id,
|
|
content=f"Error during image generation: {str(e)}",
|
|
))
|
|
|
|
|
|
async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError, None]:
|
|
"""
|
|
Single async function that handles background Flux generation with status streaming
|
|
"""
|
|
task_id = str(uuid.uuid4())
|
|
status_queue = queue.Queue()
|
|
worker_thread = None
|
|
|
|
try:
|
|
# Start background worker thread
|
|
worker_thread = Thread(
|
|
target=flux_worker,
|
|
args=(pipe, params, status_queue, task_id),
|
|
daemon=True
|
|
)
|
|
worker_thread.start()
|
|
|
|
# Initial status
|
|
status_message = ChatMessageStatus(
|
|
session_id=params.session_id,
|
|
content=f"Starting image generation with task ID {task_id}",
|
|
activity=ApiActivityType.THINKING
|
|
)
|
|
yield status_message
|
|
|
|
# Stream status updates
|
|
completed = False
|
|
last_heartbeat = time.time()
|
|
|
|
while not completed and worker_thread.is_alive():
|
|
try:
|
|
# Try to get status update (non-blocking)
|
|
status_update = status_queue.get_nowait()
|
|
|
|
# Send status update
|
|
yield status_update
|
|
|
|
# Check if completed
|
|
if status_update.status == ApiStatusType.DONE:
|
|
logger.info(f"Image generation completed for task {task_id}")
|
|
completed = True
|
|
|
|
last_heartbeat = time.time()
|
|
|
|
except queue.Empty:
|
|
# No new status, send heartbeat if needed
|
|
current_time = time.time()
|
|
if current_time - last_heartbeat > 2: # Heartbeat every 2 seconds
|
|
heartbeat = ChatMessageStatus(
|
|
session_id=params.session_id,
|
|
content=f"Heartbeat for task {task_id}",
|
|
activity=ApiActivityType.HEARTBEAT,
|
|
)
|
|
yield heartbeat
|
|
last_heartbeat = current_time
|
|
|
|
# Brief sleep to prevent busy waiting
|
|
await asyncio.sleep(0.1)
|
|
|
|
# Handle thread completion or timeout
|
|
if not completed:
|
|
if worker_thread.is_alive():
|
|
# Thread still running but we might have missed the completion signal
|
|
timeout_status = ChatMessageError(
|
|
session_id=params.session_id,
|
|
content=f"Generation timeout for task {task_id}. The process may still be running.",
|
|
)
|
|
yield timeout_status
|
|
else:
|
|
# Thread completed but we might have missed the final status
|
|
final_status = ChatMessage(
|
|
session_id=params.session_id,
|
|
status=ApiStatusType.DONE,
|
|
content=f"Generation completed for task {task_id}.",
|
|
)
|
|
yield final_status
|
|
|
|
except Exception as e:
|
|
error_status = ChatMessageError(
|
|
session_id=params.session_id,
|
|
content=f'Server error: {str(e)}'
|
|
)
|
|
logger.error(error_status)
|
|
yield error_status
|
|
|
|
finally:
|
|
# Cleanup: ensure thread completion
|
|
if worker_thread and 'worker_thread' in locals() and worker_thread.is_alive():
|
|
worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup
|
|
|
|
def status(session_id: str, status: str) -> ChatMessageStatus:
|
|
"""Update chat message status and return it."""
|
|
chat_message = ChatMessageStatus(
|
|
session_id=session_id,
|
|
activity=ApiActivityType.GENERATING_IMAGE,
|
|
content=status,
|
|
)
|
|
return chat_message
|
|
|
|
async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, None]:
|
|
"""Generate an image with specified dimensions and yield status updates with time estimates."""
|
|
session_id = request.session_id
|
|
prompt = request.prompt.strip()
|
|
try:
|
|
# Validate prompt
|
|
if not prompt:
|
|
error_message = ChatMessageError(
|
|
session_id=session_id,
|
|
content="Prompt cannot be empty."
|
|
)
|
|
logger.error(error_message.content)
|
|
yield error_message
|
|
return
|
|
|
|
# Validate dimensions
|
|
if request.height <= 0 or request.width <= 0:
|
|
error_message = ChatMessageError(
|
|
session_id=session_id,
|
|
content="Height and width must be positive integers."
|
|
)
|
|
logger.error(error_message.content)
|
|
yield error_message
|
|
return
|
|
|
|
filedir = os.path.dirname(request.filepath)
|
|
filename = os.path.basename(request.filepath)
|
|
os.makedirs(filedir, exist_ok=True)
|
|
|
|
model_type = "flux"
|
|
device = "cpu"
|
|
|
|
# Get initial time estimate, scaled by resolution
|
|
estimates = TIME_ESTIMATES[model_type][device]
|
|
resolution_scale = (request.height * request.width) / (512 * 512)
|
|
estimated_total = estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale
|
|
|
|
# Initialize or get cached pipeline
|
|
start_time = time.time()
|
|
yield status(session_id, f"Loading generative image model...")
|
|
pipe = await model_cache.get_pipeline(request.model, device)
|
|
load_time = time.time() - start_time
|
|
yield status(session_id, f"Model loaded in {load_time:.1f} seconds.",)
|
|
|
|
progress = None
|
|
async for progress in async_generate_image(pipe, request):
|
|
if progress.status == ApiStatusType.ERROR:
|
|
yield progress
|
|
return
|
|
if progress.status != ApiStatusType.DONE:
|
|
yield progress
|
|
|
|
if not progress:
|
|
error_message = ChatMessageError(
|
|
session_id=session_id,
|
|
content="Image generation failed to produce a valid response."
|
|
)
|
|
logger.error(f"⚠️ {error_message.content}")
|
|
yield error_message
|
|
return
|
|
|
|
# Final result
|
|
total_time = time.time() - start_time
|
|
chat_message = ChatMessage(
|
|
session_id=session_id,
|
|
status=ApiStatusType.DONE,
|
|
content=f"Image generated successfully in {total_time:.1f} seconds.",
|
|
)
|
|
yield chat_message
|
|
|
|
except Exception as e:
|
|
error_message = ChatMessageError(
|
|
session_id=session_id,
|
|
content=f"Error during image generation: {str(e)}"
|
|
)
|
|
logger.error(traceback.format_exc())
|
|
logger.error(error_message.content)
|
|
yield error_message
|
|
return |