Gen image works
This commit is contained in:
parent
840ad9159b
commit
74dea019fa
@ -58,7 +58,7 @@ type BackstoryMessage = {
|
|||||||
full_content?: string;
|
full_content?: string;
|
||||||
response?: string; // Set when status === 'done', 'partial', or 'error'
|
response?: string; // Set when status === 'done', 'partial', or 'error'
|
||||||
chunk?: string; // Used when status === 'streaming'
|
chunk?: string; // Used when status === 'streaming'
|
||||||
timestamp?: string;
|
timestamp?: number;
|
||||||
disableCopy?: boolean,
|
disableCopy?: boolean,
|
||||||
user?: string,
|
user?: string,
|
||||||
title?: string,
|
title?: string,
|
||||||
|
@ -115,6 +115,7 @@ const BackstoryPageContainer = (props : BackstoryPageContainerProps) => {
|
|||||||
borderRadius: 2,
|
borderRadius: 2,
|
||||||
minHeight: '80vh',
|
minHeight: '80vh',
|
||||||
width: "100%",
|
width: "100%",
|
||||||
|
flexDirection: "column",
|
||||||
}}>
|
}}>
|
||||||
{children}
|
{children}
|
||||||
</Paper>
|
</Paper>
|
||||||
|
@ -17,7 +17,7 @@ from .setup_logging import setup_logging
|
|||||||
from .agents import class_registry, AnyAgent, Agent, __all__ as agents_all
|
from .agents import class_registry, AnyAgent, Agent, __all__ as agents_all
|
||||||
from .metrics import Metrics
|
from .metrics import Metrics
|
||||||
from .check_serializable import check_serializable
|
from .check_serializable import check_serializable
|
||||||
from .profile_image import generate_image_status, ImageRequest
|
from .profile_image import generate_image, ImageRequest
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Agent",
|
"Agent",
|
||||||
@ -34,7 +34,7 @@ __all__ = [
|
|||||||
"check_serializable",
|
"check_serializable",
|
||||||
"logger",
|
"logger",
|
||||||
"User",
|
"User",
|
||||||
"generate_image_status", "ImageRequest"
|
"generate_image", "ImageRequest"
|
||||||
]
|
]
|
||||||
|
|
||||||
__all__.extend(agents_all) # type: ignore
|
__all__.extend(agents_all) # type: ignore
|
||||||
|
@ -50,6 +50,7 @@ class Agent(BaseModel, ABC):
|
|||||||
# Agent management with pydantic
|
# Agent management with pydantic
|
||||||
agent_type: Literal["base"] = "base"
|
agent_type: Literal["base"] = "base"
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
agent_persist: bool = True # Whether this agent will persist in the context
|
||||||
|
|
||||||
# Tunables (sets default for new Messages attached to this agent)
|
# Tunables (sets default for new Messages attached to this agent)
|
||||||
tunables: Tunables = Field(default_factory=Tunables)
|
tunables: Tunables = Field(default_factory=Tunables)
|
||||||
|
@ -1,448 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
from pydantic import model_validator, Field, BaseModel # type: ignore
|
|
||||||
from typing import (
|
|
||||||
Dict,
|
|
||||||
Literal,
|
|
||||||
ClassVar,
|
|
||||||
cast,
|
|
||||||
Any,
|
|
||||||
AsyncGenerator,
|
|
||||||
List,
|
|
||||||
Optional
|
|
||||||
# override
|
|
||||||
) # NOTE: You must import Optional for late binding to work
|
|
||||||
import inspect
|
|
||||||
import random
|
|
||||||
import re
|
|
||||||
import json
|
|
||||||
import traceback
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import os
|
|
||||||
|
|
||||||
from . base import Agent, agent_registry, LLMMessage
|
|
||||||
from .. message import Message
|
|
||||||
from .. rag import ChromaDBGetResponse
|
|
||||||
from .. setup_logging import setup_logging
|
|
||||||
from .. profile_image import generate_image_status, ImageRequest
|
|
||||||
from .. import defines
|
|
||||||
from .. user import User
|
|
||||||
|
|
||||||
logger = setup_logging()
|
|
||||||
|
|
||||||
seed = int(time.time())
|
|
||||||
random.seed(seed)
|
|
||||||
|
|
||||||
emptyUser = {
|
|
||||||
"profile_url": "",
|
|
||||||
"description": "",
|
|
||||||
"rag_content_size": 0,
|
|
||||||
"username": "",
|
|
||||||
"first_name": "",
|
|
||||||
"last_name": "",
|
|
||||||
"full_name": "",
|
|
||||||
"contact_info": {},
|
|
||||||
"questions": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
generate_persona_system_prompt = """\
|
|
||||||
You are a casing director for a movie. Your job is to provide information on ficticious personas for use in a screen play.
|
|
||||||
|
|
||||||
All response field MUST BE IN ENGLISH, regardless of ethnicity.
|
|
||||||
|
|
||||||
You will be provided with defaults to use if not specified by the user:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"age": number,
|
|
||||||
"gender": "male" | "female",
|
|
||||||
"ethnicity": string,
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Additional information provided in the user message can override those defaults.
|
|
||||||
|
|
||||||
You need to randomly assign an English username (can include numbers), a first name, last name, and a two English sentence description of that individual's work given the demographics provided.
|
|
||||||
|
|
||||||
Your response must be in JSON.
|
|
||||||
Provide only the JSON response, and match the field names EXACTLY.
|
|
||||||
Provide all information in English ONLY, with no other commentary:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"username": string, # A likely-to-be unique username, no more than 15 characters (can include numbers and letters but no special characters)
|
|
||||||
"first_name": string,
|
|
||||||
"last_name": string,
|
|
||||||
"description": string, # One to two sentence description of their job
|
|
||||||
"location": string, # In the location, provide ALL of: City, State/Region, and Country
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Make sure to provide a username and that the field name for the job description is "description".
|
|
||||||
"""
|
|
||||||
|
|
||||||
generate_resume_system_prompt = """
|
|
||||||
You are a creative writing casting director. As part of the casting, you are building backstories about individuals. The first part
|
|
||||||
of that is to create an in-depth resume for the person. You will be provided with the following information:
|
|
||||||
|
|
||||||
```json
|
|
||||||
"full_name": string, # Person full name
|
|
||||||
"location": string, # Location of residence
|
|
||||||
"age": number, # Age of candidate
|
|
||||||
"description": string # A brief description of the person
|
|
||||||
```
|
|
||||||
|
|
||||||
Use that information to invent a full career resume. Include sections such as:
|
|
||||||
|
|
||||||
* Contact information
|
|
||||||
* Job goal
|
|
||||||
* Top skills
|
|
||||||
* Detailed work history. If they are under the age of 25, you might include skills, hobbies, or volunteering they may have done while an adolescent
|
|
||||||
* In the work history, provide company names, years of employment, and their role
|
|
||||||
* Education
|
|
||||||
|
|
||||||
Provide the resume in Markdown format. DO NOT provide any commentary before or after the resume.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class PersonaGenerator(Agent):
|
|
||||||
agent_type: Literal["persona"] = "persona" # type: ignore
|
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
|
||||||
|
|
||||||
system_prompt: str = generate_persona_system_prompt
|
|
||||||
age: int = Field(default_factory=lambda: random.randint(22, 67))
|
|
||||||
gender: str = Field(default_factory=lambda: random.choice(["male", "female"]))
|
|
||||||
ethnicity: Literal[
|
|
||||||
"Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"
|
|
||||||
] = Field(
|
|
||||||
default_factory=lambda: random.choices(
|
|
||||||
["Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"],
|
|
||||||
weights=[57.69, 15.38, 19.23, 5.77, 1.92],
|
|
||||||
k=1
|
|
||||||
)[0]
|
|
||||||
)
|
|
||||||
username: str = ""
|
|
||||||
|
|
||||||
llm: Any = Field(default=None, exclude=True)
|
|
||||||
model: str = Field(default=None, exclude=True)
|
|
||||||
|
|
||||||
def randomize(self):
|
|
||||||
self.age = random.randint(22, 67)
|
|
||||||
self.gender = random.choice(["male", "female"])
|
|
||||||
# Use random.choices with explicit type casting to satisfy Literal type
|
|
||||||
self.ethnicity = cast(
|
|
||||||
Literal["Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"],
|
|
||||||
random.choices(
|
|
||||||
["Asian", "African", "Caucasian", "Hispanic/Latino", "Mixed/Multiracial"],
|
|
||||||
weights=[57.69, 15.38, 19.23, 5.77, 1.92],
|
|
||||||
k=1
|
|
||||||
)[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
|
|
||||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
|
||||||
|
|
||||||
if not self.context:
|
|
||||||
raise ValueError("Context is not set for this agent.")
|
|
||||||
|
|
||||||
message.tunables.enable_tools = False
|
|
||||||
message.tunables.enable_rag = False
|
|
||||||
message.tunables.enable_context = False
|
|
||||||
|
|
||||||
message.prompt = f"""\
|
|
||||||
```json
|
|
||||||
{json.dumps({
|
|
||||||
"age": self.age,
|
|
||||||
"gender": self.gender,
|
|
||||||
"ethnicity": self.ethnicity
|
|
||||||
})}
|
|
||||||
```
|
|
||||||
{message.prompt}
|
|
||||||
"""
|
|
||||||
message.status = "done"
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
async def process_message(
|
|
||||||
self, llm: Any, model: str, message: Message
|
|
||||||
) -> AsyncGenerator[Message, None]:
|
|
||||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
|
||||||
if not self.context:
|
|
||||||
raise ValueError("Context is not set for this agent.")
|
|
||||||
|
|
||||||
self.llm = llm
|
|
||||||
self.model = model
|
|
||||||
original_prompt = message.prompt
|
|
||||||
|
|
||||||
spinner: List[str] = ["\\", "|", "/", "-"]
|
|
||||||
tick: int = 0
|
|
||||||
while self.context.processing:
|
|
||||||
logger.info(
|
|
||||||
"TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time"
|
|
||||||
)
|
|
||||||
message.status = "waiting"
|
|
||||||
message.response = (
|
|
||||||
f"Busy processing another request. Please wait. {spinner[tick]}"
|
|
||||||
)
|
|
||||||
tick = (tick + 1) % len(spinner)
|
|
||||||
yield message
|
|
||||||
await asyncio.sleep(1) # Allow the event loop to process the write
|
|
||||||
|
|
||||||
self.context.processing = True
|
|
||||||
|
|
||||||
try:
|
|
||||||
|
|
||||||
#
|
|
||||||
# Generate the persona
|
|
||||||
#
|
|
||||||
async for message in self.call_llm(
|
|
||||||
message=message, system_prompt=self.system_prompt, prompt=original_prompt
|
|
||||||
):
|
|
||||||
if message.status != "done":
|
|
||||||
yield message
|
|
||||||
if message.status == "error":
|
|
||||||
raise Exception(message.response)
|
|
||||||
|
|
||||||
json_str = self.extract_json_from_text(message.response)
|
|
||||||
try:
|
|
||||||
persona = json.loads(json_str) | {
|
|
||||||
"age": self.age,
|
|
||||||
"gender": self.gender,
|
|
||||||
"ethnicity": self.ethnicity
|
|
||||||
}
|
|
||||||
if not persona.get("full_name", None):
|
|
||||||
persona["full_name"] = f"{persona['first_name']} {persona['last_name']}"
|
|
||||||
self.username = persona.get("username", None)
|
|
||||||
if not self.username:
|
|
||||||
raise ValueError("LLM did not generate a username")
|
|
||||||
user_dir = os.path.join(defines.user_dir, persona["username"])
|
|
||||||
while os.path.exists(user_dir):
|
|
||||||
match = re.match(r"^(.*?)(\d*)$", persona["username"])
|
|
||||||
if match:
|
|
||||||
base = match.group(1)
|
|
||||||
num = match.group(2)
|
|
||||||
iteration = int(num) + 1 if num else 1
|
|
||||||
persona["username"] = f"{base}{iteration}"
|
|
||||||
user_dir = os.path.join(defines.user_dir, persona["username"])
|
|
||||||
|
|
||||||
for key in persona:
|
|
||||||
if isinstance(persona[key], str):
|
|
||||||
persona[key] = persona[key].strip()
|
|
||||||
# Mark this persona as AI generated
|
|
||||||
persona["is_ai"] = True
|
|
||||||
except Exception as e:
|
|
||||||
message.response = f"Unable to parse LLM returned content: {json_str} {str(e)}"
|
|
||||||
message.status = "error"
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
logger.error(message.response)
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
# Persona generated
|
|
||||||
message.response = json.dumps(persona)
|
|
||||||
message.status = "partial"
|
|
||||||
yield message
|
|
||||||
|
|
||||||
#
|
|
||||||
# Generate the resume
|
|
||||||
#
|
|
||||||
message.status = "thinking"
|
|
||||||
message.response = f"Generating resume for {persona['full_name']}..."
|
|
||||||
yield message
|
|
||||||
|
|
||||||
prompt = f"""
|
|
||||||
```json
|
|
||||||
{{
|
|
||||||
"full_name": "{persona["full_name"]}",
|
|
||||||
"location": "{persona["location"]}",
|
|
||||||
"age": {persona["age"]},
|
|
||||||
"description": {persona["description"]}
|
|
||||||
}}
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
async for message in self.call_llm(
|
|
||||||
message=message, system_prompt=generate_resume_system_prompt, prompt=prompt
|
|
||||||
):
|
|
||||||
if message.status != "done":
|
|
||||||
yield message
|
|
||||||
if message.status == "error":
|
|
||||||
raise Exception(message.response)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
message.response = f"Unable to parse LLM returned content: {json_str} {str(e)}"
|
|
||||||
message.status = "error"
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
logger.error(message.response)
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
resume = self.extract_markdown_from_text(message.response)
|
|
||||||
if resume:
|
|
||||||
user_resume_dir = os.path.join(defines.user_dir, persona["username"], defines.resume_doc_dir)
|
|
||||||
os.makedirs(user_resume_dir, exist_ok=True)
|
|
||||||
user_resume_file = os.path.join(user_resume_dir, defines.resume_doc)
|
|
||||||
with open(user_resume_file, "w") as f:
|
|
||||||
f.write(resume)
|
|
||||||
|
|
||||||
# Resume generated
|
|
||||||
message.response = resume
|
|
||||||
message.status = "partial"
|
|
||||||
yield message
|
|
||||||
|
|
||||||
#
|
|
||||||
# Generate RAG database
|
|
||||||
#
|
|
||||||
message.status = "thinking"
|
|
||||||
message.response = f"Generating RAG content from resume..."
|
|
||||||
yield message
|
|
||||||
|
|
||||||
# Prior to instancing a new User, the json data has to be created
|
|
||||||
# so the system can process it
|
|
||||||
user_dir = os.path.join(defines.user_dir, persona["username"])
|
|
||||||
os.makedirs(user_dir, exist_ok=True)
|
|
||||||
user_info = os.path.join(user_dir, "info.json")
|
|
||||||
with open(user_info, "w") as f:
|
|
||||||
f.write(json.dumps(persona, indent=2))
|
|
||||||
|
|
||||||
user = User(llm=self.llm, username=self.username)
|
|
||||||
await user.initialize()
|
|
||||||
await user.file_watcher.initialize_collection()
|
|
||||||
# RAG content generated
|
|
||||||
message.status = "partial"
|
|
||||||
message.response = f"{user.file_watcher.collection.count()} entries created in RAG vector store."
|
|
||||||
yield message
|
|
||||||
|
|
||||||
#
|
|
||||||
# Generate the profile picture
|
|
||||||
#
|
|
||||||
prompt = f"A photorealistic profile picture of a {persona["age"]} year old {persona["gender"]} {persona["ethnicity"]} person."
|
|
||||||
if original_prompt:
|
|
||||||
prompt = f"{prompt} {original_prompt}"
|
|
||||||
message.status = "thinking"
|
|
||||||
message.response = prompt
|
|
||||||
yield message
|
|
||||||
|
|
||||||
logger.info("Beginning image generation...")
|
|
||||||
request = ImageRequest(filepath=os.path.join(defines.user_dir, persona["username"], f"profile.png"), prompt=prompt)
|
|
||||||
placeholder = Message(prompt=prompt)
|
|
||||||
async for placeholder in generate_image_status(
|
|
||||||
message=placeholder,
|
|
||||||
**request.model_dump()
|
|
||||||
):
|
|
||||||
logger.info("Image generation continue...")
|
|
||||||
if placeholder.status != "done":
|
|
||||||
placeholder.response = placeholder.response
|
|
||||||
yield placeholder
|
|
||||||
logger.info("Image generation done...")
|
|
||||||
persona["has_profile"] = True
|
|
||||||
|
|
||||||
#
|
|
||||||
# Write out the completed user information
|
|
||||||
#
|
|
||||||
with open(user_info, "w") as f:
|
|
||||||
f.write(json.dumps(persona, indent=2))
|
|
||||||
|
|
||||||
# Image generated
|
|
||||||
message.status = "done"
|
|
||||||
message.response = json.dumps(persona)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
message.status = "error"
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
logger.error(message.response)
|
|
||||||
message.response = f"Error in persona generation: {str(e)}"
|
|
||||||
logger.error(message.response)
|
|
||||||
self.randomize() # Randomize for next generation
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
# Done processing, add message to conversation
|
|
||||||
self.context.processing = False
|
|
||||||
self.randomize() # Randomize for next generation
|
|
||||||
# Return the final message
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
async def call_llm(self, message: Message, system_prompt, prompt, temperature=0.7):
|
|
||||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
|
||||||
|
|
||||||
messages: List[LLMMessage] = [
|
|
||||||
LLMMessage(role="system", content=system_prompt),
|
|
||||||
LLMMessage(role="user", content=prompt),
|
|
||||||
]
|
|
||||||
message.metadata.options = {
|
|
||||||
"seed": 8911,
|
|
||||||
"num_ctx": self.context_size,
|
|
||||||
"temperature": temperature, # Higher temperature to encourage tool usage
|
|
||||||
}
|
|
||||||
|
|
||||||
message.status = "streaming"
|
|
||||||
yield message
|
|
||||||
|
|
||||||
last_chunk_time = 0
|
|
||||||
message.chunk = ""
|
|
||||||
message.response = ""
|
|
||||||
for response in self.llm.chat(
|
|
||||||
model=self.model,
|
|
||||||
messages=messages,
|
|
||||||
options={
|
|
||||||
**message.metadata.options,
|
|
||||||
},
|
|
||||||
stream=True,
|
|
||||||
):
|
|
||||||
if not response:
|
|
||||||
message.status = "error"
|
|
||||||
message.response = "No response from LLM."
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
message.status = "streaming"
|
|
||||||
message.chunk += response.message.content
|
|
||||||
message.response += response.message.content
|
|
||||||
|
|
||||||
if not response.done:
|
|
||||||
now = time.perf_counter()
|
|
||||||
if now - last_chunk_time > 0.25:
|
|
||||||
yield message
|
|
||||||
last_chunk_time = now
|
|
||||||
message.chunk = ""
|
|
||||||
|
|
||||||
if response.done:
|
|
||||||
self.collect_metrics(response)
|
|
||||||
message.metadata.eval_count += response.eval_count
|
|
||||||
message.metadata.eval_duration += response.eval_duration
|
|
||||||
message.metadata.prompt_eval_count += response.prompt_eval_count
|
|
||||||
message.metadata.prompt_eval_duration += response.prompt_eval_duration
|
|
||||||
self.context_tokens = response.prompt_eval_count + response.eval_count
|
|
||||||
message.chunk = ""
|
|
||||||
message.status = "done"
|
|
||||||
yield message
|
|
||||||
|
|
||||||
def extract_json_from_text(self, text: str) -> str:
|
|
||||||
"""Extract JSON string from text that may contain other content."""
|
|
||||||
json_pattern = r"```json\s*([\s\S]*?)\s*```"
|
|
||||||
match = re.search(json_pattern, text)
|
|
||||||
if match:
|
|
||||||
return match.group(1).strip()
|
|
||||||
|
|
||||||
# Try to find JSON without the markdown code block
|
|
||||||
json_pattern = r"({[\s\S]*})"
|
|
||||||
match = re.search(json_pattern, text)
|
|
||||||
if match:
|
|
||||||
return match.group(1).strip()
|
|
||||||
|
|
||||||
raise ValueError("No JSON found in the response")
|
|
||||||
|
|
||||||
def extract_markdown_from_text(self, text: str) -> str:
|
|
||||||
"""Extract Markdown string from text that may contain other content."""
|
|
||||||
markdown_pattern = r"```(md|markdown)\s*([\s\S]*?)\s*```"
|
|
||||||
match = re.search(markdown_pattern, text)
|
|
||||||
if match:
|
|
||||||
return match.group(2).strip()
|
|
||||||
|
|
||||||
raise ValueError("No Markdown found in the response")
|
|
||||||
|
|
||||||
# Register the base agent
|
|
||||||
agent_registry.register(PersonaGenerator._agent_type, PersonaGenerator)
|
|
@ -129,7 +129,8 @@ class Context(BaseModel):
|
|||||||
agent = agent_cls(agent_type=agent_type, **kwargs)
|
agent = agent_cls(agent_type=agent_type, **kwargs)
|
||||||
# set_context after constructor to initialize any non-serialized data
|
# set_context after constructor to initialize any non-serialized data
|
||||||
agent.set_context(self)
|
agent.set_context(self)
|
||||||
self.agents.append(agent)
|
if agent.agent_persist: # If an agent is not set to persist, do not add it to the context
|
||||||
|
self.agents.append(agent)
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
raise ValueError(f"No agent class found for agent_type: {agent_type}")
|
raise ValueError(f"No agent class found for agent_type: {agent_type}")
|
||||||
|
@ -13,6 +13,13 @@ import gc
|
|||||||
import tempfile
|
import tempfile
|
||||||
import uuid
|
import uuid
|
||||||
import torch # type: ignore
|
import torch # type: ignore
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
from threading import Thread
|
||||||
|
import queue
|
||||||
|
import uuid
|
||||||
|
|
||||||
from .agents.base import Agent, agent_registry, LLMMessage
|
from .agents.base import Agent, agent_registry, LLMMessage
|
||||||
from .message import Message
|
from .message import Message
|
||||||
@ -24,11 +31,6 @@ logger = setup_logging()
|
|||||||
|
|
||||||
# Heuristic time estimates (in seconds) for different models and devices at 512x512
|
# Heuristic time estimates (in seconds) for different models and devices at 512x512
|
||||||
TIME_ESTIMATES = {
|
TIME_ESTIMATES = {
|
||||||
"stable-diffusion": {
|
|
||||||
"cuda": {"load": 5, "per_step": 0.5},
|
|
||||||
"xpu": {"load": 7, "per_step": 0.7},
|
|
||||||
"cpu": {"load": 20, "per_step": 5.0},
|
|
||||||
},
|
|
||||||
"flux": {
|
"flux": {
|
||||||
"cuda": {"load": 10, "per_step": 0.8},
|
"cuda": {"load": 10, "per_step": 0.8},
|
||||||
"xpu": {"load": 15, "per_step": 1.0},
|
"xpu": {"load": 15, "per_step": 1.0},
|
||||||
@ -43,24 +45,176 @@ class ImageRequest(BaseModel):
|
|||||||
iterations: int = 4
|
iterations: int = 4
|
||||||
height: int = 256
|
height: int = 256
|
||||||
width: int = 256
|
width: int = 256
|
||||||
|
guidance_scale: float = 7.5
|
||||||
|
|
||||||
# Global model cache instance
|
# Global model cache instance
|
||||||
model_cache = ImageModelCache(timeout_seconds=60 * 15) # 15 minutes
|
model_cache = ImageModelCache(timeout_seconds=60 * 15) # 15 minutes
|
||||||
|
|
||||||
|
def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str):
|
||||||
|
"""Background worker for Flux image generation"""
|
||||||
|
try:
|
||||||
|
# Your existing estimates calculation
|
||||||
|
estimates = {"per_step": 0.5} # Replace with your actual estimates
|
||||||
|
resolution_scale = (params.height * params.width) / (512 * 512)
|
||||||
|
|
||||||
|
# Flux: Run generation in the background and yield progress updates
|
||||||
|
estimated_gen_time = estimates["per_step"] * params.iterations * resolution_scale
|
||||||
|
status_queue.put({
|
||||||
|
"status": "running",
|
||||||
|
"message": f"Starting Flux image generation with {params.iterations} inference steps",
|
||||||
|
"estimated_time_remaining": estimated_gen_time,
|
||||||
|
"progress": 0
|
||||||
|
})
|
||||||
|
|
||||||
|
# Start the generation task
|
||||||
|
start_gen_time = time.time()
|
||||||
|
|
||||||
|
# Simulate your pipe call with progress updates
|
||||||
|
def status_callback(pipeline, step, timestep, callback_kwargs):
|
||||||
|
# Send progress updates
|
||||||
|
elapsed = time.time() - start_gen_time
|
||||||
|
progress = int((step + 1) / params.iterations * 100)
|
||||||
|
|
||||||
|
status_queue.put({
|
||||||
|
"status": "running",
|
||||||
|
"message": f"Processing step {step}/{params.iterations}",
|
||||||
|
"progress": progress
|
||||||
|
})
|
||||||
|
|
||||||
|
# Replace this block with your actual Flux pipe call:
|
||||||
|
image = pipe(
|
||||||
|
params.prompt,
|
||||||
|
num_inference_steps=params.iterations,
|
||||||
|
guidance_scale=7.5,
|
||||||
|
height=params.height,
|
||||||
|
width=params.width,
|
||||||
|
callback_on_step_end=status_callback,
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
# Simulate image generation completion
|
||||||
|
gen_time = time.time() - start_gen_time
|
||||||
|
per_step_time = gen_time / params.iterations if params.iterations > 0 else gen_time
|
||||||
|
|
||||||
|
image.save(params.filepath)
|
||||||
|
|
||||||
|
# Final completion status
|
||||||
|
status_queue.put({
|
||||||
|
"status": "completed",
|
||||||
|
"message": f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.",
|
||||||
|
"progress": 100,
|
||||||
|
"generation_time": gen_time,
|
||||||
|
"per_step_time": per_step_time,
|
||||||
|
"image_path": f"generated_{task_id}.png" # Replace with actual image path/URL
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
status_queue.put({
|
||||||
|
"status": "error",
|
||||||
|
"message": f"Generation failed: {str(e)}",
|
||||||
|
"error": str(e),
|
||||||
|
"progress": 0
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[Dict[str, Any], None]:
|
||||||
|
"""
|
||||||
|
Single async function that handles background Flux generation with status streaming
|
||||||
|
"""
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
status_queue = queue.Queue()
|
||||||
|
worker_thread = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Start background worker thread
|
||||||
|
worker_thread = Thread(
|
||||||
|
target=flux_worker,
|
||||||
|
args=(pipe, params, status_queue, task_id),
|
||||||
|
daemon=True
|
||||||
|
)
|
||||||
|
worker_thread.start()
|
||||||
|
|
||||||
|
# Initial status
|
||||||
|
yield {'status': 'starting', 'task_id': task_id, 'message': 'Initializing image generation'}
|
||||||
|
|
||||||
|
# Stream status updates
|
||||||
|
completed = False
|
||||||
|
last_heartbeat = time.time()
|
||||||
|
|
||||||
|
while not completed and worker_thread.is_alive():
|
||||||
|
try:
|
||||||
|
# Try to get status update (non-blocking)
|
||||||
|
status_update = status_queue.get_nowait()
|
||||||
|
|
||||||
|
# Add task_id to status update
|
||||||
|
status_update['task_id'] = task_id
|
||||||
|
|
||||||
|
# Send status update
|
||||||
|
yield status_update
|
||||||
|
|
||||||
|
# Check if completed
|
||||||
|
if status_update.get('status') in ['completed', 'error']:
|
||||||
|
completed = True
|
||||||
|
|
||||||
|
last_heartbeat = time.time()
|
||||||
|
|
||||||
|
except queue.Empty:
|
||||||
|
# No new status, send heartbeat if needed
|
||||||
|
current_time = time.time()
|
||||||
|
if current_time - last_heartbeat > 2: # Heartbeat every 2 seconds
|
||||||
|
heartbeat = {
|
||||||
|
'status': 'heartbeat',
|
||||||
|
'task_id': task_id,
|
||||||
|
'timestamp': current_time
|
||||||
|
}
|
||||||
|
yield heartbeat
|
||||||
|
last_heartbeat = current_time
|
||||||
|
|
||||||
|
# Brief sleep to prevent busy waiting
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# Handle thread completion or timeout
|
||||||
|
if not completed:
|
||||||
|
if worker_thread.is_alive():
|
||||||
|
# Thread still running but we might have missed the completion signal
|
||||||
|
timeout_status = {
|
||||||
|
'status': 'timeout',
|
||||||
|
'task_id': task_id,
|
||||||
|
'message': 'Generation timed out or connection lost'
|
||||||
|
}
|
||||||
|
yield timeout_status
|
||||||
|
else:
|
||||||
|
# Thread completed but we might have missed the final status
|
||||||
|
final_status = {
|
||||||
|
'status': 'completed',
|
||||||
|
'task_id': task_id,
|
||||||
|
'message': 'Generation completed'
|
||||||
|
}
|
||||||
|
yield final_status
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_status = {
|
||||||
|
'status': 'error',
|
||||||
|
'task_id': task_id,
|
||||||
|
'message': f'Server error: {str(e)}',
|
||||||
|
'error': str(e)
|
||||||
|
}
|
||||||
|
yield error_status
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Cleanup: ensure thread completion
|
||||||
|
if worker_thread and 'worker_thread' in locals() and worker_thread.is_alive():
|
||||||
|
worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup
|
||||||
|
|
||||||
def status(message: Message, status: str, progress: float = 0, estimated_time_remaining="...") -> Message:
|
def status(message: Message, status: str, progress: float = 0, estimated_time_remaining="...") -> Message:
|
||||||
message.status = "thinking"
|
message.status = "thinking"
|
||||||
message.response = json.dumps({
|
message.response = status
|
||||||
"status": status,
|
|
||||||
"progress": progress,
|
|
||||||
"estimated_time_remaining": estimated_time_remaining
|
|
||||||
})
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
async def generate_image_status(message: Message, model: str, prompt: str, iterations: int, filepath: str, height: int = 512, width: int = 512) -> AsyncGenerator[Message, None]:
|
async def generate_image(message: Message, request: ImageRequest) -> AsyncGenerator[Message, None]:
|
||||||
"""Generate an image with specified dimensions and yield status updates with time estimates."""
|
"""Generate an image with specified dimensions and yield status updates with time estimates."""
|
||||||
try:
|
try:
|
||||||
# Validate prompt
|
# Validate prompt
|
||||||
prompt = prompt.strip()
|
prompt = request.prompt.strip()
|
||||||
if not prompt:
|
if not prompt:
|
||||||
message.status = "error"
|
message.status = "error"
|
||||||
message.response = "Prompt cannot be empty"
|
message.response = "Prompt cannot be empty"
|
||||||
@ -68,103 +222,38 @@ async def generate_image_status(message: Message, model: str, prompt: str, itera
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Validate dimensions
|
# Validate dimensions
|
||||||
if height <= 0 or width <= 0:
|
if request.height <= 0 or request.width <= 0:
|
||||||
message.status = "error"
|
message.status = "error"
|
||||||
message.response = "Height and width must be positive"
|
message.response = "Height and width must be positive"
|
||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
if re.match(r".*stable-diffusion.*", model):
|
|
||||||
if height % 8 != 0 or width % 8 != 0:
|
|
||||||
message.status = "error"
|
|
||||||
message.response = "Stable Diffusion requires height and width to be multiples of 8"
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
filedir = os.path.dirname(filepath)
|
filedir = os.path.dirname(request.filepath)
|
||||||
filename = os.path.basename(filepath)
|
filename = os.path.basename(request.filepath)
|
||||||
os.makedirs(filedir, exist_ok=True)
|
os.makedirs(filedir, exist_ok=True)
|
||||||
|
|
||||||
model_type = "stable-diffusion" if re.match(r".*stable-diffusion.*", model) else "flux"
|
model_type = "flux"
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
if model_type == "flux":
|
yield status(message, f"Starting image generation for prompt: {prompt} {request.width}x{request.height} as {filename} using {device}")
|
||||||
device = "cpu"
|
|
||||||
else:
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "xpu" if torch.xpu.is_available() else "cpu"
|
|
||||||
|
|
||||||
yield status(message, f"Starting image generation for prompt: {prompt} {width}x{height} as {filename} using {device}")
|
|
||||||
|
|
||||||
# Get initial time estimate, scaled by resolution
|
# Get initial time estimate, scaled by resolution
|
||||||
estimates = TIME_ESTIMATES[model_type][device]
|
estimates = TIME_ESTIMATES[model_type][device]
|
||||||
resolution_scale = (height * width) / (512 * 512)
|
resolution_scale = (request.height * request.width) / (512 * 512)
|
||||||
estimated_total = estimates["load"] + estimates["per_step"] * iterations * resolution_scale
|
estimated_total = estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale
|
||||||
yield status(message, f"Estimated generation time: ~{estimated_total:.1f} seconds for {width}x{height}")
|
yield status(message, f"Estimated generation time: ~{estimated_total:.1f} seconds for {request.width}x{request.height}")
|
||||||
|
|
||||||
# Initialize or get cached pipeline
|
# Initialize or get cached pipeline
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
yield status(message, f"Loading {model_type} model: {model}")
|
yield status(message, f"Loading {model_type} model: {request.model}")
|
||||||
pipe = await model_cache.get_pipeline(model, device)
|
pipe = await model_cache.get_pipeline(request.model, device)
|
||||||
load_time = time.time() - start_time
|
load_time = time.time() - start_time
|
||||||
yield status(message, f"Model loaded in {load_time:.1f} seconds. Generating image with {iterations} inference steps", progress=10)
|
yield status(message, f"Model loaded in {load_time:.1f} seconds. Generating image with {request.iterations} inference steps", progress=10)
|
||||||
|
|
||||||
# Generate image with progress tracking
|
async for update in async_generate_image(pipe, request):
|
||||||
start_gen_time = time.time()
|
message.status = update.get("status", "thinking")
|
||||||
|
message.response = json.dumps(update) # Merge properties from async_generate_image over the message...
|
||||||
if model_type == "stable-diffusion":
|
yield message
|
||||||
steps_completed = 0
|
|
||||||
last_step_time = start_gen_time
|
|
||||||
|
|
||||||
def progress_callback(step: int, timestep: int, latents: torch.Tensor):
|
|
||||||
nonlocal steps_completed, last_step_time
|
|
||||||
steps_completed += 1
|
|
||||||
current_time = time.time()
|
|
||||||
step_time = current_time - last_step_time
|
|
||||||
last_step_time = current_time
|
|
||||||
progress = (steps_completed / iterations) * 80 + 10 # Scale from 10% to 90%
|
|
||||||
remaining_steps = iterations - steps_completed
|
|
||||||
estimated_remaining = step_time * remaining_steps * resolution_scale
|
|
||||||
yield status(message=message, status=f"Step {steps_completed}/{iterations} completed", progress=progress, estimated_time_remaining=str(estimated_remaining))
|
|
||||||
|
|
||||||
async def capture_progress(step: int, timestep: int, latents: torch.Tensor):
|
|
||||||
for msg in progress_callback(step, timestep, latents):
|
|
||||||
yield msg
|
|
||||||
|
|
||||||
yield status(message, f"Generating image with {iterations} inference steps")
|
|
||||||
image = pipe(
|
|
||||||
prompt,
|
|
||||||
num_inference_steps=iterations,
|
|
||||||
guidance_scale=7.5,
|
|
||||||
height=height,
|
|
||||||
width=width,
|
|
||||||
callback=capture_progress,
|
|
||||||
callback_steps=1
|
|
||||||
).images[0]
|
|
||||||
else:
|
|
||||||
# Flux: Run generation in the background and yield progress updates
|
|
||||||
estimated_gen_time = estimates["per_step"] * iterations * resolution_scale
|
|
||||||
yield status(message, f"Starting Flux image generation with {iterations} inference steps", estimated_time_remaining=estimated_gen_time)
|
|
||||||
|
|
||||||
image = pipe(
|
|
||||||
prompt,
|
|
||||||
num_inference_steps=iterations,
|
|
||||||
guidance_scale=7.5,
|
|
||||||
height=height,
|
|
||||||
width=width
|
|
||||||
).images[0]
|
|
||||||
|
|
||||||
# Start the generation task
|
|
||||||
start_gen_time = time.time()
|
|
||||||
|
|
||||||
gen_time = time.time() - start_gen_time
|
|
||||||
per_step_time = gen_time / iterations if iterations > 0 else gen_time
|
|
||||||
yield status(message, f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.", 90)
|
|
||||||
|
|
||||||
gen_time = time.time() - start_gen_time
|
|
||||||
per_step_time = gen_time / iterations if iterations > 0 else gen_time
|
|
||||||
yield status(message, f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.", 90)
|
|
||||||
|
|
||||||
# Save image
|
|
||||||
yield status(message, f"Saving image to {filepath}", 95)
|
|
||||||
image.save(filepath)
|
|
||||||
|
|
||||||
# Final result
|
# Final result
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
@ -172,7 +261,7 @@ async def generate_image_status(message: Message, model: str, prompt: str, itera
|
|||||||
message.response = json.dumps({
|
message.response = json.dumps({
|
||||||
"status": f"Image generation complete in {total_time:.1f} seconds",
|
"status": f"Image generation complete in {total_time:.1f} seconds",
|
||||||
"progress": 100,
|
"progress": 100,
|
||||||
"filename": filepath
|
"filename": request.filepath
|
||||||
})
|
})
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user