From 5c16c780a372efbe52e471ceeccd3cbe1db13550 Mon Sep 17 00:00:00 2001 From: James Ketrenos Date: Tue, 27 May 2025 12:16:31 -0700 Subject: [PATCH] Multi-modal with image generation --- Dockerfile | 5 + frontend/src/Components/Message.tsx | 29 ++-- .../src/NewApp/Components/BackstoryRoutes.tsx | 2 + .../src/NewApp/Components/CandidateInfo.tsx | 18 +-- frontend/src/NewApp/Components/Document.tsx | 2 +- .../src/NewApp/Components/GenerateImage.tsx | 150 ++++++++++++++++++ frontend/src/NewApp/Components/Quote.tsx | 113 +++++++------ .../Components/StyledMarkdown.css | 0 .../Components/StyledMarkdown.tsx | 26 ++- .../NewApp/Components/streamQueryResponse.tsx | 2 +- .../src/NewApp/Pages/GenerateCandidate.tsx | 2 +- src/server.py | 36 ++++- src/utils/agents/base.py | 19 ++- src/utils/agents/chat.py | 20 +++ src/utils/agents/image_generator.py | 102 +++++++++--- src/utils/profile_image.py | 2 +- src/utils/tools/basetools.py | 52 +++++- 17 files changed, 463 insertions(+), 117 deletions(-) create mode 100644 frontend/src/NewApp/Components/GenerateImage.tsx rename frontend/src/{ => NewApp}/Components/StyledMarkdown.css (100%) rename frontend/src/{ => NewApp}/Components/StyledMarkdown.tsx (85%) diff --git a/Dockerfile b/Dockerfile index c18fbb7..f656e69 100644 --- a/Dockerfile +++ b/Dockerfile @@ -267,7 +267,12 @@ WORKDIR /opt/ollama #ENV OLLAMA_VERSION=https://github.com/intel/ipex-llm/releases/download/v2.3.0-nightly/ollama-ipex-llm-2.3.0b20250415-ubuntu.tgz # NOTE: NO longer at github.com/intel -- now at ipex-llm + +# This version does not work: +# ENV OLLAMA_VERSION=https://github.com/ipex-llm/ipex-llm/releases/download/v2.3.0-nightly/ollama-ipex-llm-2.3.0b20250429-ubuntu.tgz + ENV OLLAMA_VERSION=https://github.com/ipex-llm/ipex-llm/releases/download/v2.2.0/ollama-ipex-llm-2.2.0-ubuntu.tgz + #ENV OLLAMA_VERSION=https://github.com/ipex-llm/ipex-llm/releases/download/v2.3.0-nightly/ollama-ipex-llm-2.3.0b20250429-ubuntu.tgz RUN wget -qO - ${OLLAMA_VERSION} | \ tar --strip-components=1 -C . -xzv diff --git a/frontend/src/Components/Message.tsx b/frontend/src/Components/Message.tsx index ddf4739..ec05219 100644 --- a/frontend/src/Components/Message.tsx +++ b/frontend/src/Components/Message.tsx @@ -21,7 +21,7 @@ import { SxProps, Theme } from '@mui/material'; import JsonView from '@uiw/react-json-view'; import { ChatBubble } from './ChatBubble'; -import { StyledMarkdown } from './StyledMarkdown'; +import { StyledMarkdown } from '../NewApp/Components/StyledMarkdown'; import { VectorVisualizer } from './VectorVisualizer'; import { SetSnackType } from './Snack'; @@ -175,18 +175,21 @@ const MessageMeta = (props: MessageMetaProps) => { {tool.name} - - { - if (typeof (children) === "string" && children.match("\n")) { - return
{children}
- } - }} - /> -
+ {tool.content !== "null" && + + { + if (typeof (children) === "string" && children.match("\n")) { + return
{children}
+ } + }} + /> +
+ } + {tool.content === "null" && "No response from tool call"} ) } diff --git a/frontend/src/NewApp/Components/BackstoryRoutes.tsx b/frontend/src/NewApp/Components/BackstoryRoutes.tsx index 3755820..9c5b1b5 100644 --- a/frontend/src/NewApp/Components/BackstoryRoutes.tsx +++ b/frontend/src/NewApp/Components/BackstoryRoutes.tsx @@ -18,6 +18,7 @@ import { CandidateListingPage } from '../Pages/CandidateListingPage'; import { JobAnalysisPage } from '../Pages/JobAnalysisPage'; import { DemoComponent } from "NewApp/Pages/DemoComponent"; import { GenerateCandidate } from "NewApp/Pages/GenerateCandidate"; +import { ControlsPage } from '../../Pages/ControlsPage'; const DashboardPage = () => (Dashboard); const ProfilePage = () => (Profile); @@ -49,6 +50,7 @@ const getBackstoryDynamicRoutes = (props : BackstoryDynamicRoutesProps, user?: U } />, } />, } />, + } />, ]; if (user === undefined || user === null) { diff --git a/frontend/src/NewApp/Components/CandidateInfo.tsx b/frontend/src/NewApp/Components/CandidateInfo.tsx index 4a5956d..260fb40 100644 --- a/frontend/src/NewApp/Components/CandidateInfo.tsx +++ b/frontend/src/NewApp/Components/CandidateInfo.tsx @@ -1,5 +1,5 @@ import React from 'react'; -import { Box, Link, Typography, Avatar,Grid, Chip, SxProps } from '@mui/material'; +import { Box, Link, Typography, Avatar, Grid, Chip, SxProps, CardHeader } from '@mui/material'; import { Card, CardContent, @@ -53,13 +53,9 @@ const CandidateInfo: React.FC = (props: CandidateInfoProps) ...sx }} > - setSelectedCandidate(candidate)} - sx={{ height: '100%', display: 'flex', flexDirection: 'column', alignItems: 'stretch' }} - > - + - + = (props: CandidateInfoProps) - {candidate.rag_content_size !== undefined && candidate.rag_content_size > 0 && 0 && + ) => { navigate('/knowledge-explorer'); event.stopPropagation() }} label={formatRagSize(candidate.rag_content_size)} color="primary" @@ -142,9 +139,8 @@ const CandidateInfo: React.FC = (props: CandidateInfoProps) - - - + + ); }; diff --git a/frontend/src/NewApp/Components/Document.tsx b/frontend/src/NewApp/Components/Document.tsx index 5a84c73..f06ff2d 100644 --- a/frontend/src/NewApp/Components/Document.tsx +++ b/frontend/src/NewApp/Components/Document.tsx @@ -4,7 +4,7 @@ import { Box, Typography } from '@mui/material'; import { Message } from '../../Components/Message'; import { ChatBubble } from '../../Components/ChatBubble'; import { BackstoryElementProps } from '../../Components/BackstoryTab'; -import { StyledMarkdown } from '../../Components/StyledMarkdown'; +import { StyledMarkdown } from './StyledMarkdown'; interface DocumentProps extends BackstoryElementProps { filepath?: string; diff --git a/frontend/src/NewApp/Components/GenerateImage.tsx b/frontend/src/NewApp/Components/GenerateImage.tsx new file mode 100644 index 0000000..dc1facf --- /dev/null +++ b/frontend/src/NewApp/Components/GenerateImage.tsx @@ -0,0 +1,150 @@ +import React, { useEffect, useState, useRef, useCallback } from 'react'; +import Avatar from '@mui/material/Avatar'; +import Box from '@mui/material/Box'; +import Tooltip from '@mui/material/Tooltip'; +import Button from '@mui/material/Button'; +import Paper from '@mui/material/Paper'; +import IconButton from '@mui/material/IconButton'; +import CancelIcon from '@mui/icons-material/Cancel'; +import SendIcon from '@mui/icons-material/Send'; +import PropagateLoader from 'react-spinners/PropagateLoader'; +import { CandidateInfo } from '../Components/CandidateInfo'; +import { Query } from '../../Components/ChatQuery' +import { Quote } from 'NewApp/Components/Quote'; +import { streamQueryResponse, StreamQueryController } from '../Components/streamQueryResponse'; +import { connectionBase } from 'Global'; +import { UserInfo } from '../Components/UserContext'; +import { BackstoryElementProps } from 'Components/BackstoryTab'; +import { BackstoryTextField, BackstoryTextFieldRef } from 'Components/BackstoryTextField'; +import { jsonrepair } from 'jsonrepair'; +import { StyledMarkdown } from 'NewApp/Components/StyledMarkdown'; +import { Scrollable } from 'Components/Scrollable'; +import { Pulse } from 'NewApp/Components/Pulse'; +import { useUser } from '../Components/UserContext'; + +interface GenerateImageProps extends BackstoryElementProps { + prompt: string +}; + +const GenerateImage = (props: GenerateImageProps) => { + const { user } = useUser(); + const {sessionId, setSnack, prompt} = props; + const [processing, setProcessing] = useState(false); + const [status, setStatus] = useState(''); + const [timestamp, setTimestamp] = useState(0); + const [image, setImage] = useState(''); + + // Only keep refs that are truly necessary + const controllerRef = useRef(null); + + // Effect to trigger profile generation when user data is ready + useEffect(() => { + if (controllerRef.current) { + console.log("Controller already active, skipping profile generation"); + return; + } + if (!prompt) { + return; + } + setStatus('Starting image generation...'); + setProcessing(true); + const start = Date.now(); + + controllerRef.current = streamQueryResponse({ + query: { + prompt: prompt, + agent_options: { + username: user?.username, + } + }, + type: "image", + sessionId, + connectionBase, + onComplete: (msg) => { + switch (msg.status) { + case "partial": + case "done": + if (msg.status === "done") { + if (!msg.response) { + setSnack("Image generation failed", "error"); + } else { + setImage(msg.response); + } + setProcessing(false); + controllerRef.current = null; + } + break; + case "error": + console.log(`Error generating profile: ${msg.response} after ${Date.now() - start}`); + setSnack(msg.response || "", "error"); + setProcessing(false); + controllerRef.current = null; + break; + default: + const data = JSON.parse(msg.response || ''); + if (msg.status !== "heartbeat") { + console.log(data); + } + if (data.timestamp) { + setTimestamp(data.timestamp); + } else { + setTimestamp(Date.now()) + } + if (data.message) { + setStatus(data.message); + } + break; + } + } + }); + }, [user, prompt, sessionId, setSnack]); + + if (!sessionId) { + return <>; + } + + return ( + + {image !== '' && } + { prompt && + + } + {processing && + + { status && + + Generation status + {status} + + } + + + } + ); +}; + +export { + GenerateImage +}; \ No newline at end of file diff --git a/frontend/src/NewApp/Components/Quote.tsx b/frontend/src/NewApp/Components/Quote.tsx index 9be6dca..26264bb 100644 --- a/frontend/src/NewApp/Components/Quote.tsx +++ b/frontend/src/NewApp/Components/Quote.tsx @@ -1,34 +1,38 @@ - - - - import React from 'react'; -import { Box, Typography, Paper } from '@mui/material'; +import { Box, Typography, Paper, SxProps } from '@mui/material'; import { styled } from '@mui/material/styles'; -const QuoteContainer = styled(Paper)(({ theme }) => ({ +interface QuoteContainerProps { + size?: 'normal' | 'small'; +} + +const QuoteContainer = styled(Paper, { + shouldForwardProp: (prop) => prop !== 'size', +})(({ theme, size = 'normal' }) => ({ position: 'relative', - padding: theme.spacing(4), - margin: theme.spacing(2), - background: 'linear-gradient(135deg, #FFFFFF 0%, #D3CDBF 100%)', // White to Warm Gray - borderRadius: theme.spacing(2), - boxShadow: '0 8px 32px rgba(26, 37, 54, 0.15)', // Midnight Blue shadow + padding: size === 'small' ? theme.spacing(1) : theme.spacing(4), + margin: size === 'small' ? theme.spacing(0.5) : theme.spacing(2), + background: 'linear-gradient(135deg, #FFFFFF 0%, #D3CDBF 100%)', + borderRadius: size === 'small' ? theme.spacing(1) : theme.spacing(2), + boxShadow: '0 8px 32px rgba(26, 37, 54, 0.15)', overflow: 'hidden', - border: '1px solid rgba(74, 122, 125, 0.2)', // Subtle Dusty Teal border + border: '1px solid rgba(74, 122, 125, 0.2)', '&::before': { content: '""', position: 'absolute', top: 0, left: 0, right: 0, - height: '4px', - background: 'linear-gradient(90deg, #4A7A7D 0%, #D4A017 100%)', // Dusty Teal to Golden Ochre - } + height: size === 'small' ? '2px' : '4px', + background: 'linear-gradient(90deg, #4A7A7D 0%, #D4A017 100%)', + }, })); -const QuoteText = styled(Typography)(({ theme }) => ({ - fontSize: '1.2rem', - lineHeight: 1.6, +const QuoteText = styled(Typography, { + shouldForwardProp: (prop) => prop !== 'size', +})(({ theme, size = 'normal' }) => ({ + fontSize: size === 'small' ? '0.9rem' : '1.2rem', + lineHeight: size === 'small' ? 1.4 : 1.6, fontStyle: 'italic', color: '#2E2E2E', // Charcoal Black position: 'relative', @@ -38,8 +42,10 @@ const QuoteText = styled(Typography)(({ theme }) => ({ fontWeight: 400, })); -const QuoteMark = styled(Typography)(({ theme }) => ({ - fontSize: '4rem', +const QuoteMark = styled(Typography, { + shouldForwardProp: (prop) => prop !== 'size', +})(({ theme, size = 'normal' }) => ({ + fontSize: size === 'small' ? '2.5rem' : '4rem', fontFamily: '"Georgia", "Times New Roman", serif', fontWeight: 'bold', opacity: 0.15, @@ -49,60 +55,65 @@ const QuoteMark = styled(Typography)(({ theme }) => ({ userSelect: 'none', })); -const OpeningQuote = styled(QuoteMark)({ - top: '10px', - left: '15px', -}); +const OpeningQuote = styled(QuoteMark)(({ size = 'normal' }: QuoteContainerProps) => ({ + top: size === 'small' ? '5px' : '10px', + left: size === 'small' ? '8px' : '15px', +})); -const ClosingQuote = styled(QuoteMark)({ - bottom: '10px', - right: '15px', +const ClosingQuote = styled(QuoteMark)(({ size = 'normal' }: QuoteContainerProps) => ({ + bottom: size === 'small' ? '5px' : '10px', + right: size === 'small' ? '8px' : '15px', transform: 'rotate(180deg)', -}); +})); -const AuthorText = styled(Typography)(({ theme }) => ({ - marginTop: theme.spacing(2), +const AuthorText = styled(Typography, { + shouldForwardProp: (prop) => prop !== 'size', +})(({ theme, size = 'normal' }) => ({ + marginTop: size === 'small' ? theme.spacing(1) : theme.spacing(2), textAlign: 'right', fontStyle: 'normal', fontWeight: 500, color: '#1A2536', // Midnight Blue - fontSize: '0.95rem', + fontSize: size === 'small' ? '0.8rem' : '0.95rem', '&::before': { content: '"— "', color: '#D4A017', // Golden Ochre dash - } + }, })); -const AccentLine = styled(Box)({ - width: '60px', - height: '2px', +const AccentLine = styled(Box, { + shouldForwardProp: (prop) => prop !== 'size', +})(({ theme, size = 'normal' }) => ({ + width: size === 'small' ? '40px' : '60px', + height: size === 'small' ? '1px' : '2px', background: 'linear-gradient(90deg, #D4A017 0%, #4A7A7D 100%)', // Golden Ochre to Dusty Teal - margin: '1rem auto', + margin: size === 'small' ? '0.5rem auto' : '1rem auto', borderRadius: '1px', -}); - +})); interface QuoteProps { - quote?: string, - author?: string -}; + quote?: string; + author?: string; + size?: 'small' | 'normal'; + sx?: SxProps; +} -const Quote = (props : QuoteProps) => { - const { quote, author } = props; +const Quote = (props: QuoteProps) => { + const { quote, author, size = 'normal', sx } = props; return ( - - " - " + + " + " - + {quote} - + {author && ( - + {author} )} @@ -111,6 +122,4 @@ const Quote = (props : QuoteProps) => { ); }; -export { - Quote -}; \ No newline at end of file +export { Quote }; \ No newline at end of file diff --git a/frontend/src/Components/StyledMarkdown.css b/frontend/src/NewApp/Components/StyledMarkdown.css similarity index 100% rename from frontend/src/Components/StyledMarkdown.css rename to frontend/src/NewApp/Components/StyledMarkdown.css diff --git a/frontend/src/Components/StyledMarkdown.tsx b/frontend/src/NewApp/Components/StyledMarkdown.tsx similarity index 85% rename from frontend/src/Components/StyledMarkdown.tsx rename to frontend/src/NewApp/Components/StyledMarkdown.tsx index ca23d40..578740c 100644 --- a/frontend/src/Components/StyledMarkdown.tsx +++ b/frontend/src/NewApp/Components/StyledMarkdown.tsx @@ -2,16 +2,17 @@ import React from 'react'; import { MuiMarkdown } from 'mui-markdown'; import { SxProps, useTheme } from '@mui/material/styles'; import { Link } from '@mui/material'; -import { ChatQuery } from './ChatQuery'; +import { ChatQuery } from '../../Components/ChatQuery'; import Box from '@mui/material/Box'; import JsonView from '@uiw/react-json-view'; import { vscodeTheme } from '@uiw/react-json-view/vscode'; -import { Mermaid } from './Mermaid'; -import { Scrollable } from './Scrollable'; +import { Mermaid } from '../../Components/Mermaid'; +import { Scrollable } from '../../Components/Scrollable'; import { jsonrepair } from 'jsonrepair'; +import { GenerateImage } from './GenerateImage'; import './StyledMarkdown.css'; -import { BackstoryElementProps } from './BackstoryTab'; +import { BackstoryElementProps } from '../../Components/BackstoryTab'; interface StyledMarkdownProps extends BackstoryElementProps { className?: string, @@ -20,7 +21,7 @@ interface StyledMarkdownProps extends BackstoryElementProps { }; const StyledMarkdown: React.FC = (props: StyledMarkdownProps) => { - const { className, sessionId, content, submitQuery, sx, streaming } = props; + const { className, sessionId, content, submitQuery, sx, streaming, setSnack } = props; const theme = useTheme(); const overrides: any = { @@ -107,8 +108,19 @@ const StyledMarkdown: React.FC = (props: StyledMarkdownProp return props.query; } }, - } - }; + }, + GenerateImage: { + component: (props: { prompt: string }) => { + const prompt = props.prompt.replace(/(\w+):/g, '"$1":'); + try { + return + } catch (e) { + console.log("StyledMarkdown error:", prompt, e); + return props.prompt; + } + }, + }, +}; return { await processLine(line); } catch (e) { console.error('Error processing line:', e); - console.log(line); + console.error(line); } } } diff --git a/frontend/src/NewApp/Pages/GenerateCandidate.tsx b/frontend/src/NewApp/Pages/GenerateCandidate.tsx index de94705..bee3872 100644 --- a/frontend/src/NewApp/Pages/GenerateCandidate.tsx +++ b/frontend/src/NewApp/Pages/GenerateCandidate.tsx @@ -17,7 +17,7 @@ import { UserInfo } from '../Components/UserContext'; import { BackstoryElementProps } from 'Components/BackstoryTab'; import { BackstoryTextField, BackstoryTextFieldRef } from 'Components/BackstoryTextField'; import { jsonrepair } from 'jsonrepair'; -import { StyledMarkdown } from 'Components/StyledMarkdown'; +import { StyledMarkdown } from 'NewApp/Components/StyledMarkdown'; import { Scrollable } from 'Components/Scrollable'; import { Pulse } from 'NewApp/Components/Pulse'; diff --git a/src/server.py b/src/server.py index dc10250..867aea6 100644 --- a/src/server.py +++ b/src/server.py @@ -279,6 +279,14 @@ class WebServer: self.setup_routes() + def sanitize_input(self, input: str): + # Validate input: allow only alphanumeric, underscores, and hyphens + if not re.match(r'^[a-zA-Z0-9._-]+$', input): # alphanumeric, _, -, and . are valid + raise ValueError("Invalid input format.") + if re.match(r'\.\.', input): # two ticks in a row is invalid + raise ValueError("Invalid input format.") + + def setup_routes(self): # @self.app.get("/") # async def root(): @@ -778,6 +786,30 @@ class WebServer: logger.error(f"get_users error: {str(e)}") return JSONResponse({ "error": "Unable to parse users"}, 500) + @self.app.get("/api/u/{username}/images/{image_id}/{context_id}") + async def get_user_image(username: str, image_id: str, context_id: str, request: Request): + logger.info(f"{request.method} {request.url.path}") + try: + self.sanitize_input(context_id) + self.sanitize_input(username) + self.sanitize_input(image_id) + + if not User.exists(username): + return JSONResponse({"error": f"User {username} not found."}, status_code=404) + context = await self.load_context(context_id) + if not context: + return JSONResponse({"error": f"Context {context_id} not found."}, status_code=404) + image_path = os.path.join(defines.user_dir, username, "images", image_id) + if not os.path.exists(image_path): + return JSONResponse({ "error": "User {username} does not image {image_id}"}, status_code=404) + return FileResponse(image_path) + except ValueError as e: + return JSONResponse({ "error": f"Invalid input: {image_id}" }, 400) + except Exception as e: + logger.error(traceback.format_exc()) + logger.error(e) + return JSONResponse({ "error": f"Unable to get image {username} {image_id}"}, 500) + @self.app.get("/api/u/{username}/profile/{context_id}") async def get_user_profile(username: str, context_id: str, request: Request): logger.info(f"{request.method} {request.url.path}") @@ -792,7 +824,7 @@ class WebServer: return JSONResponse({ "error": "User {username} does not have a profile picture"}, status_code=404) return FileResponse(profile_path) except Exception as e: - return JSONResponse({ "error": "Unable to load user {username}"}, 500) + return JSONResponse({ "error": f"Unable to load user {username}"}, 500) @self.app.post("/api/u/{username}/{context_id}") async def post_user(username: str, context_id: str, request: Request): @@ -843,7 +875,7 @@ class WebServer: await self.save_context(context_id) return JSONResponse(user_data) except Exception as e: - return JSONResponse({ "error": "Unable to load user {username}"}, 500) + return JSONResponse({ "error": f"Unable to load user {username}"}, 500) @self.app.post("/api/context/u/{username}") async def create_user_context(username: str, request: Request): diff --git a/src/utils/agents/base.py b/src/utils/agents/base.py index 91a18fc..2a3673e 100644 --- a/src/utils/agents/base.py +++ b/src/utils/agents/base.py @@ -33,7 +33,7 @@ from .types import agent_registry from .. import defines from ..message import Message, Tunables from ..metrics import Metrics -from ..tools import TickerValue, WeatherForecast, AnalyzeSite, DateTime, llm_tools # type: ignore -- dynamically added to __all__ +from ..tools import TickerValue, WeatherForecast, AnalyzeSite, GenerateImage, DateTime, llm_tools # type: ignore -- dynamically added to __all__ from ..conversation import Conversation class LLMMessage(BaseModel): @@ -240,6 +240,22 @@ class Agent(BaseModel, ABC): llm=llm, model=model, url=url, question=question ) + case "GenerateImage": + prompt = arguments.get("prompt", None) + if not prompt: + logger.info("No prompt supplied to GenerateImage") + ret = { "error": "No prompt supplied to GenerateImage" } + + # Additional status update for long-running operations + message.response = ( + f"Generating image for {prompt}..." + ) + yield message + ret = await GenerateImage( + llm=llm, model=model, prompt=prompt + ) + logger.info("GenerateImage returning", ret) + case "DateTime": tz = arguments.get("timezone") ret = DateTime(tz) @@ -255,6 +271,7 @@ class Agent(BaseModel, ABC): ret = WeatherForecast(city, state) case _: + logger.error(f"Requested tool {tool} does not exist") ret = None # Build response for this tool diff --git a/src/utils/agents/chat.py b/src/utils/agents/chat.py index 84c7b09..5529407 100644 --- a/src/utils/agents/chat.py +++ b/src/utils/agents/chat.py @@ -24,6 +24,26 @@ When answering queries, follow these steps: - If there is information in the <|context|> or <|resume|> sections to enhance the answer, incorporate it seamlessly and refer to it as 'the latest information' or 'recent data' instead of mentioning '<|context|>' (etc.) or quoting it directly. - Avoid phrases like 'According to the <|context|>' or similar references to the <|context|> or <|resume|>. +CRITICAL INSTRUCTIONS FOR IMAGE GENERATION: + +1. When the user requests to generate an image, inject the following into the response: . Do this when users request images, drawings, or visual content. +3. MANDATORY: You must respond with EXACTLY this format: +4. FORBIDDEN: DO NOT use markdown image syntax ![](url) +5. FORBIDDEN: DO NOT create fake URLs or file paths +6. FORBIDDEN: DO NOT use any other image embedding format + +CORRECT EXAMPLE: +User: "Draw a cat" +Your response: "" + +WRONG EXAMPLES (DO NOT DO THIS): +- ![](https://example.com/...) +- ![Cat image](any_url) +- + +The format is the ONLY way to display images in this system. +DO NOT make up a URL for an image or provide markdown syntax for embedding an image. Only use , and <|context|> when possible. Be concise, and never make up information. If you do not know the answer, say so. """ diff --git a/src/utils/agents/image_generator.py b/src/utils/agents/image_generator.py index a3f2000..3506d50 100644 --- a/src/utils/agents/image_generator.py +++ b/src/utils/agents/image_generator.py @@ -21,6 +21,7 @@ import time import asyncio import time import os +import hashlib from . base import Agent, agent_registry, LLMMessage from .. message import Message @@ -42,7 +43,7 @@ class ImageGenerator(Agent): system_prompt: str = "" # No system prompt is used username: str - filename: str + filename: str = "" llm: Any = Field(default=None, exclude=True) model: str = Field(default=None, exclude=True) @@ -64,19 +65,53 @@ class ImageGenerator(Agent): self.llm = llm self.model = model - spinner: List[str] = ["\\", "|", "/", "-"] - tick: int = 0 - while self.context.processing: - logger.info( - "TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time" - ) - message.status = "waiting" - message.response = ( - f"Busy processing another request. Please wait. {spinner[tick]}" - ) - tick = (tick + 1) % len(spinner) - yield message - await asyncio.sleep(1) # Allow the event loop to process the write + prompt = message.prompt + user_info = os.path.join(defines.user_dir, self.username, defines.user_info_file) + with open(user_info, "r") as f: + self.user = json.loads(f.read()) + images = self.user.get("images", None) + if not images: + self.user["images"] = {} + images = self.user["images"] + with open(user_info, "w") as f: + f.write(json.dumps(self.user)) + + if self.filename != "profile.png": + # Convert the prompt to a hash using MD5 + hash_object = hashlib.md5(prompt.encode('utf-8')) + # Get the hexadecimal representation of the hash + hash_string = hash_object.hexdigest() + # Return the filename with the specified extension + self.filename = f"{hash_string}.png" + file_path = os.path.join(defines.user_dir, self.username, "images") + os.makedirs(file_path, exist_ok=True) + file_path=os.path.join(file_path, self.filename) + else: + file_path = os.path.join(defines.user_dir, self.username, self.filename) + + images[self.filename] = { "status": "thinking", "response": "Generating image" } + + if self.context.processing: + tick: int = 0 + while self.context.processing: + with open(user_info, "r") as f: + self.user = json.loads(f.read()) + image = self.user["images"].get(self.filename, { "status": "waiting", "response": f"Waiting for image generation slot.{'.' * tick}"}) + message.status = image["status"] + message.response = image["response"] + tick = (tick + 1) % 5 + yield message + await asyncio.sleep(1) # Allow the event loop to process the write + + # Processing of active image is complete. Check if it was this image, and return it if so + with open(user_info, "r") as f: + self.user = json.loads(f.read()) + image = self.user["images"].get(self.filename, None) + if image: + message.status = image["status"] + message.response = image["response"] + yield message + return self.context.processing = True @@ -84,25 +119,30 @@ class ImageGenerator(Agent): # # Generate the profile picture # - prompt = message.prompt message.status = "thinking" message.response = f"Generating: {prompt}" yield message - user_info = os.path.join(defines.user_dir, self.username, defines.user_info_file) - with open(user_info, "r") as f: - self.user = json.loads(f.read()) - logger.info("Beginning image generation...", self.user) - logger.info("TODO: Add safety checks for filename... actually figure out an entirely different way to figure out where to store them.") - self.filename = "profile.png" - file_path = os.path.join(defines.user_dir, self.user["username"], self.filename) + if os.path.exists(file_path) and self.filename != "profile.png": + logger.info(f"Image already exists: {file_path}") + message.status = "done" + message.response = f"/api/u/{self.username}/images/{self.filename}" + self.user["images"][self.filename] = { "status": message.status, "response": message.response} + with open(user_info, "w") as f: + f.write(json.dumps(self.user)) + yield message + return + logger.info(f"Image generation: {file_path} <- {prompt}") request = ImageRequest(filepath=file_path, prompt=prompt) async for message in generate_image( message=message, request=request ): + self.user["images"][self.filename] = { "status": message.status, "response": message.response} + with open(user_info, "w") as f: + f.write(json.dumps(self.user)) if message.status != "done": yield message logger.info("Image generation done...") @@ -110,12 +150,13 @@ class ImageGenerator(Agent): logger.info(f"Generated image does not exist: {file_path}") logger.error(f"{message.status} {message.response}") else: - images = self.user.get("images", []) - if self.filename not in images: - images.append(self.filename) if self.filename == "profile.png": self.user["has_profile"] = True + self.user["images"][self.filename] = { "status": message.status, "response": message.response} + with open(user_info, "w") as f: + f.write(json.dumps(self.user)) + # # Write out the completed user information # @@ -124,7 +165,10 @@ class ImageGenerator(Agent): # Image generated message.status = "done" - message.response = json.dumps(self.user) + if self.filename != "profile.png": + message.response = f"/apu/u/{self.username}/images/{self.filename}" + else: + message.response = f"/api/u/{self.username}/profile" except Exception as e: message.status = "error" @@ -132,10 +176,16 @@ class ImageGenerator(Agent): logger.error(message.response) message.response = f"Error in image generation: {str(e)}" logger.error(message.response) + self.user["images"][self.filename] = { "status": message.status, "response": message.response} + with open(user_info, "w") as f: + f.write(json.dumps(self.user)) yield message return # Done processing, add message to conversation + self.user["images"][self.filename] = { "status": message.status, "response": message.response} + with open(user_info, "w") as f: + f.write(json.dumps(self.user)) self.context.processing = False # Return the final message yield message diff --git a/src/utils/profile_image.py b/src/utils/profile_image.py index d966dfe..300a913 100644 --- a/src/utils/profile_image.py +++ b/src/utils/profile_image.py @@ -76,7 +76,7 @@ def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task status_queue.put({ "status": "running", - "message": f"Processing step {step+1}/{params.iterations} ({progress}%)", + "message": f"Processing step {step+1}/{params.iterations} ({progress}%) complete.", "progress": progress }) return callback_kwargs diff --git a/src/utils/tools/basetools.py b/src/utils/tools/basetools.py index c43cbea..6598bf5 100644 --- a/src/utils/tools/basetools.py +++ b/src/utils/tools/basetools.py @@ -279,6 +279,8 @@ def DateTime(timezone="America/Los_Angeles"): except Exception as e: return {"error": f"Invalid timezone {timezone}: {str(e)}"} +async def GenerateImage(llm, model: str, prompt: str): + return { "image_id": "image-a830a83-bd831" } async def AnalyzeSite(llm, model: str, url: str, question: str): """ @@ -346,6 +348,7 @@ async def AnalyzeSite(llm, model: str, url: str, question: str): return f"Error processing the website content: {str(e)}" + # %% class Function(BaseModel): name: str @@ -358,6 +361,53 @@ class Tool(BaseModel): function: Function tools : List[Tool] = [ +# Tool.model_validate({ +# "type": "function", +# "function": { +# "name": "GenerateImage", +# "description": """\ +# CRITICAL INSTRUCTIONS FOR IMAGE GENERATION: + +# 1. Call this tool when users request images, drawings, or visual content +# 2. This tool returns an image_id (e.g., "img_abc123") +# 3. MANDATORY: You must respond with EXACTLY this format: +# 4. FORBIDDEN: DO NOT use markdown image syntax ![](url) +# 5. FORBIDDEN: DO NOT create fake URLs or file paths +# 6. FORBIDDEN: DO NOT use any other image embedding format + +# CORRECT EXAMPLE: +# User: "Draw a cat" +# Tool returns: {"image_id": "img_xyz789"} +# Your response: "Here's your cat image: " + +# WRONG EXAMPLES (DO NOT DO THIS): +# - ![](https://example.com/...) +# - ![Cat image](any_url) +# - + +# The format is the ONLY way to display images in this system. +# """, +# "parameters": { +# "type": "object", +# "properties": { +# "prompt": { +# "type": "string", +# "description": "Detailed image description including style, colors, subject, composition" +# } +# }, +# "required": ["prompt"] +# }, +# "returns": { +# "type": "object", +# "properties": { +# "image_id": { +# "type": "string", +# "description": "Unique identifier for the generated image. Use this EXACTLY in " +# } +# } +# } +# } +# }), Tool.model_validate({ "type": "function", "function": { @@ -471,6 +521,6 @@ def all_tools() -> List[ToolEntry]: def enabled_tools(tools: List[ToolEntry]) -> List[ToolEntry]: return [ToolEntry(tool=entry.tool) for entry in tools if entry.enabled == True] -tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite"] +tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite", "GenerateImage"] __all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"] # __all__.extend(__tool_functions__) # type: ignore