Multi-LLM backend working with Ollama
This commit is contained in:
parent
5fed56ba76
commit
8aa8577874
541
frontend/src/components/JobCreator.tsx
Normal file
541
frontend/src/components/JobCreator.tsx
Normal file
@ -0,0 +1,541 @@
|
|||||||
|
import React, { useState, useEffect, useRef, JSX } from 'react';
|
||||||
|
import {
|
||||||
|
Box,
|
||||||
|
Button,
|
||||||
|
Typography,
|
||||||
|
Paper,
|
||||||
|
TextField,
|
||||||
|
Grid,
|
||||||
|
Dialog,
|
||||||
|
DialogTitle,
|
||||||
|
DialogContent,
|
||||||
|
DialogContentText,
|
||||||
|
DialogActions,
|
||||||
|
IconButton,
|
||||||
|
useTheme,
|
||||||
|
useMediaQuery,
|
||||||
|
Chip,
|
||||||
|
Divider,
|
||||||
|
Card,
|
||||||
|
CardContent,
|
||||||
|
CardHeader,
|
||||||
|
LinearProgress,
|
||||||
|
Stack,
|
||||||
|
Alert
|
||||||
|
} from '@mui/material';
|
||||||
|
import {
|
||||||
|
SyncAlt,
|
||||||
|
Favorite,
|
||||||
|
Settings,
|
||||||
|
Info,
|
||||||
|
Search,
|
||||||
|
AutoFixHigh,
|
||||||
|
Image,
|
||||||
|
Psychology,
|
||||||
|
Build,
|
||||||
|
CloudUpload,
|
||||||
|
Description,
|
||||||
|
Business,
|
||||||
|
LocationOn,
|
||||||
|
Work,
|
||||||
|
CheckCircle,
|
||||||
|
Star
|
||||||
|
} from '@mui/icons-material';
|
||||||
|
import { styled } from '@mui/material/styles';
|
||||||
|
import DescriptionIcon from '@mui/icons-material/Description';
|
||||||
|
import FileUploadIcon from '@mui/icons-material/FileUpload';
|
||||||
|
|
||||||
|
import { useAuth } from 'hooks/AuthContext';
|
||||||
|
import { useSelectedCandidate, useSelectedJob } from 'hooks/GlobalContext';
|
||||||
|
import { BackstoryElementProps } from './BackstoryTab';
|
||||||
|
import { LoginRequired } from 'components/ui/LoginRequired';
|
||||||
|
|
||||||
|
import * as Types from 'types/types';
|
||||||
|
|
||||||
|
const VisuallyHiddenInput = styled('input')({
|
||||||
|
clip: 'rect(0 0 0 0)',
|
||||||
|
clipPath: 'inset(50%)',
|
||||||
|
height: 1,
|
||||||
|
overflow: 'hidden',
|
||||||
|
position: 'absolute',
|
||||||
|
bottom: 0,
|
||||||
|
left: 0,
|
||||||
|
whiteSpace: 'nowrap',
|
||||||
|
width: 1,
|
||||||
|
});
|
||||||
|
|
||||||
|
const UploadBox = styled(Box)(({ theme }) => ({
|
||||||
|
border: `2px dashed ${theme.palette.primary.main}`,
|
||||||
|
borderRadius: theme.shape.borderRadius * 2,
|
||||||
|
padding: theme.spacing(4),
|
||||||
|
textAlign: 'center',
|
||||||
|
backgroundColor: theme.palette.action.hover,
|
||||||
|
transition: 'all 0.3s ease',
|
||||||
|
cursor: 'pointer',
|
||||||
|
'&:hover': {
|
||||||
|
backgroundColor: theme.palette.action.selected,
|
||||||
|
borderColor: theme.palette.primary.dark,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
const StatusBox = styled(Box)(({ theme }) => ({
|
||||||
|
display: 'flex',
|
||||||
|
alignItems: 'center',
|
||||||
|
gap: theme.spacing(1),
|
||||||
|
padding: theme.spacing(1, 2),
|
||||||
|
backgroundColor: theme.palette.background.paper,
|
||||||
|
borderRadius: theme.shape.borderRadius,
|
||||||
|
border: `1px solid ${theme.palette.divider}`,
|
||||||
|
minHeight: 48,
|
||||||
|
}));
|
||||||
|
|
||||||
|
const getIcon = (type: Types.ApiActivityType) => {
|
||||||
|
switch (type) {
|
||||||
|
case 'converting':
|
||||||
|
return <SyncAlt color="primary" />;
|
||||||
|
case 'heartbeat':
|
||||||
|
return <Favorite color="error" />;
|
||||||
|
case 'system':
|
||||||
|
return <Settings color="action" />;
|
||||||
|
case 'info':
|
||||||
|
return <Info color="info" />;
|
||||||
|
case 'searching':
|
||||||
|
return <Search color="primary" />;
|
||||||
|
case 'generating':
|
||||||
|
return <AutoFixHigh color="secondary" />;
|
||||||
|
case 'generating_image':
|
||||||
|
return <Image color="primary" />;
|
||||||
|
case 'thinking':
|
||||||
|
return <Psychology color="secondary" />;
|
||||||
|
case 'tooling':
|
||||||
|
return <Build color="action" />;
|
||||||
|
default:
|
||||||
|
return <Info color="action" />;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
interface JobCreator extends BackstoryElementProps {
|
||||||
|
onSave?: (job: Types.Job) => void;
|
||||||
|
}
|
||||||
|
const JobCreator = (props: JobCreator) => {
|
||||||
|
const { user, apiClient } = useAuth();
|
||||||
|
const { onSave } = props;
|
||||||
|
const { selectedCandidate } = useSelectedCandidate();
|
||||||
|
const { selectedJob, setSelectedJob } = useSelectedJob();
|
||||||
|
const { setSnack, submitQuery } = props;
|
||||||
|
const backstoryProps = { setSnack, submitQuery };
|
||||||
|
const theme = useTheme();
|
||||||
|
const isMobile = useMediaQuery(theme.breakpoints.down('sm'));
|
||||||
|
const isTablet = useMediaQuery(theme.breakpoints.down('md'));
|
||||||
|
|
||||||
|
const [openUploadDialog, setOpenUploadDialog] = useState<boolean>(false);
|
||||||
|
const [jobDescription, setJobDescription] = useState<string>('');
|
||||||
|
const [jobRequirements, setJobRequirements] = useState<Types.JobRequirements | null>(null);
|
||||||
|
const [jobTitle, setJobTitle] = useState<string>('');
|
||||||
|
const [company, setCompany] = useState<string>('');
|
||||||
|
const [summary, setSummary] = useState<string>('');
|
||||||
|
const [jobLocation, setJobLocation] = useState<string>('');
|
||||||
|
const [jobId, setJobId] = useState<string>('');
|
||||||
|
const [jobStatus, setJobStatus] = useState<string>('');
|
||||||
|
const [jobStatusIcon, setJobStatusIcon] = useState<JSX.Element>(<></>);
|
||||||
|
const [isProcessing, setIsProcessing] = useState<boolean>(false);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
|
||||||
|
}, [jobTitle, jobDescription, company]);
|
||||||
|
|
||||||
|
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||||
|
|
||||||
|
|
||||||
|
if (!user?.id) {
|
||||||
|
return (
|
||||||
|
<LoginRequired asset="job creation" />
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const jobStatusHandlers = {
|
||||||
|
onStatus: (status: Types.ChatMessageStatus) => {
|
||||||
|
console.log('status:', status.content);
|
||||||
|
setJobStatusIcon(getIcon(status.activity));
|
||||||
|
setJobStatus(status.content);
|
||||||
|
},
|
||||||
|
onMessage: (job: Types.Job) => {
|
||||||
|
console.log('onMessage - job', job);
|
||||||
|
setCompany(job.company || '');
|
||||||
|
setJobDescription(job.description);
|
||||||
|
setSummary(job.summary || '');
|
||||||
|
setJobTitle(job.title || '');
|
||||||
|
setJobRequirements(job.requirements || null);
|
||||||
|
setJobStatusIcon(<></>);
|
||||||
|
setJobStatus('');
|
||||||
|
},
|
||||||
|
onError: (error: Types.ChatMessageError) => {
|
||||||
|
console.log('onError', error);
|
||||||
|
setSnack(error.content, "error");
|
||||||
|
setIsProcessing(false);
|
||||||
|
},
|
||||||
|
onComplete: () => {
|
||||||
|
setJobStatusIcon(<></>);
|
||||||
|
setJobStatus('');
|
||||||
|
setIsProcessing(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleJobUpload = async (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||||
|
if (e.target.files && e.target.files[0]) {
|
||||||
|
const file = e.target.files[0];
|
||||||
|
const fileExtension = '.' + file.name.split('.').pop()?.toLowerCase();
|
||||||
|
let docType: Types.DocumentType | null = null;
|
||||||
|
switch (fileExtension.substring(1)) {
|
||||||
|
case "pdf":
|
||||||
|
docType = "pdf";
|
||||||
|
break;
|
||||||
|
case "docx":
|
||||||
|
docType = "docx";
|
||||||
|
break;
|
||||||
|
case "md":
|
||||||
|
docType = "markdown";
|
||||||
|
break;
|
||||||
|
case "txt":
|
||||||
|
docType = "txt";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!docType) {
|
||||||
|
setSnack('Invalid file type. Please upload .txt, .md, .docx, or .pdf files only.', 'error');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
setIsProcessing(true);
|
||||||
|
setJobDescription('');
|
||||||
|
setJobTitle('');
|
||||||
|
setJobRequirements(null);
|
||||||
|
setSummary('');
|
||||||
|
const controller = apiClient.createJobFromFile(file, jobStatusHandlers);
|
||||||
|
const job = await controller.promise;
|
||||||
|
if (!job) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
console.log(`Job id: ${job.id}`);
|
||||||
|
e.target.value = '';
|
||||||
|
} catch (error) {
|
||||||
|
console.error(error);
|
||||||
|
setSnack('Failed to upload document', 'error');
|
||||||
|
setIsProcessing(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleUploadClick = () => {
|
||||||
|
fileInputRef.current?.click();
|
||||||
|
};
|
||||||
|
|
||||||
|
const renderRequirementSection = (title: string, items: string[] | undefined, icon: JSX.Element, required = false) => {
|
||||||
|
if (!items || items.length === 0) return null;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box sx={{ mb: 3 }}>
|
||||||
|
<Box sx={{ display: 'flex', alignItems: 'center', mb: 1.5 }}>
|
||||||
|
{icon}
|
||||||
|
<Typography variant="subtitle1" sx={{ ml: 1, fontWeight: 600 }}>
|
||||||
|
{title}
|
||||||
|
</Typography>
|
||||||
|
{required && <Chip label="Required" size="small" color="error" sx={{ ml: 1 }} />}
|
||||||
|
</Box>
|
||||||
|
<Stack direction="row" spacing={1} flexWrap="wrap" useFlexGap>
|
||||||
|
{items.map((item, index) => (
|
||||||
|
<Chip
|
||||||
|
key={index}
|
||||||
|
label={item}
|
||||||
|
variant="outlined"
|
||||||
|
size="small"
|
||||||
|
sx={{ mb: 1 }}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</Stack>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
const renderJobRequirements = () => {
|
||||||
|
if (!jobRequirements) return null;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Card elevation={2} sx={{ mt: 3 }}>
|
||||||
|
<CardHeader
|
||||||
|
title="Job Requirements Analysis"
|
||||||
|
avatar={<CheckCircle color="success" />}
|
||||||
|
sx={{ pb: 1 }}
|
||||||
|
/>
|
||||||
|
<CardContent sx={{ pt: 0 }}>
|
||||||
|
{renderRequirementSection(
|
||||||
|
"Technical Skills (Required)",
|
||||||
|
jobRequirements.technicalSkills.required,
|
||||||
|
<Build color="primary" />,
|
||||||
|
true
|
||||||
|
)}
|
||||||
|
{renderRequirementSection(
|
||||||
|
"Technical Skills (Preferred)",
|
||||||
|
jobRequirements.technicalSkills.preferred,
|
||||||
|
<Build color="action" />
|
||||||
|
)}
|
||||||
|
{renderRequirementSection(
|
||||||
|
"Experience Requirements (Required)",
|
||||||
|
jobRequirements.experienceRequirements.required,
|
||||||
|
<Work color="primary" />,
|
||||||
|
true
|
||||||
|
)}
|
||||||
|
{renderRequirementSection(
|
||||||
|
"Experience Requirements (Preferred)",
|
||||||
|
jobRequirements.experienceRequirements.preferred,
|
||||||
|
<Work color="action" />
|
||||||
|
)}
|
||||||
|
{renderRequirementSection(
|
||||||
|
"Soft Skills",
|
||||||
|
jobRequirements.softSkills,
|
||||||
|
<Psychology color="secondary" />
|
||||||
|
)}
|
||||||
|
{renderRequirementSection(
|
||||||
|
"Experience",
|
||||||
|
jobRequirements.experience,
|
||||||
|
<Star color="warning" />
|
||||||
|
)}
|
||||||
|
{renderRequirementSection(
|
||||||
|
"Education",
|
||||||
|
jobRequirements.education,
|
||||||
|
<Description color="info" />
|
||||||
|
)}
|
||||||
|
{renderRequirementSection(
|
||||||
|
"Certifications",
|
||||||
|
jobRequirements.certifications,
|
||||||
|
<CheckCircle color="success" />
|
||||||
|
)}
|
||||||
|
{renderRequirementSection(
|
||||||
|
"Preferred Attributes",
|
||||||
|
jobRequirements.preferredAttributes,
|
||||||
|
<Star color="secondary" />
|
||||||
|
)}
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleSave = async () => {
|
||||||
|
const newJob: Types.Job = {
|
||||||
|
ownerId: user?.id || '',
|
||||||
|
ownerType: 'candidate',
|
||||||
|
description: jobDescription,
|
||||||
|
company: company,
|
||||||
|
summary: summary,
|
||||||
|
title: jobTitle,
|
||||||
|
requirements: jobRequirements || undefined
|
||||||
|
};
|
||||||
|
setIsProcessing(true);
|
||||||
|
const job = await apiClient.createJob(newJob);
|
||||||
|
setIsProcessing(false);
|
||||||
|
onSave ? onSave(job) : setSelectedJob(job);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleExtractRequirements = () => {
|
||||||
|
// Implement requirements extraction logic here
|
||||||
|
setIsProcessing(true);
|
||||||
|
// This would call your API to extract requirements from the job description
|
||||||
|
};
|
||||||
|
|
||||||
|
const renderJobCreation = () => {
|
||||||
|
if (!user) {
|
||||||
|
return <Box>You must be logged in</Box>;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box sx={{
|
||||||
|
mx: 'auto', p: { xs: 2, sm: 3 },
|
||||||
|
}}>
|
||||||
|
{/* Upload Section */}
|
||||||
|
<Card elevation={3} sx={{ mb: 4 }}>
|
||||||
|
<CardHeader
|
||||||
|
title="Job Information"
|
||||||
|
subheader="Upload a job description or enter details manually"
|
||||||
|
avatar={<Work color="primary" />}
|
||||||
|
/>
|
||||||
|
<CardContent>
|
||||||
|
<Grid container spacing={3}>
|
||||||
|
<Grid size={{ xs: 12, md: 6 }}>
|
||||||
|
<Typography variant="h6" gutterBottom sx={{ display: 'flex', alignItems: 'center' }}>
|
||||||
|
<CloudUpload sx={{ mr: 1 }} />
|
||||||
|
Upload Job Description
|
||||||
|
</Typography>
|
||||||
|
<UploadBox onClick={handleUploadClick}>
|
||||||
|
<CloudUpload sx={{ fontSize: 48, color: 'primary.main', mb: 2 }} />
|
||||||
|
<Typography variant="h6" gutterBottom>
|
||||||
|
Drop your job description here
|
||||||
|
</Typography>
|
||||||
|
<Typography variant="body2" color="text.secondary" sx={{ mb: 2 }}>
|
||||||
|
Supported formats: PDF, DOCX, TXT, MD
|
||||||
|
</Typography>
|
||||||
|
<Button
|
||||||
|
variant="contained"
|
||||||
|
startIcon={<FileUploadIcon />}
|
||||||
|
disabled={isProcessing}
|
||||||
|
// onClick={handleUploadClick}
|
||||||
|
>
|
||||||
|
Choose File
|
||||||
|
</Button>
|
||||||
|
</UploadBox>
|
||||||
|
<VisuallyHiddenInput
|
||||||
|
ref={fileInputRef}
|
||||||
|
type="file"
|
||||||
|
accept=".txt,.md,.docx,.pdf"
|
||||||
|
onChange={handleJobUpload}
|
||||||
|
/>
|
||||||
|
</Grid>
|
||||||
|
|
||||||
|
<Grid size={{ xs: 12, md: 6 }}>
|
||||||
|
<Typography variant="h6" gutterBottom sx={{ display: 'flex', alignItems: 'center' }}>
|
||||||
|
<Description sx={{ mr: 1 }} />
|
||||||
|
Or Enter Manually
|
||||||
|
</Typography>
|
||||||
|
<TextField
|
||||||
|
fullWidth
|
||||||
|
multiline
|
||||||
|
rows={isMobile ? 8 : 12}
|
||||||
|
placeholder="Paste or type the job description here..."
|
||||||
|
variant="outlined"
|
||||||
|
value={jobDescription}
|
||||||
|
onChange={(e) => setJobDescription(e.target.value)}
|
||||||
|
disabled={isProcessing}
|
||||||
|
sx={{ mb: 2 }}
|
||||||
|
/>
|
||||||
|
{jobRequirements === null && jobDescription && (
|
||||||
|
<Button
|
||||||
|
variant="outlined"
|
||||||
|
onClick={handleExtractRequirements}
|
||||||
|
startIcon={<AutoFixHigh />}
|
||||||
|
disabled={isProcessing}
|
||||||
|
fullWidth={isMobile}
|
||||||
|
>
|
||||||
|
Extract Requirements
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</Grid>
|
||||||
|
</Grid>
|
||||||
|
|
||||||
|
{(jobStatus || isProcessing) && (
|
||||||
|
<Box sx={{ mt: 3 }}>
|
||||||
|
<StatusBox>
|
||||||
|
{jobStatusIcon}
|
||||||
|
<Typography variant="body2" sx={{ ml: 1 }}>
|
||||||
|
{jobStatus || 'Processing...'}
|
||||||
|
</Typography>
|
||||||
|
</StatusBox>
|
||||||
|
{isProcessing && <LinearProgress sx={{ mt: 1 }} />}
|
||||||
|
</Box>
|
||||||
|
)}
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
|
||||||
|
{/* Job Details Section */}
|
||||||
|
<Card elevation={3} sx={{ mb: 4 }}>
|
||||||
|
<CardHeader
|
||||||
|
title="Job Details"
|
||||||
|
subheader="Enter specific information about the position"
|
||||||
|
avatar={<Business color="primary" />}
|
||||||
|
/>
|
||||||
|
<CardContent>
|
||||||
|
<Grid container spacing={3}>
|
||||||
|
<Grid size={{ xs: 12, md: 6 }}>
|
||||||
|
<TextField
|
||||||
|
fullWidth
|
||||||
|
label="Job Title"
|
||||||
|
variant="outlined"
|
||||||
|
value={jobTitle}
|
||||||
|
onChange={(e) => setJobTitle(e.target.value)}
|
||||||
|
required
|
||||||
|
disabled={isProcessing}
|
||||||
|
InputProps={{
|
||||||
|
startAdornment: <Work sx={{ mr: 1, color: 'text.secondary' }} />
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Grid>
|
||||||
|
|
||||||
|
<Grid size={{ xs: 12, md: 6 }}>
|
||||||
|
<TextField
|
||||||
|
fullWidth
|
||||||
|
label="Company"
|
||||||
|
variant="outlined"
|
||||||
|
value={company}
|
||||||
|
onChange={(e) => setCompany(e.target.value)}
|
||||||
|
required
|
||||||
|
disabled={isProcessing}
|
||||||
|
InputProps={{
|
||||||
|
startAdornment: <Business sx={{ mr: 1, color: 'text.secondary' }} />
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Grid>
|
||||||
|
|
||||||
|
{/* <Grid size={{ xs: 12, md: 6 }}>
|
||||||
|
<TextField
|
||||||
|
fullWidth
|
||||||
|
label="Job Location"
|
||||||
|
variant="outlined"
|
||||||
|
value={jobLocation}
|
||||||
|
onChange={(e) => setJobLocation(e.target.value)}
|
||||||
|
disabled={isProcessing}
|
||||||
|
InputProps={{
|
||||||
|
startAdornment: <LocationOn sx={{ mr: 1, color: 'text.secondary' }} />
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Grid> */}
|
||||||
|
|
||||||
|
<Grid size={{ xs: 12, md: 6 }}>
|
||||||
|
<Box sx={{ display: 'flex', gap: 2, alignItems: 'flex-end', height: '100%' }}>
|
||||||
|
<Button
|
||||||
|
variant="contained"
|
||||||
|
onClick={handleSave}
|
||||||
|
disabled={!jobTitle || !company || !jobDescription || isProcessing}
|
||||||
|
fullWidth={isMobile}
|
||||||
|
size="large"
|
||||||
|
startIcon={<CheckCircle />}
|
||||||
|
>
|
||||||
|
Save Job
|
||||||
|
</Button>
|
||||||
|
</Box>
|
||||||
|
</Grid>
|
||||||
|
</Grid>
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
|
||||||
|
{/* Job Summary */}
|
||||||
|
{summary !== '' &&
|
||||||
|
<Card elevation={2} sx={{ mt: 3 }}>
|
||||||
|
<CardHeader
|
||||||
|
title="Job Summary"
|
||||||
|
avatar={<CheckCircle color="success" />}
|
||||||
|
sx={{ pb: 1 }}
|
||||||
|
/>
|
||||||
|
<CardContent sx={{ pt: 0 }}>
|
||||||
|
{summary}
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
}
|
||||||
|
|
||||||
|
{/* Requirements Display */}
|
||||||
|
{renderJobRequirements()}
|
||||||
|
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box className="JobManagement"
|
||||||
|
sx={{
|
||||||
|
background: "white",
|
||||||
|
p: 0,
|
||||||
|
}}>
|
||||||
|
{selectedJob === null && renderJobCreation()}
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export { JobCreator };
|
@ -115,34 +115,23 @@ const getIcon = (type: Types.ApiActivityType) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const JobManagement = (props: BackstoryElementProps) => {
|
const JobManagement = (props: BackstoryElementProps) => {
|
||||||
const { user, apiClient } = useAuth();
|
const { user, apiClient } = useAuth();
|
||||||
const { selectedCandidate } = useSelectedCandidate();
|
|
||||||
const { selectedJob, setSelectedJob } = useSelectedJob();
|
const { selectedJob, setSelectedJob } = useSelectedJob();
|
||||||
const { setSnack, submitQuery } = props;
|
const { setSnack, submitQuery } = props;
|
||||||
const backstoryProps = { setSnack, submitQuery };
|
|
||||||
const theme = useTheme();
|
const theme = useTheme();
|
||||||
const isMobile = useMediaQuery(theme.breakpoints.down('sm'));
|
const isMobile = useMediaQuery(theme.breakpoints.down('sm'));
|
||||||
const isTablet = useMediaQuery(theme.breakpoints.down('md'));
|
|
||||||
|
|
||||||
const [openUploadDialog, setOpenUploadDialog] = useState<boolean>(false);
|
|
||||||
const [jobDescription, setJobDescription] = useState<string>('');
|
const [jobDescription, setJobDescription] = useState<string>('');
|
||||||
const [jobRequirements, setJobRequirements] = useState<Types.JobRequirements | null>(null);
|
const [jobRequirements, setJobRequirements] = useState<Types.JobRequirements | null>(null);
|
||||||
const [jobTitle, setJobTitle] = useState<string>('');
|
const [jobTitle, setJobTitle] = useState<string>('');
|
||||||
const [company, setCompany] = useState<string>('');
|
const [company, setCompany] = useState<string>('');
|
||||||
const [summary, setSummary] = useState<string>('');
|
const [summary, setSummary] = useState<string>('');
|
||||||
const [jobLocation, setJobLocation] = useState<string>('');
|
|
||||||
const [jobId, setJobId] = useState<string>('');
|
|
||||||
const [jobStatus, setJobStatus] = useState<string>('');
|
const [jobStatus, setJobStatus] = useState<string>('');
|
||||||
const [jobStatusIcon, setJobStatusIcon] = useState<JSX.Element>(<></>);
|
const [jobStatusIcon, setJobStatusIcon] = useState<JSX.Element>(<></>);
|
||||||
const [isProcessing, setIsProcessing] = useState<boolean>(false);
|
const [isProcessing, setIsProcessing] = useState<boolean>(false);
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
|
|
||||||
}, [jobTitle, jobDescription, company]);
|
|
||||||
|
|
||||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||||
|
|
||||||
|
|
||||||
if (!user?.id) {
|
if (!user?.id) {
|
||||||
return (
|
return (
|
||||||
<LoginRequired asset="candidate analysis" />
|
<LoginRequired asset="candidate analysis" />
|
||||||
@ -223,12 +212,12 @@ const JobManagement = (props: BackstoryElementProps) => {
|
|||||||
setJobTitle('');
|
setJobTitle('');
|
||||||
setJobRequirements(null);
|
setJobRequirements(null);
|
||||||
setSummary('');
|
setSummary('');
|
||||||
const controller = apiClient.uploadCandidateDocument(file, { isJobDocument: true, overwrite: true }, documentStatusHandlers);
|
const controller = apiClient.createJobFromFile(file, jobStatusHandlers);
|
||||||
const document = await controller.promise;
|
const job = await controller.promise;
|
||||||
if (!document) {
|
if (!job) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
console.log(`Document id: ${document.id}`);
|
console.log(`Job id: ${job.id}`);
|
||||||
e.target.value = '';
|
e.target.value = '';
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(error);
|
console.error(error);
|
||||||
@ -354,10 +343,6 @@ const JobManagement = (props: BackstoryElementProps) => {
|
|||||||
// This would call your API to extract requirements from the job description
|
// This would call your API to extract requirements from the job description
|
||||||
};
|
};
|
||||||
|
|
||||||
const loadJob = async () => {
|
|
||||||
const job = await apiClient.getJob("7594e989-a926-45a2-9b07-ae553d2e0d0d");
|
|
||||||
setSelectedJob(job);
|
|
||||||
}
|
|
||||||
const renderJobCreation = () => {
|
const renderJobCreation = () => {
|
||||||
if (!user) {
|
if (!user) {
|
||||||
return <Box>You must be logged in</Box>;
|
return <Box>You must be logged in</Box>;
|
||||||
@ -367,7 +352,6 @@ const JobManagement = (props: BackstoryElementProps) => {
|
|||||||
<Box sx={{
|
<Box sx={{
|
||||||
mx: 'auto', p: { xs: 2, sm: 3 },
|
mx: 'auto', p: { xs: 2, sm: 3 },
|
||||||
}}>
|
}}>
|
||||||
<Button onClick={loadJob} variant="contained">Load Job</Button>
|
|
||||||
{/* Upload Section */}
|
{/* Upload Section */}
|
||||||
<Card elevation={3} sx={{ mb: 4 }}>
|
<Card elevation={3} sx={{ mb: 4 }}>
|
||||||
<CardHeader
|
<CardHeader
|
||||||
|
@ -14,7 +14,8 @@ import {
|
|||||||
CardContent,
|
CardContent,
|
||||||
useTheme,
|
useTheme,
|
||||||
LinearProgress,
|
LinearProgress,
|
||||||
useMediaQuery
|
useMediaQuery,
|
||||||
|
Button
|
||||||
} from '@mui/material';
|
} from '@mui/material';
|
||||||
import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
|
import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
|
||||||
import CheckCircleIcon from '@mui/icons-material/CheckCircle';
|
import CheckCircleIcon from '@mui/icons-material/CheckCircle';
|
||||||
@ -28,6 +29,8 @@ import { toCamelCase } from 'types/conversion';
|
|||||||
import { Job } from 'types/types';
|
import { Job } from 'types/types';
|
||||||
import { StyledMarkdown } from './StyledMarkdown';
|
import { StyledMarkdown } from './StyledMarkdown';
|
||||||
import { Scrollable } from './Scrollable';
|
import { Scrollable } from './Scrollable';
|
||||||
|
import { start } from 'repl';
|
||||||
|
import { TypesElement } from '@uiw/react-json-view';
|
||||||
|
|
||||||
interface JobAnalysisProps extends BackstoryPageProps {
|
interface JobAnalysisProps extends BackstoryPageProps {
|
||||||
job: Job;
|
job: Job;
|
||||||
@ -56,6 +59,9 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
|
|||||||
const [overallScore, setOverallScore] = useState<number>(0);
|
const [overallScore, setOverallScore] = useState<number>(0);
|
||||||
const [requirementsSession, setRequirementsSession] = useState<ChatSession | null>(null);
|
const [requirementsSession, setRequirementsSession] = useState<ChatSession | null>(null);
|
||||||
const [statusMessage, setStatusMessage] = useState<ChatMessage | null>(null);
|
const [statusMessage, setStatusMessage] = useState<ChatMessage | null>(null);
|
||||||
|
const [startAnalysis, setStartAnalysis] = useState<boolean>(false);
|
||||||
|
const [analyzing, setAnalyzing] = useState<boolean>(false);
|
||||||
|
|
||||||
const isMobile = useMediaQuery(theme.breakpoints.down('sm'));
|
const isMobile = useMediaQuery(theme.breakpoints.down('sm'));
|
||||||
|
|
||||||
// Handle accordion expansion
|
// Handle accordion expansion
|
||||||
@ -63,7 +69,7 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
|
|||||||
setExpanded(isExpanded ? panel : false);
|
setExpanded(isExpanded ? panel : false);
|
||||||
};
|
};
|
||||||
|
|
||||||
useEffect(() => {
|
const initializeRequirements = (job: Job) => {
|
||||||
if (!job || !job.requirements) {
|
if (!job || !job.requirements) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -106,116 +112,19 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
|
|||||||
setSkillMatches(initialSkillMatches);
|
setSkillMatches(initialSkillMatches);
|
||||||
setStatusMessage(null);
|
setStatusMessage(null);
|
||||||
setLoadingRequirements(false);
|
setLoadingRequirements(false);
|
||||||
|
setOverallScore(0);
|
||||||
}, [job, setRequirements]);
|
}
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (requirementsSession || creatingSession) {
|
initializeRequirements(job);
|
||||||
return;
|
}, [job]);
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
setCreatingSession(true);
|
|
||||||
apiClient.getOrCreateChatSession(candidate, `Generate requirements for ${candidate.fullName}`, 'job_requirements')
|
|
||||||
.then(session => {
|
|
||||||
setRequirementsSession(session);
|
|
||||||
setCreatingSession(false);
|
|
||||||
});
|
|
||||||
} catch (error) {
|
|
||||||
setSnack('Unable to load chat session', 'error');
|
|
||||||
} finally {
|
|
||||||
setCreatingSession(false);
|
|
||||||
}
|
|
||||||
|
|
||||||
}, [requirementsSession, apiClient, candidate]);
|
|
||||||
|
|
||||||
// Fetch initial requirements
|
|
||||||
useEffect(() => {
|
|
||||||
if (!job.description || !requirementsSession || loadingRequirements) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const getRequirements = async () => {
|
|
||||||
setLoadingRequirements(true);
|
|
||||||
try {
|
|
||||||
const chatMessage: ChatMessageUser = { ...defaultMessage, sessionId: requirementsSession.id || '', content: job.description };
|
|
||||||
apiClient.sendMessageStream(chatMessage, {
|
|
||||||
onMessage: (msg: ChatMessage) => {
|
|
||||||
console.log(`onMessage: ${msg.type}`, msg);
|
|
||||||
const job: Job = toCamelCase<Job>(JSON.parse(msg.content || ''));
|
|
||||||
const requirements: { requirement: string, domain: string }[] = [];
|
|
||||||
if (job.requirements?.technicalSkills) {
|
|
||||||
job.requirements.technicalSkills.required?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (required)' }));
|
|
||||||
job.requirements.technicalSkills.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (preferred)' }));
|
|
||||||
}
|
|
||||||
if (job.requirements?.experienceRequirements) {
|
|
||||||
job.requirements.experienceRequirements.required?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (required)' }));
|
|
||||||
job.requirements.experienceRequirements.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (preferred)' }));
|
|
||||||
}
|
|
||||||
if (job.requirements?.softSkills) {
|
|
||||||
job.requirements.softSkills.forEach(req => requirements.push({ requirement: req, domain: 'Soft Skills' }));
|
|
||||||
}
|
|
||||||
if (job.requirements?.experience) {
|
|
||||||
job.requirements.experience.forEach(req => requirements.push({ requirement: req, domain: 'Experience' }));
|
|
||||||
}
|
|
||||||
if (job.requirements?.education) {
|
|
||||||
job.requirements.education.forEach(req => requirements.push({ requirement: req, domain: 'Education' }));
|
|
||||||
}
|
|
||||||
if (job.requirements?.certifications) {
|
|
||||||
job.requirements.certifications.forEach(req => requirements.push({ requirement: req, domain: 'Certifications' }));
|
|
||||||
}
|
|
||||||
if (job.requirements?.preferredAttributes) {
|
|
||||||
job.requirements.preferredAttributes.forEach(req => requirements.push({ requirement: req, domain: 'Preferred Attributes' }));
|
|
||||||
}
|
|
||||||
|
|
||||||
const initialSkillMatches = requirements.map(req => ({
|
|
||||||
requirement: req.requirement,
|
|
||||||
domain: req.domain,
|
|
||||||
status: 'waiting' as const,
|
|
||||||
matchScore: 0,
|
|
||||||
assessment: '',
|
|
||||||
description: '',
|
|
||||||
citations: []
|
|
||||||
}));
|
|
||||||
|
|
||||||
setRequirements(requirements);
|
|
||||||
setSkillMatches(initialSkillMatches);
|
|
||||||
setStatusMessage(null);
|
|
||||||
setLoadingRequirements(false);
|
|
||||||
},
|
|
||||||
onError: (error: string | ChatMessageError) => {
|
|
||||||
console.log("onError:", error);
|
|
||||||
// Type-guard to determine if this is a ChatMessageBase or a string
|
|
||||||
if (typeof error === "object" && error !== null && "content" in error) {
|
|
||||||
setSnack(error.content || 'Error obtaining requirements from job description.', "error");
|
|
||||||
} else {
|
|
||||||
setSnack(error as string, "error");
|
|
||||||
}
|
|
||||||
setLoadingRequirements(false);
|
|
||||||
},
|
|
||||||
onStreaming: (chunk: ChatMessageStreaming) => {
|
|
||||||
// console.log("onStreaming:", chunk);
|
|
||||||
},
|
|
||||||
onStatus: (status: ChatMessageStatus) => {
|
|
||||||
console.log(`onStatus: ${status}`);
|
|
||||||
},
|
|
||||||
onComplete: () => {
|
|
||||||
console.log("onComplete");
|
|
||||||
setStatusMessage(null);
|
|
||||||
setLoadingRequirements(false);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
} catch (error) {
|
|
||||||
console.error('Failed to send message:', error);
|
|
||||||
setLoadingRequirements(false);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
getRequirements();
|
|
||||||
}, [job, requirementsSession]);
|
|
||||||
|
|
||||||
// Fetch match data for each requirement
|
// Fetch match data for each requirement
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
if (!startAnalysis || analyzing || !job.requirements) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const fetchMatchData = async () => {
|
const fetchMatchData = async () => {
|
||||||
if (requirements.length === 0) return;
|
if (requirements.length === 0) return;
|
||||||
|
|
||||||
@ -279,10 +188,9 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if (!loadingRequirements) {
|
setAnalyzing(true);
|
||||||
fetchMatchData();
|
fetchMatchData().then(() => { setAnalyzing(false); setStartAnalysis(false) });
|
||||||
}
|
}, [job, startAnalysis, analyzing, requirements, loadingRequirements]);
|
||||||
}, [requirements, loadingRequirements]);
|
|
||||||
|
|
||||||
// Get color based on match score
|
// Get color based on match score
|
||||||
const getMatchColor = (score: number): string => {
|
const getMatchColor = (score: number): string => {
|
||||||
@ -301,6 +209,11 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
|
|||||||
return <ErrorIcon color="error" />;
|
return <ErrorIcon color="error" />;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const beginAnalysis = () => {
|
||||||
|
initializeRequirements(job);
|
||||||
|
setStartAnalysis(true);
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Box>
|
<Box>
|
||||||
<Paper elevation={3} sx={{ p: 3, mb: 4 }}>
|
<Paper elevation={3} sx={{ p: 3, mb: 4 }}>
|
||||||
@ -344,7 +257,9 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
|
|||||||
</Box>
|
</Box>
|
||||||
|
|
||||||
<Grid size={{ xs: 12 }} sx={{ mt: 2 }}>
|
<Grid size={{ xs: 12 }} sx={{ mt: 2 }}>
|
||||||
<Box sx={{ display: 'flex', alignItems: 'center', mb: 2 }}>
|
<Box sx={{ display: 'flex', alignItems: 'center', mb: 2, gap: 1 }}>
|
||||||
|
{<Button disabled={analyzing || startAnalysis} onClick={beginAnalysis} variant="contained">Start Analysis</Button>}
|
||||||
|
{overallScore !== 0 && <>
|
||||||
<Typography variant="h5" component="h2" sx={{ mr: 2 }}>
|
<Typography variant="h5" component="h2" sx={{ mr: 2 }}>
|
||||||
Overall Match:
|
Overall Match:
|
||||||
</Typography>
|
</Typography>
|
||||||
@ -391,6 +306,7 @@ const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) =
|
|||||||
fontWeight: 'bold'
|
fontWeight: 'bold'
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
</>}
|
||||||
</Box>
|
</Box>
|
||||||
</Grid>
|
</Grid>
|
||||||
</Grid>
|
</Grid>
|
||||||
|
@ -13,7 +13,7 @@ import { CopyBubble } from "components/CopyBubble";
|
|||||||
import { rest } from 'lodash';
|
import { rest } from 'lodash';
|
||||||
import { AIBanner } from 'components/ui/AIBanner';
|
import { AIBanner } from 'components/ui/AIBanner';
|
||||||
import { useAuth } from 'hooks/AuthContext';
|
import { useAuth } from 'hooks/AuthContext';
|
||||||
import { DeleteConfirmation } from './DeleteConfirmation';
|
import { DeleteConfirmation } from '../DeleteConfirmation';
|
||||||
|
|
||||||
interface CandidateInfoProps {
|
interface CandidateInfoProps {
|
||||||
candidate: Candidate;
|
candidate: Candidate;
|
@ -4,7 +4,7 @@ import Button from '@mui/material/Button';
|
|||||||
import Box from '@mui/material/Box';
|
import Box from '@mui/material/Box';
|
||||||
|
|
||||||
import { BackstoryElementProps } from 'components/BackstoryTab';
|
import { BackstoryElementProps } from 'components/BackstoryTab';
|
||||||
import { CandidateInfo } from 'components/CandidateInfo';
|
import { CandidateInfo } from 'components/ui/CandidateInfo';
|
||||||
import { Candidate } from "types/types";
|
import { Candidate } from "types/types";
|
||||||
import { useAuth } from 'hooks/AuthContext';
|
import { useAuth } from 'hooks/AuthContext';
|
||||||
import { useSelectedCandidate } from 'hooks/GlobalContext';
|
import { useSelectedCandidate } from 'hooks/GlobalContext';
|
||||||
|
97
frontend/src/components/ui/JobInfo.tsx
Normal file
97
frontend/src/components/ui/JobInfo.tsx
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
import React from 'react';
|
||||||
|
import { Box, Link, Typography, Avatar, Grid, SxProps, CardActions } from '@mui/material';
|
||||||
|
import {
|
||||||
|
Card,
|
||||||
|
CardContent,
|
||||||
|
Divider,
|
||||||
|
useTheme,
|
||||||
|
} from '@mui/material';
|
||||||
|
import DeleteIcon from '@mui/icons-material/Delete';
|
||||||
|
import { useMediaQuery } from '@mui/material';
|
||||||
|
import { JobFull } from 'types/types';
|
||||||
|
import { CopyBubble } from "components/CopyBubble";
|
||||||
|
import { rest } from 'lodash';
|
||||||
|
import { AIBanner } from 'components/ui/AIBanner';
|
||||||
|
import { useAuth } from 'hooks/AuthContext';
|
||||||
|
import { DeleteConfirmation } from '../DeleteConfirmation';
|
||||||
|
|
||||||
|
interface JobInfoProps {
|
||||||
|
job: JobFull;
|
||||||
|
sx?: SxProps;
|
||||||
|
action?: string;
|
||||||
|
elevation?: number;
|
||||||
|
variant?: "small" | "normal" | null
|
||||||
|
};
|
||||||
|
|
||||||
|
const JobInfo: React.FC<JobInfoProps> = (props: JobInfoProps) => {
|
||||||
|
const { job } = props;
|
||||||
|
const { user, apiClient } = useAuth();
|
||||||
|
const {
|
||||||
|
sx,
|
||||||
|
action = '',
|
||||||
|
elevation = 1,
|
||||||
|
variant = "normal"
|
||||||
|
} = props;
|
||||||
|
const theme = useTheme();
|
||||||
|
const isMobile = useMediaQuery(theme.breakpoints.down('md'));
|
||||||
|
const isAdmin = user?.isAdmin;
|
||||||
|
|
||||||
|
const deleteJob = async (jobId: string | undefined) => {
|
||||||
|
if (jobId) {
|
||||||
|
await apiClient.deleteJob(jobId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!job) {
|
||||||
|
return <Box>No user loaded.</Box>;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Card
|
||||||
|
elevation={elevation}
|
||||||
|
sx={{
|
||||||
|
display: "flex",
|
||||||
|
borderColor: 'transparent',
|
||||||
|
borderWidth: 2,
|
||||||
|
borderStyle: 'solid',
|
||||||
|
transition: 'all 0.3s ease',
|
||||||
|
flexDirection: "column",
|
||||||
|
...sx
|
||||||
|
}}
|
||||||
|
{...rest}
|
||||||
|
>
|
||||||
|
<CardContent sx={{ display: "flex", flexGrow: 1, p: 3, height: '100%', flexDirection: 'column', alignItems: 'stretch', position: "relative" }}>
|
||||||
|
{variant !== "small" && <>
|
||||||
|
{job.location &&
|
||||||
|
<Typography variant="body2" sx={{ mb: 1 }}>
|
||||||
|
<strong>Location:</strong> {job.location.city}, {job.location.state || job.location.country}
|
||||||
|
</Typography>
|
||||||
|
}
|
||||||
|
{job.company &&
|
||||||
|
<Typography variant="body2" sx={{ mb: 1 }}>
|
||||||
|
<strong>Company:</strong> {job.company}
|
||||||
|
</Typography>
|
||||||
|
}
|
||||||
|
{job.summary && <Typography variant="body2">
|
||||||
|
<strong>Summary:</strong> {job.summary}
|
||||||
|
</Typography>
|
||||||
|
}
|
||||||
|
</>}
|
||||||
|
</CardContent>
|
||||||
|
<CardActions>
|
||||||
|
{isAdmin &&
|
||||||
|
<DeleteConfirmation
|
||||||
|
onDelete={() => { deleteJob(job.id); }}
|
||||||
|
sx={{ minWidth: 'auto', px: 2, maxHeight: "min-content", color: "red" }}
|
||||||
|
action="delete"
|
||||||
|
label="job"
|
||||||
|
title="Delete job"
|
||||||
|
icon=<DeleteIcon />
|
||||||
|
message={`Are you sure you want to delete ${job.id}? This action cannot be undone.`}
|
||||||
|
/>}
|
||||||
|
</CardActions>
|
||||||
|
</Card>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export { JobInfo };
|
69
frontend/src/components/ui/JobPicker.tsx
Normal file
69
frontend/src/components/ui/JobPicker.tsx
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import React, { useEffect, useState } from 'react';
|
||||||
|
import { useNavigate } from "react-router-dom";
|
||||||
|
import Button from '@mui/material/Button';
|
||||||
|
import Box from '@mui/material/Box';
|
||||||
|
|
||||||
|
import { BackstoryElementProps } from 'components/BackstoryTab';
|
||||||
|
import { JobInfo } from 'components/ui/JobInfo';
|
||||||
|
import { Job, JobFull } from "types/types";
|
||||||
|
import { useAuth } from 'hooks/AuthContext';
|
||||||
|
import { useSelectedJob } from 'hooks/GlobalContext';
|
||||||
|
|
||||||
|
interface JobPickerProps extends BackstoryElementProps {
|
||||||
|
onSelect?: (job: JobFull) => void
|
||||||
|
};
|
||||||
|
|
||||||
|
const JobPicker = (props: JobPickerProps) => {
|
||||||
|
const { onSelect } = props;
|
||||||
|
const { apiClient } = useAuth();
|
||||||
|
const { selectedJob, setSelectedJob } = useSelectedJob();
|
||||||
|
const { setSnack } = props;
|
||||||
|
const [jobs, setJobs] = useState<JobFull[] | null>(null);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (jobs !== null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const getJobs = async () => {
|
||||||
|
try {
|
||||||
|
const results = await apiClient.getJobs();
|
||||||
|
const jobs: JobFull[] = results.data;
|
||||||
|
jobs.sort((a, b) => {
|
||||||
|
let result = a.company?.localeCompare(b.company || '');
|
||||||
|
if (result === 0) {
|
||||||
|
result = a.title?.localeCompare(b.title || '');
|
||||||
|
}
|
||||||
|
return result || 0;
|
||||||
|
});
|
||||||
|
setJobs(jobs);
|
||||||
|
} catch (err) {
|
||||||
|
setSnack("" + err);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
getJobs();
|
||||||
|
}, [jobs, setSnack]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box sx={{display: "flex", flexDirection: "column", mb: 1}}>
|
||||||
|
<Box sx={{ display: "flex", gap: 1, flexWrap: "wrap", justifyContent: "center" }}>
|
||||||
|
{jobs?.map((j, i) =>
|
||||||
|
<Box key={`${j.id}`}
|
||||||
|
onClick={() => { onSelect ? onSelect(j) : setSelectedJob(j); }}
|
||||||
|
sx={{ cursor: "pointer" }}>
|
||||||
|
{selectedJob?.id === j.id &&
|
||||||
|
<JobInfo sx={{ maxWidth: "320px", "cursor": "pointer", backgroundColor: "#f0f0f0", "&:hover": { border: "2px solid orange" } }} job={j} />
|
||||||
|
}
|
||||||
|
{selectedJob?.id !== j.id &&
|
||||||
|
<JobInfo sx={{ maxWidth: "320px", "cursor": "pointer", border: "2px solid transparent", "&:hover": { border: "2px solid orange" } }} job={j} />
|
||||||
|
}
|
||||||
|
</Box>
|
||||||
|
)}
|
||||||
|
</Box>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export {
|
||||||
|
JobPicker
|
||||||
|
};
|
@ -17,7 +17,7 @@ import { ConversationHandle } from 'components/Conversation';
|
|||||||
import { BackstoryPageProps } from 'components/BackstoryTab';
|
import { BackstoryPageProps } from 'components/BackstoryTab';
|
||||||
import { Message } from 'components/Message';
|
import { Message } from 'components/Message';
|
||||||
import { DeleteConfirmation } from 'components/DeleteConfirmation';
|
import { DeleteConfirmation } from 'components/DeleteConfirmation';
|
||||||
import { CandidateInfo } from 'components/CandidateInfo';
|
import { CandidateInfo } from 'components/ui/CandidateInfo';
|
||||||
import { useNavigate } from 'react-router-dom';
|
import { useNavigate } from 'react-router-dom';
|
||||||
import { useSelectedCandidate } from 'hooks/GlobalContext';
|
import { useSelectedCandidate } from 'hooks/GlobalContext';
|
||||||
import PropagateLoader from 'react-spinners/PropagateLoader';
|
import PropagateLoader from 'react-spinners/PropagateLoader';
|
||||||
|
@ -9,7 +9,7 @@ import CancelIcon from '@mui/icons-material/Cancel';
|
|||||||
import SendIcon from '@mui/icons-material/Send';
|
import SendIcon from '@mui/icons-material/Send';
|
||||||
import PropagateLoader from 'react-spinners/PropagateLoader';
|
import PropagateLoader from 'react-spinners/PropagateLoader';
|
||||||
|
|
||||||
import { CandidateInfo } from '../components/CandidateInfo';
|
import { CandidateInfo } from '../components/ui/CandidateInfo';
|
||||||
import { Quote } from 'components/Quote';
|
import { Quote } from 'components/Quote';
|
||||||
import { BackstoryElementProps } from 'components/BackstoryTab';
|
import { BackstoryElementProps } from 'components/BackstoryTab';
|
||||||
import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryTextField';
|
import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryTextField';
|
||||||
|
@ -7,31 +7,74 @@ import {
|
|||||||
Button,
|
Button,
|
||||||
Typography,
|
Typography,
|
||||||
Paper,
|
Paper,
|
||||||
Avatar,
|
|
||||||
useTheme,
|
useTheme,
|
||||||
Snackbar,
|
Snackbar,
|
||||||
|
Container,
|
||||||
|
Grid,
|
||||||
Alert,
|
Alert,
|
||||||
|
Tabs,
|
||||||
|
Tab,
|
||||||
|
Card,
|
||||||
|
CardContent,
|
||||||
|
Divider,
|
||||||
|
Avatar,
|
||||||
|
Badge,
|
||||||
} from '@mui/material';
|
} from '@mui/material';
|
||||||
|
import {
|
||||||
|
Person,
|
||||||
|
PersonAdd,
|
||||||
|
AccountCircle,
|
||||||
|
Add,
|
||||||
|
WorkOutline,
|
||||||
|
AddCircle,
|
||||||
|
} from '@mui/icons-material';
|
||||||
import PersonIcon from '@mui/icons-material/Person';
|
import PersonIcon from '@mui/icons-material/Person';
|
||||||
import WorkIcon from '@mui/icons-material/Work';
|
import WorkIcon from '@mui/icons-material/Work';
|
||||||
import AssessmentIcon from '@mui/icons-material/Assessment';
|
import AssessmentIcon from '@mui/icons-material/Assessment';
|
||||||
import { JobMatchAnalysis } from 'components/JobMatchAnalysis';
|
import { JobMatchAnalysis } from 'components/JobMatchAnalysis';
|
||||||
import { Candidate } from "types/types";
|
import { Candidate, Job, JobFull } from "types/types";
|
||||||
import { useNavigate } from 'react-router-dom';
|
import { useNavigate } from 'react-router-dom';
|
||||||
import { BackstoryPageProps } from 'components/BackstoryTab';
|
import { BackstoryPageProps } from 'components/BackstoryTab';
|
||||||
import { useAuth } from 'hooks/AuthContext';
|
import { useAuth } from 'hooks/AuthContext';
|
||||||
import { useSelectedCandidate, useSelectedJob } from 'hooks/GlobalContext';
|
import { useSelectedCandidate, useSelectedJob } from 'hooks/GlobalContext';
|
||||||
import { CandidateInfo } from 'components/CandidateInfo';
|
import { CandidateInfo } from 'components/ui/CandidateInfo';
|
||||||
import { ComingSoon } from 'components/ui/ComingSoon';
|
import { ComingSoon } from 'components/ui/ComingSoon';
|
||||||
import { JobManagement } from 'components/JobManagement';
|
import { JobManagement } from 'components/JobManagement';
|
||||||
import { LoginRequired } from 'components/ui/LoginRequired';
|
import { LoginRequired } from 'components/ui/LoginRequired';
|
||||||
import { Scrollable } from 'components/Scrollable';
|
import { Scrollable } from 'components/Scrollable';
|
||||||
import { CandidatePicker } from 'components/ui/CandidatePicker';
|
import { CandidatePicker } from 'components/ui/CandidatePicker';
|
||||||
|
import { JobPicker } from 'components/ui/JobPicker';
|
||||||
|
import { JobCreator } from 'components/JobCreator';
|
||||||
|
|
||||||
|
function WorkAddIcon() {
|
||||||
|
return (
|
||||||
|
<Box position="relative" display="inline-flex"
|
||||||
|
sx={{
|
||||||
|
lineHeight: "30px",
|
||||||
|
mb: "6px",
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<WorkOutline sx={{ fontSize: 24 }} />
|
||||||
|
<Add
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
bottom: -2,
|
||||||
|
right: -2,
|
||||||
|
fontSize: 14,
|
||||||
|
bgcolor: 'background.paper',
|
||||||
|
borderRadius: '50%',
|
||||||
|
boxShadow: 1,
|
||||||
|
}}
|
||||||
|
color="primary"
|
||||||
|
/>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// Main component
|
// Main component
|
||||||
const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps) => {
|
const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps) => {
|
||||||
const theme = useTheme();
|
const theme = useTheme();
|
||||||
const { user } = useAuth();
|
const { user, apiClient } = useAuth();
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
const { selectedCandidate, setSelectedCandidate } = useSelectedCandidate()
|
const { selectedCandidate, setSelectedCandidate } = useSelectedCandidate()
|
||||||
const { selectedJob, setSelectedJob } = useSelectedJob()
|
const { selectedJob, setSelectedJob } = useSelectedJob()
|
||||||
@ -41,12 +84,21 @@ const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps
|
|||||||
const [activeStep, setActiveStep] = useState(0);
|
const [activeStep, setActiveStep] = useState(0);
|
||||||
const [analysisStarted, setAnalysisStarted] = useState(false);
|
const [analysisStarted, setAnalysisStarted] = useState(false);
|
||||||
const [error, setError] = useState<string | null>(null);
|
const [error, setError] = useState<string | null>(null);
|
||||||
|
const [jobTab, setJobTab] = useState<string>('load');
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (selectedJob && activeStep === 1) {
|
console.log({ activeStep, selectedCandidate, selectedJob });
|
||||||
setActiveStep(2);
|
|
||||||
|
if (!selectedCandidate) {
|
||||||
|
if (activeStep !== 0) {
|
||||||
|
setActiveStep(0);
|
||||||
|
}
|
||||||
|
} else if (!selectedJob) {
|
||||||
|
if (activeStep !== 1) {
|
||||||
|
setActiveStep(1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}, [selectedJob, activeStep]);
|
}, [selectedCandidate, selectedJob, activeStep])
|
||||||
|
|
||||||
// Steps in our process
|
// Steps in our process
|
||||||
const steps = [
|
const steps = [
|
||||||
@ -93,7 +145,6 @@ const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps
|
|||||||
setSelectedJob(null);
|
setSelectedJob(null);
|
||||||
break;
|
break;
|
||||||
case 1: /* Select Job */
|
case 1: /* Select Job */
|
||||||
setSelectedCandidate(null);
|
|
||||||
setSelectedJob(null);
|
setSelectedJob(null);
|
||||||
break;
|
break;
|
||||||
case 2: /* Job Analysis */
|
case 2: /* Job Analysis */
|
||||||
@ -109,6 +160,11 @@ const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps
|
|||||||
setActiveStep(1);
|
setActiveStep(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const onJobSelect = (job: Job) => {
|
||||||
|
setSelectedJob(job)
|
||||||
|
setActiveStep(2);
|
||||||
|
}
|
||||||
|
|
||||||
// Render function for the candidate selection step
|
// Render function for the candidate selection step
|
||||||
const renderCandidateSelection = () => (
|
const renderCandidateSelection = () => (
|
||||||
<Paper elevation={3} sx={{ p: 3, mt: 3, mb: 4, borderRadius: 2 }}>
|
<Paper elevation={3} sx={{ p: 3, mt: 3, mb: 4, borderRadius: 2 }}>
|
||||||
@ -120,16 +176,35 @@ const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps
|
|||||||
</Paper>
|
</Paper>
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const handleTabChange = (event: React.SyntheticEvent, value: string) => {
|
||||||
|
setJobTab(value);
|
||||||
|
};
|
||||||
|
|
||||||
// Render function for the job description step
|
// Render function for the job description step
|
||||||
const renderJobDescription = () => (
|
const renderJobDescription = () => {
|
||||||
<Box sx={{ mt: 3 }}>
|
if (!selectedCandidate) {
|
||||||
{selectedCandidate && (
|
return;
|
||||||
<JobManagement
|
}
|
||||||
|
|
||||||
|
return (<Box sx={{ mt: 3 }}>
|
||||||
|
<Box sx={{ borderBottom: 1, borderColor: 'divider', mb: 3 }}>
|
||||||
|
<Tabs value={jobTab} onChange={handleTabChange} centered>
|
||||||
|
<Tab value='load' icon={<WorkOutline />} label="Load" />
|
||||||
|
<Tab value='create' icon={<WorkAddIcon />} label="Create" />
|
||||||
|
</Tabs>
|
||||||
|
</Box>
|
||||||
|
|
||||||
|
{jobTab === 'load' &&
|
||||||
|
<JobPicker {...backstoryProps} onSelect={onJobSelect} />
|
||||||
|
}
|
||||||
|
{jobTab === 'create' &&
|
||||||
|
<JobCreator
|
||||||
{...backstoryProps}
|
{...backstoryProps}
|
||||||
/>
|
onSave={onJobSelect}
|
||||||
)}
|
/>}
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// Render function for the analysis step
|
// Render function for the analysis step
|
||||||
const renderAnalysis = () => (
|
const renderAnalysis = () => (
|
||||||
@ -232,7 +307,7 @@ const JobAnalysisPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps
|
|||||||
</Button>
|
</Button>
|
||||||
) : (
|
) : (
|
||||||
<Button onClick={handleNext} variant="contained">
|
<Button onClick={handleNext} variant="contained">
|
||||||
{activeStep === steps[steps.length - 1].index - 1 ? 'Start Analysis' : 'Next'}
|
{activeStep === steps[steps.length - 1].index - 1 ? 'Done' : 'Next'}
|
||||||
</Button>
|
</Button>
|
||||||
)}
|
)}
|
||||||
</Box>
|
</Box>
|
||||||
|
@ -7,7 +7,7 @@ import MuiMarkdown from 'mui-markdown';
|
|||||||
import { BackstoryPageProps } from '../components/BackstoryTab';
|
import { BackstoryPageProps } from '../components/BackstoryTab';
|
||||||
import { Conversation, ConversationHandle } from '../components/Conversation';
|
import { Conversation, ConversationHandle } from '../components/Conversation';
|
||||||
import { BackstoryQuery } from '../components/BackstoryQuery';
|
import { BackstoryQuery } from '../components/BackstoryQuery';
|
||||||
import { CandidateInfo } from 'components/CandidateInfo';
|
import { CandidateInfo } from 'components/ui/CandidateInfo';
|
||||||
import { useAuth } from 'hooks/AuthContext';
|
import { useAuth } from 'hooks/AuthContext';
|
||||||
import { Candidate } from 'types/types';
|
import { Candidate } from 'types/types';
|
||||||
|
|
||||||
|
@ -52,7 +52,7 @@ interface StreamingOptions<T = Types.ChatMessage> {
|
|||||||
signal?: AbortSignal;
|
signal?: AbortSignal;
|
||||||
}
|
}
|
||||||
|
|
||||||
interface DeleteCandidateResponse {
|
interface DeleteResponse {
|
||||||
success: boolean;
|
success: boolean;
|
||||||
message: string;
|
message: string;
|
||||||
}
|
}
|
||||||
@ -530,14 +530,24 @@ class ApiClient {
|
|||||||
return this.handleApiResponseWithConversion<Types.Candidate>(response, 'Candidate');
|
return this.handleApiResponseWithConversion<Types.Candidate>(response, 'Candidate');
|
||||||
}
|
}
|
||||||
|
|
||||||
async deleteCandidate(id: string): Promise<DeleteCandidateResponse> {
|
async deleteCandidate(id: string): Promise<DeleteResponse> {
|
||||||
const response = await fetch(`${this.baseUrl}/candidates/${id}`, {
|
const response = await fetch(`${this.baseUrl}/candidates/${id}`, {
|
||||||
method: 'DELETE',
|
method: 'DELETE',
|
||||||
headers: this.defaultHeaders,
|
headers: this.defaultHeaders,
|
||||||
body: JSON.stringify({ id })
|
body: JSON.stringify({ id })
|
||||||
});
|
});
|
||||||
|
|
||||||
return handleApiResponse<DeleteCandidateResponse>(response);
|
return handleApiResponse<DeleteResponse>(response);
|
||||||
|
}
|
||||||
|
|
||||||
|
async deleteJob(id: string): Promise<DeleteResponse> {
|
||||||
|
const response = await fetch(`${this.baseUrl}/jobs/${id}`, {
|
||||||
|
method: 'DELETE',
|
||||||
|
headers: this.defaultHeaders,
|
||||||
|
body: JSON.stringify({ id })
|
||||||
|
});
|
||||||
|
|
||||||
|
return handleApiResponse<DeleteResponse>(response);
|
||||||
}
|
}
|
||||||
|
|
||||||
async uploadCandidateProfile(file: File): Promise<boolean> {
|
async uploadCandidateProfile(file: File): Promise<boolean> {
|
||||||
@ -641,7 +651,7 @@ class ApiClient {
|
|||||||
return this.handleApiResponseWithConversion<Types.Job>(response, 'Job');
|
return this.handleApiResponseWithConversion<Types.Job>(response, 'Job');
|
||||||
}
|
}
|
||||||
|
|
||||||
async getJobs(request: Partial<PaginatedRequest> = {}): Promise<PaginatedResponse<Types.Job>> {
|
async getJobs(request: Partial<PaginatedRequest> = {}): Promise<PaginatedResponse<Types.JobFull>> {
|
||||||
const paginatedRequest = createPaginatedRequest(request);
|
const paginatedRequest = createPaginatedRequest(request);
|
||||||
const params = toUrlParams(formatApiRequest(paginatedRequest));
|
const params = toUrlParams(formatApiRequest(paginatedRequest));
|
||||||
|
|
||||||
@ -649,7 +659,7 @@ class ApiClient {
|
|||||||
headers: this.defaultHeaders
|
headers: this.defaultHeaders
|
||||||
});
|
});
|
||||||
|
|
||||||
return this.handlePaginatedApiResponseWithConversion<Types.Job>(response, 'Job');
|
return this.handlePaginatedApiResponseWithConversion<Types.JobFull>(response, 'JobFull');
|
||||||
}
|
}
|
||||||
|
|
||||||
async getJobsByEmployer(employerId: string, request: Partial<PaginatedRequest> = {}): Promise<PaginatedResponse<Types.Job>> {
|
async getJobsByEmployer(employerId: string, request: Partial<PaginatedRequest> = {}): Promise<PaginatedResponse<Types.Job>> {
|
||||||
@ -844,18 +854,28 @@ class ApiClient {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
return this.streamify<Types.DocumentMessage>('/candidates/documents/upload', formData, streamingOptions);
|
return this.streamify<Types.DocumentMessage>('/candidates/documents/upload', formData, streamingOptions);
|
||||||
// {
|
}
|
||||||
// method: 'POST',
|
|
||||||
// headers: {
|
|
||||||
// // Don't set Content-Type - browser will set it automatically with boundary
|
|
||||||
// 'Authorization': this.defaultHeaders['Authorization']
|
|
||||||
// },
|
|
||||||
// body: formData
|
|
||||||
// });
|
|
||||||
|
|
||||||
// const result = await handleApiResponse<Types.Document>(response);
|
createJobFromFile(file: File, streamingOptions?: StreamingOptions<Types.Job>): StreamingResponse<Types.Job> {
|
||||||
|
const formData = new FormData()
|
||||||
|
formData.append('file', file);
|
||||||
|
formData.append('filename', file.name);
|
||||||
|
streamingOptions = {
|
||||||
|
...streamingOptions,
|
||||||
|
headers: {
|
||||||
|
// Don't set Content-Type - browser will set it automatically with boundary
|
||||||
|
'Authorization': this.defaultHeaders['Authorization']
|
||||||
|
}
|
||||||
|
};
|
||||||
|
return this.streamify<Types.Job>('/jobs/upload', formData, streamingOptions);
|
||||||
|
}
|
||||||
|
|
||||||
// return result;
|
getJobRequirements(jobId: string, streamingOptions?: StreamingOptions<Types.DocumentMessage>): StreamingResponse<Types.DocumentMessage> {
|
||||||
|
streamingOptions = {
|
||||||
|
...streamingOptions,
|
||||||
|
headers: this.defaultHeaders,
|
||||||
|
};
|
||||||
|
return this.streamify<Types.DocumentMessage>(`/jobs/requirements/${jobId}`, null, streamingOptions);
|
||||||
}
|
}
|
||||||
|
|
||||||
async candidateMatchForRequirement(candidate_id: string, requirement: string) : Promise<Types.SkillMatch> {
|
async candidateMatchForRequirement(candidate_id: string, requirement: string) : Promise<Types.SkillMatch> {
|
||||||
@ -1001,7 +1021,7 @@ class ApiClient {
|
|||||||
* @param options callbacks, headers, and method
|
* @param options callbacks, headers, and method
|
||||||
* @returns
|
* @returns
|
||||||
*/
|
*/
|
||||||
streamify<T = Types.ChatMessage[]>(api: string, data: BodyInit, options: StreamingOptions<T> = {}) : StreamingResponse<T> {
|
streamify<T = Types.ChatMessage[]>(api: string, data: BodyInit | null, options: StreamingOptions<T> = {}) : StreamingResponse<T> {
|
||||||
const abortController = new AbortController();
|
const abortController = new AbortController();
|
||||||
const signal = options.signal || abortController.signal;
|
const signal = options.signal || abortController.signal;
|
||||||
const headers = options.headers || null;
|
const headers = options.headers || null;
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
// Generated TypeScript types from Pydantic models
|
// Generated TypeScript types from Pydantic models
|
||||||
// Source: src/backend/models.py
|
// Source: src/backend/models.py
|
||||||
// Generated on: 2025-06-05T22:02:22.004513
|
// Generated on: 2025-06-07T20:43:58.855207
|
||||||
// DO NOT EDIT MANUALLY - This file is auto-generated
|
// DO NOT EDIT MANUALLY - This file is auto-generated
|
||||||
|
|
||||||
// ============================
|
// ============================
|
||||||
@ -323,7 +323,7 @@ export interface ChatMessageMetaData {
|
|||||||
presencePenalty?: number;
|
presencePenalty?: number;
|
||||||
stopSequences?: Array<string>;
|
stopSequences?: Array<string>;
|
||||||
ragResults?: Array<ChromaDBGetResponse>;
|
ragResults?: Array<ChromaDBGetResponse>;
|
||||||
llmHistory?: Array<LLMMessage>;
|
llmHistory?: Array<any>;
|
||||||
evalCount: number;
|
evalCount: number;
|
||||||
evalDuration: number;
|
evalDuration: number;
|
||||||
promptEvalCount: number;
|
promptEvalCount: number;
|
||||||
@ -711,12 +711,6 @@ export interface JobResponse {
|
|||||||
meta?: Record<string, any>;
|
meta?: Record<string, any>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface LLMMessage {
|
|
||||||
role: string;
|
|
||||||
content: string;
|
|
||||||
toolCalls?: Array<Record<string, any>>;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface Language {
|
export interface Language {
|
||||||
language: string;
|
language: string;
|
||||||
proficiency: "basic" | "conversational" | "fluent" | "native";
|
proficiency: "basic" | "conversational" | "fluent" | "native";
|
||||||
|
@ -137,7 +137,7 @@ class Agent(BaseModel, ABC):
|
|||||||
# llm: Any,
|
# llm: Any,
|
||||||
# model: str,
|
# model: str,
|
||||||
# message: ChatMessage,
|
# message: ChatMessage,
|
||||||
# tool_message: Any, # llama response message
|
# tool_message: Any, # llama response
|
||||||
# messages: List[LLMMessage],
|
# messages: List[LLMMessage],
|
||||||
# ) -> AsyncGenerator[ChatMessage, None]:
|
# ) -> AsyncGenerator[ChatMessage, None]:
|
||||||
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
@ -270,15 +270,15 @@ class Agent(BaseModel, ABC):
|
|||||||
# },
|
# },
|
||||||
# stream=True,
|
# stream=True,
|
||||||
# ):
|
# ):
|
||||||
# # logger.info(f"LLM::Tools: {'done' if response.done else 'processing'} - {response.message}")
|
# # logger.info(f"LLM::Tools: {'done' if response.finish_reason else 'processing'} - {response}")
|
||||||
# message.status = "streaming"
|
# message.status = "streaming"
|
||||||
# message.chunk = response.message.content
|
# message.chunk = response.content
|
||||||
# message.content += message.chunk
|
# message.content += message.chunk
|
||||||
|
|
||||||
# if not response.done:
|
# if not response.finish_reason:
|
||||||
# yield message
|
# yield message
|
||||||
|
|
||||||
# if response.done:
|
# if response.finish_reason:
|
||||||
# self.collect_metrics(response)
|
# self.collect_metrics(response)
|
||||||
# message.metadata.eval_count += response.eval_count
|
# message.metadata.eval_count += response.eval_count
|
||||||
# message.metadata.eval_duration += response.eval_duration
|
# message.metadata.eval_duration += response.eval_duration
|
||||||
@ -296,9 +296,9 @@ class Agent(BaseModel, ABC):
|
|||||||
|
|
||||||
def collect_metrics(self, response):
|
def collect_metrics(self, response):
|
||||||
self.metrics.tokens_prompt.labels(agent=self.agent_type).inc(
|
self.metrics.tokens_prompt.labels(agent=self.agent_type).inc(
|
||||||
response.prompt_eval_count
|
response.usage.prompt_eval_count
|
||||||
)
|
)
|
||||||
self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count)
|
self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.usage.eval_count)
|
||||||
|
|
||||||
def get_rag_context(self, rag_message: ChatMessageRagSearch) -> str:
|
def get_rag_context(self, rag_message: ChatMessageRagSearch) -> str:
|
||||||
"""
|
"""
|
||||||
@ -361,7 +361,7 @@ Content: {content}
|
|||||||
yield status_message
|
yield status_message
|
||||||
|
|
||||||
try:
|
try:
|
||||||
chroma_results = user.file_watcher.find_similar(
|
chroma_results = await user.file_watcher.find_similar(
|
||||||
query=prompt, top_k=top_k, threshold=threshold
|
query=prompt, top_k=top_k, threshold=threshold
|
||||||
)
|
)
|
||||||
if not chroma_results:
|
if not chroma_results:
|
||||||
@ -430,7 +430,7 @@ Content: {content}
|
|||||||
logger.info(f"Message options: {options.model_dump(exclude_unset=True)}")
|
logger.info(f"Message options: {options.model_dump(exclude_unset=True)}")
|
||||||
response = None
|
response = None
|
||||||
content = ""
|
content = ""
|
||||||
for response in llm.chat(
|
async for response in llm.chat_stream(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
options={
|
options={
|
||||||
@ -446,12 +446,12 @@ Content: {content}
|
|||||||
yield error_message
|
yield error_message
|
||||||
return
|
return
|
||||||
|
|
||||||
content += response.message.content
|
content += response.content
|
||||||
|
|
||||||
if not response.done:
|
if not response.finish_reason:
|
||||||
streaming_message = ChatMessageStreaming(
|
streaming_message = ChatMessageStreaming(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
content=response.message.content,
|
content=response.content,
|
||||||
status=ApiStatusType.STREAMING,
|
status=ApiStatusType.STREAMING,
|
||||||
)
|
)
|
||||||
yield streaming_message
|
yield streaming_message
|
||||||
@ -466,7 +466,7 @@ Content: {content}
|
|||||||
|
|
||||||
self.collect_metrics(response)
|
self.collect_metrics(response)
|
||||||
self.context_tokens = (
|
self.context_tokens = (
|
||||||
response.prompt_eval_count + response.eval_count
|
response.usage.prompt_eval_count + response.usage.eval_count
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_message = ChatMessage(
|
chat_message = ChatMessage(
|
||||||
@ -476,10 +476,10 @@ Content: {content}
|
|||||||
content=content,
|
content=content,
|
||||||
metadata = ChatMessageMetaData(
|
metadata = ChatMessageMetaData(
|
||||||
options=options,
|
options=options,
|
||||||
eval_count=response.eval_count,
|
eval_count=response.usage.eval_count,
|
||||||
eval_duration=response.eval_duration,
|
eval_duration=response.usage.eval_duration,
|
||||||
prompt_eval_count=response.prompt_eval_count,
|
prompt_eval_count=response.usage.prompt_eval_count,
|
||||||
prompt_eval_duration=response.prompt_eval_duration,
|
prompt_eval_duration=response.usage.prompt_eval_duration,
|
||||||
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -588,12 +588,12 @@ Content: {content}
|
|||||||
|
|
||||||
# end_time = time.perf_counter()
|
# end_time = time.perf_counter()
|
||||||
# message.metadata.timers["tool_check"] = end_time - start_time
|
# message.metadata.timers["tool_check"] = end_time - start_time
|
||||||
# if not response.message.tool_calls:
|
# if not response.tool_calls:
|
||||||
# logger.info("LLM indicates tools will not be used")
|
# logger.info("LLM indicates tools will not be used")
|
||||||
# # The LLM will not use tools, so disable use_tools so we can stream the full response
|
# # The LLM will not use tools, so disable use_tools so we can stream the full response
|
||||||
# use_tools = False
|
# use_tools = False
|
||||||
# else:
|
# else:
|
||||||
# tool_metadata["attempted"] = response.message.tool_calls
|
# tool_metadata["attempted"] = response.tool_calls
|
||||||
|
|
||||||
# if use_tools:
|
# if use_tools:
|
||||||
# logger.info("LLM indicates tools will be used")
|
# logger.info("LLM indicates tools will be used")
|
||||||
@ -626,15 +626,15 @@ Content: {content}
|
|||||||
# yield message
|
# yield message
|
||||||
# return
|
# return
|
||||||
|
|
||||||
# if response.message.tool_calls:
|
# if response.tool_calls:
|
||||||
# tool_metadata["used"] = response.message.tool_calls
|
# tool_metadata["used"] = response.tool_calls
|
||||||
# # Process all yielded items from the handler
|
# # Process all yielded items from the handler
|
||||||
# start_time = time.perf_counter()
|
# start_time = time.perf_counter()
|
||||||
# async for message in self.process_tool_calls(
|
# async for message in self.process_tool_calls(
|
||||||
# llm=llm,
|
# llm=llm,
|
||||||
# model=model,
|
# model=model,
|
||||||
# message=message,
|
# message=message,
|
||||||
# tool_message=response.message,
|
# tool_message=response,
|
||||||
# messages=messages,
|
# messages=messages,
|
||||||
# ):
|
# ):
|
||||||
# if message.status == "error":
|
# if message.status == "error":
|
||||||
@ -647,7 +647,7 @@ Content: {content}
|
|||||||
# return
|
# return
|
||||||
|
|
||||||
# logger.info("LLM indicated tools will be used, and then they weren't")
|
# logger.info("LLM indicated tools will be used, and then they weren't")
|
||||||
# message.content = response.message.content
|
# message.content = response.content
|
||||||
# message.status = "done"
|
# message.status = "done"
|
||||||
# yield message
|
# yield message
|
||||||
# return
|
# return
|
||||||
@ -674,7 +674,7 @@ Content: {content}
|
|||||||
content = ""
|
content = ""
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
response = None
|
response = None
|
||||||
for response in llm.chat(
|
async for response in llm.chat_stream(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
options={
|
options={
|
||||||
@ -690,12 +690,12 @@ Content: {content}
|
|||||||
yield error_message
|
yield error_message
|
||||||
return
|
return
|
||||||
|
|
||||||
content += response.message.content
|
content += response.content
|
||||||
|
|
||||||
if not response.done:
|
if not response.finish_reason:
|
||||||
streaming_message = ChatMessageStreaming(
|
streaming_message = ChatMessageStreaming(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
content=response.message.content,
|
content=response.content,
|
||||||
)
|
)
|
||||||
yield streaming_message
|
yield streaming_message
|
||||||
|
|
||||||
@ -709,7 +709,7 @@ Content: {content}
|
|||||||
|
|
||||||
self.collect_metrics(response)
|
self.collect_metrics(response)
|
||||||
self.context_tokens = (
|
self.context_tokens = (
|
||||||
response.prompt_eval_count + response.eval_count
|
response.usage.prompt_eval_count + response.usage.eval_count
|
||||||
)
|
)
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
|
|
||||||
@ -720,10 +720,10 @@ Content: {content}
|
|||||||
content=content,
|
content=content,
|
||||||
metadata = ChatMessageMetaData(
|
metadata = ChatMessageMetaData(
|
||||||
options=options,
|
options=options,
|
||||||
eval_count=response.eval_count,
|
eval_count=response.usage.eval_count,
|
||||||
eval_duration=response.eval_duration,
|
eval_duration=response.usage.eval_duration,
|
||||||
prompt_eval_count=response.prompt_eval_count,
|
prompt_eval_count=response.usage.prompt_eval_count,
|
||||||
prompt_eval_duration=response.prompt_eval_duration,
|
prompt_eval_duration=response.usage.prompt_eval_duration,
|
||||||
timers={
|
timers={
|
||||||
"llm_streamed": end_time - start_time,
|
"llm_streamed": end_time - start_time,
|
||||||
"llm_with_tools": 0, # Placeholder for tool processing time
|
"llm_with_tools": 0, # Placeholder for tool processing time
|
||||||
|
@ -178,10 +178,11 @@ class RedisDatabase:
|
|||||||
'jobs': 'job:',
|
'jobs': 'job:',
|
||||||
'job_applications': 'job_application:',
|
'job_applications': 'job_application:',
|
||||||
'chat_sessions': 'chat_session:',
|
'chat_sessions': 'chat_session:',
|
||||||
'chat_messages': 'chat_messages:', # This will store lists
|
'chat_messages': 'chat_messages:',
|
||||||
'ai_parameters': 'ai_parameters:',
|
'ai_parameters': 'ai_parameters:',
|
||||||
'users': 'user:',
|
'users': 'user:',
|
||||||
'candidate_documents': 'candidate_documents:',
|
'candidate_documents': 'candidate_documents:',
|
||||||
|
'job_requirements': 'job_requirements:', # Add this line
|
||||||
}
|
}
|
||||||
|
|
||||||
def _serialize(self, data: Any) -> str:
|
def _serialize(self, data: Any) -> str:
|
||||||
@ -236,6 +237,7 @@ class RedisDatabase:
|
|||||||
# Delete each document's metadata
|
# Delete each document's metadata
|
||||||
for doc_id in document_ids:
|
for doc_id in document_ids:
|
||||||
pipe.delete(f"document:{doc_id}")
|
pipe.delete(f"document:{doc_id}")
|
||||||
|
pipe.delete(f"{self.KEY_PREFIXES['job_requirements']}{doc_id}")
|
||||||
deleted_count += 1
|
deleted_count += 1
|
||||||
|
|
||||||
# Delete the candidate's document list
|
# Delete the candidate's document list
|
||||||
@ -251,6 +253,109 @@ class RedisDatabase:
|
|||||||
logger.error(f"Error deleting all documents for candidate {candidate_id}: {e}")
|
logger.error(f"Error deleting all documents for candidate {candidate_id}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def get_cached_skill_match(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Retrieve cached skill match assessment"""
|
||||||
|
try:
|
||||||
|
cached_data = await self.redis.get(cache_key)
|
||||||
|
if cached_data:
|
||||||
|
return json.loads(cached_data)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error retrieving cached skill match: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def cache_skill_match(self, cache_key: str, assessment_data: Dict[str, Any], ttl: int = 86400 * 30) -> bool:
|
||||||
|
"""Cache skill match assessment with TTL (default 30 days)"""
|
||||||
|
try:
|
||||||
|
await self.redis.setex(
|
||||||
|
cache_key,
|
||||||
|
ttl,
|
||||||
|
json.dumps(assessment_data, default=str)
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error caching skill match: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_candidate_skill_update_time(self, candidate_id: str) -> Optional[datetime]:
|
||||||
|
"""Get the last time candidate's skill information was updated"""
|
||||||
|
try:
|
||||||
|
# This assumes you track skill update timestamps in your candidate data
|
||||||
|
candidate_data = await self.get_candidate(candidate_id)
|
||||||
|
if candidate_data and 'skills_updated_at' in candidate_data:
|
||||||
|
return datetime.fromisoformat(candidate_data['skills_updated_at'])
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting candidate skill update time: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]:
|
||||||
|
"""Get the timestamp of the latest RAG data update for a specific user"""
|
||||||
|
try:
|
||||||
|
rag_update_key = f"user:{user_id}:rag_last_update"
|
||||||
|
timestamp_str = await self.redis.get(rag_update_key)
|
||||||
|
if timestamp_str:
|
||||||
|
return datetime.fromisoformat(timestamp_str.decode('utf-8'))
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting user RAG update time for user {user_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def update_user_rag_timestamp(self, user_id: str) -> bool:
|
||||||
|
"""Update the RAG data timestamp for a specific user (call this when user's RAG data is updated)"""
|
||||||
|
try:
|
||||||
|
rag_update_key = f"user:{user_id}:rag_last_update"
|
||||||
|
current_time = datetime.utcnow().isoformat()
|
||||||
|
await self.redis.set(rag_update_key, current_time)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating RAG timestamp for user {user_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def invalidate_candidate_skill_cache(self, candidate_id: str) -> int:
|
||||||
|
"""Invalidate all cached skill matches for a specific candidate"""
|
||||||
|
try:
|
||||||
|
pattern = f"skill_match:{candidate_id}:*"
|
||||||
|
keys = await self.redis.keys(pattern)
|
||||||
|
if keys:
|
||||||
|
return await self.redis.delete(*keys)
|
||||||
|
return 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error invalidating candidate skill cache: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def clear_all_skill_match_cache(self) -> int:
|
||||||
|
"""Clear all skill match cache (useful after major system updates)"""
|
||||||
|
try:
|
||||||
|
pattern = "skill_match:*"
|
||||||
|
keys = await self.redis.keys(pattern)
|
||||||
|
if keys:
|
||||||
|
return await self.redis.delete(*keys)
|
||||||
|
return 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error clearing skill match cache: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def invalidate_user_skill_cache(self, user_id: str) -> int:
|
||||||
|
"""Invalidate all cached skill matches when a user's RAG data is updated"""
|
||||||
|
try:
|
||||||
|
# This assumes all candidates belonging to this user need cache invalidation
|
||||||
|
# You might need to adjust the pattern based on how you associate candidates with users
|
||||||
|
pattern = f"skill_match:*"
|
||||||
|
keys = await self.redis.keys(pattern)
|
||||||
|
|
||||||
|
# Filter keys that belong to candidates owned by this user
|
||||||
|
# This would require additional logic to determine candidate ownership
|
||||||
|
# For now, you might want to clear all cache when any user's RAG data updates
|
||||||
|
# or implement a more sophisticated mapping
|
||||||
|
|
||||||
|
if keys:
|
||||||
|
return await self.redis.delete(*keys)
|
||||||
|
return 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error invalidating user skill cache for user {user_id}: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]:
|
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]:
|
||||||
"""Get all documents for a specific candidate"""
|
"""Get all documents for a specific candidate"""
|
||||||
key = f"{self.KEY_PREFIXES['candidate_documents']}{candidate_id}"
|
key = f"{self.KEY_PREFIXES['candidate_documents']}{candidate_id}"
|
||||||
@ -331,6 +436,133 @@ class RedisDatabase:
|
|||||||
query_lower in doc.get("originalName", "").lower())
|
query_lower in doc.get("originalName", "").lower())
|
||||||
]
|
]
|
||||||
|
|
||||||
|
async def get_job_requirements(self, document_id: str) -> Optional[Dict]:
|
||||||
|
"""Get cached job requirements analysis for a document"""
|
||||||
|
try:
|
||||||
|
key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}"
|
||||||
|
data = await self.redis.get(key)
|
||||||
|
if data:
|
||||||
|
requirements_data = self._deserialize(data)
|
||||||
|
logger.debug(f"📋 Retrieved cached job requirements for document {document_id}")
|
||||||
|
return requirements_data
|
||||||
|
logger.debug(f"📋 No cached job requirements found for document {document_id}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error retrieving job requirements for document {document_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def save_job_requirements(self, document_id: str, requirements: Dict) -> bool:
|
||||||
|
"""Save job requirements analysis results for a document"""
|
||||||
|
try:
|
||||||
|
key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}"
|
||||||
|
|
||||||
|
# Add metadata to the requirements
|
||||||
|
requirements_with_meta = {
|
||||||
|
**requirements,
|
||||||
|
"cached_at": datetime.now(UTC).isoformat(),
|
||||||
|
"document_id": document_id
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.redis.set(key, self._serialize(requirements_with_meta))
|
||||||
|
|
||||||
|
# Optional: Set expiration (e.g., 30 days) to prevent indefinite storage
|
||||||
|
# await self.redis.expire(key, 30 * 24 * 60 * 60) # 30 days
|
||||||
|
|
||||||
|
logger.debug(f"📋 Saved job requirements for document {document_id}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error saving job requirements for document {document_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def delete_job_requirements(self, document_id: str) -> bool:
|
||||||
|
"""Delete cached job requirements for a document"""
|
||||||
|
try:
|
||||||
|
key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}"
|
||||||
|
result = await self.redis.delete(key)
|
||||||
|
if result > 0:
|
||||||
|
logger.debug(f"📋 Deleted job requirements for document {document_id}")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error deleting job requirements for document {document_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_all_job_requirements(self) -> Dict[str, Any]:
|
||||||
|
"""Get all cached job requirements"""
|
||||||
|
try:
|
||||||
|
pattern = f"{self.KEY_PREFIXES['job_requirements']}*"
|
||||||
|
keys = await self.redis.keys(pattern)
|
||||||
|
|
||||||
|
if not keys:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
for key in keys:
|
||||||
|
pipe.get(key)
|
||||||
|
values = await pipe.execute()
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for key, value in zip(keys, values):
|
||||||
|
document_id = key.replace(self.KEY_PREFIXES['job_requirements'], '')
|
||||||
|
if value:
|
||||||
|
result[document_id] = self._deserialize(value)
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error retrieving all job requirements: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def get_job_requirements_by_candidate(self, candidate_id: str) -> List[Dict]:
|
||||||
|
"""Get all job requirements analysis for documents belonging to a candidate"""
|
||||||
|
try:
|
||||||
|
# Get all documents for the candidate
|
||||||
|
candidate_documents = await self.get_candidate_documents(candidate_id)
|
||||||
|
|
||||||
|
if not candidate_documents:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Get job requirements for each document
|
||||||
|
job_requirements = []
|
||||||
|
for doc in candidate_documents:
|
||||||
|
doc_id = doc.get("id")
|
||||||
|
if doc_id:
|
||||||
|
requirements = await self.get_job_requirements(doc_id)
|
||||||
|
if requirements:
|
||||||
|
# Add document metadata to requirements
|
||||||
|
requirements["document_filename"] = doc.get("filename")
|
||||||
|
requirements["document_original_name"] = doc.get("originalName")
|
||||||
|
job_requirements.append(requirements)
|
||||||
|
|
||||||
|
return job_requirements
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error retrieving job requirements for candidate {candidate_id}: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def invalidate_job_requirements_cache(self, document_id: str) -> bool:
|
||||||
|
"""Invalidate (delete) cached job requirements for a document"""
|
||||||
|
# This is an alias for delete_job_requirements for semantic clarity
|
||||||
|
return await self.delete_job_requirements(document_id)
|
||||||
|
|
||||||
|
async def bulk_delete_job_requirements(self, document_ids: List[str]) -> int:
|
||||||
|
"""Delete job requirements for multiple documents and return count of deleted items"""
|
||||||
|
try:
|
||||||
|
deleted_count = 0
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
|
||||||
|
for doc_id in document_ids:
|
||||||
|
key = f"{self.KEY_PREFIXES['job_requirements']}{doc_id}"
|
||||||
|
pipe.delete(key)
|
||||||
|
deleted_count += 1
|
||||||
|
|
||||||
|
results = await pipe.execute()
|
||||||
|
actual_deleted = sum(1 for result in results if result > 0)
|
||||||
|
|
||||||
|
logger.info(f"📋 Bulk deleted job requirements for {actual_deleted}/{len(document_ids)} documents")
|
||||||
|
return actual_deleted
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error bulk deleting job requirements: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
# Viewer operations
|
# Viewer operations
|
||||||
async def get_viewer(self, viewer_id: str) -> Optional[Dict]:
|
async def get_viewer(self, viewer_id: str) -> Optional[Dict]:
|
||||||
"""Get viewer by ID"""
|
"""Get viewer by ID"""
|
||||||
@ -1484,6 +1716,74 @@ class RedisDatabase:
|
|||||||
key = f"{self.KEY_PREFIXES['users']}{email}"
|
key = f"{self.KEY_PREFIXES['users']}{email}"
|
||||||
await self.redis.delete(key)
|
await self.redis.delete(key)
|
||||||
|
|
||||||
|
async def get_job_requirements_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get statistics about cached job requirements"""
|
||||||
|
try:
|
||||||
|
pattern = f"{self.KEY_PREFIXES['job_requirements']}*"
|
||||||
|
keys = await self.redis.keys(pattern)
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
"total_cached_requirements": len(keys),
|
||||||
|
"cache_dates": {},
|
||||||
|
"documents_with_requirements": []
|
||||||
|
}
|
||||||
|
|
||||||
|
if keys:
|
||||||
|
# Get cache dates for analysis
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
for key in keys:
|
||||||
|
pipe.get(key)
|
||||||
|
values = await pipe.execute()
|
||||||
|
|
||||||
|
for key, value in zip(keys, values):
|
||||||
|
if value:
|
||||||
|
requirements_data = self._deserialize(value)
|
||||||
|
if requirements_data:
|
||||||
|
document_id = key.replace(self.KEY_PREFIXES['job_requirements'], '')
|
||||||
|
stats["documents_with_requirements"].append(document_id)
|
||||||
|
|
||||||
|
# Track cache dates
|
||||||
|
cached_at = requirements_data.get("cached_at")
|
||||||
|
if cached_at:
|
||||||
|
cache_date = cached_at[:10] # Extract date part
|
||||||
|
stats["cache_dates"][cache_date] = stats["cache_dates"].get(cache_date, 0) + 1
|
||||||
|
|
||||||
|
return stats
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error getting job requirements stats: {e}")
|
||||||
|
return {"total_cached_requirements": 0, "cache_dates": {}, "documents_with_requirements": []}
|
||||||
|
|
||||||
|
async def cleanup_orphaned_job_requirements(self) -> int:
|
||||||
|
"""Clean up job requirements for documents that no longer exist"""
|
||||||
|
try:
|
||||||
|
# Get all job requirements
|
||||||
|
all_requirements = await self.get_all_job_requirements()
|
||||||
|
|
||||||
|
if not all_requirements:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
orphaned_count = 0
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
|
||||||
|
for document_id in all_requirements.keys():
|
||||||
|
# Check if the document still exists
|
||||||
|
document_exists = await self.get_document(document_id)
|
||||||
|
if not document_exists:
|
||||||
|
# Document no longer exists, delete its job requirements
|
||||||
|
key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}"
|
||||||
|
pipe.delete(key)
|
||||||
|
orphaned_count += 1
|
||||||
|
logger.debug(f"📋 Queued orphaned job requirements for deletion: {document_id}")
|
||||||
|
|
||||||
|
if orphaned_count > 0:
|
||||||
|
await pipe.execute()
|
||||||
|
logger.info(f"🧹 Cleaned up {orphaned_count} orphaned job requirements")
|
||||||
|
|
||||||
|
return orphaned_count
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error cleaning up orphaned job requirements: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
# Utility methods
|
# Utility methods
|
||||||
async def clear_all_data(self):
|
async def clear_all_data(self):
|
||||||
"""Clear all data from Redis (use with caution!)"""
|
"""Clear all data from Redis (use with caution!)"""
|
||||||
|
@ -19,8 +19,9 @@ import defines
|
|||||||
from logger import logger
|
from logger import logger
|
||||||
import agents as agents
|
import agents as agents
|
||||||
from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, RagEntry, ChatMessageMetaData, ApiStatusType, Candidate, ChatContextType)
|
from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, RagEntry, ChatMessageMetaData, ApiStatusType, Candidate, ChatContextType)
|
||||||
from llm_manager import llm_manager
|
import llm_proxy as llm_manager
|
||||||
from agents.base import Agent
|
from agents.base import Agent
|
||||||
|
from database import RedisDatabase
|
||||||
|
|
||||||
class CandidateEntity(Candidate):
|
class CandidateEntity(Candidate):
|
||||||
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
|
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
|
||||||
@ -115,7 +116,7 @@ class CandidateEntity(Candidate):
|
|||||||
raise ValueError("initialize() has not been called.")
|
raise ValueError("initialize() has not been called.")
|
||||||
return self.CandidateEntity__observer
|
return self.CandidateEntity__observer
|
||||||
|
|
||||||
async def initialize(self, prometheus_collector: CollectorRegistry):
|
async def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
|
||||||
if self.CandidateEntity__initialized:
|
if self.CandidateEntity__initialized:
|
||||||
# Initialization can only be attempted once; if there are multiple attempts, it means
|
# Initialization can only be attempted once; if there are multiple attempts, it means
|
||||||
# a subsystem is failing or there is a logic bug in the code.
|
# a subsystem is failing or there is a logic bug in the code.
|
||||||
@ -140,9 +141,11 @@ class CandidateEntity(Candidate):
|
|||||||
|
|
||||||
self.CandidateEntity__observer, self.CandidateEntity__file_watcher = start_file_watcher(
|
self.CandidateEntity__observer, self.CandidateEntity__file_watcher = start_file_watcher(
|
||||||
llm=llm_manager.get_llm(),
|
llm=llm_manager.get_llm(),
|
||||||
|
user_id=self.id,
|
||||||
collection_name=self.username,
|
collection_name=self.username,
|
||||||
persist_directory=vector_db_dir,
|
persist_directory=vector_db_dir,
|
||||||
watch_directory=rag_content_dir,
|
watch_directory=rag_content_dir,
|
||||||
|
database=database,
|
||||||
recreate=False, # Don't recreate if exists
|
recreate=False, # Don't recreate if exists
|
||||||
)
|
)
|
||||||
has_username_rag = any(item["name"] == self.username for item in self.rags)
|
has_username_rag = any(item["name"] == self.username for item in self.rags)
|
||||||
|
@ -7,6 +7,7 @@ from pydantic import BaseModel, Field # type: ignore
|
|||||||
|
|
||||||
from models import ( Candidate )
|
from models import ( Candidate )
|
||||||
from .candidate_entity import CandidateEntity
|
from .candidate_entity import CandidateEntity
|
||||||
|
from database import RedisDatabase
|
||||||
from prometheus_client import CollectorRegistry # type: ignore
|
from prometheus_client import CollectorRegistry # type: ignore
|
||||||
|
|
||||||
class EntityManager:
|
class EntityManager:
|
||||||
@ -34,9 +35,10 @@ class EntityManager:
|
|||||||
pass
|
pass
|
||||||
self._cleanup_task = None
|
self._cleanup_task = None
|
||||||
|
|
||||||
def initialize(self, prometheus_collector: CollectorRegistry):
|
def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
|
||||||
"""Initialize the EntityManager with Prometheus collector"""
|
"""Initialize the EntityManager with Prometheus collector"""
|
||||||
self._prometheus_collector = prometheus_collector
|
self._prometheus_collector = prometheus_collector
|
||||||
|
self._database = database
|
||||||
|
|
||||||
async def get_entity(self, candidate: Candidate) -> CandidateEntity:
|
async def get_entity(self, candidate: Candidate) -> CandidateEntity:
|
||||||
"""Get or create CandidateEntity with proper reference tracking"""
|
"""Get or create CandidateEntity with proper reference tracking"""
|
||||||
@ -49,7 +51,7 @@ class EntityManager:
|
|||||||
return entity
|
return entity
|
||||||
|
|
||||||
entity = CandidateEntity(candidate=candidate)
|
entity = CandidateEntity(candidate=candidate)
|
||||||
await entity.initialize(prometheus_collector=self._prometheus_collector)
|
await entity.initialize(prometheus_collector=self._prometheus_collector, database=self._database)
|
||||||
|
|
||||||
# Store with reference tracking
|
# Store with reference tracking
|
||||||
self._entities[candidate.id] = entity
|
self._entities[candidate.id] = entity
|
||||||
|
@ -1,41 +0,0 @@
|
|||||||
import ollama
|
|
||||||
import defines
|
|
||||||
|
|
||||||
_llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
|
|
||||||
|
|
||||||
class llm_manager:
|
|
||||||
"""
|
|
||||||
A class to manage LLM operations using the Ollama client.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def get_llm() -> ollama.Client: # type: ignore
|
|
||||||
"""
|
|
||||||
Get the Ollama client instance.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An instance of the Ollama client.
|
|
||||||
"""
|
|
||||||
return _llm
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_models() -> list[str]:
|
|
||||||
"""
|
|
||||||
Get a list of available models from the Ollama client.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of model names.
|
|
||||||
"""
|
|
||||||
return _llm.models()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_model_info(model_name: str) -> dict:
|
|
||||||
"""
|
|
||||||
Get information about a specific model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: The name of the model to retrieve information for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary containing model information.
|
|
||||||
"""
|
|
||||||
return _llm.model(model_name)
|
|
1203
src/backend/llm_proxy.py
Normal file
1203
src/backend/llm_proxy.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -13,6 +13,9 @@ import uuid
|
|||||||
import defines
|
import defines
|
||||||
import pathlib
|
import pathlib
|
||||||
|
|
||||||
|
from markitdown import MarkItDown, StreamInfo # type: ignore
|
||||||
|
import io
|
||||||
|
|
||||||
import uvicorn # type: ignore
|
import uvicorn # type: ignore
|
||||||
from typing import List, Optional, Dict, Any
|
from typing import List, Optional, Dict, Any
|
||||||
from datetime import datetime, timedelta, UTC
|
from datetime import datetime, timedelta, UTC
|
||||||
@ -53,7 +56,7 @@ import defines
|
|||||||
from logger import logger
|
from logger import logger
|
||||||
from database import RedisDatabase, redis_manager, DatabaseManager
|
from database import RedisDatabase, redis_manager, DatabaseManager
|
||||||
from metrics import Metrics
|
from metrics import Metrics
|
||||||
from llm_manager import llm_manager
|
import llm_proxy as llm_manager
|
||||||
import entities
|
import entities
|
||||||
from email_service import VerificationEmailRateLimiter, email_service
|
from email_service import VerificationEmailRateLimiter, email_service
|
||||||
from device_manager import DeviceManager
|
from device_manager import DeviceManager
|
||||||
@ -116,6 +119,7 @@ async def lifespan(app: FastAPI):
|
|||||||
try:
|
try:
|
||||||
# Initialize database
|
# Initialize database
|
||||||
await db_manager.initialize()
|
await db_manager.initialize()
|
||||||
|
entities.entity_manager.initialize(prometheus_collector, database=db_manager.get_database())
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
@ -1827,7 +1831,7 @@ async def upload_candidate_document(
|
|||||||
yield error_message
|
yield error_message
|
||||||
return
|
return
|
||||||
|
|
||||||
converted = False;
|
converted = False
|
||||||
if document_type != DocumentType.MARKDOWN and document_type != DocumentType.TXT:
|
if document_type != DocumentType.MARKDOWN and document_type != DocumentType.TXT:
|
||||||
p = pathlib.Path(file_path)
|
p = pathlib.Path(file_path)
|
||||||
p_as_md = p.with_suffix(".md")
|
p_as_md = p.with_suffix(".md")
|
||||||
@ -1873,53 +1877,6 @@ async def upload_candidate_document(
|
|||||||
content=file_content,
|
content=file_content,
|
||||||
)
|
)
|
||||||
yield chat_message
|
yield chat_message
|
||||||
|
|
||||||
# If this is a job description, process it with the job requirements agent
|
|
||||||
if not options.is_job_document:
|
|
||||||
return
|
|
||||||
|
|
||||||
status_message = ChatMessageStatus(
|
|
||||||
session_id=MOCK_UUID, # No session ID for document uploads
|
|
||||||
content=f"Initiating connection with {candidate.first_name}'s AI agent...",
|
|
||||||
activity=ApiActivityType.INFO
|
|
||||||
)
|
|
||||||
yield status_message
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
|
|
||||||
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
|
|
||||||
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
|
|
||||||
if not chat_agent:
|
|
||||||
error_message = ChatMessageError(
|
|
||||||
session_id=MOCK_UUID, # No session ID for document uploads
|
|
||||||
content="No agent found for job requirements chat type"
|
|
||||||
)
|
|
||||||
yield error_message
|
|
||||||
return
|
|
||||||
message = None
|
|
||||||
status_message = ChatMessageStatus(
|
|
||||||
session_id=MOCK_UUID, # No session ID for document uploads
|
|
||||||
content=f"Analyzing document for company and requirement details...",
|
|
||||||
activity=ApiActivityType.SEARCHING
|
|
||||||
)
|
|
||||||
yield status_message
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
|
|
||||||
async for message in chat_agent.generate(
|
|
||||||
llm=llm_manager.get_llm(),
|
|
||||||
model=defines.model,
|
|
||||||
session_id=MOCK_UUID,
|
|
||||||
prompt=file_content
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
if not message or not isinstance(message, JobRequirementsMessage):
|
|
||||||
error_message = ChatMessageError(
|
|
||||||
session_id=MOCK_UUID, # No session ID for document uploads
|
|
||||||
content="Failed to process job description file"
|
|
||||||
)
|
|
||||||
yield error_message
|
|
||||||
return
|
|
||||||
yield message
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async def to_json(method):
|
async def to_json(method):
|
||||||
try:
|
try:
|
||||||
@ -1932,7 +1889,6 @@ async def upload_candidate_document(
|
|||||||
logger.error(f"Error in to_json conversion: {e}")
|
logger.error(f"Error in to_json conversion: {e}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# return DebugStreamingResponse(
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
to_json(upload_stream_generator(file_content)),
|
to_json(upload_stream_generator(file_content)),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
@ -1948,11 +1904,60 @@ async def upload_candidate_document(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(backstory_traceback.format_exc())
|
logger.error(backstory_traceback.format_exc())
|
||||||
logger.error(f"❌ Document upload error: {e}")
|
logger.error(f"❌ Document upload error: {e}")
|
||||||
return JSONResponse(
|
return StreamingResponse(
|
||||||
status_code=500,
|
iter([ChatMessageError(
|
||||||
content=create_error_response("UPLOAD_ERROR", "Failed to upload document")
|
session_id=MOCK_UUID, # No session ID for document uploads
|
||||||
|
content="Failed to upload document"
|
||||||
|
)]),
|
||||||
|
media_type="text/event-stream"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def create_job_from_content(database: RedisDatabase, current_user: Candidate, content: str):
|
||||||
|
status_message = ChatMessageStatus(
|
||||||
|
session_id=MOCK_UUID, # No session ID for document uploads
|
||||||
|
content=f"Initiating connection with {current_user.first_name}'s AI agent...",
|
||||||
|
activity=ApiActivityType.INFO
|
||||||
|
)
|
||||||
|
yield status_message
|
||||||
|
await asyncio.sleep(0) # Let the status message propagate
|
||||||
|
|
||||||
|
async with entities.get_candidate_entity(candidate=current_user) as candidate_entity:
|
||||||
|
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
|
||||||
|
if not chat_agent:
|
||||||
|
error_message = ChatMessageError(
|
||||||
|
session_id=MOCK_UUID, # No session ID for document uploads
|
||||||
|
content="No agent found for job requirements chat type"
|
||||||
|
)
|
||||||
|
yield error_message
|
||||||
|
return
|
||||||
|
message = None
|
||||||
|
status_message = ChatMessageStatus(
|
||||||
|
session_id=MOCK_UUID, # No session ID for document uploads
|
||||||
|
content=f"Analyzing document for company and requirement details...",
|
||||||
|
activity=ApiActivityType.SEARCHING
|
||||||
|
)
|
||||||
|
yield status_message
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
async for message in chat_agent.generate(
|
||||||
|
llm=llm_manager.get_llm(),
|
||||||
|
model=defines.model,
|
||||||
|
session_id=MOCK_UUID,
|
||||||
|
prompt=content
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
if not message or not isinstance(message, JobRequirementsMessage):
|
||||||
|
error_message = ChatMessageError(
|
||||||
|
session_id=MOCK_UUID, # No session ID for document uploads
|
||||||
|
content="Failed to process job description file"
|
||||||
|
)
|
||||||
|
yield error_message
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"✅ Successfully saved job requirements job {message.id}")
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
@api_router.post("/candidates/profile/upload")
|
@api_router.post("/candidates/profile/upload")
|
||||||
async def upload_candidate_profile(
|
async def upload_candidate_profile(
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
@ -2573,6 +2578,7 @@ async def delete_candidate(
|
|||||||
status_code=500,
|
status_code=500,
|
||||||
content=create_error_response("DELETE_ERROR", "Failed to delete candidate")
|
content=create_error_response("DELETE_ERROR", "Failed to delete candidate")
|
||||||
)
|
)
|
||||||
|
|
||||||
@api_router.patch("/candidates/{candidate_id}")
|
@api_router.patch("/candidates/{candidate_id}")
|
||||||
async def update_candidate(
|
async def update_candidate(
|
||||||
candidate_id: str = Path(...),
|
candidate_id: str = Path(...),
|
||||||
@ -2816,6 +2822,139 @@ async def create_candidate_job(
|
|||||||
content=create_error_response("CREATION_FAILED", str(e))
|
content=create_error_response("CREATION_FAILED", str(e))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_router.post("/jobs/upload")
|
||||||
|
async def create_job_from_file(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
current_user = Depends(get_current_user),
|
||||||
|
database: RedisDatabase = Depends(get_database)
|
||||||
|
):
|
||||||
|
"""Upload a job document for the current candidate and create a Job"""
|
||||||
|
# Check file size (limit to 10MB)
|
||||||
|
max_size = 10 * 1024 * 1024 # 10MB
|
||||||
|
file_content = await file.read()
|
||||||
|
if len(file_content) > max_size:
|
||||||
|
logger.info(f"⚠️ File too large: {file.filename} ({len(file_content)} bytes)")
|
||||||
|
return StreamingResponse(
|
||||||
|
iter([ChatMessageError(
|
||||||
|
session_id=MOCK_UUID, # No session ID for document uploads
|
||||||
|
content="File size exceeds 10MB limit"
|
||||||
|
)]),
|
||||||
|
media_type="text/event-stream"
|
||||||
|
)
|
||||||
|
if len(file_content) == 0:
|
||||||
|
logger.info(f"⚠️ File is empty: {file.filename}")
|
||||||
|
return StreamingResponse(
|
||||||
|
iter([ChatMessageError(
|
||||||
|
session_id=MOCK_UUID, # No session ID for document uploads
|
||||||
|
content="File is empty"
|
||||||
|
)]),
|
||||||
|
media_type="text/event-stream"
|
||||||
|
)
|
||||||
|
|
||||||
|
"""Upload a document for the current candidate"""
|
||||||
|
async def upload_stream_generator(file_content):
|
||||||
|
# Verify user is a candidate
|
||||||
|
if current_user.user_type != "candidate":
|
||||||
|
logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}")
|
||||||
|
error_message = ChatMessageError(
|
||||||
|
session_id=MOCK_UUID, # No session ID for document uploads
|
||||||
|
content="Only candidates can upload documents"
|
||||||
|
)
|
||||||
|
yield error_message
|
||||||
|
return
|
||||||
|
|
||||||
|
file.filename = re.sub(r'^.*/', '', file.filename) if file.filename else '' # Sanitize filename
|
||||||
|
if not file.filename or file.filename.strip() == "":
|
||||||
|
logger.warning("⚠️ File upload attempt with missing filename")
|
||||||
|
error_message = ChatMessageError(
|
||||||
|
session_id=MOCK_UUID, # No session ID for document uploads
|
||||||
|
content="File must have a valid filename"
|
||||||
|
)
|
||||||
|
yield error_message
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"📁 Received file upload: filename='{file.filename}', content_type='{file.content_type}', size='{len(file_content)} bytes'")
|
||||||
|
|
||||||
|
# Validate file type
|
||||||
|
allowed_types = ['.txt', '.md', '.docx', '.pdf', '.png', '.jpg', '.jpeg', '.gif']
|
||||||
|
file_extension = pathlib.Path(file.filename).suffix.lower() if file.filename else ""
|
||||||
|
|
||||||
|
if file_extension not in allowed_types:
|
||||||
|
logger.warning(f"⚠️ Invalid file type: {file_extension} for file {file.filename}")
|
||||||
|
error_message = ChatMessageError(
|
||||||
|
session_id=MOCK_UUID, # No session ID for document uploads
|
||||||
|
content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}"
|
||||||
|
)
|
||||||
|
yield error_message
|
||||||
|
return
|
||||||
|
|
||||||
|
document_type = get_document_type_from_filename(file.filename or "unknown.txt")
|
||||||
|
|
||||||
|
if document_type != DocumentType.MARKDOWN and document_type != DocumentType.TXT:
|
||||||
|
status_message = ChatMessageStatus(
|
||||||
|
session_id=MOCK_UUID, # No session ID for document uploads
|
||||||
|
content=f"Converting content from {document_type}...",
|
||||||
|
activity=ApiActivityType.CONVERTING
|
||||||
|
)
|
||||||
|
yield status_message
|
||||||
|
try:
|
||||||
|
md = MarkItDown(enable_plugins=False) # Set to True to enable plugins
|
||||||
|
stream = io.BytesIO(file_content)
|
||||||
|
stream_info = StreamInfo(
|
||||||
|
extension=file_extension, # e.g., ".pdf"
|
||||||
|
url=file.filename # optional, helps with logging and guessing
|
||||||
|
)
|
||||||
|
result = md.convert_stream(stream, stream_info=stream_info, output_format="markdown")
|
||||||
|
file_content = result.text_content
|
||||||
|
logger.info(f"✅ Converted {file.filename} to Markdown format")
|
||||||
|
except Exception as e:
|
||||||
|
error_message = ChatMessageError(
|
||||||
|
session_id=MOCK_UUID, # No session ID for document uploads
|
||||||
|
content=f"Failed to convert {file.filename} to Markdown.",
|
||||||
|
)
|
||||||
|
yield error_message
|
||||||
|
logger.error(f"❌ Error converting {file.filename} to Markdown: {e}")
|
||||||
|
return
|
||||||
|
async for message in create_job_from_content(database=database, current_user=current_user, content=file_content):
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
async def to_json(method):
|
||||||
|
try:
|
||||||
|
async for message in method:
|
||||||
|
json_data = message.model_dump(mode='json', by_alias=True)
|
||||||
|
json_str = json.dumps(json_data)
|
||||||
|
yield f"data: {json_str}\n\n".encode("utf-8")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(backstory_traceback.format_exc())
|
||||||
|
logger.error(f"Error in to_json conversion: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
to_json(upload_stream_generator(file_content)),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache, no-store, must-revalidate",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no", # Nginx
|
||||||
|
"X-Content-Type-Options": "nosniff",
|
||||||
|
"Access-Control-Allow-Origin": "*", # Adjust for your CORS needs
|
||||||
|
"Transfer-Encoding": "chunked",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(backstory_traceback.format_exc())
|
||||||
|
logger.error(f"❌ Document upload error: {e}")
|
||||||
|
return StreamingResponse(
|
||||||
|
iter([ChatMessageError(
|
||||||
|
session_id=MOCK_UUID, # No session ID for document uploads
|
||||||
|
content="Failed to upload document"
|
||||||
|
)]),
|
||||||
|
media_type="text/event-stream"
|
||||||
|
)
|
||||||
|
|
||||||
@api_router.get("/jobs/{job_id}")
|
@api_router.get("/jobs/{job_id}")
|
||||||
async def get_job(
|
async def get_job(
|
||||||
job_id: str = Path(...),
|
job_id: str = Path(...),
|
||||||
@ -2931,6 +3070,49 @@ async def search_jobs(
|
|||||||
content=create_error_response("SEARCH_FAILED", str(e))
|
content=create_error_response("SEARCH_FAILED", str(e))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_router.delete("/jobs/{job_id}")
|
||||||
|
async def delete_job(
|
||||||
|
job_id: str = Path(...),
|
||||||
|
admin_user = Depends(get_current_admin),
|
||||||
|
database: RedisDatabase = Depends(get_database)
|
||||||
|
):
|
||||||
|
"""Delete a Job"""
|
||||||
|
try:
|
||||||
|
# Check if admin user
|
||||||
|
if not admin_user.is_admin:
|
||||||
|
logger.warning(f"⚠️ Unauthorized delete attempt by user {admin_user.id}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=403,
|
||||||
|
content=create_error_response("FORBIDDEN", "Only admins can delete")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get candidate data
|
||||||
|
job_data = await database.get_job(job_id)
|
||||||
|
if not job_data:
|
||||||
|
logger.warning(f"⚠️ Candidate not found for deletion: {job_id}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=404,
|
||||||
|
content=create_error_response("NOT_FOUND", "Job not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete job from database
|
||||||
|
await database.delete_job(job_id)
|
||||||
|
|
||||||
|
logger.info(f"🗑️ Job deleted: {job_id} by admin {admin_user.id}")
|
||||||
|
|
||||||
|
return create_success_response({
|
||||||
|
"message": "Job deleted successfully",
|
||||||
|
"jobId": job_id
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Delete job error: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=create_error_response("DELETE_ERROR", "Failed to delete job")
|
||||||
|
)
|
||||||
|
|
||||||
# ============================
|
# ============================
|
||||||
# Chat Endpoints
|
# Chat Endpoints
|
||||||
# ============================
|
# ============================
|
||||||
@ -3541,7 +3723,7 @@ async def get_candidate_skill_match(
|
|||||||
current_user = Depends(get_current_user),
|
current_user = Depends(get_current_user),
|
||||||
database: RedisDatabase = Depends(get_database)
|
database: RedisDatabase = Depends(get_database)
|
||||||
):
|
):
|
||||||
"""Get skill match for a candidate against a requirement"""
|
"""Get skill match for a candidate against a requirement with caching"""
|
||||||
try:
|
try:
|
||||||
# Find candidate by ID
|
# Find candidate by ID
|
||||||
candidate_data = await database.get_candidate(candidate_id)
|
candidate_data = await database.get_candidate(candidate_id)
|
||||||
@ -3553,34 +3735,89 @@ async def get_candidate_skill_match(
|
|||||||
|
|
||||||
candidate = Candidate.model_validate(candidate_data)
|
candidate = Candidate.model_validate(candidate_data)
|
||||||
|
|
||||||
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
|
# Create cache key for this specific candidate + requirement combination
|
||||||
logger.info(f"🔍 Running skill match for candidate {candidate_entity.username} against requirement: {requirement}")
|
cache_key = f"skill_match:{candidate_id}:{hash(requirement)}"
|
||||||
agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.SKILL_MATCH)
|
|
||||||
if not agent:
|
# Get cached assessment if it exists
|
||||||
return JSONResponse(
|
cached_assessment = await database.get_cached_skill_match(cache_key)
|
||||||
status_code=400,
|
|
||||||
content=create_error_response("AGENT_NOT_FOUND", "No skill match agent found for this candidate")
|
# Get the last update time for the candidate's skill information
|
||||||
)
|
candidate_skill_update_time = await database.get_candidate_skill_update_time(candidate_id)
|
||||||
# Entity automatically released when done
|
|
||||||
skill_match = await get_last_item(
|
# Get the latest RAG data update time for the current user
|
||||||
agent.generate(
|
user_rag_update_time = await database.get_user_rag_update_time(current_user.id)
|
||||||
llm=llm_manager.get_llm(),
|
|
||||||
model=defines.model,
|
# Determine if we need to regenerate the assessment
|
||||||
session_id=MOCK_UUID,
|
should_regenerate = True
|
||||||
prompt=requirement,
|
cached_date = None
|
||||||
|
|
||||||
|
if cached_assessment:
|
||||||
|
cached_date = cached_assessment.get('cached_at')
|
||||||
|
if cached_date:
|
||||||
|
# Check if cached result is still valid
|
||||||
|
# Regenerate if:
|
||||||
|
# 1. Candidate skills were updated after cache date
|
||||||
|
# 2. User's RAG data was updated after cache date
|
||||||
|
if (not candidate_skill_update_time or cached_date >= candidate_skill_update_time) and \
|
||||||
|
(not user_rag_update_time or cached_date >= user_rag_update_time):
|
||||||
|
should_regenerate = False
|
||||||
|
logger.info(f"🔄 Using cached skill match for candidate {candidate.id}")
|
||||||
|
|
||||||
|
if should_regenerate:
|
||||||
|
logger.info(f"🔍 Generating new skill match for candidate {candidate.id} against requirement: {requirement}")
|
||||||
|
|
||||||
|
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
|
||||||
|
agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.SKILL_MATCH)
|
||||||
|
if not agent:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content=create_error_response("AGENT_NOT_FOUND", "No skill match agent found for this candidate")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate new skill match
|
||||||
|
skill_match = await get_last_item(
|
||||||
|
agent.generate(
|
||||||
|
llm=llm_manager.get_llm(),
|
||||||
|
model=defines.model,
|
||||||
|
session_id=MOCK_UUID,
|
||||||
|
prompt=requirement,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if skill_match is None:
|
|
||||||
|
if skill_match is None:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=create_error_response("NO_MATCH", "No skill match found for the given requirement")
|
||||||
|
)
|
||||||
|
|
||||||
|
skill_match_data = json.loads(skill_match.content)
|
||||||
|
|
||||||
|
# Cache the new assessment with current timestamp
|
||||||
|
cached_assessment = {
|
||||||
|
"skill_match": skill_match_data,
|
||||||
|
"cached_at": datetime.utcnow().isoformat(),
|
||||||
|
"candidate_id": candidate_id,
|
||||||
|
"requirement": requirement
|
||||||
|
}
|
||||||
|
|
||||||
|
await database.cache_skill_match(cache_key, cached_assessment)
|
||||||
|
logger.info(f"💾 Cached new skill match for candidate {candidate.id}")
|
||||||
|
logger.info(f"✅ Skill match found for candidate {candidate.id}: {skill_match_data['evidence_strength']}")
|
||||||
|
else:
|
||||||
|
# Use cached result - we know cached_assessment is not None here
|
||||||
|
if cached_assessment is None:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
content=create_error_response("NO_MATCH", "No skill match found for the given requirement")
|
content=create_error_response("CACHE_ERROR", "Unexpected cache state")
|
||||||
)
|
)
|
||||||
skill_match = json.loads(skill_match.content)
|
skill_match_data = cached_assessment["skill_match"]
|
||||||
logger.info(f"✅ Skill match found for candidate {candidate.id}: {skill_match['evidence_strength']}")
|
logger.info(f"✅ Retrieved cached skill match for candidate {candidate.id}: {skill_match_data['evidence_strength']}")
|
||||||
|
|
||||||
return create_success_response({
|
return create_success_response({
|
||||||
"candidateId": candidate.id,
|
"candidateId": candidate.id,
|
||||||
"skillMatch": skill_match
|
"skillMatch": skill_match_data,
|
||||||
|
"cached": not should_regenerate,
|
||||||
|
"cacheTimestamp": cached_date
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -3911,7 +4148,6 @@ async def track_requests(request, call_next):
|
|||||||
# FastAPI Metrics
|
# FastAPI Metrics
|
||||||
# ============================
|
# ============================
|
||||||
prometheus_collector = CollectorRegistry()
|
prometheus_collector = CollectorRegistry()
|
||||||
entities.entity_manager.initialize(prometheus_collector)
|
|
||||||
|
|
||||||
# Keep the Instrumentator instance alive
|
# Keep the Instrumentator instance alive
|
||||||
instrumentator = Instrumentator(
|
instrumentator = Instrumentator(
|
||||||
|
@ -768,11 +768,7 @@ class ChatOptions(BaseModel):
|
|||||||
"populate_by_name": True # Allow both field names and aliases
|
"populate_by_name": True # Allow both field names and aliases
|
||||||
}
|
}
|
||||||
|
|
||||||
|
from llm_proxy import (LLMMessage)
|
||||||
class LLMMessage(BaseModel):
|
|
||||||
role: str = Field(default="")
|
|
||||||
content: str = Field(default="")
|
|
||||||
tool_calls: Optional[List[Dict]] = Field(default=[], exclude=True)
|
|
||||||
|
|
||||||
class ApiMessage(BaseModel):
|
class ApiMessage(BaseModel):
|
||||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
@ -13,7 +13,6 @@ import numpy as np # type: ignore
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import chromadb # type: ignore
|
import chromadb # type: ignore
|
||||||
import ollama
|
|
||||||
from watchdog.observers import Observer # type: ignore
|
from watchdog.observers import Observer # type: ignore
|
||||||
from watchdog.events import FileSystemEventHandler # type: ignore
|
from watchdog.events import FileSystemEventHandler # type: ignore
|
||||||
import umap # type: ignore
|
import umap # type: ignore
|
||||||
@ -27,6 +26,7 @@ from .markdown_chunker import (
|
|||||||
|
|
||||||
# When imported as a module, use relative imports
|
# When imported as a module, use relative imports
|
||||||
import defines
|
import defines
|
||||||
|
from database import RedisDatabase
|
||||||
from models import ChromaDBGetResponse
|
from models import ChromaDBGetResponse
|
||||||
|
|
||||||
__all__ = ["ChromaDBFileWatcher", "start_file_watcher"]
|
__all__ = ["ChromaDBFileWatcher", "start_file_watcher"]
|
||||||
@ -47,11 +47,16 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
loop,
|
loop,
|
||||||
persist_directory,
|
persist_directory,
|
||||||
collection_name,
|
collection_name,
|
||||||
|
database: RedisDatabase,
|
||||||
|
user_id: str,
|
||||||
chunk_size=DEFAULT_CHUNK_SIZE,
|
chunk_size=DEFAULT_CHUNK_SIZE,
|
||||||
chunk_overlap=DEFAULT_CHUNK_OVERLAP,
|
chunk_overlap=DEFAULT_CHUNK_OVERLAP,
|
||||||
recreate=False,
|
recreate=False,
|
||||||
):
|
):
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
|
self.database = database
|
||||||
|
self.user_id = user_id
|
||||||
|
self.database = database
|
||||||
self.watch_directory = watch_directory
|
self.watch_directory = watch_directory
|
||||||
self.persist_directory = persist_directory or defines.persist_directory
|
self.persist_directory = persist_directory or defines.persist_directory
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
@ -284,6 +289,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
|
|
||||||
if results and "ids" in results and results["ids"]:
|
if results and "ids" in results and results["ids"]:
|
||||||
self.collection.delete(ids=results["ids"])
|
self.collection.delete(ids=results["ids"])
|
||||||
|
await self.database.update_user_rag_timestamp(self.user_id)
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Removed {len(results['ids'])} chunks for deleted file: {file_path}"
|
f"Removed {len(results['ids'])} chunks for deleted file: {file_path}"
|
||||||
)
|
)
|
||||||
@ -372,14 +378,15 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
name=self.collection_name, metadata={"hnsw:space": "cosine"}
|
name=self.collection_name, metadata={"hnsw:space": "cosine"}
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_embedding(self, text: str) -> np.ndarray:
|
async def get_embedding(self, text: str) -> np.ndarray:
|
||||||
"""Generate and normalize an embedding for the given text."""
|
"""Generate and normalize an embedding for the given text."""
|
||||||
|
|
||||||
# Get embedding
|
# Get embedding
|
||||||
try:
|
try:
|
||||||
response = self.llm.embeddings(model=defines.embedding_model, prompt=text)
|
response = await self.llm.embeddings(model=defines.embedding_model, input_texts=text)
|
||||||
embedding = np.array(response["embedding"])
|
embedding = np.array(response.get_single_embedding())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logging.error(traceback.format_exc())
|
||||||
logging.error(f"Failed to get embedding: {e}")
|
logging.error(f"Failed to get embedding: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -404,7 +411,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
|
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
def add_embeddings_to_collection(self, chunks: List[Chunk]):
|
async def _add_embeddings_to_collection(self, chunks: List[Chunk]):
|
||||||
"""Add embeddings for chunks to the collection."""
|
"""Add embeddings for chunks to the collection."""
|
||||||
|
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
@ -420,7 +427,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
content_hash = hashlib.md5(text.encode()).hexdigest()[:8]
|
content_hash = hashlib.md5(text.encode()).hexdigest()[:8]
|
||||||
chunk_id = f"{path_hash}_{i}_{content_hash}"
|
chunk_id = f"{path_hash}_{i}_{content_hash}"
|
||||||
|
|
||||||
embedding = self.get_embedding(text)
|
embedding = await self.get_embedding(text)
|
||||||
try:
|
try:
|
||||||
self.collection.add(
|
self.collection.add(
|
||||||
ids=[chunk_id],
|
ids=[chunk_id],
|
||||||
@ -458,11 +465,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
# 0.5 - 0.7 0.65 - 0.75 Balanced precision/recall
|
# 0.5 - 0.7 0.65 - 0.75 Balanced precision/recall
|
||||||
# 0.7 - 0.9 0.55 - 0.65 Higher recall, more inclusive
|
# 0.7 - 0.9 0.55 - 0.65 Higher recall, more inclusive
|
||||||
# 0.9 - 1.2 0.40 - 0.55 Very inclusive, may include tangential content
|
# 0.9 - 1.2 0.40 - 0.55 Very inclusive, may include tangential content
|
||||||
def find_similar(self, query, top_k=defines.default_rag_top_k, threshold=defines.default_rag_threshold):
|
async def find_similar(self, query, top_k=defines.default_rag_top_k, threshold=defines.default_rag_threshold):
|
||||||
"""Find similar documents to the query."""
|
"""Find similar documents to the query."""
|
||||||
|
|
||||||
# collection is configured with hnsw:space cosine
|
# collection is configured with hnsw:space cosine
|
||||||
query_embedding = self.get_embedding(query)
|
query_embedding = await self.get_embedding(query)
|
||||||
results = self.collection.query(
|
results = self.collection.query(
|
||||||
query_embeddings=[query_embedding],
|
query_embeddings=[query_embedding],
|
||||||
n_results=top_k,
|
n_results=top_k,
|
||||||
@ -572,6 +579,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
and existing_results["ids"]
|
and existing_results["ids"]
|
||||||
):
|
):
|
||||||
self.collection.delete(ids=existing_results["ids"])
|
self.collection.delete(ids=existing_results["ids"])
|
||||||
|
await self.database.update_user_rag_timestamp(self.user_id)
|
||||||
|
|
||||||
extensions = (".docx", ".xlsx", ".xls", ".pdf")
|
extensions = (".docx", ".xlsx", ".xls", ".pdf")
|
||||||
if file_path.endswith(extensions):
|
if file_path.endswith(extensions):
|
||||||
@ -606,7 +614,8 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
# f.write(json.dumps(chunk, indent=2))
|
# f.write(json.dumps(chunk, indent=2))
|
||||||
|
|
||||||
# Add chunks to collection
|
# Add chunks to collection
|
||||||
self.add_embeddings_to_collection(chunks)
|
await self._add_embeddings_to_collection(chunks)
|
||||||
|
await self.database.update_user_rag_timestamp(self.user_id)
|
||||||
|
|
||||||
logging.info(f"Updated {len(chunks)} chunks for file: {file_path}")
|
logging.info(f"Updated {len(chunks)} chunks for file: {file_path}")
|
||||||
|
|
||||||
@ -640,9 +649,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
# Function to start the file watcher
|
# Function to start the file watcher
|
||||||
def start_file_watcher(
|
def start_file_watcher(
|
||||||
llm,
|
llm,
|
||||||
|
user_id,
|
||||||
watch_directory,
|
watch_directory,
|
||||||
persist_directory,
|
persist_directory,
|
||||||
collection_name,
|
collection_name,
|
||||||
|
database: RedisDatabase,
|
||||||
initialize=False,
|
initialize=False,
|
||||||
recreate=False,
|
recreate=False,
|
||||||
):
|
):
|
||||||
@ -663,9 +674,11 @@ def start_file_watcher(
|
|||||||
llm,
|
llm,
|
||||||
watch_directory=watch_directory,
|
watch_directory=watch_directory,
|
||||||
loop=loop,
|
loop=loop,
|
||||||
|
user_id=user_id,
|
||||||
persist_directory=persist_directory,
|
persist_directory=persist_directory,
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
recreate=recreate,
|
recreate=recreate,
|
||||||
|
database=database
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process all files if:
|
# Process all files if:
|
||||||
|
193
src/multi-llm/config.md
Normal file
193
src/multi-llm/config.md
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
# Environment Configuration Examples
|
||||||
|
|
||||||
|
# ==== Development (Ollama only) ====
|
||||||
|
export OLLAMA_HOST="http://localhost:11434"
|
||||||
|
export DEFAULT_LLM_PROVIDER="ollama"
|
||||||
|
|
||||||
|
# ==== Production with OpenAI ====
|
||||||
|
export OPENAI_API_KEY="sk-your-openai-key-here"
|
||||||
|
export DEFAULT_LLM_PROVIDER="openai"
|
||||||
|
|
||||||
|
# ==== Production with Anthropic ====
|
||||||
|
export ANTHROPIC_API_KEY="sk-ant-your-anthropic-key-here"
|
||||||
|
export DEFAULT_LLM_PROVIDER="anthropic"
|
||||||
|
|
||||||
|
# ==== Production with multiple providers ====
|
||||||
|
export OPENAI_API_KEY="sk-your-openai-key-here"
|
||||||
|
export ANTHROPIC_API_KEY="sk-ant-your-anthropic-key-here"
|
||||||
|
export GEMINI_API_KEY="your-gemini-key-here"
|
||||||
|
export OLLAMA_HOST="http://ollama-server:11434"
|
||||||
|
export DEFAULT_LLM_PROVIDER="openai"
|
||||||
|
|
||||||
|
# ==== Docker Compose Example ====
|
||||||
|
# docker-compose.yml
|
||||||
|
version: '3.8'
|
||||||
|
services:
|
||||||
|
api:
|
||||||
|
build: .
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
environment:
|
||||||
|
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||||
|
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
|
||||||
|
- DEFAULT_LLM_PROVIDER=openai
|
||||||
|
depends_on:
|
||||||
|
- ollama
|
||||||
|
|
||||||
|
ollama:
|
||||||
|
image: ollama/ollama
|
||||||
|
ports:
|
||||||
|
- "11434:11434"
|
||||||
|
volumes:
|
||||||
|
- ollama_data:/root/.ollama
|
||||||
|
environment:
|
||||||
|
- OLLAMA_HOST=0.0.0.0
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
ollama_data:
|
||||||
|
|
||||||
|
# ==== Kubernetes ConfigMap ====
|
||||||
|
apiVersion: v1
|
||||||
|
kind: ConfigMap
|
||||||
|
metadata:
|
||||||
|
name: llm-config
|
||||||
|
data:
|
||||||
|
DEFAULT_LLM_PROVIDER: "openai"
|
||||||
|
OLLAMA_HOST: "http://ollama-service:11434"
|
||||||
|
|
||||||
|
---
|
||||||
|
apiVersion: v1
|
||||||
|
kind: Secret
|
||||||
|
metadata:
|
||||||
|
name: llm-secrets
|
||||||
|
type: Opaque
|
||||||
|
stringData:
|
||||||
|
OPENAI_API_KEY: "sk-your-key-here"
|
||||||
|
ANTHROPIC_API_KEY: "sk-ant-your-key-here"
|
||||||
|
|
||||||
|
# ==== Python configuration example ====
|
||||||
|
# config.py
|
||||||
|
import os
|
||||||
|
from llm_proxy import get_llm, LLMProvider
|
||||||
|
|
||||||
|
def configure_llm_for_environment():
|
||||||
|
"""Configure LLM based on deployment environment"""
|
||||||
|
llm = get_llm()
|
||||||
|
|
||||||
|
# Development: Use Ollama
|
||||||
|
if os.getenv('ENVIRONMENT') == 'development':
|
||||||
|
llm.configure_provider(LLMProvider.OLLAMA, host='http://localhost:11434')
|
||||||
|
llm.set_default_provider(LLMProvider.OLLAMA)
|
||||||
|
|
||||||
|
# Staging: Use OpenAI with rate limiting
|
||||||
|
elif os.getenv('ENVIRONMENT') == 'staging':
|
||||||
|
llm.configure_provider(LLMProvider.OPENAI,
|
||||||
|
api_key=os.getenv('OPENAI_API_KEY'),
|
||||||
|
max_retries=3,
|
||||||
|
timeout=30)
|
||||||
|
llm.set_default_provider(LLMProvider.OPENAI)
|
||||||
|
|
||||||
|
# Production: Use multiple providers with fallback
|
||||||
|
elif os.getenv('ENVIRONMENT') == 'production':
|
||||||
|
# Primary: Anthropic
|
||||||
|
llm.configure_provider(LLMProvider.ANTHROPIC,
|
||||||
|
api_key=os.getenv('ANTHROPIC_API_KEY'))
|
||||||
|
|
||||||
|
# Fallback: OpenAI
|
||||||
|
llm.configure_provider(LLMProvider.OPENAI,
|
||||||
|
api_key=os.getenv('OPENAI_API_KEY'))
|
||||||
|
|
||||||
|
# Set primary provider
|
||||||
|
llm.set_default_provider(LLMProvider.ANTHROPIC)
|
||||||
|
|
||||||
|
# ==== Usage Examples ====
|
||||||
|
|
||||||
|
# Example 1: Basic usage with default provider
|
||||||
|
async def basic_example():
|
||||||
|
llm = get_llm()
|
||||||
|
|
||||||
|
response = await llm.chat(
|
||||||
|
model="gpt-4",
|
||||||
|
messages=[{"role": "user", "content": "Hello!"}]
|
||||||
|
)
|
||||||
|
print(response.content)
|
||||||
|
|
||||||
|
# Example 2: Specify provider explicitly
|
||||||
|
async def provider_specific_example():
|
||||||
|
llm = get_llm()
|
||||||
|
|
||||||
|
# Use OpenAI specifically
|
||||||
|
response = await llm.chat(
|
||||||
|
model="gpt-4",
|
||||||
|
messages=[{"role": "user", "content": "Hello!"}],
|
||||||
|
provider=LLMProvider.OPENAI
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use Anthropic specifically
|
||||||
|
response2 = await llm.chat(
|
||||||
|
model="claude-3-sonnet-20240229",
|
||||||
|
messages=[{"role": "user", "content": "Hello!"}],
|
||||||
|
provider=LLMProvider.ANTHROPIC
|
||||||
|
)
|
||||||
|
|
||||||
|
# Example 3: Streaming with provider fallback
|
||||||
|
async def streaming_with_fallback():
|
||||||
|
llm = get_llm()
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for chunk in llm.chat(
|
||||||
|
model="claude-3-sonnet-20240229",
|
||||||
|
messages=[{"role": "user", "content": "Write a story"}],
|
||||||
|
provider=LLMProvider.ANTHROPIC,
|
||||||
|
stream=True
|
||||||
|
):
|
||||||
|
print(chunk.content, end='', flush=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Primary provider failed: {e}")
|
||||||
|
# Fallback to OpenAI
|
||||||
|
async for chunk in llm.chat(
|
||||||
|
model="gpt-4",
|
||||||
|
messages=[{"role": "user", "content": "Write a story"}],
|
||||||
|
provider=LLMProvider.OPENAI,
|
||||||
|
stream=True
|
||||||
|
):
|
||||||
|
print(chunk.content, end='', flush=True)
|
||||||
|
|
||||||
|
# Example 4: Load balancing between providers
|
||||||
|
import random
|
||||||
|
|
||||||
|
async def load_balanced_request():
|
||||||
|
llm = get_llm()
|
||||||
|
available_providers = [LLMProvider.OPENAI, LLMProvider.ANTHROPIC]
|
||||||
|
|
||||||
|
# Simple random load balancing
|
||||||
|
provider = random.choice(available_providers)
|
||||||
|
|
||||||
|
# Adjust model based on provider
|
||||||
|
model_mapping = {
|
||||||
|
LLMProvider.OPENAI: "gpt-4",
|
||||||
|
LLMProvider.ANTHROPIC: "claude-3-sonnet-20240229"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await llm.chat(
|
||||||
|
model=model_mapping[provider],
|
||||||
|
messages=[{"role": "user", "content": "Hello!"}],
|
||||||
|
provider=provider
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Response from {provider.value}: {response.content}")
|
||||||
|
|
||||||
|
# ==== FastAPI Startup Configuration ====
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
"""Configure LLM providers on application startup"""
|
||||||
|
configure_llm_for_environment()
|
||||||
|
|
||||||
|
# Verify configuration
|
||||||
|
llm = get_llm()
|
||||||
|
print(f"Configured providers: {[p.value for p in llm._initialized_providers]}")
|
||||||
|
print(f"Default provider: {llm.default_provider.value}")
|
238
src/multi-llm/example.py
Normal file
238
src/multi-llm/example.py
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
from fastapi import FastAPI, HTTPException # type: ignore
|
||||||
|
from fastapi.responses import StreamingResponse # type: ignore
|
||||||
|
from pydantic import BaseModel # type: ignore
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
from llm_proxy import get_llm, LLMProvider, ChatMessage
|
||||||
|
|
||||||
|
app = FastAPI(title="Unified LLM API")
|
||||||
|
|
||||||
|
# Pydantic models for request/response
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
messages: List[Dict[str, str]]
|
||||||
|
provider: Optional[str] = None
|
||||||
|
stream: bool = False
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
|
||||||
|
class GenerateRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
prompt: str
|
||||||
|
provider: Optional[str] = None
|
||||||
|
stream: bool = False
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
|
||||||
|
class ChatResponseModel(BaseModel):
|
||||||
|
content: str
|
||||||
|
model: str
|
||||||
|
provider: str
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
usage: Optional[Dict[str, int]] = None
|
||||||
|
|
||||||
|
@app.post("/chat")
|
||||||
|
async def chat_endpoint(request: ChatRequest):
|
||||||
|
"""Chat endpoint that works with any configured LLM provider"""
|
||||||
|
try:
|
||||||
|
llm = get_llm()
|
||||||
|
|
||||||
|
# Convert provider string to enum if provided
|
||||||
|
provider = None
|
||||||
|
if request.provider:
|
||||||
|
try:
|
||||||
|
provider = LLMProvider(request.provider.lower())
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Unsupported provider: {request.provider}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare kwargs
|
||||||
|
kwargs = {}
|
||||||
|
if request.max_tokens:
|
||||||
|
kwargs['max_tokens'] = request.max_tokens
|
||||||
|
if request.temperature:
|
||||||
|
kwargs['temperature'] = request.temperature
|
||||||
|
|
||||||
|
if request.stream:
|
||||||
|
async def generate_response():
|
||||||
|
try:
|
||||||
|
# Use the type-safe streaming method
|
||||||
|
async for chunk in llm.chat_stream(
|
||||||
|
model=request.model,
|
||||||
|
messages=request.messages,
|
||||||
|
provider=provider,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
# Format as Server-Sent Events
|
||||||
|
data = {
|
||||||
|
"content": chunk.content,
|
||||||
|
"model": chunk.model,
|
||||||
|
"provider": provider.value if provider else llm.default_provider.value,
|
||||||
|
"finish_reason": chunk.finish_reason
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(data)}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
error_data = {"error": str(e)}
|
||||||
|
yield f"data: {json.dumps(error_data)}\n\n"
|
||||||
|
finally:
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
generate_response(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use the type-safe single response method
|
||||||
|
response = await llm.chat_single(
|
||||||
|
model=request.model,
|
||||||
|
messages=request.messages,
|
||||||
|
provider=provider,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatResponseModel(
|
||||||
|
content=response.content,
|
||||||
|
model=response.model,
|
||||||
|
provider=provider.value if provider else llm.default_provider.value,
|
||||||
|
finish_reason=response.finish_reason,
|
||||||
|
usage=response.usage
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@app.post("/generate")
|
||||||
|
async def generate_endpoint(request: GenerateRequest):
|
||||||
|
"""Generate endpoint for simple text generation"""
|
||||||
|
try:
|
||||||
|
llm = get_llm()
|
||||||
|
|
||||||
|
provider = None
|
||||||
|
if request.provider:
|
||||||
|
try:
|
||||||
|
provider = LLMProvider(request.provider.lower())
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Unsupported provider: {request.provider}"
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if request.max_tokens:
|
||||||
|
kwargs['max_tokens'] = request.max_tokens
|
||||||
|
if request.temperature:
|
||||||
|
kwargs['temperature'] = request.temperature
|
||||||
|
|
||||||
|
if request.stream:
|
||||||
|
async def generate_response():
|
||||||
|
try:
|
||||||
|
# Use the type-safe streaming method
|
||||||
|
async for chunk in llm.generate_stream(
|
||||||
|
model=request.model,
|
||||||
|
prompt=request.prompt,
|
||||||
|
provider=provider,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
data = {
|
||||||
|
"content": chunk.content,
|
||||||
|
"model": chunk.model,
|
||||||
|
"provider": provider.value if provider else llm.default_provider.value,
|
||||||
|
"finish_reason": chunk.finish_reason
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(data)}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
error_data = {"error": str(e)}
|
||||||
|
yield f"data: {json.dumps(error_data)}\n\n"
|
||||||
|
finally:
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
generate_response(),
|
||||||
|
media_type="text/event-stream"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use the type-safe single response method
|
||||||
|
response = await llm.generate_single(
|
||||||
|
model=request.model,
|
||||||
|
prompt=request.prompt,
|
||||||
|
provider=provider,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatResponseModel(
|
||||||
|
content=response.content,
|
||||||
|
model=response.model,
|
||||||
|
provider=provider.value if provider else llm.default_provider.value,
|
||||||
|
finish_reason=response.finish_reason,
|
||||||
|
usage=response.usage
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@app.get("/models")
|
||||||
|
async def list_models(provider: Optional[str] = None):
|
||||||
|
"""List available models for a provider"""
|
||||||
|
try:
|
||||||
|
llm = get_llm()
|
||||||
|
|
||||||
|
provider_enum = None
|
||||||
|
if provider:
|
||||||
|
try:
|
||||||
|
provider_enum = LLMProvider(provider.lower())
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Unsupported provider: {provider}"
|
||||||
|
)
|
||||||
|
|
||||||
|
models = await llm.list_models(provider_enum)
|
||||||
|
return {
|
||||||
|
"provider": provider_enum.value if provider_enum else llm.default_provider.value,
|
||||||
|
"models": models
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@app.get("/providers")
|
||||||
|
async def list_providers():
|
||||||
|
"""List all configured providers"""
|
||||||
|
llm = get_llm()
|
||||||
|
return {
|
||||||
|
"providers": [provider.value for provider in llm._initialized_providers],
|
||||||
|
"default": llm.default_provider.value
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.post("/providers/{provider}/set-default")
|
||||||
|
async def set_default_provider(provider: str):
|
||||||
|
"""Set the default provider"""
|
||||||
|
try:
|
||||||
|
llm = get_llm()
|
||||||
|
provider_enum = LLMProvider(provider.lower())
|
||||||
|
llm.set_default_provider(provider_enum)
|
||||||
|
return {"message": f"Default provider set to {provider}", "default": provider}
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
# Health check endpoint
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
"""Health check endpoint"""
|
||||||
|
llm = get_llm()
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"providers_configured": len(llm._initialized_providers),
|
||||||
|
"default_provider": llm.default_provider.value
|
||||||
|
}
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn # type: ignore
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
606
src/multi-llm/llm_proxy.py
Normal file
606
src/multi-llm/llm_proxy.py
Normal file
@ -0,0 +1,606 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, List, Any, AsyncGenerator, Optional, Union
|
||||||
|
from enum import Enum
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Standard message format for all providers
|
||||||
|
@dataclass
|
||||||
|
class ChatMessage:
|
||||||
|
role: str # "user", "assistant", "system"
|
||||||
|
content: str
|
||||||
|
|
||||||
|
def __getitem__(self, key: str):
|
||||||
|
"""Allow dictionary-style access for backward compatibility"""
|
||||||
|
if key == 'role':
|
||||||
|
return self.role
|
||||||
|
elif key == 'content':
|
||||||
|
return self.content
|
||||||
|
else:
|
||||||
|
raise KeyError(f"'{key}' not found in ChatMessage")
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, value: str):
|
||||||
|
"""Allow dictionary-style assignment"""
|
||||||
|
if key == 'role':
|
||||||
|
self.role = value
|
||||||
|
elif key == 'content':
|
||||||
|
self.content = value
|
||||||
|
else:
|
||||||
|
raise KeyError(f"'{key}' not found in ChatMessage")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, str]) -> 'ChatMessage':
|
||||||
|
"""Create ChatMessage from dictionary"""
|
||||||
|
return cls(role=data['role'], content=data['content'])
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, str]:
|
||||||
|
"""Convert ChatMessage to dictionary"""
|
||||||
|
return {'role': self.role, 'content': self.content}
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatResponse:
|
||||||
|
content: str
|
||||||
|
model: str
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
usage: Optional[Dict[str, int]] = None
|
||||||
|
|
||||||
|
class LLMProvider(Enum):
|
||||||
|
OLLAMA = "ollama"
|
||||||
|
OPENAI = "openai"
|
||||||
|
ANTHROPIC = "anthropic"
|
||||||
|
GEMINI = "gemini"
|
||||||
|
GROK = "grok"
|
||||||
|
|
||||||
|
class BaseLLMAdapter(ABC):
|
||||||
|
"""Abstract base class for all LLM adapters"""
|
||||||
|
|
||||||
|
def __init__(self, **config):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||||
|
"""Send chat messages and get response"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||||
|
"""Generate text from prompt"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def list_models(self) -> List[str]:
|
||||||
|
"""List available models"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class OllamaAdapter(BaseLLMAdapter):
|
||||||
|
"""Adapter for Ollama"""
|
||||||
|
|
||||||
|
def __init__(self, **config):
|
||||||
|
super().__init__(**config)
|
||||||
|
import ollama
|
||||||
|
self.client = ollama.AsyncClient( # type: ignore
|
||||||
|
host=config.get('host', 'http://localhost:11434')
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||||
|
|
||||||
|
# Convert ChatMessage objects to Ollama format
|
||||||
|
ollama_messages = []
|
||||||
|
for msg in messages:
|
||||||
|
ollama_messages.append({
|
||||||
|
"role": msg.role,
|
||||||
|
"content": msg.content
|
||||||
|
})
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._stream_chat(model, ollama_messages, **kwargs)
|
||||||
|
else:
|
||||||
|
response = await self.client.chat(
|
||||||
|
model=model,
|
||||||
|
messages=ollama_messages,
|
||||||
|
stream=False,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return ChatResponse(
|
||||||
|
content=response['message']['content'],
|
||||||
|
model=model,
|
||||||
|
finish_reason=response.get('done_reason')
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _stream_chat(self, model: str, messages: List[Dict], **kwargs):
|
||||||
|
async for chunk in self.client.chat(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
stream=True,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if chunk.get('message', {}).get('content'):
|
||||||
|
yield ChatResponse(
|
||||||
|
content=chunk['message']['content'],
|
||||||
|
model=model,
|
||||||
|
finish_reason=chunk.get('done_reason')
|
||||||
|
)
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._stream_generate(model, prompt, **kwargs)
|
||||||
|
else:
|
||||||
|
response = await self.client.generate(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
stream=False,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return ChatResponse(
|
||||||
|
content=response['response'],
|
||||||
|
model=model,
|
||||||
|
finish_reason=response.get('done_reason')
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _stream_generate(self, model: str, prompt: str, **kwargs):
|
||||||
|
async for chunk in self.client.generate(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
stream=True,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if chunk.get('response'):
|
||||||
|
yield ChatResponse(
|
||||||
|
content=chunk['response'],
|
||||||
|
model=model,
|
||||||
|
finish_reason=chunk.get('done_reason')
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_models(self) -> List[str]:
|
||||||
|
models = await self.client.list()
|
||||||
|
return [model['name'] for model in models['models']]
|
||||||
|
|
||||||
|
class OpenAIAdapter(BaseLLMAdapter):
|
||||||
|
"""Adapter for OpenAI"""
|
||||||
|
|
||||||
|
def __init__(self, **config):
|
||||||
|
super().__init__(**config)
|
||||||
|
import openai # type: ignore
|
||||||
|
self.client = openai.AsyncOpenAI(
|
||||||
|
api_key=config.get('api_key', os.getenv('OPENAI_API_KEY'))
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||||
|
|
||||||
|
# Convert ChatMessage objects to OpenAI format
|
||||||
|
openai_messages = []
|
||||||
|
for msg in messages:
|
||||||
|
openai_messages.append({
|
||||||
|
"role": msg.role,
|
||||||
|
"content": msg.content
|
||||||
|
})
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._stream_chat(model, openai_messages, **kwargs)
|
||||||
|
else:
|
||||||
|
response = await self.client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=openai_messages,
|
||||||
|
stream=False,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return ChatResponse(
|
||||||
|
content=response.choices[0].message.content,
|
||||||
|
model=model,
|
||||||
|
finish_reason=response.choices[0].finish_reason,
|
||||||
|
usage=response.usage.model_dump() if response.usage else None
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _stream_chat(self, model: str, messages: List[Dict], **kwargs):
|
||||||
|
async for chunk in await self.client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
stream=True,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if chunk.choices[0].delta.content:
|
||||||
|
yield ChatResponse(
|
||||||
|
content=chunk.choices[0].delta.content,
|
||||||
|
model=model,
|
||||||
|
finish_reason=chunk.choices[0].finish_reason
|
||||||
|
)
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||||
|
|
||||||
|
# Convert to chat format for OpenAI
|
||||||
|
messages = [ChatMessage(role="user", content=prompt)]
|
||||||
|
return await self.chat(model, messages, stream, **kwargs)
|
||||||
|
|
||||||
|
async def list_models(self) -> List[str]:
|
||||||
|
models = await self.client.models.list()
|
||||||
|
return [model.id for model in models.data]
|
||||||
|
|
||||||
|
class AnthropicAdapter(BaseLLMAdapter):
|
||||||
|
"""Adapter for Anthropic Claude"""
|
||||||
|
|
||||||
|
def __init__(self, **config):
|
||||||
|
super().__init__(**config)
|
||||||
|
import anthropic # type: ignore
|
||||||
|
self.client = anthropic.AsyncAnthropic(
|
||||||
|
api_key=config.get('api_key', os.getenv('ANTHROPIC_API_KEY'))
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||||
|
|
||||||
|
# Anthropic requires system message to be separate
|
||||||
|
system_message = None
|
||||||
|
anthropic_messages = []
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
if msg.role == "system":
|
||||||
|
system_message = msg.content
|
||||||
|
else:
|
||||||
|
anthropic_messages.append({
|
||||||
|
"role": msg.role,
|
||||||
|
"content": msg.content
|
||||||
|
})
|
||||||
|
|
||||||
|
request_kwargs = {
|
||||||
|
"model": model,
|
||||||
|
"messages": anthropic_messages,
|
||||||
|
"max_tokens": kwargs.pop('max_tokens', 1000),
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
if system_message:
|
||||||
|
request_kwargs["system"] = system_message
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._stream_chat(**request_kwargs)
|
||||||
|
else:
|
||||||
|
response = await self.client.messages.create(
|
||||||
|
stream=False,
|
||||||
|
**request_kwargs
|
||||||
|
)
|
||||||
|
return ChatResponse(
|
||||||
|
content=response.content[0].text,
|
||||||
|
model=model,
|
||||||
|
finish_reason=response.stop_reason,
|
||||||
|
usage={
|
||||||
|
"input_tokens": response.usage.input_tokens,
|
||||||
|
"output_tokens": response.usage.output_tokens
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _stream_chat(self, **kwargs):
|
||||||
|
async with self.client.messages.stream(**kwargs) as stream:
|
||||||
|
async for text in stream.text_stream:
|
||||||
|
yield ChatResponse(
|
||||||
|
content=text,
|
||||||
|
model=kwargs['model']
|
||||||
|
)
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||||
|
|
||||||
|
messages = [ChatMessage(role="user", content=prompt)]
|
||||||
|
return await self.chat(model, messages, stream, **kwargs)
|
||||||
|
|
||||||
|
async def list_models(self) -> List[str]:
|
||||||
|
# Anthropic doesn't have a list models endpoint, return known models
|
||||||
|
return [
|
||||||
|
"claude-3-5-sonnet-20241022",
|
||||||
|
"claude-3-5-haiku-20241022",
|
||||||
|
"claude-3-opus-20240229"
|
||||||
|
]
|
||||||
|
|
||||||
|
class GeminiAdapter(BaseLLMAdapter):
|
||||||
|
"""Adapter for Google Gemini"""
|
||||||
|
|
||||||
|
def __init__(self, **config):
|
||||||
|
super().__init__(**config)
|
||||||
|
import google.generativeai as genai # type: ignore
|
||||||
|
genai.configure(api_key=config.get('api_key', os.getenv('GEMINI_API_KEY')))
|
||||||
|
self.genai = genai
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||||
|
|
||||||
|
model_instance = self.genai.GenerativeModel(model)
|
||||||
|
|
||||||
|
# Convert messages to Gemini format
|
||||||
|
chat_history = []
|
||||||
|
current_message = None
|
||||||
|
|
||||||
|
for i, msg in enumerate(messages[:-1]): # All but last message go to history
|
||||||
|
if msg.role == "user":
|
||||||
|
chat_history.append({"role": "user", "parts": [msg.content]})
|
||||||
|
elif msg.role == "assistant":
|
||||||
|
chat_history.append({"role": "model", "parts": [msg.content]})
|
||||||
|
|
||||||
|
# Last message is the current prompt
|
||||||
|
if messages:
|
||||||
|
current_message = messages[-1].content
|
||||||
|
|
||||||
|
chat = model_instance.start_chat(history=chat_history)
|
||||||
|
|
||||||
|
if not current_message:
|
||||||
|
raise ValueError("No current message provided for chat")
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._stream_chat(chat, current_message, **kwargs)
|
||||||
|
else:
|
||||||
|
response = await chat.send_message_async(current_message, **kwargs)
|
||||||
|
return ChatResponse(
|
||||||
|
content=response.text,
|
||||||
|
model=model,
|
||||||
|
finish_reason=response.candidates[0].finish_reason.name if response.candidates else None
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _stream_chat(self, chat, message: str, **kwargs):
|
||||||
|
async for chunk in chat.send_message_async(message, stream=True, **kwargs):
|
||||||
|
if chunk.text:
|
||||||
|
yield ChatResponse(
|
||||||
|
content=chunk.text,
|
||||||
|
model=chat.model.model_name
|
||||||
|
)
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||||
|
|
||||||
|
messages = [ChatMessage(role="user", content=prompt)]
|
||||||
|
return await self.chat(model, messages, stream, **kwargs)
|
||||||
|
|
||||||
|
async def list_models(self) -> List[str]:
|
||||||
|
models = self.genai.list_models()
|
||||||
|
return [model.name for model in models if 'generateContent' in model.supported_generation_methods]
|
||||||
|
|
||||||
|
class UnifiedLLMProxy:
|
||||||
|
"""Main proxy class that provides unified interface to all LLM providers"""
|
||||||
|
|
||||||
|
def __init__(self, default_provider: LLMProvider = LLMProvider.OLLAMA):
|
||||||
|
self.adapters: Dict[LLMProvider, BaseLLMAdapter] = {}
|
||||||
|
self.default_provider = default_provider
|
||||||
|
self._initialized_providers = set()
|
||||||
|
|
||||||
|
def configure_provider(self, provider: LLMProvider, **config):
|
||||||
|
"""Configure a specific provider with its settings"""
|
||||||
|
adapter_classes = {
|
||||||
|
LLMProvider.OLLAMA: OllamaAdapter,
|
||||||
|
LLMProvider.OPENAI: OpenAIAdapter,
|
||||||
|
LLMProvider.ANTHROPIC: AnthropicAdapter,
|
||||||
|
LLMProvider.GEMINI: GeminiAdapter,
|
||||||
|
# Add other providers as needed
|
||||||
|
}
|
||||||
|
|
||||||
|
if provider in adapter_classes:
|
||||||
|
self.adapters[provider] = adapter_classes[provider](**config)
|
||||||
|
self._initialized_providers.add(provider)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported provider: {provider}")
|
||||||
|
|
||||||
|
def set_default_provider(self, provider: LLMProvider):
|
||||||
|
"""Set the default provider for requests"""
|
||||||
|
if provider not in self._initialized_providers:
|
||||||
|
raise ValueError(f"Provider {provider} not configured")
|
||||||
|
self.default_provider = provider
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: Union[List[ChatMessage], List[Dict[str, str]]],
|
||||||
|
provider: Optional[LLMProvider] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||||
|
"""Send chat messages using specified or default provider"""
|
||||||
|
|
||||||
|
provider = provider or self.default_provider
|
||||||
|
adapter = self._get_adapter(provider)
|
||||||
|
|
||||||
|
# Normalize messages to ChatMessage objects
|
||||||
|
normalized_messages = []
|
||||||
|
for msg in messages:
|
||||||
|
if isinstance(msg, ChatMessage):
|
||||||
|
normalized_messages.append(msg)
|
||||||
|
elif isinstance(msg, dict):
|
||||||
|
normalized_messages.append(ChatMessage.from_dict(msg))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid message type: {type(msg)}")
|
||||||
|
|
||||||
|
return await adapter.chat(model, normalized_messages, stream, **kwargs)
|
||||||
|
|
||||||
|
async def chat_stream(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: Union[List[ChatMessage], List[Dict[str, str]]],
|
||||||
|
provider: Optional[LLMProvider] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncGenerator[ChatResponse, None]:
|
||||||
|
"""Stream chat messages using specified or default provider"""
|
||||||
|
|
||||||
|
result = await self.chat(model, messages, provider, stream=True, **kwargs)
|
||||||
|
# Type checker now knows this is an AsyncGenerator due to stream=True
|
||||||
|
async for chunk in result: # type: ignore
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def chat_single(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: Union[List[ChatMessage], List[Dict[str, str]]],
|
||||||
|
provider: Optional[LLMProvider] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> ChatResponse:
|
||||||
|
"""Get single chat response using specified or default provider"""
|
||||||
|
|
||||||
|
result = await self.chat(model, messages, provider, stream=False, **kwargs)
|
||||||
|
# Type checker now knows this is a ChatResponse due to stream=False
|
||||||
|
return result # type: ignore
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
provider: Optional[LLMProvider] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
|
||||||
|
"""Generate text using specified or default provider"""
|
||||||
|
|
||||||
|
provider = provider or self.default_provider
|
||||||
|
adapter = self._get_adapter(provider)
|
||||||
|
|
||||||
|
return await adapter.generate(model, prompt, stream, **kwargs)
|
||||||
|
|
||||||
|
async def generate_stream(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
provider: Optional[LLMProvider] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncGenerator[ChatResponse, None]:
|
||||||
|
"""Stream text generation using specified or default provider"""
|
||||||
|
|
||||||
|
result = await self.generate(model, prompt, provider, stream=True, **kwargs)
|
||||||
|
async for chunk in result: # type: ignore
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def generate_single(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
provider: Optional[LLMProvider] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> ChatResponse:
|
||||||
|
"""Get single generation response using specified or default provider"""
|
||||||
|
|
||||||
|
result = await self.generate(model, prompt, provider, stream=False, **kwargs)
|
||||||
|
return result # type: ignore
|
||||||
|
|
||||||
|
async def list_models(self, provider: Optional[LLMProvider] = None) -> List[str]:
|
||||||
|
"""List available models for specified or default provider"""
|
||||||
|
|
||||||
|
provider = provider or self.default_provider
|
||||||
|
adapter = self._get_adapter(provider)
|
||||||
|
|
||||||
|
return await adapter.list_models()
|
||||||
|
|
||||||
|
def _get_adapter(self, provider: LLMProvider) -> BaseLLMAdapter:
|
||||||
|
"""Get adapter for specified provider"""
|
||||||
|
if provider not in self.adapters:
|
||||||
|
raise ValueError(f"Provider {provider} not configured")
|
||||||
|
return self.adapters[provider]
|
||||||
|
|
||||||
|
# Example usage and configuration
|
||||||
|
class LLMManager:
|
||||||
|
"""Singleton manager for the unified LLM proxy"""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
_proxy = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = cls()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if LLMManager._proxy is None:
|
||||||
|
LLMManager._proxy = UnifiedLLMProxy()
|
||||||
|
self._configure_from_environment()
|
||||||
|
|
||||||
|
def _configure_from_environment(self):
|
||||||
|
"""Configure providers based on environment variables"""
|
||||||
|
|
||||||
|
if not self._proxy:
|
||||||
|
raise RuntimeError("LLM proxy not initialized")
|
||||||
|
|
||||||
|
# Configure Ollama if available
|
||||||
|
ollama_host = os.getenv('OLLAMA_HOST', 'http://localhost:11434')
|
||||||
|
self._proxy.configure_provider(LLMProvider.OLLAMA, host=ollama_host)
|
||||||
|
|
||||||
|
# Configure OpenAI if API key is available
|
||||||
|
if os.getenv('OPENAI_API_KEY'):
|
||||||
|
self._proxy.configure_provider(LLMProvider.OPENAI)
|
||||||
|
|
||||||
|
# Configure Anthropic if API key is available
|
||||||
|
if os.getenv('ANTHROPIC_API_KEY'):
|
||||||
|
self._proxy.configure_provider(LLMProvider.ANTHROPIC)
|
||||||
|
|
||||||
|
# Configure Gemini if API key is available
|
||||||
|
if os.getenv('GEMINI_API_KEY'):
|
||||||
|
self._proxy.configure_provider(LLMProvider.GEMINI)
|
||||||
|
|
||||||
|
# Set default provider from environment or use Ollama
|
||||||
|
default_provider = os.getenv('DEFAULT_LLM_PROVIDER', 'ollama')
|
||||||
|
try:
|
||||||
|
self._proxy.set_default_provider(LLMProvider(default_provider))
|
||||||
|
except ValueError:
|
||||||
|
# Fallback to Ollama if specified provider not available
|
||||||
|
self._proxy.set_default_provider(LLMProvider.OLLAMA)
|
||||||
|
|
||||||
|
def get_proxy(self) -> UnifiedLLMProxy:
|
||||||
|
"""Get the unified LLM proxy instance"""
|
||||||
|
if not self._proxy:
|
||||||
|
raise RuntimeError("LLM proxy not initialized")
|
||||||
|
return self._proxy
|
||||||
|
|
||||||
|
# Convenience function for easy access
|
||||||
|
def get_llm() -> UnifiedLLMProxy:
|
||||||
|
"""Get the configured LLM proxy"""
|
||||||
|
return LLMManager.get_instance().get_proxy()
|
Loading…
x
Reference in New Issue
Block a user