Multi-LLM backend working with Ollama
This commit is contained in:
parent
5fed56ba76
commit
8aa8577874
541
frontend/src/components/JobCreator.tsx
Normal file
541
frontend/src/components/JobCreator.tsx
Normal file
@ -0,0 +1,541 @@
|
||||
import React, { useState, useEffect, useRef, JSX } from 'react';
|
||||
import {
|
||||
Box,
|
||||
Button,
|
||||
Typography,
|
||||
Paper,
|
||||
TextField,
|
||||
Grid,
|
||||
Dialog,
|
||||
DialogTitle,
|
||||
DialogContent,
|
||||
DialogContentText,
|
||||
DialogActions,
|
||||
IconButton,
|
||||
useTheme,
|
||||
useMediaQuery,
|
||||
Chip,
|
||||
Divider,
|
||||
Card,
|
||||
CardContent,
|
||||
CardHeader,
|
||||
LinearProgress,
|
||||
Stack,
|
||||
Alert
|
||||
} from '@mui/material';
|
||||
import {
|
||||
SyncAlt,
|
||||
Favorite,
|
||||
Settings,
|
||||
Info,
|
||||
Search,
|
||||
AutoFixHigh,
|
||||
Image,
|
||||
Psychology,
|
||||
Build,
|
||||
CloudUpload,
|
||||
Description,
|
||||
Business,
|
||||
LocationOn,
|
||||
Work,
|
||||
CheckCircle,
|
||||
Star
|
||||
} from '@mui/icons-material';
|
||||
import { styled } from '@mui/material/styles';
|
||||
import DescriptionIcon from '@mui/icons-material/Description';
|
||||
import FileUploadIcon from '@mui/icons-material/FileUpload';
|
||||
|
||||
import { useAuth } from 'hooks/AuthContext';
|
||||
import { useSelectedCandidate, useSelectedJob } from 'hooks/GlobalContext';
|
||||
import { BackstoryElementProps } from './BackstoryTab';
|
||||
import { LoginRequired } from 'components/ui/LoginRequired';
|
||||
|
||||
import * as Types from 'types/types';
|
||||
|
||||
const VisuallyHiddenInput = styled('input')({
|
||||
clip: 'rect(0 0 0 0)',
|
||||
clipPath: 'inset(50%)',
|
||||
height: 1,
|
||||
overflow: 'hidden',
|
||||
position: 'absolute',
|
||||
bottom: 0,
|
||||
left: 0,
|
||||
whiteSpace: 'nowrap',
|
||||
width: 1,
|
||||
});
|
||||
|
||||
const UploadBox = styled(Box)(({ theme }) => ({
|
||||
border: `2px dashed ${theme.palette.primary.main}`,
|
||||
borderRadius: theme.shape.borderRadius * 2,
|
||||
padding: theme.spacing(4),
|
||||
textAlign: 'center',
|
||||
backgroundColor: theme.palette.action.hover,
|
||||
transition: 'all 0.3s ease',
|
||||
cursor: 'pointer',
|
||||
'&:hover': {
|
||||
backgroundColor: theme.palette.action.selected,
|
||||
borderColor: theme.palette.primary.dark,
|
||||
},
|
||||
}));
|
||||
|
||||
const StatusBox = styled(Box)(({ theme }) => ({
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
gap: theme.spacing(1),
|
||||
padding: theme.spacing(1, 2),
|
||||
backgroundColor: theme.palette.background.paper,
|
||||
borderRadius: theme.shape.borderRadius,
|
||||
border: `1px solid ${theme.palette.divider}`,
|
||||
minHeight: 48,
|
||||
}));
|
||||
|
||||
const getIcon = (type: Types.ApiActivityType) => {
|
||||
switch (type) {
|
||||
case 'converting':
|
||||
return <SyncAlt color="primary" />;
|
||||
case 'heartbeat':
|
||||
return <Favorite color="error" />;
|
||||
case 'system':
|
||||
return <Settings color="action" />;
|
||||
case 'info':
|
||||
return <Info color="info" />;
|
||||
case 'searching':
|
||||
return <Search color="primary" />;
|
||||
case 'generating':
|
||||
return <AutoFixHigh color="secondary" />;
|
||||
case 'generating_image':
|
||||
return <Image color="primary" />;
|
||||
case 'thinking':
|
||||
return <Psychology color="secondary" />;
|
||||
case 'tooling':
|
||||
return <Build color="action" />;
|
||||
default:
|
||||
return <Info color="action" />;
|
||||
}
|
||||
};
|
||||
|
||||
interface JobCreator extends BackstoryElementProps {
|
||||
onSave?: (job: Types.Job) => void;
|
||||
}
|
||||
const JobCreator = (props: JobCreator) => {
|
||||
const { user, apiClient } = useAuth();
|
||||
const { onSave } = props;
|
||||
const { selectedCandidate } = useSelectedCandidate();
|
||||
const { selectedJob, setSelectedJob } = useSelectedJob();
|
||||
const { setSnack, submitQuery } = props;
|
||||
const backstoryProps = { setSnack, submitQuery };
|
||||
const theme = useTheme();
|
||||
const isMobile = useMediaQuery(theme.breakpoints.down('sm'));
|
||||
const isTablet = useMediaQuery(theme.breakpoints.down('md'));
|
||||
|
||||
const [openUploadDialog, setOpenUploadDialog] = useState<boolean>(false);
|
||||
const [jobDescription, setJobDescription] = useState<string>('');
|
||||
const [jobRequirements, setJobRequirements] = useState<Types.JobRequirements | null>(null);
|
||||
const [jobTitle, setJobTitle] = useState<string>('');
|
||||
const [company, setCompany] = useState<string>('');
|
||||
const [summary, setSummary] = useState<string>('');
|
||||
const [jobLocation, setJobLocation] = useState<string>('');
|
||||
const [jobId, setJobId] = useState<string>('');
|
||||
const [jobStatus, setJobStatus] = useState<string>('');
|
||||
const [jobStatusIcon, setJobStatusIcon] = useState<JSX.Element>(<></>);
|
||||
const [isProcessing, setIsProcessing] = useState<boolean>(false);
|
||||
|
||||
useEffect(() => {
|
||||
|
||||
}, [jobTitle, jobDescription, company]);
|
||||
|
||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
|
||||
if (!user?.id) {
|
||||
return (
|
||||
<LoginRequired asset="job creation" />
|
||||
);
|
||||
}
|
||||
|
||||
const jobStatusHandlers = {
|
||||
onStatus: (status: Types.ChatMessageStatus) => {
|
||||
console.log('status:', status.content);
|
||||
setJobStatusIcon(getIcon(status.activity));
|
||||
setJobStatus(status.content);
|
||||
},
|
||||
onMessage: (job: Types.Job) => {
|
||||
console.log('onMessage - job', job);
|
||||
setCompany(job.company || '');
|
||||
setJobDescription(job.description);
|
||||
setSummary(job.summary || '');
|
||||
setJobTitle(job.title || '');
|
||||
setJobRequirements(job.requirements || null);
|
||||
setJobStatusIcon(<></>);
|
||||
setJobStatus('');
|
||||
},
|
||||
onError: (error: Types.ChatMessageError) => {
|
||||
console.log('onError', error);
|
||||
setSnack(error.content, "error");
|
||||
setIsProcessing(false);
|
||||
},
|
||||
onComplete: () => {
|
||||
setJobStatusIcon(<></>);
|
||||
setJobStatus('');
|
||||
setIsProcessing(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleJobUpload = async (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
if (e.target.files && e.target.files[0]) {
|
||||
const file = e.target.files[0];
|
||||
const fileExtension = '.' + file.name.split('.').pop()?.toLowerCase();
|
||||
let docType: Types.DocumentType | null = null;
|
||||
switch (fileExtension.substring(1)) {
|
||||
case "pdf":
|
||||
docType = "pdf";
|
||||
break;
|
||||
case "docx":
|
||||
docType = "docx";
|
||||
break;
|
||||
case "md":
|
||||
docType = "markdown";
|
||||
break;
|
||||
case "txt":
|
||||
docType = "txt";
|
||||
break;
|
||||
}
|
||||
|
||||
if (!docType) {
|
||||
setSnack('Invalid file type. Please upload .txt, .md, .docx, or .pdf files only.', 'error');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
setIsProcessing(true);
|
||||
setJobDescription('');
|
||||
setJobTitle('');
|
||||
setJobRequirements(null);
|
||||
setSummary('');
|
||||
const controller = apiClient.createJobFromFile(file, jobStatusHandlers);
|
||||
const job = await controller.promise;
|
||||
if (!job) {
|
||||
return;
|
||||
}
|
||||
console.log(`Job id: ${job.id}`);
|
||||
e.target.value = '';
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
setSnack('Failed to upload document', 'error');
|
||||
setIsProcessing(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleUploadClick = () => {
|
||||
fileInputRef.current?.click();
|
||||
};
|
||||
|
||||
const renderRequirementSection = (title: string, items: string[] | undefined, icon: JSX.Element, required = false) => {
|
||||
if (!items || items.length === 0) return null;
|
||||
|
||||
return (
|
||||
<Box sx={{ mb: 3 }}>
|
||||
<Box sx={{ display: 'flex', alignItems: 'center', mb: 1.5 }}>
|
||||
{icon}
|
||||
<Typography variant="subtitle1" sx={{ ml: 1, fontWeight: 600 }}>
|
||||
{title}
|
||||
</Typography>
|
||||
{required && <Chip label="Required" size="small" color="error" sx={{ ml: 1 }} />}
|
||||
</Box>
|
||||
<Stack direction="row" spacing={1} flexWrap="wrap" useFlexGap>
|
||||
{items.map((item, index) => (
|
||||
<Chip
|
||||
key={index}
|
||||
label={item}
|
||||
variant="outlined"
|
||||
size="small"
|
||||
sx={{ mb: 1 }}
|
||||
/>
|
||||
))}
|
||||
</Stack>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
const renderJobRequirements = () => {
|
||||
if (!jobRequirements) return null;
|
||||
|
||||
return (
|
||||
<Card elevation={2} sx={{ mt: 3 }}>
|
||||
<CardHeader
|
||||
title="Job Requirements Analysis"
|
||||
avatar={<CheckCircle color="success" />}
|
||||
sx={{ pb: 1 }}
|
||||
/>
|
||||
<CardContent sx={{ pt: 0 }}>
|
||||
{renderRequirementSection(
|
||||
"Technical Skills (Required)",
|
||||
jobRequirements.technicalSkills.required,
|
||||
<Build color="primary" />,
|
||||
true
|
||||
)}
|
||||
{renderRequirementSection(
|
||||
"Technical Skills (Preferred)",
|
||||
jobRequirements.technicalSkills.preferred,
|
||||
<Build color="action" />
|
||||
)}
|
||||
{renderRequirementSection(
|
||||
"Experience Requirements (Required)",
|
||||
jobRequirements.experienceRequirements.required,
|
||||
<Work color="primary" />,
|
||||
true
|
||||
)}
|
||||
{renderRequirementSection(
|
||||
"Experience Requirements (Preferred)",
|
||||
jobRequirements.experienceRequirements.preferred,
|
||||
<Work color="action" />
|
||||
)}
|
||||
{renderRequirementSection(
|
||||
"Soft Skills",
|
||||
jobRequirements.softSkills,
|
||||
<Psychology color="secondary" />
|
||||
)}
|
||||
{renderRequirementSection(
|
||||
"Experience",
|
||||
jobRequirements.experience,
|
||||
<Star color="warning" />
|
||||
)}
|
||||
{renderRequirementSection(
|
||||
"Education",
|
||||
jobRequirements.education,
|
||||
<Description color="info" />
|
||||
)}
|
||||
{renderRequirementSection(
|
||||
"Certifications",
|
||||
jobRequirements.certifications,
|
||||
<CheckCircle color="success" />
|
||||
)}
|
||||
{renderRequirementSection(
|
||||
"Preferred Attributes",
|
||||
jobRequirements.preferredAttributes,
|
||||
<Star color="secondary" />
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
};
|
||||
|
||||
const handleSave = async () => {
|
||||
const newJob: Types.Job = {
|
||||
ownerId: user?.id || '',
|
||||
ownerType: 'candidate',
|
||||
description: jobDescription,
|
||||
company: company,
|
||||
summary: summary,
|
||||
title: jobTitle,
|
||||
requirements: jobRequirements || undefined
|
||||
};
|
||||
setIsProcessing(true);
|
||||
const job = await apiClient.createJob(newJob);
|
||||
setIsProcessing(false);
|
||||
onSave ? onSave(job) : setSelectedJob(job);
|
||||
};
|
||||
|
||||
const handleExtractRequirements = () => {
|
||||
// Implement requirements extraction logic here
|
||||
setIsProcessing(true);
|
||||
// This would call your API to extract requirements from the job description
|
||||
};
|
||||
|
||||
const renderJobCreation = () => {
|
||||
if (!user) {
|
||||
return <Box>You must be logged in</Box>;
|
||||
}
|
||||
|
||||
return (
|
||||
<Box sx={{
|
||||
mx: 'auto', p: { xs: 2, sm: 3 },
|
||||
}}>
|
||||
{/* Upload Section */}
|
||||
<Card elevation={3} sx={{ mb: 4 }}>
|
||||
<CardHeader
|
||||
title="Job Information"
|
||||
subheader="Upload a job description or enter details manually"
|
||||
avatar={<Work color="primary" />}
|
||||
/>
|
||||
<CardContent>
|
||||
<Grid container spacing={3}>
|
||||
<Grid size={{ xs: 12, md: 6 }}>
|
||||
<Typography variant="h6" gutterBottom sx={{ display: 'flex', alignItems: 'center' }}>
|
||||
<CloudUpload sx={{ mr: 1 }} />
|
||||
Upload Job Description
|
||||
</Typography>
|
||||
<UploadBox onClick={handleUploadClick}>
|
||||
<CloudUpload sx={{ fontSize: 48, color: 'primary.main', mb: 2 }} />
|
||||
<Typography variant="h6" gutterBottom>
|
||||
Drop your job description here
|
||||
</Typography>
|
||||
<Typography variant="body2" color="text.secondary" sx={{ mb: 2 }}>
|
||||
Supported formats: PDF, DOCX, TXT, MD
|
||||
</Typography>
|
||||
<Button
|
||||
variant="contained"
|
||||
startIcon={<FileUploadIcon />}
|
||||
disabled={isProcessing}
|
||||
// onClick={handleUploadClick}
|
||||
>
|
||||
Choose File
|
||||
</Button>
|
||||
</UploadBox>
|
||||
<VisuallyHiddenInput
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
accept=".txt,.md,.docx,.pdf"
|
||||
onChange={handleJobUpload}
|
||||
/>
|
||||
</Grid>
|
||||
|
||||
<Grid size={{ xs: 12, md: 6 }}>
|
||||
<Typography variant="h6" gutterBottom sx={{ display: 'flex', alignItems: 'center' }}>
|
||||
<Description sx={{ mr: 1 }} />
|
||||
Or Enter Manually
|
||||
</Typography>
|
||||
<TextField
|
||||
fullWidth
|
||||
multiline
|
||||
rows={isMobile ? 8 : 12}
|
||||
placeholder="Paste or type the job description here..."
|
||||
variant="outlined"
|
||||
value={jobDescription}
|
||||
onChange={(e) => setJobDescription(e.target.value)}
|
||||
disabled={isProcessing}
|
||||
sx={{ mb: 2 }}
|
||||
/>
|
||||
{jobRequirements === null && jobDescription && (
|
||||
<Button
|
||||
variant="outlined"
|
||||
onClick={handleExtractRequirements}
|
||||
startIcon={<AutoFixHigh />}
|
||||
disabled={isProcessing}
|
||||
fullWidth={isMobile}
|
||||
>
|
||||
Extract Requirements
|
||||
</Button>
|
||||
)}
|
||||
</Grid>
|
||||
</Grid>
|
||||
|
||||
{(jobStatus || isProcessing) && (
|
||||
<Box sx={{ mt: 3 }}>
|
||||
<StatusBox>
|
||||
{jobStatusIcon}
|
||||
<Typography variant="body2" sx={{ ml: 1 }}>
|
||||
{jobStatus || 'Processing...'}
|
||||
</Typography>
|
||||
</StatusBox>
|
||||
{isProcessing && <LinearProgress sx={{ mt: 1 }} />}
|
||||
</Box>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
{/* Job Details Section */}
|
||||
<Card elevation={3} sx={{ mb: 4 }}>
|
||||
<CardHeader
|
||||
title="Job Details"
|
||||
subheader="Enter specific information about the position"
|
||||
avatar={<Business color="primary" />}
|
||||
/>
|
||||
<CardContent>
|
||||
<Grid container spacing={3}>
|
||||
<Grid size={{ xs: 12, md: 6 }}>
|
||||
<TextField
|
||||
fullWidth
|
||||
label="Job Title"
|
||||
variant="outlined"
|
||||
value={jobTitle}
|
||||
onChange={(e) => setJobTitle(e.target.value)}
|
||||
required
|
||||
disabled={isProcessing}
|
||||
InputProps={{
|
||||
startAdornment: <Work sx={{ mr: 1, color: 'text.secondary' }} />
|
||||
}}
|
||||
/>
|
||||
</Grid>
|
||||
|
||||
<Grid size={{ xs: 12, md: 6 }}>
|
||||
<TextField
|
||||
fullWidth
|
||||
label="Company"
|
||||
variant="outlined"
|
||||
value={company}
|
||||
onChange={(e) => setCompany(e.target.value)}
|
||||
required
|
||||
disabled={isProcessing}
|
||||
InputProps={{
|
||||
startAdornment: <Business sx={{ mr: 1, color: 'text.secondary' }} />
|
||||
}}
|
||||
/>
|
||||
</Grid>
|
||||
|
||||
{/* <Grid size={{ xs: 12, md: 6 }}>
|
||||
<TextField
|
||||
fullWidth
|
||||
label="Job Location"
|
||||
variant="outlined"
|
||||
value={jobLocation}
|
||||
onChange={(e) => setJobLocation(e.target.value)}
|
||||
disabled={isProcessing}
|
||||
InputProps={{
|
||||
startAdornment: <LocationOn sx={{ mr: 1, color: 'text.secondary' }} />
|
||||
}}
|
||||
/>
|
||||
</Grid> */}
|
||||
|
||||
<Grid size={{ xs: 12, md: 6 }}>
|
||||
<Box sx={{ display: 'flex', gap: 2, alignItems: 'flex-end', height: '100%' }}>
|
||||
<Button
|
||||
variant="contained"
|
||||
onClick={handleSave}
|
||||
disabled={!jobTitle || !company || !jobDescription || isProcessing}
|
||||
fullWidth={isMobile}
|
||||
size="large"
|
||||
startIcon={<CheckCircle />}
|
||||
>
|
||||
Save Job
|
||||
</Button>
|
||||
</Box>
|
||||
</Grid>
|
||||
</Grid>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
{/* Job Summary */}
|
||||
{summary !== '' &&
|
||||
<Card elevation={2} sx={{ mt: 3 }}>
|
||||
<CardHeader
|
||||
title="Job Summary"
|
||||
avatar={<CheckCircle color="success" />}
|
||||
sx={{ pb: 1 }}
|
||||
/>
|
||||
<CardContent sx={{ pt: 0 }}>
|
||||
{summary}
|
||||
</CardContent>
|
||||
</Card>
|
||||
}
|
||||
|
||||
{/* Requirements Display */}
|
||||
{renderJobRequirements()}
|
||||
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<Box className="JobManagement"
|
||||
sx={{
|
||||
background: "white",
|
||||
p: 0,
|
||||
}}>
|
||||
{selectedJob === null && renderJobCreation()}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export { JobCreator };
|
@ -115,34 +115,23 @@ const getIcon = (type: Types.ApiActivityType) => {
|
||||
};
|
||||
|
||||
const JobManagement = (props: BackstoryElementProps) => {
|
||||
const { user, apiClient } = useAuth();
|
||||
const { selectedCandidate } = useSelectedCandidate();
|
||||
const { user, apiClient } = useAuth();
|
||||
const { selectedJob, setSelectedJob } = useSelectedJob();
|
||||
const { setSnack, submitQuery } = props;
|
||||
const backstoryProps = { setSnack, submitQuery };
|
||||
const { setSnack, submitQuery } = props;
|
||||
const theme = useTheme();
|
||||
const isMobile = useMediaQuery(theme.breakpoints.down('sm'));
|
||||
const isTablet = useMediaQuery(theme.breakpoints.down('md'));
|
||||
const isMobile = useMediaQuery(theme.breakpoints.down('sm'));
|
||||
|
||||
const [openUploadDialog, setOpenUploadDialog] = useState<boolean>(false);
|
||||
const [jobDescription, setJobDescription] = useState<string>('');
|
||||
const [jobRequirements, setJobRequirements] = useState<Types.JobRequirements | null>(null);
|
||||
const [jobTitle, setJobTitle] = useState<string>('');
|
||||
const [company, setCompany] = useState<string>('');
|
||||
const [summary, setSummary] = useState<string>('');
|
||||
const [jobLocation, setJobLocation] = useState<string>('');
|
||||
const [jobId, setJobId] = useState<string>('');
|
||||
const [jobStatus, setJobStatus] = useState<string>('');
|
||||
const [jobStatusIcon, setJobStatusIcon] = useState<JSX.Element>(<></>);
|
||||
const [isProcessing, setIsProcessing] = useState<boolean>(false);
|
||||
|
||||
useEffect(() => {
|
||||
|
||||
}, [jobTitle, jobDescription, company]);
|
||||
|
||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
|
||||
if (!user?.id) {
|
||||
return (
|
||||
<LoginRequired asset="candidate analysis" />
|
||||
@ -223,12 +212,12 @@ const JobManagement = (props: BackstoryElementProps) => {
|
||||
setJobTitle('');
|
||||
setJobRequirements(null);
|
||||
setSummary('');
|
||||
const controller = apiClient.uploadCandidateDocument(file, { isJobDocument: true, overwrite: true }, documentStatusHandlers);
|
||||
const document = await controller.promise;
|
||||
if (!document) {
|
||||
const controller = apiClient.createJobFromFile(file, jobStatusHandlers);
|
||||
const job = await controller.promise;
|
||||
if (!job) {
|
||||
return;
|
||||
}
|
||||
console.log(`Document id: ${document.id}`);
|
||||
console.log(`Job id: ${job.id}`);
|
||||
e.target.value = '';
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
@ -354,10 +343,6 @@ const JobManagement = (props: BackstoryElementProps) => {
|
||||
// This would call your API to extract requirements from the job description
|
||||
};
|
||||
|
||||
const loadJob = async () => {
|
||||
const job = await apiClient.getJob("7594e989-a926-45a2-9b07-ae553d2e0d0d");
|
||||
setSelectedJob(job);
|
||||
}
|
||||
const renderJobCreation = () => {
|
||||
if (!user) {
|
||||
return <Box>You must be logged in</Box>;
|
||||
@ -367,7 +352,6 @@ const JobManagement = (props: BackstoryElementProps) => {
|
||||
<Box sx={{
|
||||
mx: 'auto', p: { xs: 2, sm: 3 },
|
||||
}}>
|
||||
<Button onClick={loadJob} variant="contained">Load Job</Button>
|
||||
{/* Upload Section */}
|
||||
<Card elevation={3} sx={{ mb: 4 }}>
|
||||
<CardHeader
|
||||
|
@ -14,7 +14,8 @@ import {
|
||||
CardContent,
|
||||
useTheme,
|
||||
LinearProgress,
|
||||
useMediaQuery
|
||||
useMediaQuery,
|
||||
Button
|
||||
} from '@mui/material';
|
||||
import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
|
||||
import CheckCircleIcon from '@mui/icons-material/CheckCircle';
|
||||
@ -28,6 +29,8 @@ import { toCamelCase } from 'types/conversion';
|
||||
import { Job } from 'types/types';
|
||||
import { StyledMarkdown } from './StyledMarkdown';
|
||||
import { Scrollable } from './Scrollable';
|
||||
import { start } from 'repl';
|
||||
import { TypesElement } from '@uiw/react-json-view';
|
||||
|
||||
interface JobAnalysisProps extends BackstoryPageProps {
|
||||
job: Job;
|
||||
@ -56,6 +59,9 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
|
||||
const [overallScore, setOverallScore] = useState<number>(0);
|
||||
const [requirementsSession, setRequirementsSession] = useState<ChatSession | null>(null);
|
||||
const [statusMessage, setStatusMessage] = useState<ChatMessage | null>(null);
|
||||
const [startAnalysis, setStartAnalysis] = useState<boolean>(false);
|
||||
const [analyzing, setAnalyzing] = useState<boolean>(false);
|
||||
|
||||
const isMobile = useMediaQuery(theme.breakpoints.down('sm'));
|
||||
|
||||
// Handle accordion expansion
|
||||
@ -63,7 +69,7 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
|
||||
setExpanded(isExpanded ? panel : false);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const initializeRequirements = (job: Job) => {
|
||||
if (!job || !job.requirements) {
|
||||
return;
|
||||
}
|
||||
@ -106,116 +112,19 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
|
||||
setSkillMatches(initialSkillMatches);
|
||||
setStatusMessage(null);
|
||||
setLoadingRequirements(false);
|
||||
|
||||
}, [job, setRequirements]);
|
||||
setOverallScore(0);
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
if (requirementsSession || creatingSession) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
setCreatingSession(true);
|
||||
apiClient.getOrCreateChatSession(candidate, `Generate requirements for ${candidate.fullName}`, 'job_requirements')
|
||||
.then(session => {
|
||||
setRequirementsSession(session);
|
||||
setCreatingSession(false);
|
||||
});
|
||||
} catch (error) {
|
||||
setSnack('Unable to load chat session', 'error');
|
||||
} finally {
|
||||
setCreatingSession(false);
|
||||
}
|
||||
|
||||
}, [requirementsSession, apiClient, candidate]);
|
||||
|
||||
// Fetch initial requirements
|
||||
useEffect(() => {
|
||||
if (!job.description || !requirementsSession || loadingRequirements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const getRequirements = async () => {
|
||||
setLoadingRequirements(true);
|
||||
try {
|
||||
const chatMessage: ChatMessageUser = { ...defaultMessage, sessionId: requirementsSession.id || '', content: job.description };
|
||||
apiClient.sendMessageStream(chatMessage, {
|
||||
onMessage: (msg: ChatMessage) => {
|
||||
console.log(`onMessage: ${msg.type}`, msg);
|
||||
const job: Job = toCamelCase<Job>(JSON.parse(msg.content || ''));
|
||||
const requirements: { requirement: string, domain: string }[] = [];
|
||||
if (job.requirements?.technicalSkills) {
|
||||
job.requirements.technicalSkills.required?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (required)' }));
|
||||
job.requirements.technicalSkills.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (preferred)' }));
|
||||
}
|
||||
if (job.requirements?.experienceRequirements) {
|
||||
job.requirements.experienceRequirements.required?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (required)' }));
|
||||
job.requirements.experienceRequirements.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (preferred)' }));
|
||||
}
|
||||
if (job.requirements?.softSkills) {
|
||||
job.requirements.softSkills.forEach(req => requirements.push({ requirement: req, domain: 'Soft Skills' }));
|
||||
}
|
||||
if (job.requirements?.experience) {
|
||||
job.requirements.experience.forEach(req => requirements.push({ requirement: req, domain: 'Experience' }));
|
||||
}
|
||||
if (job.requirements?.education) {
|
||||
job.requirements.education.forEach(req => requirements.push({ requirement: req, domain: 'Education' }));
|
||||
}
|
||||
if (job.requirements?.certifications) {
|
||||
job.requirements.certifications.forEach(req => requirements.push({ requirement: req, domain: 'Certifications' }));
|
||||
}
|
||||
if (job.requirements?.preferredAttributes) {
|
||||
job.requirements.preferredAttributes.forEach(req => requirements.push({ requirement: req, domain: 'Preferred Attributes' }));
|
||||
}
|
||||
|
||||
const initialSkillMatches = requirements.map(req => ({
|
||||
requirement: req.requirement,
|
||||
domain: req.domain,
|
||||
status: 'waiting' as const,
|
||||
matchScore: 0,
|
||||
assessment: '',
|
||||
description: '',
|
||||
citations: []
|
||||
}));
|
||||
|
||||
setRequirements(requirements);
|
||||
setSkillMatches(initialSkillMatches);
|
||||
setStatusMessage(null);
|
||||
setLoadingRequirements(false);
|
||||
},
|
||||
onError: (error: string | ChatMessageError) => {
|
||||
console.log("onError:", error);
|
||||
// Type-guard to determine if this is a ChatMessageBase or a string
|
||||
if (typeof error === "object" && error !== null && "content" in error) {
|
||||
setSnack(error.content || 'Error obtaining requirements from job description.', "error");
|
||||
} else {
|
||||
setSnack(error as string, "error");
|
||||
}
|
||||
setLoadingRequirements(false);
|
||||
},
|
||||
onStreaming: (chunk: ChatMessageStreaming) => {
|
||||
// console.log("onStreaming:", chunk);
|
||||
},
|
||||
onStatus: (status: ChatMessageStatus) => {
|
||||
console.log(`onStatus: ${status}`);
|
||||
},
|
||||
onComplete: () => {
|
||||
console.log("onComplete");
|
||||
setStatusMessage(null);
|
||||
setLoadingRequirements(false);
|
||||
}
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Failed to send message:', error);
|
||||
setLoadingRequirements(false);
|
||||
}
|
||||
};
|
||||
|
||||
getRequirements();
|
||||
}, [job, requirementsSession]);
|
||||
initializeRequirements(job);
|
||||
}, [job]);
|
||||
|
||||
// Fetch match data for each requirement
|
||||
useEffect(() => {
|
||||
if (!startAnalysis || analyzing || !job.requirements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const fetchMatchData = async () => {
|
||||
if (requirements.length === 0) return;
|
||||
|
||||
@ -279,10 +188,9 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
|
||||
}
|
||||
};
|
||||
|
||||
if (!loadingRequirements) {
|
||||
fetchMatchData();
|
||||
}
|
||||
}, [requirements, loadingRequirements]);
|
||||
setAnalyzing(true);
|
||||
fetchMatchData().then(() => { setAnalyzing(false); setStartAnalysis(false) });
|
||||
}, [job, startAnalysis, analyzing, requirements, loadingRequirements]);
|
||||
|
||||
// Get color based on match score
|
||||
const getMatchColor = (score: number): string => {
|
||||
@ -301,6 +209,11 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
|
||||
return <ErrorIcon color="error" />;
|
||||
};
|
||||
|
||||
const beginAnalysis = () => {
|
||||
initializeRequirements(job);
|
||||
setStartAnalysis(true);
|
||||
};
|
||||
|
||||
return (
|
||||
<Box>
|
||||
<Paper elevation={3} sx={{ p: 3, mb: 4 }}>
|
||||
@ -344,7 +257,9 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
|
||||
</Box>
|
||||
|
||||
<Grid size={{ xs: 12 }} sx={{ mt: 2 }}>
|
||||
<Box sx={{ display: 'flex', alignItems: 'center', mb: 2 }}>
|
||||
<Box sx={{ display: 'flex', alignItems: 'center', mb: 2, gap: 1 }}>
|
||||
{<Button disabled={analyzing || startAnalysis} onClick={beginAnalysis} variant="contained">Start Analysis</Button>}
|
||||
{overallScore !== 0 && <>
|
||||
<Typography variant="h5" component="h2" sx={{ mr: 2 }}>
|
||||
Overall Match:
|
||||
</Typography>
|
||||
@ -391,6 +306,7 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
|
||||
fontWeight: 'bold'
|
||||
}}
|
||||
/>
|
||||
</>}
|
||||
</Box>
|
||||
</Grid>
|
||||
</Grid>
|
||||
|
@ -13,7 +13,7 @@ import { CopyBubble } from "components/CopyBubble";
|
||||
import { rest } from 'lodash';
|
||||
import { AIBanner } from 'components/ui/AIBanner';
|
||||
import { useAuth } from 'hooks/AuthContext';
|
||||
import { DeleteConfirmation } from './DeleteConfirmation';
|
||||
import { DeleteConfirmation } from '../DeleteConfirmation';
|
||||
|
||||
interface CandidateInfoProps {
|
||||
candidate: Candidate;
|
@ -4,7 +4,7 @@ import Button from '@mui/material/Button';
|
||||
import Box from '@mui/material/Box';
|
||||
|
||||
import { BackstoryElementProps } from 'components/BackstoryTab';
|
||||
import { CandidateInfo } from 'components/CandidateInfo';
|
||||
import { CandidateInfo } from 'components/ui/CandidateInfo';
|
||||
import { Candidate } from "types/types";
|
||||
import { useAuth } from 'hooks/AuthContext';
|
||||
import { useSelectedCandidate } from 'hooks/GlobalContext';
|
||||
|
97
frontend/src/components/ui/JobInfo.tsx
Normal file
97
frontend/src/components/ui/JobInfo.tsx
Normal file
@ -0,0 +1,97 @@
|
||||
import React from 'react';
|
||||
import { Box, Link, Typography, Avatar, Grid, SxProps, CardActions } from '@mui/material';
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
Divider,
|
||||
useTheme,
|
||||
} from '@mui/material';
|
||||
import DeleteIcon from '@mui/icons-material/Delete';
|
||||
import { useMediaQuery } from '@mui/material';
|
||||
import { JobFull } from 'types/types';
|
||||
import { CopyBubble } from "components/CopyBubble";
|
||||
import { rest } from 'lodash';
|
||||
import { AIBanner } from 'components/ui/AIBanner';
|
||||
import { useAuth } from 'hooks/AuthContext';
|
||||
import { DeleteConfirmation } from '../DeleteConfirmation';
|
||||
|
||||
interface JobInfoProps {
|
||||
job: JobFull;
|
||||
sx?: SxProps;
|
||||
action?: string;
|
||||
elevation?: number;
|
||||
variant?: "small" | "normal" | null
|
||||
};
|
||||
|
||||
const JobInfo: React.FC<JobInfoProps> = (props: JobInfoProps) => {
|
||||
const { job } = props;
|
||||
const { user, apiClient } = useAuth();
|
||||
const {
|
||||
sx,
|
||||
action = '',
|
||||
elevation = 1,
|
||||
variant = "normal"
|
||||
} = props;
|
||||
const theme = useTheme();
|
||||
const isMobile = useMediaQuery(theme.breakpoints.down('md'));
|
||||
const isAdmin = user?.isAdmin;
|
||||
|
||||
const deleteJob = async (jobId: string | undefined) => {
|
||||
if (jobId) {
|
||||
await apiClient.deleteJob(jobId);
|
||||
}
|
||||
}
|
||||
|
||||
if (!job) {
|
||||
return <Box>No user loaded.</Box>;
|
||||
}
|
||||
|
||||
return (
|
||||
<Card
|
||||
elevation={elevation}
|
||||
sx={{
|
||||
display: "flex",
|
||||
borderColor: 'transparent',
|
||||
borderWidth: 2,
|
||||
borderStyle: 'solid',
|
||||
transition: 'all 0.3s ease',
|
||||
flexDirection: "column",
|
||||
...sx
|
||||
}}
|
||||
{...rest}
|
||||
>
|
||||
<CardContent sx={{ display: "flex", flexGrow: 1, p: 3, height: '100%', flexDirection: 'column', alignItems: 'stretch', position: "relative" }}>
|
||||
{variant !== "small" && <>
|
||||
{job.location &&
|
||||
<Typography variant="body2" sx={{ mb: 1 }}>
|
||||
<strong>Location:</strong> {job.location.city}, {job.location.state || job.location.country}
|
||||
</Typography>
|
||||
}
|
||||
{job.company &&
|
||||
<Typography variant="body2" sx={{ mb: 1 }}>
|
||||
<strong>Company:</strong> {job.company}
|
||||
</Typography>
|
||||
}
|
||||
{job.summary && <Typography variant="body2">
|
||||
<strong>Summary:</strong> {job.summary}
|
||||
</Typography>
|
||||
}
|
||||
</>}
|
||||
</CardContent>
|
||||
<CardActions>
|
||||
{isAdmin &&
|
||||
<DeleteConfirmation
|
||||
onDelete={() => { deleteJob(job.id); }}
|
||||
sx={{ minWidth: 'auto', px: 2, maxHeight: "min-content", color: "red" }}
|
||||
action="delete"
|
||||
label="job"
|
||||
title="Delete job"
|
||||
icon=<DeleteIcon />
|
||||
message={`Are you sure you want to delete ${job.id}? This action cannot be undone.`}
|
||||
/>}
|
||||
</CardActions>
|
||||
</Card>
|
||||
);
|
||||
};
|
||||
|
||||
export { JobInfo };
|
69
frontend/src/components/ui/JobPicker.tsx
Normal file
69
frontend/src/components/ui/JobPicker.tsx
Normal file
@ -0,0 +1,69 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { useNavigate } from "react-router-dom";
|
||||
import Button from '@mui/material/Button';
|
||||
import Box from '@mui/material/Box';
|
||||
|
||||
import { BackstoryElementProps } from 'components/BackstoryTab';
|
||||
import { JobInfo } from 'components/ui/JobInfo';
|
||||
import { Job, JobFull } from "types/types";
|
||||
import { useAuth } from 'hooks/AuthContext';
|
||||
import { useSelectedJob } from 'hooks/GlobalContext';
|
||||
|
||||
interface JobPickerProps extends BackstoryElementProps {
|
||||
onSelect?: (job: JobFull) => void
|
||||
};
|
||||
|
||||
const JobPicker = (props: JobPickerProps) => {
|
||||
const { onSelect } = props;
|
||||
const { apiClient } = useAuth();
|
||||
const { selectedJob, setSelectedJob } = useSelectedJob();
|
||||
const { setSnack } = props;
|
||||
const [jobs, setJobs] = useState<JobFull[] | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (jobs !== null) {
|
||||
return;
|
||||
}
|
||||
const getJobs = async () => {
|
||||
try {
|
||||
const results = await apiClient.getJobs();
|
||||
const jobs: JobFull[] = results.data;
|
||||
jobs.sort((a, b) => {
|
||||
let result = a.company?.localeCompare(b.company || '');
|
||||
if (result === 0) {
|
||||
result = a.title?.localeCompare(b.title || '');
|
||||
}
|
||||
return result || 0;
|
||||
});
|
||||
setJobs(jobs);
|
||||
} catch (err) {
|
||||
setSnack("" + err);
|
||||
}
|
||||
};
|
||||
|
||||
getJobs();
|
||||
}, [jobs, setSnack]);
|
||||
|
||||
return (
|
||||
<Box sx={{display: "flex", flexDirection: "column", mb: 1}}>
|
||||
<Box sx={{ display: "flex", gap: 1, flexWrap: "wrap", justifyContent: "center" }}>
|
||||
{jobs?.map((j, i) =>
|
||||
<Box key={`${j.id}`}
|
||||
onClick={() => { onSelect ? onSelect(j) : setSelectedJob(j); }}
|
||||
sx={{ cursor: "pointer" }}>
|
||||
{selectedJob?.id === j.id &&
|
||||
<JobInfo sx={{ maxWidth: "320px", "cursor": "pointer", backgroundColor: "#f0f0f0", "&:hover": { border: "2px solid orange" } }} job={j} />
|
||||
}
|
||||
{selectedJob?.id !== j.id &&
|
||||
<JobInfo sx={{ maxWidth: "320px", "cursor": "pointer", border: "2px solid transparent", "&:hover": { border: "2px solid orange" } }} job={j} />
|
||||
}
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export {
|
||||
JobPicker
|
||||
};
|
@ -17,7 +17,7 @@ import { ConversationHandle } from 'components/Conversation';
|
||||
import { BackstoryPageProps } from 'components/BackstoryTab';
|
||||
import { Message } from 'components/Message';
|
||||
import { DeleteConfirmation } from 'components/DeleteConfirmation';
|
||||
import { CandidateInfo } from 'components/CandidateInfo';
|
||||
import { CandidateInfo } from 'components/ui/CandidateInfo';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import { useSelectedCandidate } from 'hooks/GlobalContext';
|
||||
import PropagateLoader from 'react-spinners/PropagateLoader';
|
||||
|
@ -9,7 +9,7 @@ import CancelIcon from '@mui/icons-material/Cancel';
|
||||
import SendIcon from '@mui/icons-material/Send';
|
||||
import PropagateLoader from 'react-spinners/PropagateLoader';
|
||||
|
||||
import { CandidateInfo } from '../components/CandidateInfo';
|
||||
import { CandidateInfo } from '../components/ui/CandidateInfo';
|
||||
import { Quote } from 'components/Quote';
|
||||
import { BackstoryElementProps } from 'components/BackstoryTab';
|
||||
import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryTextField';
|
||||
|
@ -7,31 +7,74 @@ import {
|
||||
Button,
|
||||
Typography,
|
||||
Paper,
|
||||
Avatar,
|
||||
useTheme,
|
||||
Snackbar,
|
||||
Container,
|
||||
Grid,
|
||||
Alert,
|
||||
Tabs,
|
||||
Tab,
|
||||
Card,
|
||||
CardContent,
|
||||
Divider,
|
||||
Avatar,
|
||||
Badge,
|
||||
} from '@mui/material';
|
||||
import {
|
||||
Person,
|
||||
PersonAdd,
|
||||
AccountCircle,
|
||||
Add,
|
||||
WorkOutline,
|
||||
AddCircle,
|
||||
} from '@mui/icons-material';
|
||||
import PersonIcon from '@mui/icons-material/Person';
|
||||
import WorkIcon from '@mui/icons-material/Work';
|
||||
import AssessmentIcon from '@mui/icons-material/Assessment';
|
||||
import { JobMatchAnalysis } from 'components/JobMatchAnalysis';
|
||||
import { Candidate } from "types/types";
|
||||
import { Candidate, Job, JobFull } from "types/types";
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import { BackstoryPageProps } from 'components/BackstoryTab';
|
||||
import { useAuth } from 'hooks/AuthContext';
|
||||
import { useSelectedCandidate, useSelectedJob } from 'hooks/GlobalContext';
|
||||
import { CandidateInfo } from 'components/CandidateInfo';
|
||||
import { CandidateInfo } from 'components/ui/CandidateInfo';
|
||||
import { ComingSoon } from 'components/ui/ComingSoon';
|
||||
import { JobManagement } from 'components/JobManagement';
|
||||
import { LoginRequired } from 'components/ui/LoginRequired';
|
||||
import { Scrollable } from 'components/Scrollable';
|
||||
import { CandidatePicker } from 'components/ui/CandidatePicker';
|
||||
import { JobPicker } from 'components/ui/JobPicker';
|
||||
import { JobCreator } from 'components/JobCreator';
|
||||
|
||||
function WorkAddIcon() {
|
||||
return (
|
||||
<Box position="relative" display="inline-flex"
|
||||
sx={{
|
||||
lineHeight: "30px",
|
||||
mb: "6px",
|
||||
}}
|
||||
>
|
||||
<WorkOutline sx={{ fontSize: 24 }} />
|
||||
<Add
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
bottom: -2,
|
||||
right: -2,
|
||||
fontSize: 14,
|
||||
bgcolor: 'background.paper',
|
||||
borderRadius: '50%',
|
||||
boxShadow: 1,
|
||||
}}
|
||||
color="primary"
|
||||
/>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
// Main component
|
||||
const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps) => {
|
||||
const theme = useTheme();
|
||||
const { user } = useAuth();
|
||||
const { user, apiClient } = useAuth();
|
||||
const navigate = useNavigate();
|
||||
const { selectedCandidate, setSelectedCandidate } = useSelectedCandidate()
|
||||
const { selectedJob, setSelectedJob } = useSelectedJob()
|
||||
@ -41,12 +84,21 @@ const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps
|
||||
const [activeStep, setActiveStep] = useState(0);
|
||||
const [analysisStarted, setAnalysisStarted] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [jobTab, setJobTab] = useState<string>('load');
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedJob && activeStep === 1) {
|
||||
setActiveStep(2);
|
||||
console.log({ activeStep, selectedCandidate, selectedJob });
|
||||
|
||||
if (!selectedCandidate) {
|
||||
if (activeStep !== 0) {
|
||||
setActiveStep(0);
|
||||
}
|
||||
} else if (!selectedJob) {
|
||||
if (activeStep !== 1) {
|
||||
setActiveStep(1);
|
||||
}
|
||||
}
|
||||
}, [selectedJob, activeStep]);
|
||||
}, [selectedCandidate, selectedJob, activeStep])
|
||||
|
||||
// Steps in our process
|
||||
const steps = [
|
||||
@ -93,7 +145,6 @@ const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps
|
||||
setSelectedJob(null);
|
||||
break;
|
||||
case 1: /* Select Job */
|
||||
setSelectedCandidate(null);
|
||||
setSelectedJob(null);
|
||||
break;
|
||||
case 2: /* Job Analysis */
|
||||
@ -109,6 +160,11 @@ const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps
|
||||
setActiveStep(1);
|
||||
}
|
||||
|
||||
const onJobSelect = (job: Job) => {
|
||||
setSelectedJob(job)
|
||||
setActiveStep(2);
|
||||
}
|
||||
|
||||
// Render function for the candidate selection step
|
||||
const renderCandidateSelection = () => (
|
||||
<Paper elevation={3} sx={{ p: 3, mt: 3, mb: 4, borderRadius: 2 }}>
|
||||
@ -120,16 +176,35 @@ const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps
|
||||
</Paper>
|
||||
);
|
||||
|
||||
const handleTabChange = (event: React.SyntheticEvent, value: string) => {
|
||||
setJobTab(value);
|
||||
};
|
||||
|
||||
// Render function for the job description step
|
||||
const renderJobDescription = () => (
|
||||
<Box sx={{ mt: 3 }}>
|
||||
{selectedCandidate && (
|
||||
<JobManagement
|
||||
const renderJobDescription = () => {
|
||||
if (!selectedCandidate) {
|
||||
return;
|
||||
}
|
||||
|
||||
return (<Box sx={{ mt: 3 }}>
|
||||
<Box sx={{ borderBottom: 1, borderColor: 'divider', mb: 3 }}>
|
||||
<Tabs value={jobTab} onChange={handleTabChange} centered>
|
||||
<Tab value='load' icon={<WorkOutline />} label="Load" />
|
||||
<Tab value='create' icon={<WorkAddIcon />} label="Create" />
|
||||
</Tabs>
|
||||
</Box>
|
||||
|
||||
{jobTab === 'load' &&
|
||||
<JobPicker {...backstoryProps} onSelect={onJobSelect} />
|
||||
}
|
||||
{jobTab === 'create' &&
|
||||
<JobCreator
|
||||
{...backstoryProps}
|
||||
/>
|
||||
)}
|
||||
onSave={onJobSelect}
|
||||
/>}
|
||||
</Box>
|
||||
);
|
||||
);
|
||||
}
|
||||
|
||||
// Render function for the analysis step
|
||||
const renderAnalysis = () => (
|
||||
@ -232,7 +307,7 @@ const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps
|
||||
</Button>
|
||||
) : (
|
||||
<Button onClick={handleNext} variant="contained">
|
||||
{activeStep === steps[steps.length - 1].index - 1 ? 'Start Analysis' : 'Next'}
|
||||
{activeStep === steps[steps.length - 1].index - 1 ? 'Done' : 'Next'}
|
||||
</Button>
|
||||
)}
|
||||
</Box>
|
||||
|
@ -7,7 +7,7 @@ import MuiMarkdown from 'mui-markdown';
|
||||
import { BackstoryPageProps } from '../components/BackstoryTab';
|
||||
import { Conversation, ConversationHandle } from '../components/Conversation';
|
||||
import { BackstoryQuery } from '../components/BackstoryQuery';
|
||||
import { CandidateInfo } from 'components/CandidateInfo';
|
||||
import { CandidateInfo } from 'components/ui/CandidateInfo';
|
||||
import { useAuth } from 'hooks/AuthContext';
|
||||
import { Candidate } from 'types/types';
|
||||
|
||||
|
@ -52,7 +52,7 @@ interface StreamingOptions<T = Types.ChatMessage> {
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
interface DeleteCandidateResponse {
|
||||
interface DeleteResponse {
|
||||
success: boolean;
|
||||
message: string;
|
||||
}
|
||||
@ -530,14 +530,24 @@ class ApiClient {
|
||||
return this.handleApiResponseWithConversion<Types.Candidate>(response, 'Candidate');
|
||||
}
|
||||
|
||||
async deleteCandidate(id: string): Promise<DeleteCandidateResponse> {
|
||||
async deleteCandidate(id: string): Promise<DeleteResponse> {
|
||||
const response = await fetch(`${this.baseUrl}/candidates/${id}`, {
|
||||
method: 'DELETE',
|
||||
headers: this.defaultHeaders,
|
||||
body: JSON.stringify({ id })
|
||||
});
|
||||
|
||||
return handleApiResponse<DeleteCandidateResponse>(response);
|
||||
return handleApiResponse<DeleteResponse>(response);
|
||||
}
|
||||
|
||||
async deleteJob(id: string): Promise<DeleteResponse> {
|
||||
const response = await fetch(`${this.baseUrl}/jobs/${id}`, {
|
||||
method: 'DELETE',
|
||||
headers: this.defaultHeaders,
|
||||
body: JSON.stringify({ id })
|
||||
});
|
||||
|
||||
return handleApiResponse<DeleteResponse>(response);
|
||||
}
|
||||
|
||||
async uploadCandidateProfile(file: File): Promise<boolean> {
|
||||
@ -641,7 +651,7 @@ class ApiClient {
|
||||
return this.handleApiResponseWithConversion<Types.Job>(response, 'Job');
|
||||
}
|
||||
|
||||
async getJobs(request: Partial<PaginatedRequest> = {}): Promise<PaginatedResponse<Types.Job>> {
|
||||
async getJobs(request: Partial<PaginatedRequest> = {}): Promise<PaginatedResponse<Types.JobFull>> {
|
||||
const paginatedRequest = createPaginatedRequest(request);
|
||||
const params = toUrlParams(formatApiRequest(paginatedRequest));
|
||||
|
||||
@ -649,7 +659,7 @@ class ApiClient {
|
||||
headers: this.defaultHeaders
|
||||
});
|
||||
|
||||
return this.handlePaginatedApiResponseWithConversion<Types.Job>(response, 'Job');
|
||||
return this.handlePaginatedApiResponseWithConversion<Types.JobFull>(response, 'JobFull');
|
||||
}
|
||||
|
||||
async getJobsByEmployer(employerId: string, request: Partial<PaginatedRequest> = {}): Promise<PaginatedResponse<Types.Job>> {
|
||||
@ -844,18 +854,28 @@ class ApiClient {
|
||||
}
|
||||
};
|
||||
return this.streamify<Types.DocumentMessage>('/candidates/documents/upload', formData, streamingOptions);
|
||||
// {
|
||||
// method: 'POST',
|
||||
// headers: {
|
||||
// // Don't set Content-Type - browser will set it automatically with boundary
|
||||
// 'Authorization': this.defaultHeaders['Authorization']
|
||||
// },
|
||||
// body: formData
|
||||
// });
|
||||
}
|
||||
|
||||
// const result = await handleApiResponse<Types.Document>(response);
|
||||
|
||||
// return result;
|
||||
createJobFromFile(file: File, streamingOptions?: StreamingOptions<Types.Job>): StreamingResponse<Types.Job> {
|
||||
const formData = new FormData()
|
||||
formData.append('file', file);
|
||||
formData.append('filename', file.name);
|
||||
streamingOptions = {
|
||||
...streamingOptions,
|
||||
headers: {
|
||||
// Don't set Content-Type - browser will set it automatically with boundary
|
||||
'Authorization': this.defaultHeaders['Authorization']
|
||||
}
|
||||
};
|
||||
return this.streamify<Types.Job>('/jobs/upload', formData, streamingOptions);
|
||||
}
|
||||
|
||||
getJobRequirements(jobId: string, streamingOptions?: StreamingOptions<Types.DocumentMessage>): StreamingResponse<Types.DocumentMessage> {
|
||||
streamingOptions = {
|
||||
...streamingOptions,
|
||||
headers: this.defaultHeaders,
|
||||
};
|
||||
return this.streamify<Types.DocumentMessage>(`/jobs/requirements/${jobId}`, null, streamingOptions);
|
||||
}
|
||||
|
||||
async candidateMatchForRequirement(candidate_id: string, requirement: string) : Promise<Types.SkillMatch> {
|
||||
@ -1001,7 +1021,7 @@ class ApiClient {
|
||||
* @param options callbacks, headers, and method
|
||||
* @returns
|
||||
*/
|
||||
streamify<T = Types.ChatMessage[]>(api: string, data: BodyInit, options: StreamingOptions<T> = {}) : StreamingResponse<T> {
|
||||
streamify<T = Types.ChatMessage[]>(api: string, data: BodyInit | null, options: StreamingOptions<T> = {}) : StreamingResponse<T> {
|
||||
const abortController = new AbortController();
|
||||
const signal = options.signal || abortController.signal;
|
||||
const headers = options.headers || null;
|
||||
|
@ -1,6 +1,6 @@
|
||||
// Generated TypeScript types from Pydantic models
|
||||
// Source: src/backend/models.py
|
||||
// Generated on: 2025-06-05T22:02:22.004513
|
||||
// Generated on: 2025-06-07T20:43:58.855207
|
||||
// DO NOT EDIT MANUALLY - This file is auto-generated
|
||||
|
||||
// ============================
|
||||
@ -323,7 +323,7 @@ export interface ChatMessageMetaData {
|
||||
presencePenalty?: number;
|
||||
stopSequences?: Array<string>;
|
||||
ragResults?: Array<ChromaDBGetResponse>;
|
||||
llmHistory?: Array<LLMMessage>;
|
||||
llmHistory?: Array<any>;
|
||||
evalCount: number;
|
||||
evalDuration: number;
|
||||
promptEvalCount: number;
|
||||
@ -711,12 +711,6 @@ export interface JobResponse {
|
||||
meta?: Record<string, any>;
|
||||
}
|
||||
|
||||
export interface LLMMessage {
|
||||
role: string;
|
||||
content: string;
|
||||
toolCalls?: Array<Record<string, any>>;
|
||||
}
|
||||
|
||||
export interface Language {
|
||||
language: string;
|
||||
proficiency: "basic" | "conversational" | "fluent" | "native";
|
||||
|
@ -137,7 +137,7 @@ class Agent(BaseModel, ABC):
|
||||
# llm: Any,
|
||||
# model: str,
|
||||
# message: ChatMessage,
|
||||
# tool_message: Any, # llama response message
|
||||
# tool_message: Any, # llama response
|
||||
# messages: List[LLMMessage],
|
||||
# ) -> AsyncGenerator[ChatMessage, None]:
|
||||
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||
@ -270,15 +270,15 @@ class Agent(BaseModel, ABC):
|
||||
# },
|
||||
# stream=True,
|
||||
# ):
|
||||
# # logger.info(f"LLM::Tools: {'done' if response.done else 'processing'} - {response.message}")
|
||||
# # logger.info(f"LLM::Tools: {'done' if response.finish_reason else 'processing'} - {response}")
|
||||
# message.status = "streaming"
|
||||
# message.chunk = response.message.content
|
||||
# message.chunk = response.content
|
||||
# message.content += message.chunk
|
||||
|
||||
# if not response.done:
|
||||
# if not response.finish_reason:
|
||||
# yield message
|
||||
|
||||
# if response.done:
|
||||
# if response.finish_reason:
|
||||
# self.collect_metrics(response)
|
||||
# message.metadata.eval_count += response.eval_count
|
||||
# message.metadata.eval_duration += response.eval_duration
|
||||
@ -296,9 +296,9 @@ class Agent(BaseModel, ABC):
|
||||
|
||||
def collect_metrics(self, response):
|
||||
self.metrics.tokens_prompt.labels(agent=self.agent_type).inc(
|
||||
response.prompt_eval_count
|
||||
response.usage.prompt_eval_count
|
||||
)
|
||||
self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count)
|
||||
self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.usage.eval_count)
|
||||
|
||||
def get_rag_context(self, rag_message: ChatMessageRagSearch) -> str:
|
||||
"""
|
||||
@ -361,7 +361,7 @@ Content: {content}
|
||||
yield status_message
|
||||
|
||||
try:
|
||||
chroma_results = user.file_watcher.find_similar(
|
||||
chroma_results = await user.file_watcher.find_similar(
|
||||
query=prompt, top_k=top_k, threshold=threshold
|
||||
)
|
||||
if not chroma_results:
|
||||
@ -430,7 +430,7 @@ Content: {content}
|
||||
logger.info(f"Message options: {options.model_dump(exclude_unset=True)}")
|
||||
response = None
|
||||
content = ""
|
||||
for response in llm.chat(
|
||||
async for response in llm.chat_stream(
|
||||
model=model,
|
||||
messages=messages,
|
||||
options={
|
||||
@ -446,12 +446,12 @@ Content: {content}
|
||||
yield error_message
|
||||
return
|
||||
|
||||
content += response.message.content
|
||||
content += response.content
|
||||
|
||||
if not response.done:
|
||||
if not response.finish_reason:
|
||||
streaming_message = ChatMessageStreaming(
|
||||
session_id=session_id,
|
||||
content=response.message.content,
|
||||
content=response.content,
|
||||
status=ApiStatusType.STREAMING,
|
||||
)
|
||||
yield streaming_message
|
||||
@ -466,7 +466,7 @@ Content: {content}
|
||||
|
||||
self.collect_metrics(response)
|
||||
self.context_tokens = (
|
||||
response.prompt_eval_count + response.eval_count
|
||||
response.usage.prompt_eval_count + response.usage.eval_count
|
||||
)
|
||||
|
||||
chat_message = ChatMessage(
|
||||
@ -476,10 +476,10 @@ Content: {content}
|
||||
content=content,
|
||||
metadata = ChatMessageMetaData(
|
||||
options=options,
|
||||
eval_count=response.eval_count,
|
||||
eval_duration=response.eval_duration,
|
||||
prompt_eval_count=response.prompt_eval_count,
|
||||
prompt_eval_duration=response.prompt_eval_duration,
|
||||
eval_count=response.usage.eval_count,
|
||||
eval_duration=response.usage.eval_duration,
|
||||
prompt_eval_count=response.usage.prompt_eval_count,
|
||||
prompt_eval_duration=response.usage.prompt_eval_duration,
|
||||
|
||||
)
|
||||
)
|
||||
@ -588,12 +588,12 @@ Content: {content}
|
||||
|
||||
# end_time = time.perf_counter()
|
||||
# message.metadata.timers["tool_check"] = end_time - start_time
|
||||
# if not response.message.tool_calls:
|
||||
# if not response.tool_calls:
|
||||
# logger.info("LLM indicates tools will not be used")
|
||||
# # The LLM will not use tools, so disable use_tools so we can stream the full response
|
||||
# use_tools = False
|
||||
# else:
|
||||
# tool_metadata["attempted"] = response.message.tool_calls
|
||||
# tool_metadata["attempted"] = response.tool_calls
|
||||
|
||||
# if use_tools:
|
||||
# logger.info("LLM indicates tools will be used")
|
||||
@ -626,15 +626,15 @@ Content: {content}
|
||||
# yield message
|
||||
# return
|
||||
|
||||
# if response.message.tool_calls:
|
||||
# tool_metadata["used"] = response.message.tool_calls
|
||||
# if response.tool_calls:
|
||||
# tool_metadata["used"] = response.tool_calls
|
||||
# # Process all yielded items from the handler
|
||||
# start_time = time.perf_counter()
|
||||
# async for message in self.process_tool_calls(
|
||||
# llm=llm,
|
||||
# model=model,
|
||||
# message=message,
|
||||
# tool_message=response.message,
|
||||
# tool_message=response,
|
||||
# messages=messages,
|
||||
# ):
|
||||
# if message.status == "error":
|
||||
@ -647,7 +647,7 @@ Content: {content}
|
||||
# return
|
||||
|
||||
# logger.info("LLM indicated tools will be used, and then they weren't")
|
||||
# message.content = response.message.content
|
||||
# message.content = response.content
|
||||
# message.status = "done"
|
||||
# yield message
|
||||
# return
|
||||
@ -674,7 +674,7 @@ Content: {content}
|
||||
content = ""
|
||||
start_time = time.perf_counter()
|
||||
response = None
|
||||
for response in llm.chat(
|
||||
async for response in llm.chat_stream(
|
||||
model=model,
|
||||
messages=messages,
|
||||
options={
|
||||
@ -690,12 +690,12 @@ Content: {content}
|
||||
yield error_message
|
||||
return
|
||||
|
||||
content += response.message.content
|
||||
content += response.content
|
||||
|
||||
if not response.done:
|
||||
if not response.finish_reason:
|
||||
streaming_message = ChatMessageStreaming(
|
||||
session_id=session_id,
|
||||
content=response.message.content,
|
||||
content=response.content,
|
||||
)
|
||||
yield streaming_message
|
||||
|
||||
@ -709,7 +709,7 @@ Content: {content}
|
||||
|
||||
self.collect_metrics(response)
|
||||
self.context_tokens = (
|
||||
response.prompt_eval_count + response.eval_count
|
||||
response.usage.prompt_eval_count + response.usage.eval_count
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
@ -720,10 +720,10 @@ Content: {content}
|
||||
content=content,
|
||||
metadata = ChatMessageMetaData(
|
||||
options=options,
|
||||
eval_count=response.eval_count,
|
||||
eval_duration=response.eval_duration,
|
||||
prompt_eval_count=response.prompt_eval_count,
|
||||
prompt_eval_duration=response.prompt_eval_duration,
|
||||
eval_count=response.usage.eval_count,
|
||||
eval_duration=response.usage.eval_duration,
|
||||
prompt_eval_count=response.usage.prompt_eval_count,
|
||||
prompt_eval_duration=response.usage.prompt_eval_duration,
|
||||
timers={
|
||||
"llm_streamed": end_time - start_time,
|
||||
"llm_with_tools": 0, # Placeholder for tool processing time
|
||||
|
@ -178,12 +178,13 @@ class RedisDatabase:
|
||||
'jobs': 'job:',
|
||||
'job_applications': 'job_application:',
|
||||
'chat_sessions': 'chat_session:',
|
||||
'chat_messages': 'chat_messages:', # This will store lists
|
||||
'chat_messages': 'chat_messages:',
|
||||
'ai_parameters': 'ai_parameters:',
|
||||
'users': 'user:',
|
||||
'candidate_documents': 'candidate_documents:',
|
||||
'candidate_documents': 'candidate_documents:',
|
||||
'job_requirements': 'job_requirements:', # Add this line
|
||||
}
|
||||
|
||||
|
||||
def _serialize(self, data: Any) -> str:
|
||||
"""Serialize data to JSON string for Redis storage"""
|
||||
if data is None:
|
||||
@ -236,8 +237,9 @@ class RedisDatabase:
|
||||
# Delete each document's metadata
|
||||
for doc_id in document_ids:
|
||||
pipe.delete(f"document:{doc_id}")
|
||||
pipe.delete(f"{self.KEY_PREFIXES['job_requirements']}{doc_id}")
|
||||
deleted_count += 1
|
||||
|
||||
|
||||
# Delete the candidate's document list
|
||||
pipe.delete(key)
|
||||
|
||||
@ -250,7 +252,110 @@ class RedisDatabase:
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting all documents for candidate {candidate_id}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_cached_skill_match(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""Retrieve cached skill match assessment"""
|
||||
try:
|
||||
cached_data = await self.redis.get(cache_key)
|
||||
if cached_data:
|
||||
return json.loads(cached_data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving cached skill match: {e}")
|
||||
return None
|
||||
|
||||
async def cache_skill_match(self, cache_key: str, assessment_data: Dict[str, Any], ttl: int = 86400 * 30) -> bool:
|
||||
"""Cache skill match assessment with TTL (default 30 days)"""
|
||||
try:
|
||||
await self.redis.setex(
|
||||
cache_key,
|
||||
ttl,
|
||||
json.dumps(assessment_data, default=str)
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching skill match: {e}")
|
||||
return False
|
||||
|
||||
async def get_candidate_skill_update_time(self, candidate_id: str) -> Optional[datetime]:
|
||||
"""Get the last time candidate's skill information was updated"""
|
||||
try:
|
||||
# This assumes you track skill update timestamps in your candidate data
|
||||
candidate_data = await self.get_candidate(candidate_id)
|
||||
if candidate_data and 'skills_updated_at' in candidate_data:
|
||||
return datetime.fromisoformat(candidate_data['skills_updated_at'])
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting candidate skill update time: {e}")
|
||||
return None
|
||||
|
||||
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]:
|
||||
"""Get the timestamp of the latest RAG data update for a specific user"""
|
||||
try:
|
||||
rag_update_key = f"user:{user_id}:rag_last_update"
|
||||
timestamp_str = await self.redis.get(rag_update_key)
|
||||
if timestamp_str:
|
||||
return datetime.fromisoformat(timestamp_str.decode('utf-8'))
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user RAG update time for user {user_id}: {e}")
|
||||
return None
|
||||
|
||||
async def update_user_rag_timestamp(self, user_id: str) -> bool:
|
||||
"""Update the RAG data timestamp for a specific user (call this when user's RAG data is updated)"""
|
||||
try:
|
||||
rag_update_key = f"user:{user_id}:rag_last_update"
|
||||
current_time = datetime.utcnow().isoformat()
|
||||
await self.redis.set(rag_update_key, current_time)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating RAG timestamp for user {user_id}: {e}")
|
||||
return False
|
||||
|
||||
async def invalidate_candidate_skill_cache(self, candidate_id: str) -> int:
|
||||
"""Invalidate all cached skill matches for a specific candidate"""
|
||||
try:
|
||||
pattern = f"skill_match:{candidate_id}:*"
|
||||
keys = await self.redis.keys(pattern)
|
||||
if keys:
|
||||
return await self.redis.delete(*keys)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error invalidating candidate skill cache: {e}")
|
||||
return 0
|
||||
|
||||
async def clear_all_skill_match_cache(self) -> int:
|
||||
"""Clear all skill match cache (useful after major system updates)"""
|
||||
try:
|
||||
pattern = "skill_match:*"
|
||||
keys = await self.redis.keys(pattern)
|
||||
if keys:
|
||||
return await self.redis.delete(*keys)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing skill match cache: {e}")
|
||||
return 0
|
||||
|
||||
async def invalidate_user_skill_cache(self, user_id: str) -> int:
|
||||
"""Invalidate all cached skill matches when a user's RAG data is updated"""
|
||||
try:
|
||||
# This assumes all candidates belonging to this user need cache invalidation
|
||||
# You might need to adjust the pattern based on how you associate candidates with users
|
||||
pattern = f"skill_match:*"
|
||||
keys = await self.redis.keys(pattern)
|
||||
|
||||
# Filter keys that belong to candidates owned by this user
|
||||
# This would require additional logic to determine candidate ownership
|
||||
# For now, you might want to clear all cache when any user's RAG data updates
|
||||
# or implement a more sophisticated mapping
|
||||
|
||||
if keys:
|
||||
return await self.redis.delete(*keys)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error invalidating user skill cache for user {user_id}: {e}")
|
||||
return 0
|
||||
|
||||
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]:
|
||||
"""Get all documents for a specific candidate"""
|
||||
key = f"{self.KEY_PREFIXES['candidate_documents']}{candidate_id}"
|
||||
@ -330,7 +435,134 @@ class RedisDatabase:
|
||||
if (query_lower in doc.get("filename", "").lower() or
|
||||
query_lower in doc.get("originalName", "").lower())
|
||||
]
|
||||
|
||||
|
||||
async def get_job_requirements(self, document_id: str) -> Optional[Dict]:
|
||||
"""Get cached job requirements analysis for a document"""
|
||||
try:
|
||||
key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}"
|
||||
data = await self.redis.get(key)
|
||||
if data:
|
||||
requirements_data = self._deserialize(data)
|
||||
logger.debug(f"📋 Retrieved cached job requirements for document {document_id}")
|
||||
return requirements_data
|
||||
logger.debug(f"📋 No cached job requirements found for document {document_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error retrieving job requirements for document {document_id}: {e}")
|
||||
return None
|
||||
|
||||
async def save_job_requirements(self, document_id: str, requirements: Dict) -> bool:
|
||||
"""Save job requirements analysis results for a document"""
|
||||
try:
|
||||
key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}"
|
||||
|
||||
# Add metadata to the requirements
|
||||
requirements_with_meta = {
|
||||
**requirements,
|
||||
"cached_at": datetime.now(UTC).isoformat(),
|
||||
"document_id": document_id
|
||||
}
|
||||
|
||||
await self.redis.set(key, self._serialize(requirements_with_meta))
|
||||
|
||||
# Optional: Set expiration (e.g., 30 days) to prevent indefinite storage
|
||||
# await self.redis.expire(key, 30 * 24 * 60 * 60) # 30 days
|
||||
|
||||
logger.debug(f"📋 Saved job requirements for document {document_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error saving job requirements for document {document_id}: {e}")
|
||||
return False
|
||||
|
||||
async def delete_job_requirements(self, document_id: str) -> bool:
|
||||
"""Delete cached job requirements for a document"""
|
||||
try:
|
||||
key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}"
|
||||
result = await self.redis.delete(key)
|
||||
if result > 0:
|
||||
logger.debug(f"📋 Deleted job requirements for document {document_id}")
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error deleting job requirements for document {document_id}: {e}")
|
||||
return False
|
||||
|
||||
async def get_all_job_requirements(self) -> Dict[str, Any]:
|
||||
"""Get all cached job requirements"""
|
||||
try:
|
||||
pattern = f"{self.KEY_PREFIXES['job_requirements']}*"
|
||||
keys = await self.redis.keys(pattern)
|
||||
|
||||
if not keys:
|
||||
return {}
|
||||
|
||||
pipe = self.redis.pipeline()
|
||||
for key in keys:
|
||||
pipe.get(key)
|
||||
values = await pipe.execute()
|
||||
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
document_id = key.replace(self.KEY_PREFIXES['job_requirements'], '')
|
||||
if value:
|
||||
result[document_id] = self._deserialize(value)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error retrieving all job requirements: {e}")
|
||||
return {}
|
||||
|
||||
async def get_job_requirements_by_candidate(self, candidate_id: str) -> List[Dict]:
|
||||
"""Get all job requirements analysis for documents belonging to a candidate"""
|
||||
try:
|
||||
# Get all documents for the candidate
|
||||
candidate_documents = await self.get_candidate_documents(candidate_id)
|
||||
|
||||
if not candidate_documents:
|
||||
return []
|
||||
|
||||
# Get job requirements for each document
|
||||
job_requirements = []
|
||||
for doc in candidate_documents:
|
||||
doc_id = doc.get("id")
|
||||
if doc_id:
|
||||
requirements = await self.get_job_requirements(doc_id)
|
||||
if requirements:
|
||||
# Add document metadata to requirements
|
||||
requirements["document_filename"] = doc.get("filename")
|
||||
requirements["document_original_name"] = doc.get("originalName")
|
||||
job_requirements.append(requirements)
|
||||
|
||||
return job_requirements
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error retrieving job requirements for candidate {candidate_id}: {e}")
|
||||
return []
|
||||
|
||||
async def invalidate_job_requirements_cache(self, document_id: str) -> bool:
|
||||
"""Invalidate (delete) cached job requirements for a document"""
|
||||
# This is an alias for delete_job_requirements for semantic clarity
|
||||
return await self.delete_job_requirements(document_id)
|
||||
|
||||
async def bulk_delete_job_requirements(self, document_ids: List[str]) -> int:
|
||||
"""Delete job requirements for multiple documents and return count of deleted items"""
|
||||
try:
|
||||
deleted_count = 0
|
||||
pipe = self.redis.pipeline()
|
||||
|
||||
for doc_id in document_ids:
|
||||
key = f"{self.KEY_PREFIXES['job_requirements']}{doc_id}"
|
||||
pipe.delete(key)
|
||||
deleted_count += 1
|
||||
|
||||
results = await pipe.execute()
|
||||
actual_deleted = sum(1 for result in results if result > 0)
|
||||
|
||||
logger.info(f"📋 Bulk deleted job requirements for {actual_deleted}/{len(document_ids)} documents")
|
||||
return actual_deleted
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error bulk deleting job requirements: {e}")
|
||||
return 0
|
||||
|
||||
# Viewer operations
|
||||
async def get_viewer(self, viewer_id: str) -> Optional[Dict]:
|
||||
"""Get viewer by ID"""
|
||||
@ -1484,6 +1716,74 @@ class RedisDatabase:
|
||||
key = f"{self.KEY_PREFIXES['users']}{email}"
|
||||
await self.redis.delete(key)
|
||||
|
||||
async def get_job_requirements_stats(self) -> Dict[str, Any]:
|
||||
"""Get statistics about cached job requirements"""
|
||||
try:
|
||||
pattern = f"{self.KEY_PREFIXES['job_requirements']}*"
|
||||
keys = await self.redis.keys(pattern)
|
||||
|
||||
stats = {
|
||||
"total_cached_requirements": len(keys),
|
||||
"cache_dates": {},
|
||||
"documents_with_requirements": []
|
||||
}
|
||||
|
||||
if keys:
|
||||
# Get cache dates for analysis
|
||||
pipe = self.redis.pipeline()
|
||||
for key in keys:
|
||||
pipe.get(key)
|
||||
values = await pipe.execute()
|
||||
|
||||
for key, value in zip(keys, values):
|
||||
if value:
|
||||
requirements_data = self._deserialize(value)
|
||||
if requirements_data:
|
||||
document_id = key.replace(self.KEY_PREFIXES['job_requirements'], '')
|
||||
stats["documents_with_requirements"].append(document_id)
|
||||
|
||||
# Track cache dates
|
||||
cached_at = requirements_data.get("cached_at")
|
||||
if cached_at:
|
||||
cache_date = cached_at[:10] # Extract date part
|
||||
stats["cache_dates"][cache_date] = stats["cache_dates"].get(cache_date, 0) + 1
|
||||
|
||||
return stats
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting job requirements stats: {e}")
|
||||
return {"total_cached_requirements": 0, "cache_dates": {}, "documents_with_requirements": []}
|
||||
|
||||
async def cleanup_orphaned_job_requirements(self) -> int:
|
||||
"""Clean up job requirements for documents that no longer exist"""
|
||||
try:
|
||||
# Get all job requirements
|
||||
all_requirements = await self.get_all_job_requirements()
|
||||
|
||||
if not all_requirements:
|
||||
return 0
|
||||
|
||||
orphaned_count = 0
|
||||
pipe = self.redis.pipeline()
|
||||
|
||||
for document_id in all_requirements.keys():
|
||||
# Check if the document still exists
|
||||
document_exists = await self.get_document(document_id)
|
||||
if not document_exists:
|
||||
# Document no longer exists, delete its job requirements
|
||||
key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}"
|
||||
pipe.delete(key)
|
||||
orphaned_count += 1
|
||||
logger.debug(f"📋 Queued orphaned job requirements for deletion: {document_id}")
|
||||
|
||||
if orphaned_count > 0:
|
||||
await pipe.execute()
|
||||
logger.info(f"🧹 Cleaned up {orphaned_count} orphaned job requirements")
|
||||
|
||||
return orphaned_count
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error cleaning up orphaned job requirements: {e}")
|
||||
return 0
|
||||
|
||||
# Utility methods
|
||||
async def clear_all_data(self):
|
||||
"""Clear all data from Redis (use with caution!)"""
|
||||
|
@ -19,8 +19,9 @@ import defines
|
||||
from logger import logger
|
||||
import agents as agents
|
||||
from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, RagEntry, ChatMessageMetaData, ApiStatusType, Candidate, ChatContextType)
|
||||
from llm_manager import llm_manager
|
||||
import llm_proxy as llm_manager
|
||||
from agents.base import Agent
|
||||
from database import RedisDatabase
|
||||
|
||||
class CandidateEntity(Candidate):
|
||||
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
|
||||
@ -115,7 +116,7 @@ class CandidateEntity(Candidate):
|
||||
raise ValueError("initialize() has not been called.")
|
||||
return self.CandidateEntity__observer
|
||||
|
||||
async def initialize(self, prometheus_collector: CollectorRegistry):
|
||||
async def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
|
||||
if self.CandidateEntity__initialized:
|
||||
# Initialization can only be attempted once; if there are multiple attempts, it means
|
||||
# a subsystem is failing or there is a logic bug in the code.
|
||||
@ -140,9 +141,11 @@ class CandidateEntity(Candidate):
|
||||
|
||||
self.CandidateEntity__observer, self.CandidateEntity__file_watcher = start_file_watcher(
|
||||
llm=llm_manager.get_llm(),
|
||||
user_id=self.id,
|
||||
collection_name=self.username,
|
||||
persist_directory=vector_db_dir,
|
||||
watch_directory=rag_content_dir,
|
||||
database=database,
|
||||
recreate=False, # Don't recreate if exists
|
||||
)
|
||||
has_username_rag = any(item["name"] == self.username for item in self.rags)
|
||||
|
@ -7,6 +7,7 @@ from pydantic import BaseModel, Field # type: ignore
|
||||
|
||||
from models import ( Candidate )
|
||||
from .candidate_entity import CandidateEntity
|
||||
from database import RedisDatabase
|
||||
from prometheus_client import CollectorRegistry # type: ignore
|
||||
|
||||
class EntityManager:
|
||||
@ -34,9 +35,10 @@ class EntityManager:
|
||||
pass
|
||||
self._cleanup_task = None
|
||||
|
||||
def initialize(self, prometheus_collector: CollectorRegistry):
|
||||
def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
|
||||
"""Initialize the EntityManager with Prometheus collector"""
|
||||
self._prometheus_collector = prometheus_collector
|
||||
self._database = database
|
||||
|
||||
async def get_entity(self, candidate: Candidate) -> CandidateEntity:
|
||||
"""Get or create CandidateEntity with proper reference tracking"""
|
||||
@ -49,7 +51,7 @@ class EntityManager:
|
||||
return entity
|
||||
|
||||
entity = CandidateEntity(candidate=candidate)
|
||||
await entity.initialize(prometheus_collector=self._prometheus_collector)
|
||||
await entity.initialize(prometheus_collector=self._prometheus_collector, database=self._database)
|
||||
|
||||
# Store with reference tracking
|
||||
self._entities[candidate.id] = entity
|
||||
|
@ -1,41 +0,0 @@
|
||||
import ollama
|
||||
import defines
|
||||
|
||||
_llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
|
||||
|
||||
class llm_manager:
|
||||
"""
|
||||
A class to manage LLM operations using the Ollama client.
|
||||
"""
|
||||
@staticmethod
|
||||
def get_llm() -> ollama.Client: # type: ignore
|
||||
"""
|
||||
Get the Ollama client instance.
|
||||
|
||||
Returns:
|
||||
An instance of the Ollama client.
|
||||
"""
|
||||
return _llm
|
||||
|
||||
@staticmethod
|
||||
def get_models() -> list[str]:
|
||||
"""
|
||||
Get a list of available models from the Ollama client.
|
||||
|
||||
Returns:
|
||||
List of model names.
|
||||
"""
|
||||
return _llm.models()
|
||||
|
||||
@staticmethod
|
||||
def get_model_info(model_name: str) -> dict:
|
||||
"""
|
||||
Get information about a specific model.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model to retrieve information for.
|
||||
|
||||
Returns:
|
||||
A dictionary containing model information.
|
||||
"""
|
||||
return _llm.model(model_name)
|
1203
src/backend/llm_proxy.py
Normal file
1203
src/backend/llm_proxy.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -13,6 +13,9 @@ import uuid
|
||||
import defines
|
||||
import pathlib
|
||||
|
||||
from markitdown import MarkItDown, StreamInfo # type: ignore
|
||||
import io
|
||||
|
||||
import uvicorn # type: ignore
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime, timedelta, UTC
|
||||
@ -53,7 +56,7 @@ import defines
|
||||
from logger import logger
|
||||
from database import RedisDatabase, redis_manager, DatabaseManager
|
||||
from metrics import Metrics
|
||||
from llm_manager import llm_manager
|
||||
import llm_proxy as llm_manager
|
||||
import entities
|
||||
from email_service import VerificationEmailRateLimiter, email_service
|
||||
from device_manager import DeviceManager
|
||||
@ -116,7 +119,8 @@ async def lifespan(app: FastAPI):
|
||||
try:
|
||||
# Initialize database
|
||||
await db_manager.initialize()
|
||||
|
||||
entities.entity_manager.initialize(prometheus_collector, database=db_manager.get_database())
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
@ -1827,7 +1831,7 @@ async def upload_candidate_document(
|
||||
yield error_message
|
||||
return
|
||||
|
||||
converted = False;
|
||||
converted = False
|
||||
if document_type != DocumentType.MARKDOWN and document_type != DocumentType.TXT:
|
||||
p = pathlib.Path(file_path)
|
||||
p_as_md = p.with_suffix(".md")
|
||||
@ -1873,53 +1877,6 @@ async def upload_candidate_document(
|
||||
content=file_content,
|
||||
)
|
||||
yield chat_message
|
||||
|
||||
# If this is a job description, process it with the job requirements agent
|
||||
if not options.is_job_document:
|
||||
return
|
||||
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Initiating connection with {candidate.first_name}'s AI agent...",
|
||||
activity=ApiActivityType.INFO
|
||||
)
|
||||
yield status_message
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
|
||||
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
|
||||
if not chat_agent:
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="No agent found for job requirements chat type"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
message = None
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Analyzing document for company and requirement details...",
|
||||
activity=ApiActivityType.SEARCHING
|
||||
)
|
||||
yield status_message
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async for message in chat_agent.generate(
|
||||
llm=llm_manager.get_llm(),
|
||||
model=defines.model,
|
||||
session_id=MOCK_UUID,
|
||||
prompt=file_content
|
||||
):
|
||||
pass
|
||||
if not message or not isinstance(message, JobRequirementsMessage):
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Failed to process job description file"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
yield message
|
||||
|
||||
try:
|
||||
async def to_json(method):
|
||||
try:
|
||||
@ -1932,7 +1889,6 @@ async def upload_candidate_document(
|
||||
logger.error(f"Error in to_json conversion: {e}")
|
||||
return
|
||||
|
||||
# return DebugStreamingResponse(
|
||||
return StreamingResponse(
|
||||
to_json(upload_stream_generator(file_content)),
|
||||
media_type="text/event-stream",
|
||||
@ -1944,15 +1900,64 @@ async def upload_candidate_document(
|
||||
"Access-Control-Allow-Origin": "*", # Adjust for your CORS needs
|
||||
"Transfer-Encoding": "chunked",
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Document upload error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("UPLOAD_ERROR", "Failed to upload document")
|
||||
return StreamingResponse(
|
||||
iter([ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Failed to upload document"
|
||||
)]),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
|
||||
async def create_job_from_content(database: RedisDatabase, current_user: Candidate, content: str):
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Initiating connection with {current_user.first_name}'s AI agent...",
|
||||
activity=ApiActivityType.INFO
|
||||
)
|
||||
yield status_message
|
||||
await asyncio.sleep(0) # Let the status message propagate
|
||||
|
||||
async with entities.get_candidate_entity(candidate=current_user) as candidate_entity:
|
||||
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
|
||||
if not chat_agent:
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="No agent found for job requirements chat type"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
message = None
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Analyzing document for company and requirement details...",
|
||||
activity=ApiActivityType.SEARCHING
|
||||
)
|
||||
yield status_message
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async for message in chat_agent.generate(
|
||||
llm=llm_manager.get_llm(),
|
||||
model=defines.model,
|
||||
session_id=MOCK_UUID,
|
||||
prompt=content
|
||||
):
|
||||
pass
|
||||
if not message or not isinstance(message, JobRequirementsMessage):
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Failed to process job description file"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
logger.info(f"✅ Successfully saved job requirements job {message.id}")
|
||||
yield message
|
||||
return
|
||||
|
||||
@api_router.post("/candidates/profile/upload")
|
||||
async def upload_candidate_profile(
|
||||
file: UploadFile = File(...),
|
||||
@ -2573,6 +2578,7 @@ async def delete_candidate(
|
||||
status_code=500,
|
||||
content=create_error_response("DELETE_ERROR", "Failed to delete candidate")
|
||||
)
|
||||
|
||||
@api_router.patch("/candidates/{candidate_id}")
|
||||
async def update_candidate(
|
||||
candidate_id: str = Path(...),
|
||||
@ -2816,6 +2822,139 @@ async def create_candidate_job(
|
||||
content=create_error_response("CREATION_FAILED", str(e))
|
||||
)
|
||||
|
||||
|
||||
@api_router.post("/jobs/upload")
|
||||
async def create_job_from_file(
|
||||
file: UploadFile = File(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Upload a job document for the current candidate and create a Job"""
|
||||
# Check file size (limit to 10MB)
|
||||
max_size = 10 * 1024 * 1024 # 10MB
|
||||
file_content = await file.read()
|
||||
if len(file_content) > max_size:
|
||||
logger.info(f"⚠️ File too large: {file.filename} ({len(file_content)} bytes)")
|
||||
return StreamingResponse(
|
||||
iter([ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="File size exceeds 10MB limit"
|
||||
)]),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
if len(file_content) == 0:
|
||||
logger.info(f"⚠️ File is empty: {file.filename}")
|
||||
return StreamingResponse(
|
||||
iter([ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="File is empty"
|
||||
)]),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
|
||||
"""Upload a document for the current candidate"""
|
||||
async def upload_stream_generator(file_content):
|
||||
# Verify user is a candidate
|
||||
if current_user.user_type != "candidate":
|
||||
logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Only candidates can upload documents"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
file.filename = re.sub(r'^.*/', '', file.filename) if file.filename else '' # Sanitize filename
|
||||
if not file.filename or file.filename.strip() == "":
|
||||
logger.warning("⚠️ File upload attempt with missing filename")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="File must have a valid filename"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
logger.info(f"📁 Received file upload: filename='{file.filename}', content_type='{file.content_type}', size='{len(file_content)} bytes'")
|
||||
|
||||
# Validate file type
|
||||
allowed_types = ['.txt', '.md', '.docx', '.pdf', '.png', '.jpg', '.jpeg', '.gif']
|
||||
file_extension = pathlib.Path(file.filename).suffix.lower() if file.filename else ""
|
||||
|
||||
if file_extension not in allowed_types:
|
||||
logger.warning(f"⚠️ Invalid file type: {file_extension} for file {file.filename}")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
document_type = get_document_type_from_filename(file.filename or "unknown.txt")
|
||||
|
||||
if document_type != DocumentType.MARKDOWN and document_type != DocumentType.TXT:
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Converting content from {document_type}...",
|
||||
activity=ApiActivityType.CONVERTING
|
||||
)
|
||||
yield status_message
|
||||
try:
|
||||
md = MarkItDown(enable_plugins=False) # Set to True to enable plugins
|
||||
stream = io.BytesIO(file_content)
|
||||
stream_info = StreamInfo(
|
||||
extension=file_extension, # e.g., ".pdf"
|
||||
url=file.filename # optional, helps with logging and guessing
|
||||
)
|
||||
result = md.convert_stream(stream, stream_info=stream_info, output_format="markdown")
|
||||
file_content = result.text_content
|
||||
logger.info(f"✅ Converted {file.filename} to Markdown format")
|
||||
except Exception as e:
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Failed to convert {file.filename} to Markdown.",
|
||||
)
|
||||
yield error_message
|
||||
logger.error(f"❌ Error converting {file.filename} to Markdown: {e}")
|
||||
return
|
||||
async for message in create_job_from_content(database=database, current_user=current_user, content=file_content):
|
||||
yield message
|
||||
return
|
||||
|
||||
try:
|
||||
async def to_json(method):
|
||||
try:
|
||||
async for message in method:
|
||||
json_data = message.model_dump(mode='json', by_alias=True)
|
||||
json_str = json.dumps(json_data)
|
||||
yield f"data: {json_str}\n\n".encode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"Error in to_json conversion: {e}")
|
||||
return
|
||||
|
||||
return StreamingResponse(
|
||||
to_json(upload_stream_generator(file_content)),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache, no-store, must-revalidate",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Nginx
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"Access-Control-Allow-Origin": "*", # Adjust for your CORS needs
|
||||
"Transfer-Encoding": "chunked",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Document upload error: {e}")
|
||||
return StreamingResponse(
|
||||
iter([ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Failed to upload document"
|
||||
)]),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
|
||||
@api_router.get("/jobs/{job_id}")
|
||||
async def get_job(
|
||||
job_id: str = Path(...),
|
||||
@ -2931,6 +3070,49 @@ async def search_jobs(
|
||||
content=create_error_response("SEARCH_FAILED", str(e))
|
||||
)
|
||||
|
||||
|
||||
@api_router.delete("/jobs/{job_id}")
|
||||
async def delete_job(
|
||||
job_id: str = Path(...),
|
||||
admin_user = Depends(get_current_admin),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Delete a Job"""
|
||||
try:
|
||||
# Check if admin user
|
||||
if not admin_user.is_admin:
|
||||
logger.warning(f"⚠️ Unauthorized delete attempt by user {admin_user.id}")
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content=create_error_response("FORBIDDEN", "Only admins can delete")
|
||||
)
|
||||
|
||||
# Get candidate data
|
||||
job_data = await database.get_job(job_id)
|
||||
if not job_data:
|
||||
logger.warning(f"⚠️ Candidate not found for deletion: {job_id}")
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Job not found")
|
||||
)
|
||||
|
||||
# Delete job from database
|
||||
await database.delete_job(job_id)
|
||||
|
||||
logger.info(f"🗑️ Job deleted: {job_id} by admin {admin_user.id}")
|
||||
|
||||
return create_success_response({
|
||||
"message": "Job deleted successfully",
|
||||
"jobId": job_id
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Delete job error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("DELETE_ERROR", "Failed to delete job")
|
||||
)
|
||||
|
||||
# ============================
|
||||
# Chat Endpoints
|
||||
# ============================
|
||||
@ -3541,7 +3723,7 @@ async def get_candidate_skill_match(
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Get skill match for a candidate against a requirement"""
|
||||
"""Get skill match for a candidate against a requirement with caching"""
|
||||
try:
|
||||
# Find candidate by ID
|
||||
candidate_data = await database.get_candidate(candidate_id)
|
||||
@ -3553,34 +3735,89 @@ async def get_candidate_skill_match(
|
||||
|
||||
candidate = Candidate.model_validate(candidate_data)
|
||||
|
||||
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
|
||||
logger.info(f"🔍 Running skill match for candidate {candidate_entity.username} against requirement: {requirement}")
|
||||
agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.SKILL_MATCH)
|
||||
if not agent:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("AGENT_NOT_FOUND", "No skill match agent found for this candidate")
|
||||
)
|
||||
# Entity automatically released when done
|
||||
skill_match = await get_last_item(
|
||||
agent.generate(
|
||||
llm=llm_manager.get_llm(),
|
||||
model=defines.model,
|
||||
session_id=MOCK_UUID,
|
||||
prompt=requirement,
|
||||
# Create cache key for this specific candidate + requirement combination
|
||||
cache_key = f"skill_match:{candidate_id}:{hash(requirement)}"
|
||||
|
||||
# Get cached assessment if it exists
|
||||
cached_assessment = await database.get_cached_skill_match(cache_key)
|
||||
|
||||
# Get the last update time for the candidate's skill information
|
||||
candidate_skill_update_time = await database.get_candidate_skill_update_time(candidate_id)
|
||||
|
||||
# Get the latest RAG data update time for the current user
|
||||
user_rag_update_time = await database.get_user_rag_update_time(current_user.id)
|
||||
|
||||
# Determine if we need to regenerate the assessment
|
||||
should_regenerate = True
|
||||
cached_date = None
|
||||
|
||||
if cached_assessment:
|
||||
cached_date = cached_assessment.get('cached_at')
|
||||
if cached_date:
|
||||
# Check if cached result is still valid
|
||||
# Regenerate if:
|
||||
# 1. Candidate skills were updated after cache date
|
||||
# 2. User's RAG data was updated after cache date
|
||||
if (not candidate_skill_update_time or cached_date >= candidate_skill_update_time) and \
|
||||
(not user_rag_update_time or cached_date >= user_rag_update_time):
|
||||
should_regenerate = False
|
||||
logger.info(f"🔄 Using cached skill match for candidate {candidate.id}")
|
||||
|
||||
if should_regenerate:
|
||||
logger.info(f"🔍 Generating new skill match for candidate {candidate.id} against requirement: {requirement}")
|
||||
|
||||
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
|
||||
agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.SKILL_MATCH)
|
||||
if not agent:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("AGENT_NOT_FOUND", "No skill match agent found for this candidate")
|
||||
)
|
||||
|
||||
# Generate new skill match
|
||||
skill_match = await get_last_item(
|
||||
agent.generate(
|
||||
llm=llm_manager.get_llm(),
|
||||
model=defines.model,
|
||||
session_id=MOCK_UUID,
|
||||
prompt=requirement,
|
||||
),
|
||||
)
|
||||
if skill_match is None:
|
||||
|
||||
if skill_match is None:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("NO_MATCH", "No skill match found for the given requirement")
|
||||
)
|
||||
|
||||
skill_match_data = json.loads(skill_match.content)
|
||||
|
||||
# Cache the new assessment with current timestamp
|
||||
cached_assessment = {
|
||||
"skill_match": skill_match_data,
|
||||
"cached_at": datetime.utcnow().isoformat(),
|
||||
"candidate_id": candidate_id,
|
||||
"requirement": requirement
|
||||
}
|
||||
|
||||
await database.cache_skill_match(cache_key, cached_assessment)
|
||||
logger.info(f"💾 Cached new skill match for candidate {candidate.id}")
|
||||
logger.info(f"✅ Skill match found for candidate {candidate.id}: {skill_match_data['evidence_strength']}")
|
||||
else:
|
||||
# Use cached result - we know cached_assessment is not None here
|
||||
if cached_assessment is None:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("NO_MATCH", "No skill match found for the given requirement")
|
||||
content=create_error_response("CACHE_ERROR", "Unexpected cache state")
|
||||
)
|
||||
skill_match = json.loads(skill_match.content)
|
||||
logger.info(f"✅ Skill match found for candidate {candidate.id}: {skill_match['evidence_strength']}")
|
||||
skill_match_data = cached_assessment["skill_match"]
|
||||
logger.info(f"✅ Retrieved cached skill match for candidate {candidate.id}: {skill_match_data['evidence_strength']}")
|
||||
|
||||
return create_success_response({
|
||||
"candidateId": candidate.id,
|
||||
"skillMatch": skill_match
|
||||
"skillMatch": skill_match_data,
|
||||
"cached": not should_regenerate,
|
||||
"cacheTimestamp": cached_date
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
@ -3589,8 +3826,8 @@ async def get_candidate_skill_match(
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("SKILL_MATCH_ERROR", str(e))
|
||||
)
|
||||
|
||||
)
|
||||
|
||||
@api_router.get("/candidates/{username}/chat-sessions")
|
||||
async def get_candidate_chat_sessions(
|
||||
username: str = Path(...),
|
||||
@ -3911,7 +4148,6 @@ async def track_requests(request, call_next):
|
||||
# FastAPI Metrics
|
||||
# ============================
|
||||
prometheus_collector = CollectorRegistry()
|
||||
entities.entity_manager.initialize(prometheus_collector)
|
||||
|
||||
# Keep the Instrumentator instance alive
|
||||
instrumentator = Instrumentator(
|
||||
|
@ -768,11 +768,7 @@ class ChatOptions(BaseModel):
|
||||
"populate_by_name": True # Allow both field names and aliases
|
||||
}
|
||||
|
||||
|
||||
class LLMMessage(BaseModel):
|
||||
role: str = Field(default="")
|
||||
content: str = Field(default="")
|
||||
tool_calls: Optional[List[Dict]] = Field(default=[], exclude=True)
|
||||
from llm_proxy import (LLMMessage)
|
||||
|
||||
class ApiMessage(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
|
@ -13,7 +13,6 @@ import numpy as np # type: ignore
|
||||
import traceback
|
||||
|
||||
import chromadb # type: ignore
|
||||
import ollama
|
||||
from watchdog.observers import Observer # type: ignore
|
||||
from watchdog.events import FileSystemEventHandler # type: ignore
|
||||
import umap # type: ignore
|
||||
@ -27,6 +26,7 @@ from .markdown_chunker import (
|
||||
|
||||
# When imported as a module, use relative imports
|
||||
import defines
|
||||
from database import RedisDatabase
|
||||
from models import ChromaDBGetResponse
|
||||
|
||||
__all__ = ["ChromaDBFileWatcher", "start_file_watcher"]
|
||||
@ -47,11 +47,16 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
loop,
|
||||
persist_directory,
|
||||
collection_name,
|
||||
database: RedisDatabase,
|
||||
user_id: str,
|
||||
chunk_size=DEFAULT_CHUNK_SIZE,
|
||||
chunk_overlap=DEFAULT_CHUNK_OVERLAP,
|
||||
recreate=False,
|
||||
):
|
||||
self.llm = llm
|
||||
self.database = database
|
||||
self.user_id = user_id
|
||||
self.database = database
|
||||
self.watch_directory = watch_directory
|
||||
self.persist_directory = persist_directory or defines.persist_directory
|
||||
self.collection_name = collection_name
|
||||
@ -284,6 +289,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
|
||||
if results and "ids" in results and results["ids"]:
|
||||
self.collection.delete(ids=results["ids"])
|
||||
await self.database.update_user_rag_timestamp(self.user_id)
|
||||
logging.info(
|
||||
f"Removed {len(results['ids'])} chunks for deleted file: {file_path}"
|
||||
)
|
||||
@ -372,14 +378,15 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
name=self.collection_name, metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
|
||||
def get_embedding(self, text: str) -> np.ndarray:
|
||||
async def get_embedding(self, text: str) -> np.ndarray:
|
||||
"""Generate and normalize an embedding for the given text."""
|
||||
|
||||
# Get embedding
|
||||
try:
|
||||
response = self.llm.embeddings(model=defines.embedding_model, prompt=text)
|
||||
embedding = np.array(response["embedding"])
|
||||
response = await self.llm.embeddings(model=defines.embedding_model, input_texts=text)
|
||||
embedding = np.array(response.get_single_embedding())
|
||||
except Exception as e:
|
||||
logging.error(traceback.format_exc())
|
||||
logging.error(f"Failed to get embedding: {e}")
|
||||
raise
|
||||
|
||||
@ -404,7 +411,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
|
||||
return embedding
|
||||
|
||||
def add_embeddings_to_collection(self, chunks: List[Chunk]):
|
||||
async def _add_embeddings_to_collection(self, chunks: List[Chunk]):
|
||||
"""Add embeddings for chunks to the collection."""
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
@ -420,7 +427,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
content_hash = hashlib.md5(text.encode()).hexdigest()[:8]
|
||||
chunk_id = f"{path_hash}_{i}_{content_hash}"
|
||||
|
||||
embedding = self.get_embedding(text)
|
||||
embedding = await self.get_embedding(text)
|
||||
try:
|
||||
self.collection.add(
|
||||
ids=[chunk_id],
|
||||
@ -458,11 +465,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
# 0.5 - 0.7 0.65 - 0.75 Balanced precision/recall
|
||||
# 0.7 - 0.9 0.55 - 0.65 Higher recall, more inclusive
|
||||
# 0.9 - 1.2 0.40 - 0.55 Very inclusive, may include tangential content
|
||||
def find_similar(self, query, top_k=defines.default_rag_top_k, threshold=defines.default_rag_threshold):
|
||||
async def find_similar(self, query, top_k=defines.default_rag_top_k, threshold=defines.default_rag_threshold):
|
||||
"""Find similar documents to the query."""
|
||||
|
||||
# collection is configured with hnsw:space cosine
|
||||
query_embedding = self.get_embedding(query)
|
||||
query_embedding = await self.get_embedding(query)
|
||||
results = self.collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=top_k,
|
||||
@ -572,6 +579,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
and existing_results["ids"]
|
||||
):
|
||||
self.collection.delete(ids=existing_results["ids"])
|
||||
await self.database.update_user_rag_timestamp(self.user_id)
|
||||
|
||||
extensions = (".docx", ".xlsx", ".xls", ".pdf")
|
||||
if file_path.endswith(extensions):
|
||||
@ -606,7 +614,8 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
# f.write(json.dumps(chunk, indent=2))
|
||||
|
||||
# Add chunks to collection
|
||||
self.add_embeddings_to_collection(chunks)
|
||||
await self._add_embeddings_to_collection(chunks)
|
||||
await self.database.update_user_rag_timestamp(self.user_id)
|
||||
|
||||
logging.info(f"Updated {len(chunks)} chunks for file: {file_path}")
|
||||
|
||||
@ -640,9 +649,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
# Function to start the file watcher
|
||||
def start_file_watcher(
|
||||
llm,
|
||||
user_id,
|
||||
watch_directory,
|
||||
persist_directory,
|
||||
collection_name,
|
||||
database: RedisDatabase,
|
||||
initialize=False,
|
||||
recreate=False,
|
||||
):
|
||||
@ -663,9 +674,11 @@ def start_file_watcher(
|
||||
llm,
|
||||
watch_directory=watch_directory,
|
||||
loop=loop,
|
||||
user_id=user_id,
|
||||
persist_directory=persist_directory,
|
||||
collection_name=collection_name,
|
||||
recreate=recreate,
|
||||
database=database
|
||||
)
|
||||
|
||||
# Process all files if:
|
||||
|
193
src/multi-llm/config.md
Normal file
193
src/multi-llm/config.md
Normal file
@ -0,0 +1,193 @@
|
||||
# Environment Configuration Examples
|
||||
|
||||
# ==== Development (Ollama only) ====
|
||||
export OLLAMA_HOST="http://localhost:11434"
|
||||
export DEFAULT_LLM_PROVIDER="ollama"
|
||||
|
||||
# ==== Production with OpenAI ====
|
||||
export OPENAI_API_KEY="sk-your-openai-key-here"
|
||||
export DEFAULT_LLM_PROVIDER="openai"
|
||||
|
||||
# ==== Production with Anthropic ====
|
||||
export ANTHROPIC_API_KEY="sk-ant-your-anthropic-key-here"
|
||||
export DEFAULT_LLM_PROVIDER="anthropic"
|
||||
|
||||
# ==== Production with multiple providers ====
|
||||
export OPENAI_API_KEY="sk-your-openai-key-here"
|
||||
export ANTHROPIC_API_KEY="sk-ant-your-anthropic-key-here"
|
||||
export GEMINI_API_KEY="your-gemini-key-here"
|
||||
export OLLAMA_HOST="http://ollama-server:11434"
|
||||
export DEFAULT_LLM_PROVIDER="openai"
|
||||
|
||||
# ==== Docker Compose Example ====
|
||||
# docker-compose.yml
|
||||
version: '3.8'
|
||||
services:
|
||||
api:
|
||||
build: .
|
||||
ports:
|
||||
- "8000:8000"
|
||||
environment:
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
|
||||
- DEFAULT_LLM_PROVIDER=openai
|
||||
depends_on:
|
||||
- ollama
|
||||
|
||||
ollama:
|
||||
image: ollama/ollama
|
||||
ports:
|
||||
- "11434:11434"
|
||||
volumes:
|
||||
- ollama_data:/root/.ollama
|
||||
environment:
|
||||
- OLLAMA_HOST=0.0.0.0
|
||||
|
||||
volumes:
|
||||
ollama_data:
|
||||
|
||||
# ==== Kubernetes ConfigMap ====
|
||||
apiVersion: v1
|
||||
kind: ConfigMap
|
||||
metadata:
|
||||
name: llm-config
|
||||
data:
|
||||
DEFAULT_LLM_PROVIDER: "openai"
|
||||
OLLAMA_HOST: "http://ollama-service:11434"
|
||||
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Secret
|
||||
metadata:
|
||||
name: llm-secrets
|
||||
type: Opaque
|
||||
stringData:
|
||||
OPENAI_API_KEY: "sk-your-key-here"
|
||||
ANTHROPIC_API_KEY: "sk-ant-your-key-here"
|
||||
|
||||
# ==== Python configuration example ====
|
||||
# config.py
|
||||
import os
|
||||
from llm_proxy import get_llm, LLMProvider
|
||||
|
||||
def configure_llm_for_environment():
|
||||
"""Configure LLM based on deployment environment"""
|
||||
llm = get_llm()
|
||||
|
||||
# Development: Use Ollama
|
||||
if os.getenv('ENVIRONMENT') == 'development':
|
||||
llm.configure_provider(LLMProvider.OLLAMA, host='http://localhost:11434')
|
||||
llm.set_default_provider(LLMProvider.OLLAMA)
|
||||
|
||||
# Staging: Use OpenAI with rate limiting
|
||||
elif os.getenv('ENVIRONMENT') == 'staging':
|
||||
llm.configure_provider(LLMProvider.OPENAI,
|
||||
api_key=os.getenv('OPENAI_API_KEY'),
|
||||
max_retries=3,
|
||||
timeout=30)
|
||||
llm.set_default_provider(LLMProvider.OPENAI)
|
||||
|
||||
# Production: Use multiple providers with fallback
|
||||
elif os.getenv('ENVIRONMENT') == 'production':
|
||||
# Primary: Anthropic
|
||||
llm.configure_provider(LLMProvider.ANTHROPIC,
|
||||
api_key=os.getenv('ANTHROPIC_API_KEY'))
|
||||
|
||||
# Fallback: OpenAI
|
||||
llm.configure_provider(LLMProvider.OPENAI,
|
||||
api_key=os.getenv('OPENAI_API_KEY'))
|
||||
|
||||
# Set primary provider
|
||||
llm.set_default_provider(LLMProvider.ANTHROPIC)
|
||||
|
||||
# ==== Usage Examples ====
|
||||
|
||||
# Example 1: Basic usage with default provider
|
||||
async def basic_example():
|
||||
llm = get_llm()
|
||||
|
||||
response = await llm.chat(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Hello!"}]
|
||||
)
|
||||
print(response.content)
|
||||
|
||||
# Example 2: Specify provider explicitly
|
||||
async def provider_specific_example():
|
||||
llm = get_llm()
|
||||
|
||||
# Use OpenAI specifically
|
||||
response = await llm.chat(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
provider=LLMProvider.OPENAI
|
||||
)
|
||||
|
||||
# Use Anthropic specifically
|
||||
response2 = await llm.chat(
|
||||
model="claude-3-sonnet-20240229",
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
provider=LLMProvider.ANTHROPIC
|
||||
)
|
||||
|
||||
# Example 3: Streaming with provider fallback
|
||||
async def streaming_with_fallback():
|
||||
llm = get_llm()
|
||||
|
||||
try:
|
||||
async for chunk in llm.chat(
|
||||
model="claude-3-sonnet-20240229",
|
||||
messages=[{"role": "user", "content": "Write a story"}],
|
||||
provider=LLMProvider.ANTHROPIC,
|
||||
stream=True
|
||||
):
|
||||
print(chunk.content, end='', flush=True)
|
||||
except Exception as e:
|
||||
print(f"Primary provider failed: {e}")
|
||||
# Fallback to OpenAI
|
||||
async for chunk in llm.chat(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": "Write a story"}],
|
||||
provider=LLMProvider.OPENAI,
|
||||
stream=True
|
||||
):
|
||||
print(chunk.content, end='', flush=True)
|
||||
|
||||
# Example 4: Load balancing between providers
|
||||
import random
|
||||
|
||||
async def load_balanced_request():
|
||||
llm = get_llm()
|
||||
available_providers = [LLMProvider.OPENAI, LLMProvider.ANTHROPIC]
|
||||
|
||||
# Simple random load balancing
|
||||
provider = random.choice(available_providers)
|
||||
|
||||
# Adjust model based on provider
|
||||
model_mapping = {
|
||||
LLMProvider.OPENAI: "gpt-4",
|
||||
LLMProvider.ANTHROPIC: "claude-3-sonnet-20240229"
|
||||
}
|
||||
|
||||
response = await llm.chat(
|
||||
model=model_mapping[provider],
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
provider=provider
|
||||
)
|
||||
|
||||
print(f"Response from {provider.value}: {response.content}")
|
||||
|
||||
# ==== FastAPI Startup Configuration ====
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Configure LLM providers on application startup"""
|
||||
configure_llm_for_environment()
|
||||
|
||||
# Verify configuration
|
||||
llm = get_llm()
|
||||
print(f"Configured providers: {[p.value for p in llm._initialized_providers]}")
|
||||
print(f"Default provider: {llm.default_provider.value}")
|
238
src/multi-llm/example.py
Normal file
238
src/multi-llm/example.py
Normal file
@ -0,0 +1,238 @@
|
||||
from fastapi import FastAPI, HTTPException # type: ignore
|
||||
from fastapi.responses import StreamingResponse # type: ignore
|
||||
from pydantic import BaseModel # type: ignore
|
||||
from typing import List, Optional, Dict, Any
|
||||
import json
|
||||
import asyncio
|
||||
from llm_proxy import get_llm, LLMProvider, ChatMessage
|
||||
|
||||
app = FastAPI(title="Unified LLM API")
|
||||
|
||||
# Pydantic models for request/response
|
||||
class ChatRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Dict[str, str]]
|
||||
provider: Optional[str] = None
|
||||
stream: bool = False
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
provider: Optional[str] = None
|
||||
stream: bool = False
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
|
||||
class ChatResponseModel(BaseModel):
|
||||
content: str
|
||||
model: str
|
||||
provider: str
|
||||
finish_reason: Optional[str] = None
|
||||
usage: Optional[Dict[str, int]] = None
|
||||
|
||||
@app.post("/chat")
|
||||
async def chat_endpoint(request: ChatRequest):
|
||||
"""Chat endpoint that works with any configured LLM provider"""
|
||||
try:
|
||||
llm = get_llm()
|
||||
|
||||
# Convert provider string to enum if provided
|
||||
provider = None
|
||||
if request.provider:
|
||||
try:
|
||||
provider = LLMProvider(request.provider.lower())
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported provider: {request.provider}"
|
||||
)
|
||||
|
||||
# Prepare kwargs
|
||||
kwargs = {}
|
||||
if request.max_tokens:
|
||||
kwargs['max_tokens'] = request.max_tokens
|
||||
if request.temperature:
|
||||
kwargs['temperature'] = request.temperature
|
||||
|
||||
if request.stream:
|
||||
async def generate_response():
|
||||
try:
|
||||
# Use the type-safe streaming method
|
||||
async for chunk in llm.chat_stream(
|
||||
model=request.model,
|
||||
messages=request.messages,
|
||||
provider=provider,
|
||||
**kwargs
|
||||
):
|
||||
# Format as Server-Sent Events
|
||||
data = {
|
||||
"content": chunk.content,
|
||||
"model": chunk.model,
|
||||
"provider": provider.value if provider else llm.default_provider.value,
|
||||
"finish_reason": chunk.finish_reason
|
||||
}
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
except Exception as e:
|
||||
error_data = {"error": str(e)}
|
||||
yield f"data: {json.dumps(error_data)}\n\n"
|
||||
finally:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate_response(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive"
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Use the type-safe single response method
|
||||
response = await llm.chat_single(
|
||||
model=request.model,
|
||||
messages=request.messages,
|
||||
provider=provider,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return ChatResponseModel(
|
||||
content=response.content,
|
||||
model=response.model,
|
||||
provider=provider.value if provider else llm.default_provider.value,
|
||||
finish_reason=response.finish_reason,
|
||||
usage=response.usage
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate_endpoint(request: GenerateRequest):
|
||||
"""Generate endpoint for simple text generation"""
|
||||
try:
|
||||
llm = get_llm()
|
||||
|
||||
provider = None
|
||||
if request.provider:
|
||||
try:
|
||||
provider = LLMProvider(request.provider.lower())
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported provider: {request.provider}"
|
||||
)
|
||||
|
||||
kwargs = {}
|
||||
if request.max_tokens:
|
||||
kwargs['max_tokens'] = request.max_tokens
|
||||
if request.temperature:
|
||||
kwargs['temperature'] = request.temperature
|
||||
|
||||
if request.stream:
|
||||
async def generate_response():
|
||||
try:
|
||||
# Use the type-safe streaming method
|
||||
async for chunk in llm.generate_stream(
|
||||
model=request.model,
|
||||
prompt=request.prompt,
|
||||
provider=provider,
|
||||
**kwargs
|
||||
):
|
||||
data = {
|
||||
"content": chunk.content,
|
||||
"model": chunk.model,
|
||||
"provider": provider.value if provider else llm.default_provider.value,
|
||||
"finish_reason": chunk.finish_reason
|
||||
}
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
except Exception as e:
|
||||
error_data = {"error": str(e)}
|
||||
yield f"data: {json.dumps(error_data)}\n\n"
|
||||
finally:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate_response(),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
# Use the type-safe single response method
|
||||
response = await llm.generate_single(
|
||||
model=request.model,
|
||||
prompt=request.prompt,
|
||||
provider=provider,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return ChatResponseModel(
|
||||
content=response.content,
|
||||
model=response.model,
|
||||
provider=provider.value if provider else llm.default_provider.value,
|
||||
finish_reason=response.finish_reason,
|
||||
usage=response.usage
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/models")
|
||||
async def list_models(provider: Optional[str] = None):
|
||||
"""List available models for a provider"""
|
||||
try:
|
||||
llm = get_llm()
|
||||
|
||||
provider_enum = None
|
||||
if provider:
|
||||
try:
|
||||
provider_enum = LLMProvider(provider.lower())
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported provider: {provider}"
|
||||
)
|
||||
|
||||
models = await llm.list_models(provider_enum)
|
||||
return {
|
||||
"provider": provider_enum.value if provider_enum else llm.default_provider.value,
|
||||
"models": models
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/providers")
|
||||
async def list_providers():
|
||||
"""List all configured providers"""
|
||||
llm = get_llm()
|
||||
return {
|
||||
"providers": [provider.value for provider in llm._initialized_providers],
|
||||
"default": llm.default_provider.value
|
||||
}
|
||||
|
||||
@app.post("/providers/{provider}/set-default")
|
||||
async def set_default_provider(provider: str):
|
||||
"""Set the default provider"""
|
||||
try:
|
||||
llm = get_llm()
|
||||
provider_enum = LLMProvider(provider.lower())
|
||||
llm.set_default_provider(provider_enum)
|
||||
return {"message": f"Default provider set to {provider}", "default": provider}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
# Health check endpoint
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
llm = get_llm()
|
||||
return {
|
||||
"status": "healthy",
|
||||
"providers_configured": len(llm._initialized_providers),
|
||||
"default_provider": llm.default_provider.value
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn # type: ignore
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
606
src/multi-llm/llm_proxy.py
Normal file
606
src/multi-llm/llm_proxy.py
Normal file
@ -0,0 +1,606 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Any, AsyncGenerator, Optional, Union
|
||||
from enum import Enum
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
|
||||
# Standard message format for all providers
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
role: str # "user", "assistant", "system"
|
||||
content: str
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
"""Allow dictionary-style access for backward compatibility"""
|
||||
if key == 'role':
|
||||
return self.role
|
||||
elif key == 'content':
|
||||
return self.content
|
||||
else:
|
||||
raise KeyError(f"'{key}' not found in ChatMessage")
|
||||
|
||||
def __setitem__(self, key: str, value: str):
|
||||
"""Allow dictionary-style assignment"""
|
||||
if key == 'role':
|
||||
self.role = value
|
||||
elif key == 'content':
|
||||
self.content = value
|
||||
else:
|
||||
raise KeyError(f"'{key}' not found in ChatMessage")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, str]) -> 'ChatMessage':
|
||||
"""Create ChatMessage from dictionary"""
|
||||
return cls(role=data['role'], content=data['content'])
|
||||
|
||||
def to_dict(self) -> Dict[str, str]:
|
||||
"""Convert ChatMessage to dictionary"""
|
||||
return {'role': self.role, 'content': self.content}
|
||||
|
||||
@dataclass
|
||||
class ChatResponse:
|
||||
content: str
|
||||
model: str
|
||||
finish_reason: Optional[str] = None
|
||||
usage: Optional[Dict[str, int]] = None
|
||||
|
||||
class LLMProvider(Enum):
|
||||
OLLAMA = "ollama"
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
GEMINI = "gemini"
|
||||
GROK = "grok"
|
||||
|
||||
class BaseLLMAdapter(ABC):
|
||||
"""Abstract base class for all LLM adapters"""
|
||||
|
||||
def __init__(self, **config):
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[ChatMessage],
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
"""Send chat messages and get response"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def generate(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
"""Generate text from prompt"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def list_models(self) -> List[str]:
|
||||
"""List available models"""
|
||||
pass
|
||||
|
||||
class OllamaAdapter(BaseLLMAdapter):
|
||||
"""Adapter for Ollama"""
|
||||
|
||||
def __init__(self, **config):
|
||||
super().__init__(**config)
|
||||
import ollama
|
||||
self.client = ollama.AsyncClient( # type: ignore
|
||||
host=config.get('host', 'http://localhost:11434')
|
||||
)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[ChatMessage],
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
|
||||
# Convert ChatMessage objects to Ollama format
|
||||
ollama_messages = []
|
||||
for msg in messages:
|
||||
ollama_messages.append({
|
||||
"role": msg.role,
|
||||
"content": msg.content
|
||||
})
|
||||
|
||||
if stream:
|
||||
return self._stream_chat(model, ollama_messages, **kwargs)
|
||||
else:
|
||||
response = await self.client.chat(
|
||||
model=model,
|
||||
messages=ollama_messages,
|
||||
stream=False,
|
||||
**kwargs
|
||||
)
|
||||
return ChatResponse(
|
||||
content=response['message']['content'],
|
||||
model=model,
|
||||
finish_reason=response.get('done_reason')
|
||||
)
|
||||
|
||||
async def _stream_chat(self, model: str, messages: List[Dict], **kwargs):
|
||||
async for chunk in self.client.chat(
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
**kwargs
|
||||
):
|
||||
if chunk.get('message', {}).get('content'):
|
||||
yield ChatResponse(
|
||||
content=chunk['message']['content'],
|
||||
model=model,
|
||||
finish_reason=chunk.get('done_reason')
|
||||
)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
|
||||
if stream:
|
||||
return self._stream_generate(model, prompt, **kwargs)
|
||||
else:
|
||||
response = await self.client.generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
**kwargs
|
||||
)
|
||||
return ChatResponse(
|
||||
content=response['response'],
|
||||
model=model,
|
||||
finish_reason=response.get('done_reason')
|
||||
)
|
||||
|
||||
async def _stream_generate(self, model: str, prompt: str, **kwargs):
|
||||
async for chunk in self.client.generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
**kwargs
|
||||
):
|
||||
if chunk.get('response'):
|
||||
yield ChatResponse(
|
||||
content=chunk['response'],
|
||||
model=model,
|
||||
finish_reason=chunk.get('done_reason')
|
||||
)
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
models = await self.client.list()
|
||||
return [model['name'] for model in models['models']]
|
||||
|
||||
class OpenAIAdapter(BaseLLMAdapter):
|
||||
"""Adapter for OpenAI"""
|
||||
|
||||
def __init__(self, **config):
|
||||
super().__init__(**config)
|
||||
import openai # type: ignore
|
||||
self.client = openai.AsyncOpenAI(
|
||||
api_key=config.get('api_key', os.getenv('OPENAI_API_KEY'))
|
||||
)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[ChatMessage],
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
|
||||
# Convert ChatMessage objects to OpenAI format
|
||||
openai_messages = []
|
||||
for msg in messages:
|
||||
openai_messages.append({
|
||||
"role": msg.role,
|
||||
"content": msg.content
|
||||
})
|
||||
|
||||
if stream:
|
||||
return self._stream_chat(model, openai_messages, **kwargs)
|
||||
else:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=openai_messages,
|
||||
stream=False,
|
||||
**kwargs
|
||||
)
|
||||
return ChatResponse(
|
||||
content=response.choices[0].message.content,
|
||||
model=model,
|
||||
finish_reason=response.choices[0].finish_reason,
|
||||
usage=response.usage.model_dump() if response.usage else None
|
||||
)
|
||||
|
||||
async def _stream_chat(self, model: str, messages: List[Dict], **kwargs):
|
||||
async for chunk in await self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
**kwargs
|
||||
):
|
||||
if chunk.choices[0].delta.content:
|
||||
yield ChatResponse(
|
||||
content=chunk.choices[0].delta.content,
|
||||
model=model,
|
||||
finish_reason=chunk.choices[0].finish_reason
|
||||
)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
|
||||
# Convert to chat format for OpenAI
|
||||
messages = [ChatMessage(role="user", content=prompt)]
|
||||
return await self.chat(model, messages, stream, **kwargs)
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
models = await self.client.models.list()
|
||||
return [model.id for model in models.data]
|
||||
|
||||
class AnthropicAdapter(BaseLLMAdapter):
|
||||
"""Adapter for Anthropic Claude"""
|
||||
|
||||
def __init__(self, **config):
|
||||
super().__init__(**config)
|
||||
import anthropic # type: ignore
|
||||
self.client = anthropic.AsyncAnthropic(
|
||||
api_key=config.get('api_key', os.getenv('ANTHROPIC_API_KEY'))
|
||||
)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[ChatMessage],
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
|
||||
# Anthropic requires system message to be separate
|
||||
system_message = None
|
||||
anthropic_messages = []
|
||||
|
||||
for msg in messages:
|
||||
if msg.role == "system":
|
||||
system_message = msg.content
|
||||
else:
|
||||
anthropic_messages.append({
|
||||
"role": msg.role,
|
||||
"content": msg.content
|
||||
})
|
||||
|
||||
request_kwargs = {
|
||||
"model": model,
|
||||
"messages": anthropic_messages,
|
||||
"max_tokens": kwargs.pop('max_tokens', 1000),
|
||||
**kwargs
|
||||
}
|
||||
|
||||
if system_message:
|
||||
request_kwargs["system"] = system_message
|
||||
|
||||
if stream:
|
||||
return self._stream_chat(**request_kwargs)
|
||||
else:
|
||||
response = await self.client.messages.create(
|
||||
stream=False,
|
||||
**request_kwargs
|
||||
)
|
||||
return ChatResponse(
|
||||
content=response.content[0].text,
|
||||
model=model,
|
||||
finish_reason=response.stop_reason,
|
||||
usage={
|
||||
"input_tokens": response.usage.input_tokens,
|
||||
"output_tokens": response.usage.output_tokens
|
||||
}
|
||||
)
|
||||
|
||||
async def _stream_chat(self, **kwargs):
|
||||
async with self.client.messages.stream(**kwargs) as stream:
|
||||
async for text in stream.text_stream:
|
||||
yield ChatResponse(
|
||||
content=text,
|
||||
model=kwargs['model']
|
||||
)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
|
||||
messages = [ChatMessage(role="user", content=prompt)]
|
||||
return await self.chat(model, messages, stream, **kwargs)
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
# Anthropic doesn't have a list models endpoint, return known models
|
||||
return [
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-5-haiku-20241022",
|
||||
"claude-3-opus-20240229"
|
||||
]
|
||||
|
||||
class GeminiAdapter(BaseLLMAdapter):
|
||||
"""Adapter for Google Gemini"""
|
||||
|
||||
def __init__(self, **config):
|
||||
super().__init__(**config)
|
||||
import google.generativeai as genai # type: ignore
|
||||
genai.configure(api_key=config.get('api_key', os.getenv('GEMINI_API_KEY')))
|
||||
self.genai = genai
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[ChatMessage],
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
|
||||
model_instance = self.genai.GenerativeModel(model)
|
||||
|
||||
# Convert messages to Gemini format
|
||||
chat_history = []
|
||||
current_message = None
|
||||
|
||||
for i, msg in enumerate(messages[:-1]): # All but last message go to history
|
||||
if msg.role == "user":
|
||||
chat_history.append({"role": "user", "parts": [msg.content]})
|
||||
elif msg.role == "assistant":
|
||||
chat_history.append({"role": "model", "parts": [msg.content]})
|
||||
|
||||
# Last message is the current prompt
|
||||
if messages:
|
||||
current_message = messages[-1].content
|
||||
|
||||
chat = model_instance.start_chat(history=chat_history)
|
||||
|
||||
if not current_message:
|
||||
raise ValueError("No current message provided for chat")
|
||||
|
||||
if stream:
|
||||
return self._stream_chat(chat, current_message, **kwargs)
|
||||
else:
|
||||
response = await chat.send_message_async(current_message, **kwargs)
|
||||
return ChatResponse(
|
||||
content=response.text,
|
||||
model=model,
|
||||
finish_reason=response.candidates[0].finish_reason.name if response.candidates else None
|
||||
)
|
||||
|
||||
async def _stream_chat(self, chat, message: str, **kwargs):
|
||||
async for chunk in chat.send_message_async(message, stream=True, **kwargs):
|
||||
if chunk.text:
|
||||
yield ChatResponse(
|
||||
content=chunk.text,
|
||||
model=chat.model.model_name
|
||||
)
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
|
||||
messages = [ChatMessage(role="user", content=prompt)]
|
||||
return await self.chat(model, messages, stream, **kwargs)
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
models = self.genai.list_models()
|
||||
return [model.name for model in models if 'generateContent' in model.supported_generation_methods]
|
||||
|
||||
class UnifiedLLMProxy:
|
||||
"""Main proxy class that provides unified interface to all LLM providers"""
|
||||
|
||||
def __init__(self, default_provider: LLMProvider = LLMProvider.OLLAMA):
|
||||
self.adapters: Dict[LLMProvider, BaseLLMAdapter] = {}
|
||||
self.default_provider = default_provider
|
||||
self._initialized_providers = set()
|
||||
|
||||
def configure_provider(self, provider: LLMProvider, **config):
|
||||
"""Configure a specific provider with its settings"""
|
||||
adapter_classes = {
|
||||
LLMProvider.OLLAMA: OllamaAdapter,
|
||||
LLMProvider.OPENAI: OpenAIAdapter,
|
||||
LLMProvider.ANTHROPIC: AnthropicAdapter,
|
||||
LLMProvider.GEMINI: GeminiAdapter,
|
||||
# Add other providers as needed
|
||||
}
|
||||
|
||||
if provider in adapter_classes:
|
||||
self.adapters[provider] = adapter_classes[provider](**config)
|
||||
self._initialized_providers.add(provider)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
||||
def set_default_provider(self, provider: LLMProvider):
|
||||
"""Set the default provider for requests"""
|
||||
if provider not in self._initialized_providers:
|
||||
raise ValueError(f"Provider {provider} not configured")
|
||||
self.default_provider = provider
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
model: str,
|
||||
messages: Union[List[ChatMessage], List[Dict[str, str]]],
|
||||
provider: Optional[LLMProvider] = None,
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
"""Send chat messages using specified or default provider"""
|
||||
|
||||
provider = provider or self.default_provider
|
||||
adapter = self._get_adapter(provider)
|
||||
|
||||
# Normalize messages to ChatMessage objects
|
||||
normalized_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, ChatMessage):
|
||||
normalized_messages.append(msg)
|
||||
elif isinstance(msg, dict):
|
||||
normalized_messages.append(ChatMessage.from_dict(msg))
|
||||
else:
|
||||
raise ValueError(f"Invalid message type: {type(msg)}")
|
||||
|
||||
return await adapter.chat(model, normalized_messages, stream, **kwargs)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
model: str,
|
||||
messages: Union[List[ChatMessage], List[Dict[str, str]]],
|
||||
provider: Optional[LLMProvider] = None,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[ChatResponse, None]:
|
||||
"""Stream chat messages using specified or default provider"""
|
||||
|
||||
result = await self.chat(model, messages, provider, stream=True, **kwargs)
|
||||
# Type checker now knows this is an AsyncGenerator due to stream=True
|
||||
async for chunk in result: # type: ignore
|
||||
yield chunk
|
||||
|
||||
async def chat_single(
|
||||
self,
|
||||
model: str,
|
||||
messages: Union[List[ChatMessage], List[Dict[str, str]]],
|
||||
provider: Optional[LLMProvider] = None,
|
||||
**kwargs
|
||||
) -> ChatResponse:
|
||||
"""Get single chat response using specified or default provider"""
|
||||
|
||||
result = await self.chat(model, messages, provider, stream=False, **kwargs)
|
||||
# Type checker now knows this is a ChatResponse due to stream=False
|
||||
return result # type: ignore
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
provider: Optional[LLMProvider] = None,
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||
"""Generate text using specified or default provider"""
|
||||
|
||||
provider = provider or self.default_provider
|
||||
adapter = self._get_adapter(provider)
|
||||
|
||||
return await adapter.generate(model, prompt, stream, **kwargs)
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
provider: Optional[LLMProvider] = None,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[ChatResponse, None]:
|
||||
"""Stream text generation using specified or default provider"""
|
||||
|
||||
result = await self.generate(model, prompt, provider, stream=True, **kwargs)
|
||||
async for chunk in result: # type: ignore
|
||||
yield chunk
|
||||
|
||||
async def generate_single(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
provider: Optional[LLMProvider] = None,
|
||||
**kwargs
|
||||
) -> ChatResponse:
|
||||
"""Get single generation response using specified or default provider"""
|
||||
|
||||
result = await self.generate(model, prompt, provider, stream=False, **kwargs)
|
||||
return result # type: ignore
|
||||
|
||||
async def list_models(self, provider: Optional[LLMProvider] = None) -> List[str]:
|
||||
"""List available models for specified or default provider"""
|
||||
|
||||
provider = provider or self.default_provider
|
||||
adapter = self._get_adapter(provider)
|
||||
|
||||
return await adapter.list_models()
|
||||
|
||||
def _get_adapter(self, provider: LLMProvider) -> BaseLLMAdapter:
|
||||
"""Get adapter for specified provider"""
|
||||
if provider not in self.adapters:
|
||||
raise ValueError(f"Provider {provider} not configured")
|
||||
return self.adapters[provider]
|
||||
|
||||
# Example usage and configuration
|
||||
class LLMManager:
|
||||
"""Singleton manager for the unified LLM proxy"""
|
||||
|
||||
_instance = None
|
||||
_proxy = None
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if LLMManager._proxy is None:
|
||||
LLMManager._proxy = UnifiedLLMProxy()
|
||||
self._configure_from_environment()
|
||||
|
||||
def _configure_from_environment(self):
|
||||
"""Configure providers based on environment variables"""
|
||||
|
||||
if not self._proxy:
|
||||
raise RuntimeError("LLM proxy not initialized")
|
||||
|
||||
# Configure Ollama if available
|
||||
ollama_host = os.getenv('OLLAMA_HOST', 'http://localhost:11434')
|
||||
self._proxy.configure_provider(LLMProvider.OLLAMA, host=ollama_host)
|
||||
|
||||
# Configure OpenAI if API key is available
|
||||
if os.getenv('OPENAI_API_KEY'):
|
||||
self._proxy.configure_provider(LLMProvider.OPENAI)
|
||||
|
||||
# Configure Anthropic if API key is available
|
||||
if os.getenv('ANTHROPIC_API_KEY'):
|
||||
self._proxy.configure_provider(LLMProvider.ANTHROPIC)
|
||||
|
||||
# Configure Gemini if API key is available
|
||||
if os.getenv('GEMINI_API_KEY'):
|
||||
self._proxy.configure_provider(LLMProvider.GEMINI)
|
||||
|
||||
# Set default provider from environment or use Ollama
|
||||
default_provider = os.getenv('DEFAULT_LLM_PROVIDER', 'ollama')
|
||||
try:
|
||||
self._proxy.set_default_provider(LLMProvider(default_provider))
|
||||
except ValueError:
|
||||
# Fallback to Ollama if specified provider not available
|
||||
self._proxy.set_default_provider(LLMProvider.OLLAMA)
|
||||
|
||||
def get_proxy(self) -> UnifiedLLMProxy:
|
||||
"""Get the unified LLM proxy instance"""
|
||||
if not self._proxy:
|
||||
raise RuntimeError("LLM proxy not initialized")
|
||||
return self._proxy
|
||||
|
||||
# Convenience function for easy access
|
||||
def get_llm() -> UnifiedLLMProxy:
|
||||
"""Get the configured LLM proxy"""
|
||||
return LLMManager.get_instance().get_proxy()
|
Loading…
x
Reference in New Issue
Block a user