274 lines
9.7 KiB
Python
274 lines
9.7 KiB
Python
from __future__ import annotations
|
|
from pydantic import BaseModel, Field # type: ignore
|
|
from typing import Dict, Literal, Any, AsyncGenerator, Optional
|
|
import inspect
|
|
import random
|
|
import re
|
|
import json
|
|
import traceback
|
|
import asyncio
|
|
import time
|
|
import os
|
|
import gc
|
|
import tempfile
|
|
import uuid
|
|
import torch # type: ignore
|
|
import asyncio
|
|
import time
|
|
import json
|
|
from typing import AsyncGenerator
|
|
from threading import Thread
|
|
import queue
|
|
import uuid
|
|
|
|
from .agents.base import Agent, agent_registry, LLMMessage
|
|
from .message import Message
|
|
from .rag import ChromaDBGetResponse
|
|
from .setup_logging import setup_logging
|
|
from .image_model_cache import ImageModelCache
|
|
|
|
logger = setup_logging()
|
|
|
|
# Heuristic time estimates (in seconds) for different models and devices at 512x512
|
|
TIME_ESTIMATES = {
|
|
"flux": {
|
|
"cuda": {"load": 10, "per_step": 0.8},
|
|
"xpu": {"load": 15, "per_step": 1.0},
|
|
"cpu": {"load": 30, "per_step": 10.0},
|
|
}
|
|
}
|
|
|
|
class ImageRequest(BaseModel):
|
|
filepath: str
|
|
prompt: str
|
|
model: str = "black-forest-labs/FLUX.1-schnell"
|
|
iterations: int = 4
|
|
height: int = 256
|
|
width: int = 256
|
|
guidance_scale: float = 7.5
|
|
|
|
# Global model cache instance
|
|
model_cache = ImageModelCache(timeout_seconds=60 * 15) # 15 minutes
|
|
|
|
def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str):
|
|
"""Background worker for Flux image generation"""
|
|
try:
|
|
# Your existing estimates calculation
|
|
estimates = {"per_step": 0.5} # Replace with your actual estimates
|
|
resolution_scale = (params.height * params.width) / (512 * 512)
|
|
|
|
# Flux: Run generation in the background and yield progress updates
|
|
estimated_gen_time = estimates["per_step"] * params.iterations * resolution_scale
|
|
status_queue.put({
|
|
"status": "running",
|
|
"message": f"Starting Flux image generation with {params.iterations} inference steps",
|
|
"estimated_time_remaining": estimated_gen_time,
|
|
"progress": 0
|
|
})
|
|
|
|
# Start the generation task
|
|
start_gen_time = time.time()
|
|
|
|
# Simulate your pipe call with progress updates
|
|
def status_callback(pipeline, step, timestep, callback_kwargs):
|
|
# Send progress updates
|
|
elapsed = time.time() - start_gen_time
|
|
progress = int((step + 1) / params.iterations * 100)
|
|
|
|
status_queue.put({
|
|
"status": "running",
|
|
"message": f"Processing step {step}/{params.iterations}",
|
|
"progress": progress
|
|
})
|
|
|
|
# Replace this block with your actual Flux pipe call:
|
|
image = pipe(
|
|
params.prompt,
|
|
num_inference_steps=params.iterations,
|
|
guidance_scale=7.5,
|
|
height=params.height,
|
|
width=params.width,
|
|
callback_on_step_end=status_callback,
|
|
).images[0]
|
|
|
|
# Simulate image generation completion
|
|
gen_time = time.time() - start_gen_time
|
|
per_step_time = gen_time / params.iterations if params.iterations > 0 else gen_time
|
|
|
|
image.save(params.filepath)
|
|
|
|
# Final completion status
|
|
status_queue.put({
|
|
"status": "completed",
|
|
"message": f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.",
|
|
"progress": 100,
|
|
"generation_time": gen_time,
|
|
"per_step_time": per_step_time,
|
|
"image_path": f"generated_{task_id}.png" # Replace with actual image path/URL
|
|
})
|
|
|
|
except Exception as e:
|
|
status_queue.put({
|
|
"status": "error",
|
|
"message": f"Generation failed: {str(e)}",
|
|
"error": str(e),
|
|
"progress": 0
|
|
})
|
|
|
|
|
|
async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[Dict[str, Any], None]:
|
|
"""
|
|
Single async function that handles background Flux generation with status streaming
|
|
"""
|
|
task_id = str(uuid.uuid4())
|
|
status_queue = queue.Queue()
|
|
worker_thread = None
|
|
|
|
try:
|
|
# Start background worker thread
|
|
worker_thread = Thread(
|
|
target=flux_worker,
|
|
args=(pipe, params, status_queue, task_id),
|
|
daemon=True
|
|
)
|
|
worker_thread.start()
|
|
|
|
# Initial status
|
|
yield {'status': 'starting', 'task_id': task_id, 'message': 'Initializing image generation'}
|
|
|
|
# Stream status updates
|
|
completed = False
|
|
last_heartbeat = time.time()
|
|
|
|
while not completed and worker_thread.is_alive():
|
|
try:
|
|
# Try to get status update (non-blocking)
|
|
status_update = status_queue.get_nowait()
|
|
|
|
# Add task_id to status update
|
|
status_update['task_id'] = task_id
|
|
|
|
# Send status update
|
|
yield status_update
|
|
|
|
# Check if completed
|
|
if status_update.get('status') in ['completed', 'error']:
|
|
completed = True
|
|
|
|
last_heartbeat = time.time()
|
|
|
|
except queue.Empty:
|
|
# No new status, send heartbeat if needed
|
|
current_time = time.time()
|
|
if current_time - last_heartbeat > 2: # Heartbeat every 2 seconds
|
|
heartbeat = {
|
|
'status': 'heartbeat',
|
|
'task_id': task_id,
|
|
'timestamp': current_time
|
|
}
|
|
yield heartbeat
|
|
last_heartbeat = current_time
|
|
|
|
# Brief sleep to prevent busy waiting
|
|
await asyncio.sleep(0.1)
|
|
|
|
# Handle thread completion or timeout
|
|
if not completed:
|
|
if worker_thread.is_alive():
|
|
# Thread still running but we might have missed the completion signal
|
|
timeout_status = {
|
|
'status': 'timeout',
|
|
'task_id': task_id,
|
|
'message': 'Generation timed out or connection lost'
|
|
}
|
|
yield timeout_status
|
|
else:
|
|
# Thread completed but we might have missed the final status
|
|
final_status = {
|
|
'status': 'completed',
|
|
'task_id': task_id,
|
|
'message': 'Generation completed'
|
|
}
|
|
yield final_status
|
|
|
|
except Exception as e:
|
|
error_status = {
|
|
'status': 'error',
|
|
'task_id': task_id,
|
|
'message': f'Server error: {str(e)}',
|
|
'error': str(e)
|
|
}
|
|
yield error_status
|
|
|
|
finally:
|
|
# Cleanup: ensure thread completion
|
|
if worker_thread and 'worker_thread' in locals() and worker_thread.is_alive():
|
|
worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup
|
|
|
|
def status(message: Message, status: str, progress: float = 0, estimated_time_remaining="...") -> Message:
|
|
message.status = "thinking"
|
|
message.response = status
|
|
return message
|
|
|
|
async def generate_image(message: Message, request: ImageRequest) -> AsyncGenerator[Message, None]:
|
|
"""Generate an image with specified dimensions and yield status updates with time estimates."""
|
|
try:
|
|
# Validate prompt
|
|
prompt = request.prompt.strip()
|
|
if not prompt:
|
|
message.status = "error"
|
|
message.response = "Prompt cannot be empty"
|
|
yield message
|
|
return
|
|
|
|
# Validate dimensions
|
|
if request.height <= 0 or request.width <= 0:
|
|
message.status = "error"
|
|
message.response = "Height and width must be positive"
|
|
yield message
|
|
return
|
|
|
|
filedir = os.path.dirname(request.filepath)
|
|
filename = os.path.basename(request.filepath)
|
|
os.makedirs(filedir, exist_ok=True)
|
|
|
|
model_type = "flux"
|
|
device = "cpu"
|
|
|
|
yield status(message, f"Starting image generation for prompt: {prompt} {request.width}x{request.height} as {filename} using {device}")
|
|
|
|
# Get initial time estimate, scaled by resolution
|
|
estimates = TIME_ESTIMATES[model_type][device]
|
|
resolution_scale = (request.height * request.width) / (512 * 512)
|
|
estimated_total = estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale
|
|
yield status(message, f"Estimated generation time: ~{estimated_total:.1f} seconds for {request.width}x{request.height}")
|
|
|
|
# Initialize or get cached pipeline
|
|
start_time = time.time()
|
|
yield status(message, f"Loading {model_type} model: {request.model}")
|
|
pipe = await model_cache.get_pipeline(request.model, device)
|
|
load_time = time.time() - start_time
|
|
yield status(message, f"Model loaded in {load_time:.1f} seconds. Generating image with {request.iterations} inference steps", progress=10)
|
|
|
|
async for update in async_generate_image(pipe, request):
|
|
message.status = update.get("status", "thinking")
|
|
message.response = json.dumps(update) # Merge properties from async_generate_image over the message...
|
|
yield message
|
|
|
|
# Final result
|
|
total_time = time.time() - start_time
|
|
message.status = "done"
|
|
message.response = json.dumps({
|
|
"status": f"Image generation complete in {total_time:.1f} seconds",
|
|
"progress": 100,
|
|
"filename": request.filepath
|
|
})
|
|
yield message
|
|
|
|
except Exception as e:
|
|
message.status = "error"
|
|
message.response = str(e)
|
|
yield message
|
|
logger.error(traceback.format_exc())
|
|
logger.error(message.response)
|
|
return |