Gen image works
This commit is contained in:
parent
840ad9159b
commit
74dea019fa
@ -58,7 +58,7 @@ type BackstoryMessage = {
|
||||
full_content?: string;
|
||||
response?: string; // Set when status === 'done', 'partial', or 'error'
|
||||
chunk?: string; // Used when status === 'streaming'
|
||||
timestamp?: string;
|
||||
timestamp?: number;
|
||||
disableCopy?: boolean,
|
||||
user?: string,
|
||||
title?: string,
|
||||
|
@ -115,6 +115,7 @@ const BackstoryPageContainer = (props : BackstoryPageContainerProps) => {
|
||||
borderRadius: 2,
|
||||
minHeight: '80vh',
|
||||
width: "100%",
|
||||
flexDirection: "column",
|
||||
}}>
|
||||
{children}
|
||||
</Paper>
|
||||
|
@ -17,7 +17,7 @@ from .setup_logging import setup_logging
|
||||
from .agents import class_registry, AnyAgent, Agent, __all__ as agents_all
|
||||
from .metrics import Metrics
|
||||
from .check_serializable import check_serializable
|
||||
from .profile_image import generate_image_status, ImageRequest
|
||||
from .profile_image import generate_image, ImageRequest
|
||||
|
||||
__all__ = [
|
||||
"Agent",
|
||||
@ -34,7 +34,7 @@ __all__ = [
|
||||
"check_serializable",
|
||||
"logger",
|
||||
"User",
|
||||
"generate_image_status", "ImageRequest"
|
||||
"generate_image", "ImageRequest"
|
||||
]
|
||||
|
||||
__all__.extend(agents_all) # type: ignore
|
||||
|
@ -50,6 +50,7 @@ class Agent(BaseModel, ABC):
|
||||
# Agent management with pydantic
|
||||
agent_type: Literal["base"] = "base"
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
agent_persist: bool = True # Whether this agent will persist in the context
|
||||
|
||||
# Tunables (sets default for new Messages attached to this agent)
|
||||
tunables: Tunables = Field(default_factory=Tunables)
|
||||
|
@ -1,448 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import model_validator, Field, BaseModel # type: ignore
|
||||
from typing import (
|
||||
Dict,
|
||||
Literal,
|
||||
ClassVar,
|
||||
cast,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
List,
|
||||
Optional
|
||||
# override
|
||||
) # NOTE: You must import Optional for late binding to work
|
||||
import inspect
|
||||
import random
|
||||
import re
|
||||
import json
|
||||
import traceback
|
||||
import asyncio
|
||||
import time
|
||||
import asyncio
|
||||
import time
|
||||
import os
|
||||
|
||||
from . base import Agent, agent_registry, LLMMessage
|
||||
from .. message import Message
|
||||
from .. rag import ChromaDBGetResponse
|
||||
from .. setup_logging import setup_logging
|
||||
from .. profile_image import generate_image_status, ImageRequest
|
||||
from .. import defines
|
||||
from .. user import User
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
seed = int(time.time())
|
||||
random.seed(seed)
|
||||
|
||||
emptyUser = {
|
||||
"profile_url": "",
|
||||
"description": "",
|
||||
"rag_content_size": 0,
|
||||
"username": "",
|
||||
"first_name": "",
|
||||
"last_name": "",
|
||||
"full_name": "",
|
||||
"contact_info": {},
|
||||
"questions": [],
|
||||
}
|
||||
|
||||
generate_persona_system_prompt = """\
|
||||
You are a casing director for a movie. Your job is to provide information on ficticious personas for use in a screen play.
|
||||
|
||||
All response field MUST BE IN ENGLISH, regardless of ethnicity.
|
||||
|
||||
You will be provided with defaults to use if not specified by the user:
|
||||
|
||||
```json
|
||||
{
|
||||
"age": number,
|
||||
"gender": "male" | "female",
|
||||
"ethnicity": string,
|
||||
}
|
||||
```
|
||||
|
||||
Additional information provided in the user message can override those defaults.
|
||||
|
||||
You need to randomly assign an English username (can include numbers), a first name, last name, and a two English sentence description of that individual's work given the demographics provided.
|
||||
|
||||
Your response must be in JSON.
|
||||
Provide only the JSON response, and match the field names EXACTLY.
|
||||
Provide all information in English ONLY, with no other commentary:
|
||||
|
||||
```json
|
||||
{
|
||||
"username": string, # A likely-to-be unique username, no more than 15 characters (can include numbers and letters but no special characters)
|
||||
"first_name": string,
|
||||
"last_name": string,
|
||||
"description": string, # One to two sentence description of their job
|
||||
"location": string, # In the location, provide ALL of: City, State/Region, and Country
|
||||
}
|
||||
```
|
||||
|
||||
Make sure to provide a username and that the field name for the job description is "description".
|
||||
"""
|
||||
|
||||
generate_resume_system_prompt = """
|
||||
You are a creative writing casting director. As part of the casting, you are building backstories about individuals. The first part
|
||||
of that is to create an in-depth resume for the person. You will be provided with the following information:
|
||||
|
||||
```json
|
||||
"full_name": string, # Person full name
|
||||
"location": string, # Location of residence
|
||||
"age": number, # Age of candidate
|
||||
"description": string # A brief description of the person
|
||||
```
|
||||
|
||||
Use that information to invent a full career resume. Include sections such as:
|
||||
|
||||
* Contact information
|
||||
* Job goal
|
||||
* Top skills
|
||||
* Detailed work history. If they are under the age of 25, you might include skills, hobbies, or volunteering they may have done while an adolescent
|
||||
* In the work history, provide company names, years of employment, and their role
|
||||
* Education
|
||||
|
||||
Provide the resume in Markdown format. DO NOT provide any commentary before or after the resume.
|
||||
"""
|
||||
|
||||
class PersonaGenerator(Agent):
|
||||
agent_type: Literal["persona"] = "persona" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
system_prompt: str = generate_persona_system_prompt
|
||||
age: int = Field(default_factory=lambda: random.randint(22, 67))
|
||||
gender: str = Field(default_factory=lambda: random.choice(["male", "female"]))
|
||||
ethnicity: Literal[
|
||||
"Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"
|
||||
] = Field(
|
||||
default_factory=lambda: random.choices(
|
||||
["Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"],
|
||||
weights=[57.69, 15.38, 19.23, 5.77, 1.92],
|
||||
k=1
|
||||
)[0]
|
||||
)
|
||||
username: str = ""
|
||||
|
||||
llm: Any = Field(default=None, exclude=True)
|
||||
model: str = Field(default=None, exclude=True)
|
||||
|
||||
def randomize(self):
|
||||
self.age = random.randint(22, 67)
|
||||
self.gender = random.choice(["male", "female"])
|
||||
# Use random.choices with explicit type casting to satisfy Literal type
|
||||
self.ethnicity = cast(
|
||||
Literal["Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"],
|
||||
random.choices(
|
||||
["Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"],
|
||||
weights=[57.69, 15.38, 19.23, 5.77, 1.92],
|
||||
k=1
|
||||
)[0]
|
||||
)
|
||||
|
||||
async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
|
||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||
|
||||
if not self.context:
|
||||
raise ValueError("Context is not set for this agent.")
|
||||
|
||||
message.tunables.enable_tools = False
|
||||
message.tunables.enable_rag = False
|
||||
message.tunables.enable_context = False
|
||||
|
||||
message.prompt = f"""\
|
||||
```json
|
||||
{json.dumps({
|
||||
"age": self.age,
|
||||
"gender": self.gender,
|
||||
"ethnicity": self.ethnicity
|
||||
})}
|
||||
```
|
||||
{message.prompt}
|
||||
"""
|
||||
message.status = "done"
|
||||
yield message
|
||||
return
|
||||
|
||||
async def process_message(
|
||||
self, llm: Any, model: str, message: Message
|
||||
) -> AsyncGenerator[Message, None]:
|
||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||
if not self.context:
|
||||
raise ValueError("Context is not set for this agent.")
|
||||
|
||||
self.llm = llm
|
||||
self.model = model
|
||||
original_prompt = message.prompt
|
||||
|
||||
spinner: List[str] = ["\\", "|", "/", "-"]
|
||||
tick: int = 0
|
||||
while self.context.processing:
|
||||
logger.info(
|
||||
"TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time"
|
||||
)
|
||||
message.status = "waiting"
|
||||
message.response = (
|
||||
f"Busy processing another request. Please wait. {spinner[tick]}"
|
||||
)
|
||||
tick = (tick + 1) % len(spinner)
|
||||
yield message
|
||||
await asyncio.sleep(1) # Allow the event loop to process the write
|
||||
|
||||
self.context.processing = True
|
||||
|
||||
try:
|
||||
|
||||
#
|
||||
# Generate the persona
|
||||
#
|
||||
async for message in self.call_llm(
|
||||
message=message, system_prompt=self.system_prompt, prompt=original_prompt
|
||||
):
|
||||
if message.status != "done":
|
||||
yield message
|
||||
if message.status == "error":
|
||||
raise Exception(message.response)
|
||||
|
||||
json_str = self.extract_json_from_text(message.response)
|
||||
try:
|
||||
persona = json.loads(json_str) | {
|
||||
"age": self.age,
|
||||
"gender": self.gender,
|
||||
"ethnicity": self.ethnicity
|
||||
}
|
||||
if not persona.get("full_name", None):
|
||||
persona["full_name"] = f"{persona['first_name']} {persona['last_name']}"
|
||||
self.username = persona.get("username", None)
|
||||
if not self.username:
|
||||
raise ValueError("LLM did not generate a username")
|
||||
user_dir = os.path.join(defines.user_dir, persona["username"])
|
||||
while os.path.exists(user_dir):
|
||||
match = re.match(r"^(.*?)(\d*)$", persona["username"])
|
||||
if match:
|
||||
base = match.group(1)
|
||||
num = match.group(2)
|
||||
iteration = int(num) + 1 if num else 1
|
||||
persona["username"] = f"{base}{iteration}"
|
||||
user_dir = os.path.join(defines.user_dir, persona["username"])
|
||||
|
||||
for key in persona:
|
||||
if isinstance(persona[key], str):
|
||||
persona[key] = persona[key].strip()
|
||||
# Mark this persona as AI generated
|
||||
persona["is_ai"] = True
|
||||
except Exception as e:
|
||||
message.response = f"Unable to parse LLM returned content: {json_str} {str(e)}"
|
||||
message.status = "error"
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(message.response)
|
||||
yield message
|
||||
return
|
||||
|
||||
# Persona generated
|
||||
message.response = json.dumps(persona)
|
||||
message.status = "partial"
|
||||
yield message
|
||||
|
||||
#
|
||||
# Generate the resume
|
||||
#
|
||||
message.status = "thinking"
|
||||
message.response = f"Generating resume for {persona['full_name']}..."
|
||||
yield message
|
||||
|
||||
prompt = f"""
|
||||
```json
|
||||
{{
|
||||
"full_name": "{persona["full_name"]}",
|
||||
"location": "{persona["location"]}",
|
||||
"age": {persona["age"]},
|
||||
"description": {persona["description"]}
|
||||
}}
|
||||
```
|
||||
"""
|
||||
try:
|
||||
async for message in self.call_llm(
|
||||
message=message, system_prompt=generate_resume_system_prompt, prompt=prompt
|
||||
):
|
||||
if message.status != "done":
|
||||
yield message
|
||||
if message.status == "error":
|
||||
raise Exception(message.response)
|
||||
|
||||
except Exception as e:
|
||||
message.response = f"Unable to parse LLM returned content: {json_str} {str(e)}"
|
||||
message.status = "error"
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(message.response)
|
||||
yield message
|
||||
return
|
||||
|
||||
resume = self.extract_markdown_from_text(message.response)
|
||||
if resume:
|
||||
user_resume_dir = os.path.join(defines.user_dir, persona["username"], defines.resume_doc_dir)
|
||||
os.makedirs(user_resume_dir, exist_ok=True)
|
||||
user_resume_file = os.path.join(user_resume_dir, defines.resume_doc)
|
||||
with open(user_resume_file, "w") as f:
|
||||
f.write(resume)
|
||||
|
||||
# Resume generated
|
||||
message.response = resume
|
||||
message.status = "partial"
|
||||
yield message
|
||||
|
||||
#
|
||||
# Generate RAG database
|
||||
#
|
||||
message.status = "thinking"
|
||||
message.response = f"Generating RAG content from resume..."
|
||||
yield message
|
||||
|
||||
# Prior to instancing a new User, the json data has to be created
|
||||
# so the system can process it
|
||||
user_dir = os.path.join(defines.user_dir, persona["username"])
|
||||
os.makedirs(user_dir, exist_ok=True)
|
||||
user_info = os.path.join(user_dir, "info.json")
|
||||
with open(user_info, "w") as f:
|
||||
f.write(json.dumps(persona, indent=2))
|
||||
|
||||
user = User(llm=self.llm, username=self.username)
|
||||
await user.initialize()
|
||||
await user.file_watcher.initialize_collection()
|
||||
# RAG content generated
|
||||
message.status = "partial"
|
||||
message.response = f"{user.file_watcher.collection.count()} entries created in RAG vector store."
|
||||
yield message
|
||||
|
||||
#
|
||||
# Generate the profile picture
|
||||
#
|
||||
prompt = f"A photorealistic profile picture of a {persona["age"]} year old {persona["gender"]} {persona["ethnicity"]} person."
|
||||
if original_prompt:
|
||||
prompt = f"{prompt} {original_prompt}"
|
||||
message.status = "thinking"
|
||||
message.response = prompt
|
||||
yield message
|
||||
|
||||
logger.info("Beginning image generation...")
|
||||
request = ImageRequest(filepath=os.path.join(defines.user_dir, persona["username"], f"profile.png"), prompt=prompt)
|
||||
placeholder = Message(prompt=prompt)
|
||||
async for placeholder in generate_image_status(
|
||||
message=placeholder,
|
||||
**request.model_dump()
|
||||
):
|
||||
logger.info("Image generation continue...")
|
||||
if placeholder.status != "done":
|
||||
placeholder.response = placeholder.response
|
||||
yield placeholder
|
||||
logger.info("Image generation done...")
|
||||
persona["has_profile"] = True
|
||||
|
||||
#
|
||||
# Write out the completed user information
|
||||
#
|
||||
with open(user_info, "w") as f:
|
||||
f.write(json.dumps(persona, indent=2))
|
||||
|
||||
# Image generated
|
||||
message.status = "done"
|
||||
message.response = json.dumps(persona)
|
||||
|
||||
except Exception as e:
|
||||
message.status = "error"
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(message.response)
|
||||
message.response = f"Error in persona generation: {str(e)}"
|
||||
logger.error(message.response)
|
||||
self.randomize() # Randomize for next generation
|
||||
yield message
|
||||
return
|
||||
|
||||
# Done processing, add message to conversation
|
||||
self.context.processing = False
|
||||
self.randomize() # Randomize for next generation
|
||||
# Return the final message
|
||||
yield message
|
||||
return
|
||||
|
||||
async def call_llm(self, message: Message, system_prompt, prompt, temperature=0.7):
|
||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||
|
||||
messages: List[LLMMessage] = [
|
||||
LLMMessage(role="system", content=system_prompt),
|
||||
LLMMessage(role="user", content=prompt),
|
||||
]
|
||||
message.metadata.options = {
|
||||
"seed": 8911,
|
||||
"num_ctx": self.context_size,
|
||||
"temperature": temperature, # Higher temperature to encourage tool usage
|
||||
}
|
||||
|
||||
message.status = "streaming"
|
||||
yield message
|
||||
|
||||
last_chunk_time = 0
|
||||
message.chunk = ""
|
||||
message.response = ""
|
||||
for response in self.llm.chat(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
options={
|
||||
**message.metadata.options,
|
||||
},
|
||||
stream=True,
|
||||
):
|
||||
if not response:
|
||||
message.status = "error"
|
||||
message.response = "No response from LLM."
|
||||
yield message
|
||||
return
|
||||
|
||||
message.status = "streaming"
|
||||
message.chunk += response.message.content
|
||||
message.response += response.message.content
|
||||
|
||||
if not response.done:
|
||||
now = time.perf_counter()
|
||||
if now - last_chunk_time > 0.25:
|
||||
yield message
|
||||
last_chunk_time = now
|
||||
message.chunk = ""
|
||||
|
||||
if response.done:
|
||||
self.collect_metrics(response)
|
||||
message.metadata.eval_count += response.eval_count
|
||||
message.metadata.eval_duration += response.eval_duration
|
||||
message.metadata.prompt_eval_count += response.prompt_eval_count
|
||||
message.metadata.prompt_eval_duration += response.prompt_eval_duration
|
||||
self.context_tokens = response.prompt_eval_count + response.eval_count
|
||||
message.chunk = ""
|
||||
message.status = "done"
|
||||
yield message
|
||||
|
||||
def extract_json_from_text(self, text: str) -> str:
|
||||
"""Extract JSON string from text that may contain other content."""
|
||||
json_pattern = r"```json\s*([\s\S]*?)\s*```"
|
||||
match = re.search(json_pattern, text)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
# Try to find JSON without the markdown code block
|
||||
json_pattern = r"({[\s\S]*})"
|
||||
match = re.search(json_pattern, text)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
raise ValueError("No JSON found in the response")
|
||||
|
||||
def extract_markdown_from_text(self, text: str) -> str:
|
||||
"""Extract Markdown string from text that may contain other content."""
|
||||
markdown_pattern = r"```(md|markdown)\s*([\s\S]*?)\s*```"
|
||||
match = re.search(markdown_pattern, text)
|
||||
if match:
|
||||
return match.group(2).strip()
|
||||
|
||||
raise ValueError("No Markdown found in the response")
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(PersonaGenerator._agent_type, PersonaGenerator)
|
@ -129,7 +129,8 @@ class Context(BaseModel):
|
||||
agent = agent_cls(agent_type=agent_type, **kwargs)
|
||||
# set_context after constructor to initialize any non-serialized data
|
||||
agent.set_context(self)
|
||||
self.agents.append(agent)
|
||||
if agent.agent_persist: # If an agent is not set to persist, do not add it to the context
|
||||
self.agents.append(agent)
|
||||
return agent
|
||||
|
||||
raise ValueError(f"No agent class found for agent_type: {agent_type}")
|
||||
|
@ -13,6 +13,13 @@ import gc
|
||||
import tempfile
|
||||
import uuid
|
||||
import torch # type: ignore
|
||||
import asyncio
|
||||
import time
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
from threading import Thread
|
||||
import queue
|
||||
import uuid
|
||||
|
||||
from .agents.base import Agent, agent_registry, LLMMessage
|
||||
from .message import Message
|
||||
@ -24,11 +31,6 @@ logger = setup_logging()
|
||||
|
||||
# Heuristic time estimates (in seconds) for different models and devices at 512x512
|
||||
TIME_ESTIMATES = {
|
||||
"stable-diffusion": {
|
||||
"cuda": {"load": 5, "per_step": 0.5},
|
||||
"xpu": {"load": 7, "per_step": 0.7},
|
||||
"cpu": {"load": 20, "per_step": 5.0},
|
||||
},
|
||||
"flux": {
|
||||
"cuda": {"load": 10, "per_step": 0.8},
|
||||
"xpu": {"load": 15, "per_step": 1.0},
|
||||
@ -43,24 +45,176 @@ class ImageRequest(BaseModel):
|
||||
iterations: int = 4
|
||||
height: int = 256
|
||||
width: int = 256
|
||||
guidance_scale: float = 7.5
|
||||
|
||||
# Global model cache instance
|
||||
model_cache = ImageModelCache(timeout_seconds=60 * 15) # 15 minutes
|
||||
|
||||
def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str):
|
||||
"""Background worker for Flux image generation"""
|
||||
try:
|
||||
# Your existing estimates calculation
|
||||
estimates = {"per_step": 0.5} # Replace with your actual estimates
|
||||
resolution_scale = (params.height * params.width) / (512 * 512)
|
||||
|
||||
# Flux: Run generation in the background and yield progress updates
|
||||
estimated_gen_time = estimates["per_step"] * params.iterations * resolution_scale
|
||||
status_queue.put({
|
||||
"status": "running",
|
||||
"message": f"Starting Flux image generation with {params.iterations} inference steps",
|
||||
"estimated_time_remaining": estimated_gen_time,
|
||||
"progress": 0
|
||||
})
|
||||
|
||||
# Start the generation task
|
||||
start_gen_time = time.time()
|
||||
|
||||
# Simulate your pipe call with progress updates
|
||||
def status_callback(pipeline, step, timestep, callback_kwargs):
|
||||
# Send progress updates
|
||||
elapsed = time.time() - start_gen_time
|
||||
progress = int((step + 1) / params.iterations * 100)
|
||||
|
||||
status_queue.put({
|
||||
"status": "running",
|
||||
"message": f"Processing step {step}/{params.iterations}",
|
||||
"progress": progress
|
||||
})
|
||||
|
||||
# Replace this block with your actual Flux pipe call:
|
||||
image = pipe(
|
||||
params.prompt,
|
||||
num_inference_steps=params.iterations,
|
||||
guidance_scale=7.5,
|
||||
height=params.height,
|
||||
width=params.width,
|
||||
callback_on_step_end=status_callback,
|
||||
).images[0]
|
||||
|
||||
# Simulate image generation completion
|
||||
gen_time = time.time() - start_gen_time
|
||||
per_step_time = gen_time / params.iterations if params.iterations > 0 else gen_time
|
||||
|
||||
image.save(params.filepath)
|
||||
|
||||
# Final completion status
|
||||
status_queue.put({
|
||||
"status": "completed",
|
||||
"message": f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.",
|
||||
"progress": 100,
|
||||
"generation_time": gen_time,
|
||||
"per_step_time": per_step_time,
|
||||
"image_path": f"generated_{task_id}.png" # Replace with actual image path/URL
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
status_queue.put({
|
||||
"status": "error",
|
||||
"message": f"Generation failed: {str(e)}",
|
||||
"error": str(e),
|
||||
"progress": 0
|
||||
})
|
||||
|
||||
|
||||
async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
Single async function that handles background Flux generation with status streaming
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
status_queue = queue.Queue()
|
||||
worker_thread = None
|
||||
|
||||
try:
|
||||
# Start background worker thread
|
||||
worker_thread = Thread(
|
||||
target=flux_worker,
|
||||
args=(pipe, params, status_queue, task_id),
|
||||
daemon=True
|
||||
)
|
||||
worker_thread.start()
|
||||
|
||||
# Initial status
|
||||
yield {'status': 'starting', 'task_id': task_id, 'message': 'Initializing image generation'}
|
||||
|
||||
# Stream status updates
|
||||
completed = False
|
||||
last_heartbeat = time.time()
|
||||
|
||||
while not completed and worker_thread.is_alive():
|
||||
try:
|
||||
# Try to get status update (non-blocking)
|
||||
status_update = status_queue.get_nowait()
|
||||
|
||||
# Add task_id to status update
|
||||
status_update['task_id'] = task_id
|
||||
|
||||
# Send status update
|
||||
yield status_update
|
||||
|
||||
# Check if completed
|
||||
if status_update.get('status') in ['completed', 'error']:
|
||||
completed = True
|
||||
|
||||
last_heartbeat = time.time()
|
||||
|
||||
except queue.Empty:
|
||||
# No new status, send heartbeat if needed
|
||||
current_time = time.time()
|
||||
if current_time - last_heartbeat > 2: # Heartbeat every 2 seconds
|
||||
heartbeat = {
|
||||
'status': 'heartbeat',
|
||||
'task_id': task_id,
|
||||
'timestamp': current_time
|
||||
}
|
||||
yield heartbeat
|
||||
last_heartbeat = current_time
|
||||
|
||||
# Brief sleep to prevent busy waiting
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Handle thread completion or timeout
|
||||
if not completed:
|
||||
if worker_thread.is_alive():
|
||||
# Thread still running but we might have missed the completion signal
|
||||
timeout_status = {
|
||||
'status': 'timeout',
|
||||
'task_id': task_id,
|
||||
'message': 'Generation timed out or connection lost'
|
||||
}
|
||||
yield timeout_status
|
||||
else:
|
||||
# Thread completed but we might have missed the final status
|
||||
final_status = {
|
||||
'status': 'completed',
|
||||
'task_id': task_id,
|
||||
'message': 'Generation completed'
|
||||
}
|
||||
yield final_status
|
||||
|
||||
except Exception as e:
|
||||
error_status = {
|
||||
'status': 'error',
|
||||
'task_id': task_id,
|
||||
'message': f'Server error: {str(e)}',
|
||||
'error': str(e)
|
||||
}
|
||||
yield error_status
|
||||
|
||||
finally:
|
||||
# Cleanup: ensure thread completion
|
||||
if worker_thread and 'worker_thread' in locals() and worker_thread.is_alive():
|
||||
worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup
|
||||
|
||||
def status(message: Message, status: str, progress: float = 0, estimated_time_remaining="...") -> Message:
|
||||
message.status = "thinking"
|
||||
message.response = json.dumps({
|
||||
"status": status,
|
||||
"progress": progress,
|
||||
"estimated_time_remaining": estimated_time_remaining
|
||||
})
|
||||
message.response = status
|
||||
return message
|
||||
|
||||
async def generate_image_status(message: Message, model: str, prompt: str, iterations: int, filepath: str, height: int = 512, width: int = 512) -> AsyncGenerator[Message, None]:
|
||||
async def generate_image(message: Message, request: ImageRequest) -> AsyncGenerator[Message, None]:
|
||||
"""Generate an image with specified dimensions and yield status updates with time estimates."""
|
||||
try:
|
||||
# Validate prompt
|
||||
prompt = prompt.strip()
|
||||
prompt = request.prompt.strip()
|
||||
if not prompt:
|
||||
message.status = "error"
|
||||
message.response = "Prompt cannot be empty"
|
||||
@ -68,111 +222,46 @@ async def generate_image_status(message: Message, model: str, prompt: str, itera
|
||||
return
|
||||
|
||||
# Validate dimensions
|
||||
if height <= 0 or width <= 0:
|
||||
if request.height <= 0 or request.width <= 0:
|
||||
message.status = "error"
|
||||
message.response = "Height and width must be positive"
|
||||
yield message
|
||||
return
|
||||
if re.match(r".*stable-diffusion.*", model):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
message.status = "error"
|
||||
message.response = "Stable Diffusion requires height and width to be multiples of 8"
|
||||
yield message
|
||||
return
|
||||
|
||||
filedir = os.path.dirname(filepath)
|
||||
filename = os.path.basename(filepath)
|
||||
filedir = os.path.dirname(request.filepath)
|
||||
filename = os.path.basename(request.filepath)
|
||||
os.makedirs(filedir, exist_ok=True)
|
||||
|
||||
model_type = "stable-diffusion" if re.match(r".*stable-diffusion.*", model) else "flux"
|
||||
model_type = "flux"
|
||||
device = "cpu"
|
||||
|
||||
if model_type == "flux":
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu"
|
||||
|
||||
yield status(message, f"Starting image generation for prompt: {prompt} {width}x{height} as {filename} using {device}")
|
||||
yield status(message, f"Starting image generation for prompt: {prompt} {request.width}x{request.height} as {filename} using {device}")
|
||||
|
||||
# Get initial time estimate, scaled by resolution
|
||||
estimates = TIME_ESTIMATES[model_type][device]
|
||||
resolution_scale = (height * width) / (512 * 512)
|
||||
estimated_total = estimates["load"] + estimates["per_step"] * iterations * resolution_scale
|
||||
yield status(message, f"Estimated generation time: ~{estimated_total:.1f} seconds for {width}x{height}")
|
||||
resolution_scale = (request.height * request.width) / (512 * 512)
|
||||
estimated_total = estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale
|
||||
yield status(message, f"Estimated generation time: ~{estimated_total:.1f} seconds for {request.width}x{request.height}")
|
||||
|
||||
# Initialize or get cached pipeline
|
||||
start_time = time.time()
|
||||
yield status(message, f"Loading {model_type} model: {model}")
|
||||
pipe = await model_cache.get_pipeline(model, device)
|
||||
yield status(message, f"Loading {model_type} model: {request.model}")
|
||||
pipe = await model_cache.get_pipeline(request.model, device)
|
||||
load_time = time.time() - start_time
|
||||
yield status(message, f"Model loaded in {load_time:.1f} seconds. Generating image with {iterations} inference steps", progress=10)
|
||||
|
||||
# Generate image with progress tracking
|
||||
start_gen_time = time.time()
|
||||
|
||||
if model_type == "stable-diffusion":
|
||||
steps_completed = 0
|
||||
last_step_time = start_gen_time
|
||||
|
||||
def progress_callback(step: int, timestep: int, latents: torch.Tensor):
|
||||
nonlocal steps_completed, last_step_time
|
||||
steps_completed += 1
|
||||
current_time = time.time()
|
||||
step_time = current_time - last_step_time
|
||||
last_step_time = current_time
|
||||
progress = (steps_completed / iterations) * 80 + 10 # Scale from 10% to 90%
|
||||
remaining_steps = iterations - steps_completed
|
||||
estimated_remaining = step_time * remaining_steps * resolution_scale
|
||||
yield status(message=message, status=f"Step {steps_completed}/{iterations} completed", progress=progress, estimated_time_remaining=str(estimated_remaining))
|
||||
|
||||
async def capture_progress(step: int, timestep: int, latents: torch.Tensor):
|
||||
for msg in progress_callback(step, timestep, latents):
|
||||
yield msg
|
||||
|
||||
yield status(message, f"Generating image with {iterations} inference steps")
|
||||
image = pipe(
|
||||
prompt,
|
||||
num_inference_steps=iterations,
|
||||
guidance_scale=7.5,
|
||||
height=height,
|
||||
width=width,
|
||||
callback=capture_progress,
|
||||
callback_steps=1
|
||||
).images[0]
|
||||
else:
|
||||
# Flux: Run generation in the background and yield progress updates
|
||||
estimated_gen_time = estimates["per_step"] * iterations * resolution_scale
|
||||
yield status(message, f"Starting Flux image generation with {iterations} inference steps", estimated_time_remaining=estimated_gen_time)
|
||||
|
||||
image = pipe(
|
||||
prompt,
|
||||
num_inference_steps=iterations,
|
||||
guidance_scale=7.5,
|
||||
height=height,
|
||||
width=width
|
||||
).images[0]
|
||||
|
||||
# Start the generation task
|
||||
start_gen_time = time.time()
|
||||
|
||||
gen_time = time.time() - start_gen_time
|
||||
per_step_time = gen_time / iterations if iterations > 0 else gen_time
|
||||
yield status(message, f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.", 90)
|
||||
|
||||
gen_time = time.time() - start_gen_time
|
||||
per_step_time = gen_time / iterations if iterations > 0 else gen_time
|
||||
yield status(message, f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.", 90)
|
||||
|
||||
# Save image
|
||||
yield status(message, f"Saving image to {filepath}", 95)
|
||||
image.save(filepath)
|
||||
yield status(message, f"Model loaded in {load_time:.1f} seconds. Generating image with {request.iterations} inference steps", progress=10)
|
||||
|
||||
async for update in async_generate_image(pipe, request):
|
||||
message.status = update.get("status", "thinking")
|
||||
message.response = json.dumps(update) # Merge properties from async_generate_image over the message...
|
||||
yield message
|
||||
|
||||
# Final result
|
||||
total_time = time.time() - start_time
|
||||
message.status = "done"
|
||||
message.response = json.dumps({
|
||||
"status": f"Image generation complete in {total_time:.1f} seconds",
|
||||
"progress": 100,
|
||||
"filename": filepath
|
||||
"filename": request.filepath
|
||||
})
|
||||
yield message
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user