diff --git a/frontend/src/components/Conversation.tsx b/frontend/src/components/Conversation.tsx index c7b2bb5..c5eb631 100644 --- a/frontend/src/components/Conversation.tsx +++ b/frontend/src/components/Conversation.tsx @@ -15,14 +15,14 @@ import { BackstoryElementProps } from './BackstoryTab'; import { connectionBase } from 'utils/Global'; import { useAuth } from "hooks/AuthContext"; import { StreamingResponse } from 'services/api-client'; -import { ChatMessage, ChatMessageBase, ChatContext, ChatSession, ChatQuery, ChatMessageUser } from 'types/types'; +import { ChatMessage, ChatContext, ChatSession, ChatQuery, ChatMessageUser, ChatMessageError, ChatMessageStreaming, ChatMessageStatus } from 'types/types'; import { PaginatedResponse } from 'types/conversion'; import './Conversation.css'; import { useSelectedCandidate } from 'hooks/GlobalContext'; const defaultMessage: ChatMessage = { - type: "preparing", status: "done", sender: "system", sessionId: "", timestamp: new Date(), content: "" + status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "assistant" }; const loadingMessage: ChatMessage = { ...defaultMessage, content: "Establishing connection with server..." }; @@ -249,8 +249,7 @@ const Conversation = forwardRef((props: C ...conversationRef.current, { ...defaultMessage, - type: 'user', - sender: 'user', + type: 'text', content: query.prompt, } ]); @@ -260,34 +259,30 @@ const Conversation = forwardRef((props: C ); const chatMessage: ChatMessageUser = { + role: "user", sessionId: chatSession.id, content: query.prompt, tunables: query.tunables, status: "done", - type: "user", - sender: "user", + type: "text", timestamp: new Date() }; controllerRef.current = apiClient.sendMessageStream(chatMessage, { - onMessage: (msg: ChatMessageBase) => { + onMessage: (msg: ChatMessage) => { console.log("onMessage:", msg); - if (msg.type === "response") { - setConversation([ - ...conversationRef.current, - msg - ]); - setStreamingMessage(undefined); - setProcessingMessage(undefined); - setProcessing(false); - } else { - setProcessingMessage(msg); - } + setConversation([ + ...conversationRef.current, + msg + ]); + setStreamingMessage(undefined); + setProcessingMessage(undefined); + setProcessing(false); if (onResponse) { onResponse(msg); } }, - onError: (error: string | ChatMessageBase) => { + onError: (error: string | ChatMessageError) => { console.log("onError:", error); // Type-guard to determine if this is a ChatMessageBase or a string if (typeof error === "object" && error !== null && "content" in error) { @@ -298,12 +293,12 @@ const Conversation = forwardRef((props: C setProcessingMessage({ ...defaultMessage, content: error as string }); } }, - onStreaming: (chunk: ChatMessageBase) => { + onStreaming: (chunk: ChatMessageStreaming) => { console.log("onStreaming:", chunk); setStreamingMessage({ ...defaultMessage, ...chunk }); }, - onStatusChange: (status: string) => { - console.log("onStatusChange:", status); + onStatus: (status: ChatMessageStatus) => { + console.log("onStatus:", status); }, onComplete: () => { console.log("onComplete"); diff --git a/frontend/src/components/JobMatchAnalysis.tsx b/frontend/src/components/JobMatchAnalysis.tsx index b3dba6f..e7a1391 100644 --- a/frontend/src/components/JobMatchAnalysis.tsx +++ b/frontend/src/components/JobMatchAnalysis.tsx @@ -20,7 +20,7 @@ import CheckCircleIcon from '@mui/icons-material/CheckCircle'; import ErrorIcon from '@mui/icons-material/Error'; import PendingIcon from '@mui/icons-material/Pending'; import WarningIcon from '@mui/icons-material/Warning'; -import { Candidate, ChatMessage, ChatMessageBase, ChatMessageUser, ChatSession, JobRequirements, SkillMatch } from 'types/types'; +import { Candidate, ChatMessage, ChatMessageError, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatSession, JobRequirements, SkillMatch } from 'types/types'; import { useAuth } from 'hooks/AuthContext'; import { BackstoryPageProps } from './BackstoryTab'; import { toCamelCase } from 'types/conversion'; @@ -31,8 +31,8 @@ interface JobAnalysisProps extends BackstoryPageProps { candidate: Candidate; } -const defaultMessage: ChatMessageUser = { - type: "preparing", status: "done", sender: "user", sessionId: "", timestamp: new Date(), content: "" +const defaultMessage: ChatMessage = { + status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "assistant" }; const JobMatchAnalysis: React.FC = (props: JobAnalysisProps) => { @@ -62,23 +62,19 @@ const JobMatchAnalysis: React.FC = (props: JobAnalysisProps) = return; } - const createSession = async () => { - try { - const session: ChatSession = await apiClient.createCandidateChatSession( - candidate.username, - 'job_requirements', - `Generate requirements for ${job.title}` - ); - setSnack("Job analysis session started"); + try { + setCreatingSession(true); + apiClient.getOrCreateChatSession(candidate, `Generate requirements for ${candidate.fullName}`, 'job_requirements') + .then(session => { setRequirementsSession(session); - } catch (error) { - console.log(error); - setSnack("Unable to create requirements session", "error"); - } + setCreatingSession(false); + }); + } catch (error) { + setSnack('Unable to load chat session', 'error'); + } finally { setCreatingSession(false); - }; - setCreatingSession(true); - createSession(); + } + }, [requirementsSession, apiClient, candidate]); // Fetch initial requirements @@ -94,50 +90,48 @@ const JobMatchAnalysis: React.FC = (props: JobAnalysisProps) = apiClient.sendMessageStream(chatMessage, { onMessage: (msg: ChatMessage) => { console.log(`onMessage: ${msg.type}`, msg); - if (msg.type === "response") { - const job: Job = toCamelCase(JSON.parse(msg.content || '')); - const requirements: { requirement: string, domain: string }[] = []; - if (job.requirements?.technicalSkills) { - job.requirements.technicalSkills.required?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (required)' })); - job.requirements.technicalSkills.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (preferred)' })); - } - if (job.requirements?.experienceRequirements) { - job.requirements.experienceRequirements.required?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (required)' })); - job.requirements.experienceRequirements.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (preferred)' })); - } - if (job.requirements?.softSkills) { - job.requirements.softSkills.forEach(req => requirements.push({ requirement: req, domain: 'Soft Skills' })); - } - if (job.requirements?.experience) { - job.requirements.experience.forEach(req => requirements.push({ requirement: req, domain: 'Experience' })); - } - if (job.requirements?.education) { - job.requirements.education.forEach(req => requirements.push({ requirement: req, domain: 'Education' })); - } - if (job.requirements?.certifications) { - job.requirements.certifications.forEach(req => requirements.push({ requirement: req, domain: 'Certifications' })); - } - if (job.requirements?.preferredAttributes) { - job.requirements.preferredAttributes.forEach(req => requirements.push({ requirement: req, domain: 'Preferred Attributes' })); - } - - const initialSkillMatches = requirements.map(req => ({ - requirement: req.requirement, - domain: req.domain, - status: 'waiting' as const, - matchScore: 0, - assessment: '', - description: '', - citations: [] - })); - - setRequirements(requirements); - setSkillMatches(initialSkillMatches); - setStatusMessage(null); - setLoadingRequirements(false); + const job: Job = toCamelCase(JSON.parse(msg.content || '')); + const requirements: { requirement: string, domain: string }[] = []; + if (job.requirements?.technicalSkills) { + job.requirements.technicalSkills.required?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (required)' })); + job.requirements.technicalSkills.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (preferred)' })); } + if (job.requirements?.experienceRequirements) { + job.requirements.experienceRequirements.required?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (required)' })); + job.requirements.experienceRequirements.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (preferred)' })); + } + if (job.requirements?.softSkills) { + job.requirements.softSkills.forEach(req => requirements.push({ requirement: req, domain: 'Soft Skills' })); + } + if (job.requirements?.experience) { + job.requirements.experience.forEach(req => requirements.push({ requirement: req, domain: 'Experience' })); + } + if (job.requirements?.education) { + job.requirements.education.forEach(req => requirements.push({ requirement: req, domain: 'Education' })); + } + if (job.requirements?.certifications) { + job.requirements.certifications.forEach(req => requirements.push({ requirement: req, domain: 'Certifications' })); + } + if (job.requirements?.preferredAttributes) { + job.requirements.preferredAttributes.forEach(req => requirements.push({ requirement: req, domain: 'Preferred Attributes' })); + } + + const initialSkillMatches = requirements.map(req => ({ + requirement: req.requirement, + domain: req.domain, + status: 'waiting' as const, + matchScore: 0, + assessment: '', + description: '', + citations: [] + })); + + setRequirements(requirements); + setSkillMatches(initialSkillMatches); + setStatusMessage(null); + setLoadingRequirements(false); }, - onError: (error: string | ChatMessageBase) => { + onError: (error: string | ChatMessageError) => { console.log("onError:", error); // Type-guard to determine if this is a ChatMessageBase or a string if (typeof error === "object" && error !== null && "content" in error) { @@ -147,11 +141,11 @@ const JobMatchAnalysis: React.FC = (props: JobAnalysisProps) = } setLoadingRequirements(false); }, - onStreaming: (chunk: ChatMessageBase) => { + onStreaming: (chunk: ChatMessageStreaming) => { // console.log("onStreaming:", chunk); }, - onStatusChange: (status: string) => { - console.log(`onStatusChange: ${status}`); + onStatus: (status: ChatMessageStatus) => { + console.log(`onStatus: ${status}`); }, onComplete: () => { console.log("onComplete"); diff --git a/frontend/src/components/Message.tsx b/frontend/src/components/Message.tsx index 7ff22af..432de91 100644 --- a/frontend/src/components/Message.tsx +++ b/frontend/src/components/Message.tsx @@ -32,9 +32,9 @@ import { SetSnackType } from './Snack'; import { CopyBubble } from './CopyBubble'; import { Scrollable } from './Scrollable'; import { BackstoryElementProps } from './BackstoryTab'; -import { ChatMessage, ChatSession, ChatMessageType, ChatMessageMetaData, ChromaDBGetResponse } from 'types/types'; +import { ChatMessage, ChatSession, ChatMessageMetaData, ChromaDBGetResponse, ApiActivityType, ChatMessageUser, ChatMessageError, ChatMessageStatus, ChatSenderType } from 'types/types'; -const getStyle = (theme: Theme, type: ChatMessageType): any => { +const getStyle = (theme: Theme, type: ApiActivityType | ChatSenderType | "error"): any => { const defaultRadius = '16px'; const defaultStyle = { padding: theme.spacing(1, 2), @@ -56,7 +56,7 @@ const getStyle = (theme: Theme, type: ChatMessageType): any => { }; const styles: any = { - response: { + assistant: { ...defaultStyle, backgroundColor: theme.palette.primary.main, border: `1px solid ${theme.palette.secondary.main}`, @@ -90,9 +90,10 @@ const getStyle = (theme: Theme, type: ChatMessageType): any => { boxShadow: '0 1px 3px rgba(216, 58, 58, 0.15)', }, 'fact-check': 'qualifications', + generating: 'status', 'job-description': 'content', 'job-requirements': 'qualifications', - info: { + information: { ...defaultStyle, backgroundColor: '#BFD8D8', border: `1px solid ${theme.palette.secondary.main}`, @@ -156,25 +157,29 @@ const getStyle = (theme: Theme, type: ChatMessageType): any => { } } + if (!(type in styles)) { + console.log(`Style does not exist for: ${type}`); + } + return styles[type]; } -const getIcon = (messageType: ChatMessageType): React.ReactNode | null => { +const getIcon = (activityType: ApiActivityType | ChatSenderType | "error"): React.ReactNode | null => { const icons: any = { error: , generating: , - info: , + information: , preparing: , processing: , system: , thinking: , tooling: , }; - return icons[messageType] || null; + return icons[activityType] || null; } interface MessageProps extends BackstoryElementProps { - message: ChatMessage, + message: ChatMessageUser | ChatMessage | ChatMessageError | ChatMessageStatus, title?: string, chatSession?: ChatSession, className?: string, @@ -319,7 +324,7 @@ const MessageMeta = (props: MessageMetaProps) => { }; interface MessageContainerProps { - type: ChatMessageType, + type: ApiActivityType | ChatSenderType | "error", metadataView?: React.ReactNode | null, messageView?: React.ReactNode | null, sx?: SxProps, @@ -360,13 +365,21 @@ const Message = (props: MessageProps) => { setSnack }; const theme = useTheme(); - const style: any = getStyle(theme, message.type); + const type: ApiActivityType | ChatSenderType | "error" = ('activity' in message) ? message.activity : ('error' in message) ? 'error' : (message as ChatMessage).role; + const style: any = getStyle(theme, type); const handleMetaExpandClick = () => { setMetaExpanded(!metaExpanded); }; - const content = message.content?.trim(); + let content; + if (typeof (message.content) === "string") { + content = message.content.trim(); + } else { + console.error(`message content is not a string`); + return (<>) + } + if (!content) { return (<>) }; @@ -376,7 +389,8 @@ const Message = (props: MessageProps) => { ); let metadataView = (<>); - if (message.metadata) { + const metadata: ChatMessageMetaData | null = ('metadata' in message) ? (message.metadata as ChatMessageMetaData || null) : null; + if (metadata) { metadataView = ( @@ -395,18 +409,18 @@ const Message = (props: MessageProps) => { - + ); } - const copyContent = message.sender === 'assistant' ? message.content : undefined; + const copyContent = (type === 'assistant') ? message.content : undefined; if (!expandable) { /* When not expandable, the styles are applied directly to MessageContainer */ return (<> - {messageView && } + {messageView && } ); } @@ -435,7 +449,7 @@ const Message = (props: MessageProps) => { {title || ''} - + ); diff --git a/frontend/src/components/VectorVisualizer.tsx b/frontend/src/components/VectorVisualizer.tsx index d64ec35..7755cf8 100644 --- a/frontend/src/components/VectorVisualizer.tsx +++ b/frontend/src/components/VectorVisualizer.tsx @@ -25,11 +25,6 @@ import { useAuth } from 'hooks/AuthContext'; import * as Types from 'types/types'; import { useSelectedCandidate } from 'hooks/GlobalContext'; import { useNavigate } from 'react-router-dom'; -import { Message } from './Message'; -const defaultMessage: Types.ChatMessageBase = { - type: "preparing", status: "done", sender: "system", sessionId: "", timestamp: new Date(), content: "" -}; - interface VectorVisualizerProps extends BackstoryPageProps { inline?: boolean; diff --git a/frontend/src/pages/CandidateChatPage.tsx b/frontend/src/pages/CandidateChatPage.tsx index 7b0703f..132a235 100644 --- a/frontend/src/pages/CandidateChatPage.tsx +++ b/frontend/src/pages/CandidateChatPage.tsx @@ -12,7 +12,7 @@ import { Send as SendIcon } from '@mui/icons-material'; import { useAuth } from 'hooks/AuthContext'; -import { ChatMessageBase, ChatMessage, ChatSession, ChatMessageUser } from 'types/types'; +import { ChatMessage, ChatSession, ChatMessageUser, ChatMessageError, ChatMessageStreaming, ChatMessageStatus } from 'types/types'; import { ConversationHandle } from 'components/Conversation'; import { BackstoryPageProps } from 'components/BackstoryTab'; import { Message } from 'components/Message'; @@ -25,7 +25,7 @@ import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryT import { BackstoryQuery } from 'components/BackstoryQuery'; const defaultMessage: ChatMessage = { - type: "preparing", status: "done", sender: "system", sessionId: "", timestamp: new Date(), content: "" + status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "user" }; const CandidateChatPage = forwardRef((props: BackstoryPageProps, ref) => { @@ -33,7 +33,7 @@ const CandidateChatPage = forwardRef((pr const navigate = useNavigate(); const { selectedCandidate } = useSelectedCandidate() const theme = useTheme(); - const [processingMessage, setProcessingMessage] = useState(null); + const [processingMessage, setProcessingMessage] = useState(null); const [streamingMessage, setStreamingMessage] = useState(null); const backstoryTextRef = useRef(null); @@ -48,32 +48,6 @@ const CandidateChatPage = forwardRef((pr const [streaming, setStreaming] = useState(false); const messagesEndRef = useRef(null); - // Load sessions for the selectedCandidate - const loadSessions = async () => { - if (!selectedCandidate) return; - - try { - setLoading(true); - const result = await apiClient.getCandidateChatSessions(selectedCandidate.username); - let session = null; - if (result.sessions.data.length === 0) { - session = await apiClient.createCandidateChatSession( - selectedCandidate.username, - 'candidate_chat', - `Backstory chat about ${selectedCandidate.fullName}` - ); - } else { - session = result.sessions.data[0]; - } - setChatSession(session); - setLoading(false); - } catch (error) { - setSnack('Unable to load chat session', 'error'); - } finally { - setLoading(false); - } - }; - // Load messages for current session const loadMessages = async () => { if (!chatSession?.id) return; @@ -107,20 +81,22 @@ const CandidateChatPage = forwardRef((pr // Send message const sendMessage = async (message: string) => { - if (!message.trim() || !chatSession?.id || streaming) return; + if (!message.trim() || !chatSession?.id || streaming || !selectedCandidate) return; const messageContent = message; setStreaming(true); const chatMessage: ChatMessageUser = { sessionId: chatSession.id, + role: "user", content: messageContent, status: "done", - type: "user", - sender: "user", + type: "text", timestamp: new Date() }; + setProcessingMessage({ ...defaultMessage, status: 'status', content: `Establishing connection with ${selectedCandidate.firstName}'s chat session.` }); + setMessages(prev => { const filtered = prev.filter((m: any) => m.id !== chatMessage.id); return [...filtered, chatMessage] as any; @@ -129,34 +105,31 @@ const CandidateChatPage = forwardRef((pr try { apiClient.sendMessageStream(chatMessage, { onMessage: (msg: ChatMessage) => { - console.log(`onMessage: ${msg.type} ${msg.content}`, msg); - if (msg.type === "response") { - setMessages(prev => { - const filtered = prev.filter((m: any) => m.id !== msg.id); - return [...filtered, msg] as any; - }); - setStreamingMessage(null); - setProcessingMessage(null); - } else { - setProcessingMessage(msg); - } + setMessages(prev => { + const filtered = prev.filter((m: any) => m.id !== msg.id); + return [...filtered, msg] as any; + }); + setStreamingMessage(null); + setProcessingMessage(null); }, - onError: (error: string | ChatMessageBase) => { + onError: (error: string | ChatMessageError) => { console.log("onError:", error); + let message: string; // Type-guard to determine if this is a ChatMessageBase or a string if (typeof error === "object" && error !== null && "content" in error) { - setProcessingMessage(error as ChatMessage); + setProcessingMessage(error); + message = error.content as string; } else { - setProcessingMessage({ ...defaultMessage, content: error as string }); + setProcessingMessage({ ...defaultMessage, status: "error", content: error }) } setStreaming(false); }, - onStreaming: (chunk: ChatMessageBase) => { + onStreaming: (chunk: ChatMessageStreaming) => { // console.log("onStreaming:", chunk); - setStreamingMessage({ ...defaultMessage, ...chunk }); + setStreamingMessage({ ...chunk, role: 'assistant' }); }, - onStatusChange: (status: string) => { - console.log(`onStatusChange: ${status}`); + onStatus: (status: ChatMessageStatus) => { + setProcessingMessage(status); }, onComplete: () => { console.log("onComplete"); @@ -178,7 +151,19 @@ const CandidateChatPage = forwardRef((pr // Load sessions when username changes useEffect(() => { - loadSessions(); + if (!selectedCandidate) return; + try { + setLoading(true); + apiClient.getOrCreateChatSession(selectedCandidate, `Backstory chat with ${selectedCandidate.fullName}`, 'candidate_chat') + .then(session => { + setChatSession(session); + setLoading(false); + }); + } catch (error) { + setSnack('Unable to load chat session', 'error'); + } finally { + setLoading(false); + } }, [selectedCandidate]); // Load messages when session changes @@ -195,9 +180,9 @@ const CandidateChatPage = forwardRef((pr const welcomeMessage: ChatMessage = { sessionId: chatSession?.id || '', - type: "info", + role: "information", + type: "text", status: "done", - sender: "system", timestamp: new Date(), content: `Welcome to the Backstory Chat about ${selectedCandidate.fullName}. Ask any questions you have about ${selectedCandidate.firstName}.` }; @@ -252,7 +237,7 @@ const CandidateChatPage = forwardRef((pr }, }}> {messages.length === 0 && } - {messages.map((message: ChatMessageBase) => ( + {messages.map((message: ChatMessage) => ( ))} {processingMessage !== null && ( diff --git a/frontend/src/pages/GenerateCandidate.tsx b/frontend/src/pages/GenerateCandidate.tsx index a20eb36..a38630b 100644 --- a/frontend/src/pages/GenerateCandidate.tsx +++ b/frontend/src/pages/GenerateCandidate.tsx @@ -8,7 +8,6 @@ import IconButton from '@mui/material/IconButton'; import CancelIcon from '@mui/icons-material/Cancel'; import SendIcon from '@mui/icons-material/Send'; import PropagateLoader from 'react-spinners/PropagateLoader'; -import { jsonrepair } from 'jsonrepair'; import { CandidateInfo } from '../components/CandidateInfo'; import { Quote } from 'components/Quote'; @@ -18,57 +17,24 @@ import { StyledMarkdown } from 'components/StyledMarkdown'; import { Scrollable } from '../components/Scrollable'; import { Pulse } from 'components/Pulse'; import { StreamingResponse } from 'services/api-client'; -import { ChatContext, ChatMessage, ChatMessageUser, ChatMessageBase, ChatSession, ChatQuery, Candidate, CandidateAI } from 'types/types'; +import { ChatMessage, ChatMessageUser, ChatSession, CandidateAI, ChatMessageStatus, ChatMessageError } from 'types/types'; import { useAuth } from 'hooks/AuthContext'; import { Message } from 'components/Message'; -import { Types } from '@uiw/react-json-view'; -import { assert } from 'console'; - -const emptyUser: CandidateAI = { - userType: "candidate", - isAI: true, - description: "[blank]", - username: "[blank]", - firstName: "[blank]", - lastName: "[blank]", - fullName: "[blank] [blank]", - questions: [], - location: { - city: '[blank]', - country: '[blank]' - }, - email: '[blank]', - createdAt: new Date(), - updatedAt: new Date(), - status: "pending", - skills: [], - experience: [], - education: [], - preferredJobTypes: [], - languages: [], - certifications: [], - isAdmin: false, - profileImage: undefined, - ragContentSize: 0 -}; const defaultMessage: ChatMessage = { - type: "preparing", status: "done", sender: "system", sessionId: "", timestamp: new Date(), content: "" + status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "user" }; const GenerateCandidate = (props: BackstoryElementProps) => { const { apiClient, user } = useAuth(); const { setSnack, submitQuery } = props; - const [streaming, setStreaming] = useState(false); - const [streamingMessage, setStreamingMessage] = useState(null); const [processingMessage, setProcessingMessage] = useState(null); const [processing, setProcessing] = useState(false); const [generatedUser, setGeneratedUser] = useState(null); const [prompt, setPrompt] = useState(''); const [resume, setResume] = useState(null); const [canGenImage, setCanGenImage] = useState(false); - const [timestamp, setTimestamp] = useState(0); - const [state, setState] = useState(0); // Replaced stateRef + const [timestamp, setTimestamp] = useState(''); const [shouldGenerateProfile, setShouldGenerateProfile] = useState(false); const [chatSession, setChatSession] = useState(null); const [loading, setLoading] = useState(false); @@ -83,61 +49,24 @@ const GenerateCandidate = (props: BackstoryElementProps) => { return; } - const createChatSession = async () => { - console.log('Creating chat session'); - try { - const response: ChatSession = await apiClient.createCandidateChatSession( - generatedUser.username, - "generate_image", - "Profile image generation" - ); - setChatSession(response); - console.log(`Chat session created for generate_image`, response); - setSnack(`Chat session created for generate_image: ${response.id}`); - } catch (e) { - console.error(e); - setSnack("Unable to create image generation session.", "error"); - } - }; - - setLoading(true); - createChatSession().then(() => { setLoading(false) }); - }, [generatedUser, chatSession, loading, setChatSession, setLoading, setSnack]); - - const generatePersona = async (prompt: string) => { - const userMessage: ChatMessageUser = { - content: prompt, - sessionId: "", - sender: "user", - status: "done", - type: "user", - timestamp: new Date() - }; - setPrompt(prompt || ''); - setProcessing(true); - setProcessingMessage({ ...defaultMessage, content: "Generating persona..." }); try { - const result = await apiClient.createCandidateAI(userMessage); - console.log(result.message, result); - setGeneratedUser(result.candidate); - setResume(result.resume); - setCanGenImage(true); - setShouldGenerateProfile(true); // Reset the flag + setLoading(true); + apiClient.getOrCreateChatSession(generatedUser, `Profile image generator for ${generatedUser.fullName}`, 'generate_image') + .then(session => { + setChatSession(session); + setLoading(false); + }); } catch (error) { - console.error(error); - setPrompt(''); - setResume(null); - setProcessing(false); - setProcessingMessage(null); - setSnack("Unable to generate AI persona", "error"); + setSnack('Unable to load chat session', 'error'); + } finally { + setLoading(false); } - }; + }, [generatedUser, chatSession, loading, setChatSession, setLoading, setSnack, apiClient]); const cancelQuery = useCallback(() => { if (controllerRef.current) { controllerRef.current.cancel(); controllerRef.current = null; - setState(0); setProcessing(false); } }, []); @@ -146,8 +75,38 @@ const GenerateCandidate = (props: BackstoryElementProps) => { if (processing) { return; } + + const generatePersona = async (prompt: string) => { + const userMessage: ChatMessageUser = { + type: "text", + role: "user", + content: prompt, + sessionId: "", + status: "done", + timestamp: new Date() + }; + setPrompt(prompt || ''); + setProcessing(true); + setProcessingMessage({ ...defaultMessage, content: "Generating persona..." }); + try { + const result = await apiClient.createCandidateAI(userMessage); + console.log(result.message, result); + setGeneratedUser(result.candidate); + setResume(result.resume); + setCanGenImage(true); + setShouldGenerateProfile(true); // Reset the flag + } catch (error) { + console.error(error); + setPrompt(''); + setResume(null); + setProcessing(false); + setProcessingMessage(null); + setSnack("Unable to generate AI persona", "error"); + } + }; + generatePersona(value); - }, [processing, generatePersona]); + }, [processing, apiClient, setSnack]); const handleSendClick = useCallback(() => { const value = (backstoryTextRef.current && backstoryTextRef.current.getAndResetValue()) || ""; @@ -173,13 +132,12 @@ const GenerateCandidate = (props: BackstoryElementProps) => { setProcessingMessage({ ...defaultMessage, content: 'Starting image generation...' }); setProcessing(true); setCanGenImage(false); - setState(3); const chatMessage: ChatMessageUser = { sessionId: chatSession.id || '', + role: "user", status: "done", - type: "user", - sender: "user", + type: "text", timestamp: new Date(), content: prompt }; @@ -187,34 +145,23 @@ const GenerateCandidate = (props: BackstoryElementProps) => { controllerRef.current = apiClient.sendMessageStream(chatMessage, { onMessage: async (msg: ChatMessage) => { console.log(`onMessage: ${msg.type} ${msg.content}`, msg); - if (msg.type === "heartbeat" && msg.content) { - const heartbeat = JSON.parse(msg.content); - setTimestamp(heartbeat.timestamp); - } - if (msg.type === "thinking" && msg.content) { - const status = JSON.parse(msg.content); - setProcessingMessage({ ...defaultMessage, content: status.message }); - } - if (msg.type === "response") { - controllerRef.current = null; - try { - await apiClient.updateCandidate(generatedUser.id || '', { profileImage: "profile.png" }); - const { success, message } = await apiClient.deleteChatSession(chatSession.id || ''); - console.log(`Profile generated for ${username} and chat session was ${!success ? 'not ' : ''} deleted: ${message}}`); - setGeneratedUser({ - ...generatedUser, - profileImage: "profile.png" - } as CandidateAI); - setState(0); - setCanGenImage(true); - setShouldGenerateProfile(false); - } catch (error) { - console.error(error); - setSnack(`Unable to update ${username} to indicate they have a profile picture.`, "error"); - } + controllerRef.current = null; + try { + await apiClient.updateCandidate(generatedUser.id || '', { profileImage: "profile.png" }); + const { success, message } = await apiClient.deleteChatSession(chatSession.id || ''); + console.log(`Profile generated for ${username} and chat session was ${!success ? 'not ' : ''} deleted: ${message}}`); + setGeneratedUser({ + ...generatedUser, + profileImage: "profile.png" + } as CandidateAI); + setCanGenImage(true); + setShouldGenerateProfile(false); + } catch (error) { + console.error(error); + setSnack(`Unable to update ${username} to indicate they have a profile picture.`, "error"); } }, - onError: (error) => { + onError: (error: string | ChatMessageError) => { console.log("onError:", error); // Type-guard to determine if this is a ChatMessageBase or a string if (typeof error === "object" && error !== null && "content" in error) { @@ -223,27 +170,28 @@ const GenerateCandidate = (props: BackstoryElementProps) => { setSnack(error as string, "error"); } setProcessingMessage(null); - setStreaming(false); setProcessing(false); controllerRef.current = null; - setState(0); setCanGenImage(true); setShouldGenerateProfile(false); }, onComplete: () => { setProcessingMessage(null); - setStreaming(false); setProcessing(false); controllerRef.current = null; - setState(0); setCanGenImage(true); setShouldGenerateProfile(false); }, - onStatusChange: (status: string) => { + onStatus: (status: ChatMessageStatus) => { + if (status.activity === "heartbeat" && status.content) { + setTimestamp(status.timestamp?.toISOString() || ''); + } else if (status.content) { + setProcessingMessage({ ...defaultMessage, content: status.content }); + } console.log(`onStatusChange: ${status}`); }, }); - }, [chatSession, shouldGenerateProfile, generatedUser, prompt, setSnack]); + }, [chatSession, shouldGenerateProfile, generatedUser, prompt, setSnack, apiClient]); if (!user?.isAdmin) { return (You must be logged in as an admin to generate AI candidates.); diff --git a/frontend/src/pages/LoadingPage.tsx b/frontend/src/pages/LoadingPage.tsx index a6cb580..4655c2a 100644 --- a/frontend/src/pages/LoadingPage.tsx +++ b/frontend/src/pages/LoadingPage.tsx @@ -5,8 +5,8 @@ import { ChatMessage } from 'types/types'; const LoadingPage = (props: BackstoryPageProps) => { const preamble: ChatMessage = { - sender: 'system', - type: 'preparing', + role: 'assistant', + type: 'text', status: 'done', sessionId: '', content: 'Please wait while connecting to Backstory...', diff --git a/frontend/src/pages/LoginRequired.tsx b/frontend/src/pages/LoginRequired.tsx index 945f060..460df04 100644 --- a/frontend/src/pages/LoginRequired.tsx +++ b/frontend/src/pages/LoginRequired.tsx @@ -5,8 +5,8 @@ import { ChatMessage } from 'types/types'; const LoginRequired = (props: BackstoryPageProps) => { const preamble: ChatMessage = { - sender: 'system', - type: 'preparing', + role: 'assistant', + type: 'text', status: 'done', sessionId: '', content: 'You must be logged to view this feature.', diff --git a/frontend/src/services/api-client.ts b/frontend/src/services/api-client.ts index 5a20c97..d33f4fc 100644 --- a/frontend/src/services/api-client.ts +++ b/frontend/src/services/api-client.ts @@ -39,11 +39,11 @@ import { // ============================ interface StreamingOptions { - onStatusChange?: (status: Types.ChatStatusType) => void; + onStatus?: (status: Types.ChatMessageStatus) => void; onMessage?: (message: Types.ChatMessage) => void; - onStreaming?: (chunk: Types.ChatMessageBase) => void; + onStreaming?: (chunk: Types.ChatMessageStreaming) => void; onComplete?: () => void; - onError?: (error: string | Types.ChatMessageBase) => void; + onError?: (error: string | Types.ChatMessageError) => void; onWarn?: (warning: string) => void; signal?: AbortSignal; } @@ -761,6 +761,20 @@ class ApiClient { return result; } + async getOrCreateChatSession(candidate: Types.Candidate, title: string, context_type: Types.ChatContextType) : Promise { + const result = await this.getCandidateChatSessions(candidate.username); + /* Find the 'candidate_chat' session if it exists, otherwise create it */ + let session = result.sessions.data.find(session => session.title === 'candidate_chat'); + if (!session) { + session = await this.createCandidateChatSession( + candidate.username, + context_type, + title + ); + } + return session; + } + async getCandidateSimilarContent(query: string ): Promise { const response = await fetch(`${this.baseUrl}/candidates/rag-search`, { @@ -965,7 +979,7 @@ class ApiClient { * Send message with streaming response support and date conversion */ sendMessageStream( - chatMessage: Types.ChatMessageBase, + chatMessage: Types.ChatMessageUser, options: StreamingOptions = {} ): StreamingResponse { const abortController = new AbortController(); @@ -998,9 +1012,8 @@ class ApiClient { const decoder = new TextDecoder(); let buffer = ''; - let incomingMessage: Types.ChatMessage | null = null; + let streamingMessage: Types.ChatMessageStreaming | null = null; const incomingMessageList: Types.ChatMessage[] = []; - let incomingStatus : Types.ChatStatusType | null = null; try { while (true) { const { done, value } = await reader.read(); @@ -1022,39 +1035,41 @@ class ApiClient { try { if (line.startsWith('data: ')) { const data = line.slice(5).trim(); - const incoming: Types.ChatMessageBase = JSON.parse(data); + const incoming: any = JSON.parse(data); - // Convert date fields for incoming messages - const convertedIncoming = convertChatMessageFromApi(incoming); + console.log(incoming.status, incoming); - // Trigger callbacks based on status - if (convertedIncoming.status !== incomingStatus) { - options.onStatusChange?.(convertedIncoming.status); - incomingStatus = convertedIncoming.status; - } - // Handle different status types - switch (convertedIncoming.status) { + switch (incoming.status) { case 'streaming': - if (incomingMessage === null) { - incomingMessage = {...convertedIncoming}; + console.log(incoming.status, incoming); + const streaming = Types.convertChatMessageStreamingFromApi(incoming); + if (streamingMessage === null) { + streamingMessage = {...streaming}; } else { // Can't do a simple += as typescript thinks .content might not be there - incomingMessage.content = (incomingMessage?.content || '') + convertedIncoming.content; + streamingMessage.content = (streamingMessage?.content || '') + streaming.content; // Update timestamp to latest - incomingMessage.timestamp = convertedIncoming.timestamp; + streamingMessage.timestamp = streamingMessage.timestamp; } - options.onStreaming?.(convertedIncoming); + options.onStreaming?.(streamingMessage); break; - case 'error': - options.onError?.(convertedIncoming); + case 'status': + const status = Types.convertChatMessageStatusFromApi(incoming); + options.onStatus?.(status); break; - default: - incomingMessageList.push(convertedIncoming); + case 'error': + const error = Types.convertChatMessageErrorFromApi(incoming); + options.onError?.(error); + break; + + case 'done': + const message = Types.convertChatMessageFromApi(incoming); + incomingMessageList.push(message); try { - options.onMessage?.(convertedIncoming); + options.onMessage?.(message); } catch (error) { console.error('onMessage handler failed: ', error); } diff --git a/frontend/src/types/types.ts b/frontend/src/types/types.ts index 5283da4..88d0200 100644 --- a/frontend/src/types/types.ts +++ b/frontend/src/types/types.ts @@ -1,6 +1,6 @@ // Generated TypeScript types from Pydantic models // Source: src/backend/models.py -// Generated on: 2025-06-04T17:02:08.242818 +// Generated on: 2025-06-05T00:24:02.132276 // DO NOT EDIT MANUALLY - This file is auto-generated // ============================ @@ -11,15 +11,17 @@ export type AIModelType = "qwen2.5" | "flux-schnell"; export type ActivityType = "login" | "search" | "view_job" | "apply_job" | "message" | "update_profile" | "chat"; +export type ApiActivityType = "system" | "info" | "searching" | "thinking" | "generating" | "generating_image" | "tooling" | "heartbeat"; + +export type ApiMessageType = "binary" | "text" | "json"; + +export type ApiStatusType = "streaming" | "status" | "done" | "error"; + export type ApplicationStatus = "applied" | "reviewing" | "interview" | "offer" | "rejected" | "accepted" | "withdrawn"; export type ChatContextType = "job_search" | "job_requirements" | "candidate_chat" | "interview_prep" | "resume_review" | "general" | "generate_persona" | "generate_profile" | "generate_image" | "rag_search" | "skill_match"; -export type ChatMessageType = "error" | "generating" | "info" | "preparing" | "processing" | "heartbeat" | "response" | "searching" | "rag_result" | "system" | "thinking" | "tooling" | "user"; - -export type ChatSenderType = "user" | "assistant" | "agent" | "system"; - -export type ChatStatusType = "initializing" | "streaming" | "status" | "done" | "error"; +export type ChatSenderType = "user" | "assistant" | "system" | "information" | "warning" | "error"; export type ColorBlindMode = "protanopia" | "deuteranopia" | "tritanopia" | "none"; @@ -88,6 +90,15 @@ export interface Analytics { segment?: string; } +export interface ApiMessage { + id?: string; + sessionId: string; + senderId?: string; + status: "streaming" | "status" | "done" | "error"; + type: "binary" | "text" | "json"; + timestamp?: Date; +} + export interface ApiResponse { success: boolean; data?: any; @@ -284,24 +295,22 @@ export interface ChatMessage { id?: string; sessionId: string; senderId?: string; - status: "initializing" | "streaming" | "status" | "done" | "error"; - type: "error" | "generating" | "info" | "preparing" | "processing" | "heartbeat" | "response" | "searching" | "rag_result" | "system" | "thinking" | "tooling" | "user"; - sender: "user" | "assistant" | "agent" | "system"; + status: "streaming" | "status" | "done" | "error"; + type: "binary" | "text" | "json"; timestamp?: Date; - tunables?: Tunables; + role: "user" | "assistant" | "system" | "information" | "warning" | "error"; content: string; + tunables?: Tunables; metadata?: ChatMessageMetaData; } -export interface ChatMessageBase { +export interface ChatMessageError { id?: string; sessionId: string; senderId?: string; - status: "initializing" | "streaming" | "status" | "done" | "error"; - type: "error" | "generating" | "info" | "preparing" | "processing" | "heartbeat" | "response" | "searching" | "rag_result" | "system" | "thinking" | "tooling" | "user"; - sender: "user" | "assistant" | "agent" | "system"; + status: "streaming" | "status" | "done" | "error"; + type: "binary" | "text" | "json"; timestamp?: Date; - tunables?: Tunables; content: string; } @@ -328,25 +337,44 @@ export interface ChatMessageRagSearch { id?: string; sessionId: string; senderId?: string; - status: "initializing" | "streaming" | "status" | "done" | "error"; - type: "error" | "generating" | "info" | "preparing" | "processing" | "heartbeat" | "response" | "searching" | "rag_result" | "system" | "thinking" | "tooling" | "user"; - sender: "user" | "assistant" | "agent" | "system"; + status: "streaming" | "status" | "done" | "error"; + type: "binary" | "text" | "json"; timestamp?: Date; - tunables?: Tunables; - content: string; dimensions: number; + content: Array; +} + +export interface ChatMessageStatus { + id?: string; + sessionId: string; + senderId?: string; + status: "streaming" | "status" | "done" | "error"; + type: "binary" | "text" | "json"; + timestamp?: Date; + activity: "system" | "info" | "searching" | "thinking" | "generating" | "generating_image" | "tooling" | "heartbeat"; + content: any; +} + +export interface ChatMessageStreaming { + id?: string; + sessionId: string; + senderId?: string; + status: "streaming" | "status" | "done" | "error"; + type: "binary" | "text" | "json"; + timestamp?: Date; + content: string; } export interface ChatMessageUser { id?: string; sessionId: string; senderId?: string; - status: "initializing" | "streaming" | "status" | "done" | "error"; - type: "error" | "generating" | "info" | "preparing" | "processing" | "heartbeat" | "response" | "searching" | "rag_result" | "system" | "thinking" | "tooling" | "user"; - sender: "user" | "assistant" | "agent" | "system"; + status: "streaming" | "status" | "done" | "error"; + type: "binary" | "text" | "json"; timestamp?: Date; - tunables?: Tunables; + role: "user" | "assistant" | "system" | "information" | "warning" | "error"; content: string; + tunables?: Tunables; } export interface ChatOptions { @@ -644,6 +672,20 @@ export interface JobRequirements { preferredAttributes?: Array; } +export interface JobRequirementsMessage { + id?: string; + sessionId: string; + senderId?: string; + status: "streaming" | "status" | "done" | "error"; + type: "binary" | "text" | "json"; + timestamp?: Date; + title?: string; + summary?: string; + company?: string; + description: string; + requirements?: JobRequirements; +} + export interface JobResponse { success: boolean; data?: Job; @@ -923,6 +965,19 @@ export function convertAnalyticsFromApi(data: any): Analytics { timestamp: new Date(data.timestamp), }; } +/** + * Convert ApiMessage from API response, parsing date fields + * Date fields: timestamp + */ +export function convertApiMessageFromApi(data: any): ApiMessage { + if (!data) return data; + + return { + ...data, + // Convert timestamp from ISO string to Date + timestamp: data.timestamp ? new Date(data.timestamp) : undefined, + }; +} /** * Convert ApplicationDecision from API response, parsing date fields * Date fields: date @@ -1067,10 +1122,10 @@ export function convertChatMessageFromApi(data: any): ChatMessage { }; } /** - * Convert ChatMessageBase from API response, parsing date fields + * Convert ChatMessageError from API response, parsing date fields * Date fields: timestamp */ -export function convertChatMessageBaseFromApi(data: any): ChatMessageBase { +export function convertChatMessageErrorFromApi(data: any): ChatMessageError { if (!data) return data; return { @@ -1092,6 +1147,32 @@ export function convertChatMessageRagSearchFromApi(data: any): ChatMessageRagSea timestamp: data.timestamp ? new Date(data.timestamp) : undefined, }; } +/** + * Convert ChatMessageStatus from API response, parsing date fields + * Date fields: timestamp + */ +export function convertChatMessageStatusFromApi(data: any): ChatMessageStatus { + if (!data) return data; + + return { + ...data, + // Convert timestamp from ISO string to Date + timestamp: data.timestamp ? new Date(data.timestamp) : undefined, + }; +} +/** + * Convert ChatMessageStreaming from API response, parsing date fields + * Date fields: timestamp + */ +export function convertChatMessageStreamingFromApi(data: any): ChatMessageStreaming { + if (!data) return data; + + return { + ...data, + // Convert timestamp from ISO string to Date + timestamp: data.timestamp ? new Date(data.timestamp) : undefined, + }; +} /** * Convert ChatMessageUser from API response, parsing date fields * Date fields: timestamp @@ -1268,6 +1349,19 @@ export function convertJobFullFromApi(data: any): JobFull { featuredUntil: data.featuredUntil ? new Date(data.featuredUntil) : undefined, }; } +/** + * Convert JobRequirementsMessage from API response, parsing date fields + * Date fields: timestamp + */ +export function convertJobRequirementsMessageFromApi(data: any): JobRequirementsMessage { + if (!data) return data; + + return { + ...data, + // Convert timestamp from ISO string to Date + timestamp: data.timestamp ? new Date(data.timestamp) : undefined, + }; +} /** * Convert MessageReaction from API response, parsing date fields * Date fields: timestamp @@ -1348,6 +1442,8 @@ export function convertFromApi(data: any, modelType: string): T { switch (modelType) { case 'Analytics': return convertAnalyticsFromApi(data) as T; + case 'ApiMessage': + return convertApiMessageFromApi(data) as T; case 'ApplicationDecision': return convertApplicationDecisionFromApi(data) as T; case 'Attachment': @@ -1366,10 +1462,14 @@ export function convertFromApi(data: any, modelType: string): T { return convertCertificationFromApi(data) as T; case 'ChatMessage': return convertChatMessageFromApi(data) as T; - case 'ChatMessageBase': - return convertChatMessageBaseFromApi(data) as T; + case 'ChatMessageError': + return convertChatMessageErrorFromApi(data) as T; case 'ChatMessageRagSearch': return convertChatMessageRagSearchFromApi(data) as T; + case 'ChatMessageStatus': + return convertChatMessageStatusFromApi(data) as T; + case 'ChatMessageStreaming': + return convertChatMessageStreamingFromApi(data) as T; case 'ChatMessageUser': return convertChatMessageUserFromApi(data) as T; case 'ChatSession': @@ -1394,6 +1494,8 @@ export function convertFromApi(data: any, modelType: string): T { return convertJobApplicationFromApi(data) as T; case 'JobFull': return convertJobFullFromApi(data) as T; + case 'JobRequirementsMessage': + return convertJobRequirementsMessageFromApi(data) as T; case 'MessageReaction': return convertMessageReactionFromApi(data) as T; case 'RAGConfiguration': diff --git a/src/backend/agents/base.py b/src/backend/agents/base.py index 0ea0917..7f7f26a 100644 --- a/src/backend/agents/base.py +++ b/src/backend/agents/base.py @@ -24,7 +24,7 @@ from datetime import datetime, UTC from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore import numpy as np # type: ignore -from models import ( LLMMessage, ChatQuery, ChatMessage, ChatOptions, ChatMessageBase, ChatMessageUser, Tunables, ChatMessageType, ChatSenderType, ChatStatusType, ChatMessageMetaData, Candidate) +from models import ( ApiActivityType, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatQuery, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate) from logger import logger import defines from .registry import agent_registry @@ -300,12 +300,35 @@ class Agent(BaseModel, ABC): ) self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count) + def get_rag_context(self, rag_message: ChatMessageRagSearch) -> str: + """ + Extracts the RAG context from the rag_message. + """ + if not rag_message.content: + return "" + + context = [] + for chroma_results in rag_message.content: + for index, metadata in enumerate(chroma_results.metadatas): + content = "\n".join([ + line.strip() + for line in chroma_results.documents[index].split("\n") + if line + ]).strip() + context.append(f""" +Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")} +Document reference: {chroma_results.ids[index]} +Content: {content} +""") + return "\n".join(context) + async def generate_rag_results( self, - chat_message: ChatMessage, + session_id: str, + prompt: str, top_k: int=defines.default_rag_top_k, threshold: float=defines.default_rag_threshold, - ) -> AsyncGenerator[ChatMessage, None]: + ) -> AsyncGenerator[ChatMessageRagSearch | ChatMessageError | ChatMessageStatus, None]: """ Generate RAG results for the given query. @@ -315,223 +338,193 @@ class Agent(BaseModel, ABC): Returns: A list of dictionaries containing the RAG results. """ - rag_message = ChatMessage( - session_id=chat_message.session_id, - tunables=chat_message.tunables, - status=ChatStatusType.INITIALIZING, - type=ChatMessageType.PREPARING, - sender=ChatSenderType.ASSISTANT, - content="", - timestamp=datetime.now(UTC), - metadata=ChatMessageMetaData() - ) - if not self.user: - logger.error("No user set for RAG generation") - rag_message.status = ChatStatusType.DONE - rag_message.content = "" - yield rag_message + error_message = ChatMessageError( + session_id=session_id, + content="No user set for RAG generation." + ) + yield error_message return - try: - entries: int = 0 - user: Candidate = self.user - rag_content: str = "" - for rag in user.rags: - if not rag.enabled: - continue - status_message = ChatMessage( - session_id=chat_message.session_id, - sender=ChatSenderType.AGENT, - status = ChatStatusType.INITIALIZING, - type = ChatMessageType.SEARCHING, - content = f"Checking RAG context {rag.name}...") - yield status_message + results : List[ChromaDBGetResponse] = [] + entries: int = 0 + user: Candidate = self.user + for rag in user.rags: + if not rag.enabled: + continue + status_message = ChatMessageStatus( + session_id=session_id, + activity=ApiActivityType.SEARCHING, + content = f"Searching RAG context {rag.name}..." + ) + yield status_message + + try: chroma_results = user.file_watcher.find_similar( - query=chat_message.content, top_k=top_k, threshold=threshold + query=prompt, top_k=top_k, threshold=threshold ) - if chroma_results: - query_embedding = np.array(chroma_results["query_embedding"]).flatten() + if not chroma_results: + continue + query_embedding = np.array(chroma_results["query_embedding"]).flatten() - umap_2d = user.file_watcher.umap_model_2d.transform([query_embedding])[0] - umap_3d = user.file_watcher.umap_model_3d.transform([query_embedding])[0] + umap_2d = user.file_watcher.umap_model_2d.transform([query_embedding])[0] + umap_3d = user.file_watcher.umap_model_3d.transform([query_embedding])[0] - rag_metadata = ChromaDBGetResponse( - query=chat_message.content, - query_embedding=query_embedding.tolist(), - name=rag.name, - ids=chroma_results.get("ids", []), - embeddings=chroma_results.get("embeddings", []), - documents=chroma_results.get("documents", []), - metadatas=chroma_results.get("metadatas", []), - umap_embedding_2d=umap_2d.tolist(), - umap_embedding_3d=umap_3d.tolist(), - size=user.file_watcher.collection.count() - ) + rag_metadata = ChromaDBGetResponse( + name=rag.name, + query_embedding=query_embedding.tolist(), + ids=chroma_results.get("ids", []), + embeddings=chroma_results.get("embeddings", []), + documents=chroma_results.get("documents", []), + metadatas=chroma_results.get("metadatas", []), + umap_embedding_2d=umap_2d.tolist(), + umap_embedding_3d=umap_3d.tolist(), + ) + results.append(rag_metadata) + except Exception as e: + continue_message = ChatMessageStatus( + session_id=session_id, + activity=ApiActivityType.SEARCHING, + content=f"Error searching RAG context {rag.name}: {str(e)}" + ) + yield continue_message - entries += len(rag_metadata.documents) - rag_message.metadata.rag_results.append(rag_metadata) - - for index, metadata in enumerate(chroma_results["metadatas"]): - content = "\n".join( - [ - line.strip() - for line in chroma_results["documents"][index].split("\n") - if line - ] - ).strip() - rag_content += f""" -Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")} -Document reference: {chroma_results["ids"][index]} -Content: { content } -""" - rag_message.content = rag_content.strip() - rag_message.type = ChatMessageType.RAG_RESULT - rag_message.status = ChatStatusType.DONE - yield rag_message - return - - except Exception as e: - rag_message.status = ChatStatusType.ERROR - rag_message.content = f"Error generating RAG results: {str(e)}" - logger.error(traceback.format_exc()) - logger.error(rag_message.content) - yield rag_message - return - - async def llm_one_shot(self, llm: Any, model: str, user_message: ChatMessageUser, system_prompt: str, temperature=0.7): - chat_message = ChatMessage( - session_id=user_message.session_id, - tunables=user_message.tunables, - status=ChatStatusType.INITIALIZING, - type=ChatMessageType.PREPARING, - sender=ChatSenderType.AGENT, - content="", - timestamp=datetime.now(UTC) + final_message = ChatMessageRagSearch( + session_id=session_id, + content=results, + status=ApiStatusType.DONE, ) + yield final_message + return + + async def llm_one_shot( + self, + llm: Any, model: str, + session_id: str, prompt: str, system_prompt: str, + tunables: Optional[Tunables] = None, + temperature=0.7) -> AsyncGenerator[ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessage, None]: self.set_optimal_context_size( - llm, model, prompt=chat_message.content + llm=llm, model=model, prompt=prompt ) - chat_message.metadata = ChatMessageMetaData() - chat_message.metadata.options = ChatOptions( + options = ChatOptions( seed=8911, num_ctx=self.context_size, - temperature=temperature, # Higher temperature to encourage tool usage + temperature=temperature, ) messages: List[LLMMessage] = [ LLMMessage(role="system", content=system_prompt), - LLMMessage(role="user", content=user_message.content), + LLMMessage(role="user", content=prompt), ] - # Reset the response for streaming - chat_message.content = "" - chat_message.type = ChatMessageType.GENERATING - chat_message.status = ChatStatusType.STREAMING + status_message = ChatMessageStatus( + session_id=session_id, + activity=ApiActivityType.GENERATING, + content=f"Generating response..." + ) + yield status_message - logger.info(f"Message options: {chat_message.metadata.options.model_dump(exclude_unset=True)}") + logger.info(f"Message options: {options.model_dump(exclude_unset=True)}") response = None + content = "" for response in llm.chat( model=model, messages=messages, options={ - **chat_message.metadata.options.model_dump(exclude_unset=True), + **options.model_dump(exclude_unset=True), }, stream=True, ): if not response: - chat_message.status = ChatStatusType.ERROR - chat_message.content = "No response from LLM." - yield chat_message + error_message = ChatMessageError( + session_id=session_id, + content="No response from LLM." + ) + yield error_message return - chat_message.content += response.message.content + content += response.message.content if not response.done: - chat_chunk = model_cast.cast_to_model(ChatMessageBase, chat_message) - chat_chunk.content = response.message.content - yield chat_message - continue + streaming_message = ChatMessageStreaming( + session_id=session_id, + content=response.message.content, + status=ApiStatusType.STREAMING, + ) + yield streaming_message if not response: - chat_message.status = ChatStatusType.ERROR - chat_message.content = "No response from LLM." - yield chat_message + error_message = ChatMessageError( + session_id=session_id, + content="No response from LLM." + ) + yield error_message return self.collect_metrics(response) - chat_message.metadata.eval_count += response.eval_count - chat_message.metadata.eval_duration += response.eval_duration - chat_message.metadata.prompt_eval_count += response.prompt_eval_count - chat_message.metadata.prompt_eval_duration += response.prompt_eval_duration self.context_tokens = ( response.prompt_eval_count + response.eval_count ) - chat_message.type = ChatMessageType.RESPONSE - chat_message.status = ChatStatusType.DONE - yield chat_message - - async def generate( - self, llm: Any, model: str, user_message: ChatMessageUser, user: Candidate | None, temperature=0.7 - ) -> AsyncGenerator[ChatMessage | ChatMessageBase, None]: - logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") chat_message = ChatMessage( - session_id=user_message.session_id, - tunables=user_message.tunables, - status=ChatStatusType.INITIALIZING, - type=ChatMessageType.PREPARING, - sender=ChatSenderType.ASSISTANT, - content="", - timestamp=datetime.now(UTC) - ) + session_id=session_id, + tunables=tunables, + status=ApiStatusType.DONE, + content=content, + metadata = ChatMessageMetaData( + options=options, + eval_count=response.eval_count, + eval_duration=response.eval_duration, + prompt_eval_count=response.prompt_eval_count, + prompt_eval_duration=response.prompt_eval_duration, - self.set_optimal_context_size( - llm, model, prompt=chat_message.content + ) ) + yield chat_message + return - chat_message.metadata = ChatMessageMetaData() - chat_message.metadata.options = ChatOptions( - seed=8911, - num_ctx=self.context_size, - temperature=temperature, # Higher temperature to encourage tool usage + async def generate( + self, llm: Any, model: str, + session_id: str, prompt: str, + tunables: Optional[Tunables] = None, + temperature=0.7 + ) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]: + if not self.user: + error_message = ChatMessageError( + session_id=session_id, + content="No user set for chat generation." + ) + yield error_message + return + + user_message = ChatMessageUser( + session_id=session_id, + content=prompt, ) - - # Create a dict for storing various timing stats - chat_message.metadata.timers = {} + user = self.user self.metrics.generate_count.labels(agent=self.agent_type).inc() with self.metrics.generate_duration.labels(agent=self.agent_type).time(): + context = None + if self.user: + rag_message = None + async for rag_message in self.generate_rag_results(session_id=session_id, prompt=prompt): + if rag_message.status == ApiStatusType.ERROR: + yield rag_message + return + # Only yield messages that are in a streaming state + if rag_message.status == ApiStatusType.STATUS: + yield rag_message - rag_message : Optional[ChatMessage] = None - async for rag_message in self.generate_rag_results(chat_message=user_message): - if rag_message.status == ChatStatusType.ERROR: - chat_message.status = rag_message.status - chat_message.content = rag_message.content - yield chat_message - return - yield rag_message - - rag_context = "" - if rag_message: - rag_results: List[ChromaDBGetResponse] = rag_message.metadata.rag_results - chat_message.metadata.rag_results = rag_results - for chroma_results in rag_results: - for index, metadata in enumerate(chroma_results.metadatas): - content = "\n".join([ - line.strip() - for line in chroma_results.documents[index].split("\n") - if line - ]).strip() - rag_context += f""" -Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")} -Document reference: {chroma_results.ids[index]} -Content: { content } - -""" + if not isinstance(rag_message, ChatMessageRagSearch): + raise ValueError( + f"Expected ChatMessageRagSearch, got {type(rag_message)}" + ) + + context = self.get_rag_context(rag_message) # Create a pruned down message list based purely on the prompt and responses, # discarding the full preamble generated by prepare_message @@ -540,24 +533,24 @@ Content: { content } ] # Add the conversation history to the messages messages.extend([ - LLMMessage(role=m.sender, content=m.content.strip()) + LLMMessage(role="user" if isinstance(m, ChatMessageUser) else "assistant", content=m.content) for m in self.conversation ]) # Add the RAG context to the messages if available - if rag_context and user: + if context: messages.append( LLMMessage( role="user", - content=f"<|context|>\nThe following is context information about {user.full_name}:\n{rag_context.strip()}\n\n\nPrompt to respond to:\n{user_message.content.strip()}\n" + content=f"<|context|>\nThe following is context information about {self.user.full_name}:\n{context}\n\n\nPrompt to respond to:\n{prompt}\n" ) ) else: # Only the actual user query is provided with the full context message messages.append( - LLMMessage(role=user_message.sender, content=user_message.content.strip()) + LLMMessage(role="user", content=prompt) ) - chat_message.metadata.llm_history = messages + llm_history = messages # use_tools = message.tunables.enable_tools and len(self.context.tools) > 0 # message.metadata.tools = { @@ -660,57 +653,89 @@ Content: { content } # return # not use_tools - chat_message.type = ChatMessageType.THINKING - chat_message.content = f"Generating response..." - yield chat_message + status_message = ChatMessageStatus( + session_id=session_id, + activity=ApiActivityType.GENERATING, + content=f"Generating response..." + ) + yield status_message - # Reset the response for streaming - chat_message.content = "" + # Set the response for streaming + self.set_optimal_context_size( + llm, model, prompt=prompt + ) + + options = ChatOptions( + seed=8911, + num_ctx=self.context_size, + temperature=temperature, + ) + logger.info(f"Message options: {options.model_dump(exclude_unset=True)}") + content = "" start_time = time.perf_counter() - chat_message.type = ChatMessageType.GENERATING - chat_message.status = ChatStatusType.STREAMING - + response = None for response in llm.chat( model=model, messages=messages, options={ - **chat_message.metadata.options.model_dump(exclude_unset=True), + **options.model_dump(exclude_unset=True), }, stream=True, ): if not response: - chat_message.status = ChatStatusType.ERROR - chat_message.content = "No response from LLM." - yield chat_message + error_message = ChatMessageError( + session_id=session_id, + content="No response from LLM." + ) + yield error_message return - chat_message.content += response.message.content + content += response.message.content if not response.done: - chat_chunk = model_cast.cast_to_model(ChatMessageBase, chat_message) - chat_chunk.content = response.message.content - yield chat_message - continue - - if response.done: - self.collect_metrics(response) - chat_message.metadata.eval_count += response.eval_count - chat_message.metadata.eval_duration += response.eval_duration - chat_message.metadata.prompt_eval_count += response.prompt_eval_count - chat_message.metadata.prompt_eval_duration += response.prompt_eval_duration - self.context_tokens = ( - response.prompt_eval_count + response.eval_count + streaming_message = ChatMessageStreaming( + session_id=session_id, + content=response.message.content, ) - chat_message.type = ChatMessageType.RESPONSE - chat_message.status = ChatStatusType.DONE - yield chat_message + yield streaming_message + if not response: + error_message = ChatMessageError( + session_id=session_id, + content="No response from LLM." + ) + yield error_message + return + + self.collect_metrics(response) + self.context_tokens = ( + response.prompt_eval_count + response.eval_count + ) end_time = time.perf_counter() - chat_message.metadata.timers["streamed"] = end_time - start_time + + chat_message = ChatMessage( + session_id=session_id, + tunables=tunables, + status=ApiStatusType.DONE, + content=content, + metadata = ChatMessageMetaData( + options=options, + eval_count=response.eval_count, + eval_duration=response.eval_duration, + prompt_eval_count=response.prompt_eval_count, + prompt_eval_duration=response.prompt_eval_duration, + timers={ + "llm_streamed": end_time - start_time, + "llm_with_tools": 0, # Placeholder for tool processing time + }, + ) + ) + # Add the user and chat messages to the conversation self.conversation.append(user_message) self.conversation.append(chat_message) + yield chat_message return # async def process_message( diff --git a/src/backend/agents/candidate_chat.py b/src/backend/agents/candidate_chat.py index b38948b..1084b38 100644 --- a/src/backend/agents/candidate_chat.py +++ b/src/backend/agents/candidate_chat.py @@ -7,7 +7,7 @@ from .base import Agent, agent_registry from logger import logger from .registry import agent_registry -from models import ( ChatQuery, ChatMessage, Tunables, ChatStatusType, ChatMessageUser, Candidate) +from models import ( ChatMessageError, ChatMessageStatus, ChatMessageStreaming, ChatQuery, ChatMessage, Tunables, ApiStatusType, ChatMessageUser, Candidate) system_message = f""" @@ -34,8 +34,15 @@ class CandidateChat(Agent): system_prompt: str = system_message async def generate( - self, llm: Any, model: str, user_message: ChatMessageUser, user: Candidate, temperature=0.7 - ): + self, llm: Any, model: str, + session_id: str, prompt: str, + tunables: Optional[Tunables] = None, + temperature=0.7 + ) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]: + user = self.user + if not user: + logger.error("User is not set for CandidateChat agent.") + raise ValueError("User must be set before generating candidate chat responses.") self.system_prompt = f""" You are a helpful expert system representing a {user.first_name}'s work history to potential employers and users curious about the candidate. You want to incorporate as many facts and details about {user.first_name} as possible. @@ -49,7 +56,10 @@ Use that spelling instead of any spelling you may find in the <|context|>. {system_message} """ - async for message in super().generate(llm, model, user_message, user, temperature): + async for message in super().generate(llm=llm, model=model, session_id=session_id, prompt=prompt, temperature=temperature, tunables=tunables): + if message.status == ApiStatusType.ERROR: + yield message + return yield message # Register the base agent diff --git a/src/backend/agents/general.py b/src/backend/agents/general.py index 83d6f94..83b6a5d 100644 --- a/src/backend/agents/general.py +++ b/src/backend/agents/general.py @@ -7,7 +7,7 @@ from .base import Agent, agent_registry from logger import logger from .registry import agent_registry -from models import ( ChatQuery, ChatMessage, Tunables, ChatStatusType) +from models import ( ChatQuery, ChatMessage, Tunables, ApiStatusType) system_message = f""" Launched on {datetime.now().isoformat()}. diff --git a/src/backend/agents/generate_image.py b/src/backend/agents/generate_image.py index 0dc95d1..fdab6b1 100644 --- a/src/backend/agents/generate_image.py +++ b/src/backend/agents/generate_image.py @@ -25,7 +25,7 @@ import os import hashlib from .base import Agent, agent_registry, LLMMessage -from models import Candidate, ChatMessage, ChatMessageBase, ChatMessageMetaData, ChatMessageType, ChatMessageUser, ChatOptions, ChatSenderType, ChatStatusType +from models import ActivityType, ApiActivityType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, Tunables import model_cast from logger import logger import defines @@ -44,43 +44,48 @@ class ImageGenerator(Agent): system_prompt: str = "" # No system prompt is used async def generate( - self, llm: Any, model: str, user_message: ChatMessageUser, user: Candidate, temperature=0.7 - ) -> AsyncGenerator[ChatMessage, None]: - logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") + self, llm: Any, model: str, + session_id: str, prompt: str, + tunables: Optional[Tunables] = None, + temperature=0.7 + ) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]: + if not self.user: + logger.error("User is not set for ImageGenerator agent.") + raise ValueError("User must be set before generating images.") + user = self.user file_path = os.path.join(defines.user_dir, user.username, "profile.png") - chat_message = ChatMessage( - session_id=user_message.session_id, - tunables=user_message.tunables, - status=ChatStatusType.INITIALIZING, - type=ChatMessageType.PREPARING, - sender=ChatSenderType.ASSISTANT, - content="", - timestamp=datetime.now(UTC) - ) - - chat_message.metadata = ChatMessageMetaData() try: # # Generate the profile picture # - chat_message.content = f"Generating: {user_message.content}" - yield chat_message + status_message = ChatMessageStatus( + session_id=session_id, + activity=ApiActivityType.GENERATING_IMAGE, + content=f"Generating {prompt}...", + ) + yield status_message - logger.info(f"Image generation: {file_path} <- {user_message.content}") - request = ImageRequest(filepath=file_path, prompt=user_message.content, iterations=4, height=256, width=256, guidance_scale=7.5) + logger.info(f"Image generation: {file_path} <- {prompt}") + request = ImageRequest(filepath=file_path, session_id=session_id, prompt=prompt, iterations=4, height=256, width=256, guidance_scale=7.5) generated_message = None async for generated_message in generate_image( - user_message=user_message, request=request ): - if generated_message.status != "done": + if generated_message.status == ApiStatusType.ERROR: yield generated_message + return + if generated_message.status != ApiStatusType.DONE: + yield generated_message + continue if generated_message is None: - chat_message.status = ChatStatusType.ERROR - chat_message.content = "Image generation failed." - yield chat_message + error_message = ChatMessageError( + session_id=session_id, + content="Image generation failed to produce a valid response." + ) + logger.error(f"⚠️ {error_message.content}") + yield error_message return logger.info("Image generation done...") @@ -88,18 +93,23 @@ class ImageGenerator(Agent): user.profile_image = "profile.png" # Image generated - generated_message.status = ChatStatusType.DONE - generated_message.content = f"{defines.api_prefix}/profile/{user.username}" - yield generated_message + generated_image = ChatMessage( + session_id=session_id, + status=ApiStatusType.DONE, + content = f"{defines.api_prefix}/profile/{user.username}", + metadata=generated_message.metadata + ) + yield generated_image return except Exception as e: - chat_message.status = ChatStatusType.ERROR + error_message = ChatMessageError( + session_id=session_id, + content=f"Error generating image: {str(e)}" + ) logger.error(traceback.format_exc()) - logger.error(chat_message.content) - chat_message.content = f"Error in image generation: {str(e)}" - logger.error(chat_message.content) - yield chat_message + logger.error(f"⚠️ {error_message.content}") + yield error_message return # Register the base agent diff --git a/src/backend/agents/generate_persona.py b/src/backend/agents/generate_persona.py index fdef77a..fbbefbf 100644 --- a/src/backend/agents/generate_persona.py +++ b/src/backend/agents/generate_persona.py @@ -27,7 +27,7 @@ import random from names_dataset import NameDataset, NameWrapper # type: ignore from .base import Agent, agent_registry, LLMMessage -from models import Candidate, ChatMessage, ChatMessageBase, ChatMessageMetaData, ChatMessageType, ChatMessageUser, ChatOptions, ChatSenderType, ChatStatusType +from models import ApiActivityType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, Tunables import model_cast from logger import logger import defines @@ -316,14 +316,16 @@ class GeneratePersona(Agent): self.full_name = f"{self.first_name} {self.last_name}" async def generate( - self, llm: Any, model: str, user_message: ChatMessageUser, user: Candidate, temperature=0.7 - ): + self, llm: Any, model: str, + session_id: str, prompt: str, + tunables: Optional[Tunables] = None, + temperature=0.7 + ) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]: self.randomize() - status_message = ChatMessage(session_id=user_message.session_id) - original_prompt = user_message.content + original_prompt = prompt.strip() - user_message.content = f"""\ + prompt = f"""\ ```json {json.dumps({ "age": self.age, @@ -337,7 +339,7 @@ class GeneratePersona(Agent): """ if original_prompt: - user_message.content += f""" + prompt += f""" Incorporate the following into the job description: {original_prompt} """ @@ -347,17 +349,25 @@ Incorporate the following into the job description: {original_prompt} logger.info(f"🤖 Generating persona for {self.full_name}") generating_message = None async for generating_message in self.llm_one_shot( - llm=llm, model=model, - user_message=user_message, + llm=llm, model=model, + session_id=session_id, + prompt=prompt, system_prompt=generate_persona_system_prompt, temperature=temperature, ): - if generating_message.status == ChatStatusType.ERROR: + if generating_message.status == ApiStatusType.ERROR: logger.error(f"Error generating persona: {generating_message.content}") - raise Exception(generating_message.content) + yield generating_message + if generating_message.status != ApiStatusType.DONE: + yield generating_message if not generating_message: - raise Exception("No response from LLM during persona generation") + error_message = ChatMessageError( + session_id=session_id, + content="Persona generation failed to generate a response." + ) + yield error_message + return json_str = self.extract_json_from_text(generating_message.content) try: @@ -405,37 +415,39 @@ Incorporate the following into the job description: {original_prompt} persona["location"] = None persona["is_ai"] = True except Exception as e: - generating_message.content = f"Unable to parse LLM returned content: {json_str} {str(e)}" - generating_message.status = ChatStatusType.ERROR + error_message = ChatMessageError( + session_id=session_id, + content=f"Error parsing LLM response: {str(e)}\n\n{json_str}" + ) + logger.error(f"❌ Error parsing LLM response: {error_message.content}") logger.error(traceback.format_exc()) logger.error(generating_message.content) - yield generating_message + yield error_message return logger.info(f"✅ Persona for {persona['username']} generated successfully") # Persona generated - status_message = ChatMessage( - session_id=user_message.session_id, - status = ChatStatusType.STATUS, - type = ChatMessageType.RESPONSE, + persona_message = ChatMessage( + session_id=session_id, + status=ApiStatusType.DONE, + type=ApiMessageType.JSON, content = json.dumps(persona) ) - yield status_message + yield persona_message # # Generate the resume # - status_message = ChatMessage( - session_id=user_message.session_id, - status = ChatStatusType.STATUS, - type = ChatMessageType.THINKING, + status_message = ChatMessageStatus( + session_id=session_id, + activity = ApiActivityType.THINKING, content = f"Generating resume for {persona['full_name']}..." ) logger.info(f"🤖 {status_message.content}") yield status_message - user_message.content = f""" + content = f""" ```json {{ "full_name": "{persona["full_name"]}", @@ -449,30 +461,39 @@ Incorporate the following into the job description: {original_prompt} ``` """ if original_prompt: - user_message.content += f""" + content += f""" Make sure at least one of the candidate's job descriptions take into account the following: {original_prompt}.""" async for generating_message in self.llm_one_shot( - llm=llm, model=model, - user_message=user_message, + llm=llm, model=model, + session_id=session_id, + prompt=content, system_prompt=generate_resume_system_prompt, temperature=temperature, ): - if generating_message.status == ChatStatusType.ERROR: + if generating_message.status == ApiStatusType.ERROR: logger.error(f"❌ Error generating resume: {generating_message.content}") raise Exception(generating_message.content) + if generating_message.status != ApiStatusType.DONE: + yield generating_message if not generating_message: - raise Exception("No response from LLM during persona generation") + error_message = ChatMessageError( + session_id=session_id, + content="Resume generation failed to generate a response." + ) + logger.error(f"❌ {error_message.content}") + yield error_message + return resume = self.extract_markdown_from_text(generating_message.content) - status_message = ChatMessage( - session_id=user_message.session_id, - status=ChatStatusType.DONE, - type=ChatMessageType.RESPONSE, + resume_message = ChatMessage( + session_id=session_id, + status=ApiStatusType.DONE, + type=ApiMessageType.TEXT, content=resume ) - yield status_message + yield resume_message return def extract_json_from_text(self, text: str) -> str: diff --git a/src/backend/agents/job_requirements.py b/src/backend/agents/job_requirements.py index 10ac8b1..332f5ff 100644 --- a/src/backend/agents/job_requirements.py +++ b/src/backend/agents/job_requirements.py @@ -20,7 +20,7 @@ import asyncio import numpy as np # type: ignore from .base import Agent, agent_registry, LLMMessage -from models import Candidate, ChatMessage, ChatMessageBase, ChatMessageMetaData, ChatMessageType, ChatMessageUser, ChatOptions, ChatSenderType, ChatStatusType, JobRequirements +from models import Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, JobRequirements, JobRequirementsMessage, Tunables import model_cast from logger import logger import defines @@ -81,56 +81,51 @@ class JobRequirementsAgent(Agent): return system_prompt, prompt async def analyze_job_requirements( - self, llm: Any, model: str, user_message: ChatMessage - ) -> AsyncGenerator[ChatMessage, None]: + self, llm: Any, model: str, session_id: str, prompt: str + ) -> AsyncGenerator[ChatMessage | ChatMessageError, None]: """Analyze job requirements from job description.""" - system_prompt, prompt = self.create_job_analysis_prompt(user_message.content) - analyze_message = user_message.model_copy() - analyze_message.content = prompt + system_prompt, prompt = self.create_job_analysis_prompt(prompt) generated_message = None - async for generated_message in self.llm_one_shot(llm, model, system_prompt=system_prompt, user_message=analyze_message): - if generated_message.status == ChatStatusType.ERROR: - generated_message.content = "Error analyzing job requirements." + async for generated_message in self.llm_one_shot(llm, model, session_id=session_id, prompt=prompt, system_prompt=system_prompt): + if generated_message.status == ApiStatusType.ERROR: yield generated_message return - - if not generated_message: - status_message = ChatMessage( - session_id=user_message.session_id, - sender=ChatSenderType.AGENT, - status = ChatStatusType.ERROR, - type = ChatMessageType.ERROR, - content = "Job requirements analysis failed to generate a response.") - yield status_message - return + if generated_message.status != ApiStatusType.DONE: + yield generated_message - generated_message.status = ChatStatusType.DONE - generated_message.type = ChatMessageType.RESPONSE + if not generated_message: + error_message = ChatMessageError( + session_id=session_id, + content="Job requirements analysis failed to generate a response.") + logger.error(f"⚠️ {error_message.content}") + yield error_message + return + yield generated_message return async def generate( - self, llm: Any, model: str, user_message: ChatMessageUser, user: Candidate | None, temperature=0.7 + self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7 ) -> AsyncGenerator[ChatMessage, None]: # Stage 1A: Analyze job requirements - status_message = ChatMessage( - session_id=user_message.session_id, - sender=ChatSenderType.AGENT, - status=ChatStatusType.STATUS, - type=ChatMessageType.THINKING, + status_message = ChatMessageStatus( + session_id=session_id, content = f"Analyzing job requirements") yield status_message generated_message = None - async for generated_message in self.analyze_job_requirements(llm, model, user_message): - if generated_message.status == ChatStatusType.ERROR: - status_message.status = ChatStatusType.ERROR - status_message.content = generated_message.content - yield status_message + async for generated_message in self.analyze_job_requirements(llm, model, session_id, prompt): + if generated_message.status == ApiStatusType.ERROR: + yield generated_message return + if generated_message.status != ApiStatusType.DONE: + yield generated_message + if not generated_message: - status_message.status = ChatStatusType.ERROR - status_message.content = "Job requirements analysis failed." + status_message = ChatMessageStatus( + session_id=session_id, + content="Job requirements analysis failed to generate a response.") + logger.error(f"⚠️ {status_message.content}") yield status_message return @@ -150,34 +145,33 @@ class JobRequirementsAgent(Agent): if not job_requirements: raise ValueError("Job requirements data is empty or invalid.") except json.JSONDecodeError as e: - status_message.status = ChatStatusType.ERROR + status_message.status = ApiStatusType.ERROR status_message.content = f"Failed to parse job requirements JSON: {str(e)}\n\n{job_requirements_data}" logger.error(f"⚠️ {status_message.content}") yield status_message return except ValueError as e: - status_message.status = ChatStatusType.ERROR + status_message.status = ApiStatusType.ERROR status_message.content = f"Job requirements validation error: {str(e)}\n\n{job_requirements_data}" logger.error(f"⚠️ {status_message.content}") yield status_message return except Exception as e: - status_message.status = ChatStatusType.ERROR + status_message.status = ApiStatusType.ERROR status_message.content = f"Unexpected error processing job requirements: {str(e)}\n\n{job_requirements_data}" logger.error(traceback.format_exc()) logger.error(f"⚠️ {status_message.content}") yield status_message return - status_message.status = ChatStatusType.DONE - status_message.type = ChatMessageType.RESPONSE - job_data = { - "company": company_name, - "title": job_title, - "summary": job_summary, - "requirements": job_requirements.model_dump(mode="json", exclude_unset=True) - } - status_message.content = json.dumps(job_data) - yield status_message + job_requirements_message = JobRequirementsMessage( + session_id=session_id, + status=ApiStatusType.DONE, + requirements=job_requirements, + company=company_name, + title=job_title, + summary=job_summary + ) + yield job_requirements_message logger.info(f"✅ Job requirements analysis completed successfully.") return diff --git a/src/backend/agents/rag_search.py b/src/backend/agents/rag_search.py index 663ed72..3f6abad 100644 --- a/src/backend/agents/rag_search.py +++ b/src/backend/agents/rag_search.py @@ -7,7 +7,7 @@ from .base import Agent, agent_registry from logger import logger from .registry import agent_registry -from models import ( ChatMessage, ChatStatusType, ChatMessage, ChatOptions, ChatMessageType, ChatSenderType, ChatStatusType, ChatMessageMetaData, Candidate ) +from models import ( ChatMessage, ApiStatusType, ChatMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, ChatOptions, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate, Tunables ) from rag import ( ChromaDBGetResponse ) class Chat(Agent): @@ -19,8 +19,11 @@ class Chat(Agent): _agent_type: ClassVar[str] = agent_type # Add this for registration async def generate( - self, llm: Any, model: str, user_message: ChatMessage, user: Candidate, temperature=0.7 - ) -> AsyncGenerator[ChatMessage, None]: + self, llm: Any, model: str, + session_id: str, prompt: str, + tunables: Optional[Tunables] = None, + temperature=0.7 + ) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]: """ Generate a response based on the user message and the provided LLM. @@ -34,65 +37,25 @@ class Chat(Agent): Yields: ChatMessage: The generated response. """ - logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") - - if user.id != user_message.sender_id: - logger.error(f"User {user.username} id does not match message {user_message.sender_id}") - raise ValueError("User does not match message sender") - - chat_message = ChatMessage( - session_id=user_message.session_id, - tunables=user_message.tunables, - status=ChatStatusType.INITIALIZING, - type=ChatMessageType.PREPARING, - sender=ChatSenderType.ASSISTANT, - content="", - timestamp=datetime.now(UTC) - ) - - chat_message.metadata = ChatMessageMetaData() - chat_message.metadata.options = ChatOptions( - seed=8911, - num_ctx=self.context_size, - temperature=temperature, # Higher temperature to encourage tool usage - ) - - # Create a dict for storing various timing stats - chat_message.metadata.timers = {} - - self.metrics.generate_count.labels(agent=self.agent_type).inc() - with self.metrics.generate_duration.labels(agent=self.agent_type).time(): - - rag_message : Optional[ChatMessage] = None - async for rag_message in self.generate_rag_results(chat_message=user_message): - if rag_message.status == ChatStatusType.ERROR: - chat_message.status = rag_message.status - chat_message.content = rag_message.content - yield chat_message - return + rag_message = None + async for rag_message in self.generate_rag_results(session_id, prompt): + if rag_message.status == ApiStatusType.ERROR: + yield rag_message + return + if rag_message.status != ApiStatusType.DONE: yield rag_message - if rag_message: - chat_message.content = "" - rag_results: List[ChromaDBGetResponse] = rag_message.metadata.rag_results - chat_message.metadata.rag_results = rag_results - for chroma_results in rag_results: - for index, metadata in enumerate(chroma_results.metadatas): - content = "\n".join([ - line.strip() - for line in chroma_results.documents[index].split("\n") - if line - ]).strip() - chat_message.content += f""" -Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")} -Document reference: {chroma_results.ids[index]} -Content: { content } - -""" - - chat_message.status = ChatStatusType.DONE - chat_message.type = ChatMessageType.RAG_RESULT - yield chat_message + if not isinstance(rag_message, ChatMessageRagSearch): + logger.error(f"Expected ChatMessageRagSearch, got {type(rag_message)}") + error_message = ChatMessageError( + session_id=session_id, + content="RAG search did not return a valid response." + ) + yield error_message + return + + rag_message.status = ApiStatusType.DONE + yield rag_message # Register the base agent agent_registry.register(Chat._agent_type, Chat) diff --git a/src/backend/agents/skill_match.py b/src/backend/agents/skill_match.py index e4b6887..4e86768 100644 --- a/src/backend/agents/skill_match.py +++ b/src/backend/agents/skill_match.py @@ -20,7 +20,7 @@ import asyncio import numpy as np # type: ignore from .base import Agent, agent_registry, LLMMessage -from models import Candidate, ChatMessage, ChatMessageBase, ChatMessageMetaData, ChatMessageType, ChatMessageUser, ChatOptions, ChatSenderType, ChatStatusType, SkillMatch +from models import Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, SkillMatch, Tunables import model_cast from logger import logger import defines @@ -29,13 +29,13 @@ class SkillMatchAgent(Agent): agent_type: Literal["skill_match"] = "skill_match" # type: ignore _agent_type: ClassVar[str] = agent_type # Add this for registration - def generate_skill_assessment_prompt(self, skill, rag_content): + def generate_skill_assessment_prompt(self, skill, rag_context): """ Generate a system prompt to query the LLM for evidence of a specific skill Parameters: - skill (str): The specific skill to assess from job requirements - - rag_content (str): Additional RAG content queried from candidate documents + - rag_contexty (str): Additional RAG content queried from candidate documents Returns: - str: A system prompt tailored to assess the specific skill @@ -98,7 +98,7 @@ Adhere strictly to the JSON output format requested. Do not include any addition RESPOND WITH ONLY VALID JSON USING THE EXACT FORMAT SPECIFIED. -{rag_content} +{rag_context} JSON RESPONSE:""" @@ -106,110 +106,102 @@ JSON RESPONSE:""" return system_prompt, prompt async def analyze_job_requirements( - self, llm: Any, model: str, user_message: ChatMessage + self, llm: Any, model: str, session_id: str, requirement: str ) -> AsyncGenerator[ChatMessage, None]: """Analyze job requirements from job description.""" - system_prompt, prompt = self.create_job_analysis_prompt(user_message.content) - analyze_message = user_message.model_copy() - analyze_message.content = prompt + system_prompt, prompt = self.create_job_analysis_prompt(requirement) + generated_message = None - async for generated_message in self.llm_one_shot(llm, model, system_prompt=system_prompt, user_message=analyze_message): - if generated_message.status == ChatStatusType.ERROR: + async for generated_message in self.llm_one_shot(llm, model, session_id=session_id, prompt=prompt, system_prompt=system_prompt): + if generated_message.status == ApiStatusType.ERROR: generated_message.content = "Error analyzing job requirements." yield generated_message return - + if generated_message.status != ApiStatusType.DONE: + yield generated_message + if not generated_message: - status_message = ChatMessage( - session_id=user_message.session_id, - sender=ChatSenderType.AGENT, - status = ChatStatusType.ERROR, - type = ChatMessageType.ERROR, + error_message = ChatMessageError( + session_id=session_id, content = "Job requirements analysis failed to generate a response.") - yield status_message + yield error_message return - generated_message.status = ChatStatusType.DONE - generated_message.type = ChatMessageType.RESPONSE - yield generated_message + chat_message = ChatMessage( + session_id=session_id, + status=ApiStatusType.DONE, + content=generated_message.content, + metadata=generated_message.metadata, + ) + yield chat_message return async def generate( - self, llm: Any, model: str, user_message: ChatMessageUser, user: Candidate | None, temperature=0.7 - ) -> AsyncGenerator[ChatMessage, None]: + self, llm: Any, model: str, + session_id: str, prompt: str, + tunables: Optional[Tunables] = None, + temperature=0.7 + ) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]: # Stage 1A: Analyze job requirements - status_message = ChatMessage( - session_id=user_message.session_id, - sender=ChatSenderType.AGENT, - status=ChatStatusType.STATUS, - type=ChatMessageType.THINKING, - content = f"Analyzing job requirements") - yield status_message - rag_message = None - async for rag_message in self.generate_rag_results(chat_message=user_message): - if rag_message.status == ChatStatusType.ERROR: - status_message.status = ChatStatusType.ERROR - status_message.content = rag_message.content - logger.error(f"⚠️ {status_message.content}") - yield status_message + async for rag_message in self.generate_rag_results(session_id=session_id, prompt=prompt): + if rag_message.status == ApiStatusType.ERROR: + yield rag_message return + if rag_message.status != ApiStatusType.DONE: + yield rag_message if rag_message is None: - status_message.status = ChatStatusType.ERROR - status_message.content = "Failed to retrieve RAG context." - logger.error(f"⚠️ {status_message.content}") - yield status_message + error_message = ChatMessageError( + session_id=session_id, + content="RAG search did not return a valid response." + ) + logger.error(f"⚠️ {error_message.content}") + yield error_message return logger.info(f"🔍 RAG content retrieved: {len(rag_message.content)} bytes") - - system_prompt, prompt = self.generate_skill_assessment_prompt(skill=user_message.content, rag_content=rag_message.content) + rag_context = self.get_rag_context(rag_message) + system_prompt, prompt = self.generate_skill_assessment_prompt(skill=prompt, rag_context=rag_context) - user_message.content = prompt skill_assessment = None - async for skill_assessment in self.llm_one_shot(llm=llm, model=model, user_message=user_message, system_prompt=system_prompt, temperature=0.7): - if skill_assessment.status == ChatStatusType.ERROR: - status_message.status = ChatStatusType.ERROR - status_message.content = skill_assessment.content - logger.error(f"⚠️ {status_message.content}") - yield status_message + async for skill_assessment in self.llm_one_shot(llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt, temperature=0.7): + if skill_assessment.status == ApiStatusType.ERROR: + logger.error(f"⚠️ {skill_assessment.content}") + yield skill_assessment return + if skill_assessment.status != ApiStatusType.DONE: + yield skill_assessment + if skill_assessment is None: - status_message.status = ChatStatusType.ERROR - status_message.content = "Failed to generate skill assessment." - logger.error(f"⚠️ {status_message.content}") - yield status_message + error_message = ChatMessageError( + session_id=session_id, + content="Skill assessment failed to generate a response." + ) + logger.error(f"⚠️ {error_message.content}") + yield error_message return json_str = self.extract_json_from_text(skill_assessment.content) skill_assessment_data = "" try: skill_assessment_data = json.loads(json_str).get("skill_assessment", {}) - except json.JSONDecodeError as e: - status_message.status = ChatStatusType.ERROR - status_message.content = f"Failed to parse Skill assessment JSON: {str(e)}\n\n{skill_assessment_data}" - logger.error(f"⚠️ {status_message.content}") - yield status_message - return - except ValueError as e: - status_message.status = ChatStatusType.ERROR - status_message.content = f"Skill assessment validation error: {str(e)}\n\n{skill_assessment_data}" - logger.error(f"⚠️ {status_message.content}") - yield status_message - return except Exception as e: - status_message.status = ChatStatusType.ERROR - status_message.content = f"Unexpected error processing Skill assessment: {str(e)}\n\n{skill_assessment_data}" - logger.error(traceback.format_exc()) - logger.error(f"⚠️ {status_message.content}") - yield status_message + error_message = ChatMessageError( + session_id=session_id, + content=f"Failed to parse Skill assessment JSON: {str(e)}\n\n{skill_assessment.content}" + ) + logger.error(f"⚠️ {error_message.content}") + yield error_message return - status_message.status = ChatStatusType.DONE - status_message.type = ChatMessageType.RESPONSE - status_message.content = json.dumps(skill_assessment_data) - yield status_message - + + skill_assessment_message = ChatMessage( + session_id=session_id, + status=ApiStatusType.DONE, + content=json.dumps(skill_assessment_data), + metadata=skill_assessment.metadata + ) + yield skill_assessment_message logger.info(f"✅ Skill assessment completed successfully.") return diff --git a/src/backend/entities/candidate_entity.py b/src/backend/entities/candidate_entity.py index 4b0718f..91a83a1 100644 --- a/src/backend/entities/candidate_entity.py +++ b/src/backend/entities/candidate_entity.py @@ -18,7 +18,7 @@ from rag import start_file_watcher, ChromaDBFileWatcher, ChromaDBGetResponse import defines from logger import logger import agents as agents -from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, RagEntry, ChatMessageType, ChatMessageMetaData, ChatStatusType, Candidate, ChatContextType) +from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, RagEntry, ChatMessageMetaData, ApiStatusType, Candidate, ChatContextType) from llm_manager import llm_manager from agents.base import Agent diff --git a/src/backend/generate_types.py b/src/backend/generate_types.py index e69dcb2..a23c579 100644 --- a/src/backend/generate_types.py +++ b/src/backend/generate_types.py @@ -797,7 +797,7 @@ Generated conversion functions can be used like: const jobs = convertArrayFromApi(apiResponse, 'Job'); Enum types are now properly handled: - status: ChatStatusType = ChatStatusType.DONE -> status: ChatStatusType (not locked to "done") + status: ApiStatusType = ApiStatusType.DONE -> status: ApiStatusType (not locked to "done") """ ) diff --git a/src/backend/image_generator/profile_image.py b/src/backend/image_generator/profile_image.py index 9dd0446..20423af 100644 --- a/src/backend/image_generator/profile_image.py +++ b/src/backend/image_generator/profile_image.py @@ -24,7 +24,7 @@ import uuid from .image_model_cache import ImageModelCache -from models import Candidate, ChatMessage, ChatMessageBase, ChatMessageMetaData, ChatMessageType, ChatMessageUser, ChatOptions, ChatSenderType, ChatStatusType +from models import ApiActivityType, ApiStatusType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ChatMessageStatus, ChatMessageUser, ChatOptions, ChatSenderType from logger import logger from image_generator.image_model_cache import ImageModelCache @@ -39,6 +39,7 @@ TIME_ESTIMATES = { } class ImageRequest(BaseModel): + session_id: str filepath: str prompt: str model: str = "black-forest-labs/FLUX.1-schnell" @@ -53,18 +54,12 @@ model_cache = ImageModelCache() def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str): """Background worker for Flux image generation""" try: - # Your existing estimates calculation - estimates = {"per_step": 0.5} # Replace with your actual estimates - resolution_scale = (params.height * params.width) / (512 * 512) - # Flux: Run generation in the background and yield progress updates - estimated_gen_time = estimates["per_step"] * params.iterations * resolution_scale - status_queue.put({ - "status": "running", - "message": f"Initializing image generation...", - "estimated_time_remaining": estimated_gen_time, - "progress": 0 - }) + status_queue.put(ChatMessageStatus( + session_id=params.session_id, + content=f"Initializing image generation.", + activity=ApiActivityType.GENERATING_IMAGE, + )) # Start the generation task start_gen_time = time.time() @@ -74,11 +69,11 @@ def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task # Send progress updates progress = int((step+1) / params.iterations * 100) - status_queue.put({ - "status": "running", - "message": f"Processing step {step+1}/{params.iterations} ({progress}%) complete.", - "progress": progress - }) + status_queue.put(ChatMessageStatus( + session_id=params.session_id, + content=f"Processing step {step+1}/{params.iterations} ({progress}%)", + activity=ApiActivityType.GENERATING_IMAGE, + )) return callback_kwargs # Replace this block with your actual Flux pipe call: @@ -98,27 +93,22 @@ def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task image.save(params.filepath) # Final completion status - status_queue.put({ - "status": "completed", - "message": f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.", - "progress": 100, - "generation_time": gen_time, - "per_step_time": per_step_time, - "image_path": params.filepath - }) + status_queue.put(ChatMessage( + session_id=params.session_id, + status=ApiStatusType.DONE, + content=f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.", + )) except Exception as e: logger.error(traceback.format_exc()) logger.error(e) - status_queue.put({ - "status": "error", - "message": f"Generation failed: {str(e)}", - "error": str(e), - "progress": 0 - }) + status_queue.put(ChatMessageError( + session_id=params.session_id, + content=f"Error during image generation: {str(e)}", + )) -async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[Dict[str, Any], None]: +async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError, None]: """ Single async function that handles background Flux generation with status streaming """ @@ -136,7 +126,12 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato worker_thread.start() # Initial status - yield {'status': 'starting', 'task_id': task_id, 'message': 'Initializing image generation'} + status_message = ChatMessageStatus( + session_id=params.session_id, + content=f"Starting image generation with task ID {task_id}", + activity=ApiActivityType.THINKING + ) + yield status_message # Stream status updates completed = False @@ -147,14 +142,12 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato # Try to get status update (non-blocking) status_update = status_queue.get_nowait() - # Add task_id to status update - status_update['task_id'] = task_id - # Send status update yield status_update # Check if completed - if status_update.get('status') in ['completed', 'error']: + if status_update.status == ApiStatusType.DONE: + logger.info(f"Image generation completed for task {task_id}") completed = True last_heartbeat = time.time() @@ -163,11 +156,11 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato # No new status, send heartbeat if needed current_time = time.time() if current_time - last_heartbeat > 2: # Heartbeat every 2 seconds - heartbeat = { - 'status': 'heartbeat', - 'task_id': task_id, - 'timestamp': current_time - } + heartbeat = ChatMessageStatus( + session_id=params.session_id, + content=f"Heartbeat for task {task_id}", + activity=ApiActivityType.HEARTBEAT, + ) yield heartbeat last_heartbeat = current_time @@ -178,28 +171,25 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato if not completed: if worker_thread.is_alive(): # Thread still running but we might have missed the completion signal - timeout_status = { - 'status': 'timeout', - 'task_id': task_id, - 'message': 'Generation timed out or connection lost' - } + timeout_status = ChatMessageError( + session_id=params.session_id, + content=f"Generation timeout for task {task_id}. The process may still be running.", + ) yield timeout_status else: # Thread completed but we might have missed the final status - final_status = { - 'status': 'completed', - 'task_id': task_id, - 'message': 'Generation completed' - } + final_status = ChatMessage( + session_id=params.session_id, + status=ApiStatusType.DONE, + content=f"Generation completed for task {task_id}.", + ) yield final_status except Exception as e: - error_status = { - 'status': 'error', - 'task_id': task_id, - 'message': f'Server error: {str(e)}', - 'error': str(e) - } + error_status = ChatMessageError( + session_id=params.session_id, + content=f'Server error: {str(e)}' + ) logger.error(error_status) yield error_status @@ -208,41 +198,38 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato if worker_thread and 'worker_thread' in locals() and worker_thread.is_alive(): worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup -def status(chat_message: ChatMessage, status: str, progress: float = 0, estimated_time_remaining="...") -> ChatMessage: +def status(session_id: str, status: str) -> ChatMessageStatus: """Update chat message status and return it.""" - message = chat_message.copy(deep=True) - message.id = str(uuid.uuid4()) - message.timestamp = datetime.now(UTC) - message.type = ChatMessageType.THINKING - message.status = ChatStatusType.STREAMING - message.content = status - return message - -async def generate_image(user_message: ChatMessage, request: ImageRequest) -> AsyncGenerator[ChatMessage, None]: - """Generate an image with specified dimensions and yield status updates with time estimates.""" - chat_message = ChatMessage( - session_id=user_message.session_id, - tunables=user_message.tunables, - status=ChatStatusType.INITIALIZING, - type=ChatMessageType.PREPARING, - sender=ChatSenderType.ASSISTANT, - content="", - timestamp=datetime.now(UTC) + chat_message = ChatMessageStatus( + session_id=session_id, + activity=ApiActivityType.GENERATING_IMAGE, + content=status, ) + return chat_message + +async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, None]: + """Generate an image with specified dimensions and yield status updates with time estimates.""" + session_id = request.session_id + prompt = request.prompt.strip() try: # Validate prompt - prompt = user_message.content.strip() if not prompt: - chat_message.status = ChatStatusType.ERROR - chat_message.content = "Prompt cannot be empty" - yield chat_message + error_message = ChatMessageError( + session_id=session_id, + content="Prompt cannot be empty." + ) + logger.error(error_message.content) + yield error_message return # Validate dimensions if request.height <= 0 or request.width <= 0: - chat_message.status = ChatStatusType.ERROR - chat_message.content = "Height and width must be positive" - yield chat_message + error_message = ChatMessageError( + session_id=session_id, + content="Height and width must be positive integers." + ) + logger.error(error_message.content) + yield error_message return filedir = os.path.dirname(request.filepath) @@ -252,43 +239,50 @@ async def generate_image(user_message: ChatMessage, request: ImageRequest) -> As model_type = "flux" device = "cpu" - yield status(chat_message, f"Starting image generation...") - # Get initial time estimate, scaled by resolution estimates = TIME_ESTIMATES[model_type][device] resolution_scale = (request.height * request.width) / (512 * 512) estimated_total = estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale - yield status(chat_message, f"Estimated generation time: ~{estimated_total:.1f} seconds for {request.width}x{request.height}") - + # Initialize or get cached pipeline start_time = time.time() - yield status(chat_message, f"Loading generative image model...") + yield status(session_id, f"Loading generative image model...") pipe = await model_cache.get_pipeline(request.model, device) load_time = time.time() - start_time - yield status(chat_message, f"Model loaded in {load_time:.1f} seconds.", progress=10) + yield status(session_id, f"Model loaded in {load_time:.1f} seconds.",) - async for status_message in async_generate_image(pipe, request): - chat_message.content = json.dumps(status_message) # Merge properties from async_generate_image over the message... - chat_message.type = ChatMessageType.HEARTBEAT if status_message.get("status") == "heartbeat" else ChatMessageType.THINKING - if chat_message.type != ChatMessageType.HEARTBEAT: - logger.info(chat_message.content) - yield chat_message + progress = None + async for progress in async_generate_image(pipe, request): + if progress.status == ApiStatusType.ERROR: + yield progress + return + if progress.status != ApiStatusType.DONE: + yield progress + + if not progress: + error_message = ChatMessageError( + session_id=session_id, + content="Image generation failed to produce a valid response." + ) + logger.error(f"⚠️ {error_message.content}") + yield error_message + return # Final result total_time = time.time() - start_time - chat_message.status = ChatStatusType.DONE - chat_message.type = ChatMessageType.RESPONSE - chat_message.content = json.dumps({ - "status": f"Image generation complete in {total_time:.1f} seconds", - "progress": 100, - "filename": request.filepath - }) + chat_message = ChatMessage( + session_id=session_id, + status=ApiStatusType.DONE, + content=f"Image generated successfully in {total_time:.1f} seconds.", + ) yield chat_message except Exception as e: - chat_message.status = ChatStatusType.ERROR - chat_message.content = str(e) - yield chat_message + error_message = ChatMessageError( + session_id=session_id, + content=f"Error during image generation: {str(e)}" + ) logger.error(traceback.format_exc()) - logger.error(chat_message.content) + logger.error(error_message.content) + yield error_message return \ No newline at end of file diff --git a/src/backend/main.py b/src/backend/main.py index 738b495..316c538 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -64,7 +64,7 @@ import agents # ============================= from models import ( # API - Job, LoginRequest, CreateCandidateRequest, CreateEmployerRequest, + ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, Job, LoginRequest, CreateCandidateRequest, CreateEmployerRequest, # User models Candidate, Employer, BaseUserWithType, BaseUser, Guest, Authentication, AuthResponse, CandidateAI, @@ -73,7 +73,7 @@ from models import ( JobFull, JobApplication, ApplicationStatus, # Chat models - ChatSession, ChatMessage, ChatContext, ChatQuery, ChatStatusType, ChatMessageBase, ChatMessageUser, ChatSenderType, ChatMessageType, ChatContextType, + ChatSession, ChatMessage, ChatContext, ChatQuery, ApiStatusType, ChatSenderType, ApiMessageType, ChatContextType, ChatMessageRagSearch, # Document models @@ -361,7 +361,6 @@ def filter_and_paginate( async def stream_agent_response(chat_agent: agents.Agent, user_message: ChatMessageUser, - candidate: Candidate, chat_session_data: Dict[str, Any] | None = None, database: RedisDatabase | None = None) -> StreamingResponse: async def message_stream_generator(): @@ -372,30 +371,33 @@ async def stream_agent_response(chat_agent: agents.Agent, async for generated_message in chat_agent.generate( llm=llm_manager.get_llm(), model=defines.model, - user_message=user_message, - user=candidate, + session_id=user_message.session_id, + prompt=user_message.content, ): + if generated_message.status == ApiStatusType.ERROR: + logger.error(f"❌ AI generation error: {generated_message.content}") + yield f"data: {json.dumps({'status': 'error'})}\n\n" + return + # Store reference to the complete AI message - if generated_message.status == ChatStatusType.DONE: + if generated_message.status == ApiStatusType.DONE: final_message = generated_message # If the message is not done, convert it to a ChatMessageBase to remove # metadata and other unnecessary fields for streaming - if generated_message.status != ChatStatusType.DONE: - generated_message = model_cast.cast_to_model(ChatMessageBase, generated_message) + if generated_message.status != ApiStatusType.DONE: + if not isinstance(generated_message, ChatMessageStreaming) and not isinstance(generated_message, ChatMessageStatus): + raise TypeError( + f"Expected ChatMessageStreaming or ChatMessageStatus, got {type(generated_message)}" + ) - json_data = generated_message.model_dump(mode='json', by_alias=True, exclude_unset=True) + json_data = generated_message.model_dump(mode='json', by_alias=True) json_str = json.dumps(json_data) - - log = f"🔗 Message status={generated_message.status}, sender={getattr(generated_message, 'sender', 'unknown')}" - if last_log != log: - last_log = log - logger.info(log) - + yield f"data: {json_str}\n\n" # After streaming is complete, persist the final AI message to database - if final_message and final_message.status == ChatStatusType.DONE: + if final_message and final_message.status == ApiStatusType.DONE: try: if database and chat_session_data: await database.add_chat_message(final_message.session_id, final_message.model_dump()) @@ -412,9 +414,12 @@ async def stream_agent_response(chat_agent: agents.Agent, message_stream_generator(), media_type="text/event-stream", headers={ - "Cache-Control": "no-cache", + "Cache-Control": "no-cache, no-store, must-revalidate", "Connection": "keep-alive", - "X-Accel-Buffering": "no", + "X-Accel-Buffering": "no", # Nginx + "X-Content-Type-Options": "nosniff", + "Access-Control-Allow-Origin": "*", # Adjust for your CORS needs + "Transfer-Encoding": "chunked", }, ) @@ -678,19 +683,19 @@ async def create_candidate_ai( async for generated_message in generate_agent.generate( llm=llm_manager.get_llm(), model=defines.model, - user_message=user_message, - user=None, + session_id=user_message.session_id, + prompt=user_message.content, ): - if generated_message.status == ChatStatusType.ERROR: + if generated_message.status == ApiStatusType.ERROR: logger.error(f"❌ AI generation error: {generated_message.content}") return JSONResponse( status_code=500, content=create_error_response("AI_GENERATION_ERROR", generated_message.content) ) - if generated_message.type == ChatMessageType.RESPONSE and state == 0: + if generated_message.status == ApiStatusType.DONE and state == 0: persona_message = generated_message state = 1 # Switch to resume generation - elif generated_message.type == ChatMessageType.RESPONSE and state == 1: + elif generated_message.status == ApiStatusType.DONE and state == 1: resume_message = generated_message if not persona_message: @@ -2412,7 +2417,6 @@ async def update_candidate( ) is_AI = candidate_data.get("is_AI", False) - logger.info(json.dumps(candidate_data, indent=2)) candidate = CandidateAI.model_validate(candidate_data) if is_AI else Candidate.model_validate(candidate_data) # Check authorization (user can only update their own profile) @@ -2804,8 +2808,8 @@ async def post_candidate_rag_search( async for generated_message in chat_agent.generate( llm=llm_manager.get_llm(), model=defines.model, - user_message=user_message, - user=candidate, + session_id=user_message.session_id, + prompt=user_message.prompt, ): rag_message = generated_message @@ -3087,7 +3091,6 @@ async def post_chat_session_message_stream( return await stream_agent_response( chat_agent=chat_agent, user_message=user_message, - candidate=candidate, database=database, chat_session_data=chat_session_data, ) @@ -3389,15 +3392,10 @@ async def get_candidate_skill_match( agent.generate( llm=llm_manager.get_llm(), model=defines.model, - user_message=ChatMessageUser( - sender_id=candidate.id, - session_id="", - content=requirement, - timestamp=datetime.now(UTC) + session_id="", + prompt=requirement, ), - user=candidate, ) - ) if skill_match is None: return JSONResponse( status_code=500, diff --git a/src/backend/models.py b/src/backend/models.py index bd7d3eb..2fbf27a 100644 --- a/src/backend/models.py +++ b/src/backend/models.py @@ -71,8 +71,11 @@ class InterviewRecommendation(str, Enum): class ChatSenderType(str, Enum): USER = "user" ASSISTANT = "assistant" - AGENT = "agent" + # Frontend can use this to set mock responses SYSTEM = "system" + INFORMATION = "information" + WARNING = "warning" + ERROR = "error" class Requirements(BaseModel): required: List[str] = Field(default_factory=list) @@ -107,25 +110,13 @@ class SkillMatch(BaseModel): model_config = { "populate_by_name": True # Allow both field names and aliases } - -class ChatMessageType(str, Enum): - ERROR = "error" - GENERATING = "generating" - INFO = "info" - PREPARING = "preparing" - PROCESSING = "processing" - HEARTBEAT = "heartbeat" - RESPONSE = "response" - SEARCHING = "searching" - RAG_RESULT = "rag_result" - SYSTEM = "system" - THINKING = "thinking" - TOOLING = "tooling" - USER = "user" +class ApiMessageType(str, Enum): + BINARY = "binary" + TEXT = "text" + JSON = "json" -class ChatStatusType(str, Enum): - INITIALIZING = "initializing" +class ApiStatusType(str, Enum): STREAMING = "streaming" STATUS = "status" DONE = "done" @@ -772,26 +763,55 @@ class LLMMessage(BaseModel): content: str = Field(default="") tool_calls: Optional[List[Dict]] = Field(default=[], exclude=True) - -class ChatMessageBase(BaseModel): +class ApiMessage(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) session_id: str = Field(..., alias="sessionId") sender_id: Optional[str] = Field(None, alias="senderId") - status: ChatStatusType #= ChatStatusType.INITIALIZING - type: ChatMessageType #= ChatMessageType.PREPARING - sender: ChatSenderType #= ChatSenderType.SYSTEM + status: ApiStatusType + type: ApiMessageType timestamp: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="timestamp") - tunables: Optional[Tunables] = None - content: str = "" model_config = { "populate_by_name": True # Allow both field names and aliases } -class ChatMessageRagSearch(ChatMessageBase): - status: ChatStatusType = ChatStatusType.DONE - type: ChatMessageType = ChatMessageType.RAG_RESULT - sender: ChatSenderType = ChatSenderType.USER +class ChatMessageStreaming(ApiMessage): + status: ApiStatusType = ApiStatusType.STREAMING + type: ApiMessageType = ApiMessageType.TEXT + content: str + +class ApiActivityType(str, Enum): + SYSTEM = "system" # Used solely on frontend + INFO = "info" # Used solely on frontend + SEARCHING = "searching" # Used when generating RAG information + THINKING = "thinking" # Used when determing if AI will use tools + GENERATING = "generating" # Used when AI is generating a response + GENERATING_IMAGE = "generating_image" # Used when AI is generating an image + TOOLING = "tooling" # Used when AI is using tools + HEARTBEAT = "heartbeat" # Used for periodic updates + +class ChatMessageStatus(ApiMessage): + status: ApiStatusType = ApiStatusType.STATUS + type: ApiMessageType = ApiMessageType.TEXT + activity: ApiActivityType + content: Any + +class ChatMessageError(ApiMessage): + status: ApiStatusType = ApiStatusType.ERROR + type: ApiMessageType = ApiMessageType.TEXT + content: str + +class ChatMessageRagSearch(ApiMessage): + type: ApiMessageType = ApiMessageType.JSON dimensions: int = 2 | 3 + content: List[ChromaDBGetResponse] = [] + +class JobRequirementsMessage(ApiMessage): + type: ApiMessageType = ApiMessageType.JSON + title: Optional[str] + summary: Optional[str] + company: Optional[str] + description: str + requirements: Optional[JobRequirements] class ChatMessageMetaData(BaseModel): model: AIModelType = AIModelType.QWEN2_5 @@ -808,23 +828,26 @@ class ChatMessageMetaData(BaseModel): prompt_eval_count: int = 0 prompt_eval_duration: int = 0 options: Optional[ChatOptions] = None - tools: Optional[Dict[str, Any]] = None - timers: Optional[Dict[str, float]] = None + tools: Dict[str, Any] = Field(default_factory=dict) + timers: Dict[str, float] = Field(default_factory=dict) model_config = { "populate_by_name": True # Allow both field names and aliases } -class ChatMessageUser(ChatMessageBase): - status: ChatStatusType = ChatStatusType.INITIALIZING - type: ChatMessageType = ChatMessageType.GENERATING - sender: ChatSenderType = ChatSenderType.USER +class ChatMessageUser(ApiMessage): + type: ApiMessageType = ApiMessageType.TEXT + status: ApiStatusType = ApiStatusType.DONE + role: ChatSenderType = ChatSenderType.USER + content: str = "" + tunables: Optional[Tunables] = None -class ChatMessage(ChatMessageBase): +class ChatMessage(ChatMessageUser): + role: ChatSenderType = ChatSenderType.ASSISTANT + metadata: ChatMessageMetaData = Field(default_factory=ChatMessageMetaData) #attachments: Optional[List[Attachment]] = None #reactions: Optional[List[MessageReaction]] = None #is_edited: bool = Field(False, alias="isEdited") #edit_history: Optional[List[EditHistory]] = Field(None, alias="editHistory") - metadata: ChatMessageMetaData = Field(default_factory=ChatMessageMetaData) class ChatSession(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) diff --git a/src/backend/rag/rag.py b/src/backend/rag/rag.py index 5ce21c3..0c6ba48 100644 --- a/src/backend/rag/rag.py +++ b/src/backend/rag/rag.py @@ -27,8 +27,9 @@ from .markdown_chunker import ( # When imported as a module, use relative imports import defines +from models import ChromaDBGetResponse -__all__ = ["ChromaDBFileWatcher", "start_file_watcher", "ChromaDBGetResponse"] +__all__ = ["ChromaDBFileWatcher", "start_file_watcher"] DEFAULT_CHUNK_SIZE = 750 DEFAULT_CHUNK_OVERLAP = 100 @@ -37,46 +38,6 @@ class RagEntry(BaseModel): name: str description: str = "" enabled: bool = True - -class ChromaDBGetResponse(BaseModel): - name: str = "" - size: int = 0 - ids: List[str] = [] - embeddings: List[List[float]] = Field(default=[]) - documents: List[str] = [] - metadatas: List[Dict[str, Any]] = [] - query: str = "" - query_embedding: Optional[List[float]] = Field(default=None) - umap_embedding_2d: Optional[List[float]] = Field(default=None) - umap_embedding_3d: Optional[List[float]] = Field(default=None) - enabled: bool = True - - class Config: - validate_assignment = True - - @field_validator("embeddings", "query_embedding", "umap_embedding_2d", "umap_embedding_3d") - @classmethod - def validate_embeddings(cls, value, field): - # logging.info(f"Validating {field.field_name} with value: {type(value)} - {value}") - if value is None: - return value - if isinstance(value, np.ndarray): - if field.field_name == "embeddings": - if value.ndim != 2: - raise ValueError(f"{field.name} must be a 2-dimensional NumPy array") - return [[float(x) for x in row] for row in value.tolist()] - else: - if value.ndim != 1: - raise ValueError(f"{field.field_name} must be a 1-dimensional NumPy array") - return [float(x) for x in value.tolist()] - if field.field_name == "embeddings": - if not all(isinstance(sublist, list) and all(isinstance(x, (int, float)) for x in sublist) for sublist in value): - raise ValueError(f"{field.field_name} must be a list of lists of floats") - return [[float(x) for x in sublist] for sublist in value] - else: - if not isinstance(value, list) or not all(isinstance(x, (int, float)) for x in value): - raise ValueError(f"{field.field_name} must be a list of floats") - return [float(x) for x in value] class ChromaDBFileWatcher(FileSystemEventHandler): def __init__(