Multi-LLM backend working with Ollama

This commit is contained in:
James Ketr 2025-06-07 15:03:00 -07:00
parent 5fed56ba76
commit 8aa8577874
25 changed files with 3802 additions and 357 deletions

View File

@ -0,0 +1,541 @@
import React, { useState, useEffect, useRef, JSX } from 'react';
import {
Box,
Button,
Typography,
Paper,
TextField,
Grid,
Dialog,
DialogTitle,
DialogContent,
DialogContentText,
DialogActions,
IconButton,
useTheme,
useMediaQuery,
Chip,
Divider,
Card,
CardContent,
CardHeader,
LinearProgress,
Stack,
Alert
} from '@mui/material';
import {
SyncAlt,
Favorite,
Settings,
Info,
Search,
AutoFixHigh,
Image,
Psychology,
Build,
CloudUpload,
Description,
Business,
LocationOn,
Work,
CheckCircle,
Star
} from '@mui/icons-material';
import { styled } from '@mui/material/styles';
import DescriptionIcon from '@mui/icons-material/Description';
import FileUploadIcon from '@mui/icons-material/FileUpload';
import { useAuth } from 'hooks/AuthContext';
import { useSelectedCandidate, useSelectedJob } from 'hooks/GlobalContext';
import { BackstoryElementProps } from './BackstoryTab';
import { LoginRequired } from 'components/ui/LoginRequired';
import * as Types from 'types/types';
const VisuallyHiddenInput = styled('input')({
clip: 'rect(0 0 0 0)',
clipPath: 'inset(50%)',
height: 1,
overflow: 'hidden',
position: 'absolute',
bottom: 0,
left: 0,
whiteSpace: 'nowrap',
width: 1,
});
const UploadBox = styled(Box)(({ theme }) => ({
border: `2px dashed ${theme.palette.primary.main}`,
borderRadius: theme.shape.borderRadius * 2,
padding: theme.spacing(4),
textAlign: 'center',
backgroundColor: theme.palette.action.hover,
transition: 'all 0.3s ease',
cursor: 'pointer',
'&:hover': {
backgroundColor: theme.palette.action.selected,
borderColor: theme.palette.primary.dark,
},
}));
const StatusBox = styled(Box)(({ theme }) => ({
display: 'flex',
alignItems: 'center',
gap: theme.spacing(1),
padding: theme.spacing(1, 2),
backgroundColor: theme.palette.background.paper,
borderRadius: theme.shape.borderRadius,
border: `1px solid ${theme.palette.divider}`,
minHeight: 48,
}));
const getIcon = (type: Types.ApiActivityType) => {
switch (type) {
case 'converting':
return <SyncAlt color="primary" />;
case 'heartbeat':
return <Favorite color="error" />;
case 'system':
return <Settings color="action" />;
case 'info':
return <Info color="info" />;
case 'searching':
return <Search color="primary" />;
case 'generating':
return <AutoFixHigh color="secondary" />;
case 'generating_image':
return <Image color="primary" />;
case 'thinking':
return <Psychology color="secondary" />;
case 'tooling':
return <Build color="action" />;
default:
return <Info color="action" />;
}
};
interface JobCreator extends BackstoryElementProps {
onSave?: (job: Types.Job) => void;
}
const JobCreator = (props: JobCreator) => {
const { user, apiClient } = useAuth();
const { onSave } = props;
const { selectedCandidate } = useSelectedCandidate();
const { selectedJob, setSelectedJob } = useSelectedJob();
const { setSnack, submitQuery } = props;
const backstoryProps = { setSnack, submitQuery };
const theme = useTheme();
const isMobile = useMediaQuery(theme.breakpoints.down('sm'));
const isTablet = useMediaQuery(theme.breakpoints.down('md'));
const [openUploadDialog, setOpenUploadDialog] = useState<boolean>(false);
const [jobDescription, setJobDescription] = useState<string>('');
const [jobRequirements, setJobRequirements] = useState<Types.JobRequirements | null>(null);
const [jobTitle, setJobTitle] = useState<string>('');
const [company, setCompany] = useState<string>('');
const [summary, setSummary] = useState<string>('');
const [jobLocation, setJobLocation] = useState<string>('');
const [jobId, setJobId] = useState<string>('');
const [jobStatus, setJobStatus] = useState<string>('');
const [jobStatusIcon, setJobStatusIcon] = useState<JSX.Element>(<></>);
const [isProcessing, setIsProcessing] = useState<boolean>(false);
useEffect(() => {
}, [jobTitle, jobDescription, company]);
const fileInputRef = useRef<HTMLInputElement>(null);
if (!user?.id) {
return (
<LoginRequired asset="job creation" />
);
}
const jobStatusHandlers = {
onStatus: (status: Types.ChatMessageStatus) => {
console.log('status:', status.content);
setJobStatusIcon(getIcon(status.activity));
setJobStatus(status.content);
},
onMessage: (job: Types.Job) => {
console.log('onMessage - job', job);
setCompany(job.company || '');
setJobDescription(job.description);
setSummary(job.summary || '');
setJobTitle(job.title || '');
setJobRequirements(job.requirements || null);
setJobStatusIcon(<></>);
setJobStatus('');
},
onError: (error: Types.ChatMessageError) => {
console.log('onError', error);
setSnack(error.content, "error");
setIsProcessing(false);
},
onComplete: () => {
setJobStatusIcon(<></>);
setJobStatus('');
setIsProcessing(false);
}
};
const handleJobUpload = async (e: React.ChangeEvent<HTMLInputElement>) => {
if (e.target.files && e.target.files[0]) {
const file = e.target.files[0];
const fileExtension = '.' + file.name.split('.').pop()?.toLowerCase();
let docType: Types.DocumentType | null = null;
switch (fileExtension.substring(1)) {
case "pdf":
docType = "pdf";
break;
case "docx":
docType = "docx";
break;
case "md":
docType = "markdown";
break;
case "txt":
docType = "txt";
break;
}
if (!docType) {
setSnack('Invalid file type. Please upload .txt, .md, .docx, or .pdf files only.', 'error');
return;
}
try {
setIsProcessing(true);
setJobDescription('');
setJobTitle('');
setJobRequirements(null);
setSummary('');
const controller = apiClient.createJobFromFile(file, jobStatusHandlers);
const job = await controller.promise;
if (!job) {
return;
}
console.log(`Job id: ${job.id}`);
e.target.value = '';
} catch (error) {
console.error(error);
setSnack('Failed to upload document', 'error');
setIsProcessing(false);
}
}
};
const handleUploadClick = () => {
fileInputRef.current?.click();
};
const renderRequirementSection = (title: string, items: string[] | undefined, icon: JSX.Element, required = false) => {
if (!items || items.length === 0) return null;
return (
<Box sx={{ mb: 3 }}>
<Box sx={{ display: 'flex', alignItems: 'center', mb: 1.5 }}>
{icon}
<Typography variant="subtitle1" sx={{ ml: 1, fontWeight: 600 }}>
{title}
</Typography>
{required && <Chip label="Required" size="small" color="error" sx={{ ml: 1 }} />}
</Box>
<Stack direction="row" spacing={1} flexWrap="wrap" useFlexGap>
{items.map((item, index) => (
<Chip
key={index}
label={item}
variant="outlined"
size="small"
sx={{ mb: 1 }}
/>
))}
</Stack>
</Box>
);
};
const renderJobRequirements = () => {
if (!jobRequirements) return null;
return (
<Card elevation={2} sx={{ mt: 3 }}>
<CardHeader
title="Job Requirements Analysis"
avatar={<CheckCircle color="success" />}
sx={{ pb: 1 }}
/>
<CardContent sx={{ pt: 0 }}>
{renderRequirementSection(
"Technical Skills (Required)",
jobRequirements.technicalSkills.required,
<Build color="primary" />,
true
)}
{renderRequirementSection(
"Technical Skills (Preferred)",
jobRequirements.technicalSkills.preferred,
<Build color="action" />
)}
{renderRequirementSection(
"Experience Requirements (Required)",
jobRequirements.experienceRequirements.required,
<Work color="primary" />,
true
)}
{renderRequirementSection(
"Experience Requirements (Preferred)",
jobRequirements.experienceRequirements.preferred,
<Work color="action" />
)}
{renderRequirementSection(
"Soft Skills",
jobRequirements.softSkills,
<Psychology color="secondary" />
)}
{renderRequirementSection(
"Experience",
jobRequirements.experience,
<Star color="warning" />
)}
{renderRequirementSection(
"Education",
jobRequirements.education,
<Description color="info" />
)}
{renderRequirementSection(
"Certifications",
jobRequirements.certifications,
<CheckCircle color="success" />
)}
{renderRequirementSection(
"Preferred Attributes",
jobRequirements.preferredAttributes,
<Star color="secondary" />
)}
</CardContent>
</Card>
);
};
const handleSave = async () => {
const newJob: Types.Job = {
ownerId: user?.id || '',
ownerType: 'candidate',
description: jobDescription,
company: company,
summary: summary,
title: jobTitle,
requirements: jobRequirements || undefined
};
setIsProcessing(true);
const job = await apiClient.createJob(newJob);
setIsProcessing(false);
onSave ? onSave(job) : setSelectedJob(job);
};
const handleExtractRequirements = () => {
// Implement requirements extraction logic here
setIsProcessing(true);
// This would call your API to extract requirements from the job description
};
const renderJobCreation = () => {
if (!user) {
return <Box>You must be logged in</Box>;
}
return (
<Box sx={{
mx: 'auto', p: { xs: 2, sm: 3 },
}}>
{/* Upload Section */}
<Card elevation={3} sx={{ mb: 4 }}>
<CardHeader
title="Job Information"
subheader="Upload a job description or enter details manually"
avatar={<Work color="primary" />}
/>
<CardContent>
<Grid container spacing={3}>
<Grid size={{ xs: 12, md: 6 }}>
<Typography variant="h6" gutterBottom sx={{ display: 'flex', alignItems: 'center' }}>
<CloudUpload sx={{ mr: 1 }} />
Upload Job Description
</Typography>
<UploadBox onClick={handleUploadClick}>
<CloudUpload sx={{ fontSize: 48, color: 'primary.main', mb: 2 }} />
<Typography variant="h6" gutterBottom>
Drop your job description here
</Typography>
<Typography variant="body2" color="text.secondary" sx={{ mb: 2 }}>
Supported formats: PDF, DOCX, TXT, MD
</Typography>
<Button
variant="contained"
startIcon={<FileUploadIcon />}
disabled={isProcessing}
// onClick={handleUploadClick}
>
Choose File
</Button>
</UploadBox>
<VisuallyHiddenInput
ref={fileInputRef}
type="file"
accept=".txt,.md,.docx,.pdf"
onChange={handleJobUpload}
/>
</Grid>
<Grid size={{ xs: 12, md: 6 }}>
<Typography variant="h6" gutterBottom sx={{ display: 'flex', alignItems: 'center' }}>
<Description sx={{ mr: 1 }} />
Or Enter Manually
</Typography>
<TextField
fullWidth
multiline
rows={isMobile ? 8 : 12}
placeholder="Paste or type the job description here..."
variant="outlined"
value={jobDescription}
onChange={(e) => setJobDescription(e.target.value)}
disabled={isProcessing}
sx={{ mb: 2 }}
/>
{jobRequirements === null && jobDescription && (
<Button
variant="outlined"
onClick={handleExtractRequirements}
startIcon={<AutoFixHigh />}
disabled={isProcessing}
fullWidth={isMobile}
>
Extract Requirements
</Button>
)}
</Grid>
</Grid>
{(jobStatus || isProcessing) && (
<Box sx={{ mt: 3 }}>
<StatusBox>
{jobStatusIcon}
<Typography variant="body2" sx={{ ml: 1 }}>
{jobStatus || 'Processing...'}
</Typography>
</StatusBox>
{isProcessing && <LinearProgress sx={{ mt: 1 }} />}
</Box>
)}
</CardContent>
</Card>
{/* Job Details Section */}
<Card elevation={3} sx={{ mb: 4 }}>
<CardHeader
title="Job Details"
subheader="Enter specific information about the position"
avatar={<Business color="primary" />}
/>
<CardContent>
<Grid container spacing={3}>
<Grid size={{ xs: 12, md: 6 }}>
<TextField
fullWidth
label="Job Title"
variant="outlined"
value={jobTitle}
onChange={(e) => setJobTitle(e.target.value)}
required
disabled={isProcessing}
InputProps={{
startAdornment: <Work sx={{ mr: 1, color: 'text.secondary' }} />
}}
/>
</Grid>
<Grid size={{ xs: 12, md: 6 }}>
<TextField
fullWidth
label="Company"
variant="outlined"
value={company}
onChange={(e) => setCompany(e.target.value)}
required
disabled={isProcessing}
InputProps={{
startAdornment: <Business sx={{ mr: 1, color: 'text.secondary' }} />
}}
/>
</Grid>
{/* <Grid size={{ xs: 12, md: 6 }}>
<TextField
fullWidth
label="Job Location"
variant="outlined"
value={jobLocation}
onChange={(e) => setJobLocation(e.target.value)}
disabled={isProcessing}
InputProps={{
startAdornment: <LocationOn sx={{ mr: 1, color: 'text.secondary' }} />
}}
/>
</Grid> */}
<Grid size={{ xs: 12, md: 6 }}>
<Box sx={{ display: 'flex', gap: 2, alignItems: 'flex-end', height: '100%' }}>
<Button
variant="contained"
onClick={handleSave}
disabled={!jobTitle || !company || !jobDescription || isProcessing}
fullWidth={isMobile}
size="large"
startIcon={<CheckCircle />}
>
Save Job
</Button>
</Box>
</Grid>
</Grid>
</CardContent>
</Card>
{/* Job Summary */}
{summary !== '' &&
<Card elevation={2} sx={{ mt: 3 }}>
<CardHeader
title="Job Summary"
avatar={<CheckCircle color="success" />}
sx={{ pb: 1 }}
/>
<CardContent sx={{ pt: 0 }}>
{summary}
</CardContent>
</Card>
}
{/* Requirements Display */}
{renderJobRequirements()}
</Box>
);
};
return (
<Box className="JobManagement"
sx={{
background: "white",
p: 0,
}}>
{selectedJob === null && renderJobCreation()}
</Box>
);
};
export { JobCreator };

View File

@ -115,34 +115,23 @@ const getIcon = (type: Types.ApiActivityType) => {
}; };
const JobManagement = (props: BackstoryElementProps) => { const JobManagement = (props: BackstoryElementProps) => {
const { user, apiClient } = useAuth(); const { user, apiClient } = useAuth();
const { selectedCandidate } = useSelectedCandidate();
const { selectedJob, setSelectedJob } = useSelectedJob(); const { selectedJob, setSelectedJob } = useSelectedJob();
const { setSnack, submitQuery } = props; const { setSnack, submitQuery } = props;
const backstoryProps = { setSnack, submitQuery };
const theme = useTheme(); const theme = useTheme();
const isMobile = useMediaQuery(theme.breakpoints.down('sm')); const isMobile = useMediaQuery(theme.breakpoints.down('sm'));
const isTablet = useMediaQuery(theme.breakpoints.down('md'));
const [openUploadDialog, setOpenUploadDialog] = useState<boolean>(false);
const [jobDescription, setJobDescription] = useState<string>(''); const [jobDescription, setJobDescription] = useState<string>('');
const [jobRequirements, setJobRequirements] = useState<Types.JobRequirements | null>(null); const [jobRequirements, setJobRequirements] = useState<Types.JobRequirements | null>(null);
const [jobTitle, setJobTitle] = useState<string>(''); const [jobTitle, setJobTitle] = useState<string>('');
const [company, setCompany] = useState<string>(''); const [company, setCompany] = useState<string>('');
const [summary, setSummary] = useState<string>(''); const [summary, setSummary] = useState<string>('');
const [jobLocation, setJobLocation] = useState<string>('');
const [jobId, setJobId] = useState<string>('');
const [jobStatus, setJobStatus] = useState<string>(''); const [jobStatus, setJobStatus] = useState<string>('');
const [jobStatusIcon, setJobStatusIcon] = useState<JSX.Element>(<></>); const [jobStatusIcon, setJobStatusIcon] = useState<JSX.Element>(<></>);
const [isProcessing, setIsProcessing] = useState<boolean>(false); const [isProcessing, setIsProcessing] = useState<boolean>(false);
useEffect(() => {
}, [jobTitle, jobDescription, company]);
const fileInputRef = useRef<HTMLInputElement>(null); const fileInputRef = useRef<HTMLInputElement>(null);
if (!user?.id) { if (!user?.id) {
return ( return (
<LoginRequired asset="candidate analysis" /> <LoginRequired asset="candidate analysis" />
@ -223,12 +212,12 @@ const JobManagement = (props: BackstoryElementProps) => {
setJobTitle(''); setJobTitle('');
setJobRequirements(null); setJobRequirements(null);
setSummary(''); setSummary('');
const controller = apiClient.uploadCandidateDocument(file, { isJobDocument: true, overwrite: true }, documentStatusHandlers); const controller = apiClient.createJobFromFile(file, jobStatusHandlers);
const document = await controller.promise; const job = await controller.promise;
if (!document) { if (!job) {
return; return;
} }
console.log(`Document id: ${document.id}`); console.log(`Job id: ${job.id}`);
e.target.value = ''; e.target.value = '';
} catch (error) { } catch (error) {
console.error(error); console.error(error);
@ -354,10 +343,6 @@ const JobManagement = (props: BackstoryElementProps) => {
// This would call your API to extract requirements from the job description // This would call your API to extract requirements from the job description
}; };
const loadJob = async () => {
const job = await apiClient.getJob("7594e989-a926-45a2-9b07-ae553d2e0d0d");
setSelectedJob(job);
}
const renderJobCreation = () => { const renderJobCreation = () => {
if (!user) { if (!user) {
return <Box>You must be logged in</Box>; return <Box>You must be logged in</Box>;
@ -367,7 +352,6 @@ const JobManagement = (props: BackstoryElementProps) => {
<Box sx={{ <Box sx={{
mx: 'auto', p: { xs: 2, sm: 3 }, mx: 'auto', p: { xs: 2, sm: 3 },
}}> }}>
<Button onClick={loadJob} variant="contained">Load Job</Button>
{/* Upload Section */} {/* Upload Section */}
<Card elevation={3} sx={{ mb: 4 }}> <Card elevation={3} sx={{ mb: 4 }}>
<CardHeader <CardHeader

View File

@ -14,7 +14,8 @@ import {
CardContent, CardContent,
useTheme, useTheme,
LinearProgress, LinearProgress,
useMediaQuery useMediaQuery,
Button
} from '@mui/material'; } from '@mui/material';
import ExpandMoreIcon from '@mui/icons-material/ExpandMore'; import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
import CheckCircleIcon from '@mui/icons-material/CheckCircle'; import CheckCircleIcon from '@mui/icons-material/CheckCircle';
@ -28,6 +29,8 @@ import { toCamelCase } from 'types/conversion';
import { Job } from 'types/types'; import { Job } from 'types/types';
import { StyledMarkdown } from './StyledMarkdown'; import { StyledMarkdown } from './StyledMarkdown';
import { Scrollable } from './Scrollable'; import { Scrollable } from './Scrollable';
import { start } from 'repl';
import { TypesElement } from '@uiw/react-json-view';
interface JobAnalysisProps extends BackstoryPageProps { interface JobAnalysisProps extends BackstoryPageProps {
job: Job; job: Job;
@ -56,6 +59,9 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
const [overallScore, setOverallScore] = useState<number>(0); const [overallScore, setOverallScore] = useState<number>(0);
const [requirementsSession, setRequirementsSession] = useState<ChatSession | null>(null); const [requirementsSession, setRequirementsSession] = useState<ChatSession | null>(null);
const [statusMessage, setStatusMessage] = useState<ChatMessage | null>(null); const [statusMessage, setStatusMessage] = useState<ChatMessage | null>(null);
const [startAnalysis, setStartAnalysis] = useState<boolean>(false);
const [analyzing, setAnalyzing] = useState<boolean>(false);
const isMobile = useMediaQuery(theme.breakpoints.down('sm')); const isMobile = useMediaQuery(theme.breakpoints.down('sm'));
// Handle accordion expansion // Handle accordion expansion
@ -63,7 +69,7 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
setExpanded(isExpanded ? panel : false); setExpanded(isExpanded ? panel : false);
}; };
useEffect(() => { const initializeRequirements = (job: Job) => {
if (!job || !job.requirements) { if (!job || !job.requirements) {
return; return;
} }
@ -106,116 +112,19 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
setSkillMatches(initialSkillMatches); setSkillMatches(initialSkillMatches);
setStatusMessage(null); setStatusMessage(null);
setLoadingRequirements(false); setLoadingRequirements(false);
setOverallScore(0);
}, [job, setRequirements]); }
useEffect(() => { useEffect(() => {
if (requirementsSession || creatingSession) { initializeRequirements(job);
return; }, [job]);
}
try {
setCreatingSession(true);
apiClient.getOrCreateChatSession(candidate, `Generate requirements for ${candidate.fullName}`, 'job_requirements')
.then(session => {
setRequirementsSession(session);
setCreatingSession(false);
});
} catch (error) {
setSnack('Unable to load chat session', 'error');
} finally {
setCreatingSession(false);
}
}, [requirementsSession, apiClient, candidate]);
// Fetch initial requirements
useEffect(() => {
if (!job.description || !requirementsSession || loadingRequirements) {
return;
}
const getRequirements = async () => {
setLoadingRequirements(true);
try {
const chatMessage: ChatMessageUser = { ...defaultMessage, sessionId: requirementsSession.id || '', content: job.description };
apiClient.sendMessageStream(chatMessage, {
onMessage: (msg: ChatMessage) => {
console.log(`onMessage: ${msg.type}`, msg);
const job: Job = toCamelCase<Job>(JSON.parse(msg.content || ''));
const requirements: { requirement: string, domain: string }[] = [];
if (job.requirements?.technicalSkills) {
job.requirements.technicalSkills.required?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (required)' }));
job.requirements.technicalSkills.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (preferred)' }));
}
if (job.requirements?.experienceRequirements) {
job.requirements.experienceRequirements.required?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (required)' }));
job.requirements.experienceRequirements.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (preferred)' }));
}
if (job.requirements?.softSkills) {
job.requirements.softSkills.forEach(req => requirements.push({ requirement: req, domain: 'Soft Skills' }));
}
if (job.requirements?.experience) {
job.requirements.experience.forEach(req => requirements.push({ requirement: req, domain: 'Experience' }));
}
if (job.requirements?.education) {
job.requirements.education.forEach(req => requirements.push({ requirement: req, domain: 'Education' }));
}
if (job.requirements?.certifications) {
job.requirements.certifications.forEach(req => requirements.push({ requirement: req, domain: 'Certifications' }));
}
if (job.requirements?.preferredAttributes) {
job.requirements.preferredAttributes.forEach(req => requirements.push({ requirement: req, domain: 'Preferred Attributes' }));
}
const initialSkillMatches = requirements.map(req => ({
requirement: req.requirement,
domain: req.domain,
status: 'waiting' as const,
matchScore: 0,
assessment: '',
description: '',
citations: []
}));
setRequirements(requirements);
setSkillMatches(initialSkillMatches);
setStatusMessage(null);
setLoadingRequirements(false);
},
onError: (error: string | ChatMessageError) => {
console.log("onError:", error);
// Type-guard to determine if this is a ChatMessageBase or a string
if (typeof error === "object" && error !== null && "content" in error) {
setSnack(error.content || 'Error obtaining requirements from job description.', "error");
} else {
setSnack(error as string, "error");
}
setLoadingRequirements(false);
},
onStreaming: (chunk: ChatMessageStreaming) => {
// console.log("onStreaming:", chunk);
},
onStatus: (status: ChatMessageStatus) => {
console.log(`onStatus: ${status}`);
},
onComplete: () => {
console.log("onComplete");
setStatusMessage(null);
setLoadingRequirements(false);
}
});
} catch (error) {
console.error('Failed to send message:', error);
setLoadingRequirements(false);
}
};
getRequirements();
}, [job, requirementsSession]);
// Fetch match data for each requirement // Fetch match data for each requirement
useEffect(() => { useEffect(() => {
if (!startAnalysis || analyzing || !job.requirements) {
return;
}
const fetchMatchData = async () => { const fetchMatchData = async () => {
if (requirements.length === 0) return; if (requirements.length === 0) return;
@ -279,10 +188,9 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
} }
}; };
if (!loadingRequirements) { setAnalyzing(true);
fetchMatchData(); fetchMatchData().then(() => { setAnalyzing(false); setStartAnalysis(false) });
} }, [job, startAnalysis, analyzing, requirements, loadingRequirements]);
}, [requirements, loadingRequirements]);
// Get color based on match score // Get color based on match score
const getMatchColor = (score: number): string => { const getMatchColor = (score: number): string => {
@ -301,6 +209,11 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
return <ErrorIcon color="error" />; return <ErrorIcon color="error" />;
}; };
const beginAnalysis = () => {
initializeRequirements(job);
setStartAnalysis(true);
};
return ( return (
<Box> <Box>
<Paper elevation={3} sx={{ p: 3, mb: 4 }}> <Paper elevation={3} sx={{ p: 3, mb: 4 }}>
@ -344,7 +257,9 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
</Box> </Box>
<Grid size={{ xs: 12 }} sx={{ mt: 2 }}> <Grid size={{ xs: 12 }} sx={{ mt: 2 }}>
<Box sx={{ display: 'flex', alignItems: 'center', mb: 2 }}> <Box sx={{ display: 'flex', alignItems: 'center', mb: 2, gap: 1 }}>
{<Button disabled={analyzing || startAnalysis} onClick={beginAnalysis} variant="contained">Start Analysis</Button>}
{overallScore !== 0 && <>
<Typography variant="h5" component="h2" sx={{ mr: 2 }}> <Typography variant="h5" component="h2" sx={{ mr: 2 }}>
Overall Match: Overall Match:
</Typography> </Typography>
@ -391,6 +306,7 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
fontWeight: 'bold' fontWeight: 'bold'
}} }}
/> />
</>}
</Box> </Box>
</Grid> </Grid>
</Grid> </Grid>

View File

@ -13,7 +13,7 @@ import { CopyBubble } from "components/CopyBubble";
import { rest } from 'lodash'; import { rest } from 'lodash';
import { AIBanner } from 'components/ui/AIBanner'; import { AIBanner } from 'components/ui/AIBanner';
import { useAuth } from 'hooks/AuthContext'; import { useAuth } from 'hooks/AuthContext';
import { DeleteConfirmation } from './DeleteConfirmation'; import { DeleteConfirmation } from '../DeleteConfirmation';
interface CandidateInfoProps { interface CandidateInfoProps {
candidate: Candidate; candidate: Candidate;

View File

@ -4,7 +4,7 @@ import Button from '@mui/material/Button';
import Box from '@mui/material/Box'; import Box from '@mui/material/Box';
import { BackstoryElementProps } from 'components/BackstoryTab'; import { BackstoryElementProps } from 'components/BackstoryTab';
import { CandidateInfo } from 'components/CandidateInfo'; import { CandidateInfo } from 'components/ui/CandidateInfo';
import { Candidate } from "types/types"; import { Candidate } from "types/types";
import { useAuth } from 'hooks/AuthContext'; import { useAuth } from 'hooks/AuthContext';
import { useSelectedCandidate } from 'hooks/GlobalContext'; import { useSelectedCandidate } from 'hooks/GlobalContext';

View File

@ -0,0 +1,97 @@
import React from 'react';
import { Box, Link, Typography, Avatar, Grid, SxProps, CardActions } from '@mui/material';
import {
Card,
CardContent,
Divider,
useTheme,
} from '@mui/material';
import DeleteIcon from '@mui/icons-material/Delete';
import { useMediaQuery } from '@mui/material';
import { JobFull } from 'types/types';
import { CopyBubble } from "components/CopyBubble";
import { rest } from 'lodash';
import { AIBanner } from 'components/ui/AIBanner';
import { useAuth } from 'hooks/AuthContext';
import { DeleteConfirmation } from '../DeleteConfirmation';
interface JobInfoProps {
job: JobFull;
sx?: SxProps;
action?: string;
elevation?: number;
variant?: "small" | "normal" | null
};
const JobInfo: React.FC<JobInfoProps> = (props: JobInfoProps) => {
const { job } = props;
const { user, apiClient } = useAuth();
const {
sx,
action = '',
elevation = 1,
variant = "normal"
} = props;
const theme = useTheme();
const isMobile = useMediaQuery(theme.breakpoints.down('md'));
const isAdmin = user?.isAdmin;
const deleteJob = async (jobId: string | undefined) => {
if (jobId) {
await apiClient.deleteJob(jobId);
}
}
if (!job) {
return <Box>No user loaded.</Box>;
}
return (
<Card
elevation={elevation}
sx={{
display: "flex",
borderColor: 'transparent',
borderWidth: 2,
borderStyle: 'solid',
transition: 'all 0.3s ease',
flexDirection: "column",
...sx
}}
{...rest}
>
<CardContent sx={{ display: "flex", flexGrow: 1, p: 3, height: '100%', flexDirection: 'column', alignItems: 'stretch', position: "relative" }}>
{variant !== "small" && <>
{job.location &&
<Typography variant="body2" sx={{ mb: 1 }}>
<strong>Location:</strong> {job.location.city}, {job.location.state || job.location.country}
</Typography>
}
{job.company &&
<Typography variant="body2" sx={{ mb: 1 }}>
<strong>Company:</strong> {job.company}
</Typography>
}
{job.summary && <Typography variant="body2">
<strong>Summary:</strong> {job.summary}
</Typography>
}
</>}
</CardContent>
<CardActions>
{isAdmin &&
<DeleteConfirmation
onDelete={() => { deleteJob(job.id); }}
sx={{ minWidth: 'auto', px: 2, maxHeight: "min-content", color: "red" }}
action="delete"
label="job"
title="Delete job"
icon=<DeleteIcon />
message={`Are you sure you want to delete ${job.id}? This action cannot be undone.`}
/>}
</CardActions>
</Card>
);
};
export { JobInfo };

View File

@ -0,0 +1,69 @@
import React, { useEffect, useState } from 'react';
import { useNavigate } from "react-router-dom";
import Button from '@mui/material/Button';
import Box from '@mui/material/Box';
import { BackstoryElementProps } from 'components/BackstoryTab';
import { JobInfo } from 'components/ui/JobInfo';
import { Job, JobFull } from "types/types";
import { useAuth } from 'hooks/AuthContext';
import { useSelectedJob } from 'hooks/GlobalContext';
interface JobPickerProps extends BackstoryElementProps {
onSelect?: (job: JobFull) => void
};
const JobPicker = (props: JobPickerProps) => {
const { onSelect } = props;
const { apiClient } = useAuth();
const { selectedJob, setSelectedJob } = useSelectedJob();
const { setSnack } = props;
const [jobs, setJobs] = useState<JobFull[] | null>(null);
useEffect(() => {
if (jobs !== null) {
return;
}
const getJobs = async () => {
try {
const results = await apiClient.getJobs();
const jobs: JobFull[] = results.data;
jobs.sort((a, b) => {
let result = a.company?.localeCompare(b.company || '');
if (result === 0) {
result = a.title?.localeCompare(b.title || '');
}
return result || 0;
});
setJobs(jobs);
} catch (err) {
setSnack("" + err);
}
};
getJobs();
}, [jobs, setSnack]);
return (
<Box sx={{display: "flex", flexDirection: "column", mb: 1}}>
<Box sx={{ display: "flex", gap: 1, flexWrap: "wrap", justifyContent: "center" }}>
{jobs?.map((j, i) =>
<Box key={`${j.id}`}
onClick={() => { onSelect ? onSelect(j) : setSelectedJob(j); }}
sx={{ cursor: "pointer" }}>
{selectedJob?.id === j.id &&
<JobInfo sx={{ maxWidth: "320px", "cursor": "pointer", backgroundColor: "#f0f0f0", "&:hover": { border: "2px solid orange" } }} job={j} />
}
{selectedJob?.id !== j.id &&
<JobInfo sx={{ maxWidth: "320px", "cursor": "pointer", border: "2px solid transparent", "&:hover": { border: "2px solid orange" } }} job={j} />
}
</Box>
)}
</Box>
</Box>
);
};
export {
JobPicker
};

View File

@ -17,7 +17,7 @@ import { ConversationHandle } from 'components/Conversation';
import { BackstoryPageProps } from 'components/BackstoryTab'; import { BackstoryPageProps } from 'components/BackstoryTab';
import { Message } from 'components/Message'; import { Message } from 'components/Message';
import { DeleteConfirmation } from 'components/DeleteConfirmation'; import { DeleteConfirmation } from 'components/DeleteConfirmation';
import { CandidateInfo } from 'components/CandidateInfo'; import { CandidateInfo } from 'components/ui/CandidateInfo';
import { useNavigate } from 'react-router-dom'; import { useNavigate } from 'react-router-dom';
import { useSelectedCandidate } from 'hooks/GlobalContext'; import { useSelectedCandidate } from 'hooks/GlobalContext';
import PropagateLoader from 'react-spinners/PropagateLoader'; import PropagateLoader from 'react-spinners/PropagateLoader';

View File

@ -9,7 +9,7 @@ import CancelIcon from '@mui/icons-material/Cancel';
import SendIcon from '@mui/icons-material/Send'; import SendIcon from '@mui/icons-material/Send';
import PropagateLoader from 'react-spinners/PropagateLoader'; import PropagateLoader from 'react-spinners/PropagateLoader';
import { CandidateInfo } from '../components/CandidateInfo'; import { CandidateInfo } from '../components/ui/CandidateInfo';
import { Quote } from 'components/Quote'; import { Quote } from 'components/Quote';
import { BackstoryElementProps } from 'components/BackstoryTab'; import { BackstoryElementProps } from 'components/BackstoryTab';
import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryTextField'; import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryTextField';

View File

@ -7,31 +7,74 @@ import {
Button, Button,
Typography, Typography,
Paper, Paper,
Avatar,
useTheme, useTheme,
Snackbar, Snackbar,
Container,
Grid,
Alert, Alert,
Tabs,
Tab,
Card,
CardContent,
Divider,
Avatar,
Badge,
} from '@mui/material'; } from '@mui/material';
import {
Person,
PersonAdd,
AccountCircle,
Add,
WorkOutline,
AddCircle,
} from '@mui/icons-material';
import PersonIcon from '@mui/icons-material/Person'; import PersonIcon from '@mui/icons-material/Person';
import WorkIcon from '@mui/icons-material/Work'; import WorkIcon from '@mui/icons-material/Work';
import AssessmentIcon from '@mui/icons-material/Assessment'; import AssessmentIcon from '@mui/icons-material/Assessment';
import { JobMatchAnalysis } from 'components/JobMatchAnalysis'; import { JobMatchAnalysis } from 'components/JobMatchAnalysis';
import { Candidate } from "types/types"; import { Candidate, Job, JobFull } from "types/types";
import { useNavigate } from 'react-router-dom'; import { useNavigate } from 'react-router-dom';
import { BackstoryPageProps } from 'components/BackstoryTab'; import { BackstoryPageProps } from 'components/BackstoryTab';
import { useAuth } from 'hooks/AuthContext'; import { useAuth } from 'hooks/AuthContext';
import { useSelectedCandidate, useSelectedJob } from 'hooks/GlobalContext'; import { useSelectedCandidate, useSelectedJob } from 'hooks/GlobalContext';
import { CandidateInfo } from 'components/CandidateInfo'; import { CandidateInfo } from 'components/ui/CandidateInfo';
import { ComingSoon } from 'components/ui/ComingSoon'; import { ComingSoon } from 'components/ui/ComingSoon';
import { JobManagement } from 'components/JobManagement'; import { JobManagement } from 'components/JobManagement';
import { LoginRequired } from 'components/ui/LoginRequired'; import { LoginRequired } from 'components/ui/LoginRequired';
import { Scrollable } from 'components/Scrollable'; import { Scrollable } from 'components/Scrollable';
import { CandidatePicker } from 'components/ui/CandidatePicker'; import { CandidatePicker } from 'components/ui/CandidatePicker';
import { JobPicker } from 'components/ui/JobPicker';
import { JobCreator } from 'components/JobCreator';
function WorkAddIcon() {
return (
<Box position="relative" display="inline-flex"
sx={{
lineHeight: "30px",
mb: "6px",
}}
>
<WorkOutline sx={{ fontSize: 24 }} />
<Add
sx={{
position: 'absolute',
bottom: -2,
right: -2,
fontSize: 14,
bgcolor: 'background.paper',
borderRadius: '50%',
boxShadow: 1,
}}
color="primary"
/>
</Box>
);
}
// Main component // Main component
const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps) => { const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps) => {
const theme = useTheme(); const theme = useTheme();
const { user } = useAuth(); const { user, apiClient } = useAuth();
const navigate = useNavigate(); const navigate = useNavigate();
const { selectedCandidate, setSelectedCandidate } = useSelectedCandidate() const { selectedCandidate, setSelectedCandidate } = useSelectedCandidate()
const { selectedJob, setSelectedJob } = useSelectedJob() const { selectedJob, setSelectedJob } = useSelectedJob()
@ -41,12 +84,21 @@ const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps
const [activeStep, setActiveStep] = useState(0); const [activeStep, setActiveStep] = useState(0);
const [analysisStarted, setAnalysisStarted] = useState(false); const [analysisStarted, setAnalysisStarted] = useState(false);
const [error, setError] = useState<string | null>(null); const [error, setError] = useState<string | null>(null);
const [jobTab, setJobTab] = useState<string>('load');
useEffect(() => { useEffect(() => {
if (selectedJob && activeStep === 1) { console.log({ activeStep, selectedCandidate, selectedJob });
setActiveStep(2);
if (!selectedCandidate) {
if (activeStep !== 0) {
setActiveStep(0);
}
} else if (!selectedJob) {
if (activeStep !== 1) {
setActiveStep(1);
}
} }
}, [selectedJob, activeStep]); }, [selectedCandidate, selectedJob, activeStep])
// Steps in our process // Steps in our process
const steps = [ const steps = [
@ -93,7 +145,6 @@ const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps
setSelectedJob(null); setSelectedJob(null);
break; break;
case 1: /* Select Job */ case 1: /* Select Job */
setSelectedCandidate(null);
setSelectedJob(null); setSelectedJob(null);
break; break;
case 2: /* Job Analysis */ case 2: /* Job Analysis */
@ -109,6 +160,11 @@ const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps
setActiveStep(1); setActiveStep(1);
} }
const onJobSelect = (job: Job) => {
setSelectedJob(job)
setActiveStep(2);
}
// Render function for the candidate selection step // Render function for the candidate selection step
const renderCandidateSelection = () => ( const renderCandidateSelection = () => (
<Paper elevation={3} sx={{ p: 3, mt: 3, mb: 4, borderRadius: 2 }}> <Paper elevation={3} sx={{ p: 3, mt: 3, mb: 4, borderRadius: 2 }}>
@ -120,16 +176,35 @@ const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps
</Paper> </Paper>
); );
const handleTabChange = (event: React.SyntheticEvent, value: string) => {
setJobTab(value);
};
// Render function for the job description step // Render function for the job description step
const renderJobDescription = () => ( const renderJobDescription = () => {
<Box sx={{ mt: 3 }}> if (!selectedCandidate) {
{selectedCandidate && ( return;
<JobManagement }
return (<Box sx={{ mt: 3 }}>
<Box sx={{ borderBottom: 1, borderColor: 'divider', mb: 3 }}>
<Tabs value={jobTab} onChange={handleTabChange} centered>
<Tab value='load' icon={<WorkOutline />} label="Load" />
<Tab value='create' icon={<WorkAddIcon />} label="Create" />
</Tabs>
</Box>
{jobTab === 'load' &&
<JobPicker {...backstoryProps} onSelect={onJobSelect} />
}
{jobTab === 'create' &&
<JobCreator
{...backstoryProps} {...backstoryProps}
/> onSave={onJobSelect}
)} />}
</Box> </Box>
); );
}
// Render function for the analysis step // Render function for the analysis step
const renderAnalysis = () => ( const renderAnalysis = () => (
@ -232,7 +307,7 @@ const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps
</Button> </Button>
) : ( ) : (
<Button onClick={handleNext} variant="contained"> <Button onClick={handleNext} variant="contained">
{activeStep === steps[steps.length - 1].index - 1 ? 'Start Analysis' : 'Next'} {activeStep === steps[steps.length - 1].index - 1 ? 'Done' : 'Next'}
</Button> </Button>
)} )}
</Box> </Box>

View File

@ -7,7 +7,7 @@ import MuiMarkdown from 'mui-markdown';
import { BackstoryPageProps } from '../components/BackstoryTab'; import { BackstoryPageProps } from '../components/BackstoryTab';
import { Conversation, ConversationHandle } from '../components/Conversation'; import { Conversation, ConversationHandle } from '../components/Conversation';
import { BackstoryQuery } from '../components/BackstoryQuery'; import { BackstoryQuery } from '../components/BackstoryQuery';
import { CandidateInfo } from 'components/CandidateInfo'; import { CandidateInfo } from 'components/ui/CandidateInfo';
import { useAuth } from 'hooks/AuthContext'; import { useAuth } from 'hooks/AuthContext';
import { Candidate } from 'types/types'; import { Candidate } from 'types/types';

View File

@ -52,7 +52,7 @@ interface StreamingOptions<T = Types.ChatMessage> {
signal?: AbortSignal; signal?: AbortSignal;
} }
interface DeleteCandidateResponse { interface DeleteResponse {
success: boolean; success: boolean;
message: string; message: string;
} }
@ -530,14 +530,24 @@ class ApiClient {
return this.handleApiResponseWithConversion<Types.Candidate>(response, 'Candidate'); return this.handleApiResponseWithConversion<Types.Candidate>(response, 'Candidate');
} }
async deleteCandidate(id: string): Promise<DeleteCandidateResponse> { async deleteCandidate(id: string): Promise<DeleteResponse> {
const response = await fetch(`${this.baseUrl}/candidates/${id}`, { const response = await fetch(`${this.baseUrl}/candidates/${id}`, {
method: 'DELETE', method: 'DELETE',
headers: this.defaultHeaders, headers: this.defaultHeaders,
body: JSON.stringify({ id }) body: JSON.stringify({ id })
}); });
return handleApiResponse<DeleteCandidateResponse>(response); return handleApiResponse<DeleteResponse>(response);
}
async deleteJob(id: string): Promise<DeleteResponse> {
const response = await fetch(`${this.baseUrl}/jobs/${id}`, {
method: 'DELETE',
headers: this.defaultHeaders,
body: JSON.stringify({ id })
});
return handleApiResponse<DeleteResponse>(response);
} }
async uploadCandidateProfile(file: File): Promise<boolean> { async uploadCandidateProfile(file: File): Promise<boolean> {
@ -641,7 +651,7 @@ class ApiClient {
return this.handleApiResponseWithConversion<Types.Job>(response, 'Job'); return this.handleApiResponseWithConversion<Types.Job>(response, 'Job');
} }
async getJobs(request: Partial<PaginatedRequest> = {}): Promise<PaginatedResponse<Types.Job>> { async getJobs(request: Partial<PaginatedRequest> = {}): Promise<PaginatedResponse<Types.JobFull>> {
const paginatedRequest = createPaginatedRequest(request); const paginatedRequest = createPaginatedRequest(request);
const params = toUrlParams(formatApiRequest(paginatedRequest)); const params = toUrlParams(formatApiRequest(paginatedRequest));
@ -649,7 +659,7 @@ class ApiClient {
headers: this.defaultHeaders headers: this.defaultHeaders
}); });
return this.handlePaginatedApiResponseWithConversion<Types.Job>(response, 'Job'); return this.handlePaginatedApiResponseWithConversion<Types.JobFull>(response, 'JobFull');
} }
async getJobsByEmployer(employerId: string, request: Partial<PaginatedRequest> = {}): Promise<PaginatedResponse<Types.Job>> { async getJobsByEmployer(employerId: string, request: Partial<PaginatedRequest> = {}): Promise<PaginatedResponse<Types.Job>> {
@ -844,18 +854,28 @@ class ApiClient {
} }
}; };
return this.streamify<Types.DocumentMessage>('/candidates/documents/upload', formData, streamingOptions); return this.streamify<Types.DocumentMessage>('/candidates/documents/upload', formData, streamingOptions);
// { }
// method: 'POST',
// headers: {
// // Don't set Content-Type - browser will set it automatically with boundary
// 'Authorization': this.defaultHeaders['Authorization']
// },
// body: formData
// });
// const result = await handleApiResponse<Types.Document>(response); createJobFromFile(file: File, streamingOptions?: StreamingOptions<Types.Job>): StreamingResponse<Types.Job> {
const formData = new FormData()
// return result; formData.append('file', file);
formData.append('filename', file.name);
streamingOptions = {
...streamingOptions,
headers: {
// Don't set Content-Type - browser will set it automatically with boundary
'Authorization': this.defaultHeaders['Authorization']
}
};
return this.streamify<Types.Job>('/jobs/upload', formData, streamingOptions);
}
getJobRequirements(jobId: string, streamingOptions?: StreamingOptions<Types.DocumentMessage>): StreamingResponse<Types.DocumentMessage> {
streamingOptions = {
...streamingOptions,
headers: this.defaultHeaders,
};
return this.streamify<Types.DocumentMessage>(`/jobs/requirements/${jobId}`, null, streamingOptions);
} }
async candidateMatchForRequirement(candidate_id: string, requirement: string) : Promise<Types.SkillMatch> { async candidateMatchForRequirement(candidate_id: string, requirement: string) : Promise<Types.SkillMatch> {
@ -1001,7 +1021,7 @@ class ApiClient {
* @param options callbacks, headers, and method * @param options callbacks, headers, and method
* @returns * @returns
*/ */
streamify<T = Types.ChatMessage[]>(api: string, data: BodyInit, options: StreamingOptions<T> = {}) : StreamingResponse<T> { streamify<T = Types.ChatMessage[]>(api: string, data: BodyInit | null, options: StreamingOptions<T> = {}) : StreamingResponse<T> {
const abortController = new AbortController(); const abortController = new AbortController();
const signal = options.signal || abortController.signal; const signal = options.signal || abortController.signal;
const headers = options.headers || null; const headers = options.headers || null;

View File

@ -1,6 +1,6 @@
// Generated TypeScript types from Pydantic models // Generated TypeScript types from Pydantic models
// Source: src/backend/models.py // Source: src/backend/models.py
// Generated on: 2025-06-05T22:02:22.004513 // Generated on: 2025-06-07T20:43:58.855207
// DO NOT EDIT MANUALLY - This file is auto-generated // DO NOT EDIT MANUALLY - This file is auto-generated
// ============================ // ============================
@ -323,7 +323,7 @@ export interface ChatMessageMetaData {
presencePenalty?: number; presencePenalty?: number;
stopSequences?: Array<string>; stopSequences?: Array<string>;
ragResults?: Array<ChromaDBGetResponse>; ragResults?: Array<ChromaDBGetResponse>;
llmHistory?: Array<LLMMessage>; llmHistory?: Array<any>;
evalCount: number; evalCount: number;
evalDuration: number; evalDuration: number;
promptEvalCount: number; promptEvalCount: number;
@ -711,12 +711,6 @@ export interface JobResponse {
meta?: Record<string, any>; meta?: Record<string, any>;
} }
export interface LLMMessage {
role: string;
content: string;
toolCalls?: Array<Record<string, any>>;
}
export interface Language { export interface Language {
language: string; language: string;
proficiency: "basic" | "conversational" | "fluent" | "native"; proficiency: "basic" | "conversational" | "fluent" | "native";

View File

@ -137,7 +137,7 @@ class Agent(BaseModel, ABC):
# llm: Any, # llm: Any,
# model: str, # model: str,
# message: ChatMessage, # message: ChatMessage,
# tool_message: Any, # llama response message # tool_message: Any, # llama response
# messages: List[LLMMessage], # messages: List[LLMMessage],
# ) -> AsyncGenerator[ChatMessage, None]: # ) -> AsyncGenerator[ChatMessage, None]:
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") # logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
@ -270,15 +270,15 @@ class Agent(BaseModel, ABC):
# }, # },
# stream=True, # stream=True,
# ): # ):
# # logger.info(f"LLM::Tools: {'done' if response.done else 'processing'} - {response.message}") # # logger.info(f"LLM::Tools: {'done' if response.finish_reason else 'processing'} - {response}")
# message.status = "streaming" # message.status = "streaming"
# message.chunk = response.message.content # message.chunk = response.content
# message.content += message.chunk # message.content += message.chunk
# if not response.done: # if not response.finish_reason:
# yield message # yield message
# if response.done: # if response.finish_reason:
# self.collect_metrics(response) # self.collect_metrics(response)
# message.metadata.eval_count += response.eval_count # message.metadata.eval_count += response.eval_count
# message.metadata.eval_duration += response.eval_duration # message.metadata.eval_duration += response.eval_duration
@ -296,9 +296,9 @@ class Agent(BaseModel, ABC):
def collect_metrics(self, response): def collect_metrics(self, response):
self.metrics.tokens_prompt.labels(agent=self.agent_type).inc( self.metrics.tokens_prompt.labels(agent=self.agent_type).inc(
response.prompt_eval_count response.usage.prompt_eval_count
) )
self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count) self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.usage.eval_count)
def get_rag_context(self, rag_message: ChatMessageRagSearch) -> str: def get_rag_context(self, rag_message: ChatMessageRagSearch) -> str:
""" """
@ -361,7 +361,7 @@ Content: {content}
yield status_message yield status_message
try: try:
chroma_results = user.file_watcher.find_similar( chroma_results = await user.file_watcher.find_similar(
query=prompt, top_k=top_k, threshold=threshold query=prompt, top_k=top_k, threshold=threshold
) )
if not chroma_results: if not chroma_results:
@ -430,7 +430,7 @@ Content: {content}
logger.info(f"Message options: {options.model_dump(exclude_unset=True)}") logger.info(f"Message options: {options.model_dump(exclude_unset=True)}")
response = None response = None
content = "" content = ""
for response in llm.chat( async for response in llm.chat_stream(
model=model, model=model,
messages=messages, messages=messages,
options={ options={
@ -446,12 +446,12 @@ Content: {content}
yield error_message yield error_message
return return
content += response.message.content content += response.content
if not response.done: if not response.finish_reason:
streaming_message = ChatMessageStreaming( streaming_message = ChatMessageStreaming(
session_id=session_id, session_id=session_id,
content=response.message.content, content=response.content,
status=ApiStatusType.STREAMING, status=ApiStatusType.STREAMING,
) )
yield streaming_message yield streaming_message
@ -466,7 +466,7 @@ Content: {content}
self.collect_metrics(response) self.collect_metrics(response)
self.context_tokens = ( self.context_tokens = (
response.prompt_eval_count + response.eval_count response.usage.prompt_eval_count + response.usage.eval_count
) )
chat_message = ChatMessage( chat_message = ChatMessage(
@ -476,10 +476,10 @@ Content: {content}
content=content, content=content,
metadata = ChatMessageMetaData( metadata = ChatMessageMetaData(
options=options, options=options,
eval_count=response.eval_count, eval_count=response.usage.eval_count,
eval_duration=response.eval_duration, eval_duration=response.usage.eval_duration,
prompt_eval_count=response.prompt_eval_count, prompt_eval_count=response.usage.prompt_eval_count,
prompt_eval_duration=response.prompt_eval_duration, prompt_eval_duration=response.usage.prompt_eval_duration,
) )
) )
@ -588,12 +588,12 @@ Content: {content}
# end_time = time.perf_counter() # end_time = time.perf_counter()
# message.metadata.timers["tool_check"] = end_time - start_time # message.metadata.timers["tool_check"] = end_time - start_time
# if not response.message.tool_calls: # if not response.tool_calls:
# logger.info("LLM indicates tools will not be used") # logger.info("LLM indicates tools will not be used")
# # The LLM will not use tools, so disable use_tools so we can stream the full response # # The LLM will not use tools, so disable use_tools so we can stream the full response
# use_tools = False # use_tools = False
# else: # else:
# tool_metadata["attempted"] = response.message.tool_calls # tool_metadata["attempted"] = response.tool_calls
# if use_tools: # if use_tools:
# logger.info("LLM indicates tools will be used") # logger.info("LLM indicates tools will be used")
@ -626,15 +626,15 @@ Content: {content}
# yield message # yield message
# return # return
# if response.message.tool_calls: # if response.tool_calls:
# tool_metadata["used"] = response.message.tool_calls # tool_metadata["used"] = response.tool_calls
# # Process all yielded items from the handler # # Process all yielded items from the handler
# start_time = time.perf_counter() # start_time = time.perf_counter()
# async for message in self.process_tool_calls( # async for message in self.process_tool_calls(
# llm=llm, # llm=llm,
# model=model, # model=model,
# message=message, # message=message,
# tool_message=response.message, # tool_message=response,
# messages=messages, # messages=messages,
# ): # ):
# if message.status == "error": # if message.status == "error":
@ -647,7 +647,7 @@ Content: {content}
# return # return
# logger.info("LLM indicated tools will be used, and then they weren't") # logger.info("LLM indicated tools will be used, and then they weren't")
# message.content = response.message.content # message.content = response.content
# message.status = "done" # message.status = "done"
# yield message # yield message
# return # return
@ -674,7 +674,7 @@ Content: {content}
content = "" content = ""
start_time = time.perf_counter() start_time = time.perf_counter()
response = None response = None
for response in llm.chat( async for response in llm.chat_stream(
model=model, model=model,
messages=messages, messages=messages,
options={ options={
@ -690,12 +690,12 @@ Content: {content}
yield error_message yield error_message
return return
content += response.message.content content += response.content
if not response.done: if not response.finish_reason:
streaming_message = ChatMessageStreaming( streaming_message = ChatMessageStreaming(
session_id=session_id, session_id=session_id,
content=response.message.content, content=response.content,
) )
yield streaming_message yield streaming_message
@ -709,7 +709,7 @@ Content: {content}
self.collect_metrics(response) self.collect_metrics(response)
self.context_tokens = ( self.context_tokens = (
response.prompt_eval_count + response.eval_count response.usage.prompt_eval_count + response.usage.eval_count
) )
end_time = time.perf_counter() end_time = time.perf_counter()
@ -720,10 +720,10 @@ Content: {content}
content=content, content=content,
metadata = ChatMessageMetaData( metadata = ChatMessageMetaData(
options=options, options=options,
eval_count=response.eval_count, eval_count=response.usage.eval_count,
eval_duration=response.eval_duration, eval_duration=response.usage.eval_duration,
prompt_eval_count=response.prompt_eval_count, prompt_eval_count=response.usage.prompt_eval_count,
prompt_eval_duration=response.prompt_eval_duration, prompt_eval_duration=response.usage.prompt_eval_duration,
timers={ timers={
"llm_streamed": end_time - start_time, "llm_streamed": end_time - start_time,
"llm_with_tools": 0, # Placeholder for tool processing time "llm_with_tools": 0, # Placeholder for tool processing time

View File

@ -178,12 +178,13 @@ class RedisDatabase:
'jobs': 'job:', 'jobs': 'job:',
'job_applications': 'job_application:', 'job_applications': 'job_application:',
'chat_sessions': 'chat_session:', 'chat_sessions': 'chat_session:',
'chat_messages': 'chat_messages:', # This will store lists 'chat_messages': 'chat_messages:',
'ai_parameters': 'ai_parameters:', 'ai_parameters': 'ai_parameters:',
'users': 'user:', 'users': 'user:',
'candidate_documents': 'candidate_documents:', 'candidate_documents': 'candidate_documents:',
'job_requirements': 'job_requirements:', # Add this line
} }
def _serialize(self, data: Any) -> str: def _serialize(self, data: Any) -> str:
"""Serialize data to JSON string for Redis storage""" """Serialize data to JSON string for Redis storage"""
if data is None: if data is None:
@ -236,8 +237,9 @@ class RedisDatabase:
# Delete each document's metadata # Delete each document's metadata
for doc_id in document_ids: for doc_id in document_ids:
pipe.delete(f"document:{doc_id}") pipe.delete(f"document:{doc_id}")
pipe.delete(f"{self.KEY_PREFIXES['job_requirements']}{doc_id}")
deleted_count += 1 deleted_count += 1
# Delete the candidate's document list # Delete the candidate's document list
pipe.delete(key) pipe.delete(key)
@ -250,7 +252,110 @@ class RedisDatabase:
except Exception as e: except Exception as e:
logger.error(f"Error deleting all documents for candidate {candidate_id}: {e}") logger.error(f"Error deleting all documents for candidate {candidate_id}: {e}")
raise raise
async def get_cached_skill_match(self, cache_key: str) -> Optional[Dict[str, Any]]:
"""Retrieve cached skill match assessment"""
try:
cached_data = await self.redis.get(cache_key)
if cached_data:
return json.loads(cached_data)
return None
except Exception as e:
logger.error(f"Error retrieving cached skill match: {e}")
return None
async def cache_skill_match(self, cache_key: str, assessment_data: Dict[str, Any], ttl: int = 86400 * 30) -> bool:
"""Cache skill match assessment with TTL (default 30 days)"""
try:
await self.redis.setex(
cache_key,
ttl,
json.dumps(assessment_data, default=str)
)
return True
except Exception as e:
logger.error(f"Error caching skill match: {e}")
return False
async def get_candidate_skill_update_time(self, candidate_id: str) -> Optional[datetime]:
"""Get the last time candidate's skill information was updated"""
try:
# This assumes you track skill update timestamps in your candidate data
candidate_data = await self.get_candidate(candidate_id)
if candidate_data and 'skills_updated_at' in candidate_data:
return datetime.fromisoformat(candidate_data['skills_updated_at'])
return None
except Exception as e:
logger.error(f"Error getting candidate skill update time: {e}")
return None
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]:
"""Get the timestamp of the latest RAG data update for a specific user"""
try:
rag_update_key = f"user:{user_id}:rag_last_update"
timestamp_str = await self.redis.get(rag_update_key)
if timestamp_str:
return datetime.fromisoformat(timestamp_str.decode('utf-8'))
return None
except Exception as e:
logger.error(f"Error getting user RAG update time for user {user_id}: {e}")
return None
async def update_user_rag_timestamp(self, user_id: str) -> bool:
"""Update the RAG data timestamp for a specific user (call this when user's RAG data is updated)"""
try:
rag_update_key = f"user:{user_id}:rag_last_update"
current_time = datetime.utcnow().isoformat()
await self.redis.set(rag_update_key, current_time)
return True
except Exception as e:
logger.error(f"Error updating RAG timestamp for user {user_id}: {e}")
return False
async def invalidate_candidate_skill_cache(self, candidate_id: str) -> int:
"""Invalidate all cached skill matches for a specific candidate"""
try:
pattern = f"skill_match:{candidate_id}:*"
keys = await self.redis.keys(pattern)
if keys:
return await self.redis.delete(*keys)
return 0
except Exception as e:
logger.error(f"Error invalidating candidate skill cache: {e}")
return 0
async def clear_all_skill_match_cache(self) -> int:
"""Clear all skill match cache (useful after major system updates)"""
try:
pattern = "skill_match:*"
keys = await self.redis.keys(pattern)
if keys:
return await self.redis.delete(*keys)
return 0
except Exception as e:
logger.error(f"Error clearing skill match cache: {e}")
return 0
async def invalidate_user_skill_cache(self, user_id: str) -> int:
"""Invalidate all cached skill matches when a user's RAG data is updated"""
try:
# This assumes all candidates belonging to this user need cache invalidation
# You might need to adjust the pattern based on how you associate candidates with users
pattern = f"skill_match:*"
keys = await self.redis.keys(pattern)
# Filter keys that belong to candidates owned by this user
# This would require additional logic to determine candidate ownership
# For now, you might want to clear all cache when any user's RAG data updates
# or implement a more sophisticated mapping
if keys:
return await self.redis.delete(*keys)
return 0
except Exception as e:
logger.error(f"Error invalidating user skill cache for user {user_id}: {e}")
return 0
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]: async def get_candidate_documents(self, candidate_id: str) -> List[Dict]:
"""Get all documents for a specific candidate""" """Get all documents for a specific candidate"""
key = f"{self.KEY_PREFIXES['candidate_documents']}{candidate_id}" key = f"{self.KEY_PREFIXES['candidate_documents']}{candidate_id}"
@ -330,7 +435,134 @@ class RedisDatabase:
if (query_lower in doc.get("filename", "").lower() or if (query_lower in doc.get("filename", "").lower() or
query_lower in doc.get("originalName", "").lower()) query_lower in doc.get("originalName", "").lower())
] ]
async def get_job_requirements(self, document_id: str) -> Optional[Dict]:
"""Get cached job requirements analysis for a document"""
try:
key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}"
data = await self.redis.get(key)
if data:
requirements_data = self._deserialize(data)
logger.debug(f"📋 Retrieved cached job requirements for document {document_id}")
return requirements_data
logger.debug(f"📋 No cached job requirements found for document {document_id}")
return None
except Exception as e:
logger.error(f"❌ Error retrieving job requirements for document {document_id}: {e}")
return None
async def save_job_requirements(self, document_id: str, requirements: Dict) -> bool:
"""Save job requirements analysis results for a document"""
try:
key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}"
# Add metadata to the requirements
requirements_with_meta = {
**requirements,
"cached_at": datetime.now(UTC).isoformat(),
"document_id": document_id
}
await self.redis.set(key, self._serialize(requirements_with_meta))
# Optional: Set expiration (e.g., 30 days) to prevent indefinite storage
# await self.redis.expire(key, 30 * 24 * 60 * 60) # 30 days
logger.debug(f"📋 Saved job requirements for document {document_id}")
return True
except Exception as e:
logger.error(f"❌ Error saving job requirements for document {document_id}: {e}")
return False
async def delete_job_requirements(self, document_id: str) -> bool:
"""Delete cached job requirements for a document"""
try:
key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}"
result = await self.redis.delete(key)
if result > 0:
logger.debug(f"📋 Deleted job requirements for document {document_id}")
return True
return False
except Exception as e:
logger.error(f"❌ Error deleting job requirements for document {document_id}: {e}")
return False
async def get_all_job_requirements(self) -> Dict[str, Any]:
"""Get all cached job requirements"""
try:
pattern = f"{self.KEY_PREFIXES['job_requirements']}*"
keys = await self.redis.keys(pattern)
if not keys:
return {}
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
result = {}
for key, value in zip(keys, values):
document_id = key.replace(self.KEY_PREFIXES['job_requirements'], '')
if value:
result[document_id] = self._deserialize(value)
return result
except Exception as e:
logger.error(f"❌ Error retrieving all job requirements: {e}")
return {}
async def get_job_requirements_by_candidate(self, candidate_id: str) -> List[Dict]:
"""Get all job requirements analysis for documents belonging to a candidate"""
try:
# Get all documents for the candidate
candidate_documents = await self.get_candidate_documents(candidate_id)
if not candidate_documents:
return []
# Get job requirements for each document
job_requirements = []
for doc in candidate_documents:
doc_id = doc.get("id")
if doc_id:
requirements = await self.get_job_requirements(doc_id)
if requirements:
# Add document metadata to requirements
requirements["document_filename"] = doc.get("filename")
requirements["document_original_name"] = doc.get("originalName")
job_requirements.append(requirements)
return job_requirements
except Exception as e:
logger.error(f"❌ Error retrieving job requirements for candidate {candidate_id}: {e}")
return []
async def invalidate_job_requirements_cache(self, document_id: str) -> bool:
"""Invalidate (delete) cached job requirements for a document"""
# This is an alias for delete_job_requirements for semantic clarity
return await self.delete_job_requirements(document_id)
async def bulk_delete_job_requirements(self, document_ids: List[str]) -> int:
"""Delete job requirements for multiple documents and return count of deleted items"""
try:
deleted_count = 0
pipe = self.redis.pipeline()
for doc_id in document_ids:
key = f"{self.KEY_PREFIXES['job_requirements']}{doc_id}"
pipe.delete(key)
deleted_count += 1
results = await pipe.execute()
actual_deleted = sum(1 for result in results if result > 0)
logger.info(f"📋 Bulk deleted job requirements for {actual_deleted}/{len(document_ids)} documents")
return actual_deleted
except Exception as e:
logger.error(f"❌ Error bulk deleting job requirements: {e}")
return 0
# Viewer operations # Viewer operations
async def get_viewer(self, viewer_id: str) -> Optional[Dict]: async def get_viewer(self, viewer_id: str) -> Optional[Dict]:
"""Get viewer by ID""" """Get viewer by ID"""
@ -1484,6 +1716,74 @@ class RedisDatabase:
key = f"{self.KEY_PREFIXES['users']}{email}" key = f"{self.KEY_PREFIXES['users']}{email}"
await self.redis.delete(key) await self.redis.delete(key)
async def get_job_requirements_stats(self) -> Dict[str, Any]:
"""Get statistics about cached job requirements"""
try:
pattern = f"{self.KEY_PREFIXES['job_requirements']}*"
keys = await self.redis.keys(pattern)
stats = {
"total_cached_requirements": len(keys),
"cache_dates": {},
"documents_with_requirements": []
}
if keys:
# Get cache dates for analysis
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
for key, value in zip(keys, values):
if value:
requirements_data = self._deserialize(value)
if requirements_data:
document_id = key.replace(self.KEY_PREFIXES['job_requirements'], '')
stats["documents_with_requirements"].append(document_id)
# Track cache dates
cached_at = requirements_data.get("cached_at")
if cached_at:
cache_date = cached_at[:10] # Extract date part
stats["cache_dates"][cache_date] = stats["cache_dates"].get(cache_date, 0) + 1
return stats
except Exception as e:
logger.error(f"❌ Error getting job requirements stats: {e}")
return {"total_cached_requirements": 0, "cache_dates": {}, "documents_with_requirements": []}
async def cleanup_orphaned_job_requirements(self) -> int:
"""Clean up job requirements for documents that no longer exist"""
try:
# Get all job requirements
all_requirements = await self.get_all_job_requirements()
if not all_requirements:
return 0
orphaned_count = 0
pipe = self.redis.pipeline()
for document_id in all_requirements.keys():
# Check if the document still exists
document_exists = await self.get_document(document_id)
if not document_exists:
# Document no longer exists, delete its job requirements
key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}"
pipe.delete(key)
orphaned_count += 1
logger.debug(f"📋 Queued orphaned job requirements for deletion: {document_id}")
if orphaned_count > 0:
await pipe.execute()
logger.info(f"🧹 Cleaned up {orphaned_count} orphaned job requirements")
return orphaned_count
except Exception as e:
logger.error(f"❌ Error cleaning up orphaned job requirements: {e}")
return 0
# Utility methods # Utility methods
async def clear_all_data(self): async def clear_all_data(self):
"""Clear all data from Redis (use with caution!)""" """Clear all data from Redis (use with caution!)"""

View File

@ -19,8 +19,9 @@ import defines
from logger import logger from logger import logger
import agents as agents import agents as agents
from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, RagEntry, ChatMessageMetaData, ApiStatusType, Candidate, ChatContextType) from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, RagEntry, ChatMessageMetaData, ApiStatusType, Candidate, ChatContextType)
from llm_manager import llm_manager import llm_proxy as llm_manager
from agents.base import Agent from agents.base import Agent
from database import RedisDatabase
class CandidateEntity(Candidate): class CandidateEntity(Candidate):
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
@ -115,7 +116,7 @@ class CandidateEntity(Candidate):
raise ValueError("initialize() has not been called.") raise ValueError("initialize() has not been called.")
return self.CandidateEntity__observer return self.CandidateEntity__observer
async def initialize(self, prometheus_collector: CollectorRegistry): async def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
if self.CandidateEntity__initialized: if self.CandidateEntity__initialized:
# Initialization can only be attempted once; if there are multiple attempts, it means # Initialization can only be attempted once; if there are multiple attempts, it means
# a subsystem is failing or there is a logic bug in the code. # a subsystem is failing or there is a logic bug in the code.
@ -140,9 +141,11 @@ class CandidateEntity(Candidate):
self.CandidateEntity__observer, self.CandidateEntity__file_watcher = start_file_watcher( self.CandidateEntity__observer, self.CandidateEntity__file_watcher = start_file_watcher(
llm=llm_manager.get_llm(), llm=llm_manager.get_llm(),
user_id=self.id,
collection_name=self.username, collection_name=self.username,
persist_directory=vector_db_dir, persist_directory=vector_db_dir,
watch_directory=rag_content_dir, watch_directory=rag_content_dir,
database=database,
recreate=False, # Don't recreate if exists recreate=False, # Don't recreate if exists
) )
has_username_rag = any(item["name"] == self.username for item in self.rags) has_username_rag = any(item["name"] == self.username for item in self.rags)

View File

@ -7,6 +7,7 @@ from pydantic import BaseModel, Field # type: ignore
from models import ( Candidate ) from models import ( Candidate )
from .candidate_entity import CandidateEntity from .candidate_entity import CandidateEntity
from database import RedisDatabase
from prometheus_client import CollectorRegistry # type: ignore from prometheus_client import CollectorRegistry # type: ignore
class EntityManager: class EntityManager:
@ -34,9 +35,10 @@ class EntityManager:
pass pass
self._cleanup_task = None self._cleanup_task = None
def initialize(self, prometheus_collector: CollectorRegistry): def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
"""Initialize the EntityManager with Prometheus collector""" """Initialize the EntityManager with Prometheus collector"""
self._prometheus_collector = prometheus_collector self._prometheus_collector = prometheus_collector
self._database = database
async def get_entity(self, candidate: Candidate) -> CandidateEntity: async def get_entity(self, candidate: Candidate) -> CandidateEntity:
"""Get or create CandidateEntity with proper reference tracking""" """Get or create CandidateEntity with proper reference tracking"""
@ -49,7 +51,7 @@ class EntityManager:
return entity return entity
entity = CandidateEntity(candidate=candidate) entity = CandidateEntity(candidate=candidate)
await entity.initialize(prometheus_collector=self._prometheus_collector) await entity.initialize(prometheus_collector=self._prometheus_collector, database=self._database)
# Store with reference tracking # Store with reference tracking
self._entities[candidate.id] = entity self._entities[candidate.id] = entity

View File

@ -1,41 +0,0 @@
import ollama
import defines
_llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
class llm_manager:
"""
A class to manage LLM operations using the Ollama client.
"""
@staticmethod
def get_llm() -> ollama.Client: # type: ignore
"""
Get the Ollama client instance.
Returns:
An instance of the Ollama client.
"""
return _llm
@staticmethod
def get_models() -> list[str]:
"""
Get a list of available models from the Ollama client.
Returns:
List of model names.
"""
return _llm.models()
@staticmethod
def get_model_info(model_name: str) -> dict:
"""
Get information about a specific model.
Args:
model_name: The name of the model to retrieve information for.
Returns:
A dictionary containing model information.
"""
return _llm.model(model_name)

1203
src/backend/llm_proxy.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -13,6 +13,9 @@ import uuid
import defines import defines
import pathlib import pathlib
from markitdown import MarkItDown, StreamInfo # type: ignore
import io
import uvicorn # type: ignore import uvicorn # type: ignore
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any
from datetime import datetime, timedelta, UTC from datetime import datetime, timedelta, UTC
@ -53,7 +56,7 @@ import defines
from logger import logger from logger import logger
from database import RedisDatabase, redis_manager, DatabaseManager from database import RedisDatabase, redis_manager, DatabaseManager
from metrics import Metrics from metrics import Metrics
from llm_manager import llm_manager import llm_proxy as llm_manager
import entities import entities
from email_service import VerificationEmailRateLimiter, email_service from email_service import VerificationEmailRateLimiter, email_service
from device_manager import DeviceManager from device_manager import DeviceManager
@ -116,7 +119,8 @@ async def lifespan(app: FastAPI):
try: try:
# Initialize database # Initialize database
await db_manager.initialize() await db_manager.initialize()
entities.entity_manager.initialize(prometheus_collector, database=db_manager.get_database())
signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
@ -1827,7 +1831,7 @@ async def upload_candidate_document(
yield error_message yield error_message
return return
converted = False; converted = False
if document_type != DocumentType.MARKDOWN and document_type != DocumentType.TXT: if document_type != DocumentType.MARKDOWN and document_type != DocumentType.TXT:
p = pathlib.Path(file_path) p = pathlib.Path(file_path)
p_as_md = p.with_suffix(".md") p_as_md = p.with_suffix(".md")
@ -1873,53 +1877,6 @@ async def upload_candidate_document(
content=file_content, content=file_content,
) )
yield chat_message yield chat_message
# If this is a job description, process it with the job requirements agent
if not options.is_job_document:
return
status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Initiating connection with {candidate.first_name}'s AI agent...",
activity=ApiActivityType.INFO
)
yield status_message
await asyncio.sleep(0)
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
if not chat_agent:
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="No agent found for job requirements chat type"
)
yield error_message
return
message = None
status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Analyzing document for company and requirement details...",
activity=ApiActivityType.SEARCHING
)
yield status_message
await asyncio.sleep(0)
async for message in chat_agent.generate(
llm=llm_manager.get_llm(),
model=defines.model,
session_id=MOCK_UUID,
prompt=file_content
):
pass
if not message or not isinstance(message, JobRequirementsMessage):
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to process job description file"
)
yield error_message
return
yield message
try: try:
async def to_json(method): async def to_json(method):
try: try:
@ -1932,7 +1889,6 @@ async def upload_candidate_document(
logger.error(f"Error in to_json conversion: {e}") logger.error(f"Error in to_json conversion: {e}")
return return
# return DebugStreamingResponse(
return StreamingResponse( return StreamingResponse(
to_json(upload_stream_generator(file_content)), to_json(upload_stream_generator(file_content)),
media_type="text/event-stream", media_type="text/event-stream",
@ -1944,15 +1900,64 @@ async def upload_candidate_document(
"Access-Control-Allow-Origin": "*", # Adjust for your CORS needs "Access-Control-Allow-Origin": "*", # Adjust for your CORS needs
"Transfer-Encoding": "chunked", "Transfer-Encoding": "chunked",
}, },
) )
except Exception as e: except Exception as e:
logger.error(backstory_traceback.format_exc()) logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Document upload error: {e}") logger.error(f"❌ Document upload error: {e}")
return JSONResponse( return StreamingResponse(
status_code=500, iter([ChatMessageError(
content=create_error_response("UPLOAD_ERROR", "Failed to upload document") session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to upload document"
)]),
media_type="text/event-stream"
) )
async def create_job_from_content(database: RedisDatabase, current_user: Candidate, content: str):
status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Initiating connection with {current_user.first_name}'s AI agent...",
activity=ApiActivityType.INFO
)
yield status_message
await asyncio.sleep(0) # Let the status message propagate
async with entities.get_candidate_entity(candidate=current_user) as candidate_entity:
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
if not chat_agent:
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="No agent found for job requirements chat type"
)
yield error_message
return
message = None
status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Analyzing document for company and requirement details...",
activity=ApiActivityType.SEARCHING
)
yield status_message
await asyncio.sleep(0)
async for message in chat_agent.generate(
llm=llm_manager.get_llm(),
model=defines.model,
session_id=MOCK_UUID,
prompt=content
):
pass
if not message or not isinstance(message, JobRequirementsMessage):
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to process job description file"
)
yield error_message
return
logger.info(f"✅ Successfully saved job requirements job {message.id}")
yield message
return
@api_router.post("/candidates/profile/upload") @api_router.post("/candidates/profile/upload")
async def upload_candidate_profile( async def upload_candidate_profile(
file: UploadFile = File(...), file: UploadFile = File(...),
@ -2573,6 +2578,7 @@ async def delete_candidate(
status_code=500, status_code=500,
content=create_error_response("DELETE_ERROR", "Failed to delete candidate") content=create_error_response("DELETE_ERROR", "Failed to delete candidate")
) )
@api_router.patch("/candidates/{candidate_id}") @api_router.patch("/candidates/{candidate_id}")
async def update_candidate( async def update_candidate(
candidate_id: str = Path(...), candidate_id: str = Path(...),
@ -2816,6 +2822,139 @@ async def create_candidate_job(
content=create_error_response("CREATION_FAILED", str(e)) content=create_error_response("CREATION_FAILED", str(e))
) )
@api_router.post("/jobs/upload")
async def create_job_from_file(
file: UploadFile = File(...),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
):
"""Upload a job document for the current candidate and create a Job"""
# Check file size (limit to 10MB)
max_size = 10 * 1024 * 1024 # 10MB
file_content = await file.read()
if len(file_content) > max_size:
logger.info(f"⚠️ File too large: {file.filename} ({len(file_content)} bytes)")
return StreamingResponse(
iter([ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="File size exceeds 10MB limit"
)]),
media_type="text/event-stream"
)
if len(file_content) == 0:
logger.info(f"⚠️ File is empty: {file.filename}")
return StreamingResponse(
iter([ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="File is empty"
)]),
media_type="text/event-stream"
)
"""Upload a document for the current candidate"""
async def upload_stream_generator(file_content):
# Verify user is a candidate
if current_user.user_type != "candidate":
logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Only candidates can upload documents"
)
yield error_message
return
file.filename = re.sub(r'^.*/', '', file.filename) if file.filename else '' # Sanitize filename
if not file.filename or file.filename.strip() == "":
logger.warning("⚠️ File upload attempt with missing filename")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="File must have a valid filename"
)
yield error_message
return
logger.info(f"📁 Received file upload: filename='{file.filename}', content_type='{file.content_type}', size='{len(file_content)} bytes'")
# Validate file type
allowed_types = ['.txt', '.md', '.docx', '.pdf', '.png', '.jpg', '.jpeg', '.gif']
file_extension = pathlib.Path(file.filename).suffix.lower() if file.filename else ""
if file_extension not in allowed_types:
logger.warning(f"⚠️ Invalid file type: {file_extension} for file {file.filename}")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}"
)
yield error_message
return
document_type = get_document_type_from_filename(file.filename or "unknown.txt")
if document_type != DocumentType.MARKDOWN and document_type != DocumentType.TXT:
status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Converting content from {document_type}...",
activity=ApiActivityType.CONVERTING
)
yield status_message
try:
md = MarkItDown(enable_plugins=False) # Set to True to enable plugins
stream = io.BytesIO(file_content)
stream_info = StreamInfo(
extension=file_extension, # e.g., ".pdf"
url=file.filename # optional, helps with logging and guessing
)
result = md.convert_stream(stream, stream_info=stream_info, output_format="markdown")
file_content = result.text_content
logger.info(f"✅ Converted {file.filename} to Markdown format")
except Exception as e:
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Failed to convert {file.filename} to Markdown.",
)
yield error_message
logger.error(f"❌ Error converting {file.filename} to Markdown: {e}")
return
async for message in create_job_from_content(database=database, current_user=current_user, content=file_content):
yield message
return
try:
async def to_json(method):
try:
async for message in method:
json_data = message.model_dump(mode='json', by_alias=True)
json_str = json.dumps(json_data)
yield f"data: {json_str}\n\n".encode("utf-8")
except Exception as e:
logger.error(backstory_traceback.format_exc())
logger.error(f"Error in to_json conversion: {e}")
return
return StreamingResponse(
to_json(upload_stream_generator(file_content)),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache, no-store, must-revalidate",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Nginx
"X-Content-Type-Options": "nosniff",
"Access-Control-Allow-Origin": "*", # Adjust for your CORS needs
"Transfer-Encoding": "chunked",
},
)
except Exception as e:
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Document upload error: {e}")
return StreamingResponse(
iter([ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to upload document"
)]),
media_type="text/event-stream"
)
@api_router.get("/jobs/{job_id}") @api_router.get("/jobs/{job_id}")
async def get_job( async def get_job(
job_id: str = Path(...), job_id: str = Path(...),
@ -2931,6 +3070,49 @@ async def search_jobs(
content=create_error_response("SEARCH_FAILED", str(e)) content=create_error_response("SEARCH_FAILED", str(e))
) )
@api_router.delete("/jobs/{job_id}")
async def delete_job(
job_id: str = Path(...),
admin_user = Depends(get_current_admin),
database: RedisDatabase = Depends(get_database)
):
"""Delete a Job"""
try:
# Check if admin user
if not admin_user.is_admin:
logger.warning(f"⚠️ Unauthorized delete attempt by user {admin_user.id}")
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Only admins can delete")
)
# Get candidate data
job_data = await database.get_job(job_id)
if not job_data:
logger.warning(f"⚠️ Candidate not found for deletion: {job_id}")
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Job not found")
)
# Delete job from database
await database.delete_job(job_id)
logger.info(f"🗑️ Job deleted: {job_id} by admin {admin_user.id}")
return create_success_response({
"message": "Job deleted successfully",
"jobId": job_id
})
except Exception as e:
logger.error(f"❌ Delete job error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("DELETE_ERROR", "Failed to delete job")
)
# ============================ # ============================
# Chat Endpoints # Chat Endpoints
# ============================ # ============================
@ -3541,7 +3723,7 @@ async def get_candidate_skill_match(
current_user = Depends(get_current_user), current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database)
): ):
"""Get skill match for a candidate against a requirement""" """Get skill match for a candidate against a requirement with caching"""
try: try:
# Find candidate by ID # Find candidate by ID
candidate_data = await database.get_candidate(candidate_id) candidate_data = await database.get_candidate(candidate_id)
@ -3553,34 +3735,89 @@ async def get_candidate_skill_match(
candidate = Candidate.model_validate(candidate_data) candidate = Candidate.model_validate(candidate_data)
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity: # Create cache key for this specific candidate + requirement combination
logger.info(f"🔍 Running skill match for candidate {candidate_entity.username} against requirement: {requirement}") cache_key = f"skill_match:{candidate_id}:{hash(requirement)}"
agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.SKILL_MATCH)
if not agent: # Get cached assessment if it exists
return JSONResponse( cached_assessment = await database.get_cached_skill_match(cache_key)
status_code=400,
content=create_error_response("AGENT_NOT_FOUND", "No skill match agent found for this candidate") # Get the last update time for the candidate's skill information
) candidate_skill_update_time = await database.get_candidate_skill_update_time(candidate_id)
# Entity automatically released when done
skill_match = await get_last_item( # Get the latest RAG data update time for the current user
agent.generate( user_rag_update_time = await database.get_user_rag_update_time(current_user.id)
llm=llm_manager.get_llm(),
model=defines.model, # Determine if we need to regenerate the assessment
session_id=MOCK_UUID, should_regenerate = True
prompt=requirement, cached_date = None
if cached_assessment:
cached_date = cached_assessment.get('cached_at')
if cached_date:
# Check if cached result is still valid
# Regenerate if:
# 1. Candidate skills were updated after cache date
# 2. User's RAG data was updated after cache date
if (not candidate_skill_update_time or cached_date >= candidate_skill_update_time) and \
(not user_rag_update_time or cached_date >= user_rag_update_time):
should_regenerate = False
logger.info(f"🔄 Using cached skill match for candidate {candidate.id}")
if should_regenerate:
logger.info(f"🔍 Generating new skill match for candidate {candidate.id} against requirement: {requirement}")
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.SKILL_MATCH)
if not agent:
return JSONResponse(
status_code=400,
content=create_error_response("AGENT_NOT_FOUND", "No skill match agent found for this candidate")
)
# Generate new skill match
skill_match = await get_last_item(
agent.generate(
llm=llm_manager.get_llm(),
model=defines.model,
session_id=MOCK_UUID,
prompt=requirement,
), ),
) )
if skill_match is None:
if skill_match is None:
return JSONResponse(
status_code=500,
content=create_error_response("NO_MATCH", "No skill match found for the given requirement")
)
skill_match_data = json.loads(skill_match.content)
# Cache the new assessment with current timestamp
cached_assessment = {
"skill_match": skill_match_data,
"cached_at": datetime.utcnow().isoformat(),
"candidate_id": candidate_id,
"requirement": requirement
}
await database.cache_skill_match(cache_key, cached_assessment)
logger.info(f"💾 Cached new skill match for candidate {candidate.id}")
logger.info(f"✅ Skill match found for candidate {candidate.id}: {skill_match_data['evidence_strength']}")
else:
# Use cached result - we know cached_assessment is not None here
if cached_assessment is None:
return JSONResponse( return JSONResponse(
status_code=500, status_code=500,
content=create_error_response("NO_MATCH", "No skill match found for the given requirement") content=create_error_response("CACHE_ERROR", "Unexpected cache state")
) )
skill_match = json.loads(skill_match.content) skill_match_data = cached_assessment["skill_match"]
logger.info(f"✅ Skill match found for candidate {candidate.id}: {skill_match['evidence_strength']}") logger.info(f"Retrieved cached skill match for candidate {candidate.id}: {skill_match_data['evidence_strength']}")
return create_success_response({ return create_success_response({
"candidateId": candidate.id, "candidateId": candidate.id,
"skillMatch": skill_match "skillMatch": skill_match_data,
"cached": not should_regenerate,
"cacheTimestamp": cached_date
}) })
except Exception as e: except Exception as e:
@ -3589,8 +3826,8 @@ async def get_candidate_skill_match(
return JSONResponse( return JSONResponse(
status_code=500, status_code=500,
content=create_error_response("SKILL_MATCH_ERROR", str(e)) content=create_error_response("SKILL_MATCH_ERROR", str(e))
) )
@api_router.get("/candidates/{username}/chat-sessions") @api_router.get("/candidates/{username}/chat-sessions")
async def get_candidate_chat_sessions( async def get_candidate_chat_sessions(
username: str = Path(...), username: str = Path(...),
@ -3911,7 +4148,6 @@ async def track_requests(request, call_next):
# FastAPI Metrics # FastAPI Metrics
# ============================ # ============================
prometheus_collector = CollectorRegistry() prometheus_collector = CollectorRegistry()
entities.entity_manager.initialize(prometheus_collector)
# Keep the Instrumentator instance alive # Keep the Instrumentator instance alive
instrumentator = Instrumentator( instrumentator = Instrumentator(

View File

@ -768,11 +768,7 @@ class ChatOptions(BaseModel):
"populate_by_name": True # Allow both field names and aliases "populate_by_name": True # Allow both field names and aliases
} }
from llm_proxy import (LLMMessage)
class LLMMessage(BaseModel):
role: str = Field(default="")
content: str = Field(default="")
tool_calls: Optional[List[Dict]] = Field(default=[], exclude=True)
class ApiMessage(BaseModel): class ApiMessage(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4())) id: str = Field(default_factory=lambda: str(uuid.uuid4()))

View File

@ -13,7 +13,6 @@ import numpy as np # type: ignore
import traceback import traceback
import chromadb # type: ignore import chromadb # type: ignore
import ollama
from watchdog.observers import Observer # type: ignore from watchdog.observers import Observer # type: ignore
from watchdog.events import FileSystemEventHandler # type: ignore from watchdog.events import FileSystemEventHandler # type: ignore
import umap # type: ignore import umap # type: ignore
@ -27,6 +26,7 @@ from .markdown_chunker import (
# When imported as a module, use relative imports # When imported as a module, use relative imports
import defines import defines
from database import RedisDatabase
from models import ChromaDBGetResponse from models import ChromaDBGetResponse
__all__ = ["ChromaDBFileWatcher", "start_file_watcher"] __all__ = ["ChromaDBFileWatcher", "start_file_watcher"]
@ -47,11 +47,16 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
loop, loop,
persist_directory, persist_directory,
collection_name, collection_name,
database: RedisDatabase,
user_id: str,
chunk_size=DEFAULT_CHUNK_SIZE, chunk_size=DEFAULT_CHUNK_SIZE,
chunk_overlap=DEFAULT_CHUNK_OVERLAP, chunk_overlap=DEFAULT_CHUNK_OVERLAP,
recreate=False, recreate=False,
): ):
self.llm = llm self.llm = llm
self.database = database
self.user_id = user_id
self.database = database
self.watch_directory = watch_directory self.watch_directory = watch_directory
self.persist_directory = persist_directory or defines.persist_directory self.persist_directory = persist_directory or defines.persist_directory
self.collection_name = collection_name self.collection_name = collection_name
@ -284,6 +289,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
if results and "ids" in results and results["ids"]: if results and "ids" in results and results["ids"]:
self.collection.delete(ids=results["ids"]) self.collection.delete(ids=results["ids"])
await self.database.update_user_rag_timestamp(self.user_id)
logging.info( logging.info(
f"Removed {len(results['ids'])} chunks for deleted file: {file_path}" f"Removed {len(results['ids'])} chunks for deleted file: {file_path}"
) )
@ -372,14 +378,15 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
name=self.collection_name, metadata={"hnsw:space": "cosine"} name=self.collection_name, metadata={"hnsw:space": "cosine"}
) )
def get_embedding(self, text: str) -> np.ndarray: async def get_embedding(self, text: str) -> np.ndarray:
"""Generate and normalize an embedding for the given text.""" """Generate and normalize an embedding for the given text."""
# Get embedding # Get embedding
try: try:
response = self.llm.embeddings(model=defines.embedding_model, prompt=text) response = await self.llm.embeddings(model=defines.embedding_model, input_texts=text)
embedding = np.array(response["embedding"]) embedding = np.array(response.get_single_embedding())
except Exception as e: except Exception as e:
logging.error(traceback.format_exc())
logging.error(f"Failed to get embedding: {e}") logging.error(f"Failed to get embedding: {e}")
raise raise
@ -404,7 +411,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
return embedding return embedding
def add_embeddings_to_collection(self, chunks: List[Chunk]): async def _add_embeddings_to_collection(self, chunks: List[Chunk]):
"""Add embeddings for chunks to the collection.""" """Add embeddings for chunks to the collection."""
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
@ -420,7 +427,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
content_hash = hashlib.md5(text.encode()).hexdigest()[:8] content_hash = hashlib.md5(text.encode()).hexdigest()[:8]
chunk_id = f"{path_hash}_{i}_{content_hash}" chunk_id = f"{path_hash}_{i}_{content_hash}"
embedding = self.get_embedding(text) embedding = await self.get_embedding(text)
try: try:
self.collection.add( self.collection.add(
ids=[chunk_id], ids=[chunk_id],
@ -458,11 +465,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# 0.5 - 0.7 0.65 - 0.75 Balanced precision/recall # 0.5 - 0.7 0.65 - 0.75 Balanced precision/recall
# 0.7 - 0.9 0.55 - 0.65 Higher recall, more inclusive # 0.7 - 0.9 0.55 - 0.65 Higher recall, more inclusive
# 0.9 - 1.2 0.40 - 0.55 Very inclusive, may include tangential content # 0.9 - 1.2 0.40 - 0.55 Very inclusive, may include tangential content
def find_similar(self, query, top_k=defines.default_rag_top_k, threshold=defines.default_rag_threshold): async def find_similar(self, query, top_k=defines.default_rag_top_k, threshold=defines.default_rag_threshold):
"""Find similar documents to the query.""" """Find similar documents to the query."""
# collection is configured with hnsw:space cosine # collection is configured with hnsw:space cosine
query_embedding = self.get_embedding(query) query_embedding = await self.get_embedding(query)
results = self.collection.query( results = self.collection.query(
query_embeddings=[query_embedding], query_embeddings=[query_embedding],
n_results=top_k, n_results=top_k,
@ -572,6 +579,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
and existing_results["ids"] and existing_results["ids"]
): ):
self.collection.delete(ids=existing_results["ids"]) self.collection.delete(ids=existing_results["ids"])
await self.database.update_user_rag_timestamp(self.user_id)
extensions = (".docx", ".xlsx", ".xls", ".pdf") extensions = (".docx", ".xlsx", ".xls", ".pdf")
if file_path.endswith(extensions): if file_path.endswith(extensions):
@ -606,7 +614,8 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# f.write(json.dumps(chunk, indent=2)) # f.write(json.dumps(chunk, indent=2))
# Add chunks to collection # Add chunks to collection
self.add_embeddings_to_collection(chunks) await self._add_embeddings_to_collection(chunks)
await self.database.update_user_rag_timestamp(self.user_id)
logging.info(f"Updated {len(chunks)} chunks for file: {file_path}") logging.info(f"Updated {len(chunks)} chunks for file: {file_path}")
@ -640,9 +649,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# Function to start the file watcher # Function to start the file watcher
def start_file_watcher( def start_file_watcher(
llm, llm,
user_id,
watch_directory, watch_directory,
persist_directory, persist_directory,
collection_name, collection_name,
database: RedisDatabase,
initialize=False, initialize=False,
recreate=False, recreate=False,
): ):
@ -663,9 +674,11 @@ def start_file_watcher(
llm, llm,
watch_directory=watch_directory, watch_directory=watch_directory,
loop=loop, loop=loop,
user_id=user_id,
persist_directory=persist_directory, persist_directory=persist_directory,
collection_name=collection_name, collection_name=collection_name,
recreate=recreate, recreate=recreate,
database=database
) )
# Process all files if: # Process all files if:

193
src/multi-llm/config.md Normal file
View File

@ -0,0 +1,193 @@
# Environment Configuration Examples
# ==== Development (Ollama only) ====
export OLLAMA_HOST="http://localhost:11434"
export DEFAULT_LLM_PROVIDER="ollama"
# ==== Production with OpenAI ====
export OPENAI_API_KEY="sk-your-openai-key-here"
export DEFAULT_LLM_PROVIDER="openai"
# ==== Production with Anthropic ====
export ANTHROPIC_API_KEY="sk-ant-your-anthropic-key-here"
export DEFAULT_LLM_PROVIDER="anthropic"
# ==== Production with multiple providers ====
export OPENAI_API_KEY="sk-your-openai-key-here"
export ANTHROPIC_API_KEY="sk-ant-your-anthropic-key-here"
export GEMINI_API_KEY="your-gemini-key-here"
export OLLAMA_HOST="http://ollama-server:11434"
export DEFAULT_LLM_PROVIDER="openai"
# ==== Docker Compose Example ====
# docker-compose.yml
version: '3.8'
services:
api:
build: .
ports:
- "8000:8000"
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY}
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
- DEFAULT_LLM_PROVIDER=openai
depends_on:
- ollama
ollama:
image: ollama/ollama
ports:
- "11434:11434"
volumes:
- ollama_data:/root/.ollama
environment:
- OLLAMA_HOST=0.0.0.0
volumes:
ollama_data:
# ==== Kubernetes ConfigMap ====
apiVersion: v1
kind: ConfigMap
metadata:
name: llm-config
data:
DEFAULT_LLM_PROVIDER: "openai"
OLLAMA_HOST: "http://ollama-service:11434"
---
apiVersion: v1
kind: Secret
metadata:
name: llm-secrets
type: Opaque
stringData:
OPENAI_API_KEY: "sk-your-key-here"
ANTHROPIC_API_KEY: "sk-ant-your-key-here"
# ==== Python configuration example ====
# config.py
import os
from llm_proxy import get_llm, LLMProvider
def configure_llm_for_environment():
"""Configure LLM based on deployment environment"""
llm = get_llm()
# Development: Use Ollama
if os.getenv('ENVIRONMENT') == 'development':
llm.configure_provider(LLMProvider.OLLAMA, host='http://localhost:11434')
llm.set_default_provider(LLMProvider.OLLAMA)
# Staging: Use OpenAI with rate limiting
elif os.getenv('ENVIRONMENT') == 'staging':
llm.configure_provider(LLMProvider.OPENAI,
api_key=os.getenv('OPENAI_API_KEY'),
max_retries=3,
timeout=30)
llm.set_default_provider(LLMProvider.OPENAI)
# Production: Use multiple providers with fallback
elif os.getenv('ENVIRONMENT') == 'production':
# Primary: Anthropic
llm.configure_provider(LLMProvider.ANTHROPIC,
api_key=os.getenv('ANTHROPIC_API_KEY'))
# Fallback: OpenAI
llm.configure_provider(LLMProvider.OPENAI,
api_key=os.getenv('OPENAI_API_KEY'))
# Set primary provider
llm.set_default_provider(LLMProvider.ANTHROPIC)
# ==== Usage Examples ====
# Example 1: Basic usage with default provider
async def basic_example():
llm = get_llm()
response = await llm.chat(
model="gpt-4",
messages=[{"role": "user", "content": "Hello!"}]
)
print(response.content)
# Example 2: Specify provider explicitly
async def provider_specific_example():
llm = get_llm()
# Use OpenAI specifically
response = await llm.chat(
model="gpt-4",
messages=[{"role": "user", "content": "Hello!"}],
provider=LLMProvider.OPENAI
)
# Use Anthropic specifically
response2 = await llm.chat(
model="claude-3-sonnet-20240229",
messages=[{"role": "user", "content": "Hello!"}],
provider=LLMProvider.ANTHROPIC
)
# Example 3: Streaming with provider fallback
async def streaming_with_fallback():
llm = get_llm()
try:
async for chunk in llm.chat(
model="claude-3-sonnet-20240229",
messages=[{"role": "user", "content": "Write a story"}],
provider=LLMProvider.ANTHROPIC,
stream=True
):
print(chunk.content, end='', flush=True)
except Exception as e:
print(f"Primary provider failed: {e}")
# Fallback to OpenAI
async for chunk in llm.chat(
model="gpt-4",
messages=[{"role": "user", "content": "Write a story"}],
provider=LLMProvider.OPENAI,
stream=True
):
print(chunk.content, end='', flush=True)
# Example 4: Load balancing between providers
import random
async def load_balanced_request():
llm = get_llm()
available_providers = [LLMProvider.OPENAI, LLMProvider.ANTHROPIC]
# Simple random load balancing
provider = random.choice(available_providers)
# Adjust model based on provider
model_mapping = {
LLMProvider.OPENAI: "gpt-4",
LLMProvider.ANTHROPIC: "claude-3-sonnet-20240229"
}
response = await llm.chat(
model=model_mapping[provider],
messages=[{"role": "user", "content": "Hello!"}],
provider=provider
)
print(f"Response from {provider.value}: {response.content}")
# ==== FastAPI Startup Configuration ====
from fastapi import FastAPI
app = FastAPI()
@app.on_event("startup")
async def startup_event():
"""Configure LLM providers on application startup"""
configure_llm_for_environment()
# Verify configuration
llm = get_llm()
print(f"Configured providers: {[p.value for p in llm._initialized_providers]}")
print(f"Default provider: {llm.default_provider.value}")

238
src/multi-llm/example.py Normal file
View File

@ -0,0 +1,238 @@
from fastapi import FastAPI, HTTPException # type: ignore
from fastapi.responses import StreamingResponse # type: ignore
from pydantic import BaseModel # type: ignore
from typing import List, Optional, Dict, Any
import json
import asyncio
from llm_proxy import get_llm, LLMProvider, ChatMessage
app = FastAPI(title="Unified LLM API")
# Pydantic models for request/response
class ChatRequest(BaseModel):
model: str
messages: List[Dict[str, str]]
provider: Optional[str] = None
stream: bool = False
max_tokens: Optional[int] = None
temperature: Optional[float] = None
class GenerateRequest(BaseModel):
model: str
prompt: str
provider: Optional[str] = None
stream: bool = False
max_tokens: Optional[int] = None
temperature: Optional[float] = None
class ChatResponseModel(BaseModel):
content: str
model: str
provider: str
finish_reason: Optional[str] = None
usage: Optional[Dict[str, int]] = None
@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
"""Chat endpoint that works with any configured LLM provider"""
try:
llm = get_llm()
# Convert provider string to enum if provided
provider = None
if request.provider:
try:
provider = LLMProvider(request.provider.lower())
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Unsupported provider: {request.provider}"
)
# Prepare kwargs
kwargs = {}
if request.max_tokens:
kwargs['max_tokens'] = request.max_tokens
if request.temperature:
kwargs['temperature'] = request.temperature
if request.stream:
async def generate_response():
try:
# Use the type-safe streaming method
async for chunk in llm.chat_stream(
model=request.model,
messages=request.messages,
provider=provider,
**kwargs
):
# Format as Server-Sent Events
data = {
"content": chunk.content,
"model": chunk.model,
"provider": provider.value if provider else llm.default_provider.value,
"finish_reason": chunk.finish_reason
}
yield f"data: {json.dumps(data)}\n\n"
except Exception as e:
error_data = {"error": str(e)}
yield f"data: {json.dumps(error_data)}\n\n"
finally:
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_response(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive"
}
)
else:
# Use the type-safe single response method
response = await llm.chat_single(
model=request.model,
messages=request.messages,
provider=provider,
**kwargs
)
return ChatResponseModel(
content=response.content,
model=response.model,
provider=provider.value if provider else llm.default_provider.value,
finish_reason=response.finish_reason,
usage=response.usage
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate")
async def generate_endpoint(request: GenerateRequest):
"""Generate endpoint for simple text generation"""
try:
llm = get_llm()
provider = None
if request.provider:
try:
provider = LLMProvider(request.provider.lower())
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Unsupported provider: {request.provider}"
)
kwargs = {}
if request.max_tokens:
kwargs['max_tokens'] = request.max_tokens
if request.temperature:
kwargs['temperature'] = request.temperature
if request.stream:
async def generate_response():
try:
# Use the type-safe streaming method
async for chunk in llm.generate_stream(
model=request.model,
prompt=request.prompt,
provider=provider,
**kwargs
):
data = {
"content": chunk.content,
"model": chunk.model,
"provider": provider.value if provider else llm.default_provider.value,
"finish_reason": chunk.finish_reason
}
yield f"data: {json.dumps(data)}\n\n"
except Exception as e:
error_data = {"error": str(e)}
yield f"data: {json.dumps(error_data)}\n\n"
finally:
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_response(),
media_type="text/event-stream"
)
else:
# Use the type-safe single response method
response = await llm.generate_single(
model=request.model,
prompt=request.prompt,
provider=provider,
**kwargs
)
return ChatResponseModel(
content=response.content,
model=response.model,
provider=provider.value if provider else llm.default_provider.value,
finish_reason=response.finish_reason,
usage=response.usage
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/models")
async def list_models(provider: Optional[str] = None):
"""List available models for a provider"""
try:
llm = get_llm()
provider_enum = None
if provider:
try:
provider_enum = LLMProvider(provider.lower())
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Unsupported provider: {provider}"
)
models = await llm.list_models(provider_enum)
return {
"provider": provider_enum.value if provider_enum else llm.default_provider.value,
"models": models
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/providers")
async def list_providers():
"""List all configured providers"""
llm = get_llm()
return {
"providers": [provider.value for provider in llm._initialized_providers],
"default": llm.default_provider.value
}
@app.post("/providers/{provider}/set-default")
async def set_default_provider(provider: str):
"""Set the default provider"""
try:
llm = get_llm()
provider_enum = LLMProvider(provider.lower())
llm.set_default_provider(provider_enum)
return {"message": f"Default provider set to {provider}", "default": provider}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# Health check endpoint
@app.get("/health")
async def health_check():
"""Health check endpoint"""
llm = get_llm()
return {
"status": "healthy",
"providers_configured": len(llm._initialized_providers),
"default_provider": llm.default_provider.value
}
if __name__ == "__main__":
import uvicorn # type: ignore
uvicorn.run(app, host="0.0.0.0", port=8000)

606
src/multi-llm/llm_proxy.py Normal file
View File

@ -0,0 +1,606 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Any, AsyncGenerator, Optional, Union
from enum import Enum
import asyncio
import json
from dataclasses import dataclass
import os
# Standard message format for all providers
@dataclass
class ChatMessage:
role: str # "user", "assistant", "system"
content: str
def __getitem__(self, key: str):
"""Allow dictionary-style access for backward compatibility"""
if key == 'role':
return self.role
elif key == 'content':
return self.content
else:
raise KeyError(f"'{key}' not found in ChatMessage")
def __setitem__(self, key: str, value: str):
"""Allow dictionary-style assignment"""
if key == 'role':
self.role = value
elif key == 'content':
self.content = value
else:
raise KeyError(f"'{key}' not found in ChatMessage")
@classmethod
def from_dict(cls, data: Dict[str, str]) -> 'ChatMessage':
"""Create ChatMessage from dictionary"""
return cls(role=data['role'], content=data['content'])
def to_dict(self) -> Dict[str, str]:
"""Convert ChatMessage to dictionary"""
return {'role': self.role, 'content': self.content}
@dataclass
class ChatResponse:
content: str
model: str
finish_reason: Optional[str] = None
usage: Optional[Dict[str, int]] = None
class LLMProvider(Enum):
OLLAMA = "ollama"
OPENAI = "openai"
ANTHROPIC = "anthropic"
GEMINI = "gemini"
GROK = "grok"
class BaseLLMAdapter(ABC):
"""Abstract base class for all LLM adapters"""
def __init__(self, **config):
self.config = config
@abstractmethod
async def chat(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
"""Send chat messages and get response"""
pass
@abstractmethod
async def generate(
self,
model: str,
prompt: str,
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
"""Generate text from prompt"""
pass
@abstractmethod
async def list_models(self) -> List[str]:
"""List available models"""
pass
class OllamaAdapter(BaseLLMAdapter):
"""Adapter for Ollama"""
def __init__(self, **config):
super().__init__(**config)
import ollama
self.client = ollama.AsyncClient( # type: ignore
host=config.get('host', 'http://localhost:11434')
)
async def chat(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
# Convert ChatMessage objects to Ollama format
ollama_messages = []
for msg in messages:
ollama_messages.append({
"role": msg.role,
"content": msg.content
})
if stream:
return self._stream_chat(model, ollama_messages, **kwargs)
else:
response = await self.client.chat(
model=model,
messages=ollama_messages,
stream=False,
**kwargs
)
return ChatResponse(
content=response['message']['content'],
model=model,
finish_reason=response.get('done_reason')
)
async def _stream_chat(self, model: str, messages: List[Dict], **kwargs):
async for chunk in self.client.chat(
model=model,
messages=messages,
stream=True,
**kwargs
):
if chunk.get('message', {}).get('content'):
yield ChatResponse(
content=chunk['message']['content'],
model=model,
finish_reason=chunk.get('done_reason')
)
async def generate(
self,
model: str,
prompt: str,
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
if stream:
return self._stream_generate(model, prompt, **kwargs)
else:
response = await self.client.generate(
model=model,
prompt=prompt,
stream=False,
**kwargs
)
return ChatResponse(
content=response['response'],
model=model,
finish_reason=response.get('done_reason')
)
async def _stream_generate(self, model: str, prompt: str, **kwargs):
async for chunk in self.client.generate(
model=model,
prompt=prompt,
stream=True,
**kwargs
):
if chunk.get('response'):
yield ChatResponse(
content=chunk['response'],
model=model,
finish_reason=chunk.get('done_reason')
)
async def list_models(self) -> List[str]:
models = await self.client.list()
return [model['name'] for model in models['models']]
class OpenAIAdapter(BaseLLMAdapter):
"""Adapter for OpenAI"""
def __init__(self, **config):
super().__init__(**config)
import openai # type: ignore
self.client = openai.AsyncOpenAI(
api_key=config.get('api_key', os.getenv('OPENAI_API_KEY'))
)
async def chat(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
# Convert ChatMessage objects to OpenAI format
openai_messages = []
for msg in messages:
openai_messages.append({
"role": msg.role,
"content": msg.content
})
if stream:
return self._stream_chat(model, openai_messages, **kwargs)
else:
response = await self.client.chat.completions.create(
model=model,
messages=openai_messages,
stream=False,
**kwargs
)
return ChatResponse(
content=response.choices[0].message.content,
model=model,
finish_reason=response.choices[0].finish_reason,
usage=response.usage.model_dump() if response.usage else None
)
async def _stream_chat(self, model: str, messages: List[Dict], **kwargs):
async for chunk in await self.client.chat.completions.create(
model=model,
messages=messages,
stream=True,
**kwargs
):
if chunk.choices[0].delta.content:
yield ChatResponse(
content=chunk.choices[0].delta.content,
model=model,
finish_reason=chunk.choices[0].finish_reason
)
async def generate(
self,
model: str,
prompt: str,
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
# Convert to chat format for OpenAI
messages = [ChatMessage(role="user", content=prompt)]
return await self.chat(model, messages, stream, **kwargs)
async def list_models(self) -> List[str]:
models = await self.client.models.list()
return [model.id for model in models.data]
class AnthropicAdapter(BaseLLMAdapter):
"""Adapter for Anthropic Claude"""
def __init__(self, **config):
super().__init__(**config)
import anthropic # type: ignore
self.client = anthropic.AsyncAnthropic(
api_key=config.get('api_key', os.getenv('ANTHROPIC_API_KEY'))
)
async def chat(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
# Anthropic requires system message to be separate
system_message = None
anthropic_messages = []
for msg in messages:
if msg.role == "system":
system_message = msg.content
else:
anthropic_messages.append({
"role": msg.role,
"content": msg.content
})
request_kwargs = {
"model": model,
"messages": anthropic_messages,
"max_tokens": kwargs.pop('max_tokens', 1000),
**kwargs
}
if system_message:
request_kwargs["system"] = system_message
if stream:
return self._stream_chat(**request_kwargs)
else:
response = await self.client.messages.create(
stream=False,
**request_kwargs
)
return ChatResponse(
content=response.content[0].text,
model=model,
finish_reason=response.stop_reason,
usage={
"input_tokens": response.usage.input_tokens,
"output_tokens": response.usage.output_tokens
}
)
async def _stream_chat(self, **kwargs):
async with self.client.messages.stream(**kwargs) as stream:
async for text in stream.text_stream:
yield ChatResponse(
content=text,
model=kwargs['model']
)
async def generate(
self,
model: str,
prompt: str,
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
messages = [ChatMessage(role="user", content=prompt)]
return await self.chat(model, messages, stream, **kwargs)
async def list_models(self) -> List[str]:
# Anthropic doesn't have a list models endpoint, return known models
return [
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-20241022",
"claude-3-opus-20240229"
]
class GeminiAdapter(BaseLLMAdapter):
"""Adapter for Google Gemini"""
def __init__(self, **config):
super().__init__(**config)
import google.generativeai as genai # type: ignore
genai.configure(api_key=config.get('api_key', os.getenv('GEMINI_API_KEY')))
self.genai = genai
async def chat(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
model_instance = self.genai.GenerativeModel(model)
# Convert messages to Gemini format
chat_history = []
current_message = None
for i, msg in enumerate(messages[:-1]): # All but last message go to history
if msg.role == "user":
chat_history.append({"role": "user", "parts": [msg.content]})
elif msg.role == "assistant":
chat_history.append({"role": "model", "parts": [msg.content]})
# Last message is the current prompt
if messages:
current_message = messages[-1].content
chat = model_instance.start_chat(history=chat_history)
if not current_message:
raise ValueError("No current message provided for chat")
if stream:
return self._stream_chat(chat, current_message, **kwargs)
else:
response = await chat.send_message_async(current_message, **kwargs)
return ChatResponse(
content=response.text,
model=model,
finish_reason=response.candidates[0].finish_reason.name if response.candidates else None
)
async def _stream_chat(self, chat, message: str, **kwargs):
async for chunk in chat.send_message_async(message, stream=True, **kwargs):
if chunk.text:
yield ChatResponse(
content=chunk.text,
model=chat.model.model_name
)
async def generate(
self,
model: str,
prompt: str,
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
messages = [ChatMessage(role="user", content=prompt)]
return await self.chat(model, messages, stream, **kwargs)
async def list_models(self) -> List[str]:
models = self.genai.list_models()
return [model.name for model in models if 'generateContent' in model.supported_generation_methods]
class UnifiedLLMProxy:
"""Main proxy class that provides unified interface to all LLM providers"""
def __init__(self, default_provider: LLMProvider = LLMProvider.OLLAMA):
self.adapters: Dict[LLMProvider, BaseLLMAdapter] = {}
self.default_provider = default_provider
self._initialized_providers = set()
def configure_provider(self, provider: LLMProvider, **config):
"""Configure a specific provider with its settings"""
adapter_classes = {
LLMProvider.OLLAMA: OllamaAdapter,
LLMProvider.OPENAI: OpenAIAdapter,
LLMProvider.ANTHROPIC: AnthropicAdapter,
LLMProvider.GEMINI: GeminiAdapter,
# Add other providers as needed
}
if provider in adapter_classes:
self.adapters[provider] = adapter_classes[provider](**config)
self._initialized_providers.add(provider)
else:
raise ValueError(f"Unsupported provider: {provider}")
def set_default_provider(self, provider: LLMProvider):
"""Set the default provider for requests"""
if provider not in self._initialized_providers:
raise ValueError(f"Provider {provider} not configured")
self.default_provider = provider
async def chat(
self,
model: str,
messages: Union[List[ChatMessage], List[Dict[str, str]]],
provider: Optional[LLMProvider] = None,
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
"""Send chat messages using specified or default provider"""
provider = provider or self.default_provider
adapter = self._get_adapter(provider)
# Normalize messages to ChatMessage objects
normalized_messages = []
for msg in messages:
if isinstance(msg, ChatMessage):
normalized_messages.append(msg)
elif isinstance(msg, dict):
normalized_messages.append(ChatMessage.from_dict(msg))
else:
raise ValueError(f"Invalid message type: {type(msg)}")
return await adapter.chat(model, normalized_messages, stream, **kwargs)
async def chat_stream(
self,
model: str,
messages: Union[List[ChatMessage], List[Dict[str, str]]],
provider: Optional[LLMProvider] = None,
**kwargs
) -> AsyncGenerator[ChatResponse, None]:
"""Stream chat messages using specified or default provider"""
result = await self.chat(model, messages, provider, stream=True, **kwargs)
# Type checker now knows this is an AsyncGenerator due to stream=True
async for chunk in result: # type: ignore
yield chunk
async def chat_single(
self,
model: str,
messages: Union[List[ChatMessage], List[Dict[str, str]]],
provider: Optional[LLMProvider] = None,
**kwargs
) -> ChatResponse:
"""Get single chat response using specified or default provider"""
result = await self.chat(model, messages, provider, stream=False, **kwargs)
# Type checker now knows this is a ChatResponse due to stream=False
return result # type: ignore
async def generate(
self,
model: str,
prompt: str,
provider: Optional[LLMProvider] = None,
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
"""Generate text using specified or default provider"""
provider = provider or self.default_provider
adapter = self._get_adapter(provider)
return await adapter.generate(model, prompt, stream, **kwargs)
async def generate_stream(
self,
model: str,
prompt: str,
provider: Optional[LLMProvider] = None,
**kwargs
) -> AsyncGenerator[ChatResponse, None]:
"""Stream text generation using specified or default provider"""
result = await self.generate(model, prompt, provider, stream=True, **kwargs)
async for chunk in result: # type: ignore
yield chunk
async def generate_single(
self,
model: str,
prompt: str,
provider: Optional[LLMProvider] = None,
**kwargs
) -> ChatResponse:
"""Get single generation response using specified or default provider"""
result = await self.generate(model, prompt, provider, stream=False, **kwargs)
return result # type: ignore
async def list_models(self, provider: Optional[LLMProvider] = None) -> List[str]:
"""List available models for specified or default provider"""
provider = provider or self.default_provider
adapter = self._get_adapter(provider)
return await adapter.list_models()
def _get_adapter(self, provider: LLMProvider) -> BaseLLMAdapter:
"""Get adapter for specified provider"""
if provider not in self.adapters:
raise ValueError(f"Provider {provider} not configured")
return self.adapters[provider]
# Example usage and configuration
class LLMManager:
"""Singleton manager for the unified LLM proxy"""
_instance = None
_proxy = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
if LLMManager._proxy is None:
LLMManager._proxy = UnifiedLLMProxy()
self._configure_from_environment()
def _configure_from_environment(self):
"""Configure providers based on environment variables"""
if not self._proxy:
raise RuntimeError("LLM proxy not initialized")
# Configure Ollama if available
ollama_host = os.getenv('OLLAMA_HOST', 'http://localhost:11434')
self._proxy.configure_provider(LLMProvider.OLLAMA, host=ollama_host)
# Configure OpenAI if API key is available
if os.getenv('OPENAI_API_KEY'):
self._proxy.configure_provider(LLMProvider.OPENAI)
# Configure Anthropic if API key is available
if os.getenv('ANTHROPIC_API_KEY'):
self._proxy.configure_provider(LLMProvider.ANTHROPIC)
# Configure Gemini if API key is available
if os.getenv('GEMINI_API_KEY'):
self._proxy.configure_provider(LLMProvider.GEMINI)
# Set default provider from environment or use Ollama
default_provider = os.getenv('DEFAULT_LLM_PROVIDER', 'ollama')
try:
self._proxy.set_default_provider(LLMProvider(default_provider))
except ValueError:
# Fallback to Ollama if specified provider not available
self._proxy.set_default_provider(LLMProvider.OLLAMA)
def get_proxy(self) -> UnifiedLLMProxy:
"""Get the unified LLM proxy instance"""
if not self._proxy:
raise RuntimeError("LLM proxy not initialized")
return self._proxy
# Convenience function for easy access
def get_llm() -> UnifiedLLMProxy:
"""Get the configured LLM proxy"""
return LLMManager.get_instance().get_proxy()