Multi-LLM backend working with Ollama

This commit is contained in:
James Ketr 2025-06-07 15:03:00 -07:00
parent 5fed56ba76
commit 8aa8577874
25 changed files with 3802 additions and 357 deletions

View File

@ -0,0 +1,541 @@
import React, { useState, useEffect, useRef, JSX } from 'react';
import {
Box,
Button,
Typography,
Paper,
TextField,
Grid,
Dialog,
DialogTitle,
DialogContent,
DialogContentText,
DialogActions,
IconButton,
useTheme,
useMediaQuery,
Chip,
Divider,
Card,
CardContent,
CardHeader,
LinearProgress,
Stack,
Alert
} from '@mui/material';
import {
SyncAlt,
Favorite,
Settings,
Info,
Search,
AutoFixHigh,
Image,
Psychology,
Build,
CloudUpload,
Description,
Business,
LocationOn,
Work,
CheckCircle,
Star
} from '@mui/icons-material';
import { styled } from '@mui/material/styles';
import DescriptionIcon from '@mui/icons-material/Description';
import FileUploadIcon from '@mui/icons-material/FileUpload';
import { useAuth } from 'hooks/AuthContext';
import { useSelectedCandidate, useSelectedJob } from 'hooks/GlobalContext';
import { BackstoryElementProps } from './BackstoryTab';
import { LoginRequired } from 'components/ui/LoginRequired';
import * as Types from 'types/types';
const VisuallyHiddenInput = styled('input')({
clip: 'rect(0 0 0 0)',
clipPath: 'inset(50%)',
height: 1,
overflow: 'hidden',
position: 'absolute',
bottom: 0,
left: 0,
whiteSpace: 'nowrap',
width: 1,
});
const UploadBox = styled(Box)(({ theme }) => ({
border: `2px dashed ${theme.palette.primary.main}`,
borderRadius: theme.shape.borderRadius * 2,
padding: theme.spacing(4),
textAlign: 'center',
backgroundColor: theme.palette.action.hover,
transition: 'all 0.3s ease',
cursor: 'pointer',
'&:hover': {
backgroundColor: theme.palette.action.selected,
borderColor: theme.palette.primary.dark,
},
}));
const StatusBox = styled(Box)(({ theme }) => ({
display: 'flex',
alignItems: 'center',
gap: theme.spacing(1),
padding: theme.spacing(1, 2),
backgroundColor: theme.palette.background.paper,
borderRadius: theme.shape.borderRadius,
border: `1px solid ${theme.palette.divider}`,
minHeight: 48,
}));
const getIcon = (type: Types.ApiActivityType) => {
switch (type) {
case 'converting':
return <SyncAlt color="primary" />;
case 'heartbeat':
return <Favorite color="error" />;
case 'system':
return <Settings color="action" />;
case 'info':
return <Info color="info" />;
case 'searching':
return <Search color="primary" />;
case 'generating':
return <AutoFixHigh color="secondary" />;
case 'generating_image':
return <Image color="primary" />;
case 'thinking':
return <Psychology color="secondary" />;
case 'tooling':
return <Build color="action" />;
default:
return <Info color="action" />;
}
};
interface JobCreator extends BackstoryElementProps {
onSave?: (job: Types.Job) => void;
}
const JobCreator = (props: JobCreator) => {
const { user, apiClient } = useAuth();
const { onSave } = props;
const { selectedCandidate } = useSelectedCandidate();
const { selectedJob, setSelectedJob } = useSelectedJob();
const { setSnack, submitQuery } = props;
const backstoryProps = { setSnack, submitQuery };
const theme = useTheme();
const isMobile = useMediaQuery(theme.breakpoints.down('sm'));
const isTablet = useMediaQuery(theme.breakpoints.down('md'));
const [openUploadDialog, setOpenUploadDialog] = useState<boolean>(false);
const [jobDescription, setJobDescription] = useState<string>('');
const [jobRequirements, setJobRequirements] = useState<Types.JobRequirements | null>(null);
const [jobTitle, setJobTitle] = useState<string>('');
const [company, setCompany] = useState<string>('');
const [summary, setSummary] = useState<string>('');
const [jobLocation, setJobLocation] = useState<string>('');
const [jobId, setJobId] = useState<string>('');
const [jobStatus, setJobStatus] = useState<string>('');
const [jobStatusIcon, setJobStatusIcon] = useState<JSX.Element>(<></>);
const [isProcessing, setIsProcessing] = useState<boolean>(false);
useEffect(() => {
}, [jobTitle, jobDescription, company]);
const fileInputRef = useRef<HTMLInputElement>(null);
if (!user?.id) {
return (
<LoginRequired asset="job creation" />
);
}
const jobStatusHandlers = {
onStatus: (status: Types.ChatMessageStatus) => {
console.log('status:', status.content);
setJobStatusIcon(getIcon(status.activity));
setJobStatus(status.content);
},
onMessage: (job: Types.Job) => {
console.log('onMessage - job', job);
setCompany(job.company || '');
setJobDescription(job.description);
setSummary(job.summary || '');
setJobTitle(job.title || '');
setJobRequirements(job.requirements || null);
setJobStatusIcon(<></>);
setJobStatus('');
},
onError: (error: Types.ChatMessageError) => {
console.log('onError', error);
setSnack(error.content, "error");
setIsProcessing(false);
},
onComplete: () => {
setJobStatusIcon(<></>);
setJobStatus('');
setIsProcessing(false);
}
};
const handleJobUpload = async (e: React.ChangeEvent<HTMLInputElement>) => {
if (e.target.files && e.target.files[0]) {
const file = e.target.files[0];
const fileExtension = '.' + file.name.split('.').pop()?.toLowerCase();
let docType: Types.DocumentType | null = null;
switch (fileExtension.substring(1)) {
case "pdf":
docType = "pdf";
break;
case "docx":
docType = "docx";
break;
case "md":
docType = "markdown";
break;
case "txt":
docType = "txt";
break;
}
if (!docType) {
setSnack('Invalid file type. Please upload .txt, .md, .docx, or .pdf files only.', 'error');
return;
}
try {
setIsProcessing(true);
setJobDescription('');
setJobTitle('');
setJobRequirements(null);
setSummary('');
const controller = apiClient.createJobFromFile(file, jobStatusHandlers);
const job = await controller.promise;
if (!job) {
return;
}
console.log(`Job id: ${job.id}`);
e.target.value = '';
} catch (error) {
console.error(error);
setSnack('Failed to upload document', 'error');
setIsProcessing(false);
}
}
};
const handleUploadClick = () => {
fileInputRef.current?.click();
};
const renderRequirementSection = (title: string, items: string[] | undefined, icon: JSX.Element, required = false) => {
if (!items || items.length === 0) return null;
return (
<Box sx={{ mb: 3 }}>
<Box sx={{ display: 'flex', alignItems: 'center', mb: 1.5 }}>
{icon}
<Typography variant="subtitle1" sx={{ ml: 1, fontWeight: 600 }}>
{title}
</Typography>
{required && <Chip label="Required" size="small" color="error" sx={{ ml: 1 }} />}
</Box>
<Stack direction="row" spacing={1} flexWrap="wrap" useFlexGap>
{items.map((item, index) => (
<Chip
key={index}
label={item}
variant="outlined"
size="small"
sx={{ mb: 1 }}
/>
))}
</Stack>
</Box>
);
};
const renderJobRequirements = () => {
if (!jobRequirements) return null;
return (
<Card elevation={2} sx={{ mt: 3 }}>
<CardHeader
title="Job Requirements Analysis"
avatar={<CheckCircle color="success" />}
sx={{ pb: 1 }}
/>
<CardContent sx={{ pt: 0 }}>
{renderRequirementSection(
"Technical Skills (Required)",
jobRequirements.technicalSkills.required,
<Build color="primary" />,
true
)}
{renderRequirementSection(
"Technical Skills (Preferred)",
jobRequirements.technicalSkills.preferred,
<Build color="action" />
)}
{renderRequirementSection(
"Experience Requirements (Required)",
jobRequirements.experienceRequirements.required,
<Work color="primary" />,
true
)}
{renderRequirementSection(
"Experience Requirements (Preferred)",
jobRequirements.experienceRequirements.preferred,
<Work color="action" />
)}
{renderRequirementSection(
"Soft Skills",
jobRequirements.softSkills,
<Psychology color="secondary" />
)}
{renderRequirementSection(
"Experience",
jobRequirements.experience,
<Star color="warning" />
)}
{renderRequirementSection(
"Education",
jobRequirements.education,
<Description color="info" />
)}
{renderRequirementSection(
"Certifications",
jobRequirements.certifications,
<CheckCircle color="success" />
)}
{renderRequirementSection(
"Preferred Attributes",
jobRequirements.preferredAttributes,
<Star color="secondary" />
)}
</CardContent>
</Card>
);
};
const handleSave = async () => {
const newJob: Types.Job = {
ownerId: user?.id || '',
ownerType: 'candidate',
description: jobDescription,
company: company,
summary: summary,
title: jobTitle,
requirements: jobRequirements || undefined
};
setIsProcessing(true);
const job = await apiClient.createJob(newJob);
setIsProcessing(false);
onSave ? onSave(job) : setSelectedJob(job);
};
const handleExtractRequirements = () => {
// Implement requirements extraction logic here
setIsProcessing(true);
// This would call your API to extract requirements from the job description
};
const renderJobCreation = () => {
if (!user) {
return <Box>You must be logged in</Box>;
}
return (
<Box sx={{
mx: 'auto', p: { xs: 2, sm: 3 },
}}>
{/* Upload Section */}
<Card elevation={3} sx={{ mb: 4 }}>
<CardHeader
title="Job Information"
subheader="Upload a job description or enter details manually"
avatar={<Work color="primary" />}
/>
<CardContent>
<Grid container spacing={3}>
<Grid size={{ xs: 12, md: 6 }}>
<Typography variant="h6" gutterBottom sx={{ display: 'flex', alignItems: 'center' }}>
<CloudUpload sx={{ mr: 1 }} />
Upload Job Description
</Typography>
<UploadBox onClick={handleUploadClick}>
<CloudUpload sx={{ fontSize: 48, color: 'primary.main', mb: 2 }} />
<Typography variant="h6" gutterBottom>
Drop your job description here
</Typography>
<Typography variant="body2" color="text.secondary" sx={{ mb: 2 }}>
Supported formats: PDF, DOCX, TXT, MD
</Typography>
<Button
variant="contained"
startIcon={<FileUploadIcon />}
disabled={isProcessing}
// onClick={handleUploadClick}
>
Choose File
</Button>
</UploadBox>
<VisuallyHiddenInput
ref={fileInputRef}
type="file"
accept=".txt,.md,.docx,.pdf"
onChange={handleJobUpload}
/>
</Grid>
<Grid size={{ xs: 12, md: 6 }}>
<Typography variant="h6" gutterBottom sx={{ display: 'flex', alignItems: 'center' }}>
<Description sx={{ mr: 1 }} />
Or Enter Manually
</Typography>
<TextField
fullWidth
multiline
rows={isMobile ? 8 : 12}
placeholder="Paste or type the job description here..."
variant="outlined"
value={jobDescription}
onChange={(e) => setJobDescription(e.target.value)}
disabled={isProcessing}
sx={{ mb: 2 }}
/>
{jobRequirements === null && jobDescription && (
<Button
variant="outlined"
onClick={handleExtractRequirements}
startIcon={<AutoFixHigh />}
disabled={isProcessing}
fullWidth={isMobile}
>
Extract Requirements
</Button>
)}
</Grid>
</Grid>
{(jobStatus || isProcessing) && (
<Box sx={{ mt: 3 }}>
<StatusBox>
{jobStatusIcon}
<Typography variant="body2" sx={{ ml: 1 }}>
{jobStatus || 'Processing...'}
</Typography>
</StatusBox>
{isProcessing && <LinearProgress sx={{ mt: 1 }} />}
</Box>
)}
</CardContent>
</Card>
{/* Job Details Section */}
<Card elevation={3} sx={{ mb: 4 }}>
<CardHeader
title="Job Details"
subheader="Enter specific information about the position"
avatar={<Business color="primary" />}
/>
<CardContent>
<Grid container spacing={3}>
<Grid size={{ xs: 12, md: 6 }}>
<TextField
fullWidth
label="Job Title"
variant="outlined"
value={jobTitle}
onChange={(e) => setJobTitle(e.target.value)}
required
disabled={isProcessing}
InputProps={{
startAdornment: <Work sx={{ mr: 1, color: 'text.secondary' }} />
}}
/>
</Grid>
<Grid size={{ xs: 12, md: 6 }}>
<TextField
fullWidth
label="Company"
variant="outlined"
value={company}
onChange={(e) => setCompany(e.target.value)}
required
disabled={isProcessing}
InputProps={{
startAdornment: <Business sx={{ mr: 1, color: 'text.secondary' }} />
}}
/>
</Grid>
{/* <Grid size={{ xs: 12, md: 6 }}>
<TextField
fullWidth
label="Job Location"
variant="outlined"
value={jobLocation}
onChange={(e) => setJobLocation(e.target.value)}
disabled={isProcessing}
InputProps={{
startAdornment: <LocationOn sx={{ mr: 1, color: 'text.secondary' }} />
}}
/>
</Grid> */}
<Grid size={{ xs: 12, md: 6 }}>
<Box sx={{ display: 'flex', gap: 2, alignItems: 'flex-end', height: '100%' }}>
<Button
variant="contained"
onClick={handleSave}
disabled={!jobTitle || !company || !jobDescription || isProcessing}
fullWidth={isMobile}
size="large"
startIcon={<CheckCircle />}
>
Save Job
</Button>
</Box>
</Grid>
</Grid>
</CardContent>
</Card>
{/* Job Summary */}
{summary !== '' &&
<Card elevation={2} sx={{ mt: 3 }}>
<CardHeader
title="Job Summary"
avatar={<CheckCircle color="success" />}
sx={{ pb: 1 }}
/>
<CardContent sx={{ pt: 0 }}>
{summary}
</CardContent>
</Card>
}
{/* Requirements Display */}
{renderJobRequirements()}
</Box>
);
};
return (
<Box className="JobManagement"
sx={{
background: "white",
p: 0,
}}>
{selectedJob === null && renderJobCreation()}
</Box>
);
};
export { JobCreator };

View File

@ -115,34 +115,23 @@ const getIcon = (type: Types.ApiActivityType) => {
};
const JobManagement = (props: BackstoryElementProps) => {
const { user, apiClient } = useAuth();
const { selectedCandidate } = useSelectedCandidate();
const { user, apiClient } = useAuth();
const { selectedJob, setSelectedJob } = useSelectedJob();
const { setSnack, submitQuery } = props;
const backstoryProps = { setSnack, submitQuery };
const { setSnack, submitQuery } = props;
const theme = useTheme();
const isMobile = useMediaQuery(theme.breakpoints.down('sm'));
const isTablet = useMediaQuery(theme.breakpoints.down('md'));
const isMobile = useMediaQuery(theme.breakpoints.down('sm'));
const [openUploadDialog, setOpenUploadDialog] = useState<boolean>(false);
const [jobDescription, setJobDescription] = useState<string>('');
const [jobRequirements, setJobRequirements] = useState<Types.JobRequirements | null>(null);
const [jobTitle, setJobTitle] = useState<string>('');
const [company, setCompany] = useState<string>('');
const [summary, setSummary] = useState<string>('');
const [jobLocation, setJobLocation] = useState<string>('');
const [jobId, setJobId] = useState<string>('');
const [jobStatus, setJobStatus] = useState<string>('');
const [jobStatusIcon, setJobStatusIcon] = useState<JSX.Element>(<></>);
const [isProcessing, setIsProcessing] = useState<boolean>(false);
useEffect(() => {
}, [jobTitle, jobDescription, company]);
const fileInputRef = useRef<HTMLInputElement>(null);
if (!user?.id) {
return (
<LoginRequired asset="candidate analysis" />
@ -223,12 +212,12 @@ const JobManagement = (props: BackstoryElementProps) => {
setJobTitle('');
setJobRequirements(null);
setSummary('');
const controller = apiClient.uploadCandidateDocument(file, { isJobDocument: true, overwrite: true }, documentStatusHandlers);
const document = await controller.promise;
if (!document) {
const controller = apiClient.createJobFromFile(file, jobStatusHandlers);
const job = await controller.promise;
if (!job) {
return;
}
console.log(`Document id: ${document.id}`);
console.log(`Job id: ${job.id}`);
e.target.value = '';
} catch (error) {
console.error(error);
@ -354,10 +343,6 @@ const JobManagement = (props: BackstoryElementProps) => {
// This would call your API to extract requirements from the job description
};
const loadJob = async () => {
const job = await apiClient.getJob("7594e989-a926-45a2-9b07-ae553d2e0d0d");
setSelectedJob(job);
}
const renderJobCreation = () => {
if (!user) {
return <Box>You must be logged in</Box>;
@ -367,7 +352,6 @@ const JobManagement = (props: BackstoryElementProps) => {
<Box sx={{
mx: 'auto', p: { xs: 2, sm: 3 },
}}>
<Button onClick={loadJob} variant="contained">Load Job</Button>
{/* Upload Section */}
<Card elevation={3} sx={{ mb: 4 }}>
<CardHeader

View File

@ -14,7 +14,8 @@ import {
CardContent,
useTheme,
LinearProgress,
useMediaQuery
useMediaQuery,
Button
} from '@mui/material';
import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
import CheckCircleIcon from '@mui/icons-material/CheckCircle';
@ -28,6 +29,8 @@ import { toCamelCase } from 'types/conversion';
import { Job } from 'types/types';
import { StyledMarkdown } from './StyledMarkdown';
import { Scrollable } from './Scrollable';
import { start } from 'repl';
import { TypesElement } from '@uiw/react-json-view';
interface JobAnalysisProps extends BackstoryPageProps {
job: Job;
@ -56,6 +59,9 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
const [overallScore, setOverallScore] = useState<number>(0);
const [requirementsSession, setRequirementsSession] = useState<ChatSession | null>(null);
const [statusMessage, setStatusMessage] = useState<ChatMessage | null>(null);
const [startAnalysis, setStartAnalysis] = useState<boolean>(false);
const [analyzing, setAnalyzing] = useState<boolean>(false);
const isMobile = useMediaQuery(theme.breakpoints.down('sm'));
// Handle accordion expansion
@ -63,7 +69,7 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
setExpanded(isExpanded ? panel : false);
};
useEffect(() => {
const initializeRequirements = (job: Job) => {
if (!job || !job.requirements) {
return;
}
@ -106,116 +112,19 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
setSkillMatches(initialSkillMatches);
setStatusMessage(null);
setLoadingRequirements(false);
}, [job, setRequirements]);
setOverallScore(0);
}
useEffect(() => {
if (requirementsSession || creatingSession) {
return;
}
try {
setCreatingSession(true);
apiClient.getOrCreateChatSession(candidate, `Generate requirements for ${candidate.fullName}`, 'job_requirements')
.then(session => {
setRequirementsSession(session);
setCreatingSession(false);
});
} catch (error) {
setSnack('Unable to load chat session', 'error');
} finally {
setCreatingSession(false);
}
}, [requirementsSession, apiClient, candidate]);
// Fetch initial requirements
useEffect(() => {
if (!job.description || !requirementsSession || loadingRequirements) {
return;
}
const getRequirements = async () => {
setLoadingRequirements(true);
try {
const chatMessage: ChatMessageUser = { ...defaultMessage, sessionId: requirementsSession.id || '', content: job.description };
apiClient.sendMessageStream(chatMessage, {
onMessage: (msg: ChatMessage) => {
console.log(`onMessage: ${msg.type}`, msg);
const job: Job = toCamelCase<Job>(JSON.parse(msg.content || ''));
const requirements: { requirement: string, domain: string }[] = [];
if (job.requirements?.technicalSkills) {
job.requirements.technicalSkills.required?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (required)' }));
job.requirements.technicalSkills.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (preferred)' }));
}
if (job.requirements?.experienceRequirements) {
job.requirements.experienceRequirements.required?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (required)' }));
job.requirements.experienceRequirements.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (preferred)' }));
}
if (job.requirements?.softSkills) {
job.requirements.softSkills.forEach(req => requirements.push({ requirement: req, domain: 'Soft Skills' }));
}
if (job.requirements?.experience) {
job.requirements.experience.forEach(req => requirements.push({ requirement: req, domain: 'Experience' }));
}
if (job.requirements?.education) {
job.requirements.education.forEach(req => requirements.push({ requirement: req, domain: 'Education' }));
}
if (job.requirements?.certifications) {
job.requirements.certifications.forEach(req => requirements.push({ requirement: req, domain: 'Certifications' }));
}
if (job.requirements?.preferredAttributes) {
job.requirements.preferredAttributes.forEach(req => requirements.push({ requirement: req, domain: 'Preferred Attributes' }));
}
const initialSkillMatches = requirements.map(req => ({
requirement: req.requirement,
domain: req.domain,
status: 'waiting' as const,
matchScore: 0,
assessment: '',
description: '',
citations: []
}));
setRequirements(requirements);
setSkillMatches(initialSkillMatches);
setStatusMessage(null);
setLoadingRequirements(false);
},
onError: (error: string | ChatMessageError) => {
console.log("onError:", error);
// Type-guard to determine if this is a ChatMessageBase or a string
if (typeof error === "object" && error !== null && "content" in error) {
setSnack(error.content || 'Error obtaining requirements from job description.', "error");
} else {
setSnack(error as string, "error");
}
setLoadingRequirements(false);
},
onStreaming: (chunk: ChatMessageStreaming) => {
// console.log("onStreaming:", chunk);
},
onStatus: (status: ChatMessageStatus) => {
console.log(`onStatus: ${status}`);
},
onComplete: () => {
console.log("onComplete");
setStatusMessage(null);
setLoadingRequirements(false);
}
});
} catch (error) {
console.error('Failed to send message:', error);
setLoadingRequirements(false);
}
};
getRequirements();
}, [job, requirementsSession]);
initializeRequirements(job);
}, [job]);
// Fetch match data for each requirement
useEffect(() => {
if (!startAnalysis || analyzing || !job.requirements) {
return;
}
const fetchMatchData = async () => {
if (requirements.length === 0) return;
@ -279,10 +188,9 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
}
};
if (!loadingRequirements) {
fetchMatchData();
}
}, [requirements, loadingRequirements]);
setAnalyzing(true);
fetchMatchData().then(() => { setAnalyzing(false); setStartAnalysis(false) });
}, [job, startAnalysis, analyzing, requirements, loadingRequirements]);
// Get color based on match score
const getMatchColor = (score: number): string => {
@ -301,6 +209,11 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
return <ErrorIcon color="error" />;
};
const beginAnalysis = () => {
initializeRequirements(job);
setStartAnalysis(true);
};
return (
<Box>
<Paper elevation={3} sx={{ p: 3, mb: 4 }}>
@ -344,7 +257,9 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
</Box>
<Grid size={{ xs: 12 }} sx={{ mt: 2 }}>
<Box sx={{ display: 'flex', alignItems: 'center', mb: 2 }}>
<Box sx={{ display: 'flex', alignItems: 'center', mb: 2, gap: 1 }}>
{<Button disabled={analyzing || startAnalysis} onClick={beginAnalysis} variant="contained">Start Analysis</Button>}
{overallScore !== 0 && <>
<Typography variant="h5" component="h2" sx={{ mr: 2 }}>
Overall Match:
</Typography>
@ -391,6 +306,7 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
fontWeight: 'bold'
}}
/>
</>}
</Box>
</Grid>
</Grid>

View File

@ -13,7 +13,7 @@ import { CopyBubble } from "components/CopyBubble";
import { rest } from 'lodash';
import { AIBanner } from 'components/ui/AIBanner';
import { useAuth } from 'hooks/AuthContext';
import { DeleteConfirmation } from './DeleteConfirmation';
import { DeleteConfirmation } from '../DeleteConfirmation';
interface CandidateInfoProps {
candidate: Candidate;

View File

@ -4,7 +4,7 @@ import Button from '@mui/material/Button';
import Box from '@mui/material/Box';
import { BackstoryElementProps } from 'components/BackstoryTab';
import { CandidateInfo } from 'components/CandidateInfo';
import { CandidateInfo } from 'components/ui/CandidateInfo';
import { Candidate } from "types/types";
import { useAuth } from 'hooks/AuthContext';
import { useSelectedCandidate } from 'hooks/GlobalContext';

View File

@ -0,0 +1,97 @@
import React from 'react';
import { Box, Link, Typography, Avatar, Grid, SxProps, CardActions } from '@mui/material';
import {
Card,
CardContent,
Divider,
useTheme,
} from '@mui/material';
import DeleteIcon from '@mui/icons-material/Delete';
import { useMediaQuery } from '@mui/material';
import { JobFull } from 'types/types';
import { CopyBubble } from "components/CopyBubble";
import { rest } from 'lodash';
import { AIBanner } from 'components/ui/AIBanner';
import { useAuth } from 'hooks/AuthContext';
import { DeleteConfirmation } from '../DeleteConfirmation';
interface JobInfoProps {
job: JobFull;
sx?: SxProps;
action?: string;
elevation?: number;
variant?: "small" | "normal" | null
};
const JobInfo: React.FC<JobInfoProps> = (props: JobInfoProps) => {
const { job } = props;
const { user, apiClient } = useAuth();
const {
sx,
action = '',
elevation = 1,
variant = "normal"
} = props;
const theme = useTheme();
const isMobile = useMediaQuery(theme.breakpoints.down('md'));
const isAdmin = user?.isAdmin;
const deleteJob = async (jobId: string | undefined) => {
if (jobId) {
await apiClient.deleteJob(jobId);
}
}
if (!job) {
return <Box>No user loaded.</Box>;
}
return (
<Card
elevation={elevation}
sx={{
display: "flex",
borderColor: 'transparent',
borderWidth: 2,
borderStyle: 'solid',
transition: 'all 0.3s ease',
flexDirection: "column",
...sx
}}
{...rest}
>
<CardContent sx={{ display: "flex", flexGrow: 1, p: 3, height: '100%', flexDirection: 'column', alignItems: 'stretch', position: "relative" }}>
{variant !== "small" && <>
{job.location &&
<Typography variant="body2" sx={{ mb: 1 }}>
<strong>Location:</strong> {job.location.city}, {job.location.state || job.location.country}
</Typography>
}
{job.company &&
<Typography variant="body2" sx={{ mb: 1 }}>
<strong>Company:</strong> {job.company}
</Typography>
}
{job.summary && <Typography variant="body2">
<strong>Summary:</strong> {job.summary}
</Typography>
}
</>}
</CardContent>
<CardActions>
{isAdmin &&
<DeleteConfirmation
onDelete={() => { deleteJob(job.id); }}
sx={{ minWidth: 'auto', px: 2, maxHeight: "min-content", color: "red" }}
action="delete"
label="job"
title="Delete job"
icon=<DeleteIcon />
message={`Are you sure you want to delete ${job.id}? This action cannot be undone.`}
/>}
</CardActions>
</Card>
);
};
export { JobInfo };

View File

@ -0,0 +1,69 @@
import React, { useEffect, useState } from 'react';
import { useNavigate } from "react-router-dom";
import Button from '@mui/material/Button';
import Box from '@mui/material/Box';
import { BackstoryElementProps } from 'components/BackstoryTab';
import { JobInfo } from 'components/ui/JobInfo';
import { Job, JobFull } from "types/types";
import { useAuth } from 'hooks/AuthContext';
import { useSelectedJob } from 'hooks/GlobalContext';
interface JobPickerProps extends BackstoryElementProps {
onSelect?: (job: JobFull) => void
};
const JobPicker = (props: JobPickerProps) => {
const { onSelect } = props;
const { apiClient } = useAuth();
const { selectedJob, setSelectedJob } = useSelectedJob();
const { setSnack } = props;
const [jobs, setJobs] = useState<JobFull[] | null>(null);
useEffect(() => {
if (jobs !== null) {
return;
}
const getJobs = async () => {
try {
const results = await apiClient.getJobs();
const jobs: JobFull[] = results.data;
jobs.sort((a, b) => {
let result = a.company?.localeCompare(b.company || '');
if (result === 0) {
result = a.title?.localeCompare(b.title || '');
}
return result || 0;
});
setJobs(jobs);
} catch (err) {
setSnack("" + err);
}
};
getJobs();
}, [jobs, setSnack]);
return (
<Box sx={{display: "flex", flexDirection: "column", mb: 1}}>
<Box sx={{ display: "flex", gap: 1, flexWrap: "wrap", justifyContent: "center" }}>
{jobs?.map((j, i) =>
<Box key={`${j.id}`}
onClick={() => { onSelect ? onSelect(j) : setSelectedJob(j); }}
sx={{ cursor: "pointer" }}>
{selectedJob?.id === j.id &&
<JobInfo sx={{ maxWidth: "320px", "cursor": "pointer", backgroundColor: "#f0f0f0", "&:hover": { border: "2px solid orange" } }} job={j} />
}
{selectedJob?.id !== j.id &&
<JobInfo sx={{ maxWidth: "320px", "cursor": "pointer", border: "2px solid transparent", "&:hover": { border: "2px solid orange" } }} job={j} />
}
</Box>
)}
</Box>
</Box>
);
};
export {
JobPicker
};

View File

@ -17,7 +17,7 @@ import { ConversationHandle } from 'components/Conversation';
import { BackstoryPageProps } from 'components/BackstoryTab';
import { Message } from 'components/Message';
import { DeleteConfirmation } from 'components/DeleteConfirmation';
import { CandidateInfo } from 'components/CandidateInfo';
import { CandidateInfo } from 'components/ui/CandidateInfo';
import { useNavigate } from 'react-router-dom';
import { useSelectedCandidate } from 'hooks/GlobalContext';
import PropagateLoader from 'react-spinners/PropagateLoader';

View File

@ -9,7 +9,7 @@ import CancelIcon from '@mui/icons-material/Cancel';
import SendIcon from '@mui/icons-material/Send';
import PropagateLoader from 'react-spinners/PropagateLoader';
import { CandidateInfo } from '../components/CandidateInfo';
import { CandidateInfo } from '../components/ui/CandidateInfo';
import { Quote } from 'components/Quote';
import { BackstoryElementProps } from 'components/BackstoryTab';
import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryTextField';

View File

@ -7,31 +7,74 @@ import {
Button,
Typography,
Paper,
Avatar,
useTheme,
Snackbar,
Container,
Grid,
Alert,
Tabs,
Tab,
Card,
CardContent,
Divider,
Avatar,
Badge,
} from '@mui/material';
import {
Person,
PersonAdd,
AccountCircle,
Add,
WorkOutline,
AddCircle,
} from '@mui/icons-material';
import PersonIcon from '@mui/icons-material/Person';
import WorkIcon from '@mui/icons-material/Work';
import AssessmentIcon from '@mui/icons-material/Assessment';
import { JobMatchAnalysis } from 'components/JobMatchAnalysis';
import { Candidate } from "types/types";
import { Candidate, Job, JobFull } from "types/types";
import { useNavigate } from 'react-router-dom';
import { BackstoryPageProps } from 'components/BackstoryTab';
import { useAuth } from 'hooks/AuthContext';
import { useSelectedCandidate, useSelectedJob } from 'hooks/GlobalContext';
import { CandidateInfo } from 'components/CandidateInfo';
import { CandidateInfo } from 'components/ui/CandidateInfo';
import { ComingSoon } from 'components/ui/ComingSoon';
import { JobManagement } from 'components/JobManagement';
import { LoginRequired } from 'components/ui/LoginRequired';
import { Scrollable } from 'components/Scrollable';
import { CandidatePicker } from 'components/ui/CandidatePicker';
import { JobPicker } from 'components/ui/JobPicker';
import { JobCreator } from 'components/JobCreator';
function WorkAddIcon() {
return (
<Box position="relative" display="inline-flex"
sx={{
lineHeight: "30px",
mb: "6px",
}}
>
<WorkOutline sx={{ fontSize: 24 }} />
<Add
sx={{
position: 'absolute',
bottom: -2,
right: -2,
fontSize: 14,
bgcolor: 'background.paper',
borderRadius: '50%',
boxShadow: 1,
}}
color="primary"
/>
</Box>
);
}
// Main component
const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps) => {
const theme = useTheme();
const { user } = useAuth();
const { user, apiClient } = useAuth();
const navigate = useNavigate();
const { selectedCandidate, setSelectedCandidate } = useSelectedCandidate()
const { selectedJob, setSelectedJob } = useSelectedJob()
@ -41,12 +84,21 @@ const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps
const [activeStep, setActiveStep] = useState(0);
const [analysisStarted, setAnalysisStarted] = useState(false);
const [error, setError] = useState<string | null>(null);
const [jobTab, setJobTab] = useState<string>('load');
useEffect(() => {
if (selectedJob && activeStep === 1) {
setActiveStep(2);
console.log({ activeStep, selectedCandidate, selectedJob });
if (!selectedCandidate) {
if (activeStep !== 0) {
setActiveStep(0);
}
} else if (!selectedJob) {
if (activeStep !== 1) {
setActiveStep(1);
}
}
}, [selectedJob, activeStep]);
}, [selectedCandidate, selectedJob, activeStep])
// Steps in our process
const steps = [
@ -93,7 +145,6 @@ const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps
setSelectedJob(null);
break;
case 1: /* Select Job */
setSelectedCandidate(null);
setSelectedJob(null);
break;
case 2: /* Job Analysis */
@ -109,6 +160,11 @@ const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps
setActiveStep(1);
}
const onJobSelect = (job: Job) => {
setSelectedJob(job)
setActiveStep(2);
}
// Render function for the candidate selection step
const renderCandidateSelection = () => (
<Paper elevation={3} sx={{ p: 3, mt: 3, mb: 4, borderRadius: 2 }}>
@ -120,16 +176,35 @@ const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps
</Paper>
);
const handleTabChange = (event: React.SyntheticEvent, value: string) => {
setJobTab(value);
};
// Render function for the job description step
const renderJobDescription = () => (
<Box sx={{ mt: 3 }}>
{selectedCandidate && (
<JobManagement
const renderJobDescription = () => {
if (!selectedCandidate) {
return;
}
return (<Box sx={{ mt: 3 }}>
<Box sx={{ borderBottom: 1, borderColor: 'divider', mb: 3 }}>
<Tabs value={jobTab} onChange={handleTabChange} centered>
<Tab value='load' icon={<WorkOutline />} label="Load" />
<Tab value='create' icon={<WorkAddIcon />} label="Create" />
</Tabs>
</Box>
{jobTab === 'load' &&
<JobPicker {...backstoryProps} onSelect={onJobSelect} />
}
{jobTab === 'create' &&
<JobCreator
{...backstoryProps}
/>
)}
onSave={onJobSelect}
/>}
</Box>
);
);
}
// Render function for the analysis step
const renderAnalysis = () => (
@ -232,7 +307,7 @@ const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps
</Button>
) : (
<Button onClick={handleNext} variant="contained">
{activeStep === steps[steps.length - 1].index - 1 ? 'Start Analysis' : 'Next'}
{activeStep === steps[steps.length - 1].index - 1 ? 'Done' : 'Next'}
</Button>
)}
</Box>

View File

@ -7,7 +7,7 @@ import MuiMarkdown from 'mui-markdown';
import { BackstoryPageProps } from '../components/BackstoryTab';
import { Conversation, ConversationHandle } from '../components/Conversation';
import { BackstoryQuery } from '../components/BackstoryQuery';
import { CandidateInfo } from 'components/CandidateInfo';
import { CandidateInfo } from 'components/ui/CandidateInfo';
import { useAuth } from 'hooks/AuthContext';
import { Candidate } from 'types/types';

View File

@ -52,7 +52,7 @@ interface StreamingOptions<T = Types.ChatMessage> {
signal?: AbortSignal;
}
interface DeleteCandidateResponse {
interface DeleteResponse {
success: boolean;
message: string;
}
@ -530,14 +530,24 @@ class ApiClient {
return this.handleApiResponseWithConversion<Types.Candidate>(response, 'Candidate');
}
async deleteCandidate(id: string): Promise<DeleteCandidateResponse> {
async deleteCandidate(id: string): Promise<DeleteResponse> {
const response = await fetch(`${this.baseUrl}/candidates/${id}`, {
method: 'DELETE',
headers: this.defaultHeaders,
body: JSON.stringify({ id })
});
return handleApiResponse<DeleteCandidateResponse>(response);
return handleApiResponse<DeleteResponse>(response);
}
async deleteJob(id: string): Promise<DeleteResponse> {
const response = await fetch(`${this.baseUrl}/jobs/${id}`, {
method: 'DELETE',
headers: this.defaultHeaders,
body: JSON.stringify({ id })
});
return handleApiResponse<DeleteResponse>(response);
}
async uploadCandidateProfile(file: File): Promise<boolean> {
@ -641,7 +651,7 @@ class ApiClient {
return this.handleApiResponseWithConversion<Types.Job>(response, 'Job');
}
async getJobs(request: Partial<PaginatedRequest> = {}): Promise<PaginatedResponse<Types.Job>> {
async getJobs(request: Partial<PaginatedRequest> = {}): Promise<PaginatedResponse<Types.JobFull>> {
const paginatedRequest = createPaginatedRequest(request);
const params = toUrlParams(formatApiRequest(paginatedRequest));
@ -649,7 +659,7 @@ class ApiClient {
headers: this.defaultHeaders
});
return this.handlePaginatedApiResponseWithConversion<Types.Job>(response, 'Job');
return this.handlePaginatedApiResponseWithConversion<Types.JobFull>(response, 'JobFull');
}
async getJobsByEmployer(employerId: string, request: Partial<PaginatedRequest> = {}): Promise<PaginatedResponse<Types.Job>> {
@ -844,18 +854,28 @@ class ApiClient {
}
};
return this.streamify<Types.DocumentMessage>('/candidates/documents/upload', formData, streamingOptions);
// {
// method: 'POST',
// headers: {
// // Don't set Content-Type - browser will set it automatically with boundary
// 'Authorization': this.defaultHeaders['Authorization']
// },
// body: formData
// });
}
// const result = await handleApiResponse<Types.Document>(response);
// return result;
createJobFromFile(file: File, streamingOptions?: StreamingOptions<Types.Job>): StreamingResponse<Types.Job> {
const formData = new FormData()
formData.append('file', file);
formData.append('filename', file.name);
streamingOptions = {
...streamingOptions,
headers: {
// Don't set Content-Type - browser will set it automatically with boundary
'Authorization': this.defaultHeaders['Authorization']
}
};
return this.streamify<Types.Job>('/jobs/upload', formData, streamingOptions);
}
getJobRequirements(jobId: string, streamingOptions?: StreamingOptions<Types.DocumentMessage>): StreamingResponse<Types.DocumentMessage> {
streamingOptions = {
...streamingOptions,
headers: this.defaultHeaders,
};
return this.streamify<Types.DocumentMessage>(`/jobs/requirements/${jobId}`, null, streamingOptions);
}
async candidateMatchForRequirement(candidate_id: string, requirement: string) : Promise<Types.SkillMatch> {
@ -1001,7 +1021,7 @@ class ApiClient {
* @param options callbacks, headers, and method
* @returns
*/
streamify<T = Types.ChatMessage[]>(api: string, data: BodyInit, options: StreamingOptions<T> = {}) : StreamingResponse<T> {
streamify<T = Types.ChatMessage[]>(api: string, data: BodyInit | null, options: StreamingOptions<T> = {}) : StreamingResponse<T> {
const abortController = new AbortController();
const signal = options.signal || abortController.signal;
const headers = options.headers || null;

View File

@ -1,6 +1,6 @@
// Generated TypeScript types from Pydantic models
// Source: src/backend/models.py
// Generated on: 2025-06-05T22:02:22.004513
// Generated on: 2025-06-07T20:43:58.855207
// DO NOT EDIT MANUALLY - This file is auto-generated
// ============================
@ -323,7 +323,7 @@ export interface ChatMessageMetaData {
presencePenalty?: number;
stopSequences?: Array<string>;
ragResults?: Array<ChromaDBGetResponse>;
llmHistory?: Array<LLMMessage>;
llmHistory?: Array<any>;
evalCount: number;
evalDuration: number;
promptEvalCount: number;
@ -711,12 +711,6 @@ export interface JobResponse {
meta?: Record<string, any>;
}
export interface LLMMessage {
role: string;
content: string;
toolCalls?: Array<Record<string, any>>;
}
export interface Language {
language: string;
proficiency: "basic" | "conversational" | "fluent" | "native";

View File

@ -137,7 +137,7 @@ class Agent(BaseModel, ABC):
# llm: Any,
# model: str,
# message: ChatMessage,
# tool_message: Any, # llama response message
# tool_message: Any, # llama response
# messages: List[LLMMessage],
# ) -> AsyncGenerator[ChatMessage, None]:
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
@ -270,15 +270,15 @@ class Agent(BaseModel, ABC):
# },
# stream=True,
# ):
# # logger.info(f"LLM::Tools: {'done' if response.done else 'processing'} - {response.message}")
# # logger.info(f"LLM::Tools: {'done' if response.finish_reason else 'processing'} - {response}")
# message.status = "streaming"
# message.chunk = response.message.content
# message.chunk = response.content
# message.content += message.chunk
# if not response.done:
# if not response.finish_reason:
# yield message
# if response.done:
# if response.finish_reason:
# self.collect_metrics(response)
# message.metadata.eval_count += response.eval_count
# message.metadata.eval_duration += response.eval_duration
@ -296,9 +296,9 @@ class Agent(BaseModel, ABC):
def collect_metrics(self, response):
self.metrics.tokens_prompt.labels(agent=self.agent_type).inc(
response.prompt_eval_count
response.usage.prompt_eval_count
)
self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count)
self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.usage.eval_count)
def get_rag_context(self, rag_message: ChatMessageRagSearch) -> str:
"""
@ -361,7 +361,7 @@ Content: {content}
yield status_message
try:
chroma_results = user.file_watcher.find_similar(
chroma_results = await user.file_watcher.find_similar(
query=prompt, top_k=top_k, threshold=threshold
)
if not chroma_results:
@ -430,7 +430,7 @@ Content: {content}
logger.info(f"Message options: {options.model_dump(exclude_unset=True)}")
response = None
content = ""
for response in llm.chat(
async for response in llm.chat_stream(
model=model,
messages=messages,
options={
@ -446,12 +446,12 @@ Content: {content}
yield error_message
return
content += response.message.content
content += response.content
if not response.done:
if not response.finish_reason:
streaming_message = ChatMessageStreaming(
session_id=session_id,
content=response.message.content,
content=response.content,
status=ApiStatusType.STREAMING,
)
yield streaming_message
@ -466,7 +466,7 @@ Content: {content}
self.collect_metrics(response)
self.context_tokens = (
response.prompt_eval_count + response.eval_count
response.usage.prompt_eval_count + response.usage.eval_count
)
chat_message = ChatMessage(
@ -476,10 +476,10 @@ Content: {content}
content=content,
metadata = ChatMessageMetaData(
options=options,
eval_count=response.eval_count,
eval_duration=response.eval_duration,
prompt_eval_count=response.prompt_eval_count,
prompt_eval_duration=response.prompt_eval_duration,
eval_count=response.usage.eval_count,
eval_duration=response.usage.eval_duration,
prompt_eval_count=response.usage.prompt_eval_count,
prompt_eval_duration=response.usage.prompt_eval_duration,
)
)
@ -588,12 +588,12 @@ Content: {content}
# end_time = time.perf_counter()
# message.metadata.timers["tool_check"] = end_time - start_time
# if not response.message.tool_calls:
# if not response.tool_calls:
# logger.info("LLM indicates tools will not be used")
# # The LLM will not use tools, so disable use_tools so we can stream the full response
# use_tools = False
# else:
# tool_metadata["attempted"] = response.message.tool_calls
# tool_metadata["attempted"] = response.tool_calls
# if use_tools:
# logger.info("LLM indicates tools will be used")
@ -626,15 +626,15 @@ Content: {content}
# yield message
# return
# if response.message.tool_calls:
# tool_metadata["used"] = response.message.tool_calls
# if response.tool_calls:
# tool_metadata["used"] = response.tool_calls
# # Process all yielded items from the handler
# start_time = time.perf_counter()
# async for message in self.process_tool_calls(
# llm=llm,
# model=model,
# message=message,
# tool_message=response.message,
# tool_message=response,
# messages=messages,
# ):
# if message.status == "error":
@ -647,7 +647,7 @@ Content: {content}
# return
# logger.info("LLM indicated tools will be used, and then they weren't")
# message.content = response.message.content
# message.content = response.content
# message.status = "done"
# yield message
# return
@ -674,7 +674,7 @@ Content: {content}
content = ""
start_time = time.perf_counter()
response = None
for response in llm.chat(
async for response in llm.chat_stream(
model=model,
messages=messages,
options={
@ -690,12 +690,12 @@ Content: {content}
yield error_message
return
content += response.message.content
content += response.content
if not response.done:
if not response.finish_reason:
streaming_message = ChatMessageStreaming(
session_id=session_id,
content=response.message.content,
content=response.content,
)
yield streaming_message
@ -709,7 +709,7 @@ Content: {content}
self.collect_metrics(response)
self.context_tokens = (
response.prompt_eval_count + response.eval_count
response.usage.prompt_eval_count + response.usage.eval_count
)
end_time = time.perf_counter()
@ -720,10 +720,10 @@ Content: {content}
content=content,
metadata = ChatMessageMetaData(
options=options,
eval_count=response.eval_count,
eval_duration=response.eval_duration,
prompt_eval_count=response.prompt_eval_count,
prompt_eval_duration=response.prompt_eval_duration,
eval_count=response.usage.eval_count,
eval_duration=response.usage.eval_duration,
prompt_eval_count=response.usage.prompt_eval_count,
prompt_eval_duration=response.usage.prompt_eval_duration,
timers={
"llm_streamed": end_time - start_time,
"llm_with_tools": 0, # Placeholder for tool processing time

View File

@ -178,12 +178,13 @@ class RedisDatabase:
'jobs': 'job:',
'job_applications': 'job_application:',
'chat_sessions': 'chat_session:',
'chat_messages': 'chat_messages:', # This will store lists
'chat_messages': 'chat_messages:',
'ai_parameters': 'ai_parameters:',
'users': 'user:',
'candidate_documents': 'candidate_documents:',
'candidate_documents': 'candidate_documents:',
'job_requirements': 'job_requirements:', # Add this line
}
def _serialize(self, data: Any) -> str:
"""Serialize data to JSON string for Redis storage"""
if data is None:
@ -236,8 +237,9 @@ class RedisDatabase:
# Delete each document's metadata
for doc_id in document_ids:
pipe.delete(f"document:{doc_id}")
pipe.delete(f"{self.KEY_PREFIXES['job_requirements']}{doc_id}")
deleted_count += 1
# Delete the candidate's document list
pipe.delete(key)
@ -250,7 +252,110 @@ class RedisDatabase:
except Exception as e:
logger.error(f"Error deleting all documents for candidate {candidate_id}: {e}")
raise
async def get_cached_skill_match(self, cache_key: str) -> Optional[Dict[str, Any]]:
"""Retrieve cached skill match assessment"""
try:
cached_data = await self.redis.get(cache_key)
if cached_data:
return json.loads(cached_data)
return None
except Exception as e:
logger.error(f"Error retrieving cached skill match: {e}")
return None
async def cache_skill_match(self, cache_key: str, assessment_data: Dict[str, Any], ttl: int = 86400 * 30) -> bool:
"""Cache skill match assessment with TTL (default 30 days)"""
try:
await self.redis.setex(
cache_key,
ttl,
json.dumps(assessment_data, default=str)
)
return True
except Exception as e:
logger.error(f"Error caching skill match: {e}")
return False
async def get_candidate_skill_update_time(self, candidate_id: str) -> Optional[datetime]:
"""Get the last time candidate's skill information was updated"""
try:
# This assumes you track skill update timestamps in your candidate data
candidate_data = await self.get_candidate(candidate_id)
if candidate_data and 'skills_updated_at' in candidate_data:
return datetime.fromisoformat(candidate_data['skills_updated_at'])
return None
except Exception as e:
logger.error(f"Error getting candidate skill update time: {e}")
return None
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]:
"""Get the timestamp of the latest RAG data update for a specific user"""
try:
rag_update_key = f"user:{user_id}:rag_last_update"
timestamp_str = await self.redis.get(rag_update_key)
if timestamp_str:
return datetime.fromisoformat(timestamp_str.decode('utf-8'))
return None
except Exception as e:
logger.error(f"Error getting user RAG update time for user {user_id}: {e}")
return None
async def update_user_rag_timestamp(self, user_id: str) -> bool:
"""Update the RAG data timestamp for a specific user (call this when user's RAG data is updated)"""
try:
rag_update_key = f"user:{user_id}:rag_last_update"
current_time = datetime.utcnow().isoformat()
await self.redis.set(rag_update_key, current_time)
return True
except Exception as e:
logger.error(f"Error updating RAG timestamp for user {user_id}: {e}")
return False
async def invalidate_candidate_skill_cache(self, candidate_id: str) -> int:
"""Invalidate all cached skill matches for a specific candidate"""
try:
pattern = f"skill_match:{candidate_id}:*"
keys = await self.redis.keys(pattern)
if keys:
return await self.redis.delete(*keys)
return 0
except Exception as e:
logger.error(f"Error invalidating candidate skill cache: {e}")
return 0
async def clear_all_skill_match_cache(self) -> int:
"""Clear all skill match cache (useful after major system updates)"""
try:
pattern = "skill_match:*"
keys = await self.redis.keys(pattern)
if keys:
return await self.redis.delete(*keys)
return 0
except Exception as e:
logger.error(f"Error clearing skill match cache: {e}")
return 0
async def invalidate_user_skill_cache(self, user_id: str) -> int:
"""Invalidate all cached skill matches when a user's RAG data is updated"""
try:
# This assumes all candidates belonging to this user need cache invalidation
# You might need to adjust the pattern based on how you associate candidates with users
pattern = f"skill_match:*"
keys = await self.redis.keys(pattern)
# Filter keys that belong to candidates owned by this user
# This would require additional logic to determine candidate ownership
# For now, you might want to clear all cache when any user's RAG data updates
# or implement a more sophisticated mapping
if keys:
return await self.redis.delete(*keys)
return 0
except Exception as e:
logger.error(f"Error invalidating user skill cache for user {user_id}: {e}")
return 0
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]:
"""Get all documents for a specific candidate"""
key = f"{self.KEY_PREFIXES['candidate_documents']}{candidate_id}"
@ -330,7 +435,134 @@ class RedisDatabase:
if (query_lower in doc.get("filename", "").lower() or
query_lower in doc.get("originalName", "").lower())
]
async def get_job_requirements(self, document_id: str) -> Optional[Dict]:
"""Get cached job requirements analysis for a document"""
try:
key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}"
data = await self.redis.get(key)
if data:
requirements_data = self._deserialize(data)
logger.debug(f"📋 Retrieved cached job requirements for document {document_id}")
return requirements_data
logger.debug(f"📋 No cached job requirements found for document {document_id}")
return None
except Exception as e:
logger.error(f"❌ Error retrieving job requirements for document {document_id}: {e}")
return None
async def save_job_requirements(self, document_id: str, requirements: Dict) -> bool:
"""Save job requirements analysis results for a document"""
try:
key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}"
# Add metadata to the requirements
requirements_with_meta = {
**requirements,
"cached_at": datetime.now(UTC).isoformat(),
"document_id": document_id
}
await self.redis.set(key, self._serialize(requirements_with_meta))
# Optional: Set expiration (e.g., 30 days) to prevent indefinite storage
# await self.redis.expire(key, 30 * 24 * 60 * 60) # 30 days
logger.debug(f"📋 Saved job requirements for document {document_id}")
return True
except Exception as e:
logger.error(f"❌ Error saving job requirements for document {document_id}: {e}")
return False
async def delete_job_requirements(self, document_id: str) -> bool:
"""Delete cached job requirements for a document"""
try:
key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}"
result = await self.redis.delete(key)
if result > 0:
logger.debug(f"📋 Deleted job requirements for document {document_id}")
return True
return False
except Exception as e:
logger.error(f"❌ Error deleting job requirements for document {document_id}: {e}")
return False
async def get_all_job_requirements(self) -> Dict[str, Any]:
"""Get all cached job requirements"""
try:
pattern = f"{self.KEY_PREFIXES['job_requirements']}*"
keys = await self.redis.keys(pattern)
if not keys:
return {}
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
result = {}
for key, value in zip(keys, values):
document_id = key.replace(self.KEY_PREFIXES['job_requirements'], '')
if value:
result[document_id] = self._deserialize(value)
return result
except Exception as e:
logger.error(f"❌ Error retrieving all job requirements: {e}")
return {}
async def get_job_requirements_by_candidate(self, candidate_id: str) -> List[Dict]:
"""Get all job requirements analysis for documents belonging to a candidate"""
try:
# Get all documents for the candidate
candidate_documents = await self.get_candidate_documents(candidate_id)
if not candidate_documents:
return []
# Get job requirements for each document
job_requirements = []
for doc in candidate_documents:
doc_id = doc.get("id")
if doc_id:
requirements = await self.get_job_requirements(doc_id)
if requirements:
# Add document metadata to requirements
requirements["document_filename"] = doc.get("filename")
requirements["document_original_name"] = doc.get("originalName")
job_requirements.append(requirements)
return job_requirements
except Exception as e:
logger.error(f"❌ Error retrieving job requirements for candidate {candidate_id}: {e}")
return []
async def invalidate_job_requirements_cache(self, document_id: str) -> bool:
"""Invalidate (delete) cached job requirements for a document"""
# This is an alias for delete_job_requirements for semantic clarity
return await self.delete_job_requirements(document_id)
async def bulk_delete_job_requirements(self, document_ids: List[str]) -> int:
"""Delete job requirements for multiple documents and return count of deleted items"""
try:
deleted_count = 0
pipe = self.redis.pipeline()
for doc_id in document_ids:
key = f"{self.KEY_PREFIXES['job_requirements']}{doc_id}"
pipe.delete(key)
deleted_count += 1
results = await pipe.execute()
actual_deleted = sum(1 for result in results if result > 0)
logger.info(f"📋 Bulk deleted job requirements for {actual_deleted}/{len(document_ids)} documents")
return actual_deleted
except Exception as e:
logger.error(f"❌ Error bulk deleting job requirements: {e}")
return 0
# Viewer operations
async def get_viewer(self, viewer_id: str) -> Optional[Dict]:
"""Get viewer by ID"""
@ -1484,6 +1716,74 @@ class RedisDatabase:
key = f"{self.KEY_PREFIXES['users']}{email}"
await self.redis.delete(key)
async def get_job_requirements_stats(self) -> Dict[str, Any]:
"""Get statistics about cached job requirements"""
try:
pattern = f"{self.KEY_PREFIXES['job_requirements']}*"
keys = await self.redis.keys(pattern)
stats = {
"total_cached_requirements": len(keys),
"cache_dates": {},
"documents_with_requirements": []
}
if keys:
# Get cache dates for analysis
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
for key, value in zip(keys, values):
if value:
requirements_data = self._deserialize(value)
if requirements_data:
document_id = key.replace(self.KEY_PREFIXES['job_requirements'], '')
stats["documents_with_requirements"].append(document_id)
# Track cache dates
cached_at = requirements_data.get("cached_at")
if cached_at:
cache_date = cached_at[:10] # Extract date part
stats["cache_dates"][cache_date] = stats["cache_dates"].get(cache_date, 0) + 1
return stats
except Exception as e:
logger.error(f"❌ Error getting job requirements stats: {e}")
return {"total_cached_requirements": 0, "cache_dates": {}, "documents_with_requirements": []}
async def cleanup_orphaned_job_requirements(self) -> int:
"""Clean up job requirements for documents that no longer exist"""
try:
# Get all job requirements
all_requirements = await self.get_all_job_requirements()
if not all_requirements:
return 0
orphaned_count = 0
pipe = self.redis.pipeline()
for document_id in all_requirements.keys():
# Check if the document still exists
document_exists = await self.get_document(document_id)
if not document_exists:
# Document no longer exists, delete its job requirements
key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}"
pipe.delete(key)
orphaned_count += 1
logger.debug(f"📋 Queued orphaned job requirements for deletion: {document_id}")
if orphaned_count > 0:
await pipe.execute()
logger.info(f"🧹 Cleaned up {orphaned_count} orphaned job requirements")
return orphaned_count
except Exception as e:
logger.error(f"❌ Error cleaning up orphaned job requirements: {e}")
return 0
# Utility methods
async def clear_all_data(self):
"""Clear all data from Redis (use with caution!)"""

View File

@ -19,8 +19,9 @@ import defines
from logger import logger
import agents as agents
from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, RagEntry, ChatMessageMetaData, ApiStatusType, Candidate, ChatContextType)
from llm_manager import llm_manager
import llm_proxy as llm_manager
from agents.base import Agent
from database import RedisDatabase
class CandidateEntity(Candidate):
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
@ -115,7 +116,7 @@ class CandidateEntity(Candidate):
raise ValueError("initialize() has not been called.")
return self.CandidateEntity__observer
async def initialize(self, prometheus_collector: CollectorRegistry):
async def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
if self.CandidateEntity__initialized:
# Initialization can only be attempted once; if there are multiple attempts, it means
# a subsystem is failing or there is a logic bug in the code.
@ -140,9 +141,11 @@ class CandidateEntity(Candidate):
self.CandidateEntity__observer, self.CandidateEntity__file_watcher = start_file_watcher(
llm=llm_manager.get_llm(),
user_id=self.id,
collection_name=self.username,
persist_directory=vector_db_dir,
watch_directory=rag_content_dir,
database=database,
recreate=False, # Don't recreate if exists
)
has_username_rag = any(item["name"] == self.username for item in self.rags)

View File

@ -7,6 +7,7 @@ from pydantic import BaseModel, Field # type: ignore
from models import ( Candidate )
from .candidate_entity import CandidateEntity
from database import RedisDatabase
from prometheus_client import CollectorRegistry # type: ignore
class EntityManager:
@ -34,9 +35,10 @@ class EntityManager:
pass
self._cleanup_task = None
def initialize(self, prometheus_collector: CollectorRegistry):
def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
"""Initialize the EntityManager with Prometheus collector"""
self._prometheus_collector = prometheus_collector
self._database = database
async def get_entity(self, candidate: Candidate) -> CandidateEntity:
"""Get or create CandidateEntity with proper reference tracking"""
@ -49,7 +51,7 @@ class EntityManager:
return entity
entity = CandidateEntity(candidate=candidate)
await entity.initialize(prometheus_collector=self._prometheus_collector)
await entity.initialize(prometheus_collector=self._prometheus_collector, database=self._database)
# Store with reference tracking
self._entities[candidate.id] = entity

View File

@ -1,41 +0,0 @@
import ollama
import defines
_llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
class llm_manager:
"""
A class to manage LLM operations using the Ollama client.
"""
@staticmethod
def get_llm() -> ollama.Client: # type: ignore
"""
Get the Ollama client instance.
Returns:
An instance of the Ollama client.
"""
return _llm
@staticmethod
def get_models() -> list[str]:
"""
Get a list of available models from the Ollama client.
Returns:
List of model names.
"""
return _llm.models()
@staticmethod
def get_model_info(model_name: str) -> dict:
"""
Get information about a specific model.
Args:
model_name: The name of the model to retrieve information for.
Returns:
A dictionary containing model information.
"""
return _llm.model(model_name)

1203
src/backend/llm_proxy.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -13,6 +13,9 @@ import uuid
import defines
import pathlib
from markitdown import MarkItDown, StreamInfo # type: ignore
import io
import uvicorn # type: ignore
from typing import List, Optional, Dict, Any
from datetime import datetime, timedelta, UTC
@ -53,7 +56,7 @@ import defines
from logger import logger
from database import RedisDatabase, redis_manager, DatabaseManager
from metrics import Metrics
from llm_manager import llm_manager
import llm_proxy as llm_manager
import entities
from email_service import VerificationEmailRateLimiter, email_service
from device_manager import DeviceManager
@ -116,7 +119,8 @@ async def lifespan(app: FastAPI):
try:
# Initialize database
await db_manager.initialize()
entities.entity_manager.initialize(prometheus_collector, database=db_manager.get_database())
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
@ -1827,7 +1831,7 @@ async def upload_candidate_document(
yield error_message
return
converted = False;
converted = False
if document_type != DocumentType.MARKDOWN and document_type != DocumentType.TXT:
p = pathlib.Path(file_path)
p_as_md = p.with_suffix(".md")
@ -1873,53 +1877,6 @@ async def upload_candidate_document(
content=file_content,
)
yield chat_message
# If this is a job description, process it with the job requirements agent
if not options.is_job_document:
return
status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Initiating connection with {candidate.first_name}'s AI agent...",
activity=ApiActivityType.INFO
)
yield status_message
await asyncio.sleep(0)
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
if not chat_agent:
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="No agent found for job requirements chat type"
)
yield error_message
return
message = None
status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Analyzing document for company and requirement details...",
activity=ApiActivityType.SEARCHING
)
yield status_message
await asyncio.sleep(0)
async for message in chat_agent.generate(
llm=llm_manager.get_llm(),
model=defines.model,
session_id=MOCK_UUID,
prompt=file_content
):
pass
if not message or not isinstance(message, JobRequirementsMessage):
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to process job description file"
)
yield error_message
return
yield message
try:
async def to_json(method):
try:
@ -1932,7 +1889,6 @@ async def upload_candidate_document(
logger.error(f"Error in to_json conversion: {e}")
return
# return DebugStreamingResponse(
return StreamingResponse(
to_json(upload_stream_generator(file_content)),
media_type="text/event-stream",
@ -1944,15 +1900,64 @@ async def upload_candidate_document(
"Access-Control-Allow-Origin": "*", # Adjust for your CORS needs
"Transfer-Encoding": "chunked",
},
)
)
except Exception as e:
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Document upload error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("UPLOAD_ERROR", "Failed to upload document")
return StreamingResponse(
iter([ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to upload document"
)]),
media_type="text/event-stream"
)
async def create_job_from_content(database: RedisDatabase, current_user: Candidate, content: str):
status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Initiating connection with {current_user.first_name}'s AI agent...",
activity=ApiActivityType.INFO
)
yield status_message
await asyncio.sleep(0) # Let the status message propagate
async with entities.get_candidate_entity(candidate=current_user) as candidate_entity:
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
if not chat_agent:
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="No agent found for job requirements chat type"
)
yield error_message
return
message = None
status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Analyzing document for company and requirement details...",
activity=ApiActivityType.SEARCHING
)
yield status_message
await asyncio.sleep(0)
async for message in chat_agent.generate(
llm=llm_manager.get_llm(),
model=defines.model,
session_id=MOCK_UUID,
prompt=content
):
pass
if not message or not isinstance(message, JobRequirementsMessage):
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to process job description file"
)
yield error_message
return
logger.info(f"✅ Successfully saved job requirements job {message.id}")
yield message
return
@api_router.post("/candidates/profile/upload")
async def upload_candidate_profile(
file: UploadFile = File(...),
@ -2573,6 +2578,7 @@ async def delete_candidate(
status_code=500,
content=create_error_response("DELETE_ERROR", "Failed to delete candidate")
)
@api_router.patch("/candidates/{candidate_id}")
async def update_candidate(
candidate_id: str = Path(...),
@ -2816,6 +2822,139 @@ async def create_candidate_job(
content=create_error_response("CREATION_FAILED", str(e))
)
@api_router.post("/jobs/upload")
async def create_job_from_file(
file: UploadFile = File(...),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
):
"""Upload a job document for the current candidate and create a Job"""
# Check file size (limit to 10MB)
max_size = 10 * 1024 * 1024 # 10MB
file_content = await file.read()
if len(file_content) > max_size:
logger.info(f"⚠️ File too large: {file.filename} ({len(file_content)} bytes)")
return StreamingResponse(
iter([ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="File size exceeds 10MB limit"
)]),
media_type="text/event-stream"
)
if len(file_content) == 0:
logger.info(f"⚠️ File is empty: {file.filename}")
return StreamingResponse(
iter([ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="File is empty"
)]),
media_type="text/event-stream"
)
"""Upload a document for the current candidate"""
async def upload_stream_generator(file_content):
# Verify user is a candidate
if current_user.user_type != "candidate":
logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Only candidates can upload documents"
)
yield error_message
return
file.filename = re.sub(r'^.*/', '', file.filename) if file.filename else '' # Sanitize filename
if not file.filename or file.filename.strip() == "":
logger.warning("⚠️ File upload attempt with missing filename")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="File must have a valid filename"
)
yield error_message
return
logger.info(f"📁 Received file upload: filename='{file.filename}', content_type='{file.content_type}', size='{len(file_content)} bytes'")
# Validate file type
allowed_types = ['.txt', '.md', '.docx', '.pdf', '.png', '.jpg', '.jpeg', '.gif']
file_extension = pathlib.Path(file.filename).suffix.lower() if file.filename else ""
if file_extension not in allowed_types:
logger.warning(f"⚠️ Invalid file type: {file_extension} for file {file.filename}")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}"
)
yield error_message
return
document_type = get_document_type_from_filename(file.filename or "unknown.txt")
if document_type != DocumentType.MARKDOWN and document_type != DocumentType.TXT:
status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Converting content from {document_type}...",
activity=ApiActivityType.CONVERTING
)
yield status_message
try:
md = MarkItDown(enable_plugins=False) # Set to True to enable plugins
stream = io.BytesIO(file_content)
stream_info = StreamInfo(
extension=file_extension, # e.g., ".pdf"
url=file.filename # optional, helps with logging and guessing
)
result = md.convert_stream(stream, stream_info=stream_info, output_format="markdown")
file_content = result.text_content
logger.info(f"✅ Converted {file.filename} to Markdown format")
except Exception as e:
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Failed to convert {file.filename} to Markdown.",
)
yield error_message
logger.error(f"❌ Error converting {file.filename} to Markdown: {e}")
return
async for message in create_job_from_content(database=database, current_user=current_user, content=file_content):
yield message
return
try:
async def to_json(method):
try:
async for message in method:
json_data = message.model_dump(mode='json', by_alias=True)
json_str = json.dumps(json_data)
yield f"data: {json_str}\n\n".encode("utf-8")
except Exception as e:
logger.error(backstory_traceback.format_exc())
logger.error(f"Error in to_json conversion: {e}")
return
return StreamingResponse(
to_json(upload_stream_generator(file_content)),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache, no-store, must-revalidate",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Nginx
"X-Content-Type-Options": "nosniff",
"Access-Control-Allow-Origin": "*", # Adjust for your CORS needs
"Transfer-Encoding": "chunked",
},
)
except Exception as e:
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Document upload error: {e}")
return StreamingResponse(
iter([ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to upload document"
)]),
media_type="text/event-stream"
)
@api_router.get("/jobs/{job_id}")
async def get_job(
job_id: str = Path(...),
@ -2931,6 +3070,49 @@ async def search_jobs(
content=create_error_response("SEARCH_FAILED", str(e))
)
@api_router.delete("/jobs/{job_id}")
async def delete_job(
job_id: str = Path(...),
admin_user = Depends(get_current_admin),
database: RedisDatabase = Depends(get_database)
):
"""Delete a Job"""
try:
# Check if admin user
if not admin_user.is_admin:
logger.warning(f"⚠️ Unauthorized delete attempt by user {admin_user.id}")
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Only admins can delete")
)
# Get candidate data
job_data = await database.get_job(job_id)
if not job_data:
logger.warning(f"⚠️ Candidate not found for deletion: {job_id}")
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Job not found")
)
# Delete job from database
await database.delete_job(job_id)
logger.info(f"🗑️ Job deleted: {job_id} by admin {admin_user.id}")
return create_success_response({
"message": "Job deleted successfully",
"jobId": job_id
})
except Exception as e:
logger.error(f"❌ Delete job error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("DELETE_ERROR", "Failed to delete job")
)
# ============================
# Chat Endpoints
# ============================
@ -3541,7 +3723,7 @@ async def get_candidate_skill_match(
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
):
"""Get skill match for a candidate against a requirement"""
"""Get skill match for a candidate against a requirement with caching"""
try:
# Find candidate by ID
candidate_data = await database.get_candidate(candidate_id)
@ -3553,34 +3735,89 @@ async def get_candidate_skill_match(
candidate = Candidate.model_validate(candidate_data)
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
logger.info(f"🔍 Running skill match for candidate {candidate_entity.username} against requirement: {requirement}")
agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.SKILL_MATCH)
if not agent:
return JSONResponse(
status_code=400,
content=create_error_response("AGENT_NOT_FOUND", "No skill match agent found for this candidate")
)
# Entity automatically released when done
skill_match = await get_last_item(
agent.generate(
llm=llm_manager.get_llm(),
model=defines.model,
session_id=MOCK_UUID,
prompt=requirement,
# Create cache key for this specific candidate + requirement combination
cache_key = f"skill_match:{candidate_id}:{hash(requirement)}"
# Get cached assessment if it exists
cached_assessment = await database.get_cached_skill_match(cache_key)
# Get the last update time for the candidate's skill information
candidate_skill_update_time = await database.get_candidate_skill_update_time(candidate_id)
# Get the latest RAG data update time for the current user
user_rag_update_time = await database.get_user_rag_update_time(current_user.id)
# Determine if we need to regenerate the assessment
should_regenerate = True
cached_date = None
if cached_assessment:
cached_date = cached_assessment.get('cached_at')
if cached_date:
# Check if cached result is still valid
# Regenerate if:
# 1. Candidate skills were updated after cache date
# 2. User's RAG data was updated after cache date
if (not candidate_skill_update_time or cached_date >= candidate_skill_update_time) and \
(not user_rag_update_time or cached_date >= user_rag_update_time):
should_regenerate = False
logger.info(f"🔄 Using cached skill match for candidate {candidate.id}")
if should_regenerate:
logger.info(f"🔍 Generating new skill match for candidate {candidate.id} against requirement: {requirement}")
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.SKILL_MATCH)
if not agent:
return JSONResponse(
status_code=400,
content=create_error_response("AGENT_NOT_FOUND", "No skill match agent found for this candidate")
)
# Generate new skill match
skill_match = await get_last_item(
agent.generate(
llm=llm_manager.get_llm(),
model=defines.model,
session_id=MOCK_UUID,
prompt=requirement,
),
)
if skill_match is None:
if skill_match is None:
return JSONResponse(
status_code=500,
content=create_error_response("NO_MATCH", "No skill match found for the given requirement")
)
skill_match_data = json.loads(skill_match.content)
# Cache the new assessment with current timestamp
cached_assessment = {
"skill_match": skill_match_data,
"cached_at": datetime.utcnow().isoformat(),
"candidate_id": candidate_id,
"requirement": requirement
}
await database.cache_skill_match(cache_key, cached_assessment)
logger.info(f"💾 Cached new skill match for candidate {candidate.id}")
logger.info(f"✅ Skill match found for candidate {candidate.id}: {skill_match_data['evidence_strength']}")
else:
# Use cached result - we know cached_assessment is not None here
if cached_assessment is None:
return JSONResponse(
status_code=500,
content=create_error_response("NO_MATCH", "No skill match found for the given requirement")
content=create_error_response("CACHE_ERROR", "Unexpected cache state")
)
skill_match = json.loads(skill_match.content)
logger.info(f"✅ Skill match found for candidate {candidate.id}: {skill_match['evidence_strength']}")
skill_match_data = cached_assessment["skill_match"]
logger.info(f"Retrieved cached skill match for candidate {candidate.id}: {skill_match_data['evidence_strength']}")
return create_success_response({
"candidateId": candidate.id,
"skillMatch": skill_match
"skillMatch": skill_match_data,
"cached": not should_regenerate,
"cacheTimestamp": cached_date
})
except Exception as e:
@ -3589,8 +3826,8 @@ async def get_candidate_skill_match(
return JSONResponse(
status_code=500,
content=create_error_response("SKILL_MATCH_ERROR", str(e))
)
)
@api_router.get("/candidates/{username}/chat-sessions")
async def get_candidate_chat_sessions(
username: str = Path(...),
@ -3911,7 +4148,6 @@ async def track_requests(request, call_next):
# FastAPI Metrics
# ============================
prometheus_collector = CollectorRegistry()
entities.entity_manager.initialize(prometheus_collector)
# Keep the Instrumentator instance alive
instrumentator = Instrumentator(

View File

@ -768,11 +768,7 @@ class ChatOptions(BaseModel):
"populate_by_name": True # Allow both field names and aliases
}
class LLMMessage(BaseModel):
role: str = Field(default="")
content: str = Field(default="")
tool_calls: Optional[List[Dict]] = Field(default=[], exclude=True)
from llm_proxy import (LLMMessage)
class ApiMessage(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))

View File

@ -13,7 +13,6 @@ import numpy as np # type: ignore
import traceback
import chromadb # type: ignore
import ollama
from watchdog.observers import Observer # type: ignore
from watchdog.events import FileSystemEventHandler # type: ignore
import umap # type: ignore
@ -27,6 +26,7 @@ from .markdown_chunker import (
# When imported as a module, use relative imports
import defines
from database import RedisDatabase
from models import ChromaDBGetResponse
__all__ = ["ChromaDBFileWatcher", "start_file_watcher"]
@ -47,11 +47,16 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
loop,
persist_directory,
collection_name,
database: RedisDatabase,
user_id: str,
chunk_size=DEFAULT_CHUNK_SIZE,
chunk_overlap=DEFAULT_CHUNK_OVERLAP,
recreate=False,
):
self.llm = llm
self.database = database
self.user_id = user_id
self.database = database
self.watch_directory = watch_directory
self.persist_directory = persist_directory or defines.persist_directory
self.collection_name = collection_name
@ -284,6 +289,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
if results and "ids" in results and results["ids"]:
self.collection.delete(ids=results["ids"])
await self.database.update_user_rag_timestamp(self.user_id)
logging.info(
f"Removed {len(results['ids'])} chunks for deleted file: {file_path}"
)
@ -372,14 +378,15 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
name=self.collection_name, metadata={"hnsw:space": "cosine"}
)
def get_embedding(self, text: str) -> np.ndarray:
async def get_embedding(self, text: str) -> np.ndarray:
"""Generate and normalize an embedding for the given text."""
# Get embedding
try:
response = self.llm.embeddings(model=defines.embedding_model, prompt=text)
embedding = np.array(response["embedding"])
response = await self.llm.embeddings(model=defines.embedding_model, input_texts=text)
embedding = np.array(response.get_single_embedding())
except Exception as e:
logging.error(traceback.format_exc())
logging.error(f"Failed to get embedding: {e}")
raise
@ -404,7 +411,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
return embedding
def add_embeddings_to_collection(self, chunks: List[Chunk]):
async def _add_embeddings_to_collection(self, chunks: List[Chunk]):
"""Add embeddings for chunks to the collection."""
for i, chunk in enumerate(chunks):
@ -420,7 +427,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
content_hash = hashlib.md5(text.encode()).hexdigest()[:8]
chunk_id = f"{path_hash}_{i}_{content_hash}"
embedding = self.get_embedding(text)
embedding = await self.get_embedding(text)
try:
self.collection.add(
ids=[chunk_id],
@ -458,11 +465,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# 0.5 - 0.7 0.65 - 0.75 Balanced precision/recall
# 0.7 - 0.9 0.55 - 0.65 Higher recall, more inclusive
# 0.9 - 1.2 0.40 - 0.55 Very inclusive, may include tangential content
def find_similar(self, query, top_k=defines.default_rag_top_k, threshold=defines.default_rag_threshold):
async def find_similar(self, query, top_k=defines.default_rag_top_k, threshold=defines.default_rag_threshold):
"""Find similar documents to the query."""
# collection is configured with hnsw:space cosine
query_embedding = self.get_embedding(query)
query_embedding = await self.get_embedding(query)
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
@ -572,6 +579,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
and existing_results["ids"]
):
self.collection.delete(ids=existing_results["ids"])
await self.database.update_user_rag_timestamp(self.user_id)
extensions = (".docx", ".xlsx", ".xls", ".pdf")
if file_path.endswith(extensions):
@ -606,7 +614,8 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# f.write(json.dumps(chunk, indent=2))
# Add chunks to collection
self.add_embeddings_to_collection(chunks)
await self._add_embeddings_to_collection(chunks)
await self.database.update_user_rag_timestamp(self.user_id)
logging.info(f"Updated {len(chunks)} chunks for file: {file_path}")
@ -640,9 +649,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# Function to start the file watcher
def start_file_watcher(
llm,
user_id,
watch_directory,
persist_directory,
collection_name,
database: RedisDatabase,
initialize=False,
recreate=False,
):
@ -663,9 +674,11 @@ def start_file_watcher(
llm,
watch_directory=watch_directory,
loop=loop,
user_id=user_id,
persist_directory=persist_directory,
collection_name=collection_name,
recreate=recreate,
database=database
)
# Process all files if:

193
src/multi-llm/config.md Normal file
View File

@ -0,0 +1,193 @@
# Environment Configuration Examples
# ==== Development (Ollama only) ====
export OLLAMA_HOST="http://localhost:11434"
export DEFAULT_LLM_PROVIDER="ollama"
# ==== Production with OpenAI ====
export OPENAI_API_KEY="sk-your-openai-key-here"
export DEFAULT_LLM_PROVIDER="openai"
# ==== Production with Anthropic ====
export ANTHROPIC_API_KEY="sk-ant-your-anthropic-key-here"
export DEFAULT_LLM_PROVIDER="anthropic"
# ==== Production with multiple providers ====
export OPENAI_API_KEY="sk-your-openai-key-here"
export ANTHROPIC_API_KEY="sk-ant-your-anthropic-key-here"
export GEMINI_API_KEY="your-gemini-key-here"
export OLLAMA_HOST="http://ollama-server:11434"
export DEFAULT_LLM_PROVIDER="openai"
# ==== Docker Compose Example ====
# docker-compose.yml
version: '3.8'
services:
api:
build: .
ports:
- "8000:8000"
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY}
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
- DEFAULT_LLM_PROVIDER=openai
depends_on:
- ollama
ollama:
image: ollama/ollama
ports:
- "11434:11434"
volumes:
- ollama_data:/root/.ollama
environment:
- OLLAMA_HOST=0.0.0.0
volumes:
ollama_data:
# ==== Kubernetes ConfigMap ====
apiVersion: v1
kind: ConfigMap
metadata:
name: llm-config
data:
DEFAULT_LLM_PROVIDER: "openai"
OLLAMA_HOST: "http://ollama-service:11434"
---
apiVersion: v1
kind: Secret
metadata:
name: llm-secrets
type: Opaque
stringData:
OPENAI_API_KEY: "sk-your-key-here"
ANTHROPIC_API_KEY: "sk-ant-your-key-here"
# ==== Python configuration example ====
# config.py
import os
from llm_proxy import get_llm, LLMProvider
def configure_llm_for_environment():
"""Configure LLM based on deployment environment"""
llm = get_llm()
# Development: Use Ollama
if os.getenv('ENVIRONMENT') == 'development':
llm.configure_provider(LLMProvider.OLLAMA, host='http://localhost:11434')
llm.set_default_provider(LLMProvider.OLLAMA)
# Staging: Use OpenAI with rate limiting
elif os.getenv('ENVIRONMENT') == 'staging':
llm.configure_provider(LLMProvider.OPENAI,
api_key=os.getenv('OPENAI_API_KEY'),
max_retries=3,
timeout=30)
llm.set_default_provider(LLMProvider.OPENAI)
# Production: Use multiple providers with fallback
elif os.getenv('ENVIRONMENT') == 'production':
# Primary: Anthropic
llm.configure_provider(LLMProvider.ANTHROPIC,
api_key=os.getenv('ANTHROPIC_API_KEY'))
# Fallback: OpenAI
llm.configure_provider(LLMProvider.OPENAI,
api_key=os.getenv('OPENAI_API_KEY'))
# Set primary provider
llm.set_default_provider(LLMProvider.ANTHROPIC)
# ==== Usage Examples ====
# Example 1: Basic usage with default provider
async def basic_example():
llm = get_llm()
response = await llm.chat(
model="gpt-4",
messages=[{"role": "user", "content": "Hello!"}]
)
print(response.content)
# Example 2: Specify provider explicitly
async def provider_specific_example():
llm = get_llm()
# Use OpenAI specifically
response = await llm.chat(
model="gpt-4",
messages=[{"role": "user", "content": "Hello!"}],
provider=LLMProvider.OPENAI
)
# Use Anthropic specifically
response2 = await llm.chat(
model="claude-3-sonnet-20240229",
messages=[{"role": "user", "content": "Hello!"}],
provider=LLMProvider.ANTHROPIC
)
# Example 3: Streaming with provider fallback
async def streaming_with_fallback():
llm = get_llm()
try:
async for chunk in llm.chat(
model="claude-3-sonnet-20240229",
messages=[{"role": "user", "content": "Write a story"}],
provider=LLMProvider.ANTHROPIC,
stream=True
):
print(chunk.content, end='', flush=True)
except Exception as e:
print(f"Primary provider failed: {e}")
# Fallback to OpenAI
async for chunk in llm.chat(
model="gpt-4",
messages=[{"role": "user", "content": "Write a story"}],
provider=LLMProvider.OPENAI,
stream=True
):
print(chunk.content, end='', flush=True)
# Example 4: Load balancing between providers
import random
async def load_balanced_request():
llm = get_llm()
available_providers = [LLMProvider.OPENAI, LLMProvider.ANTHROPIC]
# Simple random load balancing
provider = random.choice(available_providers)
# Adjust model based on provider
model_mapping = {
LLMProvider.OPENAI: "gpt-4",
LLMProvider.ANTHROPIC: "claude-3-sonnet-20240229"
}
response = await llm.chat(
model=model_mapping[provider],
messages=[{"role": "user", "content": "Hello!"}],
provider=provider
)
print(f"Response from {provider.value}: {response.content}")
# ==== FastAPI Startup Configuration ====
from fastapi import FastAPI
app = FastAPI()
@app.on_event("startup")
async def startup_event():
"""Configure LLM providers on application startup"""
configure_llm_for_environment()
# Verify configuration
llm = get_llm()
print(f"Configured providers: {[p.value for p in llm._initialized_providers]}")
print(f"Default provider: {llm.default_provider.value}")

238
src/multi-llm/example.py Normal file
View File

@ -0,0 +1,238 @@
from fastapi import FastAPI, HTTPException # type: ignore
from fastapi.responses import StreamingResponse # type: ignore
from pydantic import BaseModel # type: ignore
from typing import List, Optional, Dict, Any
import json
import asyncio
from llm_proxy import get_llm, LLMProvider, ChatMessage
app = FastAPI(title="Unified LLM API")
# Pydantic models for request/response
class ChatRequest(BaseModel):
model: str
messages: List[Dict[str, str]]
provider: Optional[str] = None
stream: bool = False
max_tokens: Optional[int] = None
temperature: Optional[float] = None
class GenerateRequest(BaseModel):
model: str
prompt: str
provider: Optional[str] = None
stream: bool = False
max_tokens: Optional[int] = None
temperature: Optional[float] = None
class ChatResponseModel(BaseModel):
content: str
model: str
provider: str
finish_reason: Optional[str] = None
usage: Optional[Dict[str, int]] = None
@app.post("/chat")
async def chat_endpoint(request: ChatRequest):
"""Chat endpoint that works with any configured LLM provider"""
try:
llm = get_llm()
# Convert provider string to enum if provided
provider = None
if request.provider:
try:
provider = LLMProvider(request.provider.lower())
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Unsupported provider: {request.provider}"
)
# Prepare kwargs
kwargs = {}
if request.max_tokens:
kwargs['max_tokens'] = request.max_tokens
if request.temperature:
kwargs['temperature'] = request.temperature
if request.stream:
async def generate_response():
try:
# Use the type-safe streaming method
async for chunk in llm.chat_stream(
model=request.model,
messages=request.messages,
provider=provider,
**kwargs
):
# Format as Server-Sent Events
data = {
"content": chunk.content,
"model": chunk.model,
"provider": provider.value if provider else llm.default_provider.value,
"finish_reason": chunk.finish_reason
}
yield f"data: {json.dumps(data)}\n\n"
except Exception as e:
error_data = {"error": str(e)}
yield f"data: {json.dumps(error_data)}\n\n"
finally:
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_response(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive"
}
)
else:
# Use the type-safe single response method
response = await llm.chat_single(
model=request.model,
messages=request.messages,
provider=provider,
**kwargs
)
return ChatResponseModel(
content=response.content,
model=response.model,
provider=provider.value if provider else llm.default_provider.value,
finish_reason=response.finish_reason,
usage=response.usage
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate")
async def generate_endpoint(request: GenerateRequest):
"""Generate endpoint for simple text generation"""
try:
llm = get_llm()
provider = None
if request.provider:
try:
provider = LLMProvider(request.provider.lower())
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Unsupported provider: {request.provider}"
)
kwargs = {}
if request.max_tokens:
kwargs['max_tokens'] = request.max_tokens
if request.temperature:
kwargs['temperature'] = request.temperature
if request.stream:
async def generate_response():
try:
# Use the type-safe streaming method
async for chunk in llm.generate_stream(
model=request.model,
prompt=request.prompt,
provider=provider,
**kwargs
):
data = {
"content": chunk.content,
"model": chunk.model,
"provider": provider.value if provider else llm.default_provider.value,
"finish_reason": chunk.finish_reason
}
yield f"data: {json.dumps(data)}\n\n"
except Exception as e:
error_data = {"error": str(e)}
yield f"data: {json.dumps(error_data)}\n\n"
finally:
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_response(),
media_type="text/event-stream"
)
else:
# Use the type-safe single response method
response = await llm.generate_single(
model=request.model,
prompt=request.prompt,
provider=provider,
**kwargs
)
return ChatResponseModel(
content=response.content,
model=response.model,
provider=provider.value if provider else llm.default_provider.value,
finish_reason=response.finish_reason,
usage=response.usage
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/models")
async def list_models(provider: Optional[str] = None):
"""List available models for a provider"""
try:
llm = get_llm()
provider_enum = None
if provider:
try:
provider_enum = LLMProvider(provider.lower())
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Unsupported provider: {provider}"
)
models = await llm.list_models(provider_enum)
return {
"provider": provider_enum.value if provider_enum else llm.default_provider.value,
"models": models
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/providers")
async def list_providers():
"""List all configured providers"""
llm = get_llm()
return {
"providers": [provider.value for provider in llm._initialized_providers],
"default": llm.default_provider.value
}
@app.post("/providers/{provider}/set-default")
async def set_default_provider(provider: str):
"""Set the default provider"""
try:
llm = get_llm()
provider_enum = LLMProvider(provider.lower())
llm.set_default_provider(provider_enum)
return {"message": f"Default provider set to {provider}", "default": provider}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# Health check endpoint
@app.get("/health")
async def health_check():
"""Health check endpoint"""
llm = get_llm()
return {
"status": "healthy",
"providers_configured": len(llm._initialized_providers),
"default_provider": llm.default_provider.value
}
if __name__ == "__main__":
import uvicorn # type: ignore
uvicorn.run(app, host="0.0.0.0", port=8000)

606
src/multi-llm/llm_proxy.py Normal file
View File

@ -0,0 +1,606 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Any, AsyncGenerator, Optional, Union
from enum import Enum
import asyncio
import json
from dataclasses import dataclass
import os
# Standard message format for all providers
@dataclass
class ChatMessage:
role: str # "user", "assistant", "system"
content: str
def __getitem__(self, key: str):
"""Allow dictionary-style access for backward compatibility"""
if key == 'role':
return self.role
elif key == 'content':
return self.content
else:
raise KeyError(f"'{key}' not found in ChatMessage")
def __setitem__(self, key: str, value: str):
"""Allow dictionary-style assignment"""
if key == 'role':
self.role = value
elif key == 'content':
self.content = value
else:
raise KeyError(f"'{key}' not found in ChatMessage")
@classmethod
def from_dict(cls, data: Dict[str, str]) -> 'ChatMessage':
"""Create ChatMessage from dictionary"""
return cls(role=data['role'], content=data['content'])
def to_dict(self) -> Dict[str, str]:
"""Convert ChatMessage to dictionary"""
return {'role': self.role, 'content': self.content}
@dataclass
class ChatResponse:
content: str
model: str
finish_reason: Optional[str] = None
usage: Optional[Dict[str, int]] = None
class LLMProvider(Enum):
OLLAMA = "ollama"
OPENAI = "openai"
ANTHROPIC = "anthropic"
GEMINI = "gemini"
GROK = "grok"
class BaseLLMAdapter(ABC):
"""Abstract base class for all LLM adapters"""
def __init__(self, **config):
self.config = config
@abstractmethod
async def chat(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
"""Send chat messages and get response"""
pass
@abstractmethod
async def generate(
self,
model: str,
prompt: str,
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
"""Generate text from prompt"""
pass
@abstractmethod
async def list_models(self) -> List[str]:
"""List available models"""
pass
class OllamaAdapter(BaseLLMAdapter):
"""Adapter for Ollama"""
def __init__(self, **config):
super().__init__(**config)
import ollama
self.client = ollama.AsyncClient( # type: ignore
host=config.get('host', 'http://localhost:11434')
)
async def chat(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
# Convert ChatMessage objects to Ollama format
ollama_messages = []
for msg in messages:
ollama_messages.append({
"role": msg.role,
"content": msg.content
})
if stream:
return self._stream_chat(model, ollama_messages, **kwargs)
else:
response = await self.client.chat(
model=model,
messages=ollama_messages,
stream=False,
**kwargs
)
return ChatResponse(
content=response['message']['content'],
model=model,
finish_reason=response.get('done_reason')
)
async def _stream_chat(self, model: str, messages: List[Dict], **kwargs):
async for chunk in self.client.chat(
model=model,
messages=messages,
stream=True,
**kwargs
):
if chunk.get('message', {}).get('content'):
yield ChatResponse(
content=chunk['message']['content'],
model=model,
finish_reason=chunk.get('done_reason')
)
async def generate(
self,
model: str,
prompt: str,
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
if stream:
return self._stream_generate(model, prompt, **kwargs)
else:
response = await self.client.generate(
model=model,
prompt=prompt,
stream=False,
**kwargs
)
return ChatResponse(
content=response['response'],
model=model,
finish_reason=response.get('done_reason')
)
async def _stream_generate(self, model: str, prompt: str, **kwargs):
async for chunk in self.client.generate(
model=model,
prompt=prompt,
stream=True,
**kwargs
):
if chunk.get('response'):
yield ChatResponse(
content=chunk['response'],
model=model,
finish_reason=chunk.get('done_reason')
)
async def list_models(self) -> List[str]:
models = await self.client.list()
return [model['name'] for model in models['models']]
class OpenAIAdapter(BaseLLMAdapter):
"""Adapter for OpenAI"""
def __init__(self, **config):
super().__init__(**config)
import openai # type: ignore
self.client = openai.AsyncOpenAI(
api_key=config.get('api_key', os.getenv('OPENAI_API_KEY'))
)
async def chat(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
# Convert ChatMessage objects to OpenAI format
openai_messages = []
for msg in messages:
openai_messages.append({
"role": msg.role,
"content": msg.content
})
if stream:
return self._stream_chat(model, openai_messages, **kwargs)
else:
response = await self.client.chat.completions.create(
model=model,
messages=openai_messages,
stream=False,
**kwargs
)
return ChatResponse(
content=response.choices[0].message.content,
model=model,
finish_reason=response.choices[0].finish_reason,
usage=response.usage.model_dump() if response.usage else None
)
async def _stream_chat(self, model: str, messages: List[Dict], **kwargs):
async for chunk in await self.client.chat.completions.create(
model=model,
messages=messages,
stream=True,
**kwargs
):
if chunk.choices[0].delta.content:
yield ChatResponse(
content=chunk.choices[0].delta.content,
model=model,
finish_reason=chunk.choices[0].finish_reason
)
async def generate(
self,
model: str,
prompt: str,
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
# Convert to chat format for OpenAI
messages = [ChatMessage(role="user", content=prompt)]
return await self.chat(model, messages, stream, **kwargs)
async def list_models(self) -> List[str]:
models = await self.client.models.list()
return [model.id for model in models.data]
class AnthropicAdapter(BaseLLMAdapter):
"""Adapter for Anthropic Claude"""
def __init__(self, **config):
super().__init__(**config)
import anthropic # type: ignore
self.client = anthropic.AsyncAnthropic(
api_key=config.get('api_key', os.getenv('ANTHROPIC_API_KEY'))
)
async def chat(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
# Anthropic requires system message to be separate
system_message = None
anthropic_messages = []
for msg in messages:
if msg.role == "system":
system_message = msg.content
else:
anthropic_messages.append({
"role": msg.role,
"content": msg.content
})
request_kwargs = {
"model": model,
"messages": anthropic_messages,
"max_tokens": kwargs.pop('max_tokens', 1000),
**kwargs
}
if system_message:
request_kwargs["system"] = system_message
if stream:
return self._stream_chat(**request_kwargs)
else:
response = await self.client.messages.create(
stream=False,
**request_kwargs
)
return ChatResponse(
content=response.content[0].text,
model=model,
finish_reason=response.stop_reason,
usage={
"input_tokens": response.usage.input_tokens,
"output_tokens": response.usage.output_tokens
}
)
async def _stream_chat(self, **kwargs):
async with self.client.messages.stream(**kwargs) as stream:
async for text in stream.text_stream:
yield ChatResponse(
content=text,
model=kwargs['model']
)
async def generate(
self,
model: str,
prompt: str,
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
messages = [ChatMessage(role="user", content=prompt)]
return await self.chat(model, messages, stream, **kwargs)
async def list_models(self) -> List[str]:
# Anthropic doesn't have a list models endpoint, return known models
return [
"claude-3-5-sonnet-20241022",
"claude-3-5-haiku-20241022",
"claude-3-opus-20240229"
]
class GeminiAdapter(BaseLLMAdapter):
"""Adapter for Google Gemini"""
def __init__(self, **config):
super().__init__(**config)
import google.generativeai as genai # type: ignore
genai.configure(api_key=config.get('api_key', os.getenv('GEMINI_API_KEY')))
self.genai = genai
async def chat(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
model_instance = self.genai.GenerativeModel(model)
# Convert messages to Gemini format
chat_history = []
current_message = None
for i, msg in enumerate(messages[:-1]): # All but last message go to history
if msg.role == "user":
chat_history.append({"role": "user", "parts": [msg.content]})
elif msg.role == "assistant":
chat_history.append({"role": "model", "parts": [msg.content]})
# Last message is the current prompt
if messages:
current_message = messages[-1].content
chat = model_instance.start_chat(history=chat_history)
if not current_message:
raise ValueError("No current message provided for chat")
if stream:
return self._stream_chat(chat, current_message, **kwargs)
else:
response = await chat.send_message_async(current_message, **kwargs)
return ChatResponse(
content=response.text,
model=model,
finish_reason=response.candidates[0].finish_reason.name if response.candidates else None
)
async def _stream_chat(self, chat, message: str, **kwargs):
async for chunk in chat.send_message_async(message, stream=True, **kwargs):
if chunk.text:
yield ChatResponse(
content=chunk.text,
model=chat.model.model_name
)
async def generate(
self,
model: str,
prompt: str,
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
messages = [ChatMessage(role="user", content=prompt)]
return await self.chat(model, messages, stream, **kwargs)
async def list_models(self) -> List[str]:
models = self.genai.list_models()
return [model.name for model in models if 'generateContent' in model.supported_generation_methods]
class UnifiedLLMProxy:
"""Main proxy class that provides unified interface to all LLM providers"""
def __init__(self, default_provider: LLMProvider = LLMProvider.OLLAMA):
self.adapters: Dict[LLMProvider, BaseLLMAdapter] = {}
self.default_provider = default_provider
self._initialized_providers = set()
def configure_provider(self, provider: LLMProvider, **config):
"""Configure a specific provider with its settings"""
adapter_classes = {
LLMProvider.OLLAMA: OllamaAdapter,
LLMProvider.OPENAI: OpenAIAdapter,
LLMProvider.ANTHROPIC: AnthropicAdapter,
LLMProvider.GEMINI: GeminiAdapter,
# Add other providers as needed
}
if provider in adapter_classes:
self.adapters[provider] = adapter_classes[provider](**config)
self._initialized_providers.add(provider)
else:
raise ValueError(f"Unsupported provider: {provider}")
def set_default_provider(self, provider: LLMProvider):
"""Set the default provider for requests"""
if provider not in self._initialized_providers:
raise ValueError(f"Provider {provider} not configured")
self.default_provider = provider
async def chat(
self,
model: str,
messages: Union[List[ChatMessage], List[Dict[str, str]]],
provider: Optional[LLMProvider] = None,
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
"""Send chat messages using specified or default provider"""
provider = provider or self.default_provider
adapter = self._get_adapter(provider)
# Normalize messages to ChatMessage objects
normalized_messages = []
for msg in messages:
if isinstance(msg, ChatMessage):
normalized_messages.append(msg)
elif isinstance(msg, dict):
normalized_messages.append(ChatMessage.from_dict(msg))
else:
raise ValueError(f"Invalid message type: {type(msg)}")
return await adapter.chat(model, normalized_messages, stream, **kwargs)
async def chat_stream(
self,
model: str,
messages: Union[List[ChatMessage], List[Dict[str, str]]],
provider: Optional[LLMProvider] = None,
**kwargs
) -> AsyncGenerator[ChatResponse, None]:
"""Stream chat messages using specified or default provider"""
result = await self.chat(model, messages, provider, stream=True, **kwargs)
# Type checker now knows this is an AsyncGenerator due to stream=True
async for chunk in result: # type: ignore
yield chunk
async def chat_single(
self,
model: str,
messages: Union[List[ChatMessage], List[Dict[str, str]]],
provider: Optional[LLMProvider] = None,
**kwargs
) -> ChatResponse:
"""Get single chat response using specified or default provider"""
result = await self.chat(model, messages, provider, stream=False, **kwargs)
# Type checker now knows this is a ChatResponse due to stream=False
return result # type: ignore
async def generate(
self,
model: str,
prompt: str,
provider: Optional[LLMProvider] = None,
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
"""Generate text using specified or default provider"""
provider = provider or self.default_provider
adapter = self._get_adapter(provider)
return await adapter.generate(model, prompt, stream, **kwargs)
async def generate_stream(
self,
model: str,
prompt: str,
provider: Optional[LLMProvider] = None,
**kwargs
) -> AsyncGenerator[ChatResponse, None]:
"""Stream text generation using specified or default provider"""
result = await self.generate(model, prompt, provider, stream=True, **kwargs)
async for chunk in result: # type: ignore
yield chunk
async def generate_single(
self,
model: str,
prompt: str,
provider: Optional[LLMProvider] = None,
**kwargs
) -> ChatResponse:
"""Get single generation response using specified or default provider"""
result = await self.generate(model, prompt, provider, stream=False, **kwargs)
return result # type: ignore
async def list_models(self, provider: Optional[LLMProvider] = None) -> List[str]:
"""List available models for specified or default provider"""
provider = provider or self.default_provider
adapter = self._get_adapter(provider)
return await adapter.list_models()
def _get_adapter(self, provider: LLMProvider) -> BaseLLMAdapter:
"""Get adapter for specified provider"""
if provider not in self.adapters:
raise ValueError(f"Provider {provider} not configured")
return self.adapters[provider]
# Example usage and configuration
class LLMManager:
"""Singleton manager for the unified LLM proxy"""
_instance = None
_proxy = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
if LLMManager._proxy is None:
LLMManager._proxy = UnifiedLLMProxy()
self._configure_from_environment()
def _configure_from_environment(self):
"""Configure providers based on environment variables"""
if not self._proxy:
raise RuntimeError("LLM proxy not initialized")
# Configure Ollama if available
ollama_host = os.getenv('OLLAMA_HOST', 'http://localhost:11434')
self._proxy.configure_provider(LLMProvider.OLLAMA, host=ollama_host)
# Configure OpenAI if API key is available
if os.getenv('OPENAI_API_KEY'):
self._proxy.configure_provider(LLMProvider.OPENAI)
# Configure Anthropic if API key is available
if os.getenv('ANTHROPIC_API_KEY'):
self._proxy.configure_provider(LLMProvider.ANTHROPIC)
# Configure Gemini if API key is available
if os.getenv('GEMINI_API_KEY'):
self._proxy.configure_provider(LLMProvider.GEMINI)
# Set default provider from environment or use Ollama
default_provider = os.getenv('DEFAULT_LLM_PROVIDER', 'ollama')
try:
self._proxy.set_default_provider(LLMProvider(default_provider))
except ValueError:
# Fallback to Ollama if specified provider not available
self._proxy.set_default_provider(LLMProvider.OLLAMA)
def get_proxy(self) -> UnifiedLLMProxy:
"""Get the unified LLM proxy instance"""
if not self._proxy:
raise RuntimeError("LLM proxy not initialized")
return self._proxy
# Convenience function for easy access
def get_llm() -> UnifiedLLMProxy:
"""Get the configured LLM proxy"""
return LLMManager.get_instance().get_proxy()