backstory/src/utils/profile_image.py
2025-05-22 15:25:52 -07:00

274 lines
9.7 KiB
Python

from __future__ import annotations
from pydantic import BaseModel, Field # type: ignore
from typing import Dict, Literal, Any, AsyncGenerator, Optional
import inspect
import random
import re
import json
import traceback
import asyncio
import time
import os
import gc
import tempfile
import uuid
import torch # type: ignore
import asyncio
import time
import json
from typing import AsyncGenerator
from threading import Thread
import queue
import uuid
from .agents.base import Agent, agent_registry, LLMMessage
from .message import Message
from .rag import ChromaDBGetResponse
from .setup_logging import setup_logging
from .image_model_cache import ImageModelCache
logger = setup_logging()
# Heuristic time estimates (in seconds) for different models and devices at 512x512
TIME_ESTIMATES = {
"flux": {
"cuda": {"load": 10, "per_step": 0.8},
"xpu": {"load": 15, "per_step": 1.0},
"cpu": {"load": 30, "per_step": 10.0},
}
}
class ImageRequest(BaseModel):
filepath: str
prompt: str
model: str = "black-forest-labs/FLUX.1-schnell"
iterations: int = 4
height: int = 256
width: int = 256
guidance_scale: float = 7.5
# Global model cache instance
model_cache = ImageModelCache(timeout_seconds=60 * 15) # 15 minutes
def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str):
"""Background worker for Flux image generation"""
try:
# Your existing estimates calculation
estimates = {"per_step": 0.5} # Replace with your actual estimates
resolution_scale = (params.height * params.width) / (512 * 512)
# Flux: Run generation in the background and yield progress updates
estimated_gen_time = estimates["per_step"] * params.iterations * resolution_scale
status_queue.put({
"status": "running",
"message": f"Starting Flux image generation with {params.iterations} inference steps",
"estimated_time_remaining": estimated_gen_time,
"progress": 0
})
# Start the generation task
start_gen_time = time.time()
# Simulate your pipe call with progress updates
def status_callback(pipeline, step, timestep, callback_kwargs):
# Send progress updates
elapsed = time.time() - start_gen_time
progress = int((step + 1) / params.iterations * 100)
status_queue.put({
"status": "running",
"message": f"Processing step {step}/{params.iterations}",
"progress": progress
})
# Replace this block with your actual Flux pipe call:
image = pipe(
params.prompt,
num_inference_steps=params.iterations,
guidance_scale=7.5,
height=params.height,
width=params.width,
callback_on_step_end=status_callback,
).images[0]
# Simulate image generation completion
gen_time = time.time() - start_gen_time
per_step_time = gen_time / params.iterations if params.iterations > 0 else gen_time
image.save(params.filepath)
# Final completion status
status_queue.put({
"status": "completed",
"message": f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.",
"progress": 100,
"generation_time": gen_time,
"per_step_time": per_step_time,
"image_path": f"generated_{task_id}.png" # Replace with actual image path/URL
})
except Exception as e:
status_queue.put({
"status": "error",
"message": f"Generation failed: {str(e)}",
"error": str(e),
"progress": 0
})
async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[Dict[str, Any], None]:
"""
Single async function that handles background Flux generation with status streaming
"""
task_id = str(uuid.uuid4())
status_queue = queue.Queue()
worker_thread = None
try:
# Start background worker thread
worker_thread = Thread(
target=flux_worker,
args=(pipe, params, status_queue, task_id),
daemon=True
)
worker_thread.start()
# Initial status
yield {'status': 'starting', 'task_id': task_id, 'message': 'Initializing image generation'}
# Stream status updates
completed = False
last_heartbeat = time.time()
while not completed and worker_thread.is_alive():
try:
# Try to get status update (non-blocking)
status_update = status_queue.get_nowait()
# Add task_id to status update
status_update['task_id'] = task_id
# Send status update
yield status_update
# Check if completed
if status_update.get('status') in ['completed', 'error']:
completed = True
last_heartbeat = time.time()
except queue.Empty:
# No new status, send heartbeat if needed
current_time = time.time()
if current_time - last_heartbeat > 2: # Heartbeat every 2 seconds
heartbeat = {
'status': 'heartbeat',
'task_id': task_id,
'timestamp': current_time
}
yield heartbeat
last_heartbeat = current_time
# Brief sleep to prevent busy waiting
await asyncio.sleep(0.1)
# Handle thread completion or timeout
if not completed:
if worker_thread.is_alive():
# Thread still running but we might have missed the completion signal
timeout_status = {
'status': 'timeout',
'task_id': task_id,
'message': 'Generation timed out or connection lost'
}
yield timeout_status
else:
# Thread completed but we might have missed the final status
final_status = {
'status': 'completed',
'task_id': task_id,
'message': 'Generation completed'
}
yield final_status
except Exception as e:
error_status = {
'status': 'error',
'task_id': task_id,
'message': f'Server error: {str(e)}',
'error': str(e)
}
yield error_status
finally:
# Cleanup: ensure thread completion
if worker_thread and 'worker_thread' in locals() and worker_thread.is_alive():
worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup
def status(message: Message, status: str, progress: float = 0, estimated_time_remaining="...") -> Message:
message.status = "thinking"
message.response = status
return message
async def generate_image(message: Message, request: ImageRequest) -> AsyncGenerator[Message, None]:
"""Generate an image with specified dimensions and yield status updates with time estimates."""
try:
# Validate prompt
prompt = request.prompt.strip()
if not prompt:
message.status = "error"
message.response = "Prompt cannot be empty"
yield message
return
# Validate dimensions
if request.height <= 0 or request.width <= 0:
message.status = "error"
message.response = "Height and width must be positive"
yield message
return
filedir = os.path.dirname(request.filepath)
filename = os.path.basename(request.filepath)
os.makedirs(filedir, exist_ok=True)
model_type = "flux"
device = "cpu"
yield status(message, f"Starting image generation for prompt: {prompt} {request.width}x{request.height} as {filename} using {device}")
# Get initial time estimate, scaled by resolution
estimates = TIME_ESTIMATES[model_type][device]
resolution_scale = (request.height * request.width) / (512 * 512)
estimated_total = estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale
yield status(message, f"Estimated generation time: ~{estimated_total:.1f} seconds for {request.width}x{request.height}")
# Initialize or get cached pipeline
start_time = time.time()
yield status(message, f"Loading {model_type} model: {request.model}")
pipe = await model_cache.get_pipeline(request.model, device)
load_time = time.time() - start_time
yield status(message, f"Model loaded in {load_time:.1f} seconds. Generating image with {request.iterations} inference steps", progress=10)
async for update in async_generate_image(pipe, request):
message.status = update.get("status", "thinking")
message.response = json.dumps(update) # Merge properties from async_generate_image over the message...
yield message
# Final result
total_time = time.time() - start_time
message.status = "done"
message.response = json.dumps({
"status": f"Image generation complete in {total_time:.1f} seconds",
"progress": 100,
"filename": request.filepath
})
yield message
except Exception as e:
message.status = "error"
message.response = str(e)
yield message
logger.error(traceback.format_exc())
logger.error(message.response)
return