Chat and stuff is working again
This commit is contained in:
parent
efef926e45
commit
48e6eeaa71
@ -15,14 +15,14 @@ import { BackstoryElementProps } from './BackstoryTab';
|
|||||||
import { connectionBase } from 'utils/Global';
|
import { connectionBase } from 'utils/Global';
|
||||||
import { useAuth } from "hooks/AuthContext";
|
import { useAuth } from "hooks/AuthContext";
|
||||||
import { StreamingResponse } from 'services/api-client';
|
import { StreamingResponse } from 'services/api-client';
|
||||||
import { ChatMessage, ChatMessageBase, ChatContext, ChatSession, ChatQuery, ChatMessageUser } from 'types/types';
|
import { ChatMessage, ChatContext, ChatSession, ChatQuery, ChatMessageUser, ChatMessageError, ChatMessageStreaming, ChatMessageStatus } from 'types/types';
|
||||||
import { PaginatedResponse } from 'types/conversion';
|
import { PaginatedResponse } from 'types/conversion';
|
||||||
|
|
||||||
import './Conversation.css';
|
import './Conversation.css';
|
||||||
import { useSelectedCandidate } from 'hooks/GlobalContext';
|
import { useSelectedCandidate } from 'hooks/GlobalContext';
|
||||||
|
|
||||||
const defaultMessage: ChatMessage = {
|
const defaultMessage: ChatMessage = {
|
||||||
type: "preparing", status: "done", sender: "system", sessionId: "", timestamp: new Date(), content: ""
|
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "assistant"
|
||||||
};
|
};
|
||||||
|
|
||||||
const loadingMessage: ChatMessage = { ...defaultMessage, content: "Establishing connection with server..." };
|
const loadingMessage: ChatMessage = { ...defaultMessage, content: "Establishing connection with server..." };
|
||||||
@ -249,8 +249,7 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
|
|||||||
...conversationRef.current,
|
...conversationRef.current,
|
||||||
{
|
{
|
||||||
...defaultMessage,
|
...defaultMessage,
|
||||||
type: 'user',
|
type: 'text',
|
||||||
sender: 'user',
|
|
||||||
content: query.prompt,
|
content: query.prompt,
|
||||||
}
|
}
|
||||||
]);
|
]);
|
||||||
@ -260,34 +259,30 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
|
|||||||
);
|
);
|
||||||
|
|
||||||
const chatMessage: ChatMessageUser = {
|
const chatMessage: ChatMessageUser = {
|
||||||
|
role: "user",
|
||||||
sessionId: chatSession.id,
|
sessionId: chatSession.id,
|
||||||
content: query.prompt,
|
content: query.prompt,
|
||||||
tunables: query.tunables,
|
tunables: query.tunables,
|
||||||
status: "done",
|
status: "done",
|
||||||
type: "user",
|
type: "text",
|
||||||
sender: "user",
|
|
||||||
timestamp: new Date()
|
timestamp: new Date()
|
||||||
};
|
};
|
||||||
|
|
||||||
controllerRef.current = apiClient.sendMessageStream(chatMessage, {
|
controllerRef.current = apiClient.sendMessageStream(chatMessage, {
|
||||||
onMessage: (msg: ChatMessageBase) => {
|
onMessage: (msg: ChatMessage) => {
|
||||||
console.log("onMessage:", msg);
|
console.log("onMessage:", msg);
|
||||||
if (msg.type === "response") {
|
setConversation([
|
||||||
setConversation([
|
...conversationRef.current,
|
||||||
...conversationRef.current,
|
msg
|
||||||
msg
|
]);
|
||||||
]);
|
setStreamingMessage(undefined);
|
||||||
setStreamingMessage(undefined);
|
setProcessingMessage(undefined);
|
||||||
setProcessingMessage(undefined);
|
setProcessing(false);
|
||||||
setProcessing(false);
|
|
||||||
} else {
|
|
||||||
setProcessingMessage(msg);
|
|
||||||
}
|
|
||||||
if (onResponse) {
|
if (onResponse) {
|
||||||
onResponse(msg);
|
onResponse(msg);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
onError: (error: string | ChatMessageBase) => {
|
onError: (error: string | ChatMessageError) => {
|
||||||
console.log("onError:", error);
|
console.log("onError:", error);
|
||||||
// Type-guard to determine if this is a ChatMessageBase or a string
|
// Type-guard to determine if this is a ChatMessageBase or a string
|
||||||
if (typeof error === "object" && error !== null && "content" in error) {
|
if (typeof error === "object" && error !== null && "content" in error) {
|
||||||
@ -298,12 +293,12 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
|
|||||||
setProcessingMessage({ ...defaultMessage, content: error as string });
|
setProcessingMessage({ ...defaultMessage, content: error as string });
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
onStreaming: (chunk: ChatMessageBase) => {
|
onStreaming: (chunk: ChatMessageStreaming) => {
|
||||||
console.log("onStreaming:", chunk);
|
console.log("onStreaming:", chunk);
|
||||||
setStreamingMessage({ ...defaultMessage, ...chunk });
|
setStreamingMessage({ ...defaultMessage, ...chunk });
|
||||||
},
|
},
|
||||||
onStatusChange: (status: string) => {
|
onStatus: (status: ChatMessageStatus) => {
|
||||||
console.log("onStatusChange:", status);
|
console.log("onStatus:", status);
|
||||||
},
|
},
|
||||||
onComplete: () => {
|
onComplete: () => {
|
||||||
console.log("onComplete");
|
console.log("onComplete");
|
||||||
|
@ -20,7 +20,7 @@ import CheckCircleIcon from '@mui/icons-material/CheckCircle';
|
|||||||
import ErrorIcon from '@mui/icons-material/Error';
|
import ErrorIcon from '@mui/icons-material/Error';
|
||||||
import PendingIcon from '@mui/icons-material/Pending';
|
import PendingIcon from '@mui/icons-material/Pending';
|
||||||
import WarningIcon from '@mui/icons-material/Warning';
|
import WarningIcon from '@mui/icons-material/Warning';
|
||||||
import { Candidate, ChatMessage, ChatMessageBase, ChatMessageUser, ChatSession, JobRequirements, SkillMatch } from 'types/types';
|
import { Candidate, ChatMessage, ChatMessageError, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatSession, JobRequirements, SkillMatch } from 'types/types';
|
||||||
import { useAuth } from 'hooks/AuthContext';
|
import { useAuth } from 'hooks/AuthContext';
|
||||||
import { BackstoryPageProps } from './BackstoryTab';
|
import { BackstoryPageProps } from './BackstoryTab';
|
||||||
import { toCamelCase } from 'types/conversion';
|
import { toCamelCase } from 'types/conversion';
|
||||||
@ -31,8 +31,8 @@ interface JobAnalysisProps extends BackstoryPageProps {
|
|||||||
candidate: Candidate;
|
candidate: Candidate;
|
||||||
}
|
}
|
||||||
|
|
||||||
const defaultMessage: ChatMessageUser = {
|
const defaultMessage: ChatMessage = {
|
||||||
type: "preparing", status: "done", sender: "user", sessionId: "", timestamp: new Date(), content: ""
|
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "assistant"
|
||||||
};
|
};
|
||||||
|
|
||||||
const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) => {
|
const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) => {
|
||||||
@ -62,23 +62,19 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const createSession = async () => {
|
try {
|
||||||
try {
|
setCreatingSession(true);
|
||||||
const session: ChatSession = await apiClient.createCandidateChatSession(
|
apiClient.getOrCreateChatSession(candidate, `Generate requirements for ${candidate.fullName}`, 'job_requirements')
|
||||||
candidate.username,
|
.then(session => {
|
||||||
'job_requirements',
|
|
||||||
`Generate requirements for ${job.title}`
|
|
||||||
);
|
|
||||||
setSnack("Job analysis session started");
|
|
||||||
setRequirementsSession(session);
|
setRequirementsSession(session);
|
||||||
} catch (error) {
|
setCreatingSession(false);
|
||||||
console.log(error);
|
});
|
||||||
setSnack("Unable to create requirements session", "error");
|
} catch (error) {
|
||||||
}
|
setSnack('Unable to load chat session', 'error');
|
||||||
|
} finally {
|
||||||
setCreatingSession(false);
|
setCreatingSession(false);
|
||||||
};
|
}
|
||||||
setCreatingSession(true);
|
|
||||||
createSession();
|
|
||||||
}, [requirementsSession, apiClient, candidate]);
|
}, [requirementsSession, apiClient, candidate]);
|
||||||
|
|
||||||
// Fetch initial requirements
|
// Fetch initial requirements
|
||||||
@ -94,50 +90,48 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
|
|||||||
apiClient.sendMessageStream(chatMessage, {
|
apiClient.sendMessageStream(chatMessage, {
|
||||||
onMessage: (msg: ChatMessage) => {
|
onMessage: (msg: ChatMessage) => {
|
||||||
console.log(`onMessage: ${msg.type}`, msg);
|
console.log(`onMessage: ${msg.type}`, msg);
|
||||||
if (msg.type === "response") {
|
const job: Job = toCamelCase<Job>(JSON.parse(msg.content || ''));
|
||||||
const job: Job = toCamelCase<Job>(JSON.parse(msg.content || ''));
|
const requirements: { requirement: string, domain: string }[] = [];
|
||||||
const requirements: { requirement: string, domain: string }[] = [];
|
if (job.requirements?.technicalSkills) {
|
||||||
if (job.requirements?.technicalSkills) {
|
job.requirements.technicalSkills.required?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (required)' }));
|
||||||
job.requirements.technicalSkills.required?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (required)' }));
|
job.requirements.technicalSkills.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (preferred)' }));
|
||||||
job.requirements.technicalSkills.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (preferred)' }));
|
|
||||||
}
|
|
||||||
if (job.requirements?.experienceRequirements) {
|
|
||||||
job.requirements.experienceRequirements.required?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (required)' }));
|
|
||||||
job.requirements.experienceRequirements.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (preferred)' }));
|
|
||||||
}
|
|
||||||
if (job.requirements?.softSkills) {
|
|
||||||
job.requirements.softSkills.forEach(req => requirements.push({ requirement: req, domain: 'Soft Skills' }));
|
|
||||||
}
|
|
||||||
if (job.requirements?.experience) {
|
|
||||||
job.requirements.experience.forEach(req => requirements.push({ requirement: req, domain: 'Experience' }));
|
|
||||||
}
|
|
||||||
if (job.requirements?.education) {
|
|
||||||
job.requirements.education.forEach(req => requirements.push({ requirement: req, domain: 'Education' }));
|
|
||||||
}
|
|
||||||
if (job.requirements?.certifications) {
|
|
||||||
job.requirements.certifications.forEach(req => requirements.push({ requirement: req, domain: 'Certifications' }));
|
|
||||||
}
|
|
||||||
if (job.requirements?.preferredAttributes) {
|
|
||||||
job.requirements.preferredAttributes.forEach(req => requirements.push({ requirement: req, domain: 'Preferred Attributes' }));
|
|
||||||
}
|
|
||||||
|
|
||||||
const initialSkillMatches = requirements.map(req => ({
|
|
||||||
requirement: req.requirement,
|
|
||||||
domain: req.domain,
|
|
||||||
status: 'waiting' as const,
|
|
||||||
matchScore: 0,
|
|
||||||
assessment: '',
|
|
||||||
description: '',
|
|
||||||
citations: []
|
|
||||||
}));
|
|
||||||
|
|
||||||
setRequirements(requirements);
|
|
||||||
setSkillMatches(initialSkillMatches);
|
|
||||||
setStatusMessage(null);
|
|
||||||
setLoadingRequirements(false);
|
|
||||||
}
|
}
|
||||||
|
if (job.requirements?.experienceRequirements) {
|
||||||
|
job.requirements.experienceRequirements.required?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (required)' }));
|
||||||
|
job.requirements.experienceRequirements.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (preferred)' }));
|
||||||
|
}
|
||||||
|
if (job.requirements?.softSkills) {
|
||||||
|
job.requirements.softSkills.forEach(req => requirements.push({ requirement: req, domain: 'Soft Skills' }));
|
||||||
|
}
|
||||||
|
if (job.requirements?.experience) {
|
||||||
|
job.requirements.experience.forEach(req => requirements.push({ requirement: req, domain: 'Experience' }));
|
||||||
|
}
|
||||||
|
if (job.requirements?.education) {
|
||||||
|
job.requirements.education.forEach(req => requirements.push({ requirement: req, domain: 'Education' }));
|
||||||
|
}
|
||||||
|
if (job.requirements?.certifications) {
|
||||||
|
job.requirements.certifications.forEach(req => requirements.push({ requirement: req, domain: 'Certifications' }));
|
||||||
|
}
|
||||||
|
if (job.requirements?.preferredAttributes) {
|
||||||
|
job.requirements.preferredAttributes.forEach(req => requirements.push({ requirement: req, domain: 'Preferred Attributes' }));
|
||||||
|
}
|
||||||
|
|
||||||
|
const initialSkillMatches = requirements.map(req => ({
|
||||||
|
requirement: req.requirement,
|
||||||
|
domain: req.domain,
|
||||||
|
status: 'waiting' as const,
|
||||||
|
matchScore: 0,
|
||||||
|
assessment: '',
|
||||||
|
description: '',
|
||||||
|
citations: []
|
||||||
|
}));
|
||||||
|
|
||||||
|
setRequirements(requirements);
|
||||||
|
setSkillMatches(initialSkillMatches);
|
||||||
|
setStatusMessage(null);
|
||||||
|
setLoadingRequirements(false);
|
||||||
},
|
},
|
||||||
onError: (error: string | ChatMessageBase) => {
|
onError: (error: string | ChatMessageError) => {
|
||||||
console.log("onError:", error);
|
console.log("onError:", error);
|
||||||
// Type-guard to determine if this is a ChatMessageBase or a string
|
// Type-guard to determine if this is a ChatMessageBase or a string
|
||||||
if (typeof error === "object" && error !== null && "content" in error) {
|
if (typeof error === "object" && error !== null && "content" in error) {
|
||||||
@ -147,11 +141,11 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
|
|||||||
}
|
}
|
||||||
setLoadingRequirements(false);
|
setLoadingRequirements(false);
|
||||||
},
|
},
|
||||||
onStreaming: (chunk: ChatMessageBase) => {
|
onStreaming: (chunk: ChatMessageStreaming) => {
|
||||||
// console.log("onStreaming:", chunk);
|
// console.log("onStreaming:", chunk);
|
||||||
},
|
},
|
||||||
onStatusChange: (status: string) => {
|
onStatus: (status: ChatMessageStatus) => {
|
||||||
console.log(`onStatusChange: ${status}`);
|
console.log(`onStatus: ${status}`);
|
||||||
},
|
},
|
||||||
onComplete: () => {
|
onComplete: () => {
|
||||||
console.log("onComplete");
|
console.log("onComplete");
|
||||||
|
@ -32,9 +32,9 @@ import { SetSnackType } from './Snack';
|
|||||||
import { CopyBubble } from './CopyBubble';
|
import { CopyBubble } from './CopyBubble';
|
||||||
import { Scrollable } from './Scrollable';
|
import { Scrollable } from './Scrollable';
|
||||||
import { BackstoryElementProps } from './BackstoryTab';
|
import { BackstoryElementProps } from './BackstoryTab';
|
||||||
import { ChatMessage, ChatSession, ChatMessageType, ChatMessageMetaData, ChromaDBGetResponse } from 'types/types';
|
import { ChatMessage, ChatSession, ChatMessageMetaData, ChromaDBGetResponse, ApiActivityType, ChatMessageUser, ChatMessageError, ChatMessageStatus, ChatSenderType } from 'types/types';
|
||||||
|
|
||||||
const getStyle = (theme: Theme, type: ChatMessageType): any => {
|
const getStyle = (theme: Theme, type: ApiActivityType | ChatSenderType | "error"): any => {
|
||||||
const defaultRadius = '16px';
|
const defaultRadius = '16px';
|
||||||
const defaultStyle = {
|
const defaultStyle = {
|
||||||
padding: theme.spacing(1, 2),
|
padding: theme.spacing(1, 2),
|
||||||
@ -56,7 +56,7 @@ const getStyle = (theme: Theme, type: ChatMessageType): any => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const styles: any = {
|
const styles: any = {
|
||||||
response: {
|
assistant: {
|
||||||
...defaultStyle,
|
...defaultStyle,
|
||||||
backgroundColor: theme.palette.primary.main,
|
backgroundColor: theme.palette.primary.main,
|
||||||
border: `1px solid ${theme.palette.secondary.main}`,
|
border: `1px solid ${theme.palette.secondary.main}`,
|
||||||
@ -90,9 +90,10 @@ const getStyle = (theme: Theme, type: ChatMessageType): any => {
|
|||||||
boxShadow: '0 1px 3px rgba(216, 58, 58, 0.15)',
|
boxShadow: '0 1px 3px rgba(216, 58, 58, 0.15)',
|
||||||
},
|
},
|
||||||
'fact-check': 'qualifications',
|
'fact-check': 'qualifications',
|
||||||
|
generating: 'status',
|
||||||
'job-description': 'content',
|
'job-description': 'content',
|
||||||
'job-requirements': 'qualifications',
|
'job-requirements': 'qualifications',
|
||||||
info: {
|
information: {
|
||||||
...defaultStyle,
|
...defaultStyle,
|
||||||
backgroundColor: '#BFD8D8',
|
backgroundColor: '#BFD8D8',
|
||||||
border: `1px solid ${theme.palette.secondary.main}`,
|
border: `1px solid ${theme.palette.secondary.main}`,
|
||||||
@ -156,25 +157,29 @@ const getStyle = (theme: Theme, type: ChatMessageType): any => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!(type in styles)) {
|
||||||
|
console.log(`Style does not exist for: ${type}`);
|
||||||
|
}
|
||||||
|
|
||||||
return styles[type];
|
return styles[type];
|
||||||
}
|
}
|
||||||
|
|
||||||
const getIcon = (messageType: ChatMessageType): React.ReactNode | null => {
|
const getIcon = (activityType: ApiActivityType | ChatSenderType | "error"): React.ReactNode | null => {
|
||||||
const icons: any = {
|
const icons: any = {
|
||||||
error: <ErrorOutline color="error" />,
|
error: <ErrorOutline color="error" />,
|
||||||
generating: <LocationSearchingIcon />,
|
generating: <LocationSearchingIcon />,
|
||||||
info: <InfoOutline color="info" />,
|
information: <InfoOutline color="info" />,
|
||||||
preparing: <LocationSearchingIcon />,
|
preparing: <LocationSearchingIcon />,
|
||||||
processing: <LocationSearchingIcon />,
|
processing: <LocationSearchingIcon />,
|
||||||
system: <Memory />,
|
system: <Memory />,
|
||||||
thinking: <Psychology />,
|
thinking: <Psychology />,
|
||||||
tooling: <LocationSearchingIcon />,
|
tooling: <LocationSearchingIcon />,
|
||||||
};
|
};
|
||||||
return icons[messageType] || null;
|
return icons[activityType] || null;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface MessageProps extends BackstoryElementProps {
|
interface MessageProps extends BackstoryElementProps {
|
||||||
message: ChatMessage,
|
message: ChatMessageUser | ChatMessage | ChatMessageError | ChatMessageStatus,
|
||||||
title?: string,
|
title?: string,
|
||||||
chatSession?: ChatSession,
|
chatSession?: ChatSession,
|
||||||
className?: string,
|
className?: string,
|
||||||
@ -319,7 +324,7 @@ const MessageMeta = (props: MessageMetaProps) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
interface MessageContainerProps {
|
interface MessageContainerProps {
|
||||||
type: ChatMessageType,
|
type: ApiActivityType | ChatSenderType | "error",
|
||||||
metadataView?: React.ReactNode | null,
|
metadataView?: React.ReactNode | null,
|
||||||
messageView?: React.ReactNode | null,
|
messageView?: React.ReactNode | null,
|
||||||
sx?: SxProps<Theme>,
|
sx?: SxProps<Theme>,
|
||||||
@ -360,13 +365,21 @@ const Message = (props: MessageProps) => {
|
|||||||
setSnack
|
setSnack
|
||||||
};
|
};
|
||||||
const theme = useTheme();
|
const theme = useTheme();
|
||||||
const style: any = getStyle(theme, message.type);
|
const type: ApiActivityType | ChatSenderType | "error" = ('activity' in message) ? message.activity : ('error' in message) ? 'error' : (message as ChatMessage).role;
|
||||||
|
const style: any = getStyle(theme, type);
|
||||||
|
|
||||||
const handleMetaExpandClick = () => {
|
const handleMetaExpandClick = () => {
|
||||||
setMetaExpanded(!metaExpanded);
|
setMetaExpanded(!metaExpanded);
|
||||||
};
|
};
|
||||||
|
|
||||||
const content = message.content?.trim();
|
let content;
|
||||||
|
if (typeof (message.content) === "string") {
|
||||||
|
content = message.content.trim();
|
||||||
|
} else {
|
||||||
|
console.error(`message content is not a string`);
|
||||||
|
return (<></>)
|
||||||
|
}
|
||||||
|
|
||||||
if (!content) {
|
if (!content) {
|
||||||
return (<></>)
|
return (<></>)
|
||||||
};
|
};
|
||||||
@ -376,7 +389,8 @@ const Message = (props: MessageProps) => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let metadataView = (<></>);
|
let metadataView = (<></>);
|
||||||
if (message.metadata) {
|
const metadata: ChatMessageMetaData | null = ('metadata' in message) ? (message.metadata as ChatMessageMetaData || null) : null;
|
||||||
|
if (metadata) {
|
||||||
metadataView = (
|
metadataView = (
|
||||||
<Box sx={{ display: "flex", flexDirection: "column", width: "100%" }}>
|
<Box sx={{ display: "flex", flexDirection: "column", width: "100%" }}>
|
||||||
<Box sx={{ display: "flex", alignItems: "center", gap: 1, flexDirection: "row" }}>
|
<Box sx={{ display: "flex", alignItems: "center", gap: 1, flexDirection: "row" }}>
|
||||||
@ -395,18 +409,18 @@ const Message = (props: MessageProps) => {
|
|||||||
</Box>
|
</Box>
|
||||||
<Collapse in={metaExpanded} timeout="auto" unmountOnExit>
|
<Collapse in={metaExpanded} timeout="auto" unmountOnExit>
|
||||||
<CardContent>
|
<CardContent>
|
||||||
<MessageMeta messageProps={props} metadata={message.metadata} />
|
<MessageMeta messageProps={props} metadata={metadata} />
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Collapse>
|
</Collapse>
|
||||||
</Box>);
|
</Box>);
|
||||||
}
|
}
|
||||||
|
|
||||||
const copyContent = message.sender === 'assistant' ? message.content : undefined;
|
const copyContent = (type === 'assistant') ? message.content : undefined;
|
||||||
|
|
||||||
if (!expandable) {
|
if (!expandable) {
|
||||||
/* When not expandable, the styles are applied directly to MessageContainer */
|
/* When not expandable, the styles are applied directly to MessageContainer */
|
||||||
return (<>
|
return (<>
|
||||||
{messageView && <MessageContainer copyContent={copyContent} type={message.type} {...{ messageView, metadataView }} sx={{ ...style, ...sx }} />}
|
{messageView && <MessageContainer copyContent={copyContent} type={type} {...{ messageView, metadataView }} sx={{ ...style, ...sx }} />}
|
||||||
</>);
|
</>);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -435,7 +449,7 @@ const Message = (props: MessageProps) => {
|
|||||||
{title || ''}
|
{title || ''}
|
||||||
</AccordionSummary>
|
</AccordionSummary>
|
||||||
<AccordionDetails sx={{ mt: 0, mb: 0, p: 0, pl: 2, pr: 2 }}>
|
<AccordionDetails sx={{ mt: 0, mb: 0, p: 0, pl: 2, pr: 2 }}>
|
||||||
<MessageContainer copyContent={copyContent} type={message.type} {...{ messageView, metadataView }} />
|
<MessageContainer copyContent={copyContent} type={type} {...{ messageView, metadataView }} />
|
||||||
</AccordionDetails>
|
</AccordionDetails>
|
||||||
</Accordion>
|
</Accordion>
|
||||||
);
|
);
|
||||||
|
@ -25,11 +25,6 @@ import { useAuth } from 'hooks/AuthContext';
|
|||||||
import * as Types from 'types/types';
|
import * as Types from 'types/types';
|
||||||
import { useSelectedCandidate } from 'hooks/GlobalContext';
|
import { useSelectedCandidate } from 'hooks/GlobalContext';
|
||||||
import { useNavigate } from 'react-router-dom';
|
import { useNavigate } from 'react-router-dom';
|
||||||
import { Message } from './Message';
|
|
||||||
const defaultMessage: Types.ChatMessageBase = {
|
|
||||||
type: "preparing", status: "done", sender: "system", sessionId: "", timestamp: new Date(), content: ""
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
interface VectorVisualizerProps extends BackstoryPageProps {
|
interface VectorVisualizerProps extends BackstoryPageProps {
|
||||||
inline?: boolean;
|
inline?: boolean;
|
||||||
|
@ -12,7 +12,7 @@ import {
|
|||||||
Send as SendIcon
|
Send as SendIcon
|
||||||
} from '@mui/icons-material';
|
} from '@mui/icons-material';
|
||||||
import { useAuth } from 'hooks/AuthContext';
|
import { useAuth } from 'hooks/AuthContext';
|
||||||
import { ChatMessageBase, ChatMessage, ChatSession, ChatMessageUser } from 'types/types';
|
import { ChatMessage, ChatSession, ChatMessageUser, ChatMessageError, ChatMessageStreaming, ChatMessageStatus } from 'types/types';
|
||||||
import { ConversationHandle } from 'components/Conversation';
|
import { ConversationHandle } from 'components/Conversation';
|
||||||
import { BackstoryPageProps } from 'components/BackstoryTab';
|
import { BackstoryPageProps } from 'components/BackstoryTab';
|
||||||
import { Message } from 'components/Message';
|
import { Message } from 'components/Message';
|
||||||
@ -25,7 +25,7 @@ import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryT
|
|||||||
import { BackstoryQuery } from 'components/BackstoryQuery';
|
import { BackstoryQuery } from 'components/BackstoryQuery';
|
||||||
|
|
||||||
const defaultMessage: ChatMessage = {
|
const defaultMessage: ChatMessage = {
|
||||||
type: "preparing", status: "done", sender: "system", sessionId: "", timestamp: new Date(), content: ""
|
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "user"
|
||||||
};
|
};
|
||||||
|
|
||||||
const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: BackstoryPageProps, ref) => {
|
const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: BackstoryPageProps, ref) => {
|
||||||
@ -33,7 +33,7 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
|||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
const { selectedCandidate } = useSelectedCandidate()
|
const { selectedCandidate } = useSelectedCandidate()
|
||||||
const theme = useTheme();
|
const theme = useTheme();
|
||||||
const [processingMessage, setProcessingMessage] = useState<ChatMessage | null>(null);
|
const [processingMessage, setProcessingMessage] = useState<ChatMessageStatus | ChatMessageError | null>(null);
|
||||||
const [streamingMessage, setStreamingMessage] = useState<ChatMessage | null>(null);
|
const [streamingMessage, setStreamingMessage] = useState<ChatMessage | null>(null);
|
||||||
const backstoryTextRef = useRef<BackstoryTextFieldRef>(null);
|
const backstoryTextRef = useRef<BackstoryTextFieldRef>(null);
|
||||||
|
|
||||||
@ -48,32 +48,6 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
|||||||
const [streaming, setStreaming] = useState<boolean>(false);
|
const [streaming, setStreaming] = useState<boolean>(false);
|
||||||
const messagesEndRef = useRef(null);
|
const messagesEndRef = useRef(null);
|
||||||
|
|
||||||
// Load sessions for the selectedCandidate
|
|
||||||
const loadSessions = async () => {
|
|
||||||
if (!selectedCandidate) return;
|
|
||||||
|
|
||||||
try {
|
|
||||||
setLoading(true);
|
|
||||||
const result = await apiClient.getCandidateChatSessions(selectedCandidate.username);
|
|
||||||
let session = null;
|
|
||||||
if (result.sessions.data.length === 0) {
|
|
||||||
session = await apiClient.createCandidateChatSession(
|
|
||||||
selectedCandidate.username,
|
|
||||||
'candidate_chat',
|
|
||||||
`Backstory chat about ${selectedCandidate.fullName}`
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
session = result.sessions.data[0];
|
|
||||||
}
|
|
||||||
setChatSession(session);
|
|
||||||
setLoading(false);
|
|
||||||
} catch (error) {
|
|
||||||
setSnack('Unable to load chat session', 'error');
|
|
||||||
} finally {
|
|
||||||
setLoading(false);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Load messages for current session
|
// Load messages for current session
|
||||||
const loadMessages = async () => {
|
const loadMessages = async () => {
|
||||||
if (!chatSession?.id) return;
|
if (!chatSession?.id) return;
|
||||||
@ -107,20 +81,22 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
|||||||
|
|
||||||
// Send message
|
// Send message
|
||||||
const sendMessage = async (message: string) => {
|
const sendMessage = async (message: string) => {
|
||||||
if (!message.trim() || !chatSession?.id || streaming) return;
|
if (!message.trim() || !chatSession?.id || streaming || !selectedCandidate) return;
|
||||||
|
|
||||||
const messageContent = message;
|
const messageContent = message;
|
||||||
setStreaming(true);
|
setStreaming(true);
|
||||||
|
|
||||||
const chatMessage: ChatMessageUser = {
|
const chatMessage: ChatMessageUser = {
|
||||||
sessionId: chatSession.id,
|
sessionId: chatSession.id,
|
||||||
|
role: "user",
|
||||||
content: messageContent,
|
content: messageContent,
|
||||||
status: "done",
|
status: "done",
|
||||||
type: "user",
|
type: "text",
|
||||||
sender: "user",
|
|
||||||
timestamp: new Date()
|
timestamp: new Date()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
setProcessingMessage({ ...defaultMessage, status: 'status', content: `Establishing connection with ${selectedCandidate.firstName}'s chat session.` });
|
||||||
|
|
||||||
setMessages(prev => {
|
setMessages(prev => {
|
||||||
const filtered = prev.filter((m: any) => m.id !== chatMessage.id);
|
const filtered = prev.filter((m: any) => m.id !== chatMessage.id);
|
||||||
return [...filtered, chatMessage] as any;
|
return [...filtered, chatMessage] as any;
|
||||||
@ -129,34 +105,31 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
|||||||
try {
|
try {
|
||||||
apiClient.sendMessageStream(chatMessage, {
|
apiClient.sendMessageStream(chatMessage, {
|
||||||
onMessage: (msg: ChatMessage) => {
|
onMessage: (msg: ChatMessage) => {
|
||||||
console.log(`onMessage: ${msg.type} ${msg.content}`, msg);
|
setMessages(prev => {
|
||||||
if (msg.type === "response") {
|
const filtered = prev.filter((m: any) => m.id !== msg.id);
|
||||||
setMessages(prev => {
|
return [...filtered, msg] as any;
|
||||||
const filtered = prev.filter((m: any) => m.id !== msg.id);
|
});
|
||||||
return [...filtered, msg] as any;
|
setStreamingMessage(null);
|
||||||
});
|
setProcessingMessage(null);
|
||||||
setStreamingMessage(null);
|
|
||||||
setProcessingMessage(null);
|
|
||||||
} else {
|
|
||||||
setProcessingMessage(msg);
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
onError: (error: string | ChatMessageBase) => {
|
onError: (error: string | ChatMessageError) => {
|
||||||
console.log("onError:", error);
|
console.log("onError:", error);
|
||||||
|
let message: string;
|
||||||
// Type-guard to determine if this is a ChatMessageBase or a string
|
// Type-guard to determine if this is a ChatMessageBase or a string
|
||||||
if (typeof error === "object" && error !== null && "content" in error) {
|
if (typeof error === "object" && error !== null && "content" in error) {
|
||||||
setProcessingMessage(error as ChatMessage);
|
setProcessingMessage(error);
|
||||||
|
message = error.content as string;
|
||||||
} else {
|
} else {
|
||||||
setProcessingMessage({ ...defaultMessage, content: error as string });
|
setProcessingMessage({ ...defaultMessage, status: "error", content: error })
|
||||||
}
|
}
|
||||||
setStreaming(false);
|
setStreaming(false);
|
||||||
},
|
},
|
||||||
onStreaming: (chunk: ChatMessageBase) => {
|
onStreaming: (chunk: ChatMessageStreaming) => {
|
||||||
// console.log("onStreaming:", chunk);
|
// console.log("onStreaming:", chunk);
|
||||||
setStreamingMessage({ ...defaultMessage, ...chunk });
|
setStreamingMessage({ ...chunk, role: 'assistant' });
|
||||||
},
|
},
|
||||||
onStatusChange: (status: string) => {
|
onStatus: (status: ChatMessageStatus) => {
|
||||||
console.log(`onStatusChange: ${status}`);
|
setProcessingMessage(status);
|
||||||
},
|
},
|
||||||
onComplete: () => {
|
onComplete: () => {
|
||||||
console.log("onComplete");
|
console.log("onComplete");
|
||||||
@ -178,7 +151,19 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
|||||||
|
|
||||||
// Load sessions when username changes
|
// Load sessions when username changes
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
loadSessions();
|
if (!selectedCandidate) return;
|
||||||
|
try {
|
||||||
|
setLoading(true);
|
||||||
|
apiClient.getOrCreateChatSession(selectedCandidate, `Backstory chat with ${selectedCandidate.fullName}`, 'candidate_chat')
|
||||||
|
.then(session => {
|
||||||
|
setChatSession(session);
|
||||||
|
setLoading(false);
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
setSnack('Unable to load chat session', 'error');
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
}, [selectedCandidate]);
|
}, [selectedCandidate]);
|
||||||
|
|
||||||
// Load messages when session changes
|
// Load messages when session changes
|
||||||
@ -195,9 +180,9 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
|||||||
|
|
||||||
const welcomeMessage: ChatMessage = {
|
const welcomeMessage: ChatMessage = {
|
||||||
sessionId: chatSession?.id || '',
|
sessionId: chatSession?.id || '',
|
||||||
type: "info",
|
role: "information",
|
||||||
|
type: "text",
|
||||||
status: "done",
|
status: "done",
|
||||||
sender: "system",
|
|
||||||
timestamp: new Date(),
|
timestamp: new Date(),
|
||||||
content: `Welcome to the Backstory Chat about ${selectedCandidate.fullName}. Ask any questions you have about ${selectedCandidate.firstName}.`
|
content: `Welcome to the Backstory Chat about ${selectedCandidate.fullName}. Ask any questions you have about ${selectedCandidate.firstName}.`
|
||||||
};
|
};
|
||||||
@ -252,7 +237,7 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
|||||||
},
|
},
|
||||||
}}>
|
}}>
|
||||||
{messages.length === 0 && <Message {...{ chatSession, message: welcomeMessage, setSnack, submitQuery }} />}
|
{messages.length === 0 && <Message {...{ chatSession, message: welcomeMessage, setSnack, submitQuery }} />}
|
||||||
{messages.map((message: ChatMessageBase) => (
|
{messages.map((message: ChatMessage) => (
|
||||||
<Message key={message.id} {...{ chatSession, message, setSnack, submitQuery }} />
|
<Message key={message.id} {...{ chatSession, message, setSnack, submitQuery }} />
|
||||||
))}
|
))}
|
||||||
{processingMessage !== null && (
|
{processingMessage !== null && (
|
||||||
|
@ -8,7 +8,6 @@ import IconButton from '@mui/material/IconButton';
|
|||||||
import CancelIcon from '@mui/icons-material/Cancel';
|
import CancelIcon from '@mui/icons-material/Cancel';
|
||||||
import SendIcon from '@mui/icons-material/Send';
|
import SendIcon from '@mui/icons-material/Send';
|
||||||
import PropagateLoader from 'react-spinners/PropagateLoader';
|
import PropagateLoader from 'react-spinners/PropagateLoader';
|
||||||
import { jsonrepair } from 'jsonrepair';
|
|
||||||
|
|
||||||
import { CandidateInfo } from '../components/CandidateInfo';
|
import { CandidateInfo } from '../components/CandidateInfo';
|
||||||
import { Quote } from 'components/Quote';
|
import { Quote } from 'components/Quote';
|
||||||
@ -18,57 +17,24 @@ import { StyledMarkdown } from 'components/StyledMarkdown';
|
|||||||
import { Scrollable } from '../components/Scrollable';
|
import { Scrollable } from '../components/Scrollable';
|
||||||
import { Pulse } from 'components/Pulse';
|
import { Pulse } from 'components/Pulse';
|
||||||
import { StreamingResponse } from 'services/api-client';
|
import { StreamingResponse } from 'services/api-client';
|
||||||
import { ChatContext, ChatMessage, ChatMessageUser, ChatMessageBase, ChatSession, ChatQuery, Candidate, CandidateAI } from 'types/types';
|
import { ChatMessage, ChatMessageUser, ChatSession, CandidateAI, ChatMessageStatus, ChatMessageError } from 'types/types';
|
||||||
import { useAuth } from 'hooks/AuthContext';
|
import { useAuth } from 'hooks/AuthContext';
|
||||||
import { Message } from 'components/Message';
|
import { Message } from 'components/Message';
|
||||||
import { Types } from '@uiw/react-json-view';
|
|
||||||
import { assert } from 'console';
|
|
||||||
|
|
||||||
const emptyUser: CandidateAI = {
|
|
||||||
userType: "candidate",
|
|
||||||
isAI: true,
|
|
||||||
description: "[blank]",
|
|
||||||
username: "[blank]",
|
|
||||||
firstName: "[blank]",
|
|
||||||
lastName: "[blank]",
|
|
||||||
fullName: "[blank] [blank]",
|
|
||||||
questions: [],
|
|
||||||
location: {
|
|
||||||
city: '[blank]',
|
|
||||||
country: '[blank]'
|
|
||||||
},
|
|
||||||
email: '[blank]',
|
|
||||||
createdAt: new Date(),
|
|
||||||
updatedAt: new Date(),
|
|
||||||
status: "pending",
|
|
||||||
skills: [],
|
|
||||||
experience: [],
|
|
||||||
education: [],
|
|
||||||
preferredJobTypes: [],
|
|
||||||
languages: [],
|
|
||||||
certifications: [],
|
|
||||||
isAdmin: false,
|
|
||||||
profileImage: undefined,
|
|
||||||
ragContentSize: 0
|
|
||||||
};
|
|
||||||
|
|
||||||
const defaultMessage: ChatMessage = {
|
const defaultMessage: ChatMessage = {
|
||||||
type: "preparing", status: "done", sender: "system", sessionId: "", timestamp: new Date(), content: ""
|
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "user"
|
||||||
};
|
};
|
||||||
|
|
||||||
const GenerateCandidate = (props: BackstoryElementProps) => {
|
const GenerateCandidate = (props: BackstoryElementProps) => {
|
||||||
const { apiClient, user } = useAuth();
|
const { apiClient, user } = useAuth();
|
||||||
const { setSnack, submitQuery } = props;
|
const { setSnack, submitQuery } = props;
|
||||||
const [streaming, setStreaming] = useState<boolean>(false);
|
|
||||||
const [streamingMessage, setStreamingMessage] = useState<ChatMessage | null>(null);
|
|
||||||
const [processingMessage, setProcessingMessage] = useState<ChatMessage | null>(null);
|
const [processingMessage, setProcessingMessage] = useState<ChatMessage | null>(null);
|
||||||
const [processing, setProcessing] = useState<boolean>(false);
|
const [processing, setProcessing] = useState<boolean>(false);
|
||||||
const [generatedUser, setGeneratedUser] = useState<CandidateAI | null>(null);
|
const [generatedUser, setGeneratedUser] = useState<CandidateAI | null>(null);
|
||||||
const [prompt, setPrompt] = useState<string>('');
|
const [prompt, setPrompt] = useState<string>('');
|
||||||
const [resume, setResume] = useState<string | null>(null);
|
const [resume, setResume] = useState<string | null>(null);
|
||||||
const [canGenImage, setCanGenImage] = useState<boolean>(false);
|
const [canGenImage, setCanGenImage] = useState<boolean>(false);
|
||||||
const [timestamp, setTimestamp] = useState<number>(0);
|
const [timestamp, setTimestamp] = useState<string>('');
|
||||||
const [state, setState] = useState<number>(0); // Replaced stateRef
|
|
||||||
const [shouldGenerateProfile, setShouldGenerateProfile] = useState<boolean>(false);
|
const [shouldGenerateProfile, setShouldGenerateProfile] = useState<boolean>(false);
|
||||||
const [chatSession, setChatSession] = useState<ChatSession | null>(null);
|
const [chatSession, setChatSession] = useState<ChatSession | null>(null);
|
||||||
const [loading, setLoading] = useState<boolean>(false);
|
const [loading, setLoading] = useState<boolean>(false);
|
||||||
@ -83,61 +49,24 @@ const GenerateCandidate = (props: BackstoryElementProps) => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const createChatSession = async () => {
|
|
||||||
console.log('Creating chat session');
|
|
||||||
try {
|
|
||||||
const response: ChatSession = await apiClient.createCandidateChatSession(
|
|
||||||
generatedUser.username,
|
|
||||||
"generate_image",
|
|
||||||
"Profile image generation"
|
|
||||||
);
|
|
||||||
setChatSession(response);
|
|
||||||
console.log(`Chat session created for generate_image`, response);
|
|
||||||
setSnack(`Chat session created for generate_image: ${response.id}`);
|
|
||||||
} catch (e) {
|
|
||||||
console.error(e);
|
|
||||||
setSnack("Unable to create image generation session.", "error");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
setLoading(true);
|
|
||||||
createChatSession().then(() => { setLoading(false) });
|
|
||||||
}, [generatedUser, chatSession, loading, setChatSession, setLoading, setSnack]);
|
|
||||||
|
|
||||||
const generatePersona = async (prompt: string) => {
|
|
||||||
const userMessage: ChatMessageUser = {
|
|
||||||
content: prompt,
|
|
||||||
sessionId: "",
|
|
||||||
sender: "user",
|
|
||||||
status: "done",
|
|
||||||
type: "user",
|
|
||||||
timestamp: new Date()
|
|
||||||
};
|
|
||||||
setPrompt(prompt || '');
|
|
||||||
setProcessing(true);
|
|
||||||
setProcessingMessage({ ...defaultMessage, content: "Generating persona..." });
|
|
||||||
try {
|
try {
|
||||||
const result = await apiClient.createCandidateAI(userMessage);
|
setLoading(true);
|
||||||
console.log(result.message, result);
|
apiClient.getOrCreateChatSession(generatedUser, `Profile image generator for ${generatedUser.fullName}`, 'generate_image')
|
||||||
setGeneratedUser(result.candidate);
|
.then(session => {
|
||||||
setResume(result.resume);
|
setChatSession(session);
|
||||||
setCanGenImage(true);
|
setLoading(false);
|
||||||
setShouldGenerateProfile(true); // Reset the flag
|
});
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(error);
|
setSnack('Unable to load chat session', 'error');
|
||||||
setPrompt('');
|
} finally {
|
||||||
setResume(null);
|
setLoading(false);
|
||||||
setProcessing(false);
|
|
||||||
setProcessingMessage(null);
|
|
||||||
setSnack("Unable to generate AI persona", "error");
|
|
||||||
}
|
}
|
||||||
};
|
}, [generatedUser, chatSession, loading, setChatSession, setLoading, setSnack, apiClient]);
|
||||||
|
|
||||||
const cancelQuery = useCallback(() => {
|
const cancelQuery = useCallback(() => {
|
||||||
if (controllerRef.current) {
|
if (controllerRef.current) {
|
||||||
controllerRef.current.cancel();
|
controllerRef.current.cancel();
|
||||||
controllerRef.current = null;
|
controllerRef.current = null;
|
||||||
setState(0);
|
|
||||||
setProcessing(false);
|
setProcessing(false);
|
||||||
}
|
}
|
||||||
}, []);
|
}, []);
|
||||||
@ -146,8 +75,38 @@ const GenerateCandidate = (props: BackstoryElementProps) => {
|
|||||||
if (processing) {
|
if (processing) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const generatePersona = async (prompt: string) => {
|
||||||
|
const userMessage: ChatMessageUser = {
|
||||||
|
type: "text",
|
||||||
|
role: "user",
|
||||||
|
content: prompt,
|
||||||
|
sessionId: "",
|
||||||
|
status: "done",
|
||||||
|
timestamp: new Date()
|
||||||
|
};
|
||||||
|
setPrompt(prompt || '');
|
||||||
|
setProcessing(true);
|
||||||
|
setProcessingMessage({ ...defaultMessage, content: "Generating persona..." });
|
||||||
|
try {
|
||||||
|
const result = await apiClient.createCandidateAI(userMessage);
|
||||||
|
console.log(result.message, result);
|
||||||
|
setGeneratedUser(result.candidate);
|
||||||
|
setResume(result.resume);
|
||||||
|
setCanGenImage(true);
|
||||||
|
setShouldGenerateProfile(true); // Reset the flag
|
||||||
|
} catch (error) {
|
||||||
|
console.error(error);
|
||||||
|
setPrompt('');
|
||||||
|
setResume(null);
|
||||||
|
setProcessing(false);
|
||||||
|
setProcessingMessage(null);
|
||||||
|
setSnack("Unable to generate AI persona", "error");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
generatePersona(value);
|
generatePersona(value);
|
||||||
}, [processing, generatePersona]);
|
}, [processing, apiClient, setSnack]);
|
||||||
|
|
||||||
const handleSendClick = useCallback(() => {
|
const handleSendClick = useCallback(() => {
|
||||||
const value = (backstoryTextRef.current && backstoryTextRef.current.getAndResetValue()) || "";
|
const value = (backstoryTextRef.current && backstoryTextRef.current.getAndResetValue()) || "";
|
||||||
@ -173,13 +132,12 @@ const GenerateCandidate = (props: BackstoryElementProps) => {
|
|||||||
setProcessingMessage({ ...defaultMessage, content: 'Starting image generation...' });
|
setProcessingMessage({ ...defaultMessage, content: 'Starting image generation...' });
|
||||||
setProcessing(true);
|
setProcessing(true);
|
||||||
setCanGenImage(false);
|
setCanGenImage(false);
|
||||||
setState(3);
|
|
||||||
|
|
||||||
const chatMessage: ChatMessageUser = {
|
const chatMessage: ChatMessageUser = {
|
||||||
sessionId: chatSession.id || '',
|
sessionId: chatSession.id || '',
|
||||||
|
role: "user",
|
||||||
status: "done",
|
status: "done",
|
||||||
type: "user",
|
type: "text",
|
||||||
sender: "user",
|
|
||||||
timestamp: new Date(),
|
timestamp: new Date(),
|
||||||
content: prompt
|
content: prompt
|
||||||
};
|
};
|
||||||
@ -187,34 +145,23 @@ const GenerateCandidate = (props: BackstoryElementProps) => {
|
|||||||
controllerRef.current = apiClient.sendMessageStream(chatMessage, {
|
controllerRef.current = apiClient.sendMessageStream(chatMessage, {
|
||||||
onMessage: async (msg: ChatMessage) => {
|
onMessage: async (msg: ChatMessage) => {
|
||||||
console.log(`onMessage: ${msg.type} ${msg.content}`, msg);
|
console.log(`onMessage: ${msg.type} ${msg.content}`, msg);
|
||||||
if (msg.type === "heartbeat" && msg.content) {
|
controllerRef.current = null;
|
||||||
const heartbeat = JSON.parse(msg.content);
|
try {
|
||||||
setTimestamp(heartbeat.timestamp);
|
await apiClient.updateCandidate(generatedUser.id || '', { profileImage: "profile.png" });
|
||||||
}
|
const { success, message } = await apiClient.deleteChatSession(chatSession.id || '');
|
||||||
if (msg.type === "thinking" && msg.content) {
|
console.log(`Profile generated for ${username} and chat session was ${!success ? 'not ' : ''} deleted: ${message}}`);
|
||||||
const status = JSON.parse(msg.content);
|
setGeneratedUser({
|
||||||
setProcessingMessage({ ...defaultMessage, content: status.message });
|
...generatedUser,
|
||||||
}
|
profileImage: "profile.png"
|
||||||
if (msg.type === "response") {
|
} as CandidateAI);
|
||||||
controllerRef.current = null;
|
setCanGenImage(true);
|
||||||
try {
|
setShouldGenerateProfile(false);
|
||||||
await apiClient.updateCandidate(generatedUser.id || '', { profileImage: "profile.png" });
|
} catch (error) {
|
||||||
const { success, message } = await apiClient.deleteChatSession(chatSession.id || '');
|
console.error(error);
|
||||||
console.log(`Profile generated for ${username} and chat session was ${!success ? 'not ' : ''} deleted: ${message}}`);
|
setSnack(`Unable to update ${username} to indicate they have a profile picture.`, "error");
|
||||||
setGeneratedUser({
|
|
||||||
...generatedUser,
|
|
||||||
profileImage: "profile.png"
|
|
||||||
} as CandidateAI);
|
|
||||||
setState(0);
|
|
||||||
setCanGenImage(true);
|
|
||||||
setShouldGenerateProfile(false);
|
|
||||||
} catch (error) {
|
|
||||||
console.error(error);
|
|
||||||
setSnack(`Unable to update ${username} to indicate they have a profile picture.`, "error");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
onError: (error) => {
|
onError: (error: string | ChatMessageError) => {
|
||||||
console.log("onError:", error);
|
console.log("onError:", error);
|
||||||
// Type-guard to determine if this is a ChatMessageBase or a string
|
// Type-guard to determine if this is a ChatMessageBase or a string
|
||||||
if (typeof error === "object" && error !== null && "content" in error) {
|
if (typeof error === "object" && error !== null && "content" in error) {
|
||||||
@ -223,27 +170,28 @@ const GenerateCandidate = (props: BackstoryElementProps) => {
|
|||||||
setSnack(error as string, "error");
|
setSnack(error as string, "error");
|
||||||
}
|
}
|
||||||
setProcessingMessage(null);
|
setProcessingMessage(null);
|
||||||
setStreaming(false);
|
|
||||||
setProcessing(false);
|
setProcessing(false);
|
||||||
controllerRef.current = null;
|
controllerRef.current = null;
|
||||||
setState(0);
|
|
||||||
setCanGenImage(true);
|
setCanGenImage(true);
|
||||||
setShouldGenerateProfile(false);
|
setShouldGenerateProfile(false);
|
||||||
},
|
},
|
||||||
onComplete: () => {
|
onComplete: () => {
|
||||||
setProcessingMessage(null);
|
setProcessingMessage(null);
|
||||||
setStreaming(false);
|
|
||||||
setProcessing(false);
|
setProcessing(false);
|
||||||
controllerRef.current = null;
|
controllerRef.current = null;
|
||||||
setState(0);
|
|
||||||
setCanGenImage(true);
|
setCanGenImage(true);
|
||||||
setShouldGenerateProfile(false);
|
setShouldGenerateProfile(false);
|
||||||
},
|
},
|
||||||
onStatusChange: (status: string) => {
|
onStatus: (status: ChatMessageStatus) => {
|
||||||
|
if (status.activity === "heartbeat" && status.content) {
|
||||||
|
setTimestamp(status.timestamp?.toISOString() || '');
|
||||||
|
} else if (status.content) {
|
||||||
|
setProcessingMessage({ ...defaultMessage, content: status.content });
|
||||||
|
}
|
||||||
console.log(`onStatusChange: ${status}`);
|
console.log(`onStatusChange: ${status}`);
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
}, [chatSession, shouldGenerateProfile, generatedUser, prompt, setSnack]);
|
}, [chatSession, shouldGenerateProfile, generatedUser, prompt, setSnack, apiClient]);
|
||||||
|
|
||||||
if (!user?.isAdmin) {
|
if (!user?.isAdmin) {
|
||||||
return (<Box>You must be logged in as an admin to generate AI candidates.</Box>);
|
return (<Box>You must be logged in as an admin to generate AI candidates.</Box>);
|
||||||
|
@ -5,8 +5,8 @@ import { ChatMessage } from 'types/types';
|
|||||||
|
|
||||||
const LoadingPage = (props: BackstoryPageProps) => {
|
const LoadingPage = (props: BackstoryPageProps) => {
|
||||||
const preamble: ChatMessage = {
|
const preamble: ChatMessage = {
|
||||||
sender: 'system',
|
role: 'assistant',
|
||||||
type: 'preparing',
|
type: 'text',
|
||||||
status: 'done',
|
status: 'done',
|
||||||
sessionId: '',
|
sessionId: '',
|
||||||
content: 'Please wait while connecting to Backstory...',
|
content: 'Please wait while connecting to Backstory...',
|
||||||
|
@ -5,8 +5,8 @@ import { ChatMessage } from 'types/types';
|
|||||||
|
|
||||||
const LoginRequired = (props: BackstoryPageProps) => {
|
const LoginRequired = (props: BackstoryPageProps) => {
|
||||||
const preamble: ChatMessage = {
|
const preamble: ChatMessage = {
|
||||||
sender: 'system',
|
role: 'assistant',
|
||||||
type: 'preparing',
|
type: 'text',
|
||||||
status: 'done',
|
status: 'done',
|
||||||
sessionId: '',
|
sessionId: '',
|
||||||
content: 'You must be logged to view this feature.',
|
content: 'You must be logged to view this feature.',
|
||||||
|
@ -39,11 +39,11 @@ import {
|
|||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
interface StreamingOptions {
|
interface StreamingOptions {
|
||||||
onStatusChange?: (status: Types.ChatStatusType) => void;
|
onStatus?: (status: Types.ChatMessageStatus) => void;
|
||||||
onMessage?: (message: Types.ChatMessage) => void;
|
onMessage?: (message: Types.ChatMessage) => void;
|
||||||
onStreaming?: (chunk: Types.ChatMessageBase) => void;
|
onStreaming?: (chunk: Types.ChatMessageStreaming) => void;
|
||||||
onComplete?: () => void;
|
onComplete?: () => void;
|
||||||
onError?: (error: string | Types.ChatMessageBase) => void;
|
onError?: (error: string | Types.ChatMessageError) => void;
|
||||||
onWarn?: (warning: string) => void;
|
onWarn?: (warning: string) => void;
|
||||||
signal?: AbortSignal;
|
signal?: AbortSignal;
|
||||||
}
|
}
|
||||||
@ -761,6 +761,20 @@ class ApiClient {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async getOrCreateChatSession(candidate: Types.Candidate, title: string, context_type: Types.ChatContextType) : Promise<Types.ChatSession> {
|
||||||
|
const result = await this.getCandidateChatSessions(candidate.username);
|
||||||
|
/* Find the 'candidate_chat' session if it exists, otherwise create it */
|
||||||
|
let session = result.sessions.data.find(session => session.title === 'candidate_chat');
|
||||||
|
if (!session) {
|
||||||
|
session = await this.createCandidateChatSession(
|
||||||
|
candidate.username,
|
||||||
|
context_type,
|
||||||
|
title
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return session;
|
||||||
|
}
|
||||||
|
|
||||||
async getCandidateSimilarContent(query: string
|
async getCandidateSimilarContent(query: string
|
||||||
): Promise<Types.ChromaDBGetResponse> {
|
): Promise<Types.ChromaDBGetResponse> {
|
||||||
const response = await fetch(`${this.baseUrl}/candidates/rag-search`, {
|
const response = await fetch(`${this.baseUrl}/candidates/rag-search`, {
|
||||||
@ -965,7 +979,7 @@ class ApiClient {
|
|||||||
* Send message with streaming response support and date conversion
|
* Send message with streaming response support and date conversion
|
||||||
*/
|
*/
|
||||||
sendMessageStream(
|
sendMessageStream(
|
||||||
chatMessage: Types.ChatMessageBase,
|
chatMessage: Types.ChatMessageUser,
|
||||||
options: StreamingOptions = {}
|
options: StreamingOptions = {}
|
||||||
): StreamingResponse {
|
): StreamingResponse {
|
||||||
const abortController = new AbortController();
|
const abortController = new AbortController();
|
||||||
@ -998,9 +1012,8 @@ class ApiClient {
|
|||||||
|
|
||||||
const decoder = new TextDecoder();
|
const decoder = new TextDecoder();
|
||||||
let buffer = '';
|
let buffer = '';
|
||||||
let incomingMessage: Types.ChatMessage | null = null;
|
let streamingMessage: Types.ChatMessageStreaming | null = null;
|
||||||
const incomingMessageList: Types.ChatMessage[] = [];
|
const incomingMessageList: Types.ChatMessage[] = [];
|
||||||
let incomingStatus : Types.ChatStatusType | null = null;
|
|
||||||
try {
|
try {
|
||||||
while (true) {
|
while (true) {
|
||||||
const { done, value } = await reader.read();
|
const { done, value } = await reader.read();
|
||||||
@ -1022,39 +1035,41 @@ class ApiClient {
|
|||||||
try {
|
try {
|
||||||
if (line.startsWith('data: ')) {
|
if (line.startsWith('data: ')) {
|
||||||
const data = line.slice(5).trim();
|
const data = line.slice(5).trim();
|
||||||
const incoming: Types.ChatMessageBase = JSON.parse(data);
|
const incoming: any = JSON.parse(data);
|
||||||
|
|
||||||
// Convert date fields for incoming messages
|
console.log(incoming.status, incoming);
|
||||||
const convertedIncoming = convertChatMessageFromApi(incoming);
|
|
||||||
|
|
||||||
// Trigger callbacks based on status
|
|
||||||
if (convertedIncoming.status !== incomingStatus) {
|
|
||||||
options.onStatusChange?.(convertedIncoming.status);
|
|
||||||
incomingStatus = convertedIncoming.status;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle different status types
|
// Handle different status types
|
||||||
switch (convertedIncoming.status) {
|
switch (incoming.status) {
|
||||||
case 'streaming':
|
case 'streaming':
|
||||||
if (incomingMessage === null) {
|
console.log(incoming.status, incoming);
|
||||||
incomingMessage = {...convertedIncoming};
|
const streaming = Types.convertChatMessageStreamingFromApi(incoming);
|
||||||
|
if (streamingMessage === null) {
|
||||||
|
streamingMessage = {...streaming};
|
||||||
} else {
|
} else {
|
||||||
// Can't do a simple += as typescript thinks .content might not be there
|
// Can't do a simple += as typescript thinks .content might not be there
|
||||||
incomingMessage.content = (incomingMessage?.content || '') + convertedIncoming.content;
|
streamingMessage.content = (streamingMessage?.content || '') + streaming.content;
|
||||||
// Update timestamp to latest
|
// Update timestamp to latest
|
||||||
incomingMessage.timestamp = convertedIncoming.timestamp;
|
streamingMessage.timestamp = streamingMessage.timestamp;
|
||||||
}
|
}
|
||||||
options.onStreaming?.(convertedIncoming);
|
options.onStreaming?.(streamingMessage);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 'status':
|
||||||
|
const status = Types.convertChatMessageStatusFromApi(incoming);
|
||||||
|
options.onStatus?.(status);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case 'error':
|
case 'error':
|
||||||
options.onError?.(convertedIncoming);
|
const error = Types.convertChatMessageErrorFromApi(incoming);
|
||||||
|
options.onError?.(error);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
default:
|
case 'done':
|
||||||
incomingMessageList.push(convertedIncoming);
|
const message = Types.convertChatMessageFromApi(incoming);
|
||||||
|
incomingMessageList.push(message);
|
||||||
try {
|
try {
|
||||||
options.onMessage?.(convertedIncoming);
|
options.onMessage?.(message);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('onMessage handler failed: ', error);
|
console.error('onMessage handler failed: ', error);
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
// Generated TypeScript types from Pydantic models
|
// Generated TypeScript types from Pydantic models
|
||||||
// Source: src/backend/models.py
|
// Source: src/backend/models.py
|
||||||
// Generated on: 2025-06-04T17:02:08.242818
|
// Generated on: 2025-06-05T00:24:02.132276
|
||||||
// DO NOT EDIT MANUALLY - This file is auto-generated
|
// DO NOT EDIT MANUALLY - This file is auto-generated
|
||||||
|
|
||||||
// ============================
|
// ============================
|
||||||
@ -11,15 +11,17 @@ export type AIModelType = "qwen2.5" | "flux-schnell";
|
|||||||
|
|
||||||
export type ActivityType = "login" | "search" | "view_job" | "apply_job" | "message" | "update_profile" | "chat";
|
export type ActivityType = "login" | "search" | "view_job" | "apply_job" | "message" | "update_profile" | "chat";
|
||||||
|
|
||||||
|
export type ApiActivityType = "system" | "info" | "searching" | "thinking" | "generating" | "generating_image" | "tooling" | "heartbeat";
|
||||||
|
|
||||||
|
export type ApiMessageType = "binary" | "text" | "json";
|
||||||
|
|
||||||
|
export type ApiStatusType = "streaming" | "status" | "done" | "error";
|
||||||
|
|
||||||
export type ApplicationStatus = "applied" | "reviewing" | "interview" | "offer" | "rejected" | "accepted" | "withdrawn";
|
export type ApplicationStatus = "applied" | "reviewing" | "interview" | "offer" | "rejected" | "accepted" | "withdrawn";
|
||||||
|
|
||||||
export type ChatContextType = "job_search" | "job_requirements" | "candidate_chat" | "interview_prep" | "resume_review" | "general" | "generate_persona" | "generate_profile" | "generate_image" | "rag_search" | "skill_match";
|
export type ChatContextType = "job_search" | "job_requirements" | "candidate_chat" | "interview_prep" | "resume_review" | "general" | "generate_persona" | "generate_profile" | "generate_image" | "rag_search" | "skill_match";
|
||||||
|
|
||||||
export type ChatMessageType = "error" | "generating" | "info" | "preparing" | "processing" | "heartbeat" | "response" | "searching" | "rag_result" | "system" | "thinking" | "tooling" | "user";
|
export type ChatSenderType = "user" | "assistant" | "system" | "information" | "warning" | "error";
|
||||||
|
|
||||||
export type ChatSenderType = "user" | "assistant" | "agent" | "system";
|
|
||||||
|
|
||||||
export type ChatStatusType = "initializing" | "streaming" | "status" | "done" | "error";
|
|
||||||
|
|
||||||
export type ColorBlindMode = "protanopia" | "deuteranopia" | "tritanopia" | "none";
|
export type ColorBlindMode = "protanopia" | "deuteranopia" | "tritanopia" | "none";
|
||||||
|
|
||||||
@ -88,6 +90,15 @@ export interface Analytics {
|
|||||||
segment?: string;
|
segment?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface ApiMessage {
|
||||||
|
id?: string;
|
||||||
|
sessionId: string;
|
||||||
|
senderId?: string;
|
||||||
|
status: "streaming" | "status" | "done" | "error";
|
||||||
|
type: "binary" | "text" | "json";
|
||||||
|
timestamp?: Date;
|
||||||
|
}
|
||||||
|
|
||||||
export interface ApiResponse {
|
export interface ApiResponse {
|
||||||
success: boolean;
|
success: boolean;
|
||||||
data?: any;
|
data?: any;
|
||||||
@ -284,24 +295,22 @@ export interface ChatMessage {
|
|||||||
id?: string;
|
id?: string;
|
||||||
sessionId: string;
|
sessionId: string;
|
||||||
senderId?: string;
|
senderId?: string;
|
||||||
status: "initializing" | "streaming" | "status" | "done" | "error";
|
status: "streaming" | "status" | "done" | "error";
|
||||||
type: "error" | "generating" | "info" | "preparing" | "processing" | "heartbeat" | "response" | "searching" | "rag_result" | "system" | "thinking" | "tooling" | "user";
|
type: "binary" | "text" | "json";
|
||||||
sender: "user" | "assistant" | "agent" | "system";
|
|
||||||
timestamp?: Date;
|
timestamp?: Date;
|
||||||
tunables?: Tunables;
|
role: "user" | "assistant" | "system" | "information" | "warning" | "error";
|
||||||
content: string;
|
content: string;
|
||||||
|
tunables?: Tunables;
|
||||||
metadata?: ChatMessageMetaData;
|
metadata?: ChatMessageMetaData;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ChatMessageBase {
|
export interface ChatMessageError {
|
||||||
id?: string;
|
id?: string;
|
||||||
sessionId: string;
|
sessionId: string;
|
||||||
senderId?: string;
|
senderId?: string;
|
||||||
status: "initializing" | "streaming" | "status" | "done" | "error";
|
status: "streaming" | "status" | "done" | "error";
|
||||||
type: "error" | "generating" | "info" | "preparing" | "processing" | "heartbeat" | "response" | "searching" | "rag_result" | "system" | "thinking" | "tooling" | "user";
|
type: "binary" | "text" | "json";
|
||||||
sender: "user" | "assistant" | "agent" | "system";
|
|
||||||
timestamp?: Date;
|
timestamp?: Date;
|
||||||
tunables?: Tunables;
|
|
||||||
content: string;
|
content: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -328,25 +337,44 @@ export interface ChatMessageRagSearch {
|
|||||||
id?: string;
|
id?: string;
|
||||||
sessionId: string;
|
sessionId: string;
|
||||||
senderId?: string;
|
senderId?: string;
|
||||||
status: "initializing" | "streaming" | "status" | "done" | "error";
|
status: "streaming" | "status" | "done" | "error";
|
||||||
type: "error" | "generating" | "info" | "preparing" | "processing" | "heartbeat" | "response" | "searching" | "rag_result" | "system" | "thinking" | "tooling" | "user";
|
type: "binary" | "text" | "json";
|
||||||
sender: "user" | "assistant" | "agent" | "system";
|
|
||||||
timestamp?: Date;
|
timestamp?: Date;
|
||||||
tunables?: Tunables;
|
|
||||||
content: string;
|
|
||||||
dimensions: number;
|
dimensions: number;
|
||||||
|
content: Array<ChromaDBGetResponse>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ChatMessageStatus {
|
||||||
|
id?: string;
|
||||||
|
sessionId: string;
|
||||||
|
senderId?: string;
|
||||||
|
status: "streaming" | "status" | "done" | "error";
|
||||||
|
type: "binary" | "text" | "json";
|
||||||
|
timestamp?: Date;
|
||||||
|
activity: "system" | "info" | "searching" | "thinking" | "generating" | "generating_image" | "tooling" | "heartbeat";
|
||||||
|
content: any;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ChatMessageStreaming {
|
||||||
|
id?: string;
|
||||||
|
sessionId: string;
|
||||||
|
senderId?: string;
|
||||||
|
status: "streaming" | "status" | "done" | "error";
|
||||||
|
type: "binary" | "text" | "json";
|
||||||
|
timestamp?: Date;
|
||||||
|
content: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ChatMessageUser {
|
export interface ChatMessageUser {
|
||||||
id?: string;
|
id?: string;
|
||||||
sessionId: string;
|
sessionId: string;
|
||||||
senderId?: string;
|
senderId?: string;
|
||||||
status: "initializing" | "streaming" | "status" | "done" | "error";
|
status: "streaming" | "status" | "done" | "error";
|
||||||
type: "error" | "generating" | "info" | "preparing" | "processing" | "heartbeat" | "response" | "searching" | "rag_result" | "system" | "thinking" | "tooling" | "user";
|
type: "binary" | "text" | "json";
|
||||||
sender: "user" | "assistant" | "agent" | "system";
|
|
||||||
timestamp?: Date;
|
timestamp?: Date;
|
||||||
tunables?: Tunables;
|
role: "user" | "assistant" | "system" | "information" | "warning" | "error";
|
||||||
content: string;
|
content: string;
|
||||||
|
tunables?: Tunables;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ChatOptions {
|
export interface ChatOptions {
|
||||||
@ -644,6 +672,20 @@ export interface JobRequirements {
|
|||||||
preferredAttributes?: Array<string>;
|
preferredAttributes?: Array<string>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface JobRequirementsMessage {
|
||||||
|
id?: string;
|
||||||
|
sessionId: string;
|
||||||
|
senderId?: string;
|
||||||
|
status: "streaming" | "status" | "done" | "error";
|
||||||
|
type: "binary" | "text" | "json";
|
||||||
|
timestamp?: Date;
|
||||||
|
title?: string;
|
||||||
|
summary?: string;
|
||||||
|
company?: string;
|
||||||
|
description: string;
|
||||||
|
requirements?: JobRequirements;
|
||||||
|
}
|
||||||
|
|
||||||
export interface JobResponse {
|
export interface JobResponse {
|
||||||
success: boolean;
|
success: boolean;
|
||||||
data?: Job;
|
data?: Job;
|
||||||
@ -923,6 +965,19 @@ export function convertAnalyticsFromApi(data: any): Analytics {
|
|||||||
timestamp: new Date(data.timestamp),
|
timestamp: new Date(data.timestamp),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
/**
|
||||||
|
* Convert ApiMessage from API response, parsing date fields
|
||||||
|
* Date fields: timestamp
|
||||||
|
*/
|
||||||
|
export function convertApiMessageFromApi(data: any): ApiMessage {
|
||||||
|
if (!data) return data;
|
||||||
|
|
||||||
|
return {
|
||||||
|
...data,
|
||||||
|
// Convert timestamp from ISO string to Date
|
||||||
|
timestamp: data.timestamp ? new Date(data.timestamp) : undefined,
|
||||||
|
};
|
||||||
|
}
|
||||||
/**
|
/**
|
||||||
* Convert ApplicationDecision from API response, parsing date fields
|
* Convert ApplicationDecision from API response, parsing date fields
|
||||||
* Date fields: date
|
* Date fields: date
|
||||||
@ -1067,10 +1122,10 @@ export function convertChatMessageFromApi(data: any): ChatMessage {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* Convert ChatMessageBase from API response, parsing date fields
|
* Convert ChatMessageError from API response, parsing date fields
|
||||||
* Date fields: timestamp
|
* Date fields: timestamp
|
||||||
*/
|
*/
|
||||||
export function convertChatMessageBaseFromApi(data: any): ChatMessageBase {
|
export function convertChatMessageErrorFromApi(data: any): ChatMessageError {
|
||||||
if (!data) return data;
|
if (!data) return data;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -1092,6 +1147,32 @@ export function convertChatMessageRagSearchFromApi(data: any): ChatMessageRagSea
|
|||||||
timestamp: data.timestamp ? new Date(data.timestamp) : undefined,
|
timestamp: data.timestamp ? new Date(data.timestamp) : undefined,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
/**
|
||||||
|
* Convert ChatMessageStatus from API response, parsing date fields
|
||||||
|
* Date fields: timestamp
|
||||||
|
*/
|
||||||
|
export function convertChatMessageStatusFromApi(data: any): ChatMessageStatus {
|
||||||
|
if (!data) return data;
|
||||||
|
|
||||||
|
return {
|
||||||
|
...data,
|
||||||
|
// Convert timestamp from ISO string to Date
|
||||||
|
timestamp: data.timestamp ? new Date(data.timestamp) : undefined,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Convert ChatMessageStreaming from API response, parsing date fields
|
||||||
|
* Date fields: timestamp
|
||||||
|
*/
|
||||||
|
export function convertChatMessageStreamingFromApi(data: any): ChatMessageStreaming {
|
||||||
|
if (!data) return data;
|
||||||
|
|
||||||
|
return {
|
||||||
|
...data,
|
||||||
|
// Convert timestamp from ISO string to Date
|
||||||
|
timestamp: data.timestamp ? new Date(data.timestamp) : undefined,
|
||||||
|
};
|
||||||
|
}
|
||||||
/**
|
/**
|
||||||
* Convert ChatMessageUser from API response, parsing date fields
|
* Convert ChatMessageUser from API response, parsing date fields
|
||||||
* Date fields: timestamp
|
* Date fields: timestamp
|
||||||
@ -1268,6 +1349,19 @@ export function convertJobFullFromApi(data: any): JobFull {
|
|||||||
featuredUntil: data.featuredUntil ? new Date(data.featuredUntil) : undefined,
|
featuredUntil: data.featuredUntil ? new Date(data.featuredUntil) : undefined,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
/**
|
||||||
|
* Convert JobRequirementsMessage from API response, parsing date fields
|
||||||
|
* Date fields: timestamp
|
||||||
|
*/
|
||||||
|
export function convertJobRequirementsMessageFromApi(data: any): JobRequirementsMessage {
|
||||||
|
if (!data) return data;
|
||||||
|
|
||||||
|
return {
|
||||||
|
...data,
|
||||||
|
// Convert timestamp from ISO string to Date
|
||||||
|
timestamp: data.timestamp ? new Date(data.timestamp) : undefined,
|
||||||
|
};
|
||||||
|
}
|
||||||
/**
|
/**
|
||||||
* Convert MessageReaction from API response, parsing date fields
|
* Convert MessageReaction from API response, parsing date fields
|
||||||
* Date fields: timestamp
|
* Date fields: timestamp
|
||||||
@ -1348,6 +1442,8 @@ export function convertFromApi<T>(data: any, modelType: string): T {
|
|||||||
switch (modelType) {
|
switch (modelType) {
|
||||||
case 'Analytics':
|
case 'Analytics':
|
||||||
return convertAnalyticsFromApi(data) as T;
|
return convertAnalyticsFromApi(data) as T;
|
||||||
|
case 'ApiMessage':
|
||||||
|
return convertApiMessageFromApi(data) as T;
|
||||||
case 'ApplicationDecision':
|
case 'ApplicationDecision':
|
||||||
return convertApplicationDecisionFromApi(data) as T;
|
return convertApplicationDecisionFromApi(data) as T;
|
||||||
case 'Attachment':
|
case 'Attachment':
|
||||||
@ -1366,10 +1462,14 @@ export function convertFromApi<T>(data: any, modelType: string): T {
|
|||||||
return convertCertificationFromApi(data) as T;
|
return convertCertificationFromApi(data) as T;
|
||||||
case 'ChatMessage':
|
case 'ChatMessage':
|
||||||
return convertChatMessageFromApi(data) as T;
|
return convertChatMessageFromApi(data) as T;
|
||||||
case 'ChatMessageBase':
|
case 'ChatMessageError':
|
||||||
return convertChatMessageBaseFromApi(data) as T;
|
return convertChatMessageErrorFromApi(data) as T;
|
||||||
case 'ChatMessageRagSearch':
|
case 'ChatMessageRagSearch':
|
||||||
return convertChatMessageRagSearchFromApi(data) as T;
|
return convertChatMessageRagSearchFromApi(data) as T;
|
||||||
|
case 'ChatMessageStatus':
|
||||||
|
return convertChatMessageStatusFromApi(data) as T;
|
||||||
|
case 'ChatMessageStreaming':
|
||||||
|
return convertChatMessageStreamingFromApi(data) as T;
|
||||||
case 'ChatMessageUser':
|
case 'ChatMessageUser':
|
||||||
return convertChatMessageUserFromApi(data) as T;
|
return convertChatMessageUserFromApi(data) as T;
|
||||||
case 'ChatSession':
|
case 'ChatSession':
|
||||||
@ -1394,6 +1494,8 @@ export function convertFromApi<T>(data: any, modelType: string): T {
|
|||||||
return convertJobApplicationFromApi(data) as T;
|
return convertJobApplicationFromApi(data) as T;
|
||||||
case 'JobFull':
|
case 'JobFull':
|
||||||
return convertJobFullFromApi(data) as T;
|
return convertJobFullFromApi(data) as T;
|
||||||
|
case 'JobRequirementsMessage':
|
||||||
|
return convertJobRequirementsMessageFromApi(data) as T;
|
||||||
case 'MessageReaction':
|
case 'MessageReaction':
|
||||||
return convertMessageReactionFromApi(data) as T;
|
return convertMessageReactionFromApi(data) as T;
|
||||||
case 'RAGConfiguration':
|
case 'RAGConfiguration':
|
||||||
|
@ -24,7 +24,7 @@ from datetime import datetime, UTC
|
|||||||
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
|
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
|
||||||
import numpy as np # type: ignore
|
import numpy as np # type: ignore
|
||||||
|
|
||||||
from models import ( LLMMessage, ChatQuery, ChatMessage, ChatOptions, ChatMessageBase, ChatMessageUser, Tunables, ChatMessageType, ChatSenderType, ChatStatusType, ChatMessageMetaData, Candidate)
|
from models import ( ApiActivityType, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatQuery, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate)
|
||||||
from logger import logger
|
from logger import logger
|
||||||
import defines
|
import defines
|
||||||
from .registry import agent_registry
|
from .registry import agent_registry
|
||||||
@ -300,12 +300,35 @@ class Agent(BaseModel, ABC):
|
|||||||
)
|
)
|
||||||
self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count)
|
self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count)
|
||||||
|
|
||||||
|
def get_rag_context(self, rag_message: ChatMessageRagSearch) -> str:
|
||||||
|
"""
|
||||||
|
Extracts the RAG context from the rag_message.
|
||||||
|
"""
|
||||||
|
if not rag_message.content:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
context = []
|
||||||
|
for chroma_results in rag_message.content:
|
||||||
|
for index, metadata in enumerate(chroma_results.metadatas):
|
||||||
|
content = "\n".join([
|
||||||
|
line.strip()
|
||||||
|
for line in chroma_results.documents[index].split("\n")
|
||||||
|
if line
|
||||||
|
]).strip()
|
||||||
|
context.append(f"""
|
||||||
|
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")}
|
||||||
|
Document reference: {chroma_results.ids[index]}
|
||||||
|
Content: {content}
|
||||||
|
""")
|
||||||
|
return "\n".join(context)
|
||||||
|
|
||||||
async def generate_rag_results(
|
async def generate_rag_results(
|
||||||
self,
|
self,
|
||||||
chat_message: ChatMessage,
|
session_id: str,
|
||||||
|
prompt: str,
|
||||||
top_k: int=defines.default_rag_top_k,
|
top_k: int=defines.default_rag_top_k,
|
||||||
threshold: float=defines.default_rag_threshold,
|
threshold: float=defines.default_rag_threshold,
|
||||||
) -> AsyncGenerator[ChatMessage, None]:
|
) -> AsyncGenerator[ChatMessageRagSearch | ChatMessageError | ChatMessageStatus, None]:
|
||||||
"""
|
"""
|
||||||
Generate RAG results for the given query.
|
Generate RAG results for the given query.
|
||||||
|
|
||||||
@ -315,223 +338,193 @@ class Agent(BaseModel, ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
A list of dictionaries containing the RAG results.
|
A list of dictionaries containing the RAG results.
|
||||||
"""
|
"""
|
||||||
rag_message = ChatMessage(
|
|
||||||
session_id=chat_message.session_id,
|
|
||||||
tunables=chat_message.tunables,
|
|
||||||
status=ChatStatusType.INITIALIZING,
|
|
||||||
type=ChatMessageType.PREPARING,
|
|
||||||
sender=ChatSenderType.ASSISTANT,
|
|
||||||
content="",
|
|
||||||
timestamp=datetime.now(UTC),
|
|
||||||
metadata=ChatMessageMetaData()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.user:
|
if not self.user:
|
||||||
logger.error("No user set for RAG generation")
|
error_message = ChatMessageError(
|
||||||
rag_message.status = ChatStatusType.DONE
|
session_id=session_id,
|
||||||
rag_message.content = ""
|
content="No user set for RAG generation."
|
||||||
yield rag_message
|
)
|
||||||
|
yield error_message
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
results : List[ChromaDBGetResponse] = []
|
||||||
entries: int = 0
|
entries: int = 0
|
||||||
user: Candidate = self.user
|
user: Candidate = self.user
|
||||||
rag_content: str = ""
|
for rag in user.rags:
|
||||||
for rag in user.rags:
|
if not rag.enabled:
|
||||||
if not rag.enabled:
|
continue
|
||||||
continue
|
|
||||||
status_message = ChatMessage(
|
|
||||||
session_id=chat_message.session_id,
|
|
||||||
sender=ChatSenderType.AGENT,
|
|
||||||
status = ChatStatusType.INITIALIZING,
|
|
||||||
type = ChatMessageType.SEARCHING,
|
|
||||||
content = f"Checking RAG context {rag.name}...")
|
|
||||||
yield status_message
|
|
||||||
|
|
||||||
|
status_message = ChatMessageStatus(
|
||||||
|
session_id=session_id,
|
||||||
|
activity=ApiActivityType.SEARCHING,
|
||||||
|
content = f"Searching RAG context {rag.name}..."
|
||||||
|
)
|
||||||
|
yield status_message
|
||||||
|
|
||||||
|
try:
|
||||||
chroma_results = user.file_watcher.find_similar(
|
chroma_results = user.file_watcher.find_similar(
|
||||||
query=chat_message.content, top_k=top_k, threshold=threshold
|
query=prompt, top_k=top_k, threshold=threshold
|
||||||
)
|
)
|
||||||
if chroma_results:
|
if not chroma_results:
|
||||||
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
|
continue
|
||||||
|
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
|
||||||
|
|
||||||
umap_2d = user.file_watcher.umap_model_2d.transform([query_embedding])[0]
|
umap_2d = user.file_watcher.umap_model_2d.transform([query_embedding])[0]
|
||||||
umap_3d = user.file_watcher.umap_model_3d.transform([query_embedding])[0]
|
umap_3d = user.file_watcher.umap_model_3d.transform([query_embedding])[0]
|
||||||
|
|
||||||
rag_metadata = ChromaDBGetResponse(
|
rag_metadata = ChromaDBGetResponse(
|
||||||
query=chat_message.content,
|
name=rag.name,
|
||||||
query_embedding=query_embedding.tolist(),
|
query_embedding=query_embedding.tolist(),
|
||||||
name=rag.name,
|
ids=chroma_results.get("ids", []),
|
||||||
ids=chroma_results.get("ids", []),
|
embeddings=chroma_results.get("embeddings", []),
|
||||||
embeddings=chroma_results.get("embeddings", []),
|
documents=chroma_results.get("documents", []),
|
||||||
documents=chroma_results.get("documents", []),
|
metadatas=chroma_results.get("metadatas", []),
|
||||||
metadatas=chroma_results.get("metadatas", []),
|
umap_embedding_2d=umap_2d.tolist(),
|
||||||
umap_embedding_2d=umap_2d.tolist(),
|
umap_embedding_3d=umap_3d.tolist(),
|
||||||
umap_embedding_3d=umap_3d.tolist(),
|
)
|
||||||
size=user.file_watcher.collection.count()
|
results.append(rag_metadata)
|
||||||
)
|
except Exception as e:
|
||||||
|
continue_message = ChatMessageStatus(
|
||||||
|
session_id=session_id,
|
||||||
|
activity=ApiActivityType.SEARCHING,
|
||||||
|
content=f"Error searching RAG context {rag.name}: {str(e)}"
|
||||||
|
)
|
||||||
|
yield continue_message
|
||||||
|
|
||||||
entries += len(rag_metadata.documents)
|
final_message = ChatMessageRagSearch(
|
||||||
rag_message.metadata.rag_results.append(rag_metadata)
|
session_id=session_id,
|
||||||
|
content=results,
|
||||||
for index, metadata in enumerate(chroma_results["metadatas"]):
|
status=ApiStatusType.DONE,
|
||||||
content = "\n".join(
|
|
||||||
[
|
|
||||||
line.strip()
|
|
||||||
for line in chroma_results["documents"][index].split("\n")
|
|
||||||
if line
|
|
||||||
]
|
|
||||||
).strip()
|
|
||||||
rag_content += f"""
|
|
||||||
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")}
|
|
||||||
Document reference: {chroma_results["ids"][index]}
|
|
||||||
Content: { content }
|
|
||||||
"""
|
|
||||||
rag_message.content = rag_content.strip()
|
|
||||||
rag_message.type = ChatMessageType.RAG_RESULT
|
|
||||||
rag_message.status = ChatStatusType.DONE
|
|
||||||
yield rag_message
|
|
||||||
return
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
rag_message.status = ChatStatusType.ERROR
|
|
||||||
rag_message.content = f"Error generating RAG results: {str(e)}"
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
logger.error(rag_message.content)
|
|
||||||
yield rag_message
|
|
||||||
return
|
|
||||||
|
|
||||||
async def llm_one_shot(self, llm: Any, model: str, user_message: ChatMessageUser, system_prompt: str, temperature=0.7):
|
|
||||||
chat_message = ChatMessage(
|
|
||||||
session_id=user_message.session_id,
|
|
||||||
tunables=user_message.tunables,
|
|
||||||
status=ChatStatusType.INITIALIZING,
|
|
||||||
type=ChatMessageType.PREPARING,
|
|
||||||
sender=ChatSenderType.AGENT,
|
|
||||||
content="",
|
|
||||||
timestamp=datetime.now(UTC)
|
|
||||||
)
|
)
|
||||||
|
yield final_message
|
||||||
|
return
|
||||||
|
|
||||||
|
async def llm_one_shot(
|
||||||
|
self,
|
||||||
|
llm: Any, model: str,
|
||||||
|
session_id: str, prompt: str, system_prompt: str,
|
||||||
|
tunables: Optional[Tunables] = None,
|
||||||
|
temperature=0.7) -> AsyncGenerator[ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessage, None]:
|
||||||
|
|
||||||
self.set_optimal_context_size(
|
self.set_optimal_context_size(
|
||||||
llm, model, prompt=chat_message.content
|
llm=llm, model=model, prompt=prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_message.metadata = ChatMessageMetaData()
|
options = ChatOptions(
|
||||||
chat_message.metadata.options = ChatOptions(
|
|
||||||
seed=8911,
|
seed=8911,
|
||||||
num_ctx=self.context_size,
|
num_ctx=self.context_size,
|
||||||
temperature=temperature, # Higher temperature to encourage tool usage
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages: List[LLMMessage] = [
|
messages: List[LLMMessage] = [
|
||||||
LLMMessage(role="system", content=system_prompt),
|
LLMMessage(role="system", content=system_prompt),
|
||||||
LLMMessage(role="user", content=user_message.content),
|
LLMMessage(role="user", content=prompt),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Reset the response for streaming
|
status_message = ChatMessageStatus(
|
||||||
chat_message.content = ""
|
session_id=session_id,
|
||||||
chat_message.type = ChatMessageType.GENERATING
|
activity=ApiActivityType.GENERATING,
|
||||||
chat_message.status = ChatStatusType.STREAMING
|
content=f"Generating response..."
|
||||||
|
)
|
||||||
|
yield status_message
|
||||||
|
|
||||||
logger.info(f"Message options: {chat_message.metadata.options.model_dump(exclude_unset=True)}")
|
logger.info(f"Message options: {options.model_dump(exclude_unset=True)}")
|
||||||
response = None
|
response = None
|
||||||
|
content = ""
|
||||||
for response in llm.chat(
|
for response in llm.chat(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
options={
|
options={
|
||||||
**chat_message.metadata.options.model_dump(exclude_unset=True),
|
**options.model_dump(exclude_unset=True),
|
||||||
},
|
},
|
||||||
stream=True,
|
stream=True,
|
||||||
):
|
):
|
||||||
if not response:
|
if not response:
|
||||||
chat_message.status = ChatStatusType.ERROR
|
error_message = ChatMessageError(
|
||||||
chat_message.content = "No response from LLM."
|
session_id=session_id,
|
||||||
yield chat_message
|
content="No response from LLM."
|
||||||
|
)
|
||||||
|
yield error_message
|
||||||
return
|
return
|
||||||
|
|
||||||
chat_message.content += response.message.content
|
content += response.message.content
|
||||||
|
|
||||||
if not response.done:
|
if not response.done:
|
||||||
chat_chunk = model_cast.cast_to_model(ChatMessageBase, chat_message)
|
streaming_message = ChatMessageStreaming(
|
||||||
chat_chunk.content = response.message.content
|
session_id=session_id,
|
||||||
yield chat_message
|
content=response.message.content,
|
||||||
continue
|
status=ApiStatusType.STREAMING,
|
||||||
|
)
|
||||||
|
yield streaming_message
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
chat_message.status = ChatStatusType.ERROR
|
error_message = ChatMessageError(
|
||||||
chat_message.content = "No response from LLM."
|
session_id=session_id,
|
||||||
yield chat_message
|
content="No response from LLM."
|
||||||
|
)
|
||||||
|
yield error_message
|
||||||
return
|
return
|
||||||
|
|
||||||
self.collect_metrics(response)
|
self.collect_metrics(response)
|
||||||
chat_message.metadata.eval_count += response.eval_count
|
|
||||||
chat_message.metadata.eval_duration += response.eval_duration
|
|
||||||
chat_message.metadata.prompt_eval_count += response.prompt_eval_count
|
|
||||||
chat_message.metadata.prompt_eval_duration += response.prompt_eval_duration
|
|
||||||
self.context_tokens = (
|
self.context_tokens = (
|
||||||
response.prompt_eval_count + response.eval_count
|
response.prompt_eval_count + response.eval_count
|
||||||
)
|
)
|
||||||
chat_message.type = ChatMessageType.RESPONSE
|
|
||||||
chat_message.status = ChatStatusType.DONE
|
|
||||||
yield chat_message
|
|
||||||
|
|
||||||
async def generate(
|
|
||||||
self, llm: Any, model: str, user_message: ChatMessageUser, user: Candidate | None, temperature=0.7
|
|
||||||
) -> AsyncGenerator[ChatMessage | ChatMessageBase, None]:
|
|
||||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
|
||||||
|
|
||||||
chat_message = ChatMessage(
|
chat_message = ChatMessage(
|
||||||
session_id=user_message.session_id,
|
session_id=session_id,
|
||||||
tunables=user_message.tunables,
|
tunables=tunables,
|
||||||
status=ChatStatusType.INITIALIZING,
|
status=ApiStatusType.DONE,
|
||||||
type=ChatMessageType.PREPARING,
|
content=content,
|
||||||
sender=ChatSenderType.ASSISTANT,
|
metadata = ChatMessageMetaData(
|
||||||
content="",
|
options=options,
|
||||||
timestamp=datetime.now(UTC)
|
eval_count=response.eval_count,
|
||||||
)
|
eval_duration=response.eval_duration,
|
||||||
|
prompt_eval_count=response.prompt_eval_count,
|
||||||
|
prompt_eval_duration=response.prompt_eval_duration,
|
||||||
|
|
||||||
self.set_optimal_context_size(
|
)
|
||||||
llm, model, prompt=chat_message.content
|
|
||||||
)
|
)
|
||||||
|
yield chat_message
|
||||||
|
return
|
||||||
|
|
||||||
chat_message.metadata = ChatMessageMetaData()
|
async def generate(
|
||||||
chat_message.metadata.options = ChatOptions(
|
self, llm: Any, model: str,
|
||||||
seed=8911,
|
session_id: str, prompt: str,
|
||||||
num_ctx=self.context_size,
|
tunables: Optional[Tunables] = None,
|
||||||
temperature=temperature, # Higher temperature to encourage tool usage
|
temperature=0.7
|
||||||
|
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
|
||||||
|
if not self.user:
|
||||||
|
error_message = ChatMessageError(
|
||||||
|
session_id=session_id,
|
||||||
|
content="No user set for chat generation."
|
||||||
|
)
|
||||||
|
yield error_message
|
||||||
|
return
|
||||||
|
|
||||||
|
user_message = ChatMessageUser(
|
||||||
|
session_id=session_id,
|
||||||
|
content=prompt,
|
||||||
)
|
)
|
||||||
|
user = self.user
|
||||||
# Create a dict for storing various timing stats
|
|
||||||
chat_message.metadata.timers = {}
|
|
||||||
|
|
||||||
self.metrics.generate_count.labels(agent=self.agent_type).inc()
|
self.metrics.generate_count.labels(agent=self.agent_type).inc()
|
||||||
with self.metrics.generate_duration.labels(agent=self.agent_type).time():
|
with self.metrics.generate_duration.labels(agent=self.agent_type).time():
|
||||||
|
context = None
|
||||||
|
if self.user:
|
||||||
|
rag_message = None
|
||||||
|
async for rag_message in self.generate_rag_results(session_id=session_id, prompt=prompt):
|
||||||
|
if rag_message.status == ApiStatusType.ERROR:
|
||||||
|
yield rag_message
|
||||||
|
return
|
||||||
|
# Only yield messages that are in a streaming state
|
||||||
|
if rag_message.status == ApiStatusType.STATUS:
|
||||||
|
yield rag_message
|
||||||
|
|
||||||
rag_message : Optional[ChatMessage] = None
|
if not isinstance(rag_message, ChatMessageRagSearch):
|
||||||
async for rag_message in self.generate_rag_results(chat_message=user_message):
|
raise ValueError(
|
||||||
if rag_message.status == ChatStatusType.ERROR:
|
f"Expected ChatMessageRagSearch, got {type(rag_message)}"
|
||||||
chat_message.status = rag_message.status
|
)
|
||||||
chat_message.content = rag_message.content
|
|
||||||
yield chat_message
|
|
||||||
return
|
|
||||||
yield rag_message
|
|
||||||
|
|
||||||
rag_context = ""
|
context = self.get_rag_context(rag_message)
|
||||||
if rag_message:
|
|
||||||
rag_results: List[ChromaDBGetResponse] = rag_message.metadata.rag_results
|
|
||||||
chat_message.metadata.rag_results = rag_results
|
|
||||||
for chroma_results in rag_results:
|
|
||||||
for index, metadata in enumerate(chroma_results.metadatas):
|
|
||||||
content = "\n".join([
|
|
||||||
line.strip()
|
|
||||||
for line in chroma_results.documents[index].split("\n")
|
|
||||||
if line
|
|
||||||
]).strip()
|
|
||||||
rag_context += f"""
|
|
||||||
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")}
|
|
||||||
Document reference: {chroma_results.ids[index]}
|
|
||||||
Content: { content }
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Create a pruned down message list based purely on the prompt and responses,
|
# Create a pruned down message list based purely on the prompt and responses,
|
||||||
# discarding the full preamble generated by prepare_message
|
# discarding the full preamble generated by prepare_message
|
||||||
@ -540,24 +533,24 @@ Content: { content }
|
|||||||
]
|
]
|
||||||
# Add the conversation history to the messages
|
# Add the conversation history to the messages
|
||||||
messages.extend([
|
messages.extend([
|
||||||
LLMMessage(role=m.sender, content=m.content.strip())
|
LLMMessage(role="user" if isinstance(m, ChatMessageUser) else "assistant", content=m.content)
|
||||||
for m in self.conversation
|
for m in self.conversation
|
||||||
])
|
])
|
||||||
# Add the RAG context to the messages if available
|
# Add the RAG context to the messages if available
|
||||||
if rag_context and user:
|
if context:
|
||||||
messages.append(
|
messages.append(
|
||||||
LLMMessage(
|
LLMMessage(
|
||||||
role="user",
|
role="user",
|
||||||
content=f"<|context|>\nThe following is context information about {user.full_name}:\n{rag_context.strip()}\n</|context|>\n\nPrompt to respond to:\n{user_message.content.strip()}\n"
|
content=f"<|context|>\nThe following is context information about {self.user.full_name}:\n{context}\n</|context|>\n\nPrompt to respond to:\n{prompt}\n"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Only the actual user query is provided with the full context message
|
# Only the actual user query is provided with the full context message
|
||||||
messages.append(
|
messages.append(
|
||||||
LLMMessage(role=user_message.sender, content=user_message.content.strip())
|
LLMMessage(role="user", content=prompt)
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_message.metadata.llm_history = messages
|
llm_history = messages
|
||||||
|
|
||||||
# use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
|
# use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
|
||||||
# message.metadata.tools = {
|
# message.metadata.tools = {
|
||||||
@ -660,57 +653,89 @@ Content: { content }
|
|||||||
# return
|
# return
|
||||||
|
|
||||||
# not use_tools
|
# not use_tools
|
||||||
chat_message.type = ChatMessageType.THINKING
|
status_message = ChatMessageStatus(
|
||||||
chat_message.content = f"Generating response..."
|
session_id=session_id,
|
||||||
yield chat_message
|
activity=ApiActivityType.GENERATING,
|
||||||
|
content=f"Generating response..."
|
||||||
|
)
|
||||||
|
yield status_message
|
||||||
|
|
||||||
# Reset the response for streaming
|
# Set the response for streaming
|
||||||
chat_message.content = ""
|
self.set_optimal_context_size(
|
||||||
|
llm, model, prompt=prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
options = ChatOptions(
|
||||||
|
seed=8911,
|
||||||
|
num_ctx=self.context_size,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
logger.info(f"Message options: {options.model_dump(exclude_unset=True)}")
|
||||||
|
content = ""
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
chat_message.type = ChatMessageType.GENERATING
|
response = None
|
||||||
chat_message.status = ChatStatusType.STREAMING
|
|
||||||
|
|
||||||
for response in llm.chat(
|
for response in llm.chat(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
options={
|
options={
|
||||||
**chat_message.metadata.options.model_dump(exclude_unset=True),
|
**options.model_dump(exclude_unset=True),
|
||||||
},
|
},
|
||||||
stream=True,
|
stream=True,
|
||||||
):
|
):
|
||||||
if not response:
|
if not response:
|
||||||
chat_message.status = ChatStatusType.ERROR
|
error_message = ChatMessageError(
|
||||||
chat_message.content = "No response from LLM."
|
session_id=session_id,
|
||||||
yield chat_message
|
content="No response from LLM."
|
||||||
|
)
|
||||||
|
yield error_message
|
||||||
return
|
return
|
||||||
|
|
||||||
chat_message.content += response.message.content
|
content += response.message.content
|
||||||
|
|
||||||
if not response.done:
|
if not response.done:
|
||||||
chat_chunk = model_cast.cast_to_model(ChatMessageBase, chat_message)
|
streaming_message = ChatMessageStreaming(
|
||||||
chat_chunk.content = response.message.content
|
session_id=session_id,
|
||||||
yield chat_message
|
content=response.message.content,
|
||||||
continue
|
|
||||||
|
|
||||||
if response.done:
|
|
||||||
self.collect_metrics(response)
|
|
||||||
chat_message.metadata.eval_count += response.eval_count
|
|
||||||
chat_message.metadata.eval_duration += response.eval_duration
|
|
||||||
chat_message.metadata.prompt_eval_count += response.prompt_eval_count
|
|
||||||
chat_message.metadata.prompt_eval_duration += response.prompt_eval_duration
|
|
||||||
self.context_tokens = (
|
|
||||||
response.prompt_eval_count + response.eval_count
|
|
||||||
)
|
)
|
||||||
chat_message.type = ChatMessageType.RESPONSE
|
yield streaming_message
|
||||||
chat_message.status = ChatStatusType.DONE
|
|
||||||
yield chat_message
|
|
||||||
|
|
||||||
|
if not response:
|
||||||
|
error_message = ChatMessageError(
|
||||||
|
session_id=session_id,
|
||||||
|
content="No response from LLM."
|
||||||
|
)
|
||||||
|
yield error_message
|
||||||
|
return
|
||||||
|
|
||||||
|
self.collect_metrics(response)
|
||||||
|
self.context_tokens = (
|
||||||
|
response.prompt_eval_count + response.eval_count
|
||||||
|
)
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
chat_message.metadata.timers["streamed"] = end_time - start_time
|
|
||||||
|
chat_message = ChatMessage(
|
||||||
|
session_id=session_id,
|
||||||
|
tunables=tunables,
|
||||||
|
status=ApiStatusType.DONE,
|
||||||
|
content=content,
|
||||||
|
metadata = ChatMessageMetaData(
|
||||||
|
options=options,
|
||||||
|
eval_count=response.eval_count,
|
||||||
|
eval_duration=response.eval_duration,
|
||||||
|
prompt_eval_count=response.prompt_eval_count,
|
||||||
|
prompt_eval_duration=response.prompt_eval_duration,
|
||||||
|
timers={
|
||||||
|
"llm_streamed": end_time - start_time,
|
||||||
|
"llm_with_tools": 0, # Placeholder for tool processing time
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Add the user and chat messages to the conversation
|
# Add the user and chat messages to the conversation
|
||||||
self.conversation.append(user_message)
|
self.conversation.append(user_message)
|
||||||
self.conversation.append(chat_message)
|
self.conversation.append(chat_message)
|
||||||
|
yield chat_message
|
||||||
return
|
return
|
||||||
|
|
||||||
# async def process_message(
|
# async def process_message(
|
||||||
|
@ -7,7 +7,7 @@ from .base import Agent, agent_registry
|
|||||||
from logger import logger
|
from logger import logger
|
||||||
|
|
||||||
from .registry import agent_registry
|
from .registry import agent_registry
|
||||||
from models import ( ChatQuery, ChatMessage, Tunables, ChatStatusType, ChatMessageUser, Candidate)
|
from models import ( ChatMessageError, ChatMessageStatus, ChatMessageStreaming, ChatQuery, ChatMessage, Tunables, ApiStatusType, ChatMessageUser, Candidate)
|
||||||
|
|
||||||
|
|
||||||
system_message = f"""
|
system_message = f"""
|
||||||
@ -34,8 +34,15 @@ class CandidateChat(Agent):
|
|||||||
system_prompt: str = system_message
|
system_prompt: str = system_message
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self, llm: Any, model: str, user_message: ChatMessageUser, user: Candidate, temperature=0.7
|
self, llm: Any, model: str,
|
||||||
):
|
session_id: str, prompt: str,
|
||||||
|
tunables: Optional[Tunables] = None,
|
||||||
|
temperature=0.7
|
||||||
|
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
|
||||||
|
user = self.user
|
||||||
|
if not user:
|
||||||
|
logger.error("User is not set for CandidateChat agent.")
|
||||||
|
raise ValueError("User must be set before generating candidate chat responses.")
|
||||||
self.system_prompt = f"""
|
self.system_prompt = f"""
|
||||||
You are a helpful expert system representing a {user.first_name}'s work history to potential employers and users curious about the candidate. You want to incorporate as many facts and details about {user.first_name} as possible.
|
You are a helpful expert system representing a {user.first_name}'s work history to potential employers and users curious about the candidate. You want to incorporate as many facts and details about {user.first_name} as possible.
|
||||||
|
|
||||||
@ -49,7 +56,10 @@ Use that spelling instead of any spelling you may find in the <|context|>.
|
|||||||
{system_message}
|
{system_message}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async for message in super().generate(llm, model, user_message, user, temperature):
|
async for message in super().generate(llm=llm, model=model, session_id=session_id, prompt=prompt, temperature=temperature, tunables=tunables):
|
||||||
|
if message.status == ApiStatusType.ERROR:
|
||||||
|
yield message
|
||||||
|
return
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
# Register the base agent
|
# Register the base agent
|
||||||
|
@ -7,7 +7,7 @@ from .base import Agent, agent_registry
|
|||||||
from logger import logger
|
from logger import logger
|
||||||
|
|
||||||
from .registry import agent_registry
|
from .registry import agent_registry
|
||||||
from models import ( ChatQuery, ChatMessage, Tunables, ChatStatusType)
|
from models import ( ChatQuery, ChatMessage, Tunables, ApiStatusType)
|
||||||
|
|
||||||
system_message = f"""
|
system_message = f"""
|
||||||
Launched on {datetime.now().isoformat()}.
|
Launched on {datetime.now().isoformat()}.
|
||||||
|
@ -25,7 +25,7 @@ import os
|
|||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
from .base import Agent, agent_registry, LLMMessage
|
from .base import Agent, agent_registry, LLMMessage
|
||||||
from models import Candidate, ChatMessage, ChatMessageBase, ChatMessageMetaData, ChatMessageType, ChatMessageUser, ChatOptions, ChatSenderType, ChatStatusType
|
from models import ActivityType, ApiActivityType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, Tunables
|
||||||
import model_cast
|
import model_cast
|
||||||
from logger import logger
|
from logger import logger
|
||||||
import defines
|
import defines
|
||||||
@ -44,43 +44,48 @@ class ImageGenerator(Agent):
|
|||||||
system_prompt: str = "" # No system prompt is used
|
system_prompt: str = "" # No system prompt is used
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self, llm: Any, model: str, user_message: ChatMessageUser, user: Candidate, temperature=0.7
|
self, llm: Any, model: str,
|
||||||
) -> AsyncGenerator[ChatMessage, None]:
|
session_id: str, prompt: str,
|
||||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
tunables: Optional[Tunables] = None,
|
||||||
|
temperature=0.7
|
||||||
|
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
|
||||||
|
if not self.user:
|
||||||
|
logger.error("User is not set for ImageGenerator agent.")
|
||||||
|
raise ValueError("User must be set before generating images.")
|
||||||
|
user = self.user
|
||||||
|
|
||||||
file_path = os.path.join(defines.user_dir, user.username, "profile.png")
|
file_path = os.path.join(defines.user_dir, user.username, "profile.png")
|
||||||
chat_message = ChatMessage(
|
|
||||||
session_id=user_message.session_id,
|
|
||||||
tunables=user_message.tunables,
|
|
||||||
status=ChatStatusType.INITIALIZING,
|
|
||||||
type=ChatMessageType.PREPARING,
|
|
||||||
sender=ChatSenderType.ASSISTANT,
|
|
||||||
content="",
|
|
||||||
timestamp=datetime.now(UTC)
|
|
||||||
)
|
|
||||||
|
|
||||||
chat_message.metadata = ChatMessageMetaData()
|
|
||||||
try:
|
try:
|
||||||
#
|
#
|
||||||
# Generate the profile picture
|
# Generate the profile picture
|
||||||
#
|
#
|
||||||
chat_message.content = f"Generating: {user_message.content}"
|
status_message = ChatMessageStatus(
|
||||||
yield chat_message
|
session_id=session_id,
|
||||||
|
activity=ApiActivityType.GENERATING_IMAGE,
|
||||||
|
content=f"Generating {prompt}...",
|
||||||
|
)
|
||||||
|
yield status_message
|
||||||
|
|
||||||
logger.info(f"Image generation: {file_path} <- {user_message.content}")
|
logger.info(f"Image generation: {file_path} <- {prompt}")
|
||||||
request = ImageRequest(filepath=file_path, prompt=user_message.content, iterations=4, height=256, width=256, guidance_scale=7.5)
|
request = ImageRequest(filepath=file_path, session_id=session_id, prompt=prompt, iterations=4, height=256, width=256, guidance_scale=7.5)
|
||||||
generated_message = None
|
generated_message = None
|
||||||
async for generated_message in generate_image(
|
async for generated_message in generate_image(
|
||||||
user_message=user_message,
|
|
||||||
request=request
|
request=request
|
||||||
):
|
):
|
||||||
if generated_message.status != "done":
|
if generated_message.status == ApiStatusType.ERROR:
|
||||||
yield generated_message
|
yield generated_message
|
||||||
|
return
|
||||||
|
if generated_message.status != ApiStatusType.DONE:
|
||||||
|
yield generated_message
|
||||||
|
continue
|
||||||
|
|
||||||
if generated_message is None:
|
if generated_message is None:
|
||||||
chat_message.status = ChatStatusType.ERROR
|
error_message = ChatMessageError(
|
||||||
chat_message.content = "Image generation failed."
|
session_id=session_id,
|
||||||
yield chat_message
|
content="Image generation failed to produce a valid response."
|
||||||
|
)
|
||||||
|
logger.error(f"⚠️ {error_message.content}")
|
||||||
|
yield error_message
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Image generation done...")
|
logger.info("Image generation done...")
|
||||||
@ -88,18 +93,23 @@ class ImageGenerator(Agent):
|
|||||||
user.profile_image = "profile.png"
|
user.profile_image = "profile.png"
|
||||||
|
|
||||||
# Image generated
|
# Image generated
|
||||||
generated_message.status = ChatStatusType.DONE
|
generated_image = ChatMessage(
|
||||||
generated_message.content = f"{defines.api_prefix}/profile/{user.username}"
|
session_id=session_id,
|
||||||
yield generated_message
|
status=ApiStatusType.DONE,
|
||||||
|
content = f"{defines.api_prefix}/profile/{user.username}",
|
||||||
|
metadata=generated_message.metadata
|
||||||
|
)
|
||||||
|
yield generated_image
|
||||||
return
|
return
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
chat_message.status = ChatStatusType.ERROR
|
error_message = ChatMessageError(
|
||||||
|
session_id=session_id,
|
||||||
|
content=f"Error generating image: {str(e)}"
|
||||||
|
)
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
logger.error(chat_message.content)
|
logger.error(f"⚠️ {error_message.content}")
|
||||||
chat_message.content = f"Error in image generation: {str(e)}"
|
yield error_message
|
||||||
logger.error(chat_message.content)
|
|
||||||
yield chat_message
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Register the base agent
|
# Register the base agent
|
||||||
|
@ -27,7 +27,7 @@ import random
|
|||||||
from names_dataset import NameDataset, NameWrapper # type: ignore
|
from names_dataset import NameDataset, NameWrapper # type: ignore
|
||||||
|
|
||||||
from .base import Agent, agent_registry, LLMMessage
|
from .base import Agent, agent_registry, LLMMessage
|
||||||
from models import Candidate, ChatMessage, ChatMessageBase, ChatMessageMetaData, ChatMessageType, ChatMessageUser, ChatOptions, ChatSenderType, ChatStatusType
|
from models import ApiActivityType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, Tunables
|
||||||
import model_cast
|
import model_cast
|
||||||
from logger import logger
|
from logger import logger
|
||||||
import defines
|
import defines
|
||||||
@ -316,14 +316,16 @@ class GeneratePersona(Agent):
|
|||||||
self.full_name = f"{self.first_name} {self.last_name}"
|
self.full_name = f"{self.first_name} {self.last_name}"
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self, llm: Any, model: str, user_message: ChatMessageUser, user: Candidate, temperature=0.7
|
self, llm: Any, model: str,
|
||||||
):
|
session_id: str, prompt: str,
|
||||||
|
tunables: Optional[Tunables] = None,
|
||||||
|
temperature=0.7
|
||||||
|
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
|
||||||
self.randomize()
|
self.randomize()
|
||||||
|
|
||||||
status_message = ChatMessage(session_id=user_message.session_id)
|
original_prompt = prompt.strip()
|
||||||
original_prompt = user_message.content
|
|
||||||
|
|
||||||
user_message.content = f"""\
|
prompt = f"""\
|
||||||
```json
|
```json
|
||||||
{json.dumps({
|
{json.dumps({
|
||||||
"age": self.age,
|
"age": self.age,
|
||||||
@ -337,7 +339,7 @@ class GeneratePersona(Agent):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if original_prompt:
|
if original_prompt:
|
||||||
user_message.content += f"""
|
prompt += f"""
|
||||||
Incorporate the following into the job description: {original_prompt}
|
Incorporate the following into the job description: {original_prompt}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -348,16 +350,24 @@ Incorporate the following into the job description: {original_prompt}
|
|||||||
generating_message = None
|
generating_message = None
|
||||||
async for generating_message in self.llm_one_shot(
|
async for generating_message in self.llm_one_shot(
|
||||||
llm=llm, model=model,
|
llm=llm, model=model,
|
||||||
user_message=user_message,
|
session_id=session_id,
|
||||||
|
prompt=prompt,
|
||||||
system_prompt=generate_persona_system_prompt,
|
system_prompt=generate_persona_system_prompt,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
):
|
):
|
||||||
if generating_message.status == ChatStatusType.ERROR:
|
if generating_message.status == ApiStatusType.ERROR:
|
||||||
logger.error(f"Error generating persona: {generating_message.content}")
|
logger.error(f"Error generating persona: {generating_message.content}")
|
||||||
raise Exception(generating_message.content)
|
yield generating_message
|
||||||
|
if generating_message.status != ApiStatusType.DONE:
|
||||||
|
yield generating_message
|
||||||
|
|
||||||
if not generating_message:
|
if not generating_message:
|
||||||
raise Exception("No response from LLM during persona generation")
|
error_message = ChatMessageError(
|
||||||
|
session_id=session_id,
|
||||||
|
content="Persona generation failed to generate a response."
|
||||||
|
)
|
||||||
|
yield error_message
|
||||||
|
return
|
||||||
|
|
||||||
json_str = self.extract_json_from_text(generating_message.content)
|
json_str = self.extract_json_from_text(generating_message.content)
|
||||||
try:
|
try:
|
||||||
@ -405,37 +415,39 @@ Incorporate the following into the job description: {original_prompt}
|
|||||||
persona["location"] = None
|
persona["location"] = None
|
||||||
persona["is_ai"] = True
|
persona["is_ai"] = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
generating_message.content = f"Unable to parse LLM returned content: {json_str} {str(e)}"
|
error_message = ChatMessageError(
|
||||||
generating_message.status = ChatStatusType.ERROR
|
session_id=session_id,
|
||||||
|
content=f"Error parsing LLM response: {str(e)}\n\n{json_str}"
|
||||||
|
)
|
||||||
|
logger.error(f"❌ Error parsing LLM response: {error_message.content}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
logger.error(generating_message.content)
|
logger.error(generating_message.content)
|
||||||
yield generating_message
|
yield error_message
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"✅ Persona for {persona['username']} generated successfully")
|
logger.info(f"✅ Persona for {persona['username']} generated successfully")
|
||||||
|
|
||||||
# Persona generated
|
# Persona generated
|
||||||
status_message = ChatMessage(
|
persona_message = ChatMessage(
|
||||||
session_id=user_message.session_id,
|
session_id=session_id,
|
||||||
status = ChatStatusType.STATUS,
|
status=ApiStatusType.DONE,
|
||||||
type = ChatMessageType.RESPONSE,
|
type=ApiMessageType.JSON,
|
||||||
content = json.dumps(persona)
|
content = json.dumps(persona)
|
||||||
)
|
)
|
||||||
yield status_message
|
yield persona_message
|
||||||
|
|
||||||
#
|
#
|
||||||
# Generate the resume
|
# Generate the resume
|
||||||
#
|
#
|
||||||
status_message = ChatMessage(
|
status_message = ChatMessageStatus(
|
||||||
session_id=user_message.session_id,
|
session_id=session_id,
|
||||||
status = ChatStatusType.STATUS,
|
activity = ApiActivityType.THINKING,
|
||||||
type = ChatMessageType.THINKING,
|
|
||||||
content = f"Generating resume for {persona['full_name']}..."
|
content = f"Generating resume for {persona['full_name']}..."
|
||||||
)
|
)
|
||||||
logger.info(f"🤖 {status_message.content}")
|
logger.info(f"🤖 {status_message.content}")
|
||||||
yield status_message
|
yield status_message
|
||||||
|
|
||||||
user_message.content = f"""
|
content = f"""
|
||||||
```json
|
```json
|
||||||
{{
|
{{
|
||||||
"full_name": "{persona["full_name"]}",
|
"full_name": "{persona["full_name"]}",
|
||||||
@ -449,30 +461,39 @@ Incorporate the following into the job description: {original_prompt}
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
if original_prompt:
|
if original_prompt:
|
||||||
user_message.content += f"""
|
content += f"""
|
||||||
Make sure at least one of the candidate's job descriptions take into account the following: {original_prompt}."""
|
Make sure at least one of the candidate's job descriptions take into account the following: {original_prompt}."""
|
||||||
|
|
||||||
async for generating_message in self.llm_one_shot(
|
async for generating_message in self.llm_one_shot(
|
||||||
llm=llm, model=model,
|
llm=llm, model=model,
|
||||||
user_message=user_message,
|
session_id=session_id,
|
||||||
|
prompt=content,
|
||||||
system_prompt=generate_resume_system_prompt,
|
system_prompt=generate_resume_system_prompt,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
):
|
):
|
||||||
if generating_message.status == ChatStatusType.ERROR:
|
if generating_message.status == ApiStatusType.ERROR:
|
||||||
logger.error(f"❌ Error generating resume: {generating_message.content}")
|
logger.error(f"❌ Error generating resume: {generating_message.content}")
|
||||||
raise Exception(generating_message.content)
|
raise Exception(generating_message.content)
|
||||||
|
if generating_message.status != ApiStatusType.DONE:
|
||||||
|
yield generating_message
|
||||||
|
|
||||||
if not generating_message:
|
if not generating_message:
|
||||||
raise Exception("No response from LLM during persona generation")
|
error_message = ChatMessageError(
|
||||||
|
session_id=session_id,
|
||||||
|
content="Resume generation failed to generate a response."
|
||||||
|
)
|
||||||
|
logger.error(f"❌ {error_message.content}")
|
||||||
|
yield error_message
|
||||||
|
return
|
||||||
|
|
||||||
resume = self.extract_markdown_from_text(generating_message.content)
|
resume = self.extract_markdown_from_text(generating_message.content)
|
||||||
status_message = ChatMessage(
|
resume_message = ChatMessage(
|
||||||
session_id=user_message.session_id,
|
session_id=session_id,
|
||||||
status=ChatStatusType.DONE,
|
status=ApiStatusType.DONE,
|
||||||
type=ChatMessageType.RESPONSE,
|
type=ApiMessageType.TEXT,
|
||||||
content=resume
|
content=resume
|
||||||
)
|
)
|
||||||
yield status_message
|
yield resume_message
|
||||||
return
|
return
|
||||||
|
|
||||||
def extract_json_from_text(self, text: str) -> str:
|
def extract_json_from_text(self, text: str) -> str:
|
||||||
|
@ -20,7 +20,7 @@ import asyncio
|
|||||||
import numpy as np # type: ignore
|
import numpy as np # type: ignore
|
||||||
|
|
||||||
from .base import Agent, agent_registry, LLMMessage
|
from .base import Agent, agent_registry, LLMMessage
|
||||||
from models import Candidate, ChatMessage, ChatMessageBase, ChatMessageMetaData, ChatMessageType, ChatMessageUser, ChatOptions, ChatSenderType, ChatStatusType, JobRequirements
|
from models import Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, JobRequirements, JobRequirementsMessage, Tunables
|
||||||
import model_cast
|
import model_cast
|
||||||
from logger import logger
|
from logger import logger
|
||||||
import defines
|
import defines
|
||||||
@ -81,56 +81,51 @@ class JobRequirementsAgent(Agent):
|
|||||||
return system_prompt, prompt
|
return system_prompt, prompt
|
||||||
|
|
||||||
async def analyze_job_requirements(
|
async def analyze_job_requirements(
|
||||||
self, llm: Any, model: str, user_message: ChatMessage
|
self, llm: Any, model: str, session_id: str, prompt: str
|
||||||
) -> AsyncGenerator[ChatMessage, None]:
|
) -> AsyncGenerator[ChatMessage | ChatMessageError, None]:
|
||||||
"""Analyze job requirements from job description."""
|
"""Analyze job requirements from job description."""
|
||||||
system_prompt, prompt = self.create_job_analysis_prompt(user_message.content)
|
system_prompt, prompt = self.create_job_analysis_prompt(prompt)
|
||||||
analyze_message = user_message.model_copy()
|
|
||||||
analyze_message.content = prompt
|
|
||||||
generated_message = None
|
generated_message = None
|
||||||
async for generated_message in self.llm_one_shot(llm, model, system_prompt=system_prompt, user_message=analyze_message):
|
async for generated_message in self.llm_one_shot(llm, model, session_id=session_id, prompt=prompt, system_prompt=system_prompt):
|
||||||
if generated_message.status == ChatStatusType.ERROR:
|
if generated_message.status == ApiStatusType.ERROR:
|
||||||
generated_message.content = "Error analyzing job requirements."
|
|
||||||
yield generated_message
|
yield generated_message
|
||||||
return
|
return
|
||||||
|
if generated_message.status != ApiStatusType.DONE:
|
||||||
|
yield generated_message
|
||||||
|
|
||||||
if not generated_message:
|
if not generated_message:
|
||||||
status_message = ChatMessage(
|
error_message = ChatMessageError(
|
||||||
session_id=user_message.session_id,
|
session_id=session_id,
|
||||||
sender=ChatSenderType.AGENT,
|
content="Job requirements analysis failed to generate a response.")
|
||||||
status = ChatStatusType.ERROR,
|
logger.error(f"⚠️ {error_message.content}")
|
||||||
type = ChatMessageType.ERROR,
|
yield error_message
|
||||||
content = "Job requirements analysis failed to generate a response.")
|
|
||||||
yield status_message
|
|
||||||
return
|
return
|
||||||
|
|
||||||
generated_message.status = ChatStatusType.DONE
|
|
||||||
generated_message.type = ChatMessageType.RESPONSE
|
|
||||||
yield generated_message
|
yield generated_message
|
||||||
return
|
return
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self, llm: Any, model: str, user_message: ChatMessageUser, user: Candidate | None, temperature=0.7
|
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
|
||||||
) -> AsyncGenerator[ChatMessage, None]:
|
) -> AsyncGenerator[ChatMessage, None]:
|
||||||
# Stage 1A: Analyze job requirements
|
# Stage 1A: Analyze job requirements
|
||||||
status_message = ChatMessage(
|
status_message = ChatMessageStatus(
|
||||||
session_id=user_message.session_id,
|
session_id=session_id,
|
||||||
sender=ChatSenderType.AGENT,
|
|
||||||
status=ChatStatusType.STATUS,
|
|
||||||
type=ChatMessageType.THINKING,
|
|
||||||
content = f"Analyzing job requirements")
|
content = f"Analyzing job requirements")
|
||||||
yield status_message
|
yield status_message
|
||||||
|
|
||||||
generated_message = None
|
generated_message = None
|
||||||
async for generated_message in self.analyze_job_requirements(llm, model, user_message):
|
async for generated_message in self.analyze_job_requirements(llm, model, session_id, prompt):
|
||||||
if generated_message.status == ChatStatusType.ERROR:
|
if generated_message.status == ApiStatusType.ERROR:
|
||||||
status_message.status = ChatStatusType.ERROR
|
yield generated_message
|
||||||
status_message.content = generated_message.content
|
|
||||||
yield status_message
|
|
||||||
return
|
return
|
||||||
|
if generated_message.status != ApiStatusType.DONE:
|
||||||
|
yield generated_message
|
||||||
|
|
||||||
if not generated_message:
|
if not generated_message:
|
||||||
status_message.status = ChatStatusType.ERROR
|
status_message = ChatMessageStatus(
|
||||||
status_message.content = "Job requirements analysis failed."
|
session_id=session_id,
|
||||||
|
content="Job requirements analysis failed to generate a response.")
|
||||||
|
logger.error(f"⚠️ {status_message.content}")
|
||||||
yield status_message
|
yield status_message
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -150,34 +145,33 @@ class JobRequirementsAgent(Agent):
|
|||||||
if not job_requirements:
|
if not job_requirements:
|
||||||
raise ValueError("Job requirements data is empty or invalid.")
|
raise ValueError("Job requirements data is empty or invalid.")
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
status_message.status = ChatStatusType.ERROR
|
status_message.status = ApiStatusType.ERROR
|
||||||
status_message.content = f"Failed to parse job requirements JSON: {str(e)}\n\n{job_requirements_data}"
|
status_message.content = f"Failed to parse job requirements JSON: {str(e)}\n\n{job_requirements_data}"
|
||||||
logger.error(f"⚠️ {status_message.content}")
|
logger.error(f"⚠️ {status_message.content}")
|
||||||
yield status_message
|
yield status_message
|
||||||
return
|
return
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
status_message.status = ChatStatusType.ERROR
|
status_message.status = ApiStatusType.ERROR
|
||||||
status_message.content = f"Job requirements validation error: {str(e)}\n\n{job_requirements_data}"
|
status_message.content = f"Job requirements validation error: {str(e)}\n\n{job_requirements_data}"
|
||||||
logger.error(f"⚠️ {status_message.content}")
|
logger.error(f"⚠️ {status_message.content}")
|
||||||
yield status_message
|
yield status_message
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
status_message.status = ChatStatusType.ERROR
|
status_message.status = ApiStatusType.ERROR
|
||||||
status_message.content = f"Unexpected error processing job requirements: {str(e)}\n\n{job_requirements_data}"
|
status_message.content = f"Unexpected error processing job requirements: {str(e)}\n\n{job_requirements_data}"
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
logger.error(f"⚠️ {status_message.content}")
|
logger.error(f"⚠️ {status_message.content}")
|
||||||
yield status_message
|
yield status_message
|
||||||
return
|
return
|
||||||
status_message.status = ChatStatusType.DONE
|
job_requirements_message = JobRequirementsMessage(
|
||||||
status_message.type = ChatMessageType.RESPONSE
|
session_id=session_id,
|
||||||
job_data = {
|
status=ApiStatusType.DONE,
|
||||||
"company": company_name,
|
requirements=job_requirements,
|
||||||
"title": job_title,
|
company=company_name,
|
||||||
"summary": job_summary,
|
title=job_title,
|
||||||
"requirements": job_requirements.model_dump(mode="json", exclude_unset=True)
|
summary=job_summary
|
||||||
}
|
)
|
||||||
status_message.content = json.dumps(job_data)
|
yield job_requirements_message
|
||||||
yield status_message
|
|
||||||
|
|
||||||
logger.info(f"✅ Job requirements analysis completed successfully.")
|
logger.info(f"✅ Job requirements analysis completed successfully.")
|
||||||
return
|
return
|
||||||
|
@ -7,7 +7,7 @@ from .base import Agent, agent_registry
|
|||||||
from logger import logger
|
from logger import logger
|
||||||
|
|
||||||
from .registry import agent_registry
|
from .registry import agent_registry
|
||||||
from models import ( ChatMessage, ChatStatusType, ChatMessage, ChatOptions, ChatMessageType, ChatSenderType, ChatStatusType, ChatMessageMetaData, Candidate )
|
from models import ( ChatMessage, ApiStatusType, ChatMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, ChatOptions, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate, Tunables )
|
||||||
from rag import ( ChromaDBGetResponse )
|
from rag import ( ChromaDBGetResponse )
|
||||||
|
|
||||||
class Chat(Agent):
|
class Chat(Agent):
|
||||||
@ -19,8 +19,11 @@ class Chat(Agent):
|
|||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self, llm: Any, model: str, user_message: ChatMessage, user: Candidate, temperature=0.7
|
self, llm: Any, model: str,
|
||||||
) -> AsyncGenerator[ChatMessage, None]:
|
session_id: str, prompt: str,
|
||||||
|
tunables: Optional[Tunables] = None,
|
||||||
|
temperature=0.7
|
||||||
|
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
|
||||||
"""
|
"""
|
||||||
Generate a response based on the user message and the provided LLM.
|
Generate a response based on the user message and the provided LLM.
|
||||||
|
|
||||||
@ -34,65 +37,25 @@ class Chat(Agent):
|
|||||||
Yields:
|
Yields:
|
||||||
ChatMessage: The generated response.
|
ChatMessage: The generated response.
|
||||||
"""
|
"""
|
||||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
rag_message = None
|
||||||
|
async for rag_message in self.generate_rag_results(session_id, prompt):
|
||||||
if user.id != user_message.sender_id:
|
if rag_message.status == ApiStatusType.ERROR:
|
||||||
logger.error(f"User {user.username} id does not match message {user_message.sender_id}")
|
yield rag_message
|
||||||
raise ValueError("User does not match message sender")
|
return
|
||||||
|
if rag_message.status != ApiStatusType.DONE:
|
||||||
chat_message = ChatMessage(
|
|
||||||
session_id=user_message.session_id,
|
|
||||||
tunables=user_message.tunables,
|
|
||||||
status=ChatStatusType.INITIALIZING,
|
|
||||||
type=ChatMessageType.PREPARING,
|
|
||||||
sender=ChatSenderType.ASSISTANT,
|
|
||||||
content="",
|
|
||||||
timestamp=datetime.now(UTC)
|
|
||||||
)
|
|
||||||
|
|
||||||
chat_message.metadata = ChatMessageMetaData()
|
|
||||||
chat_message.metadata.options = ChatOptions(
|
|
||||||
seed=8911,
|
|
||||||
num_ctx=self.context_size,
|
|
||||||
temperature=temperature, # Higher temperature to encourage tool usage
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a dict for storing various timing stats
|
|
||||||
chat_message.metadata.timers = {}
|
|
||||||
|
|
||||||
self.metrics.generate_count.labels(agent=self.agent_type).inc()
|
|
||||||
with self.metrics.generate_duration.labels(agent=self.agent_type).time():
|
|
||||||
|
|
||||||
rag_message : Optional[ChatMessage] = None
|
|
||||||
async for rag_message in self.generate_rag_results(chat_message=user_message):
|
|
||||||
if rag_message.status == ChatStatusType.ERROR:
|
|
||||||
chat_message.status = rag_message.status
|
|
||||||
chat_message.content = rag_message.content
|
|
||||||
yield chat_message
|
|
||||||
return
|
|
||||||
yield rag_message
|
yield rag_message
|
||||||
|
|
||||||
if rag_message:
|
if not isinstance(rag_message, ChatMessageRagSearch):
|
||||||
chat_message.content = ""
|
logger.error(f"Expected ChatMessageRagSearch, got {type(rag_message)}")
|
||||||
rag_results: List[ChromaDBGetResponse] = rag_message.metadata.rag_results
|
error_message = ChatMessageError(
|
||||||
chat_message.metadata.rag_results = rag_results
|
session_id=session_id,
|
||||||
for chroma_results in rag_results:
|
content="RAG search did not return a valid response."
|
||||||
for index, metadata in enumerate(chroma_results.metadatas):
|
)
|
||||||
content = "\n".join([
|
yield error_message
|
||||||
line.strip()
|
return
|
||||||
for line in chroma_results.documents[index].split("\n")
|
|
||||||
if line
|
|
||||||
]).strip()
|
|
||||||
chat_message.content += f"""
|
|
||||||
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")}
|
|
||||||
Document reference: {chroma_results.ids[index]}
|
|
||||||
Content: { content }
|
|
||||||
|
|
||||||
"""
|
rag_message.status = ApiStatusType.DONE
|
||||||
|
yield rag_message
|
||||||
chat_message.status = ChatStatusType.DONE
|
|
||||||
chat_message.type = ChatMessageType.RAG_RESULT
|
|
||||||
yield chat_message
|
|
||||||
|
|
||||||
# Register the base agent
|
# Register the base agent
|
||||||
agent_registry.register(Chat._agent_type, Chat)
|
agent_registry.register(Chat._agent_type, Chat)
|
||||||
|
@ -20,7 +20,7 @@ import asyncio
|
|||||||
import numpy as np # type: ignore
|
import numpy as np # type: ignore
|
||||||
|
|
||||||
from .base import Agent, agent_registry, LLMMessage
|
from .base import Agent, agent_registry, LLMMessage
|
||||||
from models import Candidate, ChatMessage, ChatMessageBase, ChatMessageMetaData, ChatMessageType, ChatMessageUser, ChatOptions, ChatSenderType, ChatStatusType, SkillMatch
|
from models import Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, SkillMatch, Tunables
|
||||||
import model_cast
|
import model_cast
|
||||||
from logger import logger
|
from logger import logger
|
||||||
import defines
|
import defines
|
||||||
@ -29,13 +29,13 @@ class SkillMatchAgent(Agent):
|
|||||||
agent_type: Literal["skill_match"] = "skill_match" # type: ignore
|
agent_type: Literal["skill_match"] = "skill_match" # type: ignore
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
def generate_skill_assessment_prompt(self, skill, rag_content):
|
def generate_skill_assessment_prompt(self, skill, rag_context):
|
||||||
"""
|
"""
|
||||||
Generate a system prompt to query the LLM for evidence of a specific skill
|
Generate a system prompt to query the LLM for evidence of a specific skill
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- skill (str): The specific skill to assess from job requirements
|
- skill (str): The specific skill to assess from job requirements
|
||||||
- rag_content (str): Additional RAG content queried from candidate documents
|
- rag_contexty (str): Additional RAG content queried from candidate documents
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- str: A system prompt tailored to assess the specific skill
|
- str: A system prompt tailored to assess the specific skill
|
||||||
@ -98,7 +98,7 @@ Adhere strictly to the JSON output format requested. Do not include any addition
|
|||||||
RESPOND WITH ONLY VALID JSON USING THE EXACT FORMAT SPECIFIED.
|
RESPOND WITH ONLY VALID JSON USING THE EXACT FORMAT SPECIFIED.
|
||||||
|
|
||||||
<candidate_info>
|
<candidate_info>
|
||||||
{rag_content}
|
{rag_context}
|
||||||
</candidate_info>
|
</candidate_info>
|
||||||
|
|
||||||
JSON RESPONSE:"""
|
JSON RESPONSE:"""
|
||||||
@ -106,110 +106,102 @@ JSON RESPONSE:"""
|
|||||||
return system_prompt, prompt
|
return system_prompt, prompt
|
||||||
|
|
||||||
async def analyze_job_requirements(
|
async def analyze_job_requirements(
|
||||||
self, llm: Any, model: str, user_message: ChatMessage
|
self, llm: Any, model: str, session_id: str, requirement: str
|
||||||
) -> AsyncGenerator[ChatMessage, None]:
|
) -> AsyncGenerator[ChatMessage, None]:
|
||||||
"""Analyze job requirements from job description."""
|
"""Analyze job requirements from job description."""
|
||||||
system_prompt, prompt = self.create_job_analysis_prompt(user_message.content)
|
system_prompt, prompt = self.create_job_analysis_prompt(requirement)
|
||||||
analyze_message = user_message.model_copy()
|
|
||||||
analyze_message.content = prompt
|
|
||||||
generated_message = None
|
generated_message = None
|
||||||
async for generated_message in self.llm_one_shot(llm, model, system_prompt=system_prompt, user_message=analyze_message):
|
async for generated_message in self.llm_one_shot(llm, model, session_id=session_id, prompt=prompt, system_prompt=system_prompt):
|
||||||
if generated_message.status == ChatStatusType.ERROR:
|
if generated_message.status == ApiStatusType.ERROR:
|
||||||
generated_message.content = "Error analyzing job requirements."
|
generated_message.content = "Error analyzing job requirements."
|
||||||
yield generated_message
|
yield generated_message
|
||||||
return
|
return
|
||||||
|
if generated_message.status != ApiStatusType.DONE:
|
||||||
|
yield generated_message
|
||||||
|
|
||||||
if not generated_message:
|
if not generated_message:
|
||||||
status_message = ChatMessage(
|
error_message = ChatMessageError(
|
||||||
session_id=user_message.session_id,
|
session_id=session_id,
|
||||||
sender=ChatSenderType.AGENT,
|
|
||||||
status = ChatStatusType.ERROR,
|
|
||||||
type = ChatMessageType.ERROR,
|
|
||||||
content = "Job requirements analysis failed to generate a response.")
|
content = "Job requirements analysis failed to generate a response.")
|
||||||
yield status_message
|
yield error_message
|
||||||
return
|
return
|
||||||
|
|
||||||
generated_message.status = ChatStatusType.DONE
|
chat_message = ChatMessage(
|
||||||
generated_message.type = ChatMessageType.RESPONSE
|
session_id=session_id,
|
||||||
yield generated_message
|
status=ApiStatusType.DONE,
|
||||||
|
content=generated_message.content,
|
||||||
|
metadata=generated_message.metadata,
|
||||||
|
)
|
||||||
|
yield chat_message
|
||||||
return
|
return
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self, llm: Any, model: str, user_message: ChatMessageUser, user: Candidate | None, temperature=0.7
|
self, llm: Any, model: str,
|
||||||
) -> AsyncGenerator[ChatMessage, None]:
|
session_id: str, prompt: str,
|
||||||
|
tunables: Optional[Tunables] = None,
|
||||||
|
temperature=0.7
|
||||||
|
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
|
||||||
# Stage 1A: Analyze job requirements
|
# Stage 1A: Analyze job requirements
|
||||||
status_message = ChatMessage(
|
|
||||||
session_id=user_message.session_id,
|
|
||||||
sender=ChatSenderType.AGENT,
|
|
||||||
status=ChatStatusType.STATUS,
|
|
||||||
type=ChatMessageType.THINKING,
|
|
||||||
content = f"Analyzing job requirements")
|
|
||||||
yield status_message
|
|
||||||
|
|
||||||
rag_message = None
|
rag_message = None
|
||||||
async for rag_message in self.generate_rag_results(chat_message=user_message):
|
async for rag_message in self.generate_rag_results(session_id=session_id, prompt=prompt):
|
||||||
if rag_message.status == ChatStatusType.ERROR:
|
if rag_message.status == ApiStatusType.ERROR:
|
||||||
status_message.status = ChatStatusType.ERROR
|
yield rag_message
|
||||||
status_message.content = rag_message.content
|
|
||||||
logger.error(f"⚠️ {status_message.content}")
|
|
||||||
yield status_message
|
|
||||||
return
|
return
|
||||||
|
if rag_message.status != ApiStatusType.DONE:
|
||||||
|
yield rag_message
|
||||||
|
|
||||||
if rag_message is None:
|
if rag_message is None:
|
||||||
status_message.status = ChatStatusType.ERROR
|
error_message = ChatMessageError(
|
||||||
status_message.content = "Failed to retrieve RAG context."
|
session_id=session_id,
|
||||||
logger.error(f"⚠️ {status_message.content}")
|
content="RAG search did not return a valid response."
|
||||||
yield status_message
|
)
|
||||||
|
logger.error(f"⚠️ {error_message.content}")
|
||||||
|
yield error_message
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"🔍 RAG content retrieved: {len(rag_message.content)} bytes")
|
logger.info(f"🔍 RAG content retrieved: {len(rag_message.content)} bytes")
|
||||||
|
rag_context = self.get_rag_context(rag_message)
|
||||||
|
system_prompt, prompt = self.generate_skill_assessment_prompt(skill=prompt, rag_context=rag_context)
|
||||||
|
|
||||||
system_prompt, prompt = self.generate_skill_assessment_prompt(skill=user_message.content, rag_content=rag_message.content)
|
|
||||||
|
|
||||||
user_message.content = prompt
|
|
||||||
skill_assessment = None
|
skill_assessment = None
|
||||||
async for skill_assessment in self.llm_one_shot(llm=llm, model=model, user_message=user_message, system_prompt=system_prompt, temperature=0.7):
|
async for skill_assessment in self.llm_one_shot(llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt, temperature=0.7):
|
||||||
if skill_assessment.status == ChatStatusType.ERROR:
|
if skill_assessment.status == ApiStatusType.ERROR:
|
||||||
status_message.status = ChatStatusType.ERROR
|
logger.error(f"⚠️ {skill_assessment.content}")
|
||||||
status_message.content = skill_assessment.content
|
yield skill_assessment
|
||||||
logger.error(f"⚠️ {status_message.content}")
|
|
||||||
yield status_message
|
|
||||||
return
|
return
|
||||||
|
if skill_assessment.status != ApiStatusType.DONE:
|
||||||
|
yield skill_assessment
|
||||||
|
|
||||||
if skill_assessment is None:
|
if skill_assessment is None:
|
||||||
status_message.status = ChatStatusType.ERROR
|
error_message = ChatMessageError(
|
||||||
status_message.content = "Failed to generate skill assessment."
|
session_id=session_id,
|
||||||
logger.error(f"⚠️ {status_message.content}")
|
content="Skill assessment failed to generate a response."
|
||||||
yield status_message
|
)
|
||||||
|
logger.error(f"⚠️ {error_message.content}")
|
||||||
|
yield error_message
|
||||||
return
|
return
|
||||||
|
|
||||||
json_str = self.extract_json_from_text(skill_assessment.content)
|
json_str = self.extract_json_from_text(skill_assessment.content)
|
||||||
skill_assessment_data = ""
|
skill_assessment_data = ""
|
||||||
try:
|
try:
|
||||||
skill_assessment_data = json.loads(json_str).get("skill_assessment", {})
|
skill_assessment_data = json.loads(json_str).get("skill_assessment", {})
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
status_message.status = ChatStatusType.ERROR
|
|
||||||
status_message.content = f"Failed to parse Skill assessment JSON: {str(e)}\n\n{skill_assessment_data}"
|
|
||||||
logger.error(f"⚠️ {status_message.content}")
|
|
||||||
yield status_message
|
|
||||||
return
|
|
||||||
except ValueError as e:
|
|
||||||
status_message.status = ChatStatusType.ERROR
|
|
||||||
status_message.content = f"Skill assessment validation error: {str(e)}\n\n{skill_assessment_data}"
|
|
||||||
logger.error(f"⚠️ {status_message.content}")
|
|
||||||
yield status_message
|
|
||||||
return
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
status_message.status = ChatStatusType.ERROR
|
error_message = ChatMessageError(
|
||||||
status_message.content = f"Unexpected error processing Skill assessment: {str(e)}\n\n{skill_assessment_data}"
|
session_id=session_id,
|
||||||
logger.error(traceback.format_exc())
|
content=f"Failed to parse Skill assessment JSON: {str(e)}\n\n{skill_assessment.content}"
|
||||||
logger.error(f"⚠️ {status_message.content}")
|
)
|
||||||
yield status_message
|
logger.error(f"⚠️ {error_message.content}")
|
||||||
|
yield error_message
|
||||||
return
|
return
|
||||||
status_message.status = ChatStatusType.DONE
|
|
||||||
status_message.type = ChatMessageType.RESPONSE
|
|
||||||
status_message.content = json.dumps(skill_assessment_data)
|
|
||||||
yield status_message
|
|
||||||
|
|
||||||
|
skill_assessment_message = ChatMessage(
|
||||||
|
session_id=session_id,
|
||||||
|
status=ApiStatusType.DONE,
|
||||||
|
content=json.dumps(skill_assessment_data),
|
||||||
|
metadata=skill_assessment.metadata
|
||||||
|
)
|
||||||
|
yield skill_assessment_message
|
||||||
logger.info(f"✅ Skill assessment completed successfully.")
|
logger.info(f"✅ Skill assessment completed successfully.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ from rag import start_file_watcher, ChromaDBFileWatcher, ChromaDBGetResponse
|
|||||||
import defines
|
import defines
|
||||||
from logger import logger
|
from logger import logger
|
||||||
import agents as agents
|
import agents as agents
|
||||||
from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, RagEntry, ChatMessageType, ChatMessageMetaData, ChatStatusType, Candidate, ChatContextType)
|
from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, RagEntry, ChatMessageMetaData, ApiStatusType, Candidate, ChatContextType)
|
||||||
from llm_manager import llm_manager
|
from llm_manager import llm_manager
|
||||||
from agents.base import Agent
|
from agents.base import Agent
|
||||||
|
|
||||||
|
@ -797,7 +797,7 @@ Generated conversion functions can be used like:
|
|||||||
const jobs = convertArrayFromApi<Job>(apiResponse, 'Job');
|
const jobs = convertArrayFromApi<Job>(apiResponse, 'Job');
|
||||||
|
|
||||||
Enum types are now properly handled:
|
Enum types are now properly handled:
|
||||||
status: ChatStatusType = ChatStatusType.DONE -> status: ChatStatusType (not locked to "done")
|
status: ApiStatusType = ApiStatusType.DONE -> status: ApiStatusType (not locked to "done")
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ import uuid
|
|||||||
|
|
||||||
from .image_model_cache import ImageModelCache
|
from .image_model_cache import ImageModelCache
|
||||||
|
|
||||||
from models import Candidate, ChatMessage, ChatMessageBase, ChatMessageMetaData, ChatMessageType, ChatMessageUser, ChatOptions, ChatSenderType, ChatStatusType
|
from models import ApiActivityType, ApiStatusType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ChatMessageStatus, ChatMessageUser, ChatOptions, ChatSenderType
|
||||||
from logger import logger
|
from logger import logger
|
||||||
|
|
||||||
from image_generator.image_model_cache import ImageModelCache
|
from image_generator.image_model_cache import ImageModelCache
|
||||||
@ -39,6 +39,7 @@ TIME_ESTIMATES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
class ImageRequest(BaseModel):
|
class ImageRequest(BaseModel):
|
||||||
|
session_id: str
|
||||||
filepath: str
|
filepath: str
|
||||||
prompt: str
|
prompt: str
|
||||||
model: str = "black-forest-labs/FLUX.1-schnell"
|
model: str = "black-forest-labs/FLUX.1-schnell"
|
||||||
@ -53,18 +54,12 @@ model_cache = ImageModelCache()
|
|||||||
def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str):
|
def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str):
|
||||||
"""Background worker for Flux image generation"""
|
"""Background worker for Flux image generation"""
|
||||||
try:
|
try:
|
||||||
# Your existing estimates calculation
|
|
||||||
estimates = {"per_step": 0.5} # Replace with your actual estimates
|
|
||||||
resolution_scale = (params.height * params.width) / (512 * 512)
|
|
||||||
|
|
||||||
# Flux: Run generation in the background and yield progress updates
|
# Flux: Run generation in the background and yield progress updates
|
||||||
estimated_gen_time = estimates["per_step"] * params.iterations * resolution_scale
|
status_queue.put(ChatMessageStatus(
|
||||||
status_queue.put({
|
session_id=params.session_id,
|
||||||
"status": "running",
|
content=f"Initializing image generation.",
|
||||||
"message": f"Initializing image generation...",
|
activity=ApiActivityType.GENERATING_IMAGE,
|
||||||
"estimated_time_remaining": estimated_gen_time,
|
))
|
||||||
"progress": 0
|
|
||||||
})
|
|
||||||
|
|
||||||
# Start the generation task
|
# Start the generation task
|
||||||
start_gen_time = time.time()
|
start_gen_time = time.time()
|
||||||
@ -74,11 +69,11 @@ def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task
|
|||||||
# Send progress updates
|
# Send progress updates
|
||||||
progress = int((step+1) / params.iterations * 100)
|
progress = int((step+1) / params.iterations * 100)
|
||||||
|
|
||||||
status_queue.put({
|
status_queue.put(ChatMessageStatus(
|
||||||
"status": "running",
|
session_id=params.session_id,
|
||||||
"message": f"Processing step {step+1}/{params.iterations} ({progress}%) complete.",
|
content=f"Processing step {step+1}/{params.iterations} ({progress}%)",
|
||||||
"progress": progress
|
activity=ApiActivityType.GENERATING_IMAGE,
|
||||||
})
|
))
|
||||||
return callback_kwargs
|
return callback_kwargs
|
||||||
|
|
||||||
# Replace this block with your actual Flux pipe call:
|
# Replace this block with your actual Flux pipe call:
|
||||||
@ -98,27 +93,22 @@ def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task
|
|||||||
image.save(params.filepath)
|
image.save(params.filepath)
|
||||||
|
|
||||||
# Final completion status
|
# Final completion status
|
||||||
status_queue.put({
|
status_queue.put(ChatMessage(
|
||||||
"status": "completed",
|
session_id=params.session_id,
|
||||||
"message": f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.",
|
status=ApiStatusType.DONE,
|
||||||
"progress": 100,
|
content=f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.",
|
||||||
"generation_time": gen_time,
|
))
|
||||||
"per_step_time": per_step_time,
|
|
||||||
"image_path": params.filepath
|
|
||||||
})
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
status_queue.put({
|
status_queue.put(ChatMessageError(
|
||||||
"status": "error",
|
session_id=params.session_id,
|
||||||
"message": f"Generation failed: {str(e)}",
|
content=f"Error during image generation: {str(e)}",
|
||||||
"error": str(e),
|
))
|
||||||
"progress": 0
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[Dict[str, Any], None]:
|
async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError, None]:
|
||||||
"""
|
"""
|
||||||
Single async function that handles background Flux generation with status streaming
|
Single async function that handles background Flux generation with status streaming
|
||||||
"""
|
"""
|
||||||
@ -136,7 +126,12 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato
|
|||||||
worker_thread.start()
|
worker_thread.start()
|
||||||
|
|
||||||
# Initial status
|
# Initial status
|
||||||
yield {'status': 'starting', 'task_id': task_id, 'message': 'Initializing image generation'}
|
status_message = ChatMessageStatus(
|
||||||
|
session_id=params.session_id,
|
||||||
|
content=f"Starting image generation with task ID {task_id}",
|
||||||
|
activity=ApiActivityType.THINKING
|
||||||
|
)
|
||||||
|
yield status_message
|
||||||
|
|
||||||
# Stream status updates
|
# Stream status updates
|
||||||
completed = False
|
completed = False
|
||||||
@ -147,14 +142,12 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato
|
|||||||
# Try to get status update (non-blocking)
|
# Try to get status update (non-blocking)
|
||||||
status_update = status_queue.get_nowait()
|
status_update = status_queue.get_nowait()
|
||||||
|
|
||||||
# Add task_id to status update
|
|
||||||
status_update['task_id'] = task_id
|
|
||||||
|
|
||||||
# Send status update
|
# Send status update
|
||||||
yield status_update
|
yield status_update
|
||||||
|
|
||||||
# Check if completed
|
# Check if completed
|
||||||
if status_update.get('status') in ['completed', 'error']:
|
if status_update.status == ApiStatusType.DONE:
|
||||||
|
logger.info(f"Image generation completed for task {task_id}")
|
||||||
completed = True
|
completed = True
|
||||||
|
|
||||||
last_heartbeat = time.time()
|
last_heartbeat = time.time()
|
||||||
@ -163,11 +156,11 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato
|
|||||||
# No new status, send heartbeat if needed
|
# No new status, send heartbeat if needed
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
if current_time - last_heartbeat > 2: # Heartbeat every 2 seconds
|
if current_time - last_heartbeat > 2: # Heartbeat every 2 seconds
|
||||||
heartbeat = {
|
heartbeat = ChatMessageStatus(
|
||||||
'status': 'heartbeat',
|
session_id=params.session_id,
|
||||||
'task_id': task_id,
|
content=f"Heartbeat for task {task_id}",
|
||||||
'timestamp': current_time
|
activity=ApiActivityType.HEARTBEAT,
|
||||||
}
|
)
|
||||||
yield heartbeat
|
yield heartbeat
|
||||||
last_heartbeat = current_time
|
last_heartbeat = current_time
|
||||||
|
|
||||||
@ -178,28 +171,25 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato
|
|||||||
if not completed:
|
if not completed:
|
||||||
if worker_thread.is_alive():
|
if worker_thread.is_alive():
|
||||||
# Thread still running but we might have missed the completion signal
|
# Thread still running but we might have missed the completion signal
|
||||||
timeout_status = {
|
timeout_status = ChatMessageError(
|
||||||
'status': 'timeout',
|
session_id=params.session_id,
|
||||||
'task_id': task_id,
|
content=f"Generation timeout for task {task_id}. The process may still be running.",
|
||||||
'message': 'Generation timed out or connection lost'
|
)
|
||||||
}
|
|
||||||
yield timeout_status
|
yield timeout_status
|
||||||
else:
|
else:
|
||||||
# Thread completed but we might have missed the final status
|
# Thread completed but we might have missed the final status
|
||||||
final_status = {
|
final_status = ChatMessage(
|
||||||
'status': 'completed',
|
session_id=params.session_id,
|
||||||
'task_id': task_id,
|
status=ApiStatusType.DONE,
|
||||||
'message': 'Generation completed'
|
content=f"Generation completed for task {task_id}.",
|
||||||
}
|
)
|
||||||
yield final_status
|
yield final_status
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_status = {
|
error_status = ChatMessageError(
|
||||||
'status': 'error',
|
session_id=params.session_id,
|
||||||
'task_id': task_id,
|
content=f'Server error: {str(e)}'
|
||||||
'message': f'Server error: {str(e)}',
|
)
|
||||||
'error': str(e)
|
|
||||||
}
|
|
||||||
logger.error(error_status)
|
logger.error(error_status)
|
||||||
yield error_status
|
yield error_status
|
||||||
|
|
||||||
@ -208,41 +198,38 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato
|
|||||||
if worker_thread and 'worker_thread' in locals() and worker_thread.is_alive():
|
if worker_thread and 'worker_thread' in locals() and worker_thread.is_alive():
|
||||||
worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup
|
worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup
|
||||||
|
|
||||||
def status(chat_message: ChatMessage, status: str, progress: float = 0, estimated_time_remaining="...") -> ChatMessage:
|
def status(session_id: str, status: str) -> ChatMessageStatus:
|
||||||
"""Update chat message status and return it."""
|
"""Update chat message status and return it."""
|
||||||
message = chat_message.copy(deep=True)
|
chat_message = ChatMessageStatus(
|
||||||
message.id = str(uuid.uuid4())
|
session_id=session_id,
|
||||||
message.timestamp = datetime.now(UTC)
|
activity=ApiActivityType.GENERATING_IMAGE,
|
||||||
message.type = ChatMessageType.THINKING
|
content=status,
|
||||||
message.status = ChatStatusType.STREAMING
|
|
||||||
message.content = status
|
|
||||||
return message
|
|
||||||
|
|
||||||
async def generate_image(user_message: ChatMessage, request: ImageRequest) -> AsyncGenerator[ChatMessage, None]:
|
|
||||||
"""Generate an image with specified dimensions and yield status updates with time estimates."""
|
|
||||||
chat_message = ChatMessage(
|
|
||||||
session_id=user_message.session_id,
|
|
||||||
tunables=user_message.tunables,
|
|
||||||
status=ChatStatusType.INITIALIZING,
|
|
||||||
type=ChatMessageType.PREPARING,
|
|
||||||
sender=ChatSenderType.ASSISTANT,
|
|
||||||
content="",
|
|
||||||
timestamp=datetime.now(UTC)
|
|
||||||
)
|
)
|
||||||
|
return chat_message
|
||||||
|
|
||||||
|
async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, None]:
|
||||||
|
"""Generate an image with specified dimensions and yield status updates with time estimates."""
|
||||||
|
session_id = request.session_id
|
||||||
|
prompt = request.prompt.strip()
|
||||||
try:
|
try:
|
||||||
# Validate prompt
|
# Validate prompt
|
||||||
prompt = user_message.content.strip()
|
|
||||||
if not prompt:
|
if not prompt:
|
||||||
chat_message.status = ChatStatusType.ERROR
|
error_message = ChatMessageError(
|
||||||
chat_message.content = "Prompt cannot be empty"
|
session_id=session_id,
|
||||||
yield chat_message
|
content="Prompt cannot be empty."
|
||||||
|
)
|
||||||
|
logger.error(error_message.content)
|
||||||
|
yield error_message
|
||||||
return
|
return
|
||||||
|
|
||||||
# Validate dimensions
|
# Validate dimensions
|
||||||
if request.height <= 0 or request.width <= 0:
|
if request.height <= 0 or request.width <= 0:
|
||||||
chat_message.status = ChatStatusType.ERROR
|
error_message = ChatMessageError(
|
||||||
chat_message.content = "Height and width must be positive"
|
session_id=session_id,
|
||||||
yield chat_message
|
content="Height and width must be positive integers."
|
||||||
|
)
|
||||||
|
logger.error(error_message.content)
|
||||||
|
yield error_message
|
||||||
return
|
return
|
||||||
|
|
||||||
filedir = os.path.dirname(request.filepath)
|
filedir = os.path.dirname(request.filepath)
|
||||||
@ -252,43 +239,50 @@ async def generate_image(user_message: ChatMessage, request: ImageRequest) -> As
|
|||||||
model_type = "flux"
|
model_type = "flux"
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
|
|
||||||
yield status(chat_message, f"Starting image generation...")
|
|
||||||
|
|
||||||
# Get initial time estimate, scaled by resolution
|
# Get initial time estimate, scaled by resolution
|
||||||
estimates = TIME_ESTIMATES[model_type][device]
|
estimates = TIME_ESTIMATES[model_type][device]
|
||||||
resolution_scale = (request.height * request.width) / (512 * 512)
|
resolution_scale = (request.height * request.width) / (512 * 512)
|
||||||
estimated_total = estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale
|
estimated_total = estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale
|
||||||
yield status(chat_message, f"Estimated generation time: ~{estimated_total:.1f} seconds for {request.width}x{request.height}")
|
|
||||||
|
|
||||||
# Initialize or get cached pipeline
|
# Initialize or get cached pipeline
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
yield status(chat_message, f"Loading generative image model...")
|
yield status(session_id, f"Loading generative image model...")
|
||||||
pipe = await model_cache.get_pipeline(request.model, device)
|
pipe = await model_cache.get_pipeline(request.model, device)
|
||||||
load_time = time.time() - start_time
|
load_time = time.time() - start_time
|
||||||
yield status(chat_message, f"Model loaded in {load_time:.1f} seconds.", progress=10)
|
yield status(session_id, f"Model loaded in {load_time:.1f} seconds.",)
|
||||||
|
|
||||||
async for status_message in async_generate_image(pipe, request):
|
progress = None
|
||||||
chat_message.content = json.dumps(status_message) # Merge properties from async_generate_image over the message...
|
async for progress in async_generate_image(pipe, request):
|
||||||
chat_message.type = ChatMessageType.HEARTBEAT if status_message.get("status") == "heartbeat" else ChatMessageType.THINKING
|
if progress.status == ApiStatusType.ERROR:
|
||||||
if chat_message.type != ChatMessageType.HEARTBEAT:
|
yield progress
|
||||||
logger.info(chat_message.content)
|
return
|
||||||
yield chat_message
|
if progress.status != ApiStatusType.DONE:
|
||||||
|
yield progress
|
||||||
|
|
||||||
|
if not progress:
|
||||||
|
error_message = ChatMessageError(
|
||||||
|
session_id=session_id,
|
||||||
|
content="Image generation failed to produce a valid response."
|
||||||
|
)
|
||||||
|
logger.error(f"⚠️ {error_message.content}")
|
||||||
|
yield error_message
|
||||||
|
return
|
||||||
|
|
||||||
# Final result
|
# Final result
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
chat_message.status = ChatStatusType.DONE
|
chat_message = ChatMessage(
|
||||||
chat_message.type = ChatMessageType.RESPONSE
|
session_id=session_id,
|
||||||
chat_message.content = json.dumps({
|
status=ApiStatusType.DONE,
|
||||||
"status": f"Image generation complete in {total_time:.1f} seconds",
|
content=f"Image generated successfully in {total_time:.1f} seconds.",
|
||||||
"progress": 100,
|
)
|
||||||
"filename": request.filepath
|
|
||||||
})
|
|
||||||
yield chat_message
|
yield chat_message
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
chat_message.status = ChatStatusType.ERROR
|
error_message = ChatMessageError(
|
||||||
chat_message.content = str(e)
|
session_id=session_id,
|
||||||
yield chat_message
|
content=f"Error during image generation: {str(e)}"
|
||||||
|
)
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
logger.error(chat_message.content)
|
logger.error(error_message.content)
|
||||||
|
yield error_message
|
||||||
return
|
return
|
@ -64,7 +64,7 @@ import agents
|
|||||||
# =============================
|
# =============================
|
||||||
from models import (
|
from models import (
|
||||||
# API
|
# API
|
||||||
Job, LoginRequest, CreateCandidateRequest, CreateEmployerRequest,
|
ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, Job, LoginRequest, CreateCandidateRequest, CreateEmployerRequest,
|
||||||
|
|
||||||
# User models
|
# User models
|
||||||
Candidate, Employer, BaseUserWithType, BaseUser, Guest, Authentication, AuthResponse, CandidateAI,
|
Candidate, Employer, BaseUserWithType, BaseUser, Guest, Authentication, AuthResponse, CandidateAI,
|
||||||
@ -73,7 +73,7 @@ from models import (
|
|||||||
JobFull, JobApplication, ApplicationStatus,
|
JobFull, JobApplication, ApplicationStatus,
|
||||||
|
|
||||||
# Chat models
|
# Chat models
|
||||||
ChatSession, ChatMessage, ChatContext, ChatQuery, ChatStatusType, ChatMessageBase, ChatMessageUser, ChatSenderType, ChatMessageType, ChatContextType,
|
ChatSession, ChatMessage, ChatContext, ChatQuery, ApiStatusType, ChatSenderType, ApiMessageType, ChatContextType,
|
||||||
ChatMessageRagSearch,
|
ChatMessageRagSearch,
|
||||||
|
|
||||||
# Document models
|
# Document models
|
||||||
@ -361,7 +361,6 @@ def filter_and_paginate(
|
|||||||
|
|
||||||
async def stream_agent_response(chat_agent: agents.Agent,
|
async def stream_agent_response(chat_agent: agents.Agent,
|
||||||
user_message: ChatMessageUser,
|
user_message: ChatMessageUser,
|
||||||
candidate: Candidate,
|
|
||||||
chat_session_data: Dict[str, Any] | None = None,
|
chat_session_data: Dict[str, Any] | None = None,
|
||||||
database: RedisDatabase | None = None) -> StreamingResponse:
|
database: RedisDatabase | None = None) -> StreamingResponse:
|
||||||
async def message_stream_generator():
|
async def message_stream_generator():
|
||||||
@ -372,30 +371,33 @@ async def stream_agent_response(chat_agent: agents.Agent,
|
|||||||
async for generated_message in chat_agent.generate(
|
async for generated_message in chat_agent.generate(
|
||||||
llm=llm_manager.get_llm(),
|
llm=llm_manager.get_llm(),
|
||||||
model=defines.model,
|
model=defines.model,
|
||||||
user_message=user_message,
|
session_id=user_message.session_id,
|
||||||
user=candidate,
|
prompt=user_message.content,
|
||||||
):
|
):
|
||||||
|
if generated_message.status == ApiStatusType.ERROR:
|
||||||
|
logger.error(f"❌ AI generation error: {generated_message.content}")
|
||||||
|
yield f"data: {json.dumps({'status': 'error'})}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
# Store reference to the complete AI message
|
# Store reference to the complete AI message
|
||||||
if generated_message.status == ChatStatusType.DONE:
|
if generated_message.status == ApiStatusType.DONE:
|
||||||
final_message = generated_message
|
final_message = generated_message
|
||||||
|
|
||||||
# If the message is not done, convert it to a ChatMessageBase to remove
|
# If the message is not done, convert it to a ChatMessageBase to remove
|
||||||
# metadata and other unnecessary fields for streaming
|
# metadata and other unnecessary fields for streaming
|
||||||
if generated_message.status != ChatStatusType.DONE:
|
if generated_message.status != ApiStatusType.DONE:
|
||||||
generated_message = model_cast.cast_to_model(ChatMessageBase, generated_message)
|
if not isinstance(generated_message, ChatMessageStreaming) and not isinstance(generated_message, ChatMessageStatus):
|
||||||
|
raise TypeError(
|
||||||
|
f"Expected ChatMessageStreaming or ChatMessageStatus, got {type(generated_message)}"
|
||||||
|
)
|
||||||
|
|
||||||
json_data = generated_message.model_dump(mode='json', by_alias=True, exclude_unset=True)
|
json_data = generated_message.model_dump(mode='json', by_alias=True)
|
||||||
json_str = json.dumps(json_data)
|
json_str = json.dumps(json_data)
|
||||||
|
|
||||||
log = f"🔗 Message status={generated_message.status}, sender={getattr(generated_message, 'sender', 'unknown')}"
|
|
||||||
if last_log != log:
|
|
||||||
last_log = log
|
|
||||||
logger.info(log)
|
|
||||||
|
|
||||||
yield f"data: {json_str}\n\n"
|
yield f"data: {json_str}\n\n"
|
||||||
|
|
||||||
# After streaming is complete, persist the final AI message to database
|
# After streaming is complete, persist the final AI message to database
|
||||||
if final_message and final_message.status == ChatStatusType.DONE:
|
if final_message and final_message.status == ApiStatusType.DONE:
|
||||||
try:
|
try:
|
||||||
if database and chat_session_data:
|
if database and chat_session_data:
|
||||||
await database.add_chat_message(final_message.session_id, final_message.model_dump())
|
await database.add_chat_message(final_message.session_id, final_message.model_dump())
|
||||||
@ -412,9 +414,12 @@ async def stream_agent_response(chat_agent: agents.Agent,
|
|||||||
message_stream_generator(),
|
message_stream_generator(),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={
|
headers={
|
||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache, no-store, must-revalidate",
|
||||||
"Connection": "keep-alive",
|
"Connection": "keep-alive",
|
||||||
"X-Accel-Buffering": "no",
|
"X-Accel-Buffering": "no", # Nginx
|
||||||
|
"X-Content-Type-Options": "nosniff",
|
||||||
|
"Access-Control-Allow-Origin": "*", # Adjust for your CORS needs
|
||||||
|
"Transfer-Encoding": "chunked",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -678,19 +683,19 @@ async def create_candidate_ai(
|
|||||||
async for generated_message in generate_agent.generate(
|
async for generated_message in generate_agent.generate(
|
||||||
llm=llm_manager.get_llm(),
|
llm=llm_manager.get_llm(),
|
||||||
model=defines.model,
|
model=defines.model,
|
||||||
user_message=user_message,
|
session_id=user_message.session_id,
|
||||||
user=None,
|
prompt=user_message.content,
|
||||||
):
|
):
|
||||||
if generated_message.status == ChatStatusType.ERROR:
|
if generated_message.status == ApiStatusType.ERROR:
|
||||||
logger.error(f"❌ AI generation error: {generated_message.content}")
|
logger.error(f"❌ AI generation error: {generated_message.content}")
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
content=create_error_response("AI_GENERATION_ERROR", generated_message.content)
|
content=create_error_response("AI_GENERATION_ERROR", generated_message.content)
|
||||||
)
|
)
|
||||||
if generated_message.type == ChatMessageType.RESPONSE and state == 0:
|
if generated_message.status == ApiStatusType.DONE and state == 0:
|
||||||
persona_message = generated_message
|
persona_message = generated_message
|
||||||
state = 1 # Switch to resume generation
|
state = 1 # Switch to resume generation
|
||||||
elif generated_message.type == ChatMessageType.RESPONSE and state == 1:
|
elif generated_message.status == ApiStatusType.DONE and state == 1:
|
||||||
resume_message = generated_message
|
resume_message = generated_message
|
||||||
|
|
||||||
if not persona_message:
|
if not persona_message:
|
||||||
@ -2412,7 +2417,6 @@ async def update_candidate(
|
|||||||
)
|
)
|
||||||
|
|
||||||
is_AI = candidate_data.get("is_AI", False)
|
is_AI = candidate_data.get("is_AI", False)
|
||||||
logger.info(json.dumps(candidate_data, indent=2))
|
|
||||||
candidate = CandidateAI.model_validate(candidate_data) if is_AI else Candidate.model_validate(candidate_data)
|
candidate = CandidateAI.model_validate(candidate_data) if is_AI else Candidate.model_validate(candidate_data)
|
||||||
|
|
||||||
# Check authorization (user can only update their own profile)
|
# Check authorization (user can only update their own profile)
|
||||||
@ -2804,8 +2808,8 @@ async def post_candidate_rag_search(
|
|||||||
async for generated_message in chat_agent.generate(
|
async for generated_message in chat_agent.generate(
|
||||||
llm=llm_manager.get_llm(),
|
llm=llm_manager.get_llm(),
|
||||||
model=defines.model,
|
model=defines.model,
|
||||||
user_message=user_message,
|
session_id=user_message.session_id,
|
||||||
user=candidate,
|
prompt=user_message.prompt,
|
||||||
):
|
):
|
||||||
rag_message = generated_message
|
rag_message = generated_message
|
||||||
|
|
||||||
@ -3087,7 +3091,6 @@ async def post_chat_session_message_stream(
|
|||||||
return await stream_agent_response(
|
return await stream_agent_response(
|
||||||
chat_agent=chat_agent,
|
chat_agent=chat_agent,
|
||||||
user_message=user_message,
|
user_message=user_message,
|
||||||
candidate=candidate,
|
|
||||||
database=database,
|
database=database,
|
||||||
chat_session_data=chat_session_data,
|
chat_session_data=chat_session_data,
|
||||||
)
|
)
|
||||||
@ -3389,15 +3392,10 @@ async def get_candidate_skill_match(
|
|||||||
agent.generate(
|
agent.generate(
|
||||||
llm=llm_manager.get_llm(),
|
llm=llm_manager.get_llm(),
|
||||||
model=defines.model,
|
model=defines.model,
|
||||||
user_message=ChatMessageUser(
|
session_id="",
|
||||||
sender_id=candidate.id,
|
prompt=requirement,
|
||||||
session_id="",
|
|
||||||
content=requirement,
|
|
||||||
timestamp=datetime.now(UTC)
|
|
||||||
),
|
),
|
||||||
user=candidate,
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if skill_match is None:
|
if skill_match is None:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
|
@ -71,8 +71,11 @@ class InterviewRecommendation(str, Enum):
|
|||||||
class ChatSenderType(str, Enum):
|
class ChatSenderType(str, Enum):
|
||||||
USER = "user"
|
USER = "user"
|
||||||
ASSISTANT = "assistant"
|
ASSISTANT = "assistant"
|
||||||
AGENT = "agent"
|
# Frontend can use this to set mock responses
|
||||||
SYSTEM = "system"
|
SYSTEM = "system"
|
||||||
|
INFORMATION = "information"
|
||||||
|
WARNING = "warning"
|
||||||
|
ERROR = "error"
|
||||||
|
|
||||||
class Requirements(BaseModel):
|
class Requirements(BaseModel):
|
||||||
required: List[str] = Field(default_factory=list)
|
required: List[str] = Field(default_factory=list)
|
||||||
@ -108,24 +111,12 @@ class SkillMatch(BaseModel):
|
|||||||
"populate_by_name": True # Allow both field names and aliases
|
"populate_by_name": True # Allow both field names and aliases
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class ApiMessageType(str, Enum):
|
||||||
|
BINARY = "binary"
|
||||||
|
TEXT = "text"
|
||||||
|
JSON = "json"
|
||||||
|
|
||||||
class ChatMessageType(str, Enum):
|
class ApiStatusType(str, Enum):
|
||||||
ERROR = "error"
|
|
||||||
GENERATING = "generating"
|
|
||||||
INFO = "info"
|
|
||||||
PREPARING = "preparing"
|
|
||||||
PROCESSING = "processing"
|
|
||||||
HEARTBEAT = "heartbeat"
|
|
||||||
RESPONSE = "response"
|
|
||||||
SEARCHING = "searching"
|
|
||||||
RAG_RESULT = "rag_result"
|
|
||||||
SYSTEM = "system"
|
|
||||||
THINKING = "thinking"
|
|
||||||
TOOLING = "tooling"
|
|
||||||
USER = "user"
|
|
||||||
|
|
||||||
class ChatStatusType(str, Enum):
|
|
||||||
INITIALIZING = "initializing"
|
|
||||||
STREAMING = "streaming"
|
STREAMING = "streaming"
|
||||||
STATUS = "status"
|
STATUS = "status"
|
||||||
DONE = "done"
|
DONE = "done"
|
||||||
@ -772,26 +763,55 @@ class LLMMessage(BaseModel):
|
|||||||
content: str = Field(default="")
|
content: str = Field(default="")
|
||||||
tool_calls: Optional[List[Dict]] = Field(default=[], exclude=True)
|
tool_calls: Optional[List[Dict]] = Field(default=[], exclude=True)
|
||||||
|
|
||||||
|
class ApiMessage(BaseModel):
|
||||||
class ChatMessageBase(BaseModel):
|
|
||||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
session_id: str = Field(..., alias="sessionId")
|
session_id: str = Field(..., alias="sessionId")
|
||||||
sender_id: Optional[str] = Field(None, alias="senderId")
|
sender_id: Optional[str] = Field(None, alias="senderId")
|
||||||
status: ChatStatusType #= ChatStatusType.INITIALIZING
|
status: ApiStatusType
|
||||||
type: ChatMessageType #= ChatMessageType.PREPARING
|
type: ApiMessageType
|
||||||
sender: ChatSenderType #= ChatSenderType.SYSTEM
|
|
||||||
timestamp: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="timestamp")
|
timestamp: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="timestamp")
|
||||||
tunables: Optional[Tunables] = None
|
|
||||||
content: str = ""
|
|
||||||
model_config = {
|
model_config = {
|
||||||
"populate_by_name": True # Allow both field names and aliases
|
"populate_by_name": True # Allow both field names and aliases
|
||||||
}
|
}
|
||||||
|
|
||||||
class ChatMessageRagSearch(ChatMessageBase):
|
class ChatMessageStreaming(ApiMessage):
|
||||||
status: ChatStatusType = ChatStatusType.DONE
|
status: ApiStatusType = ApiStatusType.STREAMING
|
||||||
type: ChatMessageType = ChatMessageType.RAG_RESULT
|
type: ApiMessageType = ApiMessageType.TEXT
|
||||||
sender: ChatSenderType = ChatSenderType.USER
|
content: str
|
||||||
|
|
||||||
|
class ApiActivityType(str, Enum):
|
||||||
|
SYSTEM = "system" # Used solely on frontend
|
||||||
|
INFO = "info" # Used solely on frontend
|
||||||
|
SEARCHING = "searching" # Used when generating RAG information
|
||||||
|
THINKING = "thinking" # Used when determing if AI will use tools
|
||||||
|
GENERATING = "generating" # Used when AI is generating a response
|
||||||
|
GENERATING_IMAGE = "generating_image" # Used when AI is generating an image
|
||||||
|
TOOLING = "tooling" # Used when AI is using tools
|
||||||
|
HEARTBEAT = "heartbeat" # Used for periodic updates
|
||||||
|
|
||||||
|
class ChatMessageStatus(ApiMessage):
|
||||||
|
status: ApiStatusType = ApiStatusType.STATUS
|
||||||
|
type: ApiMessageType = ApiMessageType.TEXT
|
||||||
|
activity: ApiActivityType
|
||||||
|
content: Any
|
||||||
|
|
||||||
|
class ChatMessageError(ApiMessage):
|
||||||
|
status: ApiStatusType = ApiStatusType.ERROR
|
||||||
|
type: ApiMessageType = ApiMessageType.TEXT
|
||||||
|
content: str
|
||||||
|
|
||||||
|
class ChatMessageRagSearch(ApiMessage):
|
||||||
|
type: ApiMessageType = ApiMessageType.JSON
|
||||||
dimensions: int = 2 | 3
|
dimensions: int = 2 | 3
|
||||||
|
content: List[ChromaDBGetResponse] = []
|
||||||
|
|
||||||
|
class JobRequirementsMessage(ApiMessage):
|
||||||
|
type: ApiMessageType = ApiMessageType.JSON
|
||||||
|
title: Optional[str]
|
||||||
|
summary: Optional[str]
|
||||||
|
company: Optional[str]
|
||||||
|
description: str
|
||||||
|
requirements: Optional[JobRequirements]
|
||||||
|
|
||||||
class ChatMessageMetaData(BaseModel):
|
class ChatMessageMetaData(BaseModel):
|
||||||
model: AIModelType = AIModelType.QWEN2_5
|
model: AIModelType = AIModelType.QWEN2_5
|
||||||
@ -808,23 +828,26 @@ class ChatMessageMetaData(BaseModel):
|
|||||||
prompt_eval_count: int = 0
|
prompt_eval_count: int = 0
|
||||||
prompt_eval_duration: int = 0
|
prompt_eval_duration: int = 0
|
||||||
options: Optional[ChatOptions] = None
|
options: Optional[ChatOptions] = None
|
||||||
tools: Optional[Dict[str, Any]] = None
|
tools: Dict[str, Any] = Field(default_factory=dict)
|
||||||
timers: Optional[Dict[str, float]] = None
|
timers: Dict[str, float] = Field(default_factory=dict)
|
||||||
model_config = {
|
model_config = {
|
||||||
"populate_by_name": True # Allow both field names and aliases
|
"populate_by_name": True # Allow both field names and aliases
|
||||||
}
|
}
|
||||||
|
|
||||||
class ChatMessageUser(ChatMessageBase):
|
class ChatMessageUser(ApiMessage):
|
||||||
status: ChatStatusType = ChatStatusType.INITIALIZING
|
type: ApiMessageType = ApiMessageType.TEXT
|
||||||
type: ChatMessageType = ChatMessageType.GENERATING
|
status: ApiStatusType = ApiStatusType.DONE
|
||||||
sender: ChatSenderType = ChatSenderType.USER
|
role: ChatSenderType = ChatSenderType.USER
|
||||||
|
content: str = ""
|
||||||
|
tunables: Optional[Tunables] = None
|
||||||
|
|
||||||
class ChatMessage(ChatMessageBase):
|
class ChatMessage(ChatMessageUser):
|
||||||
|
role: ChatSenderType = ChatSenderType.ASSISTANT
|
||||||
|
metadata: ChatMessageMetaData = Field(default_factory=ChatMessageMetaData)
|
||||||
#attachments: Optional[List[Attachment]] = None
|
#attachments: Optional[List[Attachment]] = None
|
||||||
#reactions: Optional[List[MessageReaction]] = None
|
#reactions: Optional[List[MessageReaction]] = None
|
||||||
#is_edited: bool = Field(False, alias="isEdited")
|
#is_edited: bool = Field(False, alias="isEdited")
|
||||||
#edit_history: Optional[List[EditHistory]] = Field(None, alias="editHistory")
|
#edit_history: Optional[List[EditHistory]] = Field(None, alias="editHistory")
|
||||||
metadata: ChatMessageMetaData = Field(default_factory=ChatMessageMetaData)
|
|
||||||
|
|
||||||
class ChatSession(BaseModel):
|
class ChatSession(BaseModel):
|
||||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
@ -27,8 +27,9 @@ from .markdown_chunker import (
|
|||||||
|
|
||||||
# When imported as a module, use relative imports
|
# When imported as a module, use relative imports
|
||||||
import defines
|
import defines
|
||||||
|
from models import ChromaDBGetResponse
|
||||||
|
|
||||||
__all__ = ["ChromaDBFileWatcher", "start_file_watcher", "ChromaDBGetResponse"]
|
__all__ = ["ChromaDBFileWatcher", "start_file_watcher"]
|
||||||
|
|
||||||
DEFAULT_CHUNK_SIZE = 750
|
DEFAULT_CHUNK_SIZE = 750
|
||||||
DEFAULT_CHUNK_OVERLAP = 100
|
DEFAULT_CHUNK_OVERLAP = 100
|
||||||
@ -38,46 +39,6 @@ class RagEntry(BaseModel):
|
|||||||
description: str = ""
|
description: str = ""
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
|
|
||||||
class ChromaDBGetResponse(BaseModel):
|
|
||||||
name: str = ""
|
|
||||||
size: int = 0
|
|
||||||
ids: List[str] = []
|
|
||||||
embeddings: List[List[float]] = Field(default=[])
|
|
||||||
documents: List[str] = []
|
|
||||||
metadatas: List[Dict[str, Any]] = []
|
|
||||||
query: str = ""
|
|
||||||
query_embedding: Optional[List[float]] = Field(default=None)
|
|
||||||
umap_embedding_2d: Optional[List[float]] = Field(default=None)
|
|
||||||
umap_embedding_3d: Optional[List[float]] = Field(default=None)
|
|
||||||
enabled: bool = True
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
validate_assignment = True
|
|
||||||
|
|
||||||
@field_validator("embeddings", "query_embedding", "umap_embedding_2d", "umap_embedding_3d")
|
|
||||||
@classmethod
|
|
||||||
def validate_embeddings(cls, value, field):
|
|
||||||
# logging.info(f"Validating {field.field_name} with value: {type(value)} - {value}")
|
|
||||||
if value is None:
|
|
||||||
return value
|
|
||||||
if isinstance(value, np.ndarray):
|
|
||||||
if field.field_name == "embeddings":
|
|
||||||
if value.ndim != 2:
|
|
||||||
raise ValueError(f"{field.name} must be a 2-dimensional NumPy array")
|
|
||||||
return [[float(x) for x in row] for row in value.tolist()]
|
|
||||||
else:
|
|
||||||
if value.ndim != 1:
|
|
||||||
raise ValueError(f"{field.field_name} must be a 1-dimensional NumPy array")
|
|
||||||
return [float(x) for x in value.tolist()]
|
|
||||||
if field.field_name == "embeddings":
|
|
||||||
if not all(isinstance(sublist, list) and all(isinstance(x, (int, float)) for x in sublist) for sublist in value):
|
|
||||||
raise ValueError(f"{field.field_name} must be a list of lists of floats")
|
|
||||||
return [[float(x) for x in sublist] for sublist in value]
|
|
||||||
else:
|
|
||||||
if not isinstance(value, list) or not all(isinstance(x, (int, float)) for x in value):
|
|
||||||
raise ValueError(f"{field.field_name} must be a list of floats")
|
|
||||||
return [float(x) for x in value]
|
|
||||||
|
|
||||||
class ChromaDBFileWatcher(FileSystemEventHandler):
|
class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user