Chat and stuff is working again
This commit is contained in:
parent
efef926e45
commit
48e6eeaa71
@ -15,14 +15,14 @@ import { BackstoryElementProps } from './BackstoryTab';
|
||||
import { connectionBase } from 'utils/Global';
|
||||
import { useAuth } from "hooks/AuthContext";
|
||||
import { StreamingResponse } from 'services/api-client';
|
||||
import { ChatMessage, ChatMessageBase, ChatContext, ChatSession, ChatQuery, ChatMessageUser } from 'types/types';
|
||||
import { ChatMessage, ChatContext, ChatSession, ChatQuery, ChatMessageUser, ChatMessageError, ChatMessageStreaming, ChatMessageStatus } from 'types/types';
|
||||
import { PaginatedResponse } from 'types/conversion';
|
||||
|
||||
import './Conversation.css';
|
||||
import { useSelectedCandidate } from 'hooks/GlobalContext';
|
||||
|
||||
const defaultMessage: ChatMessage = {
|
||||
type: "preparing", status: "done", sender: "system", sessionId: "", timestamp: new Date(), content: ""
|
||||
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "assistant"
|
||||
};
|
||||
|
||||
const loadingMessage: ChatMessage = { ...defaultMessage, content: "Establishing connection with server..." };
|
||||
@ -249,8 +249,7 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
|
||||
...conversationRef.current,
|
||||
{
|
||||
...defaultMessage,
|
||||
type: 'user',
|
||||
sender: 'user',
|
||||
type: 'text',
|
||||
content: query.prompt,
|
||||
}
|
||||
]);
|
||||
@ -260,34 +259,30 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
|
||||
);
|
||||
|
||||
const chatMessage: ChatMessageUser = {
|
||||
role: "user",
|
||||
sessionId: chatSession.id,
|
||||
content: query.prompt,
|
||||
tunables: query.tunables,
|
||||
status: "done",
|
||||
type: "user",
|
||||
sender: "user",
|
||||
type: "text",
|
||||
timestamp: new Date()
|
||||
};
|
||||
|
||||
controllerRef.current = apiClient.sendMessageStream(chatMessage, {
|
||||
onMessage: (msg: ChatMessageBase) => {
|
||||
onMessage: (msg: ChatMessage) => {
|
||||
console.log("onMessage:", msg);
|
||||
if (msg.type === "response") {
|
||||
setConversation([
|
||||
...conversationRef.current,
|
||||
msg
|
||||
]);
|
||||
setStreamingMessage(undefined);
|
||||
setProcessingMessage(undefined);
|
||||
setProcessing(false);
|
||||
} else {
|
||||
setProcessingMessage(msg);
|
||||
}
|
||||
setConversation([
|
||||
...conversationRef.current,
|
||||
msg
|
||||
]);
|
||||
setStreamingMessage(undefined);
|
||||
setProcessingMessage(undefined);
|
||||
setProcessing(false);
|
||||
if (onResponse) {
|
||||
onResponse(msg);
|
||||
}
|
||||
},
|
||||
onError: (error: string | ChatMessageBase) => {
|
||||
onError: (error: string | ChatMessageError) => {
|
||||
console.log("onError:", error);
|
||||
// Type-guard to determine if this is a ChatMessageBase or a string
|
||||
if (typeof error === "object" && error !== null && "content" in error) {
|
||||
@ -298,12 +293,12 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
|
||||
setProcessingMessage({ ...defaultMessage, content: error as string });
|
||||
}
|
||||
},
|
||||
onStreaming: (chunk: ChatMessageBase) => {
|
||||
onStreaming: (chunk: ChatMessageStreaming) => {
|
||||
console.log("onStreaming:", chunk);
|
||||
setStreamingMessage({ ...defaultMessage, ...chunk });
|
||||
},
|
||||
onStatusChange: (status: string) => {
|
||||
console.log("onStatusChange:", status);
|
||||
onStatus: (status: ChatMessageStatus) => {
|
||||
console.log("onStatus:", status);
|
||||
},
|
||||
onComplete: () => {
|
||||
console.log("onComplete");
|
||||
|
@ -20,7 +20,7 @@ import CheckCircleIcon from '@mui/icons-material/CheckCircle';
|
||||
import ErrorIcon from '@mui/icons-material/Error';
|
||||
import PendingIcon from '@mui/icons-material/Pending';
|
||||
import WarningIcon from '@mui/icons-material/Warning';
|
||||
import { Candidate, ChatMessage, ChatMessageBase, ChatMessageUser, ChatSession, JobRequirements, SkillMatch } from 'types/types';
|
||||
import { Candidate, ChatMessage, ChatMessageError, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatSession, JobRequirements, SkillMatch } from 'types/types';
|
||||
import { useAuth } from 'hooks/AuthContext';
|
||||
import { BackstoryPageProps } from './BackstoryTab';
|
||||
import { toCamelCase } from 'types/conversion';
|
||||
@ -31,8 +31,8 @@ interface JobAnalysisProps extends BackstoryPageProps {
|
||||
candidate: Candidate;
|
||||
}
|
||||
|
||||
const defaultMessage: ChatMessageUser = {
|
||||
type: "preparing", status: "done", sender: "user", sessionId: "", timestamp: new Date(), content: ""
|
||||
const defaultMessage: ChatMessage = {
|
||||
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "assistant"
|
||||
};
|
||||
|
||||
const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) => {
|
||||
@ -62,23 +62,19 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
|
||||
return;
|
||||
}
|
||||
|
||||
const createSession = async () => {
|
||||
try {
|
||||
const session: ChatSession = await apiClient.createCandidateChatSession(
|
||||
candidate.username,
|
||||
'job_requirements',
|
||||
`Generate requirements for ${job.title}`
|
||||
);
|
||||
setSnack("Job analysis session started");
|
||||
try {
|
||||
setCreatingSession(true);
|
||||
apiClient.getOrCreateChatSession(candidate, `Generate requirements for ${candidate.fullName}`, 'job_requirements')
|
||||
.then(session => {
|
||||
setRequirementsSession(session);
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
setSnack("Unable to create requirements session", "error");
|
||||
}
|
||||
setCreatingSession(false);
|
||||
});
|
||||
} catch (error) {
|
||||
setSnack('Unable to load chat session', 'error');
|
||||
} finally {
|
||||
setCreatingSession(false);
|
||||
};
|
||||
setCreatingSession(true);
|
||||
createSession();
|
||||
}
|
||||
|
||||
}, [requirementsSession, apiClient, candidate]);
|
||||
|
||||
// Fetch initial requirements
|
||||
@ -94,50 +90,48 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
|
||||
apiClient.sendMessageStream(chatMessage, {
|
||||
onMessage: (msg: ChatMessage) => {
|
||||
console.log(`onMessage: ${msg.type}`, msg);
|
||||
if (msg.type === "response") {
|
||||
const job: Job = toCamelCase<Job>(JSON.parse(msg.content || ''));
|
||||
const requirements: { requirement: string, domain: string }[] = [];
|
||||
if (job.requirements?.technicalSkills) {
|
||||
job.requirements.technicalSkills.required?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (required)' }));
|
||||
job.requirements.technicalSkills.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (preferred)' }));
|
||||
}
|
||||
if (job.requirements?.experienceRequirements) {
|
||||
job.requirements.experienceRequirements.required?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (required)' }));
|
||||
job.requirements.experienceRequirements.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (preferred)' }));
|
||||
}
|
||||
if (job.requirements?.softSkills) {
|
||||
job.requirements.softSkills.forEach(req => requirements.push({ requirement: req, domain: 'Soft Skills' }));
|
||||
}
|
||||
if (job.requirements?.experience) {
|
||||
job.requirements.experience.forEach(req => requirements.push({ requirement: req, domain: 'Experience' }));
|
||||
}
|
||||
if (job.requirements?.education) {
|
||||
job.requirements.education.forEach(req => requirements.push({ requirement: req, domain: 'Education' }));
|
||||
}
|
||||
if (job.requirements?.certifications) {
|
||||
job.requirements.certifications.forEach(req => requirements.push({ requirement: req, domain: 'Certifications' }));
|
||||
}
|
||||
if (job.requirements?.preferredAttributes) {
|
||||
job.requirements.preferredAttributes.forEach(req => requirements.push({ requirement: req, domain: 'Preferred Attributes' }));
|
||||
}
|
||||
|
||||
const initialSkillMatches = requirements.map(req => ({
|
||||
requirement: req.requirement,
|
||||
domain: req.domain,
|
||||
status: 'waiting' as const,
|
||||
matchScore: 0,
|
||||
assessment: '',
|
||||
description: '',
|
||||
citations: []
|
||||
}));
|
||||
|
||||
setRequirements(requirements);
|
||||
setSkillMatches(initialSkillMatches);
|
||||
setStatusMessage(null);
|
||||
setLoadingRequirements(false);
|
||||
const job: Job = toCamelCase<Job>(JSON.parse(msg.content || ''));
|
||||
const requirements: { requirement: string, domain: string }[] = [];
|
||||
if (job.requirements?.technicalSkills) {
|
||||
job.requirements.technicalSkills.required?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (required)' }));
|
||||
job.requirements.technicalSkills.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (preferred)' }));
|
||||
}
|
||||
if (job.requirements?.experienceRequirements) {
|
||||
job.requirements.experienceRequirements.required?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (required)' }));
|
||||
job.requirements.experienceRequirements.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (preferred)' }));
|
||||
}
|
||||
if (job.requirements?.softSkills) {
|
||||
job.requirements.softSkills.forEach(req => requirements.push({ requirement: req, domain: 'Soft Skills' }));
|
||||
}
|
||||
if (job.requirements?.experience) {
|
||||
job.requirements.experience.forEach(req => requirements.push({ requirement: req, domain: 'Experience' }));
|
||||
}
|
||||
if (job.requirements?.education) {
|
||||
job.requirements.education.forEach(req => requirements.push({ requirement: req, domain: 'Education' }));
|
||||
}
|
||||
if (job.requirements?.certifications) {
|
||||
job.requirements.certifications.forEach(req => requirements.push({ requirement: req, domain: 'Certifications' }));
|
||||
}
|
||||
if (job.requirements?.preferredAttributes) {
|
||||
job.requirements.preferredAttributes.forEach(req => requirements.push({ requirement: req, domain: 'Preferred Attributes' }));
|
||||
}
|
||||
|
||||
const initialSkillMatches = requirements.map(req => ({
|
||||
requirement: req.requirement,
|
||||
domain: req.domain,
|
||||
status: 'waiting' as const,
|
||||
matchScore: 0,
|
||||
assessment: '',
|
||||
description: '',
|
||||
citations: []
|
||||
}));
|
||||
|
||||
setRequirements(requirements);
|
||||
setSkillMatches(initialSkillMatches);
|
||||
setStatusMessage(null);
|
||||
setLoadingRequirements(false);
|
||||
},
|
||||
onError: (error: string | ChatMessageBase) => {
|
||||
onError: (error: string | ChatMessageError) => {
|
||||
console.log("onError:", error);
|
||||
// Type-guard to determine if this is a ChatMessageBase or a string
|
||||
if (typeof error === "object" && error !== null && "content" in error) {
|
||||
@ -147,11 +141,11 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
|
||||
}
|
||||
setLoadingRequirements(false);
|
||||
},
|
||||
onStreaming: (chunk: ChatMessageBase) => {
|
||||
onStreaming: (chunk: ChatMessageStreaming) => {
|
||||
// console.log("onStreaming:", chunk);
|
||||
},
|
||||
onStatusChange: (status: string) => {
|
||||
console.log(`onStatusChange: ${status}`);
|
||||
onStatus: (status: ChatMessageStatus) => {
|
||||
console.log(`onStatus: ${status}`);
|
||||
},
|
||||
onComplete: () => {
|
||||
console.log("onComplete");
|
||||
|
@ -32,9 +32,9 @@ import { SetSnackType } from './Snack';
|
||||
import { CopyBubble } from './CopyBubble';
|
||||
import { Scrollable } from './Scrollable';
|
||||
import { BackstoryElementProps } from './BackstoryTab';
|
||||
import { ChatMessage, ChatSession, ChatMessageType, ChatMessageMetaData, ChromaDBGetResponse } from 'types/types';
|
||||
import { ChatMessage, ChatSession, ChatMessageMetaData, ChromaDBGetResponse, ApiActivityType, ChatMessageUser, ChatMessageError, ChatMessageStatus, ChatSenderType } from 'types/types';
|
||||
|
||||
const getStyle = (theme: Theme, type: ChatMessageType): any => {
|
||||
const getStyle = (theme: Theme, type: ApiActivityType | ChatSenderType | "error"): any => {
|
||||
const defaultRadius = '16px';
|
||||
const defaultStyle = {
|
||||
padding: theme.spacing(1, 2),
|
||||
@ -56,7 +56,7 @@ const getStyle = (theme: Theme, type: ChatMessageType): any => {
|
||||
};
|
||||
|
||||
const styles: any = {
|
||||
response: {
|
||||
assistant: {
|
||||
...defaultStyle,
|
||||
backgroundColor: theme.palette.primary.main,
|
||||
border: `1px solid ${theme.palette.secondary.main}`,
|
||||
@ -90,9 +90,10 @@ const getStyle = (theme: Theme, type: ChatMessageType): any => {
|
||||
boxShadow: '0 1px 3px rgba(216, 58, 58, 0.15)',
|
||||
},
|
||||
'fact-check': 'qualifications',
|
||||
generating: 'status',
|
||||
'job-description': 'content',
|
||||
'job-requirements': 'qualifications',
|
||||
info: {
|
||||
information: {
|
||||
...defaultStyle,
|
||||
backgroundColor: '#BFD8D8',
|
||||
border: `1px solid ${theme.palette.secondary.main}`,
|
||||
@ -156,25 +157,29 @@ const getStyle = (theme: Theme, type: ChatMessageType): any => {
|
||||
}
|
||||
}
|
||||
|
||||
if (!(type in styles)) {
|
||||
console.log(`Style does not exist for: ${type}`);
|
||||
}
|
||||
|
||||
return styles[type];
|
||||
}
|
||||
|
||||
const getIcon = (messageType: ChatMessageType): React.ReactNode | null => {
|
||||
const getIcon = (activityType: ApiActivityType | ChatSenderType | "error"): React.ReactNode | null => {
|
||||
const icons: any = {
|
||||
error: <ErrorOutline color="error" />,
|
||||
generating: <LocationSearchingIcon />,
|
||||
info: <InfoOutline color="info" />,
|
||||
information: <InfoOutline color="info" />,
|
||||
preparing: <LocationSearchingIcon />,
|
||||
processing: <LocationSearchingIcon />,
|
||||
system: <Memory />,
|
||||
thinking: <Psychology />,
|
||||
tooling: <LocationSearchingIcon />,
|
||||
};
|
||||
return icons[messageType] || null;
|
||||
return icons[activityType] || null;
|
||||
}
|
||||
|
||||
interface MessageProps extends BackstoryElementProps {
|
||||
message: ChatMessage,
|
||||
message: ChatMessageUser | ChatMessage | ChatMessageError | ChatMessageStatus,
|
||||
title?: string,
|
||||
chatSession?: ChatSession,
|
||||
className?: string,
|
||||
@ -319,7 +324,7 @@ const MessageMeta = (props: MessageMetaProps) => {
|
||||
};
|
||||
|
||||
interface MessageContainerProps {
|
||||
type: ChatMessageType,
|
||||
type: ApiActivityType | ChatSenderType | "error",
|
||||
metadataView?: React.ReactNode | null,
|
||||
messageView?: React.ReactNode | null,
|
||||
sx?: SxProps<Theme>,
|
||||
@ -360,13 +365,21 @@ const Message = (props: MessageProps) => {
|
||||
setSnack
|
||||
};
|
||||
const theme = useTheme();
|
||||
const style: any = getStyle(theme, message.type);
|
||||
const type: ApiActivityType | ChatSenderType | "error" = ('activity' in message) ? message.activity : ('error' in message) ? 'error' : (message as ChatMessage).role;
|
||||
const style: any = getStyle(theme, type);
|
||||
|
||||
const handleMetaExpandClick = () => {
|
||||
setMetaExpanded(!metaExpanded);
|
||||
};
|
||||
|
||||
const content = message.content?.trim();
|
||||
let content;
|
||||
if (typeof (message.content) === "string") {
|
||||
content = message.content.trim();
|
||||
} else {
|
||||
console.error(`message content is not a string`);
|
||||
return (<></>)
|
||||
}
|
||||
|
||||
if (!content) {
|
||||
return (<></>)
|
||||
};
|
||||
@ -376,7 +389,8 @@ const Message = (props: MessageProps) => {
|
||||
);
|
||||
|
||||
let metadataView = (<></>);
|
||||
if (message.metadata) {
|
||||
const metadata: ChatMessageMetaData | null = ('metadata' in message) ? (message.metadata as ChatMessageMetaData || null) : null;
|
||||
if (metadata) {
|
||||
metadataView = (
|
||||
<Box sx={{ display: "flex", flexDirection: "column", width: "100%" }}>
|
||||
<Box sx={{ display: "flex", alignItems: "center", gap: 1, flexDirection: "row" }}>
|
||||
@ -395,18 +409,18 @@ const Message = (props: MessageProps) => {
|
||||
</Box>
|
||||
<Collapse in={metaExpanded} timeout="auto" unmountOnExit>
|
||||
<CardContent>
|
||||
<MessageMeta messageProps={props} metadata={message.metadata} />
|
||||
<MessageMeta messageProps={props} metadata={metadata} />
|
||||
</CardContent>
|
||||
</Collapse>
|
||||
</Box>);
|
||||
}
|
||||
|
||||
const copyContent = message.sender === 'assistant' ? message.content : undefined;
|
||||
const copyContent = (type === 'assistant') ? message.content : undefined;
|
||||
|
||||
if (!expandable) {
|
||||
/* When not expandable, the styles are applied directly to MessageContainer */
|
||||
return (<>
|
||||
{messageView && <MessageContainer copyContent={copyContent} type={message.type} {...{ messageView, metadataView }} sx={{ ...style, ...sx }} />}
|
||||
{messageView && <MessageContainer copyContent={copyContent} type={type} {...{ messageView, metadataView }} sx={{ ...style, ...sx }} />}
|
||||
</>);
|
||||
}
|
||||
|
||||
@ -435,7 +449,7 @@ const Message = (props: MessageProps) => {
|
||||
{title || ''}
|
||||
</AccordionSummary>
|
||||
<AccordionDetails sx={{ mt: 0, mb: 0, p: 0, pl: 2, pr: 2 }}>
|
||||
<MessageContainer copyContent={copyContent} type={message.type} {...{ messageView, metadataView }} />
|
||||
<MessageContainer copyContent={copyContent} type={type} {...{ messageView, metadataView }} />
|
||||
</AccordionDetails>
|
||||
</Accordion>
|
||||
);
|
||||
|
@ -25,11 +25,6 @@ import { useAuth } from 'hooks/AuthContext';
|
||||
import * as Types from 'types/types';
|
||||
import { useSelectedCandidate } from 'hooks/GlobalContext';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import { Message } from './Message';
|
||||
const defaultMessage: Types.ChatMessageBase = {
|
||||
type: "preparing", status: "done", sender: "system", sessionId: "", timestamp: new Date(), content: ""
|
||||
};
|
||||
|
||||
|
||||
interface VectorVisualizerProps extends BackstoryPageProps {
|
||||
inline?: boolean;
|
||||
|
@ -12,7 +12,7 @@ import {
|
||||
Send as SendIcon
|
||||
} from '@mui/icons-material';
|
||||
import { useAuth } from 'hooks/AuthContext';
|
||||
import { ChatMessageBase, ChatMessage, ChatSession, ChatMessageUser } from 'types/types';
|
||||
import { ChatMessage, ChatSession, ChatMessageUser, ChatMessageError, ChatMessageStreaming, ChatMessageStatus } from 'types/types';
|
||||
import { ConversationHandle } from 'components/Conversation';
|
||||
import { BackstoryPageProps } from 'components/BackstoryTab';
|
||||
import { Message } from 'components/Message';
|
||||
@ -25,7 +25,7 @@ import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryT
|
||||
import { BackstoryQuery } from 'components/BackstoryQuery';
|
||||
|
||||
const defaultMessage: ChatMessage = {
|
||||
type: "preparing", status: "done", sender: "system", sessionId: "", timestamp: new Date(), content: ""
|
||||
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "user"
|
||||
};
|
||||
|
||||
const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: BackstoryPageProps, ref) => {
|
||||
@ -33,7 +33,7 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
||||
const navigate = useNavigate();
|
||||
const { selectedCandidate } = useSelectedCandidate()
|
||||
const theme = useTheme();
|
||||
const [processingMessage, setProcessingMessage] = useState<ChatMessage | null>(null);
|
||||
const [processingMessage, setProcessingMessage] = useState<ChatMessageStatus | ChatMessageError | null>(null);
|
||||
const [streamingMessage, setStreamingMessage] = useState<ChatMessage | null>(null);
|
||||
const backstoryTextRef = useRef<BackstoryTextFieldRef>(null);
|
||||
|
||||
@ -48,32 +48,6 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
||||
const [streaming, setStreaming] = useState<boolean>(false);
|
||||
const messagesEndRef = useRef(null);
|
||||
|
||||
// Load sessions for the selectedCandidate
|
||||
const loadSessions = async () => {
|
||||
if (!selectedCandidate) return;
|
||||
|
||||
try {
|
||||
setLoading(true);
|
||||
const result = await apiClient.getCandidateChatSessions(selectedCandidate.username);
|
||||
let session = null;
|
||||
if (result.sessions.data.length === 0) {
|
||||
session = await apiClient.createCandidateChatSession(
|
||||
selectedCandidate.username,
|
||||
'candidate_chat',
|
||||
`Backstory chat about ${selectedCandidate.fullName}`
|
||||
);
|
||||
} else {
|
||||
session = result.sessions.data[0];
|
||||
}
|
||||
setChatSession(session);
|
||||
setLoading(false);
|
||||
} catch (error) {
|
||||
setSnack('Unable to load chat session', 'error');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Load messages for current session
|
||||
const loadMessages = async () => {
|
||||
if (!chatSession?.id) return;
|
||||
@ -107,20 +81,22 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
||||
|
||||
// Send message
|
||||
const sendMessage = async (message: string) => {
|
||||
if (!message.trim() || !chatSession?.id || streaming) return;
|
||||
if (!message.trim() || !chatSession?.id || streaming || !selectedCandidate) return;
|
||||
|
||||
const messageContent = message;
|
||||
setStreaming(true);
|
||||
|
||||
const chatMessage: ChatMessageUser = {
|
||||
sessionId: chatSession.id,
|
||||
role: "user",
|
||||
content: messageContent,
|
||||
status: "done",
|
||||
type: "user",
|
||||
sender: "user",
|
||||
type: "text",
|
||||
timestamp: new Date()
|
||||
};
|
||||
|
||||
setProcessingMessage({ ...defaultMessage, status: 'status', content: `Establishing connection with ${selectedCandidate.firstName}'s chat session.` });
|
||||
|
||||
setMessages(prev => {
|
||||
const filtered = prev.filter((m: any) => m.id !== chatMessage.id);
|
||||
return [...filtered, chatMessage] as any;
|
||||
@ -129,34 +105,31 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
||||
try {
|
||||
apiClient.sendMessageStream(chatMessage, {
|
||||
onMessage: (msg: ChatMessage) => {
|
||||
console.log(`onMessage: ${msg.type} ${msg.content}`, msg);
|
||||
if (msg.type === "response") {
|
||||
setMessages(prev => {
|
||||
const filtered = prev.filter((m: any) => m.id !== msg.id);
|
||||
return [...filtered, msg] as any;
|
||||
});
|
||||
setStreamingMessage(null);
|
||||
setProcessingMessage(null);
|
||||
} else {
|
||||
setProcessingMessage(msg);
|
||||
}
|
||||
setMessages(prev => {
|
||||
const filtered = prev.filter((m: any) => m.id !== msg.id);
|
||||
return [...filtered, msg] as any;
|
||||
});
|
||||
setStreamingMessage(null);
|
||||
setProcessingMessage(null);
|
||||
},
|
||||
onError: (error: string | ChatMessageBase) => {
|
||||
onError: (error: string | ChatMessageError) => {
|
||||
console.log("onError:", error);
|
||||
let message: string;
|
||||
// Type-guard to determine if this is a ChatMessageBase or a string
|
||||
if (typeof error === "object" && error !== null && "content" in error) {
|
||||
setProcessingMessage(error as ChatMessage);
|
||||
setProcessingMessage(error);
|
||||
message = error.content as string;
|
||||
} else {
|
||||
setProcessingMessage({ ...defaultMessage, content: error as string });
|
||||
setProcessingMessage({ ...defaultMessage, status: "error", content: error })
|
||||
}
|
||||
setStreaming(false);
|
||||
},
|
||||
onStreaming: (chunk: ChatMessageBase) => {
|
||||
onStreaming: (chunk: ChatMessageStreaming) => {
|
||||
// console.log("onStreaming:", chunk);
|
||||
setStreamingMessage({ ...defaultMessage, ...chunk });
|
||||
setStreamingMessage({ ...chunk, role: 'assistant' });
|
||||
},
|
||||
onStatusChange: (status: string) => {
|
||||
console.log(`onStatusChange: ${status}`);
|
||||
onStatus: (status: ChatMessageStatus) => {
|
||||
setProcessingMessage(status);
|
||||
},
|
||||
onComplete: () => {
|
||||
console.log("onComplete");
|
||||
@ -178,7 +151,19 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
||||
|
||||
// Load sessions when username changes
|
||||
useEffect(() => {
|
||||
loadSessions();
|
||||
if (!selectedCandidate) return;
|
||||
try {
|
||||
setLoading(true);
|
||||
apiClient.getOrCreateChatSession(selectedCandidate, `Backstory chat with ${selectedCandidate.fullName}`, 'candidate_chat')
|
||||
.then(session => {
|
||||
setChatSession(session);
|
||||
setLoading(false);
|
||||
});
|
||||
} catch (error) {
|
||||
setSnack('Unable to load chat session', 'error');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, [selectedCandidate]);
|
||||
|
||||
// Load messages when session changes
|
||||
@ -195,9 +180,9 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
||||
|
||||
const welcomeMessage: ChatMessage = {
|
||||
sessionId: chatSession?.id || '',
|
||||
type: "info",
|
||||
role: "information",
|
||||
type: "text",
|
||||
status: "done",
|
||||
sender: "system",
|
||||
timestamp: new Date(),
|
||||
content: `Welcome to the Backstory Chat about ${selectedCandidate.fullName}. Ask any questions you have about ${selectedCandidate.firstName}.`
|
||||
};
|
||||
@ -252,7 +237,7 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
||||
},
|
||||
}}>
|
||||
{messages.length === 0 && <Message {...{ chatSession, message: welcomeMessage, setSnack, submitQuery }} />}
|
||||
{messages.map((message: ChatMessageBase) => (
|
||||
{messages.map((message: ChatMessage) => (
|
||||
<Message key={message.id} {...{ chatSession, message, setSnack, submitQuery }} />
|
||||
))}
|
||||
{processingMessage !== null && (
|
||||
|
@ -8,7 +8,6 @@ import IconButton from '@mui/material/IconButton';
|
||||
import CancelIcon from '@mui/icons-material/Cancel';
|
||||
import SendIcon from '@mui/icons-material/Send';
|
||||
import PropagateLoader from 'react-spinners/PropagateLoader';
|
||||
import { jsonrepair } from 'jsonrepair';
|
||||
|
||||
import { CandidateInfo } from '../components/CandidateInfo';
|
||||
import { Quote } from 'components/Quote';
|
||||
@ -18,57 +17,24 @@ import { StyledMarkdown } from 'components/StyledMarkdown';
|
||||
import { Scrollable } from '../components/Scrollable';
|
||||
import { Pulse } from 'components/Pulse';
|
||||
import { StreamingResponse } from 'services/api-client';
|
||||
import { ChatContext, ChatMessage, ChatMessageUser, ChatMessageBase, ChatSession, ChatQuery, Candidate, CandidateAI } from 'types/types';
|
||||
import { ChatMessage, ChatMessageUser, ChatSession, CandidateAI, ChatMessageStatus, ChatMessageError } from 'types/types';
|
||||
import { useAuth } from 'hooks/AuthContext';
|
||||
import { Message } from 'components/Message';
|
||||
import { Types } from '@uiw/react-json-view';
|
||||
import { assert } from 'console';
|
||||
|
||||
const emptyUser: CandidateAI = {
|
||||
userType: "candidate",
|
||||
isAI: true,
|
||||
description: "[blank]",
|
||||
username: "[blank]",
|
||||
firstName: "[blank]",
|
||||
lastName: "[blank]",
|
||||
fullName: "[blank] [blank]",
|
||||
questions: [],
|
||||
location: {
|
||||
city: '[blank]',
|
||||
country: '[blank]'
|
||||
},
|
||||
email: '[blank]',
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
status: "pending",
|
||||
skills: [],
|
||||
experience: [],
|
||||
education: [],
|
||||
preferredJobTypes: [],
|
||||
languages: [],
|
||||
certifications: [],
|
||||
isAdmin: false,
|
||||
profileImage: undefined,
|
||||
ragContentSize: 0
|
||||
};
|
||||
|
||||
const defaultMessage: ChatMessage = {
|
||||
type: "preparing", status: "done", sender: "system", sessionId: "", timestamp: new Date(), content: ""
|
||||
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "user"
|
||||
};
|
||||
|
||||
const GenerateCandidate = (props: BackstoryElementProps) => {
|
||||
const { apiClient, user } = useAuth();
|
||||
const { setSnack, submitQuery } = props;
|
||||
const [streaming, setStreaming] = useState<boolean>(false);
|
||||
const [streamingMessage, setStreamingMessage] = useState<ChatMessage | null>(null);
|
||||
const [processingMessage, setProcessingMessage] = useState<ChatMessage | null>(null);
|
||||
const [processing, setProcessing] = useState<boolean>(false);
|
||||
const [generatedUser, setGeneratedUser] = useState<CandidateAI | null>(null);
|
||||
const [prompt, setPrompt] = useState<string>('');
|
||||
const [resume, setResume] = useState<string | null>(null);
|
||||
const [canGenImage, setCanGenImage] = useState<boolean>(false);
|
||||
const [timestamp, setTimestamp] = useState<number>(0);
|
||||
const [state, setState] = useState<number>(0); // Replaced stateRef
|
||||
const [timestamp, setTimestamp] = useState<string>('');
|
||||
const [shouldGenerateProfile, setShouldGenerateProfile] = useState<boolean>(false);
|
||||
const [chatSession, setChatSession] = useState<ChatSession | null>(null);
|
||||
const [loading, setLoading] = useState<boolean>(false);
|
||||
@ -83,61 +49,24 @@ const GenerateCandidate = (props: BackstoryElementProps) => {
|
||||
return;
|
||||
}
|
||||
|
||||
const createChatSession = async () => {
|
||||
console.log('Creating chat session');
|
||||
try {
|
||||
const response: ChatSession = await apiClient.createCandidateChatSession(
|
||||
generatedUser.username,
|
||||
"generate_image",
|
||||
"Profile image generation"
|
||||
);
|
||||
setChatSession(response);
|
||||
console.log(`Chat session created for generate_image`, response);
|
||||
setSnack(`Chat session created for generate_image: ${response.id}`);
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
setSnack("Unable to create image generation session.", "error");
|
||||
}
|
||||
};
|
||||
|
||||
setLoading(true);
|
||||
createChatSession().then(() => { setLoading(false) });
|
||||
}, [generatedUser, chatSession, loading, setChatSession, setLoading, setSnack]);
|
||||
|
||||
const generatePersona = async (prompt: string) => {
|
||||
const userMessage: ChatMessageUser = {
|
||||
content: prompt,
|
||||
sessionId: "",
|
||||
sender: "user",
|
||||
status: "done",
|
||||
type: "user",
|
||||
timestamp: new Date()
|
||||
};
|
||||
setPrompt(prompt || '');
|
||||
setProcessing(true);
|
||||
setProcessingMessage({ ...defaultMessage, content: "Generating persona..." });
|
||||
try {
|
||||
const result = await apiClient.createCandidateAI(userMessage);
|
||||
console.log(result.message, result);
|
||||
setGeneratedUser(result.candidate);
|
||||
setResume(result.resume);
|
||||
setCanGenImage(true);
|
||||
setShouldGenerateProfile(true); // Reset the flag
|
||||
setLoading(true);
|
||||
apiClient.getOrCreateChatSession(generatedUser, `Profile image generator for ${generatedUser.fullName}`, 'generate_image')
|
||||
.then(session => {
|
||||
setChatSession(session);
|
||||
setLoading(false);
|
||||
});
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
setPrompt('');
|
||||
setResume(null);
|
||||
setProcessing(false);
|
||||
setProcessingMessage(null);
|
||||
setSnack("Unable to generate AI persona", "error");
|
||||
setSnack('Unable to load chat session', 'error');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
}, [generatedUser, chatSession, loading, setChatSession, setLoading, setSnack, apiClient]);
|
||||
|
||||
const cancelQuery = useCallback(() => {
|
||||
if (controllerRef.current) {
|
||||
controllerRef.current.cancel();
|
||||
controllerRef.current = null;
|
||||
setState(0);
|
||||
setProcessing(false);
|
||||
}
|
||||
}, []);
|
||||
@ -146,8 +75,38 @@ const GenerateCandidate = (props: BackstoryElementProps) => {
|
||||
if (processing) {
|
||||
return;
|
||||
}
|
||||
|
||||
const generatePersona = async (prompt: string) => {
|
||||
const userMessage: ChatMessageUser = {
|
||||
type: "text",
|
||||
role: "user",
|
||||
content: prompt,
|
||||
sessionId: "",
|
||||
status: "done",
|
||||
timestamp: new Date()
|
||||
};
|
||||
setPrompt(prompt || '');
|
||||
setProcessing(true);
|
||||
setProcessingMessage({ ...defaultMessage, content: "Generating persona..." });
|
||||
try {
|
||||
const result = await apiClient.createCandidateAI(userMessage);
|
||||
console.log(result.message, result);
|
||||
setGeneratedUser(result.candidate);
|
||||
setResume(result.resume);
|
||||
setCanGenImage(true);
|
||||
setShouldGenerateProfile(true); // Reset the flag
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
setPrompt('');
|
||||
setResume(null);
|
||||
setProcessing(false);
|
||||
setProcessingMessage(null);
|
||||
setSnack("Unable to generate AI persona", "error");
|
||||
}
|
||||
};
|
||||
|
||||
generatePersona(value);
|
||||
}, [processing, generatePersona]);
|
||||
}, [processing, apiClient, setSnack]);
|
||||
|
||||
const handleSendClick = useCallback(() => {
|
||||
const value = (backstoryTextRef.current && backstoryTextRef.current.getAndResetValue()) || "";
|
||||
@ -173,13 +132,12 @@ const GenerateCandidate = (props: BackstoryElementProps) => {
|
||||
setProcessingMessage({ ...defaultMessage, content: 'Starting image generation...' });
|
||||
setProcessing(true);
|
||||
setCanGenImage(false);
|
||||
setState(3);
|
||||
|
||||
const chatMessage: ChatMessageUser = {
|
||||
sessionId: chatSession.id || '',
|
||||
role: "user",
|
||||
status: "done",
|
||||
type: "user",
|
||||
sender: "user",
|
||||
type: "text",
|
||||
timestamp: new Date(),
|
||||
content: prompt
|
||||
};
|
||||
@ -187,34 +145,23 @@ const GenerateCandidate = (props: BackstoryElementProps) => {
|
||||
controllerRef.current = apiClient.sendMessageStream(chatMessage, {
|
||||
onMessage: async (msg: ChatMessage) => {
|
||||
console.log(`onMessage: ${msg.type} ${msg.content}`, msg);
|
||||
if (msg.type === "heartbeat" && msg.content) {
|
||||
const heartbeat = JSON.parse(msg.content);
|
||||
setTimestamp(heartbeat.timestamp);
|
||||
}
|
||||
if (msg.type === "thinking" && msg.content) {
|
||||
const status = JSON.parse(msg.content);
|
||||
setProcessingMessage({ ...defaultMessage, content: status.message });
|
||||
}
|
||||
if (msg.type === "response") {
|
||||
controllerRef.current = null;
|
||||
try {
|
||||
await apiClient.updateCandidate(generatedUser.id || '', { profileImage: "profile.png" });
|
||||
const { success, message } = await apiClient.deleteChatSession(chatSession.id || '');
|
||||
console.log(`Profile generated for ${username} and chat session was ${!success ? 'not ' : ''} deleted: ${message}}`);
|
||||
setGeneratedUser({
|
||||
...generatedUser,
|
||||
profileImage: "profile.png"
|
||||
} as CandidateAI);
|
||||
setState(0);
|
||||
setCanGenImage(true);
|
||||
setShouldGenerateProfile(false);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
setSnack(`Unable to update ${username} to indicate they have a profile picture.`, "error");
|
||||
}
|
||||
controllerRef.current = null;
|
||||
try {
|
||||
await apiClient.updateCandidate(generatedUser.id || '', { profileImage: "profile.png" });
|
||||
const { success, message } = await apiClient.deleteChatSession(chatSession.id || '');
|
||||
console.log(`Profile generated for ${username} and chat session was ${!success ? 'not ' : ''} deleted: ${message}}`);
|
||||
setGeneratedUser({
|
||||
...generatedUser,
|
||||
profileImage: "profile.png"
|
||||
} as CandidateAI);
|
||||
setCanGenImage(true);
|
||||
setShouldGenerateProfile(false);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
setSnack(`Unable to update ${username} to indicate they have a profile picture.`, "error");
|
||||
}
|
||||
},
|
||||
onError: (error) => {
|
||||
onError: (error: string | ChatMessageError) => {
|
||||
console.log("onError:", error);
|
||||
// Type-guard to determine if this is a ChatMessageBase or a string
|
||||
if (typeof error === "object" && error !== null && "content" in error) {
|
||||
@ -223,27 +170,28 @@ const GenerateCandidate = (props: BackstoryElementProps) => {
|
||||
setSnack(error as string, "error");
|
||||
}
|
||||
setProcessingMessage(null);
|
||||
setStreaming(false);
|
||||
setProcessing(false);
|
||||
controllerRef.current = null;
|
||||
setState(0);
|
||||
setCanGenImage(true);
|
||||
setShouldGenerateProfile(false);
|
||||
},
|
||||
onComplete: () => {
|
||||
setProcessingMessage(null);
|
||||
setStreaming(false);
|
||||
setProcessing(false);
|
||||
controllerRef.current = null;
|
||||
setState(0);
|
||||
setCanGenImage(true);
|
||||
setShouldGenerateProfile(false);
|
||||
},
|
||||
onStatusChange: (status: string) => {
|
||||
onStatus: (status: ChatMessageStatus) => {
|
||||
if (status.activity === "heartbeat" && status.content) {
|
||||
setTimestamp(status.timestamp?.toISOString() || '');
|
||||
} else if (status.content) {
|
||||
setProcessingMessage({ ...defaultMessage, content: status.content });
|
||||
}
|
||||
console.log(`onStatusChange: ${status}`);
|
||||
},
|
||||
});
|
||||
}, [chatSession, shouldGenerateProfile, generatedUser, prompt, setSnack]);
|
||||
}, [chatSession, shouldGenerateProfile, generatedUser, prompt, setSnack, apiClient]);
|
||||
|
||||
if (!user?.isAdmin) {
|
||||
return (<Box>You must be logged in as an admin to generate AI candidates.</Box>);
|
||||
|
@ -5,8 +5,8 @@ import { ChatMessage } from 'types/types';
|
||||
|
||||
const LoadingPage = (props: BackstoryPageProps) => {
|
||||
const preamble: ChatMessage = {
|
||||
sender: 'system',
|
||||
type: 'preparing',
|
||||
role: 'assistant',
|
||||
type: 'text',
|
||||
status: 'done',
|
||||
sessionId: '',
|
||||
content: 'Please wait while connecting to Backstory...',
|
||||
|
@ -5,8 +5,8 @@ import { ChatMessage } from 'types/types';
|
||||
|
||||
const LoginRequired = (props: BackstoryPageProps) => {
|
||||
const preamble: ChatMessage = {
|
||||
sender: 'system',
|
||||
type: 'preparing',
|
||||
role: 'assistant',
|
||||
type: 'text',
|
||||
status: 'done',
|
||||
sessionId: '',
|
||||
content: 'You must be logged to view this feature.',
|
||||
|
@ -39,11 +39,11 @@ import {
|
||||
// ============================
|
||||
|
||||
interface StreamingOptions {
|
||||
onStatusChange?: (status: Types.ChatStatusType) => void;
|
||||
onStatus?: (status: Types.ChatMessageStatus) => void;
|
||||
onMessage?: (message: Types.ChatMessage) => void;
|
||||
onStreaming?: (chunk: Types.ChatMessageBase) => void;
|
||||
onStreaming?: (chunk: Types.ChatMessageStreaming) => void;
|
||||
onComplete?: () => void;
|
||||
onError?: (error: string | Types.ChatMessageBase) => void;
|
||||
onError?: (error: string | Types.ChatMessageError) => void;
|
||||
onWarn?: (warning: string) => void;
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
@ -761,6 +761,20 @@ class ApiClient {
|
||||
return result;
|
||||
}
|
||||
|
||||
async getOrCreateChatSession(candidate: Types.Candidate, title: string, context_type: Types.ChatContextType) : Promise<Types.ChatSession> {
|
||||
const result = await this.getCandidateChatSessions(candidate.username);
|
||||
/* Find the 'candidate_chat' session if it exists, otherwise create it */
|
||||
let session = result.sessions.data.find(session => session.title === 'candidate_chat');
|
||||
if (!session) {
|
||||
session = await this.createCandidateChatSession(
|
||||
candidate.username,
|
||||
context_type,
|
||||
title
|
||||
);
|
||||
}
|
||||
return session;
|
||||
}
|
||||
|
||||
async getCandidateSimilarContent(query: string
|
||||
): Promise<Types.ChromaDBGetResponse> {
|
||||
const response = await fetch(`${this.baseUrl}/candidates/rag-search`, {
|
||||
@ -965,7 +979,7 @@ class ApiClient {
|
||||
* Send message with streaming response support and date conversion
|
||||
*/
|
||||
sendMessageStream(
|
||||
chatMessage: Types.ChatMessageBase,
|
||||
chatMessage: Types.ChatMessageUser,
|
||||
options: StreamingOptions = {}
|
||||
): StreamingResponse {
|
||||
const abortController = new AbortController();
|
||||
@ -998,9 +1012,8 @@ class ApiClient {
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = '';
|
||||
let incomingMessage: Types.ChatMessage | null = null;
|
||||
let streamingMessage: Types.ChatMessageStreaming | null = null;
|
||||
const incomingMessageList: Types.ChatMessage[] = [];
|
||||
let incomingStatus : Types.ChatStatusType | null = null;
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
@ -1022,39 +1035,41 @@ class ApiClient {
|
||||
try {
|
||||
if (line.startsWith('data: ')) {
|
||||
const data = line.slice(5).trim();
|
||||
const incoming: Types.ChatMessageBase = JSON.parse(data);
|
||||
const incoming: any = JSON.parse(data);
|
||||
|
||||
// Convert date fields for incoming messages
|
||||
const convertedIncoming = convertChatMessageFromApi(incoming);
|
||||
|
||||
// Trigger callbacks based on status
|
||||
if (convertedIncoming.status !== incomingStatus) {
|
||||
options.onStatusChange?.(convertedIncoming.status);
|
||||
incomingStatus = convertedIncoming.status;
|
||||
}
|
||||
console.log(incoming.status, incoming);
|
||||
|
||||
// Handle different status types
|
||||
switch (convertedIncoming.status) {
|
||||
switch (incoming.status) {
|
||||
case 'streaming':
|
||||
if (incomingMessage === null) {
|
||||
incomingMessage = {...convertedIncoming};
|
||||
console.log(incoming.status, incoming);
|
||||
const streaming = Types.convertChatMessageStreamingFromApi(incoming);
|
||||
if (streamingMessage === null) {
|
||||
streamingMessage = {...streaming};
|
||||
} else {
|
||||
// Can't do a simple += as typescript thinks .content might not be there
|
||||
incomingMessage.content = (incomingMessage?.content || '') + convertedIncoming.content;
|
||||
streamingMessage.content = (streamingMessage?.content || '') + streaming.content;
|
||||
// Update timestamp to latest
|
||||
incomingMessage.timestamp = convertedIncoming.timestamp;
|
||||
streamingMessage.timestamp = streamingMessage.timestamp;
|
||||
}
|
||||
options.onStreaming?.(convertedIncoming);
|
||||
options.onStreaming?.(streamingMessage);
|
||||
break;
|
||||
|
||||
case 'status':
|
||||
const status = Types.convertChatMessageStatusFromApi(incoming);
|
||||
options.onStatus?.(status);
|
||||
break;
|
||||
|
||||
case 'error':
|
||||
options.onError?.(convertedIncoming);
|
||||
const error = Types.convertChatMessageErrorFromApi(incoming);
|
||||
options.onError?.(error);
|
||||
break;
|
||||
|
||||
default:
|
||||
incomingMessageList.push(convertedIncoming);
|
||||
case 'done':
|
||||
const message = Types.convertChatMessageFromApi(incoming);
|
||||
incomingMessageList.push(message);
|
||||
try {
|
||||
options.onMessage?.(convertedIncoming);
|
||||
options.onMessage?.(message);
|
||||
} catch (error) {
|
||||
console.error('onMessage handler failed: ', error);
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
// Generated TypeScript types from Pydantic models
|
||||
// Source: src/backend/models.py
|
||||
// Generated on: 2025-06-04T17:02:08.242818
|
||||
// Generated on: 2025-06-05T00:24:02.132276
|
||||
// DO NOT EDIT MANUALLY - This file is auto-generated
|
||||
|
||||
// ============================
|
||||
@ -11,15 +11,17 @@ export type AIModelType = "qwen2.5" | "flux-schnell";
|
||||
|
||||
export type ActivityType = "login" | "search" | "view_job" | "apply_job" | "message" | "update_profile" | "chat";
|
||||
|
||||
export type ApiActivityType = "system" | "info" | "searching" | "thinking" | "generating" | "generating_image" | "tooling" | "heartbeat";
|
||||
|
||||
export type ApiMessageType = "binary" | "text" | "json";
|
||||
|
||||
export type ApiStatusType = "streaming" | "status" | "done" | "error";
|
||||
|
||||
export type ApplicationStatus = "applied" | "reviewing" | "interview" | "offer" | "rejected" | "accepted" | "withdrawn";
|
||||
|
||||
export type ChatContextType = "job_search" | "job_requirements" | "candidate_chat" | "interview_prep" | "resume_review" | "general" | "generate_persona" | "generate_profile" | "generate_image" | "rag_search" | "skill_match";
|
||||
|
||||
export type ChatMessageType = "error" | "generating" | "info" | "preparing" | "processing" | "heartbeat" | "response" | "searching" | "rag_result" | "system" | "thinking" | "tooling" | "user";
|
||||
|
||||
export type ChatSenderType = "user" | "assistant" | "agent" | "system";
|
||||
|
||||
export type ChatStatusType = "initializing" | "streaming" | "status" | "done" | "error";
|
||||
export type ChatSenderType = "user" | "assistant" | "system" | "information" | "warning" | "error";
|
||||
|
||||
export type ColorBlindMode = "protanopia" | "deuteranopia" | "tritanopia" | "none";
|
||||
|
||||
@ -88,6 +90,15 @@ export interface Analytics {
|
||||
segment?: string;
|
||||
}
|
||||
|
||||
export interface ApiMessage {
|
||||
id?: string;
|
||||
sessionId: string;
|
||||
senderId?: string;
|
||||
status: "streaming" | "status" | "done" | "error";
|
||||
type: "binary" | "text" | "json";
|
||||
timestamp?: Date;
|
||||
}
|
||||
|
||||
export interface ApiResponse {
|
||||
success: boolean;
|
||||
data?: any;
|
||||
@ -284,24 +295,22 @@ export interface ChatMessage {
|
||||
id?: string;
|
||||
sessionId: string;
|
||||
senderId?: string;
|
||||
status: "initializing" | "streaming" | "status" | "done" | "error";
|
||||
type: "error" | "generating" | "info" | "preparing" | "processing" | "heartbeat" | "response" | "searching" | "rag_result" | "system" | "thinking" | "tooling" | "user";
|
||||
sender: "user" | "assistant" | "agent" | "system";
|
||||
status: "streaming" | "status" | "done" | "error";
|
||||
type: "binary" | "text" | "json";
|
||||
timestamp?: Date;
|
||||
tunables?: Tunables;
|
||||
role: "user" | "assistant" | "system" | "information" | "warning" | "error";
|
||||
content: string;
|
||||
tunables?: Tunables;
|
||||
metadata?: ChatMessageMetaData;
|
||||
}
|
||||
|
||||
export interface ChatMessageBase {
|
||||
export interface ChatMessageError {
|
||||
id?: string;
|
||||
sessionId: string;
|
||||
senderId?: string;
|
||||
status: "initializing" | "streaming" | "status" | "done" | "error";
|
||||
type: "error" | "generating" | "info" | "preparing" | "processing" | "heartbeat" | "response" | "searching" | "rag_result" | "system" | "thinking" | "tooling" | "user";
|
||||
sender: "user" | "assistant" | "agent" | "system";
|
||||
status: "streaming" | "status" | "done" | "error";
|
||||
type: "binary" | "text" | "json";
|
||||
timestamp?: Date;
|
||||
tunables?: Tunables;
|
||||
content: string;
|
||||
}
|
||||
|
||||
@ -328,25 +337,44 @@ export interface ChatMessageRagSearch {
|
||||
id?: string;
|
||||
sessionId: string;
|
||||
senderId?: string;
|
||||
status: "initializing" | "streaming" | "status" | "done" | "error";
|
||||
type: "error" | "generating" | "info" | "preparing" | "processing" | "heartbeat" | "response" | "searching" | "rag_result" | "system" | "thinking" | "tooling" | "user";
|
||||
sender: "user" | "assistant" | "agent" | "system";
|
||||
status: "streaming" | "status" | "done" | "error";
|
||||
type: "binary" | "text" | "json";
|
||||
timestamp?: Date;
|
||||
tunables?: Tunables;
|
||||
content: string;
|
||||
dimensions: number;
|
||||
content: Array<ChromaDBGetResponse>;
|
||||
}
|
||||
|
||||
export interface ChatMessageStatus {
|
||||
id?: string;
|
||||
sessionId: string;
|
||||
senderId?: string;
|
||||
status: "streaming" | "status" | "done" | "error";
|
||||
type: "binary" | "text" | "json";
|
||||
timestamp?: Date;
|
||||
activity: "system" | "info" | "searching" | "thinking" | "generating" | "generating_image" | "tooling" | "heartbeat";
|
||||
content: any;
|
||||
}
|
||||
|
||||
export interface ChatMessageStreaming {
|
||||
id?: string;
|
||||
sessionId: string;
|
||||
senderId?: string;
|
||||
status: "streaming" | "status" | "done" | "error";
|
||||
type: "binary" | "text" | "json";
|
||||
timestamp?: Date;
|
||||
content: string;
|
||||
}
|
||||
|
||||
export interface ChatMessageUser {
|
||||
id?: string;
|
||||
sessionId: string;
|
||||
senderId?: string;
|
||||
status: "initializing" | "streaming" | "status" | "done" | "error";
|
||||
type: "error" | "generating" | "info" | "preparing" | "processing" | "heartbeat" | "response" | "searching" | "rag_result" | "system" | "thinking" | "tooling" | "user";
|
||||
sender: "user" | "assistant" | "agent" | "system";
|
||||
status: "streaming" | "status" | "done" | "error";
|
||||
type: "binary" | "text" | "json";
|
||||
timestamp?: Date;
|
||||
tunables?: Tunables;
|
||||
role: "user" | "assistant" | "system" | "information" | "warning" | "error";
|
||||
content: string;
|
||||
tunables?: Tunables;
|
||||
}
|
||||
|
||||
export interface ChatOptions {
|
||||
@ -644,6 +672,20 @@ export interface JobRequirements {
|
||||
preferredAttributes?: Array<string>;
|
||||
}
|
||||
|
||||
export interface JobRequirementsMessage {
|
||||
id?: string;
|
||||
sessionId: string;
|
||||
senderId?: string;
|
||||
status: "streaming" | "status" | "done" | "error";
|
||||
type: "binary" | "text" | "json";
|
||||
timestamp?: Date;
|
||||
title?: string;
|
||||
summary?: string;
|
||||
company?: string;
|
||||
description: string;
|
||||
requirements?: JobRequirements;
|
||||
}
|
||||
|
||||
export interface JobResponse {
|
||||
success: boolean;
|
||||
data?: Job;
|
||||
@ -923,6 +965,19 @@ export function convertAnalyticsFromApi(data: any): Analytics {
|
||||
timestamp: new Date(data.timestamp),
|
||||
};
|
||||
}
|
||||
/**
|
||||
* Convert ApiMessage from API response, parsing date fields
|
||||
* Date fields: timestamp
|
||||
*/
|
||||
export function convertApiMessageFromApi(data: any): ApiMessage {
|
||||
if (!data) return data;
|
||||
|
||||
return {
|
||||
...data,
|
||||
// Convert timestamp from ISO string to Date
|
||||
timestamp: data.timestamp ? new Date(data.timestamp) : undefined,
|
||||
};
|
||||
}
|
||||
/**
|
||||
* Convert ApplicationDecision from API response, parsing date fields
|
||||
* Date fields: date
|
||||
@ -1067,10 +1122,10 @@ export function convertChatMessageFromApi(data: any): ChatMessage {
|
||||
};
|
||||
}
|
||||
/**
|
||||
* Convert ChatMessageBase from API response, parsing date fields
|
||||
* Convert ChatMessageError from API response, parsing date fields
|
||||
* Date fields: timestamp
|
||||
*/
|
||||
export function convertChatMessageBaseFromApi(data: any): ChatMessageBase {
|
||||
export function convertChatMessageErrorFromApi(data: any): ChatMessageError {
|
||||
if (!data) return data;
|
||||
|
||||
return {
|
||||
@ -1092,6 +1147,32 @@ export function convertChatMessageRagSearchFromApi(data: any): ChatMessageRagSea
|
||||
timestamp: data.timestamp ? new Date(data.timestamp) : undefined,
|
||||
};
|
||||
}
|
||||
/**
|
||||
* Convert ChatMessageStatus from API response, parsing date fields
|
||||
* Date fields: timestamp
|
||||
*/
|
||||
export function convertChatMessageStatusFromApi(data: any): ChatMessageStatus {
|
||||
if (!data) return data;
|
||||
|
||||
return {
|
||||
...data,
|
||||
// Convert timestamp from ISO string to Date
|
||||
timestamp: data.timestamp ? new Date(data.timestamp) : undefined,
|
||||
};
|
||||
}
|
||||
/**
|
||||
* Convert ChatMessageStreaming from API response, parsing date fields
|
||||
* Date fields: timestamp
|
||||
*/
|
||||
export function convertChatMessageStreamingFromApi(data: any): ChatMessageStreaming {
|
||||
if (!data) return data;
|
||||
|
||||
return {
|
||||
...data,
|
||||
// Convert timestamp from ISO string to Date
|
||||
timestamp: data.timestamp ? new Date(data.timestamp) : undefined,
|
||||
};
|
||||
}
|
||||
/**
|
||||
* Convert ChatMessageUser from API response, parsing date fields
|
||||
* Date fields: timestamp
|
||||
@ -1268,6 +1349,19 @@ export function convertJobFullFromApi(data: any): JobFull {
|
||||
featuredUntil: data.featuredUntil ? new Date(data.featuredUntil) : undefined,
|
||||
};
|
||||
}
|
||||
/**
|
||||
* Convert JobRequirementsMessage from API response, parsing date fields
|
||||
* Date fields: timestamp
|
||||
*/
|
||||
export function convertJobRequirementsMessageFromApi(data: any): JobRequirementsMessage {
|
||||
if (!data) return data;
|
||||
|
||||
return {
|
||||
...data,
|
||||
// Convert timestamp from ISO string to Date
|
||||
timestamp: data.timestamp ? new Date(data.timestamp) : undefined,
|
||||
};
|
||||
}
|
||||
/**
|
||||
* Convert MessageReaction from API response, parsing date fields
|
||||
* Date fields: timestamp
|
||||
@ -1348,6 +1442,8 @@ export function convertFromApi<T>(data: any, modelType: string): T {
|
||||
switch (modelType) {
|
||||
case 'Analytics':
|
||||
return convertAnalyticsFromApi(data) as T;
|
||||
case 'ApiMessage':
|
||||
return convertApiMessageFromApi(data) as T;
|
||||
case 'ApplicationDecision':
|
||||
return convertApplicationDecisionFromApi(data) as T;
|
||||
case 'Attachment':
|
||||
@ -1366,10 +1462,14 @@ export function convertFromApi<T>(data: any, modelType: string): T {
|
||||
return convertCertificationFromApi(data) as T;
|
||||
case 'ChatMessage':
|
||||
return convertChatMessageFromApi(data) as T;
|
||||
case 'ChatMessageBase':
|
||||
return convertChatMessageBaseFromApi(data) as T;
|
||||
case 'ChatMessageError':
|
||||
return convertChatMessageErrorFromApi(data) as T;
|
||||
case 'ChatMessageRagSearch':
|
||||
return convertChatMessageRagSearchFromApi(data) as T;
|
||||
case 'ChatMessageStatus':
|
||||
return convertChatMessageStatusFromApi(data) as T;
|
||||
case 'ChatMessageStreaming':
|
||||
return convertChatMessageStreamingFromApi(data) as T;
|
||||
case 'ChatMessageUser':
|
||||
return convertChatMessageUserFromApi(data) as T;
|
||||
case 'ChatSession':
|
||||
@ -1394,6 +1494,8 @@ export function convertFromApi<T>(data: any, modelType: string): T {
|
||||
return convertJobApplicationFromApi(data) as T;
|
||||
case 'JobFull':
|
||||
return convertJobFullFromApi(data) as T;
|
||||
case 'JobRequirementsMessage':
|
||||
return convertJobRequirementsMessageFromApi(data) as T;
|
||||
case 'MessageReaction':
|
||||
return convertMessageReactionFromApi(data) as T;
|
||||
case 'RAGConfiguration':
|
||||
|
@ -24,7 +24,7 @@ from datetime import datetime, UTC
|
||||
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
|
||||
import numpy as np # type: ignore
|
||||
|
||||
from models import ( LLMMessage, ChatQuery, ChatMessage, ChatOptions, ChatMessageBase, ChatMessageUser, Tunables, ChatMessageType, ChatSenderType, ChatStatusType, ChatMessageMetaData, Candidate)
|
||||
from models import ( ApiActivityType, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatQuery, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate)
|
||||
from logger import logger
|
||||
import defines
|
||||
from .registry import agent_registry
|
||||
@ -300,12 +300,35 @@ class Agent(BaseModel, ABC):
|
||||
)
|
||||
self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count)
|
||||
|
||||
def get_rag_context(self, rag_message: ChatMessageRagSearch) -> str:
|
||||
"""
|
||||
Extracts the RAG context from the rag_message.
|
||||
"""
|
||||
if not rag_message.content:
|
||||
return ""
|
||||
|
||||
context = []
|
||||
for chroma_results in rag_message.content:
|
||||
for index, metadata in enumerate(chroma_results.metadatas):
|
||||
content = "\n".join([
|
||||
line.strip()
|
||||
for line in chroma_results.documents[index].split("\n")
|
||||
if line
|
||||
]).strip()
|
||||
context.append(f"""
|
||||
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")}
|
||||
Document reference: {chroma_results.ids[index]}
|
||||
Content: {content}
|
||||
""")
|
||||
return "\n".join(context)
|
||||
|
||||
async def generate_rag_results(
|
||||
self,
|
||||
chat_message: ChatMessage,
|
||||
session_id: str,
|
||||
prompt: str,
|
||||
top_k: int=defines.default_rag_top_k,
|
||||
threshold: float=defines.default_rag_threshold,
|
||||
) -> AsyncGenerator[ChatMessage, None]:
|
||||
) -> AsyncGenerator[ChatMessageRagSearch | ChatMessageError | ChatMessageStatus, None]:
|
||||
"""
|
||||
Generate RAG results for the given query.
|
||||
|
||||
@ -315,223 +338,193 @@ class Agent(BaseModel, ABC):
|
||||
Returns:
|
||||
A list of dictionaries containing the RAG results.
|
||||
"""
|
||||
rag_message = ChatMessage(
|
||||
session_id=chat_message.session_id,
|
||||
tunables=chat_message.tunables,
|
||||
status=ChatStatusType.INITIALIZING,
|
||||
type=ChatMessageType.PREPARING,
|
||||
sender=ChatSenderType.ASSISTANT,
|
||||
content="",
|
||||
timestamp=datetime.now(UTC),
|
||||
metadata=ChatMessageMetaData()
|
||||
)
|
||||
|
||||
if not self.user:
|
||||
logger.error("No user set for RAG generation")
|
||||
rag_message.status = ChatStatusType.DONE
|
||||
rag_message.content = ""
|
||||
yield rag_message
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="No user set for RAG generation."
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
try:
|
||||
entries: int = 0
|
||||
user: Candidate = self.user
|
||||
rag_content: str = ""
|
||||
for rag in user.rags:
|
||||
if not rag.enabled:
|
||||
continue
|
||||
status_message = ChatMessage(
|
||||
session_id=chat_message.session_id,
|
||||
sender=ChatSenderType.AGENT,
|
||||
status = ChatStatusType.INITIALIZING,
|
||||
type = ChatMessageType.SEARCHING,
|
||||
content = f"Checking RAG context {rag.name}...")
|
||||
yield status_message
|
||||
results : List[ChromaDBGetResponse] = []
|
||||
entries: int = 0
|
||||
user: Candidate = self.user
|
||||
for rag in user.rags:
|
||||
if not rag.enabled:
|
||||
continue
|
||||
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
activity=ApiActivityType.SEARCHING,
|
||||
content = f"Searching RAG context {rag.name}..."
|
||||
)
|
||||
yield status_message
|
||||
|
||||
try:
|
||||
chroma_results = user.file_watcher.find_similar(
|
||||
query=chat_message.content, top_k=top_k, threshold=threshold
|
||||
query=prompt, top_k=top_k, threshold=threshold
|
||||
)
|
||||
if chroma_results:
|
||||
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
|
||||
if not chroma_results:
|
||||
continue
|
||||
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
|
||||
|
||||
umap_2d = user.file_watcher.umap_model_2d.transform([query_embedding])[0]
|
||||
umap_3d = user.file_watcher.umap_model_3d.transform([query_embedding])[0]
|
||||
umap_2d = user.file_watcher.umap_model_2d.transform([query_embedding])[0]
|
||||
umap_3d = user.file_watcher.umap_model_3d.transform([query_embedding])[0]
|
||||
|
||||
rag_metadata = ChromaDBGetResponse(
|
||||
query=chat_message.content,
|
||||
query_embedding=query_embedding.tolist(),
|
||||
name=rag.name,
|
||||
ids=chroma_results.get("ids", []),
|
||||
embeddings=chroma_results.get("embeddings", []),
|
||||
documents=chroma_results.get("documents", []),
|
||||
metadatas=chroma_results.get("metadatas", []),
|
||||
umap_embedding_2d=umap_2d.tolist(),
|
||||
umap_embedding_3d=umap_3d.tolist(),
|
||||
size=user.file_watcher.collection.count()
|
||||
)
|
||||
rag_metadata = ChromaDBGetResponse(
|
||||
name=rag.name,
|
||||
query_embedding=query_embedding.tolist(),
|
||||
ids=chroma_results.get("ids", []),
|
||||
embeddings=chroma_results.get("embeddings", []),
|
||||
documents=chroma_results.get("documents", []),
|
||||
metadatas=chroma_results.get("metadatas", []),
|
||||
umap_embedding_2d=umap_2d.tolist(),
|
||||
umap_embedding_3d=umap_3d.tolist(),
|
||||
)
|
||||
results.append(rag_metadata)
|
||||
except Exception as e:
|
||||
continue_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
activity=ApiActivityType.SEARCHING,
|
||||
content=f"Error searching RAG context {rag.name}: {str(e)}"
|
||||
)
|
||||
yield continue_message
|
||||
|
||||
entries += len(rag_metadata.documents)
|
||||
rag_message.metadata.rag_results.append(rag_metadata)
|
||||
|
||||
for index, metadata in enumerate(chroma_results["metadatas"]):
|
||||
content = "\n".join(
|
||||
[
|
||||
line.strip()
|
||||
for line in chroma_results["documents"][index].split("\n")
|
||||
if line
|
||||
]
|
||||
).strip()
|
||||
rag_content += f"""
|
||||
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")}
|
||||
Document reference: {chroma_results["ids"][index]}
|
||||
Content: { content }
|
||||
"""
|
||||
rag_message.content = rag_content.strip()
|
||||
rag_message.type = ChatMessageType.RAG_RESULT
|
||||
rag_message.status = ChatStatusType.DONE
|
||||
yield rag_message
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
rag_message.status = ChatStatusType.ERROR
|
||||
rag_message.content = f"Error generating RAG results: {str(e)}"
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(rag_message.content)
|
||||
yield rag_message
|
||||
return
|
||||
|
||||
async def llm_one_shot(self, llm: Any, model: str, user_message: ChatMessageUser, system_prompt: str, temperature=0.7):
|
||||
chat_message = ChatMessage(
|
||||
session_id=user_message.session_id,
|
||||
tunables=user_message.tunables,
|
||||
status=ChatStatusType.INITIALIZING,
|
||||
type=ChatMessageType.PREPARING,
|
||||
sender=ChatSenderType.AGENT,
|
||||
content="",
|
||||
timestamp=datetime.now(UTC)
|
||||
final_message = ChatMessageRagSearch(
|
||||
session_id=session_id,
|
||||
content=results,
|
||||
status=ApiStatusType.DONE,
|
||||
)
|
||||
yield final_message
|
||||
return
|
||||
|
||||
async def llm_one_shot(
|
||||
self,
|
||||
llm: Any, model: str,
|
||||
session_id: str, prompt: str, system_prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7) -> AsyncGenerator[ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessage, None]:
|
||||
|
||||
self.set_optimal_context_size(
|
||||
llm, model, prompt=chat_message.content
|
||||
llm=llm, model=model, prompt=prompt
|
||||
)
|
||||
|
||||
chat_message.metadata = ChatMessageMetaData()
|
||||
chat_message.metadata.options = ChatOptions(
|
||||
options = ChatOptions(
|
||||
seed=8911,
|
||||
num_ctx=self.context_size,
|
||||
temperature=temperature, # Higher temperature to encourage tool usage
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
messages: List[LLMMessage] = [
|
||||
LLMMessage(role="system", content=system_prompt),
|
||||
LLMMessage(role="user", content=user_message.content),
|
||||
LLMMessage(role="user", content=prompt),
|
||||
]
|
||||
|
||||
# Reset the response for streaming
|
||||
chat_message.content = ""
|
||||
chat_message.type = ChatMessageType.GENERATING
|
||||
chat_message.status = ChatStatusType.STREAMING
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
activity=ApiActivityType.GENERATING,
|
||||
content=f"Generating response..."
|
||||
)
|
||||
yield status_message
|
||||
|
||||
logger.info(f"Message options: {chat_message.metadata.options.model_dump(exclude_unset=True)}")
|
||||
logger.info(f"Message options: {options.model_dump(exclude_unset=True)}")
|
||||
response = None
|
||||
content = ""
|
||||
for response in llm.chat(
|
||||
model=model,
|
||||
messages=messages,
|
||||
options={
|
||||
**chat_message.metadata.options.model_dump(exclude_unset=True),
|
||||
**options.model_dump(exclude_unset=True),
|
||||
},
|
||||
stream=True,
|
||||
):
|
||||
if not response:
|
||||
chat_message.status = ChatStatusType.ERROR
|
||||
chat_message.content = "No response from LLM."
|
||||
yield chat_message
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="No response from LLM."
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
chat_message.content += response.message.content
|
||||
content += response.message.content
|
||||
|
||||
if not response.done:
|
||||
chat_chunk = model_cast.cast_to_model(ChatMessageBase, chat_message)
|
||||
chat_chunk.content = response.message.content
|
||||
yield chat_message
|
||||
continue
|
||||
streaming_message = ChatMessageStreaming(
|
||||
session_id=session_id,
|
||||
content=response.message.content,
|
||||
status=ApiStatusType.STREAMING,
|
||||
)
|
||||
yield streaming_message
|
||||
|
||||
if not response:
|
||||
chat_message.status = ChatStatusType.ERROR
|
||||
chat_message.content = "No response from LLM."
|
||||
yield chat_message
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="No response from LLM."
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
self.collect_metrics(response)
|
||||
chat_message.metadata.eval_count += response.eval_count
|
||||
chat_message.metadata.eval_duration += response.eval_duration
|
||||
chat_message.metadata.prompt_eval_count += response.prompt_eval_count
|
||||
chat_message.metadata.prompt_eval_duration += response.prompt_eval_duration
|
||||
self.context_tokens = (
|
||||
response.prompt_eval_count + response.eval_count
|
||||
)
|
||||
chat_message.type = ChatMessageType.RESPONSE
|
||||
chat_message.status = ChatStatusType.DONE
|
||||
yield chat_message
|
||||
|
||||
async def generate(
|
||||
self, llm: Any, model: str, user_message: ChatMessageUser, user: Candidate | None, temperature=0.7
|
||||
) -> AsyncGenerator[ChatMessage | ChatMessageBase, None]:
|
||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||
|
||||
chat_message = ChatMessage(
|
||||
session_id=user_message.session_id,
|
||||
tunables=user_message.tunables,
|
||||
status=ChatStatusType.INITIALIZING,
|
||||
type=ChatMessageType.PREPARING,
|
||||
sender=ChatSenderType.ASSISTANT,
|
||||
content="",
|
||||
timestamp=datetime.now(UTC)
|
||||
)
|
||||
session_id=session_id,
|
||||
tunables=tunables,
|
||||
status=ApiStatusType.DONE,
|
||||
content=content,
|
||||
metadata = ChatMessageMetaData(
|
||||
options=options,
|
||||
eval_count=response.eval_count,
|
||||
eval_duration=response.eval_duration,
|
||||
prompt_eval_count=response.prompt_eval_count,
|
||||
prompt_eval_duration=response.prompt_eval_duration,
|
||||
|
||||
self.set_optimal_context_size(
|
||||
llm, model, prompt=chat_message.content
|
||||
)
|
||||
)
|
||||
yield chat_message
|
||||
return
|
||||
|
||||
chat_message.metadata = ChatMessageMetaData()
|
||||
chat_message.metadata.options = ChatOptions(
|
||||
seed=8911,
|
||||
num_ctx=self.context_size,
|
||||
temperature=temperature, # Higher temperature to encourage tool usage
|
||||
async def generate(
|
||||
self, llm: Any, model: str,
|
||||
session_id: str, prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7
|
||||
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
|
||||
if not self.user:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="No user set for chat generation."
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
user_message = ChatMessageUser(
|
||||
session_id=session_id,
|
||||
content=prompt,
|
||||
)
|
||||
|
||||
# Create a dict for storing various timing stats
|
||||
chat_message.metadata.timers = {}
|
||||
user = self.user
|
||||
|
||||
self.metrics.generate_count.labels(agent=self.agent_type).inc()
|
||||
with self.metrics.generate_duration.labels(agent=self.agent_type).time():
|
||||
context = None
|
||||
if self.user:
|
||||
rag_message = None
|
||||
async for rag_message in self.generate_rag_results(session_id=session_id, prompt=prompt):
|
||||
if rag_message.status == ApiStatusType.ERROR:
|
||||
yield rag_message
|
||||
return
|
||||
# Only yield messages that are in a streaming state
|
||||
if rag_message.status == ApiStatusType.STATUS:
|
||||
yield rag_message
|
||||
|
||||
rag_message : Optional[ChatMessage] = None
|
||||
async for rag_message in self.generate_rag_results(chat_message=user_message):
|
||||
if rag_message.status == ChatStatusType.ERROR:
|
||||
chat_message.status = rag_message.status
|
||||
chat_message.content = rag_message.content
|
||||
yield chat_message
|
||||
return
|
||||
yield rag_message
|
||||
if not isinstance(rag_message, ChatMessageRagSearch):
|
||||
raise ValueError(
|
||||
f"Expected ChatMessageRagSearch, got {type(rag_message)}"
|
||||
)
|
||||
|
||||
rag_context = ""
|
||||
if rag_message:
|
||||
rag_results: List[ChromaDBGetResponse] = rag_message.metadata.rag_results
|
||||
chat_message.metadata.rag_results = rag_results
|
||||
for chroma_results in rag_results:
|
||||
for index, metadata in enumerate(chroma_results.metadatas):
|
||||
content = "\n".join([
|
||||
line.strip()
|
||||
for line in chroma_results.documents[index].split("\n")
|
||||
if line
|
||||
]).strip()
|
||||
rag_context += f"""
|
||||
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")}
|
||||
Document reference: {chroma_results.ids[index]}
|
||||
Content: { content }
|
||||
|
||||
"""
|
||||
context = self.get_rag_context(rag_message)
|
||||
|
||||
# Create a pruned down message list based purely on the prompt and responses,
|
||||
# discarding the full preamble generated by prepare_message
|
||||
@ -540,24 +533,24 @@ Content: { content }
|
||||
]
|
||||
# Add the conversation history to the messages
|
||||
messages.extend([
|
||||
LLMMessage(role=m.sender, content=m.content.strip())
|
||||
LLMMessage(role="user" if isinstance(m, ChatMessageUser) else "assistant", content=m.content)
|
||||
for m in self.conversation
|
||||
])
|
||||
# Add the RAG context to the messages if available
|
||||
if rag_context and user:
|
||||
if context:
|
||||
messages.append(
|
||||
LLMMessage(
|
||||
role="user",
|
||||
content=f"<|context|>\nThe following is context information about {user.full_name}:\n{rag_context.strip()}\n</|context|>\n\nPrompt to respond to:\n{user_message.content.strip()}\n"
|
||||
content=f"<|context|>\nThe following is context information about {self.user.full_name}:\n{context}\n</|context|>\n\nPrompt to respond to:\n{prompt}\n"
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Only the actual user query is provided with the full context message
|
||||
messages.append(
|
||||
LLMMessage(role=user_message.sender, content=user_message.content.strip())
|
||||
LLMMessage(role="user", content=prompt)
|
||||
)
|
||||
|
||||
chat_message.metadata.llm_history = messages
|
||||
llm_history = messages
|
||||
|
||||
# use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
|
||||
# message.metadata.tools = {
|
||||
@ -660,57 +653,89 @@ Content: { content }
|
||||
# return
|
||||
|
||||
# not use_tools
|
||||
chat_message.type = ChatMessageType.THINKING
|
||||
chat_message.content = f"Generating response..."
|
||||
yield chat_message
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
activity=ApiActivityType.GENERATING,
|
||||
content=f"Generating response..."
|
||||
)
|
||||
yield status_message
|
||||
|
||||
# Reset the response for streaming
|
||||
chat_message.content = ""
|
||||
# Set the response for streaming
|
||||
self.set_optimal_context_size(
|
||||
llm, model, prompt=prompt
|
||||
)
|
||||
|
||||
options = ChatOptions(
|
||||
seed=8911,
|
||||
num_ctx=self.context_size,
|
||||
temperature=temperature,
|
||||
)
|
||||
logger.info(f"Message options: {options.model_dump(exclude_unset=True)}")
|
||||
content = ""
|
||||
start_time = time.perf_counter()
|
||||
chat_message.type = ChatMessageType.GENERATING
|
||||
chat_message.status = ChatStatusType.STREAMING
|
||||
|
||||
response = None
|
||||
for response in llm.chat(
|
||||
model=model,
|
||||
messages=messages,
|
||||
options={
|
||||
**chat_message.metadata.options.model_dump(exclude_unset=True),
|
||||
**options.model_dump(exclude_unset=True),
|
||||
},
|
||||
stream=True,
|
||||
):
|
||||
if not response:
|
||||
chat_message.status = ChatStatusType.ERROR
|
||||
chat_message.content = "No response from LLM."
|
||||
yield chat_message
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="No response from LLM."
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
chat_message.content += response.message.content
|
||||
content += response.message.content
|
||||
|
||||
if not response.done:
|
||||
chat_chunk = model_cast.cast_to_model(ChatMessageBase, chat_message)
|
||||
chat_chunk.content = response.message.content
|
||||
yield chat_message
|
||||
continue
|
||||
|
||||
if response.done:
|
||||
self.collect_metrics(response)
|
||||
chat_message.metadata.eval_count += response.eval_count
|
||||
chat_message.metadata.eval_duration += response.eval_duration
|
||||
chat_message.metadata.prompt_eval_count += response.prompt_eval_count
|
||||
chat_message.metadata.prompt_eval_duration += response.prompt_eval_duration
|
||||
self.context_tokens = (
|
||||
response.prompt_eval_count + response.eval_count
|
||||
streaming_message = ChatMessageStreaming(
|
||||
session_id=session_id,
|
||||
content=response.message.content,
|
||||
)
|
||||
chat_message.type = ChatMessageType.RESPONSE
|
||||
chat_message.status = ChatStatusType.DONE
|
||||
yield chat_message
|
||||
yield streaming_message
|
||||
|
||||
if not response:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="No response from LLM."
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
self.collect_metrics(response)
|
||||
self.context_tokens = (
|
||||
response.prompt_eval_count + response.eval_count
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
chat_message.metadata.timers["streamed"] = end_time - start_time
|
||||
|
||||
chat_message = ChatMessage(
|
||||
session_id=session_id,
|
||||
tunables=tunables,
|
||||
status=ApiStatusType.DONE,
|
||||
content=content,
|
||||
metadata = ChatMessageMetaData(
|
||||
options=options,
|
||||
eval_count=response.eval_count,
|
||||
eval_duration=response.eval_duration,
|
||||
prompt_eval_count=response.prompt_eval_count,
|
||||
prompt_eval_duration=response.prompt_eval_duration,
|
||||
timers={
|
||||
"llm_streamed": end_time - start_time,
|
||||
"llm_with_tools": 0, # Placeholder for tool processing time
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Add the user and chat messages to the conversation
|
||||
self.conversation.append(user_message)
|
||||
self.conversation.append(chat_message)
|
||||
yield chat_message
|
||||
return
|
||||
|
||||
# async def process_message(
|
||||
|
@ -7,7 +7,7 @@ from .base import Agent, agent_registry
|
||||
from logger import logger
|
||||
|
||||
from .registry import agent_registry
|
||||
from models import ( ChatQuery, ChatMessage, Tunables, ChatStatusType, ChatMessageUser, Candidate)
|
||||
from models import ( ChatMessageError, ChatMessageStatus, ChatMessageStreaming, ChatQuery, ChatMessage, Tunables, ApiStatusType, ChatMessageUser, Candidate)
|
||||
|
||||
|
||||
system_message = f"""
|
||||
@ -34,8 +34,15 @@ class CandidateChat(Agent):
|
||||
system_prompt: str = system_message
|
||||
|
||||
async def generate(
|
||||
self, llm: Any, model: str, user_message: ChatMessageUser, user: Candidate, temperature=0.7
|
||||
):
|
||||
self, llm: Any, model: str,
|
||||
session_id: str, prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7
|
||||
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
|
||||
user = self.user
|
||||
if not user:
|
||||
logger.error("User is not set for CandidateChat agent.")
|
||||
raise ValueError("User must be set before generating candidate chat responses.")
|
||||
self.system_prompt = f"""
|
||||
You are a helpful expert system representing a {user.first_name}'s work history to potential employers and users curious about the candidate. You want to incorporate as many facts and details about {user.first_name} as possible.
|
||||
|
||||
@ -49,7 +56,10 @@ Use that spelling instead of any spelling you may find in the <|context|>.
|
||||
{system_message}
|
||||
"""
|
||||
|
||||
async for message in super().generate(llm, model, user_message, user, temperature):
|
||||
async for message in super().generate(llm=llm, model=model, session_id=session_id, prompt=prompt, temperature=temperature, tunables=tunables):
|
||||
if message.status == ApiStatusType.ERROR:
|
||||
yield message
|
||||
return
|
||||
yield message
|
||||
|
||||
# Register the base agent
|
||||
|
@ -7,7 +7,7 @@ from .base import Agent, agent_registry
|
||||
from logger import logger
|
||||
|
||||
from .registry import agent_registry
|
||||
from models import ( ChatQuery, ChatMessage, Tunables, ChatStatusType)
|
||||
from models import ( ChatQuery, ChatMessage, Tunables, ApiStatusType)
|
||||
|
||||
system_message = f"""
|
||||
Launched on {datetime.now().isoformat()}.
|
||||
|
@ -25,7 +25,7 @@ import os
|
||||
import hashlib
|
||||
|
||||
from .base import Agent, agent_registry, LLMMessage
|
||||
from models import Candidate, ChatMessage, ChatMessageBase, ChatMessageMetaData, ChatMessageType, ChatMessageUser, ChatOptions, ChatSenderType, ChatStatusType
|
||||
from models import ActivityType, ApiActivityType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, Tunables
|
||||
import model_cast
|
||||
from logger import logger
|
||||
import defines
|
||||
@ -44,43 +44,48 @@ class ImageGenerator(Agent):
|
||||
system_prompt: str = "" # No system prompt is used
|
||||
|
||||
async def generate(
|
||||
self, llm: Any, model: str, user_message: ChatMessageUser, user: Candidate, temperature=0.7
|
||||
) -> AsyncGenerator[ChatMessage, None]:
|
||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||
self, llm: Any, model: str,
|
||||
session_id: str, prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7
|
||||
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
|
||||
if not self.user:
|
||||
logger.error("User is not set for ImageGenerator agent.")
|
||||
raise ValueError("User must be set before generating images.")
|
||||
user = self.user
|
||||
|
||||
file_path = os.path.join(defines.user_dir, user.username, "profile.png")
|
||||
chat_message = ChatMessage(
|
||||
session_id=user_message.session_id,
|
||||
tunables=user_message.tunables,
|
||||
status=ChatStatusType.INITIALIZING,
|
||||
type=ChatMessageType.PREPARING,
|
||||
sender=ChatSenderType.ASSISTANT,
|
||||
content="",
|
||||
timestamp=datetime.now(UTC)
|
||||
)
|
||||
|
||||
chat_message.metadata = ChatMessageMetaData()
|
||||
try:
|
||||
#
|
||||
# Generate the profile picture
|
||||
#
|
||||
chat_message.content = f"Generating: {user_message.content}"
|
||||
yield chat_message
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
activity=ApiActivityType.GENERATING_IMAGE,
|
||||
content=f"Generating {prompt}...",
|
||||
)
|
||||
yield status_message
|
||||
|
||||
logger.info(f"Image generation: {file_path} <- {user_message.content}")
|
||||
request = ImageRequest(filepath=file_path, prompt=user_message.content, iterations=4, height=256, width=256, guidance_scale=7.5)
|
||||
logger.info(f"Image generation: {file_path} <- {prompt}")
|
||||
request = ImageRequest(filepath=file_path, session_id=session_id, prompt=prompt, iterations=4, height=256, width=256, guidance_scale=7.5)
|
||||
generated_message = None
|
||||
async for generated_message in generate_image(
|
||||
user_message=user_message,
|
||||
request=request
|
||||
):
|
||||
if generated_message.status != "done":
|
||||
if generated_message.status == ApiStatusType.ERROR:
|
||||
yield generated_message
|
||||
return
|
||||
if generated_message.status != ApiStatusType.DONE:
|
||||
yield generated_message
|
||||
continue
|
||||
|
||||
if generated_message is None:
|
||||
chat_message.status = ChatStatusType.ERROR
|
||||
chat_message.content = "Image generation failed."
|
||||
yield chat_message
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Image generation failed to produce a valid response."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
logger.info("Image generation done...")
|
||||
@ -88,18 +93,23 @@ class ImageGenerator(Agent):
|
||||
user.profile_image = "profile.png"
|
||||
|
||||
# Image generated
|
||||
generated_message.status = ChatStatusType.DONE
|
||||
generated_message.content = f"{defines.api_prefix}/profile/{user.username}"
|
||||
yield generated_message
|
||||
generated_image = ChatMessage(
|
||||
session_id=session_id,
|
||||
status=ApiStatusType.DONE,
|
||||
content = f"{defines.api_prefix}/profile/{user.username}",
|
||||
metadata=generated_message.metadata
|
||||
)
|
||||
yield generated_image
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
chat_message.status = ChatStatusType.ERROR
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content=f"Error generating image: {str(e)}"
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(chat_message.content)
|
||||
chat_message.content = f"Error in image generation: {str(e)}"
|
||||
logger.error(chat_message.content)
|
||||
yield chat_message
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
# Register the base agent
|
||||
|
@ -27,7 +27,7 @@ import random
|
||||
from names_dataset import NameDataset, NameWrapper # type: ignore
|
||||
|
||||
from .base import Agent, agent_registry, LLMMessage
|
||||
from models import Candidate, ChatMessage, ChatMessageBase, ChatMessageMetaData, ChatMessageType, ChatMessageUser, ChatOptions, ChatSenderType, ChatStatusType
|
||||
from models import ApiActivityType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, Tunables
|
||||
import model_cast
|
||||
from logger import logger
|
||||
import defines
|
||||
@ -316,14 +316,16 @@ class GeneratePersona(Agent):
|
||||
self.full_name = f"{self.first_name} {self.last_name}"
|
||||
|
||||
async def generate(
|
||||
self, llm: Any, model: str, user_message: ChatMessageUser, user: Candidate, temperature=0.7
|
||||
):
|
||||
self, llm: Any, model: str,
|
||||
session_id: str, prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7
|
||||
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
|
||||
self.randomize()
|
||||
|
||||
status_message = ChatMessage(session_id=user_message.session_id)
|
||||
original_prompt = user_message.content
|
||||
original_prompt = prompt.strip()
|
||||
|
||||
user_message.content = f"""\
|
||||
prompt = f"""\
|
||||
```json
|
||||
{json.dumps({
|
||||
"age": self.age,
|
||||
@ -337,7 +339,7 @@ class GeneratePersona(Agent):
|
||||
"""
|
||||
|
||||
if original_prompt:
|
||||
user_message.content += f"""
|
||||
prompt += f"""
|
||||
Incorporate the following into the job description: {original_prompt}
|
||||
"""
|
||||
|
||||
@ -348,16 +350,24 @@ Incorporate the following into the job description: {original_prompt}
|
||||
generating_message = None
|
||||
async for generating_message in self.llm_one_shot(
|
||||
llm=llm, model=model,
|
||||
user_message=user_message,
|
||||
session_id=session_id,
|
||||
prompt=prompt,
|
||||
system_prompt=generate_persona_system_prompt,
|
||||
temperature=temperature,
|
||||
):
|
||||
if generating_message.status == ChatStatusType.ERROR:
|
||||
if generating_message.status == ApiStatusType.ERROR:
|
||||
logger.error(f"Error generating persona: {generating_message.content}")
|
||||
raise Exception(generating_message.content)
|
||||
yield generating_message
|
||||
if generating_message.status != ApiStatusType.DONE:
|
||||
yield generating_message
|
||||
|
||||
if not generating_message:
|
||||
raise Exception("No response from LLM during persona generation")
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Persona generation failed to generate a response."
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
json_str = self.extract_json_from_text(generating_message.content)
|
||||
try:
|
||||
@ -405,37 +415,39 @@ Incorporate the following into the job description: {original_prompt}
|
||||
persona["location"] = None
|
||||
persona["is_ai"] = True
|
||||
except Exception as e:
|
||||
generating_message.content = f"Unable to parse LLM returned content: {json_str} {str(e)}"
|
||||
generating_message.status = ChatStatusType.ERROR
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content=f"Error parsing LLM response: {str(e)}\n\n{json_str}"
|
||||
)
|
||||
logger.error(f"❌ Error parsing LLM response: {error_message.content}")
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(generating_message.content)
|
||||
yield generating_message
|
||||
yield error_message
|
||||
return
|
||||
|
||||
logger.info(f"✅ Persona for {persona['username']} generated successfully")
|
||||
|
||||
# Persona generated
|
||||
status_message = ChatMessage(
|
||||
session_id=user_message.session_id,
|
||||
status = ChatStatusType.STATUS,
|
||||
type = ChatMessageType.RESPONSE,
|
||||
persona_message = ChatMessage(
|
||||
session_id=session_id,
|
||||
status=ApiStatusType.DONE,
|
||||
type=ApiMessageType.JSON,
|
||||
content = json.dumps(persona)
|
||||
)
|
||||
yield status_message
|
||||
yield persona_message
|
||||
|
||||
#
|
||||
# Generate the resume
|
||||
#
|
||||
status_message = ChatMessage(
|
||||
session_id=user_message.session_id,
|
||||
status = ChatStatusType.STATUS,
|
||||
type = ChatMessageType.THINKING,
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
activity = ApiActivityType.THINKING,
|
||||
content = f"Generating resume for {persona['full_name']}..."
|
||||
)
|
||||
logger.info(f"🤖 {status_message.content}")
|
||||
yield status_message
|
||||
|
||||
user_message.content = f"""
|
||||
content = f"""
|
||||
```json
|
||||
{{
|
||||
"full_name": "{persona["full_name"]}",
|
||||
@ -449,30 +461,39 @@ Incorporate the following into the job description: {original_prompt}
|
||||
```
|
||||
"""
|
||||
if original_prompt:
|
||||
user_message.content += f"""
|
||||
content += f"""
|
||||
Make sure at least one of the candidate's job descriptions take into account the following: {original_prompt}."""
|
||||
|
||||
async for generating_message in self.llm_one_shot(
|
||||
llm=llm, model=model,
|
||||
user_message=user_message,
|
||||
session_id=session_id,
|
||||
prompt=content,
|
||||
system_prompt=generate_resume_system_prompt,
|
||||
temperature=temperature,
|
||||
):
|
||||
if generating_message.status == ChatStatusType.ERROR:
|
||||
if generating_message.status == ApiStatusType.ERROR:
|
||||
logger.error(f"❌ Error generating resume: {generating_message.content}")
|
||||
raise Exception(generating_message.content)
|
||||
if generating_message.status != ApiStatusType.DONE:
|
||||
yield generating_message
|
||||
|
||||
if not generating_message:
|
||||
raise Exception("No response from LLM during persona generation")
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Resume generation failed to generate a response."
|
||||
)
|
||||
logger.error(f"❌ {error_message.content}")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
resume = self.extract_markdown_from_text(generating_message.content)
|
||||
status_message = ChatMessage(
|
||||
session_id=user_message.session_id,
|
||||
status=ChatStatusType.DONE,
|
||||
type=ChatMessageType.RESPONSE,
|
||||
resume_message = ChatMessage(
|
||||
session_id=session_id,
|
||||
status=ApiStatusType.DONE,
|
||||
type=ApiMessageType.TEXT,
|
||||
content=resume
|
||||
)
|
||||
yield status_message
|
||||
yield resume_message
|
||||
return
|
||||
|
||||
def extract_json_from_text(self, text: str) -> str:
|
||||
|
@ -20,7 +20,7 @@ import asyncio
|
||||
import numpy as np # type: ignore
|
||||
|
||||
from .base import Agent, agent_registry, LLMMessage
|
||||
from models import Candidate, ChatMessage, ChatMessageBase, ChatMessageMetaData, ChatMessageType, ChatMessageUser, ChatOptions, ChatSenderType, ChatStatusType, JobRequirements
|
||||
from models import Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, JobRequirements, JobRequirementsMessage, Tunables
|
||||
import model_cast
|
||||
from logger import logger
|
||||
import defines
|
||||
@ -81,56 +81,51 @@ class JobRequirementsAgent(Agent):
|
||||
return system_prompt, prompt
|
||||
|
||||
async def analyze_job_requirements(
|
||||
self, llm: Any, model: str, user_message: ChatMessage
|
||||
) -> AsyncGenerator[ChatMessage, None]:
|
||||
self, llm: Any, model: str, session_id: str, prompt: str
|
||||
) -> AsyncGenerator[ChatMessage | ChatMessageError, None]:
|
||||
"""Analyze job requirements from job description."""
|
||||
system_prompt, prompt = self.create_job_analysis_prompt(user_message.content)
|
||||
analyze_message = user_message.model_copy()
|
||||
analyze_message.content = prompt
|
||||
system_prompt, prompt = self.create_job_analysis_prompt(prompt)
|
||||
generated_message = None
|
||||
async for generated_message in self.llm_one_shot(llm, model, system_prompt=system_prompt, user_message=analyze_message):
|
||||
if generated_message.status == ChatStatusType.ERROR:
|
||||
generated_message.content = "Error analyzing job requirements."
|
||||
async for generated_message in self.llm_one_shot(llm, model, session_id=session_id, prompt=prompt, system_prompt=system_prompt):
|
||||
if generated_message.status == ApiStatusType.ERROR:
|
||||
yield generated_message
|
||||
return
|
||||
if generated_message.status != ApiStatusType.DONE:
|
||||
yield generated_message
|
||||
|
||||
if not generated_message:
|
||||
status_message = ChatMessage(
|
||||
session_id=user_message.session_id,
|
||||
sender=ChatSenderType.AGENT,
|
||||
status = ChatStatusType.ERROR,
|
||||
type = ChatMessageType.ERROR,
|
||||
content = "Job requirements analysis failed to generate a response.")
|
||||
yield status_message
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Job requirements analysis failed to generate a response.")
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
generated_message.status = ChatStatusType.DONE
|
||||
generated_message.type = ChatMessageType.RESPONSE
|
||||
yield generated_message
|
||||
return
|
||||
|
||||
async def generate(
|
||||
self, llm: Any, model: str, user_message: ChatMessageUser, user: Candidate | None, temperature=0.7
|
||||
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
|
||||
) -> AsyncGenerator[ChatMessage, None]:
|
||||
# Stage 1A: Analyze job requirements
|
||||
status_message = ChatMessage(
|
||||
session_id=user_message.session_id,
|
||||
sender=ChatSenderType.AGENT,
|
||||
status=ChatStatusType.STATUS,
|
||||
type=ChatMessageType.THINKING,
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
content = f"Analyzing job requirements")
|
||||
yield status_message
|
||||
|
||||
generated_message = None
|
||||
async for generated_message in self.analyze_job_requirements(llm, model, user_message):
|
||||
if generated_message.status == ChatStatusType.ERROR:
|
||||
status_message.status = ChatStatusType.ERROR
|
||||
status_message.content = generated_message.content
|
||||
yield status_message
|
||||
async for generated_message in self.analyze_job_requirements(llm, model, session_id, prompt):
|
||||
if generated_message.status == ApiStatusType.ERROR:
|
||||
yield generated_message
|
||||
return
|
||||
if generated_message.status != ApiStatusType.DONE:
|
||||
yield generated_message
|
||||
|
||||
if not generated_message:
|
||||
status_message.status = ChatStatusType.ERROR
|
||||
status_message.content = "Job requirements analysis failed."
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
content="Job requirements analysis failed to generate a response.")
|
||||
logger.error(f"⚠️ {status_message.content}")
|
||||
yield status_message
|
||||
return
|
||||
|
||||
@ -150,34 +145,33 @@ class JobRequirementsAgent(Agent):
|
||||
if not job_requirements:
|
||||
raise ValueError("Job requirements data is empty or invalid.")
|
||||
except json.JSONDecodeError as e:
|
||||
status_message.status = ChatStatusType.ERROR
|
||||
status_message.status = ApiStatusType.ERROR
|
||||
status_message.content = f"Failed to parse job requirements JSON: {str(e)}\n\n{job_requirements_data}"
|
||||
logger.error(f"⚠️ {status_message.content}")
|
||||
yield status_message
|
||||
return
|
||||
except ValueError as e:
|
||||
status_message.status = ChatStatusType.ERROR
|
||||
status_message.status = ApiStatusType.ERROR
|
||||
status_message.content = f"Job requirements validation error: {str(e)}\n\n{job_requirements_data}"
|
||||
logger.error(f"⚠️ {status_message.content}")
|
||||
yield status_message
|
||||
return
|
||||
except Exception as e:
|
||||
status_message.status = ChatStatusType.ERROR
|
||||
status_message.status = ApiStatusType.ERROR
|
||||
status_message.content = f"Unexpected error processing job requirements: {str(e)}\n\n{job_requirements_data}"
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"⚠️ {status_message.content}")
|
||||
yield status_message
|
||||
return
|
||||
status_message.status = ChatStatusType.DONE
|
||||
status_message.type = ChatMessageType.RESPONSE
|
||||
job_data = {
|
||||
"company": company_name,
|
||||
"title": job_title,
|
||||
"summary": job_summary,
|
||||
"requirements": job_requirements.model_dump(mode="json", exclude_unset=True)
|
||||
}
|
||||
status_message.content = json.dumps(job_data)
|
||||
yield status_message
|
||||
job_requirements_message = JobRequirementsMessage(
|
||||
session_id=session_id,
|
||||
status=ApiStatusType.DONE,
|
||||
requirements=job_requirements,
|
||||
company=company_name,
|
||||
title=job_title,
|
||||
summary=job_summary
|
||||
)
|
||||
yield job_requirements_message
|
||||
|
||||
logger.info(f"✅ Job requirements analysis completed successfully.")
|
||||
return
|
||||
|
@ -7,7 +7,7 @@ from .base import Agent, agent_registry
|
||||
from logger import logger
|
||||
|
||||
from .registry import agent_registry
|
||||
from models import ( ChatMessage, ChatStatusType, ChatMessage, ChatOptions, ChatMessageType, ChatSenderType, ChatStatusType, ChatMessageMetaData, Candidate )
|
||||
from models import ( ChatMessage, ApiStatusType, ChatMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, ChatOptions, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate, Tunables )
|
||||
from rag import ( ChromaDBGetResponse )
|
||||
|
||||
class Chat(Agent):
|
||||
@ -19,8 +19,11 @@ class Chat(Agent):
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
async def generate(
|
||||
self, llm: Any, model: str, user_message: ChatMessage, user: Candidate, temperature=0.7
|
||||
) -> AsyncGenerator[ChatMessage, None]:
|
||||
self, llm: Any, model: str,
|
||||
session_id: str, prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7
|
||||
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
|
||||
"""
|
||||
Generate a response based on the user message and the provided LLM.
|
||||
|
||||
@ -34,65 +37,25 @@ class Chat(Agent):
|
||||
Yields:
|
||||
ChatMessage: The generated response.
|
||||
"""
|
||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||
|
||||
if user.id != user_message.sender_id:
|
||||
logger.error(f"User {user.username} id does not match message {user_message.sender_id}")
|
||||
raise ValueError("User does not match message sender")
|
||||
|
||||
chat_message = ChatMessage(
|
||||
session_id=user_message.session_id,
|
||||
tunables=user_message.tunables,
|
||||
status=ChatStatusType.INITIALIZING,
|
||||
type=ChatMessageType.PREPARING,
|
||||
sender=ChatSenderType.ASSISTANT,
|
||||
content="",
|
||||
timestamp=datetime.now(UTC)
|
||||
)
|
||||
|
||||
chat_message.metadata = ChatMessageMetaData()
|
||||
chat_message.metadata.options = ChatOptions(
|
||||
seed=8911,
|
||||
num_ctx=self.context_size,
|
||||
temperature=temperature, # Higher temperature to encourage tool usage
|
||||
)
|
||||
|
||||
# Create a dict for storing various timing stats
|
||||
chat_message.metadata.timers = {}
|
||||
|
||||
self.metrics.generate_count.labels(agent=self.agent_type).inc()
|
||||
with self.metrics.generate_duration.labels(agent=self.agent_type).time():
|
||||
|
||||
rag_message : Optional[ChatMessage] = None
|
||||
async for rag_message in self.generate_rag_results(chat_message=user_message):
|
||||
if rag_message.status == ChatStatusType.ERROR:
|
||||
chat_message.status = rag_message.status
|
||||
chat_message.content = rag_message.content
|
||||
yield chat_message
|
||||
return
|
||||
rag_message = None
|
||||
async for rag_message in self.generate_rag_results(session_id, prompt):
|
||||
if rag_message.status == ApiStatusType.ERROR:
|
||||
yield rag_message
|
||||
return
|
||||
if rag_message.status != ApiStatusType.DONE:
|
||||
yield rag_message
|
||||
|
||||
if rag_message:
|
||||
chat_message.content = ""
|
||||
rag_results: List[ChromaDBGetResponse] = rag_message.metadata.rag_results
|
||||
chat_message.metadata.rag_results = rag_results
|
||||
for chroma_results in rag_results:
|
||||
for index, metadata in enumerate(chroma_results.metadatas):
|
||||
content = "\n".join([
|
||||
line.strip()
|
||||
for line in chroma_results.documents[index].split("\n")
|
||||
if line
|
||||
]).strip()
|
||||
chat_message.content += f"""
|
||||
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")}
|
||||
Document reference: {chroma_results.ids[index]}
|
||||
Content: { content }
|
||||
if not isinstance(rag_message, ChatMessageRagSearch):
|
||||
logger.error(f"Expected ChatMessageRagSearch, got {type(rag_message)}")
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="RAG search did not return a valid response."
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
"""
|
||||
|
||||
chat_message.status = ChatStatusType.DONE
|
||||
chat_message.type = ChatMessageType.RAG_RESULT
|
||||
yield chat_message
|
||||
rag_message.status = ApiStatusType.DONE
|
||||
yield rag_message
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(Chat._agent_type, Chat)
|
||||
|
@ -20,7 +20,7 @@ import asyncio
|
||||
import numpy as np # type: ignore
|
||||
|
||||
from .base import Agent, agent_registry, LLMMessage
|
||||
from models import Candidate, ChatMessage, ChatMessageBase, ChatMessageMetaData, ChatMessageType, ChatMessageUser, ChatOptions, ChatSenderType, ChatStatusType, SkillMatch
|
||||
from models import Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, SkillMatch, Tunables
|
||||
import model_cast
|
||||
from logger import logger
|
||||
import defines
|
||||
@ -29,13 +29,13 @@ class SkillMatchAgent(Agent):
|
||||
agent_type: Literal["skill_match"] = "skill_match" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
def generate_skill_assessment_prompt(self, skill, rag_content):
|
||||
def generate_skill_assessment_prompt(self, skill, rag_context):
|
||||
"""
|
||||
Generate a system prompt to query the LLM for evidence of a specific skill
|
||||
|
||||
Parameters:
|
||||
- skill (str): The specific skill to assess from job requirements
|
||||
- rag_content (str): Additional RAG content queried from candidate documents
|
||||
- rag_contexty (str): Additional RAG content queried from candidate documents
|
||||
|
||||
Returns:
|
||||
- str: A system prompt tailored to assess the specific skill
|
||||
@ -98,7 +98,7 @@ Adhere strictly to the JSON output format requested. Do not include any addition
|
||||
RESPOND WITH ONLY VALID JSON USING THE EXACT FORMAT SPECIFIED.
|
||||
|
||||
<candidate_info>
|
||||
{rag_content}
|
||||
{rag_context}
|
||||
</candidate_info>
|
||||
|
||||
JSON RESPONSE:"""
|
||||
@ -106,110 +106,102 @@ JSON RESPONSE:"""
|
||||
return system_prompt, prompt
|
||||
|
||||
async def analyze_job_requirements(
|
||||
self, llm: Any, model: str, user_message: ChatMessage
|
||||
self, llm: Any, model: str, session_id: str, requirement: str
|
||||
) -> AsyncGenerator[ChatMessage, None]:
|
||||
"""Analyze job requirements from job description."""
|
||||
system_prompt, prompt = self.create_job_analysis_prompt(user_message.content)
|
||||
analyze_message = user_message.model_copy()
|
||||
analyze_message.content = prompt
|
||||
system_prompt, prompt = self.create_job_analysis_prompt(requirement)
|
||||
|
||||
generated_message = None
|
||||
async for generated_message in self.llm_one_shot(llm, model, system_prompt=system_prompt, user_message=analyze_message):
|
||||
if generated_message.status == ChatStatusType.ERROR:
|
||||
async for generated_message in self.llm_one_shot(llm, model, session_id=session_id, prompt=prompt, system_prompt=system_prompt):
|
||||
if generated_message.status == ApiStatusType.ERROR:
|
||||
generated_message.content = "Error analyzing job requirements."
|
||||
yield generated_message
|
||||
return
|
||||
if generated_message.status != ApiStatusType.DONE:
|
||||
yield generated_message
|
||||
|
||||
if not generated_message:
|
||||
status_message = ChatMessage(
|
||||
session_id=user_message.session_id,
|
||||
sender=ChatSenderType.AGENT,
|
||||
status = ChatStatusType.ERROR,
|
||||
type = ChatMessageType.ERROR,
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content = "Job requirements analysis failed to generate a response.")
|
||||
yield status_message
|
||||
yield error_message
|
||||
return
|
||||
|
||||
generated_message.status = ChatStatusType.DONE
|
||||
generated_message.type = ChatMessageType.RESPONSE
|
||||
yield generated_message
|
||||
chat_message = ChatMessage(
|
||||
session_id=session_id,
|
||||
status=ApiStatusType.DONE,
|
||||
content=generated_message.content,
|
||||
metadata=generated_message.metadata,
|
||||
)
|
||||
yield chat_message
|
||||
return
|
||||
|
||||
async def generate(
|
||||
self, llm: Any, model: str, user_message: ChatMessageUser, user: Candidate | None, temperature=0.7
|
||||
) -> AsyncGenerator[ChatMessage, None]:
|
||||
self, llm: Any, model: str,
|
||||
session_id: str, prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7
|
||||
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
|
||||
# Stage 1A: Analyze job requirements
|
||||
status_message = ChatMessage(
|
||||
session_id=user_message.session_id,
|
||||
sender=ChatSenderType.AGENT,
|
||||
status=ChatStatusType.STATUS,
|
||||
type=ChatMessageType.THINKING,
|
||||
content = f"Analyzing job requirements")
|
||||
yield status_message
|
||||
|
||||
rag_message = None
|
||||
async for rag_message in self.generate_rag_results(chat_message=user_message):
|
||||
if rag_message.status == ChatStatusType.ERROR:
|
||||
status_message.status = ChatStatusType.ERROR
|
||||
status_message.content = rag_message.content
|
||||
logger.error(f"⚠️ {status_message.content}")
|
||||
yield status_message
|
||||
async for rag_message in self.generate_rag_results(session_id=session_id, prompt=prompt):
|
||||
if rag_message.status == ApiStatusType.ERROR:
|
||||
yield rag_message
|
||||
return
|
||||
if rag_message.status != ApiStatusType.DONE:
|
||||
yield rag_message
|
||||
|
||||
if rag_message is None:
|
||||
status_message.status = ChatStatusType.ERROR
|
||||
status_message.content = "Failed to retrieve RAG context."
|
||||
logger.error(f"⚠️ {status_message.content}")
|
||||
yield status_message
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="RAG search did not return a valid response."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
logger.info(f"🔍 RAG content retrieved: {len(rag_message.content)} bytes")
|
||||
rag_context = self.get_rag_context(rag_message)
|
||||
system_prompt, prompt = self.generate_skill_assessment_prompt(skill=prompt, rag_context=rag_context)
|
||||
|
||||
system_prompt, prompt = self.generate_skill_assessment_prompt(skill=user_message.content, rag_content=rag_message.content)
|
||||
|
||||
user_message.content = prompt
|
||||
skill_assessment = None
|
||||
async for skill_assessment in self.llm_one_shot(llm=llm, model=model, user_message=user_message, system_prompt=system_prompt, temperature=0.7):
|
||||
if skill_assessment.status == ChatStatusType.ERROR:
|
||||
status_message.status = ChatStatusType.ERROR
|
||||
status_message.content = skill_assessment.content
|
||||
logger.error(f"⚠️ {status_message.content}")
|
||||
yield status_message
|
||||
async for skill_assessment in self.llm_one_shot(llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt, temperature=0.7):
|
||||
if skill_assessment.status == ApiStatusType.ERROR:
|
||||
logger.error(f"⚠️ {skill_assessment.content}")
|
||||
yield skill_assessment
|
||||
return
|
||||
if skill_assessment.status != ApiStatusType.DONE:
|
||||
yield skill_assessment
|
||||
|
||||
if skill_assessment is None:
|
||||
status_message.status = ChatStatusType.ERROR
|
||||
status_message.content = "Failed to generate skill assessment."
|
||||
logger.error(f"⚠️ {status_message.content}")
|
||||
yield status_message
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Skill assessment failed to generate a response."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
json_str = self.extract_json_from_text(skill_assessment.content)
|
||||
skill_assessment_data = ""
|
||||
try:
|
||||
skill_assessment_data = json.loads(json_str).get("skill_assessment", {})
|
||||
except json.JSONDecodeError as e:
|
||||
status_message.status = ChatStatusType.ERROR
|
||||
status_message.content = f"Failed to parse Skill assessment JSON: {str(e)}\n\n{skill_assessment_data}"
|
||||
logger.error(f"⚠️ {status_message.content}")
|
||||
yield status_message
|
||||
return
|
||||
except ValueError as e:
|
||||
status_message.status = ChatStatusType.ERROR
|
||||
status_message.content = f"Skill assessment validation error: {str(e)}\n\n{skill_assessment_data}"
|
||||
logger.error(f"⚠️ {status_message.content}")
|
||||
yield status_message
|
||||
return
|
||||
except Exception as e:
|
||||
status_message.status = ChatStatusType.ERROR
|
||||
status_message.content = f"Unexpected error processing Skill assessment: {str(e)}\n\n{skill_assessment_data}"
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"⚠️ {status_message.content}")
|
||||
yield status_message
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content=f"Failed to parse Skill assessment JSON: {str(e)}\n\n{skill_assessment.content}"
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
return
|
||||
status_message.status = ChatStatusType.DONE
|
||||
status_message.type = ChatMessageType.RESPONSE
|
||||
status_message.content = json.dumps(skill_assessment_data)
|
||||
yield status_message
|
||||
|
||||
skill_assessment_message = ChatMessage(
|
||||
session_id=session_id,
|
||||
status=ApiStatusType.DONE,
|
||||
content=json.dumps(skill_assessment_data),
|
||||
metadata=skill_assessment.metadata
|
||||
)
|
||||
yield skill_assessment_message
|
||||
logger.info(f"✅ Skill assessment completed successfully.")
|
||||
return
|
||||
|
||||
|
@ -18,7 +18,7 @@ from rag import start_file_watcher, ChromaDBFileWatcher, ChromaDBGetResponse
|
||||
import defines
|
||||
from logger import logger
|
||||
import agents as agents
|
||||
from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, RagEntry, ChatMessageType, ChatMessageMetaData, ChatStatusType, Candidate, ChatContextType)
|
||||
from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, RagEntry, ChatMessageMetaData, ApiStatusType, Candidate, ChatContextType)
|
||||
from llm_manager import llm_manager
|
||||
from agents.base import Agent
|
||||
|
||||
|
@ -797,7 +797,7 @@ Generated conversion functions can be used like:
|
||||
const jobs = convertArrayFromApi<Job>(apiResponse, 'Job');
|
||||
|
||||
Enum types are now properly handled:
|
||||
status: ChatStatusType = ChatStatusType.DONE -> status: ChatStatusType (not locked to "done")
|
||||
status: ApiStatusType = ApiStatusType.DONE -> status: ApiStatusType (not locked to "done")
|
||||
"""
|
||||
)
|
||||
|
||||
|
@ -24,7 +24,7 @@ import uuid
|
||||
|
||||
from .image_model_cache import ImageModelCache
|
||||
|
||||
from models import Candidate, ChatMessage, ChatMessageBase, ChatMessageMetaData, ChatMessageType, ChatMessageUser, ChatOptions, ChatSenderType, ChatStatusType
|
||||
from models import ApiActivityType, ApiStatusType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ChatMessageStatus, ChatMessageUser, ChatOptions, ChatSenderType
|
||||
from logger import logger
|
||||
|
||||
from image_generator.image_model_cache import ImageModelCache
|
||||
@ -39,6 +39,7 @@ TIME_ESTIMATES = {
|
||||
}
|
||||
|
||||
class ImageRequest(BaseModel):
|
||||
session_id: str
|
||||
filepath: str
|
||||
prompt: str
|
||||
model: str = "black-forest-labs/FLUX.1-schnell"
|
||||
@ -53,18 +54,12 @@ model_cache = ImageModelCache()
|
||||
def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str):
|
||||
"""Background worker for Flux image generation"""
|
||||
try:
|
||||
# Your existing estimates calculation
|
||||
estimates = {"per_step": 0.5} # Replace with your actual estimates
|
||||
resolution_scale = (params.height * params.width) / (512 * 512)
|
||||
|
||||
# Flux: Run generation in the background and yield progress updates
|
||||
estimated_gen_time = estimates["per_step"] * params.iterations * resolution_scale
|
||||
status_queue.put({
|
||||
"status": "running",
|
||||
"message": f"Initializing image generation...",
|
||||
"estimated_time_remaining": estimated_gen_time,
|
||||
"progress": 0
|
||||
})
|
||||
status_queue.put(ChatMessageStatus(
|
||||
session_id=params.session_id,
|
||||
content=f"Initializing image generation.",
|
||||
activity=ApiActivityType.GENERATING_IMAGE,
|
||||
))
|
||||
|
||||
# Start the generation task
|
||||
start_gen_time = time.time()
|
||||
@ -74,11 +69,11 @@ def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task
|
||||
# Send progress updates
|
||||
progress = int((step+1) / params.iterations * 100)
|
||||
|
||||
status_queue.put({
|
||||
"status": "running",
|
||||
"message": f"Processing step {step+1}/{params.iterations} ({progress}%) complete.",
|
||||
"progress": progress
|
||||
})
|
||||
status_queue.put(ChatMessageStatus(
|
||||
session_id=params.session_id,
|
||||
content=f"Processing step {step+1}/{params.iterations} ({progress}%)",
|
||||
activity=ApiActivityType.GENERATING_IMAGE,
|
||||
))
|
||||
return callback_kwargs
|
||||
|
||||
# Replace this block with your actual Flux pipe call:
|
||||
@ -98,27 +93,22 @@ def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task
|
||||
image.save(params.filepath)
|
||||
|
||||
# Final completion status
|
||||
status_queue.put({
|
||||
"status": "completed",
|
||||
"message": f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.",
|
||||
"progress": 100,
|
||||
"generation_time": gen_time,
|
||||
"per_step_time": per_step_time,
|
||||
"image_path": params.filepath
|
||||
})
|
||||
status_queue.put(ChatMessage(
|
||||
session_id=params.session_id,
|
||||
status=ApiStatusType.DONE,
|
||||
content=f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.",
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(e)
|
||||
status_queue.put({
|
||||
"status": "error",
|
||||
"message": f"Generation failed: {str(e)}",
|
||||
"error": str(e),
|
||||
"progress": 0
|
||||
})
|
||||
status_queue.put(ChatMessageError(
|
||||
session_id=params.session_id,
|
||||
content=f"Error during image generation: {str(e)}",
|
||||
))
|
||||
|
||||
|
||||
async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError, None]:
|
||||
"""
|
||||
Single async function that handles background Flux generation with status streaming
|
||||
"""
|
||||
@ -136,7 +126,12 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato
|
||||
worker_thread.start()
|
||||
|
||||
# Initial status
|
||||
yield {'status': 'starting', 'task_id': task_id, 'message': 'Initializing image generation'}
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=params.session_id,
|
||||
content=f"Starting image generation with task ID {task_id}",
|
||||
activity=ApiActivityType.THINKING
|
||||
)
|
||||
yield status_message
|
||||
|
||||
# Stream status updates
|
||||
completed = False
|
||||
@ -147,14 +142,12 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato
|
||||
# Try to get status update (non-blocking)
|
||||
status_update = status_queue.get_nowait()
|
||||
|
||||
# Add task_id to status update
|
||||
status_update['task_id'] = task_id
|
||||
|
||||
# Send status update
|
||||
yield status_update
|
||||
|
||||
# Check if completed
|
||||
if status_update.get('status') in ['completed', 'error']:
|
||||
if status_update.status == ApiStatusType.DONE:
|
||||
logger.info(f"Image generation completed for task {task_id}")
|
||||
completed = True
|
||||
|
||||
last_heartbeat = time.time()
|
||||
@ -163,11 +156,11 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato
|
||||
# No new status, send heartbeat if needed
|
||||
current_time = time.time()
|
||||
if current_time - last_heartbeat > 2: # Heartbeat every 2 seconds
|
||||
heartbeat = {
|
||||
'status': 'heartbeat',
|
||||
'task_id': task_id,
|
||||
'timestamp': current_time
|
||||
}
|
||||
heartbeat = ChatMessageStatus(
|
||||
session_id=params.session_id,
|
||||
content=f"Heartbeat for task {task_id}",
|
||||
activity=ApiActivityType.HEARTBEAT,
|
||||
)
|
||||
yield heartbeat
|
||||
last_heartbeat = current_time
|
||||
|
||||
@ -178,28 +171,25 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato
|
||||
if not completed:
|
||||
if worker_thread.is_alive():
|
||||
# Thread still running but we might have missed the completion signal
|
||||
timeout_status = {
|
||||
'status': 'timeout',
|
||||
'task_id': task_id,
|
||||
'message': 'Generation timed out or connection lost'
|
||||
}
|
||||
timeout_status = ChatMessageError(
|
||||
session_id=params.session_id,
|
||||
content=f"Generation timeout for task {task_id}. The process may still be running.",
|
||||
)
|
||||
yield timeout_status
|
||||
else:
|
||||
# Thread completed but we might have missed the final status
|
||||
final_status = {
|
||||
'status': 'completed',
|
||||
'task_id': task_id,
|
||||
'message': 'Generation completed'
|
||||
}
|
||||
final_status = ChatMessage(
|
||||
session_id=params.session_id,
|
||||
status=ApiStatusType.DONE,
|
||||
content=f"Generation completed for task {task_id}.",
|
||||
)
|
||||
yield final_status
|
||||
|
||||
except Exception as e:
|
||||
error_status = {
|
||||
'status': 'error',
|
||||
'task_id': task_id,
|
||||
'message': f'Server error: {str(e)}',
|
||||
'error': str(e)
|
||||
}
|
||||
error_status = ChatMessageError(
|
||||
session_id=params.session_id,
|
||||
content=f'Server error: {str(e)}'
|
||||
)
|
||||
logger.error(error_status)
|
||||
yield error_status
|
||||
|
||||
@ -208,41 +198,38 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato
|
||||
if worker_thread and 'worker_thread' in locals() and worker_thread.is_alive():
|
||||
worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup
|
||||
|
||||
def status(chat_message: ChatMessage, status: str, progress: float = 0, estimated_time_remaining="...") -> ChatMessage:
|
||||
def status(session_id: str, status: str) -> ChatMessageStatus:
|
||||
"""Update chat message status and return it."""
|
||||
message = chat_message.copy(deep=True)
|
||||
message.id = str(uuid.uuid4())
|
||||
message.timestamp = datetime.now(UTC)
|
||||
message.type = ChatMessageType.THINKING
|
||||
message.status = ChatStatusType.STREAMING
|
||||
message.content = status
|
||||
return message
|
||||
|
||||
async def generate_image(user_message: ChatMessage, request: ImageRequest) -> AsyncGenerator[ChatMessage, None]:
|
||||
"""Generate an image with specified dimensions and yield status updates with time estimates."""
|
||||
chat_message = ChatMessage(
|
||||
session_id=user_message.session_id,
|
||||
tunables=user_message.tunables,
|
||||
status=ChatStatusType.INITIALIZING,
|
||||
type=ChatMessageType.PREPARING,
|
||||
sender=ChatSenderType.ASSISTANT,
|
||||
content="",
|
||||
timestamp=datetime.now(UTC)
|
||||
chat_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
activity=ApiActivityType.GENERATING_IMAGE,
|
||||
content=status,
|
||||
)
|
||||
return chat_message
|
||||
|
||||
async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, None]:
|
||||
"""Generate an image with specified dimensions and yield status updates with time estimates."""
|
||||
session_id = request.session_id
|
||||
prompt = request.prompt.strip()
|
||||
try:
|
||||
# Validate prompt
|
||||
prompt = user_message.content.strip()
|
||||
if not prompt:
|
||||
chat_message.status = ChatStatusType.ERROR
|
||||
chat_message.content = "Prompt cannot be empty"
|
||||
yield chat_message
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Prompt cannot be empty."
|
||||
)
|
||||
logger.error(error_message.content)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
# Validate dimensions
|
||||
if request.height <= 0 or request.width <= 0:
|
||||
chat_message.status = ChatStatusType.ERROR
|
||||
chat_message.content = "Height and width must be positive"
|
||||
yield chat_message
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Height and width must be positive integers."
|
||||
)
|
||||
logger.error(error_message.content)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
filedir = os.path.dirname(request.filepath)
|
||||
@ -252,43 +239,50 @@ async def generate_image(user_message: ChatMessage, request: ImageRequest) -> As
|
||||
model_type = "flux"
|
||||
device = "cpu"
|
||||
|
||||
yield status(chat_message, f"Starting image generation...")
|
||||
|
||||
# Get initial time estimate, scaled by resolution
|
||||
estimates = TIME_ESTIMATES[model_type][device]
|
||||
resolution_scale = (request.height * request.width) / (512 * 512)
|
||||
estimated_total = estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale
|
||||
yield status(chat_message, f"Estimated generation time: ~{estimated_total:.1f} seconds for {request.width}x{request.height}")
|
||||
|
||||
# Initialize or get cached pipeline
|
||||
start_time = time.time()
|
||||
yield status(chat_message, f"Loading generative image model...")
|
||||
yield status(session_id, f"Loading generative image model...")
|
||||
pipe = await model_cache.get_pipeline(request.model, device)
|
||||
load_time = time.time() - start_time
|
||||
yield status(chat_message, f"Model loaded in {load_time:.1f} seconds.", progress=10)
|
||||
yield status(session_id, f"Model loaded in {load_time:.1f} seconds.",)
|
||||
|
||||
async for status_message in async_generate_image(pipe, request):
|
||||
chat_message.content = json.dumps(status_message) # Merge properties from async_generate_image over the message...
|
||||
chat_message.type = ChatMessageType.HEARTBEAT if status_message.get("status") == "heartbeat" else ChatMessageType.THINKING
|
||||
if chat_message.type != ChatMessageType.HEARTBEAT:
|
||||
logger.info(chat_message.content)
|
||||
yield chat_message
|
||||
progress = None
|
||||
async for progress in async_generate_image(pipe, request):
|
||||
if progress.status == ApiStatusType.ERROR:
|
||||
yield progress
|
||||
return
|
||||
if progress.status != ApiStatusType.DONE:
|
||||
yield progress
|
||||
|
||||
if not progress:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Image generation failed to produce a valid response."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
# Final result
|
||||
total_time = time.time() - start_time
|
||||
chat_message.status = ChatStatusType.DONE
|
||||
chat_message.type = ChatMessageType.RESPONSE
|
||||
chat_message.content = json.dumps({
|
||||
"status": f"Image generation complete in {total_time:.1f} seconds",
|
||||
"progress": 100,
|
||||
"filename": request.filepath
|
||||
})
|
||||
chat_message = ChatMessage(
|
||||
session_id=session_id,
|
||||
status=ApiStatusType.DONE,
|
||||
content=f"Image generated successfully in {total_time:.1f} seconds.",
|
||||
)
|
||||
yield chat_message
|
||||
|
||||
except Exception as e:
|
||||
chat_message.status = ChatStatusType.ERROR
|
||||
chat_message.content = str(e)
|
||||
yield chat_message
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content=f"Error during image generation: {str(e)}"
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(chat_message.content)
|
||||
logger.error(error_message.content)
|
||||
yield error_message
|
||||
return
|
@ -64,7 +64,7 @@ import agents
|
||||
# =============================
|
||||
from models import (
|
||||
# API
|
||||
Job, LoginRequest, CreateCandidateRequest, CreateEmployerRequest,
|
||||
ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, Job, LoginRequest, CreateCandidateRequest, CreateEmployerRequest,
|
||||
|
||||
# User models
|
||||
Candidate, Employer, BaseUserWithType, BaseUser, Guest, Authentication, AuthResponse, CandidateAI,
|
||||
@ -73,7 +73,7 @@ from models import (
|
||||
JobFull, JobApplication, ApplicationStatus,
|
||||
|
||||
# Chat models
|
||||
ChatSession, ChatMessage, ChatContext, ChatQuery, ChatStatusType, ChatMessageBase, ChatMessageUser, ChatSenderType, ChatMessageType, ChatContextType,
|
||||
ChatSession, ChatMessage, ChatContext, ChatQuery, ApiStatusType, ChatSenderType, ApiMessageType, ChatContextType,
|
||||
ChatMessageRagSearch,
|
||||
|
||||
# Document models
|
||||
@ -361,7 +361,6 @@ def filter_and_paginate(
|
||||
|
||||
async def stream_agent_response(chat_agent: agents.Agent,
|
||||
user_message: ChatMessageUser,
|
||||
candidate: Candidate,
|
||||
chat_session_data: Dict[str, Any] | None = None,
|
||||
database: RedisDatabase | None = None) -> StreamingResponse:
|
||||
async def message_stream_generator():
|
||||
@ -372,30 +371,33 @@ async def stream_agent_response(chat_agent: agents.Agent,
|
||||
async for generated_message in chat_agent.generate(
|
||||
llm=llm_manager.get_llm(),
|
||||
model=defines.model,
|
||||
user_message=user_message,
|
||||
user=candidate,
|
||||
session_id=user_message.session_id,
|
||||
prompt=user_message.content,
|
||||
):
|
||||
if generated_message.status == ApiStatusType.ERROR:
|
||||
logger.error(f"❌ AI generation error: {generated_message.content}")
|
||||
yield f"data: {json.dumps({'status': 'error'})}\n\n"
|
||||
return
|
||||
|
||||
# Store reference to the complete AI message
|
||||
if generated_message.status == ChatStatusType.DONE:
|
||||
if generated_message.status == ApiStatusType.DONE:
|
||||
final_message = generated_message
|
||||
|
||||
# If the message is not done, convert it to a ChatMessageBase to remove
|
||||
# metadata and other unnecessary fields for streaming
|
||||
if generated_message.status != ChatStatusType.DONE:
|
||||
generated_message = model_cast.cast_to_model(ChatMessageBase, generated_message)
|
||||
if generated_message.status != ApiStatusType.DONE:
|
||||
if not isinstance(generated_message, ChatMessageStreaming) and not isinstance(generated_message, ChatMessageStatus):
|
||||
raise TypeError(
|
||||
f"Expected ChatMessageStreaming or ChatMessageStatus, got {type(generated_message)}"
|
||||
)
|
||||
|
||||
json_data = generated_message.model_dump(mode='json', by_alias=True, exclude_unset=True)
|
||||
json_data = generated_message.model_dump(mode='json', by_alias=True)
|
||||
json_str = json.dumps(json_data)
|
||||
|
||||
log = f"🔗 Message status={generated_message.status}, sender={getattr(generated_message, 'sender', 'unknown')}"
|
||||
if last_log != log:
|
||||
last_log = log
|
||||
logger.info(log)
|
||||
|
||||
yield f"data: {json_str}\n\n"
|
||||
|
||||
# After streaming is complete, persist the final AI message to database
|
||||
if final_message and final_message.status == ChatStatusType.DONE:
|
||||
if final_message and final_message.status == ApiStatusType.DONE:
|
||||
try:
|
||||
if database and chat_session_data:
|
||||
await database.add_chat_message(final_message.session_id, final_message.model_dump())
|
||||
@ -412,9 +414,12 @@ async def stream_agent_response(chat_agent: agents.Agent,
|
||||
message_stream_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Cache-Control": "no-cache, no-store, must-revalidate",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"X-Accel-Buffering": "no", # Nginx
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"Access-Control-Allow-Origin": "*", # Adjust for your CORS needs
|
||||
"Transfer-Encoding": "chunked",
|
||||
},
|
||||
)
|
||||
|
||||
@ -678,19 +683,19 @@ async def create_candidate_ai(
|
||||
async for generated_message in generate_agent.generate(
|
||||
llm=llm_manager.get_llm(),
|
||||
model=defines.model,
|
||||
user_message=user_message,
|
||||
user=None,
|
||||
session_id=user_message.session_id,
|
||||
prompt=user_message.content,
|
||||
):
|
||||
if generated_message.status == ChatStatusType.ERROR:
|
||||
if generated_message.status == ApiStatusType.ERROR:
|
||||
logger.error(f"❌ AI generation error: {generated_message.content}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("AI_GENERATION_ERROR", generated_message.content)
|
||||
)
|
||||
if generated_message.type == ChatMessageType.RESPONSE and state == 0:
|
||||
if generated_message.status == ApiStatusType.DONE and state == 0:
|
||||
persona_message = generated_message
|
||||
state = 1 # Switch to resume generation
|
||||
elif generated_message.type == ChatMessageType.RESPONSE and state == 1:
|
||||
elif generated_message.status == ApiStatusType.DONE and state == 1:
|
||||
resume_message = generated_message
|
||||
|
||||
if not persona_message:
|
||||
@ -2412,7 +2417,6 @@ async def update_candidate(
|
||||
)
|
||||
|
||||
is_AI = candidate_data.get("is_AI", False)
|
||||
logger.info(json.dumps(candidate_data, indent=2))
|
||||
candidate = CandidateAI.model_validate(candidate_data) if is_AI else Candidate.model_validate(candidate_data)
|
||||
|
||||
# Check authorization (user can only update their own profile)
|
||||
@ -2804,8 +2808,8 @@ async def post_candidate_rag_search(
|
||||
async for generated_message in chat_agent.generate(
|
||||
llm=llm_manager.get_llm(),
|
||||
model=defines.model,
|
||||
user_message=user_message,
|
||||
user=candidate,
|
||||
session_id=user_message.session_id,
|
||||
prompt=user_message.prompt,
|
||||
):
|
||||
rag_message = generated_message
|
||||
|
||||
@ -3087,7 +3091,6 @@ async def post_chat_session_message_stream(
|
||||
return await stream_agent_response(
|
||||
chat_agent=chat_agent,
|
||||
user_message=user_message,
|
||||
candidate=candidate,
|
||||
database=database,
|
||||
chat_session_data=chat_session_data,
|
||||
)
|
||||
@ -3389,15 +3392,10 @@ async def get_candidate_skill_match(
|
||||
agent.generate(
|
||||
llm=llm_manager.get_llm(),
|
||||
model=defines.model,
|
||||
user_message=ChatMessageUser(
|
||||
sender_id=candidate.id,
|
||||
session_id="",
|
||||
content=requirement,
|
||||
timestamp=datetime.now(UTC)
|
||||
session_id="",
|
||||
prompt=requirement,
|
||||
),
|
||||
user=candidate,
|
||||
)
|
||||
)
|
||||
if skill_match is None:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
|
@ -71,8 +71,11 @@ class InterviewRecommendation(str, Enum):
|
||||
class ChatSenderType(str, Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
AGENT = "agent"
|
||||
# Frontend can use this to set mock responses
|
||||
SYSTEM = "system"
|
||||
INFORMATION = "information"
|
||||
WARNING = "warning"
|
||||
ERROR = "error"
|
||||
|
||||
class Requirements(BaseModel):
|
||||
required: List[str] = Field(default_factory=list)
|
||||
@ -108,24 +111,12 @@ class SkillMatch(BaseModel):
|
||||
"populate_by_name": True # Allow both field names and aliases
|
||||
}
|
||||
|
||||
class ApiMessageType(str, Enum):
|
||||
BINARY = "binary"
|
||||
TEXT = "text"
|
||||
JSON = "json"
|
||||
|
||||
class ChatMessageType(str, Enum):
|
||||
ERROR = "error"
|
||||
GENERATING = "generating"
|
||||
INFO = "info"
|
||||
PREPARING = "preparing"
|
||||
PROCESSING = "processing"
|
||||
HEARTBEAT = "heartbeat"
|
||||
RESPONSE = "response"
|
||||
SEARCHING = "searching"
|
||||
RAG_RESULT = "rag_result"
|
||||
SYSTEM = "system"
|
||||
THINKING = "thinking"
|
||||
TOOLING = "tooling"
|
||||
USER = "user"
|
||||
|
||||
class ChatStatusType(str, Enum):
|
||||
INITIALIZING = "initializing"
|
||||
class ApiStatusType(str, Enum):
|
||||
STREAMING = "streaming"
|
||||
STATUS = "status"
|
||||
DONE = "done"
|
||||
@ -772,26 +763,55 @@ class LLMMessage(BaseModel):
|
||||
content: str = Field(default="")
|
||||
tool_calls: Optional[List[Dict]] = Field(default=[], exclude=True)
|
||||
|
||||
|
||||
class ChatMessageBase(BaseModel):
|
||||
class ApiMessage(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
session_id: str = Field(..., alias="sessionId")
|
||||
sender_id: Optional[str] = Field(None, alias="senderId")
|
||||
status: ChatStatusType #= ChatStatusType.INITIALIZING
|
||||
type: ChatMessageType #= ChatMessageType.PREPARING
|
||||
sender: ChatSenderType #= ChatSenderType.SYSTEM
|
||||
status: ApiStatusType
|
||||
type: ApiMessageType
|
||||
timestamp: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="timestamp")
|
||||
tunables: Optional[Tunables] = None
|
||||
content: str = ""
|
||||
model_config = {
|
||||
"populate_by_name": True # Allow both field names and aliases
|
||||
}
|
||||
|
||||
class ChatMessageRagSearch(ChatMessageBase):
|
||||
status: ChatStatusType = ChatStatusType.DONE
|
||||
type: ChatMessageType = ChatMessageType.RAG_RESULT
|
||||
sender: ChatSenderType = ChatSenderType.USER
|
||||
class ChatMessageStreaming(ApiMessage):
|
||||
status: ApiStatusType = ApiStatusType.STREAMING
|
||||
type: ApiMessageType = ApiMessageType.TEXT
|
||||
content: str
|
||||
|
||||
class ApiActivityType(str, Enum):
|
||||
SYSTEM = "system" # Used solely on frontend
|
||||
INFO = "info" # Used solely on frontend
|
||||
SEARCHING = "searching" # Used when generating RAG information
|
||||
THINKING = "thinking" # Used when determing if AI will use tools
|
||||
GENERATING = "generating" # Used when AI is generating a response
|
||||
GENERATING_IMAGE = "generating_image" # Used when AI is generating an image
|
||||
TOOLING = "tooling" # Used when AI is using tools
|
||||
HEARTBEAT = "heartbeat" # Used for periodic updates
|
||||
|
||||
class ChatMessageStatus(ApiMessage):
|
||||
status: ApiStatusType = ApiStatusType.STATUS
|
||||
type: ApiMessageType = ApiMessageType.TEXT
|
||||
activity: ApiActivityType
|
||||
content: Any
|
||||
|
||||
class ChatMessageError(ApiMessage):
|
||||
status: ApiStatusType = ApiStatusType.ERROR
|
||||
type: ApiMessageType = ApiMessageType.TEXT
|
||||
content: str
|
||||
|
||||
class ChatMessageRagSearch(ApiMessage):
|
||||
type: ApiMessageType = ApiMessageType.JSON
|
||||
dimensions: int = 2 | 3
|
||||
content: List[ChromaDBGetResponse] = []
|
||||
|
||||
class JobRequirementsMessage(ApiMessage):
|
||||
type: ApiMessageType = ApiMessageType.JSON
|
||||
title: Optional[str]
|
||||
summary: Optional[str]
|
||||
company: Optional[str]
|
||||
description: str
|
||||
requirements: Optional[JobRequirements]
|
||||
|
||||
class ChatMessageMetaData(BaseModel):
|
||||
model: AIModelType = AIModelType.QWEN2_5
|
||||
@ -808,23 +828,26 @@ class ChatMessageMetaData(BaseModel):
|
||||
prompt_eval_count: int = 0
|
||||
prompt_eval_duration: int = 0
|
||||
options: Optional[ChatOptions] = None
|
||||
tools: Optional[Dict[str, Any]] = None
|
||||
timers: Optional[Dict[str, float]] = None
|
||||
tools: Dict[str, Any] = Field(default_factory=dict)
|
||||
timers: Dict[str, float] = Field(default_factory=dict)
|
||||
model_config = {
|
||||
"populate_by_name": True # Allow both field names and aliases
|
||||
}
|
||||
|
||||
class ChatMessageUser(ChatMessageBase):
|
||||
status: ChatStatusType = ChatStatusType.INITIALIZING
|
||||
type: ChatMessageType = ChatMessageType.GENERATING
|
||||
sender: ChatSenderType = ChatSenderType.USER
|
||||
class ChatMessageUser(ApiMessage):
|
||||
type: ApiMessageType = ApiMessageType.TEXT
|
||||
status: ApiStatusType = ApiStatusType.DONE
|
||||
role: ChatSenderType = ChatSenderType.USER
|
||||
content: str = ""
|
||||
tunables: Optional[Tunables] = None
|
||||
|
||||
class ChatMessage(ChatMessageBase):
|
||||
class ChatMessage(ChatMessageUser):
|
||||
role: ChatSenderType = ChatSenderType.ASSISTANT
|
||||
metadata: ChatMessageMetaData = Field(default_factory=ChatMessageMetaData)
|
||||
#attachments: Optional[List[Attachment]] = None
|
||||
#reactions: Optional[List[MessageReaction]] = None
|
||||
#is_edited: bool = Field(False, alias="isEdited")
|
||||
#edit_history: Optional[List[EditHistory]] = Field(None, alias="editHistory")
|
||||
metadata: ChatMessageMetaData = Field(default_factory=ChatMessageMetaData)
|
||||
|
||||
class ChatSession(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
|
@ -27,8 +27,9 @@ from .markdown_chunker import (
|
||||
|
||||
# When imported as a module, use relative imports
|
||||
import defines
|
||||
from models import ChromaDBGetResponse
|
||||
|
||||
__all__ = ["ChromaDBFileWatcher", "start_file_watcher", "ChromaDBGetResponse"]
|
||||
__all__ = ["ChromaDBFileWatcher", "start_file_watcher"]
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 750
|
||||
DEFAULT_CHUNK_OVERLAP = 100
|
||||
@ -38,46 +39,6 @@ class RagEntry(BaseModel):
|
||||
description: str = ""
|
||||
enabled: bool = True
|
||||
|
||||
class ChromaDBGetResponse(BaseModel):
|
||||
name: str = ""
|
||||
size: int = 0
|
||||
ids: List[str] = []
|
||||
embeddings: List[List[float]] = Field(default=[])
|
||||
documents: List[str] = []
|
||||
metadatas: List[Dict[str, Any]] = []
|
||||
query: str = ""
|
||||
query_embedding: Optional[List[float]] = Field(default=None)
|
||||
umap_embedding_2d: Optional[List[float]] = Field(default=None)
|
||||
umap_embedding_3d: Optional[List[float]] = Field(default=None)
|
||||
enabled: bool = True
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
|
||||
@field_validator("embeddings", "query_embedding", "umap_embedding_2d", "umap_embedding_3d")
|
||||
@classmethod
|
||||
def validate_embeddings(cls, value, field):
|
||||
# logging.info(f"Validating {field.field_name} with value: {type(value)} - {value}")
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, np.ndarray):
|
||||
if field.field_name == "embeddings":
|
||||
if value.ndim != 2:
|
||||
raise ValueError(f"{field.name} must be a 2-dimensional NumPy array")
|
||||
return [[float(x) for x in row] for row in value.tolist()]
|
||||
else:
|
||||
if value.ndim != 1:
|
||||
raise ValueError(f"{field.field_name} must be a 1-dimensional NumPy array")
|
||||
return [float(x) for x in value.tolist()]
|
||||
if field.field_name == "embeddings":
|
||||
if not all(isinstance(sublist, list) and all(isinstance(x, (int, float)) for x in sublist) for sublist in value):
|
||||
raise ValueError(f"{field.field_name} must be a list of lists of floats")
|
||||
return [[float(x) for x in sublist] for sublist in value]
|
||||
else:
|
||||
if not isinstance(value, list) or not all(isinstance(x, (int, float)) for x in value):
|
||||
raise ValueError(f"{field.field_name} must be a list of floats")
|
||||
return [float(x) for x in value]
|
||||
|
||||
class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
def __init__(
|
||||
self,
|
||||
|
Loading…
x
Reference in New Issue
Block a user