diff --git a/frontend/src/components/JobCreator.tsx b/frontend/src/components/JobCreator.tsx
new file mode 100644
index 0000000..dfe936e
--- /dev/null
+++ b/frontend/src/components/JobCreator.tsx
@@ -0,0 +1,541 @@
+import React, { useState, useEffect, useRef, JSX } from 'react';
+import {
+ Box,
+ Button,
+ Typography,
+ Paper,
+ TextField,
+ Grid,
+ Dialog,
+ DialogTitle,
+ DialogContent,
+ DialogContentText,
+ DialogActions,
+ IconButton,
+ useTheme,
+ useMediaQuery,
+ Chip,
+ Divider,
+ Card,
+ CardContent,
+ CardHeader,
+ LinearProgress,
+ Stack,
+ Alert
+} from '@mui/material';
+import {
+ SyncAlt,
+ Favorite,
+ Settings,
+ Info,
+ Search,
+ AutoFixHigh,
+ Image,
+ Psychology,
+ Build,
+ CloudUpload,
+ Description,
+ Business,
+ LocationOn,
+ Work,
+ CheckCircle,
+ Star
+} from '@mui/icons-material';
+import { styled } from '@mui/material/styles';
+import DescriptionIcon from '@mui/icons-material/Description';
+import FileUploadIcon from '@mui/icons-material/FileUpload';
+
+import { useAuth } from 'hooks/AuthContext';
+import { useSelectedCandidate, useSelectedJob } from 'hooks/GlobalContext';
+import { BackstoryElementProps } from './BackstoryTab';
+import { LoginRequired } from 'components/ui/LoginRequired';
+
+import * as Types from 'types/types';
+
+const VisuallyHiddenInput = styled('input')({
+ clip: 'rect(0 0 0 0)',
+ clipPath: 'inset(50%)',
+ height: 1,
+ overflow: 'hidden',
+ position: 'absolute',
+ bottom: 0,
+ left: 0,
+ whiteSpace: 'nowrap',
+ width: 1,
+});
+
+const UploadBox = styled(Box)(({ theme }) => ({
+ border: `2px dashed ${theme.palette.primary.main}`,
+ borderRadius: theme.shape.borderRadius * 2,
+ padding: theme.spacing(4),
+ textAlign: 'center',
+ backgroundColor: theme.palette.action.hover,
+ transition: 'all 0.3s ease',
+ cursor: 'pointer',
+ '&:hover': {
+ backgroundColor: theme.palette.action.selected,
+ borderColor: theme.palette.primary.dark,
+ },
+}));
+
+const StatusBox = styled(Box)(({ theme }) => ({
+ display: 'flex',
+ alignItems: 'center',
+ gap: theme.spacing(1),
+ padding: theme.spacing(1, 2),
+ backgroundColor: theme.palette.background.paper,
+ borderRadius: theme.shape.borderRadius,
+ border: `1px solid ${theme.palette.divider}`,
+ minHeight: 48,
+}));
+
+const getIcon = (type: Types.ApiActivityType) => {
+ switch (type) {
+ case 'converting':
+ return ;
+ case 'heartbeat':
+ return ;
+ case 'system':
+ return ;
+ case 'info':
+ return ;
+ case 'searching':
+ return ;
+ case 'generating':
+ return ;
+ case 'generating_image':
+ return ;
+ case 'thinking':
+ return ;
+ case 'tooling':
+ return ;
+ default:
+ return ;
+ }
+};
+
+interface JobCreator extends BackstoryElementProps {
+ onSave?: (job: Types.Job) => void;
+}
+const JobCreator = (props: JobCreator) => {
+ const { user, apiClient } = useAuth();
+ const { onSave } = props;
+ const { selectedCandidate } = useSelectedCandidate();
+ const { selectedJob, setSelectedJob } = useSelectedJob();
+ const { setSnack, submitQuery } = props;
+ const backstoryProps = { setSnack, submitQuery };
+ const theme = useTheme();
+ const isMobile = useMediaQuery(theme.breakpoints.down('sm'));
+ const isTablet = useMediaQuery(theme.breakpoints.down('md'));
+
+ const [openUploadDialog, setOpenUploadDialog] = useState(false);
+ const [jobDescription, setJobDescription] = useState('');
+ const [jobRequirements, setJobRequirements] = useState(null);
+ const [jobTitle, setJobTitle] = useState('');
+ const [company, setCompany] = useState('');
+ const [summary, setSummary] = useState('');
+ const [jobLocation, setJobLocation] = useState('');
+ const [jobId, setJobId] = useState('');
+ const [jobStatus, setJobStatus] = useState('');
+ const [jobStatusIcon, setJobStatusIcon] = useState(<>>);
+ const [isProcessing, setIsProcessing] = useState(false);
+
+ useEffect(() => {
+
+ }, [jobTitle, jobDescription, company]);
+
+ const fileInputRef = useRef(null);
+
+
+ if (!user?.id) {
+ return (
+
+ );
+ }
+
+ const jobStatusHandlers = {
+ onStatus: (status: Types.ChatMessageStatus) => {
+ console.log('status:', status.content);
+ setJobStatusIcon(getIcon(status.activity));
+ setJobStatus(status.content);
+ },
+ onMessage: (job: Types.Job) => {
+ console.log('onMessage - job', job);
+ setCompany(job.company || '');
+ setJobDescription(job.description);
+ setSummary(job.summary || '');
+ setJobTitle(job.title || '');
+ setJobRequirements(job.requirements || null);
+ setJobStatusIcon(<>>);
+ setJobStatus('');
+ },
+ onError: (error: Types.ChatMessageError) => {
+ console.log('onError', error);
+ setSnack(error.content, "error");
+ setIsProcessing(false);
+ },
+ onComplete: () => {
+ setJobStatusIcon(<>>);
+ setJobStatus('');
+ setIsProcessing(false);
+ }
+ };
+
+ const handleJobUpload = async (e: React.ChangeEvent) => {
+ if (e.target.files && e.target.files[0]) {
+ const file = e.target.files[0];
+ const fileExtension = '.' + file.name.split('.').pop()?.toLowerCase();
+ let docType: Types.DocumentType | null = null;
+ switch (fileExtension.substring(1)) {
+ case "pdf":
+ docType = "pdf";
+ break;
+ case "docx":
+ docType = "docx";
+ break;
+ case "md":
+ docType = "markdown";
+ break;
+ case "txt":
+ docType = "txt";
+ break;
+ }
+
+ if (!docType) {
+ setSnack('Invalid file type. Please upload .txt, .md, .docx, or .pdf files only.', 'error');
+ return;
+ }
+
+ try {
+ setIsProcessing(true);
+ setJobDescription('');
+ setJobTitle('');
+ setJobRequirements(null);
+ setSummary('');
+ const controller = apiClient.createJobFromFile(file, jobStatusHandlers);
+ const job = await controller.promise;
+ if (!job) {
+ return;
+ }
+ console.log(`Job id: ${job.id}`);
+ e.target.value = '';
+ } catch (error) {
+ console.error(error);
+ setSnack('Failed to upload document', 'error');
+ setIsProcessing(false);
+ }
+ }
+ };
+
+ const handleUploadClick = () => {
+ fileInputRef.current?.click();
+ };
+
+ const renderRequirementSection = (title: string, items: string[] | undefined, icon: JSX.Element, required = false) => {
+ if (!items || items.length === 0) return null;
+
+ return (
+
+
+ {icon}
+
+ {title}
+
+ {required && }
+
+
+ {items.map((item, index) => (
+
+ ))}
+
+
+ );
+ };
+
+ const renderJobRequirements = () => {
+ if (!jobRequirements) return null;
+
+ return (
+
+ }
+ sx={{ pb: 1 }}
+ />
+
+ {renderRequirementSection(
+ "Technical Skills (Required)",
+ jobRequirements.technicalSkills.required,
+ ,
+ true
+ )}
+ {renderRequirementSection(
+ "Technical Skills (Preferred)",
+ jobRequirements.technicalSkills.preferred,
+
+ )}
+ {renderRequirementSection(
+ "Experience Requirements (Required)",
+ jobRequirements.experienceRequirements.required,
+ ,
+ true
+ )}
+ {renderRequirementSection(
+ "Experience Requirements (Preferred)",
+ jobRequirements.experienceRequirements.preferred,
+
+ )}
+ {renderRequirementSection(
+ "Soft Skills",
+ jobRequirements.softSkills,
+
+ )}
+ {renderRequirementSection(
+ "Experience",
+ jobRequirements.experience,
+
+ )}
+ {renderRequirementSection(
+ "Education",
+ jobRequirements.education,
+
+ )}
+ {renderRequirementSection(
+ "Certifications",
+ jobRequirements.certifications,
+
+ )}
+ {renderRequirementSection(
+ "Preferred Attributes",
+ jobRequirements.preferredAttributes,
+
+ )}
+
+
+ );
+ };
+
+ const handleSave = async () => {
+ const newJob: Types.Job = {
+ ownerId: user?.id || '',
+ ownerType: 'candidate',
+ description: jobDescription,
+ company: company,
+ summary: summary,
+ title: jobTitle,
+ requirements: jobRequirements || undefined
+ };
+ setIsProcessing(true);
+ const job = await apiClient.createJob(newJob);
+ setIsProcessing(false);
+ onSave ? onSave(job) : setSelectedJob(job);
+ };
+
+ const handleExtractRequirements = () => {
+ // Implement requirements extraction logic here
+ setIsProcessing(true);
+ // This would call your API to extract requirements from the job description
+ };
+
+ const renderJobCreation = () => {
+ if (!user) {
+ return You must be logged in;
+ }
+
+ return (
+
+ {/* Upload Section */}
+
+ }
+ />
+
+
+
+
+
+ Upload Job Description
+
+
+
+
+ Drop your job description here
+
+
+ Supported formats: PDF, DOCX, TXT, MD
+
+ }
+ disabled={isProcessing}
+ // onClick={handleUploadClick}
+ >
+ Choose File
+
+
+
+
+
+
+
+
+ Or Enter Manually
+
+ setJobDescription(e.target.value)}
+ disabled={isProcessing}
+ sx={{ mb: 2 }}
+ />
+ {jobRequirements === null && jobDescription && (
+ }
+ disabled={isProcessing}
+ fullWidth={isMobile}
+ >
+ Extract Requirements
+
+ )}
+
+
+
+ {(jobStatus || isProcessing) && (
+
+
+ {jobStatusIcon}
+
+ {jobStatus || 'Processing...'}
+
+
+ {isProcessing && }
+
+ )}
+
+
+
+ {/* Job Details Section */}
+
+ }
+ />
+
+
+
+ setJobTitle(e.target.value)}
+ required
+ disabled={isProcessing}
+ InputProps={{
+ startAdornment:
+ }}
+ />
+
+
+
+ setCompany(e.target.value)}
+ required
+ disabled={isProcessing}
+ InputProps={{
+ startAdornment:
+ }}
+ />
+
+
+ {/*
+ setJobLocation(e.target.value)}
+ disabled={isProcessing}
+ InputProps={{
+ startAdornment:
+ }}
+ />
+ */}
+
+
+
+ }
+ >
+ Save Job
+
+
+
+
+
+
+
+ {/* Job Summary */}
+ {summary !== '' &&
+
+ }
+ sx={{ pb: 1 }}
+ />
+
+ {summary}
+
+
+ }
+
+ {/* Requirements Display */}
+ {renderJobRequirements()}
+
+
+ );
+ };
+
+ return (
+
+ {selectedJob === null && renderJobCreation()}
+
+ );
+};
+
+export { JobCreator };
\ No newline at end of file
diff --git a/frontend/src/components/JobManagement.tsx b/frontend/src/components/JobManagement.tsx
index 87a7ad5..59834b5 100644
--- a/frontend/src/components/JobManagement.tsx
+++ b/frontend/src/components/JobManagement.tsx
@@ -115,34 +115,23 @@ const getIcon = (type: Types.ApiActivityType) => {
};
const JobManagement = (props: BackstoryElementProps) => {
- const { user, apiClient } = useAuth();
- const { selectedCandidate } = useSelectedCandidate();
+ const { user, apiClient } = useAuth();
const { selectedJob, setSelectedJob } = useSelectedJob();
- const { setSnack, submitQuery } = props;
- const backstoryProps = { setSnack, submitQuery };
+ const { setSnack, submitQuery } = props;
const theme = useTheme();
- const isMobile = useMediaQuery(theme.breakpoints.down('sm'));
- const isTablet = useMediaQuery(theme.breakpoints.down('md'));
+ const isMobile = useMediaQuery(theme.breakpoints.down('sm'));
- const [openUploadDialog, setOpenUploadDialog] = useState(false);
const [jobDescription, setJobDescription] = useState('');
const [jobRequirements, setJobRequirements] = useState(null);
const [jobTitle, setJobTitle] = useState('');
const [company, setCompany] = useState('');
const [summary, setSummary] = useState('');
- const [jobLocation, setJobLocation] = useState('');
- const [jobId, setJobId] = useState('');
const [jobStatus, setJobStatus] = useState('');
const [jobStatusIcon, setJobStatusIcon] = useState(<>>);
const [isProcessing, setIsProcessing] = useState(false);
- useEffect(() => {
-
- }, [jobTitle, jobDescription, company]);
-
const fileInputRef = useRef(null);
-
if (!user?.id) {
return (
@@ -223,12 +212,12 @@ const JobManagement = (props: BackstoryElementProps) => {
setJobTitle('');
setJobRequirements(null);
setSummary('');
- const controller = apiClient.uploadCandidateDocument(file, { isJobDocument: true, overwrite: true }, documentStatusHandlers);
- const document = await controller.promise;
- if (!document) {
+ const controller = apiClient.createJobFromFile(file, jobStatusHandlers);
+ const job = await controller.promise;
+ if (!job) {
return;
}
- console.log(`Document id: ${document.id}`);
+ console.log(`Job id: ${job.id}`);
e.target.value = '';
} catch (error) {
console.error(error);
@@ -354,10 +343,6 @@ const JobManagement = (props: BackstoryElementProps) => {
// This would call your API to extract requirements from the job description
};
- const loadJob = async () => {
- const job = await apiClient.getJob("7594e989-a926-45a2-9b07-ae553d2e0d0d");
- setSelectedJob(job);
- }
const renderJobCreation = () => {
if (!user) {
return You must be logged in;
@@ -367,7 +352,6 @@ const JobManagement = (props: BackstoryElementProps) => {
-
{/* Upload Section */}
= (props: JobAnalysisProps) =
const [overallScore, setOverallScore] = useState(0);
const [requirementsSession, setRequirementsSession] = useState(null);
const [statusMessage, setStatusMessage] = useState(null);
+ const [startAnalysis, setStartAnalysis] = useState(false);
+ const [analyzing, setAnalyzing] = useState(false);
+
const isMobile = useMediaQuery(theme.breakpoints.down('sm'));
// Handle accordion expansion
@@ -63,7 +69,7 @@ const JobMatchAnalysis: React.FC = (props: JobAnalysisProps) =
setExpanded(isExpanded ? panel : false);
};
- useEffect(() => {
+ const initializeRequirements = (job: Job) => {
if (!job || !job.requirements) {
return;
}
@@ -106,116 +112,19 @@ const JobMatchAnalysis: React.FC = (props: JobAnalysisProps) =
setSkillMatches(initialSkillMatches);
setStatusMessage(null);
setLoadingRequirements(false);
-
- }, [job, setRequirements]);
+ setOverallScore(0);
+ }
useEffect(() => {
- if (requirementsSession || creatingSession) {
- return;
- }
-
- try {
- setCreatingSession(true);
- apiClient.getOrCreateChatSession(candidate, `Generate requirements for ${candidate.fullName}`, 'job_requirements')
- .then(session => {
- setRequirementsSession(session);
- setCreatingSession(false);
- });
- } catch (error) {
- setSnack('Unable to load chat session', 'error');
- } finally {
- setCreatingSession(false);
- }
-
- }, [requirementsSession, apiClient, candidate]);
-
- // Fetch initial requirements
- useEffect(() => {
- if (!job.description || !requirementsSession || loadingRequirements) {
- return;
- }
-
- const getRequirements = async () => {
- setLoadingRequirements(true);
- try {
- const chatMessage: ChatMessageUser = { ...defaultMessage, sessionId: requirementsSession.id || '', content: job.description };
- apiClient.sendMessageStream(chatMessage, {
- onMessage: (msg: ChatMessage) => {
- console.log(`onMessage: ${msg.type}`, msg);
- const job: Job = toCamelCase(JSON.parse(msg.content || ''));
- const requirements: { requirement: string, domain: string }[] = [];
- if (job.requirements?.technicalSkills) {
- job.requirements.technicalSkills.required?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (required)' }));
- job.requirements.technicalSkills.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (preferred)' }));
- }
- if (job.requirements?.experienceRequirements) {
- job.requirements.experienceRequirements.required?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (required)' }));
- job.requirements.experienceRequirements.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (preferred)' }));
- }
- if (job.requirements?.softSkills) {
- job.requirements.softSkills.forEach(req => requirements.push({ requirement: req, domain: 'Soft Skills' }));
- }
- if (job.requirements?.experience) {
- job.requirements.experience.forEach(req => requirements.push({ requirement: req, domain: 'Experience' }));
- }
- if (job.requirements?.education) {
- job.requirements.education.forEach(req => requirements.push({ requirement: req, domain: 'Education' }));
- }
- if (job.requirements?.certifications) {
- job.requirements.certifications.forEach(req => requirements.push({ requirement: req, domain: 'Certifications' }));
- }
- if (job.requirements?.preferredAttributes) {
- job.requirements.preferredAttributes.forEach(req => requirements.push({ requirement: req, domain: 'Preferred Attributes' }));
- }
-
- const initialSkillMatches = requirements.map(req => ({
- requirement: req.requirement,
- domain: req.domain,
- status: 'waiting' as const,
- matchScore: 0,
- assessment: '',
- description: '',
- citations: []
- }));
-
- setRequirements(requirements);
- setSkillMatches(initialSkillMatches);
- setStatusMessage(null);
- setLoadingRequirements(false);
- },
- onError: (error: string | ChatMessageError) => {
- console.log("onError:", error);
- // Type-guard to determine if this is a ChatMessageBase or a string
- if (typeof error === "object" && error !== null && "content" in error) {
- setSnack(error.content || 'Error obtaining requirements from job description.', "error");
- } else {
- setSnack(error as string, "error");
- }
- setLoadingRequirements(false);
- },
- onStreaming: (chunk: ChatMessageStreaming) => {
- // console.log("onStreaming:", chunk);
- },
- onStatus: (status: ChatMessageStatus) => {
- console.log(`onStatus: ${status}`);
- },
- onComplete: () => {
- console.log("onComplete");
- setStatusMessage(null);
- setLoadingRequirements(false);
- }
- });
- } catch (error) {
- console.error('Failed to send message:', error);
- setLoadingRequirements(false);
- }
- };
-
- getRequirements();
- }, [job, requirementsSession]);
+ initializeRequirements(job);
+ }, [job]);
// Fetch match data for each requirement
useEffect(() => {
+ if (!startAnalysis || analyzing || !job.requirements) {
+ return;
+ }
+
const fetchMatchData = async () => {
if (requirements.length === 0) return;
@@ -279,10 +188,9 @@ const JobMatchAnalysis: React.FC = (props: JobAnalysisProps) =
}
};
- if (!loadingRequirements) {
- fetchMatchData();
- }
- }, [requirements, loadingRequirements]);
+ setAnalyzing(true);
+ fetchMatchData().then(() => { setAnalyzing(false); setStartAnalysis(false) });
+ }, [job, startAnalysis, analyzing, requirements, loadingRequirements]);
// Get color based on match score
const getMatchColor = (score: number): string => {
@@ -301,6 +209,11 @@ const JobMatchAnalysis: React.FC = (props: JobAnalysisProps) =
return ;
};
+ const beginAnalysis = () => {
+ initializeRequirements(job);
+ setStartAnalysis(true);
+ };
+
return (
@@ -344,7 +257,9 @@ const JobMatchAnalysis: React.FC = (props: JobAnalysisProps) =
-
+
+ {}
+ {overallScore !== 0 && <>
Overall Match:
@@ -391,6 +306,7 @@ const JobMatchAnalysis: React.FC = (props: JobAnalysisProps) =
fontWeight: 'bold'
}}
/>
+ >}
diff --git a/frontend/src/components/CandidateInfo.tsx b/frontend/src/components/ui/CandidateInfo.tsx
similarity index 98%
rename from frontend/src/components/CandidateInfo.tsx
rename to frontend/src/components/ui/CandidateInfo.tsx
index 95b866e..9b1e15c 100644
--- a/frontend/src/components/CandidateInfo.tsx
+++ b/frontend/src/components/ui/CandidateInfo.tsx
@@ -13,7 +13,7 @@ import { CopyBubble } from "components/CopyBubble";
import { rest } from 'lodash';
import { AIBanner } from 'components/ui/AIBanner';
import { useAuth } from 'hooks/AuthContext';
-import { DeleteConfirmation } from './DeleteConfirmation';
+import { DeleteConfirmation } from '../DeleteConfirmation';
interface CandidateInfoProps {
candidate: Candidate;
diff --git a/frontend/src/components/ui/CandidatePicker.tsx b/frontend/src/components/ui/CandidatePicker.tsx
index 1a895f0..decf152 100644
--- a/frontend/src/components/ui/CandidatePicker.tsx
+++ b/frontend/src/components/ui/CandidatePicker.tsx
@@ -4,7 +4,7 @@ import Button from '@mui/material/Button';
import Box from '@mui/material/Box';
import { BackstoryElementProps } from 'components/BackstoryTab';
-import { CandidateInfo } from 'components/CandidateInfo';
+import { CandidateInfo } from 'components/ui/CandidateInfo';
import { Candidate } from "types/types";
import { useAuth } from 'hooks/AuthContext';
import { useSelectedCandidate } from 'hooks/GlobalContext';
diff --git a/frontend/src/components/ui/JobInfo.tsx b/frontend/src/components/ui/JobInfo.tsx
new file mode 100644
index 0000000..c850c01
--- /dev/null
+++ b/frontend/src/components/ui/JobInfo.tsx
@@ -0,0 +1,97 @@
+import React from 'react';
+import { Box, Link, Typography, Avatar, Grid, SxProps, CardActions } from '@mui/material';
+import {
+ Card,
+ CardContent,
+ Divider,
+ useTheme,
+} from '@mui/material';
+import DeleteIcon from '@mui/icons-material/Delete';
+import { useMediaQuery } from '@mui/material';
+import { JobFull } from 'types/types';
+import { CopyBubble } from "components/CopyBubble";
+import { rest } from 'lodash';
+import { AIBanner } from 'components/ui/AIBanner';
+import { useAuth } from 'hooks/AuthContext';
+import { DeleteConfirmation } from '../DeleteConfirmation';
+
+interface JobInfoProps {
+ job: JobFull;
+ sx?: SxProps;
+ action?: string;
+ elevation?: number;
+ variant?: "small" | "normal" | null
+};
+
+const JobInfo: React.FC = (props: JobInfoProps) => {
+ const { job } = props;
+ const { user, apiClient } = useAuth();
+ const {
+ sx,
+ action = '',
+ elevation = 1,
+ variant = "normal"
+ } = props;
+ const theme = useTheme();
+ const isMobile = useMediaQuery(theme.breakpoints.down('md'));
+ const isAdmin = user?.isAdmin;
+
+ const deleteJob = async (jobId: string | undefined) => {
+ if (jobId) {
+ await apiClient.deleteJob(jobId);
+ }
+ }
+
+ if (!job) {
+ return No user loaded.;
+ }
+
+ return (
+
+
+ {variant !== "small" && <>
+ {job.location &&
+
+ Location: {job.location.city}, {job.location.state || job.location.country}
+
+ }
+ {job.company &&
+
+ Company: {job.company}
+
+ }
+ {job.summary &&
+ Summary: {job.summary}
+
+ }
+ >}
+
+
+ {isAdmin &&
+ { deleteJob(job.id); }}
+ sx={{ minWidth: 'auto', px: 2, maxHeight: "min-content", color: "red" }}
+ action="delete"
+ label="job"
+ title="Delete job"
+ icon=
+ message={`Are you sure you want to delete ${job.id}? This action cannot be undone.`}
+ />}
+
+
+ );
+};
+
+export { JobInfo };
diff --git a/frontend/src/components/ui/JobPicker.tsx b/frontend/src/components/ui/JobPicker.tsx
new file mode 100644
index 0000000..9ae5057
--- /dev/null
+++ b/frontend/src/components/ui/JobPicker.tsx
@@ -0,0 +1,69 @@
+import React, { useEffect, useState } from 'react';
+import { useNavigate } from "react-router-dom";
+import Button from '@mui/material/Button';
+import Box from '@mui/material/Box';
+
+import { BackstoryElementProps } from 'components/BackstoryTab';
+import { JobInfo } from 'components/ui/JobInfo';
+import { Job, JobFull } from "types/types";
+import { useAuth } from 'hooks/AuthContext';
+import { useSelectedJob } from 'hooks/GlobalContext';
+
+interface JobPickerProps extends BackstoryElementProps {
+ onSelect?: (job: JobFull) => void
+};
+
+const JobPicker = (props: JobPickerProps) => {
+ const { onSelect } = props;
+ const { apiClient } = useAuth();
+ const { selectedJob, setSelectedJob } = useSelectedJob();
+ const { setSnack } = props;
+ const [jobs, setJobs] = useState(null);
+
+ useEffect(() => {
+ if (jobs !== null) {
+ return;
+ }
+ const getJobs = async () => {
+ try {
+ const results = await apiClient.getJobs();
+ const jobs: JobFull[] = results.data;
+ jobs.sort((a, b) => {
+ let result = a.company?.localeCompare(b.company || '');
+ if (result === 0) {
+ result = a.title?.localeCompare(b.title || '');
+ }
+ return result || 0;
+ });
+ setJobs(jobs);
+ } catch (err) {
+ setSnack("" + err);
+ }
+ };
+
+ getJobs();
+ }, [jobs, setSnack]);
+
+ return (
+
+
+ {jobs?.map((j, i) =>
+ { onSelect ? onSelect(j) : setSelectedJob(j); }}
+ sx={{ cursor: "pointer" }}>
+ {selectedJob?.id === j.id &&
+
+ }
+ {selectedJob?.id !== j.id &&
+
+ }
+
+ )}
+
+
+ );
+};
+
+export {
+ JobPicker
+};
\ No newline at end of file
diff --git a/frontend/src/pages/CandidateChatPage.tsx b/frontend/src/pages/CandidateChatPage.tsx
index 132a235..95f8eda 100644
--- a/frontend/src/pages/CandidateChatPage.tsx
+++ b/frontend/src/pages/CandidateChatPage.tsx
@@ -17,7 +17,7 @@ import { ConversationHandle } from 'components/Conversation';
import { BackstoryPageProps } from 'components/BackstoryTab';
import { Message } from 'components/Message';
import { DeleteConfirmation } from 'components/DeleteConfirmation';
-import { CandidateInfo } from 'components/CandidateInfo';
+import { CandidateInfo } from 'components/ui/CandidateInfo';
import { useNavigate } from 'react-router-dom';
import { useSelectedCandidate } from 'hooks/GlobalContext';
import PropagateLoader from 'react-spinners/PropagateLoader';
diff --git a/frontend/src/pages/GenerateCandidate.tsx b/frontend/src/pages/GenerateCandidate.tsx
index a38630b..2b8efbf 100644
--- a/frontend/src/pages/GenerateCandidate.tsx
+++ b/frontend/src/pages/GenerateCandidate.tsx
@@ -9,7 +9,7 @@ import CancelIcon from '@mui/icons-material/Cancel';
import SendIcon from '@mui/icons-material/Send';
import PropagateLoader from 'react-spinners/PropagateLoader';
-import { CandidateInfo } from '../components/CandidateInfo';
+import { CandidateInfo } from '../components/ui/CandidateInfo';
import { Quote } from 'components/Quote';
import { BackstoryElementProps } from 'components/BackstoryTab';
import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryTextField';
diff --git a/frontend/src/pages/JobAnalysisPage.tsx b/frontend/src/pages/JobAnalysisPage.tsx
index dcfdfac..a7b7123 100644
--- a/frontend/src/pages/JobAnalysisPage.tsx
+++ b/frontend/src/pages/JobAnalysisPage.tsx
@@ -7,31 +7,74 @@ import {
Button,
Typography,
Paper,
- Avatar,
useTheme,
Snackbar,
+ Container,
+ Grid,
Alert,
+ Tabs,
+ Tab,
+ Card,
+ CardContent,
+ Divider,
+ Avatar,
+ Badge,
} from '@mui/material';
+import {
+ Person,
+ PersonAdd,
+ AccountCircle,
+ Add,
+ WorkOutline,
+ AddCircle,
+} from '@mui/icons-material';
import PersonIcon from '@mui/icons-material/Person';
import WorkIcon from '@mui/icons-material/Work';
import AssessmentIcon from '@mui/icons-material/Assessment';
import { JobMatchAnalysis } from 'components/JobMatchAnalysis';
-import { Candidate } from "types/types";
+import { Candidate, Job, JobFull } from "types/types";
import { useNavigate } from 'react-router-dom';
import { BackstoryPageProps } from 'components/BackstoryTab';
import { useAuth } from 'hooks/AuthContext';
import { useSelectedCandidate, useSelectedJob } from 'hooks/GlobalContext';
-import { CandidateInfo } from 'components/CandidateInfo';
+import { CandidateInfo } from 'components/ui/CandidateInfo';
import { ComingSoon } from 'components/ui/ComingSoon';
import { JobManagement } from 'components/JobManagement';
import { LoginRequired } from 'components/ui/LoginRequired';
import { Scrollable } from 'components/Scrollable';
import { CandidatePicker } from 'components/ui/CandidatePicker';
+import { JobPicker } from 'components/ui/JobPicker';
+import { JobCreator } from 'components/JobCreator';
+
+function WorkAddIcon() {
+ return (
+
+
+
+
+ );
+}
// Main component
const JobAnalysisPage: React.FC = (props: BackstoryPageProps) => {
const theme = useTheme();
- const { user } = useAuth();
+ const { user, apiClient } = useAuth();
const navigate = useNavigate();
const { selectedCandidate, setSelectedCandidate } = useSelectedCandidate()
const { selectedJob, setSelectedJob } = useSelectedJob()
@@ -41,12 +84,21 @@ const JobAnalysisPage: React.FC = (props: BackstoryPageProps
const [activeStep, setActiveStep] = useState(0);
const [analysisStarted, setAnalysisStarted] = useState(false);
const [error, setError] = useState(null);
+ const [jobTab, setJobTab] = useState('load');
useEffect(() => {
- if (selectedJob && activeStep === 1) {
- setActiveStep(2);
+ console.log({ activeStep, selectedCandidate, selectedJob });
+
+ if (!selectedCandidate) {
+ if (activeStep !== 0) {
+ setActiveStep(0);
+ }
+ } else if (!selectedJob) {
+ if (activeStep !== 1) {
+ setActiveStep(1);
+ }
}
- }, [selectedJob, activeStep]);
+ }, [selectedCandidate, selectedJob, activeStep])
// Steps in our process
const steps = [
@@ -93,7 +145,6 @@ const JobAnalysisPage: React.FC = (props: BackstoryPageProps
setSelectedJob(null);
break;
case 1: /* Select Job */
- setSelectedCandidate(null);
setSelectedJob(null);
break;
case 2: /* Job Analysis */
@@ -109,6 +160,11 @@ const JobAnalysisPage: React.FC = (props: BackstoryPageProps
setActiveStep(1);
}
+ const onJobSelect = (job: Job) => {
+ setSelectedJob(job)
+ setActiveStep(2);
+ }
+
// Render function for the candidate selection step
const renderCandidateSelection = () => (
@@ -120,16 +176,35 @@ const JobAnalysisPage: React.FC = (props: BackstoryPageProps
);
+ const handleTabChange = (event: React.SyntheticEvent, value: string) => {
+ setJobTab(value);
+ };
+
// Render function for the job description step
- const renderJobDescription = () => (
-
- {selectedCandidate && (
- {
+ if (!selectedCandidate) {
+ return;
+ }
+
+ return (
+
+
+ } label="Load" />
+ } label="Create" />
+
+
+
+ {jobTab === 'load' &&
+
+ }
+ {jobTab === 'create' &&
+
- )}
+ onSave={onJobSelect}
+ />}
- );
+ );
+ }
// Render function for the analysis step
const renderAnalysis = () => (
@@ -232,7 +307,7 @@ const JobAnalysisPage: React.FC = (props: BackstoryPageProps
) : (
)}
diff --git a/frontend/src/pages/OldChatPage.tsx b/frontend/src/pages/OldChatPage.tsx
index c58f549..a4ec2cc 100644
--- a/frontend/src/pages/OldChatPage.tsx
+++ b/frontend/src/pages/OldChatPage.tsx
@@ -7,7 +7,7 @@ import MuiMarkdown from 'mui-markdown';
import { BackstoryPageProps } from '../components/BackstoryTab';
import { Conversation, ConversationHandle } from '../components/Conversation';
import { BackstoryQuery } from '../components/BackstoryQuery';
-import { CandidateInfo } from 'components/CandidateInfo';
+import { CandidateInfo } from 'components/ui/CandidateInfo';
import { useAuth } from 'hooks/AuthContext';
import { Candidate } from 'types/types';
diff --git a/frontend/src/services/api-client.ts b/frontend/src/services/api-client.ts
index 1aef0fe..36b7b06 100644
--- a/frontend/src/services/api-client.ts
+++ b/frontend/src/services/api-client.ts
@@ -52,7 +52,7 @@ interface StreamingOptions {
signal?: AbortSignal;
}
-interface DeleteCandidateResponse {
+interface DeleteResponse {
success: boolean;
message: string;
}
@@ -530,14 +530,24 @@ class ApiClient {
return this.handleApiResponseWithConversion(response, 'Candidate');
}
- async deleteCandidate(id: string): Promise {
+ async deleteCandidate(id: string): Promise {
const response = await fetch(`${this.baseUrl}/candidates/${id}`, {
method: 'DELETE',
headers: this.defaultHeaders,
body: JSON.stringify({ id })
});
- return handleApiResponse(response);
+ return handleApiResponse(response);
+ }
+
+ async deleteJob(id: string): Promise {
+ const response = await fetch(`${this.baseUrl}/jobs/${id}`, {
+ method: 'DELETE',
+ headers: this.defaultHeaders,
+ body: JSON.stringify({ id })
+ });
+
+ return handleApiResponse(response);
}
async uploadCandidateProfile(file: File): Promise {
@@ -641,7 +651,7 @@ class ApiClient {
return this.handleApiResponseWithConversion(response, 'Job');
}
- async getJobs(request: Partial = {}): Promise> {
+ async getJobs(request: Partial = {}): Promise> {
const paginatedRequest = createPaginatedRequest(request);
const params = toUrlParams(formatApiRequest(paginatedRequest));
@@ -649,7 +659,7 @@ class ApiClient {
headers: this.defaultHeaders
});
- return this.handlePaginatedApiResponseWithConversion(response, 'Job');
+ return this.handlePaginatedApiResponseWithConversion(response, 'JobFull');
}
async getJobsByEmployer(employerId: string, request: Partial = {}): Promise> {
@@ -844,18 +854,28 @@ class ApiClient {
}
};
return this.streamify('/candidates/documents/upload', formData, streamingOptions);
- // {
- // method: 'POST',
- // headers: {
- // // Don't set Content-Type - browser will set it automatically with boundary
- // 'Authorization': this.defaultHeaders['Authorization']
- // },
- // body: formData
- // });
+ }
- // const result = await handleApiResponse(response);
-
- // return result;
+ createJobFromFile(file: File, streamingOptions?: StreamingOptions): StreamingResponse {
+ const formData = new FormData()
+ formData.append('file', file);
+ formData.append('filename', file.name);
+ streamingOptions = {
+ ...streamingOptions,
+ headers: {
+ // Don't set Content-Type - browser will set it automatically with boundary
+ 'Authorization': this.defaultHeaders['Authorization']
+ }
+ };
+ return this.streamify('/jobs/upload', formData, streamingOptions);
+ }
+
+ getJobRequirements(jobId: string, streamingOptions?: StreamingOptions): StreamingResponse {
+ streamingOptions = {
+ ...streamingOptions,
+ headers: this.defaultHeaders,
+ };
+ return this.streamify(`/jobs/requirements/${jobId}`, null, streamingOptions);
}
async candidateMatchForRequirement(candidate_id: string, requirement: string) : Promise {
@@ -1001,7 +1021,7 @@ class ApiClient {
* @param options callbacks, headers, and method
* @returns
*/
- streamify(api: string, data: BodyInit, options: StreamingOptions = {}) : StreamingResponse {
+ streamify(api: string, data: BodyInit | null, options: StreamingOptions = {}) : StreamingResponse {
const abortController = new AbortController();
const signal = options.signal || abortController.signal;
const headers = options.headers || null;
diff --git a/frontend/src/types/types.ts b/frontend/src/types/types.ts
index dfc452e..a422f84 100644
--- a/frontend/src/types/types.ts
+++ b/frontend/src/types/types.ts
@@ -1,6 +1,6 @@
// Generated TypeScript types from Pydantic models
// Source: src/backend/models.py
-// Generated on: 2025-06-05T22:02:22.004513
+// Generated on: 2025-06-07T20:43:58.855207
// DO NOT EDIT MANUALLY - This file is auto-generated
// ============================
@@ -323,7 +323,7 @@ export interface ChatMessageMetaData {
presencePenalty?: number;
stopSequences?: Array;
ragResults?: Array;
- llmHistory?: Array;
+ llmHistory?: Array;
evalCount: number;
evalDuration: number;
promptEvalCount: number;
@@ -711,12 +711,6 @@ export interface JobResponse {
meta?: Record;
}
-export interface LLMMessage {
- role: string;
- content: string;
- toolCalls?: Array>;
-}
-
export interface Language {
language: string;
proficiency: "basic" | "conversational" | "fluent" | "native";
diff --git a/src/backend/agents/base.py b/src/backend/agents/base.py
index 12a7af8..e74fed8 100644
--- a/src/backend/agents/base.py
+++ b/src/backend/agents/base.py
@@ -137,7 +137,7 @@ class Agent(BaseModel, ABC):
# llm: Any,
# model: str,
# message: ChatMessage,
- # tool_message: Any, # llama response message
+ # tool_message: Any, # llama response
# messages: List[LLMMessage],
# ) -> AsyncGenerator[ChatMessage, None]:
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
@@ -270,15 +270,15 @@ class Agent(BaseModel, ABC):
# },
# stream=True,
# ):
- # # logger.info(f"LLM::Tools: {'done' if response.done else 'processing'} - {response.message}")
+ # # logger.info(f"LLM::Tools: {'done' if response.finish_reason else 'processing'} - {response}")
# message.status = "streaming"
- # message.chunk = response.message.content
+ # message.chunk = response.content
# message.content += message.chunk
- # if not response.done:
+ # if not response.finish_reason:
# yield message
- # if response.done:
+ # if response.finish_reason:
# self.collect_metrics(response)
# message.metadata.eval_count += response.eval_count
# message.metadata.eval_duration += response.eval_duration
@@ -296,9 +296,9 @@ class Agent(BaseModel, ABC):
def collect_metrics(self, response):
self.metrics.tokens_prompt.labels(agent=self.agent_type).inc(
- response.prompt_eval_count
+ response.usage.prompt_eval_count
)
- self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count)
+ self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.usage.eval_count)
def get_rag_context(self, rag_message: ChatMessageRagSearch) -> str:
"""
@@ -361,7 +361,7 @@ Content: {content}
yield status_message
try:
- chroma_results = user.file_watcher.find_similar(
+ chroma_results = await user.file_watcher.find_similar(
query=prompt, top_k=top_k, threshold=threshold
)
if not chroma_results:
@@ -430,7 +430,7 @@ Content: {content}
logger.info(f"Message options: {options.model_dump(exclude_unset=True)}")
response = None
content = ""
- for response in llm.chat(
+ async for response in llm.chat_stream(
model=model,
messages=messages,
options={
@@ -446,12 +446,12 @@ Content: {content}
yield error_message
return
- content += response.message.content
+ content += response.content
- if not response.done:
+ if not response.finish_reason:
streaming_message = ChatMessageStreaming(
session_id=session_id,
- content=response.message.content,
+ content=response.content,
status=ApiStatusType.STREAMING,
)
yield streaming_message
@@ -466,7 +466,7 @@ Content: {content}
self.collect_metrics(response)
self.context_tokens = (
- response.prompt_eval_count + response.eval_count
+ response.usage.prompt_eval_count + response.usage.eval_count
)
chat_message = ChatMessage(
@@ -476,10 +476,10 @@ Content: {content}
content=content,
metadata = ChatMessageMetaData(
options=options,
- eval_count=response.eval_count,
- eval_duration=response.eval_duration,
- prompt_eval_count=response.prompt_eval_count,
- prompt_eval_duration=response.prompt_eval_duration,
+ eval_count=response.usage.eval_count,
+ eval_duration=response.usage.eval_duration,
+ prompt_eval_count=response.usage.prompt_eval_count,
+ prompt_eval_duration=response.usage.prompt_eval_duration,
)
)
@@ -588,12 +588,12 @@ Content: {content}
# end_time = time.perf_counter()
# message.metadata.timers["tool_check"] = end_time - start_time
- # if not response.message.tool_calls:
+ # if not response.tool_calls:
# logger.info("LLM indicates tools will not be used")
# # The LLM will not use tools, so disable use_tools so we can stream the full response
# use_tools = False
# else:
- # tool_metadata["attempted"] = response.message.tool_calls
+ # tool_metadata["attempted"] = response.tool_calls
# if use_tools:
# logger.info("LLM indicates tools will be used")
@@ -626,15 +626,15 @@ Content: {content}
# yield message
# return
- # if response.message.tool_calls:
- # tool_metadata["used"] = response.message.tool_calls
+ # if response.tool_calls:
+ # tool_metadata["used"] = response.tool_calls
# # Process all yielded items from the handler
# start_time = time.perf_counter()
# async for message in self.process_tool_calls(
# llm=llm,
# model=model,
# message=message,
- # tool_message=response.message,
+ # tool_message=response,
# messages=messages,
# ):
# if message.status == "error":
@@ -647,7 +647,7 @@ Content: {content}
# return
# logger.info("LLM indicated tools will be used, and then they weren't")
- # message.content = response.message.content
+ # message.content = response.content
# message.status = "done"
# yield message
# return
@@ -674,7 +674,7 @@ Content: {content}
content = ""
start_time = time.perf_counter()
response = None
- for response in llm.chat(
+ async for response in llm.chat_stream(
model=model,
messages=messages,
options={
@@ -690,12 +690,12 @@ Content: {content}
yield error_message
return
- content += response.message.content
+ content += response.content
- if not response.done:
+ if not response.finish_reason:
streaming_message = ChatMessageStreaming(
session_id=session_id,
- content=response.message.content,
+ content=response.content,
)
yield streaming_message
@@ -709,7 +709,7 @@ Content: {content}
self.collect_metrics(response)
self.context_tokens = (
- response.prompt_eval_count + response.eval_count
+ response.usage.prompt_eval_count + response.usage.eval_count
)
end_time = time.perf_counter()
@@ -720,10 +720,10 @@ Content: {content}
content=content,
metadata = ChatMessageMetaData(
options=options,
- eval_count=response.eval_count,
- eval_duration=response.eval_duration,
- prompt_eval_count=response.prompt_eval_count,
- prompt_eval_duration=response.prompt_eval_duration,
+ eval_count=response.usage.eval_count,
+ eval_duration=response.usage.eval_duration,
+ prompt_eval_count=response.usage.prompt_eval_count,
+ prompt_eval_duration=response.usage.prompt_eval_duration,
timers={
"llm_streamed": end_time - start_time,
"llm_with_tools": 0, # Placeholder for tool processing time
diff --git a/src/backend/database.py b/src/backend/database.py
index e42fa82..739e6b0 100644
--- a/src/backend/database.py
+++ b/src/backend/database.py
@@ -178,12 +178,13 @@ class RedisDatabase:
'jobs': 'job:',
'job_applications': 'job_application:',
'chat_sessions': 'chat_session:',
- 'chat_messages': 'chat_messages:', # This will store lists
+ 'chat_messages': 'chat_messages:',
'ai_parameters': 'ai_parameters:',
'users': 'user:',
- 'candidate_documents': 'candidate_documents:',
+ 'candidate_documents': 'candidate_documents:',
+ 'job_requirements': 'job_requirements:', # Add this line
}
-
+
def _serialize(self, data: Any) -> str:
"""Serialize data to JSON string for Redis storage"""
if data is None:
@@ -236,8 +237,9 @@ class RedisDatabase:
# Delete each document's metadata
for doc_id in document_ids:
pipe.delete(f"document:{doc_id}")
+ pipe.delete(f"{self.KEY_PREFIXES['job_requirements']}{doc_id}")
deleted_count += 1
-
+
# Delete the candidate's document list
pipe.delete(key)
@@ -250,7 +252,110 @@ class RedisDatabase:
except Exception as e:
logger.error(f"Error deleting all documents for candidate {candidate_id}: {e}")
raise
-
+
+ async def get_cached_skill_match(self, cache_key: str) -> Optional[Dict[str, Any]]:
+ """Retrieve cached skill match assessment"""
+ try:
+ cached_data = await self.redis.get(cache_key)
+ if cached_data:
+ return json.loads(cached_data)
+ return None
+ except Exception as e:
+ logger.error(f"Error retrieving cached skill match: {e}")
+ return None
+
+ async def cache_skill_match(self, cache_key: str, assessment_data: Dict[str, Any], ttl: int = 86400 * 30) -> bool:
+ """Cache skill match assessment with TTL (default 30 days)"""
+ try:
+ await self.redis.setex(
+ cache_key,
+ ttl,
+ json.dumps(assessment_data, default=str)
+ )
+ return True
+ except Exception as e:
+ logger.error(f"Error caching skill match: {e}")
+ return False
+
+ async def get_candidate_skill_update_time(self, candidate_id: str) -> Optional[datetime]:
+ """Get the last time candidate's skill information was updated"""
+ try:
+ # This assumes you track skill update timestamps in your candidate data
+ candidate_data = await self.get_candidate(candidate_id)
+ if candidate_data and 'skills_updated_at' in candidate_data:
+ return datetime.fromisoformat(candidate_data['skills_updated_at'])
+ return None
+ except Exception as e:
+ logger.error(f"Error getting candidate skill update time: {e}")
+ return None
+
+ async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]:
+ """Get the timestamp of the latest RAG data update for a specific user"""
+ try:
+ rag_update_key = f"user:{user_id}:rag_last_update"
+ timestamp_str = await self.redis.get(rag_update_key)
+ if timestamp_str:
+ return datetime.fromisoformat(timestamp_str.decode('utf-8'))
+ return None
+ except Exception as e:
+ logger.error(f"Error getting user RAG update time for user {user_id}: {e}")
+ return None
+
+ async def update_user_rag_timestamp(self, user_id: str) -> bool:
+ """Update the RAG data timestamp for a specific user (call this when user's RAG data is updated)"""
+ try:
+ rag_update_key = f"user:{user_id}:rag_last_update"
+ current_time = datetime.utcnow().isoformat()
+ await self.redis.set(rag_update_key, current_time)
+ return True
+ except Exception as e:
+ logger.error(f"Error updating RAG timestamp for user {user_id}: {e}")
+ return False
+
+ async def invalidate_candidate_skill_cache(self, candidate_id: str) -> int:
+ """Invalidate all cached skill matches for a specific candidate"""
+ try:
+ pattern = f"skill_match:{candidate_id}:*"
+ keys = await self.redis.keys(pattern)
+ if keys:
+ return await self.redis.delete(*keys)
+ return 0
+ except Exception as e:
+ logger.error(f"Error invalidating candidate skill cache: {e}")
+ return 0
+
+ async def clear_all_skill_match_cache(self) -> int:
+ """Clear all skill match cache (useful after major system updates)"""
+ try:
+ pattern = "skill_match:*"
+ keys = await self.redis.keys(pattern)
+ if keys:
+ return await self.redis.delete(*keys)
+ return 0
+ except Exception as e:
+ logger.error(f"Error clearing skill match cache: {e}")
+ return 0
+
+ async def invalidate_user_skill_cache(self, user_id: str) -> int:
+ """Invalidate all cached skill matches when a user's RAG data is updated"""
+ try:
+ # This assumes all candidates belonging to this user need cache invalidation
+ # You might need to adjust the pattern based on how you associate candidates with users
+ pattern = f"skill_match:*"
+ keys = await self.redis.keys(pattern)
+
+ # Filter keys that belong to candidates owned by this user
+ # This would require additional logic to determine candidate ownership
+ # For now, you might want to clear all cache when any user's RAG data updates
+ # or implement a more sophisticated mapping
+
+ if keys:
+ return await self.redis.delete(*keys)
+ return 0
+ except Exception as e:
+ logger.error(f"Error invalidating user skill cache for user {user_id}: {e}")
+ return 0
+
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]:
"""Get all documents for a specific candidate"""
key = f"{self.KEY_PREFIXES['candidate_documents']}{candidate_id}"
@@ -330,7 +435,134 @@ class RedisDatabase:
if (query_lower in doc.get("filename", "").lower() or
query_lower in doc.get("originalName", "").lower())
]
-
+
+ async def get_job_requirements(self, document_id: str) -> Optional[Dict]:
+ """Get cached job requirements analysis for a document"""
+ try:
+ key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}"
+ data = await self.redis.get(key)
+ if data:
+ requirements_data = self._deserialize(data)
+ logger.debug(f"📋 Retrieved cached job requirements for document {document_id}")
+ return requirements_data
+ logger.debug(f"📋 No cached job requirements found for document {document_id}")
+ return None
+ except Exception as e:
+ logger.error(f"❌ Error retrieving job requirements for document {document_id}: {e}")
+ return None
+
+ async def save_job_requirements(self, document_id: str, requirements: Dict) -> bool:
+ """Save job requirements analysis results for a document"""
+ try:
+ key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}"
+
+ # Add metadata to the requirements
+ requirements_with_meta = {
+ **requirements,
+ "cached_at": datetime.now(UTC).isoformat(),
+ "document_id": document_id
+ }
+
+ await self.redis.set(key, self._serialize(requirements_with_meta))
+
+ # Optional: Set expiration (e.g., 30 days) to prevent indefinite storage
+ # await self.redis.expire(key, 30 * 24 * 60 * 60) # 30 days
+
+ logger.debug(f"📋 Saved job requirements for document {document_id}")
+ return True
+ except Exception as e:
+ logger.error(f"❌ Error saving job requirements for document {document_id}: {e}")
+ return False
+
+ async def delete_job_requirements(self, document_id: str) -> bool:
+ """Delete cached job requirements for a document"""
+ try:
+ key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}"
+ result = await self.redis.delete(key)
+ if result > 0:
+ logger.debug(f"📋 Deleted job requirements for document {document_id}")
+ return True
+ return False
+ except Exception as e:
+ logger.error(f"❌ Error deleting job requirements for document {document_id}: {e}")
+ return False
+
+ async def get_all_job_requirements(self) -> Dict[str, Any]:
+ """Get all cached job requirements"""
+ try:
+ pattern = f"{self.KEY_PREFIXES['job_requirements']}*"
+ keys = await self.redis.keys(pattern)
+
+ if not keys:
+ return {}
+
+ pipe = self.redis.pipeline()
+ for key in keys:
+ pipe.get(key)
+ values = await pipe.execute()
+
+ result = {}
+ for key, value in zip(keys, values):
+ document_id = key.replace(self.KEY_PREFIXES['job_requirements'], '')
+ if value:
+ result[document_id] = self._deserialize(value)
+
+ return result
+ except Exception as e:
+ logger.error(f"❌ Error retrieving all job requirements: {e}")
+ return {}
+
+ async def get_job_requirements_by_candidate(self, candidate_id: str) -> List[Dict]:
+ """Get all job requirements analysis for documents belonging to a candidate"""
+ try:
+ # Get all documents for the candidate
+ candidate_documents = await self.get_candidate_documents(candidate_id)
+
+ if not candidate_documents:
+ return []
+
+ # Get job requirements for each document
+ job_requirements = []
+ for doc in candidate_documents:
+ doc_id = doc.get("id")
+ if doc_id:
+ requirements = await self.get_job_requirements(doc_id)
+ if requirements:
+ # Add document metadata to requirements
+ requirements["document_filename"] = doc.get("filename")
+ requirements["document_original_name"] = doc.get("originalName")
+ job_requirements.append(requirements)
+
+ return job_requirements
+ except Exception as e:
+ logger.error(f"❌ Error retrieving job requirements for candidate {candidate_id}: {e}")
+ return []
+
+ async def invalidate_job_requirements_cache(self, document_id: str) -> bool:
+ """Invalidate (delete) cached job requirements for a document"""
+ # This is an alias for delete_job_requirements for semantic clarity
+ return await self.delete_job_requirements(document_id)
+
+ async def bulk_delete_job_requirements(self, document_ids: List[str]) -> int:
+ """Delete job requirements for multiple documents and return count of deleted items"""
+ try:
+ deleted_count = 0
+ pipe = self.redis.pipeline()
+
+ for doc_id in document_ids:
+ key = f"{self.KEY_PREFIXES['job_requirements']}{doc_id}"
+ pipe.delete(key)
+ deleted_count += 1
+
+ results = await pipe.execute()
+ actual_deleted = sum(1 for result in results if result > 0)
+
+ logger.info(f"📋 Bulk deleted job requirements for {actual_deleted}/{len(document_ids)} documents")
+ return actual_deleted
+ except Exception as e:
+ logger.error(f"❌ Error bulk deleting job requirements: {e}")
+ return 0
+
# Viewer operations
async def get_viewer(self, viewer_id: str) -> Optional[Dict]:
"""Get viewer by ID"""
@@ -1484,6 +1716,74 @@ class RedisDatabase:
key = f"{self.KEY_PREFIXES['users']}{email}"
await self.redis.delete(key)
+ async def get_job_requirements_stats(self) -> Dict[str, Any]:
+ """Get statistics about cached job requirements"""
+ try:
+ pattern = f"{self.KEY_PREFIXES['job_requirements']}*"
+ keys = await self.redis.keys(pattern)
+
+ stats = {
+ "total_cached_requirements": len(keys),
+ "cache_dates": {},
+ "documents_with_requirements": []
+ }
+
+ if keys:
+ # Get cache dates for analysis
+ pipe = self.redis.pipeline()
+ for key in keys:
+ pipe.get(key)
+ values = await pipe.execute()
+
+ for key, value in zip(keys, values):
+ if value:
+ requirements_data = self._deserialize(value)
+ if requirements_data:
+ document_id = key.replace(self.KEY_PREFIXES['job_requirements'], '')
+ stats["documents_with_requirements"].append(document_id)
+
+ # Track cache dates
+ cached_at = requirements_data.get("cached_at")
+ if cached_at:
+ cache_date = cached_at[:10] # Extract date part
+ stats["cache_dates"][cache_date] = stats["cache_dates"].get(cache_date, 0) + 1
+
+ return stats
+ except Exception as e:
+ logger.error(f"❌ Error getting job requirements stats: {e}")
+ return {"total_cached_requirements": 0, "cache_dates": {}, "documents_with_requirements": []}
+
+ async def cleanup_orphaned_job_requirements(self) -> int:
+ """Clean up job requirements for documents that no longer exist"""
+ try:
+ # Get all job requirements
+ all_requirements = await self.get_all_job_requirements()
+
+ if not all_requirements:
+ return 0
+
+ orphaned_count = 0
+ pipe = self.redis.pipeline()
+
+ for document_id in all_requirements.keys():
+ # Check if the document still exists
+ document_exists = await self.get_document(document_id)
+ if not document_exists:
+ # Document no longer exists, delete its job requirements
+ key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}"
+ pipe.delete(key)
+ orphaned_count += 1
+ logger.debug(f"📋 Queued orphaned job requirements for deletion: {document_id}")
+
+ if orphaned_count > 0:
+ await pipe.execute()
+ logger.info(f"🧹 Cleaned up {orphaned_count} orphaned job requirements")
+
+ return orphaned_count
+ except Exception as e:
+ logger.error(f"❌ Error cleaning up orphaned job requirements: {e}")
+ return 0
+
# Utility methods
async def clear_all_data(self):
"""Clear all data from Redis (use with caution!)"""
diff --git a/src/backend/entities/candidate_entity.py b/src/backend/entities/candidate_entity.py
index 91a83a1..a35da0a 100644
--- a/src/backend/entities/candidate_entity.py
+++ b/src/backend/entities/candidate_entity.py
@@ -19,8 +19,9 @@ import defines
from logger import logger
import agents as agents
from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, RagEntry, ChatMessageMetaData, ApiStatusType, Candidate, ChatContextType)
-from llm_manager import llm_manager
+import llm_proxy as llm_manager
from agents.base import Agent
+from database import RedisDatabase
class CandidateEntity(Candidate):
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
@@ -115,7 +116,7 @@ class CandidateEntity(Candidate):
raise ValueError("initialize() has not been called.")
return self.CandidateEntity__observer
- async def initialize(self, prometheus_collector: CollectorRegistry):
+ async def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
if self.CandidateEntity__initialized:
# Initialization can only be attempted once; if there are multiple attempts, it means
# a subsystem is failing or there is a logic bug in the code.
@@ -140,9 +141,11 @@ class CandidateEntity(Candidate):
self.CandidateEntity__observer, self.CandidateEntity__file_watcher = start_file_watcher(
llm=llm_manager.get_llm(),
+ user_id=self.id,
collection_name=self.username,
persist_directory=vector_db_dir,
watch_directory=rag_content_dir,
+ database=database,
recreate=False, # Don't recreate if exists
)
has_username_rag = any(item["name"] == self.username for item in self.rags)
diff --git a/src/backend/entities/entity_manager.py b/src/backend/entities/entity_manager.py
index cebb79d..4b535b4 100644
--- a/src/backend/entities/entity_manager.py
+++ b/src/backend/entities/entity_manager.py
@@ -7,6 +7,7 @@ from pydantic import BaseModel, Field # type: ignore
from models import ( Candidate )
from .candidate_entity import CandidateEntity
+from database import RedisDatabase
from prometheus_client import CollectorRegistry # type: ignore
class EntityManager:
@@ -34,9 +35,10 @@ class EntityManager:
pass
self._cleanup_task = None
- def initialize(self, prometheus_collector: CollectorRegistry):
+ def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
"""Initialize the EntityManager with Prometheus collector"""
self._prometheus_collector = prometheus_collector
+ self._database = database
async def get_entity(self, candidate: Candidate) -> CandidateEntity:
"""Get or create CandidateEntity with proper reference tracking"""
@@ -49,7 +51,7 @@ class EntityManager:
return entity
entity = CandidateEntity(candidate=candidate)
- await entity.initialize(prometheus_collector=self._prometheus_collector)
+ await entity.initialize(prometheus_collector=self._prometheus_collector, database=self._database)
# Store with reference tracking
self._entities[candidate.id] = entity
diff --git a/src/backend/llm_manager.py b/src/backend/llm_manager.py
deleted file mode 100644
index 43281c8..0000000
--- a/src/backend/llm_manager.py
+++ /dev/null
@@ -1,41 +0,0 @@
-import ollama
-import defines
-
-_llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
-
-class llm_manager:
- """
- A class to manage LLM operations using the Ollama client.
- """
- @staticmethod
- def get_llm() -> ollama.Client: # type: ignore
- """
- Get the Ollama client instance.
-
- Returns:
- An instance of the Ollama client.
- """
- return _llm
-
- @staticmethod
- def get_models() -> list[str]:
- """
- Get a list of available models from the Ollama client.
-
- Returns:
- List of model names.
- """
- return _llm.models()
-
- @staticmethod
- def get_model_info(model_name: str) -> dict:
- """
- Get information about a specific model.
-
- Args:
- model_name: The name of the model to retrieve information for.
-
- Returns:
- A dictionary containing model information.
- """
- return _llm.model(model_name)
\ No newline at end of file
diff --git a/src/backend/llm_proxy.py b/src/backend/llm_proxy.py
new file mode 100644
index 0000000..3e0f77c
--- /dev/null
+++ b/src/backend/llm_proxy.py
@@ -0,0 +1,1203 @@
+from abc import ABC, abstractmethod
+from typing import Dict, List, Any, AsyncGenerator, Optional, Union
+from pydantic import BaseModel, Field # type: ignore
+from enum import Enum
+import asyncio
+import json
+from dataclasses import dataclass
+import os
+import defines
+from logger import logger
+
+# Standard message format for all providers
+@dataclass
+class LLMMessage:
+ role: str # "user", "assistant", "system"
+ content: str
+
+ def __getitem__(self, key: str):
+ """Allow dictionary-style access for backward compatibility"""
+ if key == 'role':
+ return self.role
+ elif key == 'content':
+ return self.content
+ else:
+ raise KeyError(f"'{key}' not found in LLMMessage")
+
+ def __setitem__(self, key: str, value: str):
+ """Allow dictionary-style assignment"""
+ if key == 'role':
+ self.role = value
+ elif key == 'content':
+ self.content = value
+ else:
+ raise KeyError(f"'{key}' not found in LLMMessage")
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, str]) -> 'LLMMessage':
+ """Create LLMMessage from dictionary"""
+ return cls(role=data['role'], content=data['content'])
+
+ def to_dict(self) -> Dict[str, str]:
+ """Convert LLMMessage to dictionary"""
+ return {'role': self.role, 'content': self.content}
+
+# Enhanced usage statistics model
+class UsageStats(BaseModel):
+ """Comprehensive usage statistics across all providers"""
+ # Token counts (standardized across providers)
+ prompt_tokens: Optional[int] = Field(default=None, description="Number of tokens in the prompt")
+ completion_tokens: Optional[int] = Field(default=None, description="Number of tokens in the completion")
+ total_tokens: Optional[int] = Field(default=None, description="Total number of tokens used")
+
+ # Ollama-specific detailed stats
+ prompt_eval_count: Optional[int] = Field(default=None, description="Number of tokens evaluated in prompt")
+ prompt_eval_duration: Optional[int] = Field(default=None, description="Time spent evaluating prompt (nanoseconds)")
+ eval_count: Optional[int] = Field(default=None, description="Number of tokens generated")
+ eval_duration: Optional[int] = Field(default=None, description="Time spent generating tokens (nanoseconds)")
+ total_duration: Optional[int] = Field(default=None, description="Total request duration (nanoseconds)")
+
+ # Performance metrics
+ tokens_per_second: Optional[float] = Field(default=None, description="Generation speed in tokens/second")
+ prompt_tokens_per_second: Optional[float] = Field(default=None, description="Prompt processing speed")
+
+ # Additional provider-specific stats
+ extra_stats: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Provider-specific additional statistics")
+
+ def calculate_derived_stats(self) -> None:
+ """Calculate derived statistics where possible"""
+ # Calculate tokens per second for Ollama
+ if self.eval_count and self.eval_duration and self.eval_duration > 0:
+ # Convert nanoseconds to seconds and calculate tokens/sec
+ duration_seconds = self.eval_duration / 1_000_000_000
+ self.tokens_per_second = self.eval_count / duration_seconds
+
+ # Calculate prompt processing speed for Ollama
+ if self.prompt_eval_count and self.prompt_eval_duration and self.prompt_eval_duration > 0:
+ duration_seconds = self.prompt_eval_duration / 1_000_000_000
+ self.prompt_tokens_per_second = self.prompt_eval_count / duration_seconds
+
+ # Standardize token counts across providers
+ if not self.total_tokens and self.prompt_tokens and self.completion_tokens:
+ self.total_tokens = self.prompt_tokens + self.completion_tokens
+
+ # Map Ollama counts to standard format if not already set
+ if not self.prompt_tokens and self.prompt_eval_count:
+ self.prompt_tokens = self.prompt_eval_count
+ if not self.completion_tokens and self.eval_count:
+ self.completion_tokens = self.eval_count
+
+# Embedding response models
+class EmbeddingData(BaseModel):
+ """Single embedding result"""
+ embedding: List[float] = Field(description="The embedding vector")
+ index: int = Field(description="Index in the input list")
+
+class EmbeddingResponse(BaseModel):
+ """Response from embeddings API"""
+ data: List[EmbeddingData] = Field(description="List of embedding results")
+ model: str = Field(description="Model used for embeddings")
+ usage: Optional[UsageStats] = Field(default=None, description="Usage statistics")
+
+ def get_single_embedding(self) -> List[float]:
+ """Get the first embedding for single-text requests"""
+ if not self.data:
+ raise ValueError("No embeddings in response")
+ return self.data[0].embedding
+
+ def get_embeddings(self) -> List[List[float]]:
+ """Get all embeddings as a list of vectors"""
+ return [item.embedding for item in self.data]
+
+class ChatResponse(BaseModel):
+ content: str
+ model: str
+ finish_reason: Optional[str] = Field(default="")
+ usage: Optional[UsageStats] = Field(default=None)
+ # Keep legacy usage field for backward compatibility
+ usage_legacy: Optional[Dict[str, int]] = Field(default=None, alias="usage_dict")
+
+ def get_usage_dict(self) -> Dict[str, Any]:
+ """Get usage statistics as dictionary for backward compatibility"""
+ if self.usage:
+ return self.usage.model_dump(exclude_none=True)
+ return self.usage_legacy or {}
+
+class LLMProvider(str, Enum):
+ OLLAMA = "ollama"
+ OPENAI = "openai"
+ ANTHROPIC = "anthropic"
+ GEMINI = "gemini"
+ GROK = "grok"
+
+class BaseLLMAdapter(ABC):
+ """Abstract base class for all LLM adapters"""
+
+ def __init__(self, **config):
+ self.config = config
+
+ @abstractmethod
+ async def chat(
+ self,
+ model: str,
+ messages: List[LLMMessage],
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+ """Send chat messages and get response"""
+ pass
+
+ @abstractmethod
+ async def generate(
+ self,
+ model: str,
+ prompt: str,
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+ """Generate text from prompt"""
+ pass
+
+ @abstractmethod
+ async def embeddings(
+ self,
+ model: str,
+ input_texts: Union[str, List[str]],
+ **kwargs
+ ) -> EmbeddingResponse:
+ """Generate embeddings for input text(s)"""
+ pass
+
+ @abstractmethod
+ async def list_models(self) -> List[str]:
+ """List available models"""
+ pass
+
+class OllamaAdapter(BaseLLMAdapter):
+ """Adapter for Ollama with enhanced statistics"""
+
+ def __init__(self, **config):
+ super().__init__(**config)
+ import ollama
+ self.client = ollama.AsyncClient( # type: ignore
+ host=config.get('host', defines.ollama_api_url)
+ )
+
+ def _create_usage_stats(self, response_data: Dict[str, Any]) -> UsageStats:
+ """Create comprehensive usage statistics from Ollama response"""
+ usage_stats = UsageStats()
+
+ # Extract Ollama-specific stats
+ if 'prompt_eval_count' in response_data:
+ usage_stats.prompt_eval_count = response_data['prompt_eval_count']
+ if 'prompt_eval_duration' in response_data:
+ usage_stats.prompt_eval_duration = response_data['prompt_eval_duration']
+ if 'eval_count' in response_data:
+ usage_stats.eval_count = response_data['eval_count']
+ if 'eval_duration' in response_data:
+ usage_stats.eval_duration = response_data['eval_duration']
+ if 'total_duration' in response_data:
+ usage_stats.total_duration = response_data['total_duration']
+
+ # Store any additional stats
+ extra_stats = {}
+ if type(response_data) is dict:
+ for key, value in response_data.items():
+ if key not in ['prompt_eval_count', 'prompt_eval_duration', 'eval_count',
+ 'eval_duration', 'total_duration', 'message', 'done_reason', 'response']:
+ extra_stats[key] = value
+
+ if extra_stats:
+ usage_stats.extra_stats = extra_stats
+
+ # Calculate derived statistics
+ usage_stats.calculate_derived_stats()
+
+ return usage_stats
+
+ async def chat(
+ self,
+ model: str,
+ messages: List[LLMMessage],
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+
+ # Convert LLMMessage objects to Ollama format
+ ollama_messages = []
+ for msg in messages:
+ ollama_messages.append({
+ "role": msg.role,
+ "content": msg.content
+ })
+
+ if stream:
+ return self._stream_chat(model, ollama_messages, **kwargs)
+ else:
+ response = await self.client.chat(
+ model=model,
+ messages=ollama_messages,
+ stream=False,
+ **kwargs
+ )
+
+ usage_stats = self._create_usage_stats(response)
+
+ return ChatResponse(
+ content=response['message']['content'],
+ model=model,
+ finish_reason=response.get('done_reason'),
+ usage=usage_stats
+ )
+
+ async def _stream_chat(self, model: str, messages: List[Dict], **kwargs):
+ # Await the chat call first, then iterate over the result
+ stream = await self.client.chat(
+ model=model,
+ messages=messages,
+ stream=True,
+ **kwargs
+ )
+
+ # Accumulate stats for final chunk
+ accumulated_stats = {}
+
+ async for chunk in stream:
+ # Update accumulated stats
+ accumulated_stats.update(chunk)
+ content = chunk.get('message', {}).get('content', '')
+ usage_stats = None
+ if chunk.done:
+ usage_stats = self._create_usage_stats(accumulated_stats)
+ yield ChatResponse(
+ content=content,
+ model=model,
+ finish_reason=chunk.get('done_reason'),
+ usage=usage_stats
+ )
+ else:
+ yield ChatResponse(
+ content=content,
+ model=model,
+ finish_reason=None,
+ usage=None
+ )
+
+ async def generate(
+ self,
+ model: str,
+ prompt: str,
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+
+ if stream:
+ return self._stream_generate(model, prompt, **kwargs)
+ else:
+ response = await self.client.generate(
+ model=model,
+ prompt=prompt,
+ stream=False,
+ **kwargs
+ )
+
+ usage_stats = self._create_usage_stats(response)
+
+ return ChatResponse(
+ content=response['response'],
+ model=model,
+ finish_reason=response.get('done_reason'),
+ usage=usage_stats
+ )
+
+ async def _stream_generate(self, model: str, prompt: str, **kwargs):
+ # Await the generate call first, then iterate over the result
+ stream = await self.client.generate(
+ model=model,
+ prompt=prompt,
+ stream=True,
+ **kwargs
+ )
+
+ accumulated_stats = {}
+
+ async for chunk in stream:
+ if chunk.get('response'):
+ accumulated_stats.update(chunk)
+
+ usage_stats = None
+ if chunk.get('done', False): # Only create stats for final chunk
+ usage_stats = self._create_usage_stats(accumulated_stats)
+
+ yield ChatResponse(
+ content=chunk['response'],
+ model=model,
+ finish_reason=chunk.get('done_reason'),
+ usage=usage_stats
+ )
+
+ async def embeddings(
+ self,
+ model: str,
+ input_texts: Union[str, List[str]],
+ **kwargs
+ ) -> EmbeddingResponse:
+ """Generate embeddings using Ollama"""
+ # Normalize input to list
+ if isinstance(input_texts, str):
+ texts = [input_texts]
+ else:
+ texts = input_texts
+
+ # Ollama embeddings API typically handles one text at a time
+ results = []
+ final_response = None
+
+ for i, text in enumerate(texts):
+ response = await self.client.embeddings(
+ model=model,
+ prompt=text,
+ **kwargs
+ )
+
+ results.append(EmbeddingData(
+ embedding=response['embedding'],
+ index=i
+ ))
+ final_response = response
+
+ # Create usage stats if available from the last response
+ usage_stats = None
+ if final_response and len(results) == 1:
+ usage_stats = self._create_usage_stats(final_response)
+
+ return EmbeddingResponse(
+ data=results,
+ model=model,
+ usage=usage_stats
+ )
+
+ async def list_models(self) -> List[str]:
+ models = await self.client.list()
+ return [model['name'] for model in models['models']]
+
+class OpenAIAdapter(BaseLLMAdapter):
+ """Adapter for OpenAI with enhanced statistics"""
+
+ def __init__(self, **config):
+ super().__init__(**config)
+ import openai # type: ignore
+ self.client = openai.AsyncOpenAI(
+ api_key=config.get('api_key', os.getenv('OPENAI_API_KEY'))
+ )
+
+ def _create_usage_stats(self, usage_data: Any) -> UsageStats:
+ """Create usage statistics from OpenAI response"""
+ if not usage_data:
+ return UsageStats()
+
+ usage_dict = usage_data.model_dump() if hasattr(usage_data, 'model_dump') else usage_data
+
+ return UsageStats(
+ prompt_tokens=usage_dict.get('prompt_tokens'),
+ completion_tokens=usage_dict.get('completion_tokens'),
+ total_tokens=usage_dict.get('total_tokens'),
+ extra_stats={k: v for k, v in usage_dict.items()
+ if k not in ['prompt_tokens', 'completion_tokens', 'total_tokens']}
+ )
+
+ async def chat(
+ self,
+ model: str,
+ messages: List[LLMMessage],
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+
+ # Convert LLMMessage objects to OpenAI format
+ openai_messages = []
+ for msg in messages:
+ openai_messages.append({
+ "role": msg.role,
+ "content": msg.content
+ })
+
+ if stream:
+ return self._stream_chat(model, openai_messages, **kwargs)
+ else:
+ response = await self.client.chat.completions.create(
+ model=model,
+ messages=openai_messages,
+ stream=False,
+ **kwargs
+ )
+
+ usage_stats = self._create_usage_stats(response.usage)
+
+ return ChatResponse(
+ content=response.choices[0].message.content,
+ model=model,
+ finish_reason=response.choices[0].finish_reason,
+ usage=usage_stats
+ )
+
+ async def _stream_chat(self, model: str, messages: List[Dict], **kwargs):
+ # Await the stream creation first, then iterate
+ stream = await self.client.chat.completions.create(
+ model=model,
+ messages=messages,
+ stream=True,
+ **kwargs
+ )
+
+ async for chunk in stream:
+ if chunk.choices[0].delta.content:
+ # Usage stats are only available in the final chunk for OpenAI streaming
+ usage_stats = None
+ if (chunk.choices[0].finish_reason is not None and
+ hasattr(chunk, 'usage') and chunk.usage):
+ usage_stats = self._create_usage_stats(chunk.usage)
+
+ yield ChatResponse(
+ content=chunk.choices[0].delta.content,
+ model=model,
+ finish_reason=chunk.choices[0].finish_reason,
+ usage=usage_stats
+ )
+
+ async def generate(
+ self,
+ model: str,
+ prompt: str,
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+
+ # Convert to chat format for OpenAI
+ messages = [LLMMessage(role="user", content=prompt)]
+ return await self.chat(model, messages, stream, **kwargs)
+
+ async def embeddings(
+ self,
+ model: str,
+ input_texts: Union[str, List[str]],
+ **kwargs
+ ) -> EmbeddingResponse:
+ """Generate embeddings using OpenAI"""
+ # Normalize input to list
+ if isinstance(input_texts, str):
+ texts = [input_texts]
+ else:
+ texts = input_texts
+
+ response = await self.client.embeddings.create(
+ model=model,
+ input=texts,
+ **kwargs
+ )
+
+ # Convert OpenAI response to our format
+ results = []
+ for item in response.data:
+ results.append(EmbeddingData(
+ embedding=item.embedding,
+ index=item.index
+ ))
+
+ # Create usage stats
+ usage_stats = self._create_usage_stats(response.usage)
+
+ return EmbeddingResponse(
+ data=results,
+ model=model,
+ usage=usage_stats
+ )
+
+ async def list_models(self) -> List[str]:
+ models = await self.client.models.list()
+ return [model.id for model in models.data]
+
+class AnthropicAdapter(BaseLLMAdapter):
+ """Adapter for Anthropic Claude with enhanced statistics"""
+
+ def __init__(self, **config):
+ super().__init__(**config)
+ import anthropic # type: ignore
+ self.client = anthropic.AsyncAnthropic(
+ api_key=config.get('api_key', os.getenv('ANTHROPIC_API_KEY'))
+ )
+
+ def _create_usage_stats(self, usage_data: Any) -> UsageStats:
+ """Create usage statistics from Anthropic response"""
+ if not usage_data:
+ return UsageStats()
+
+ usage_dict = usage_data.model_dump() if hasattr(usage_data, 'model_dump') else usage_data
+
+ return UsageStats(
+ prompt_tokens=usage_dict.get('input_tokens'),
+ completion_tokens=usage_dict.get('output_tokens'),
+ total_tokens=(usage_dict.get('input_tokens', 0) + usage_dict.get('output_tokens', 0)) or None,
+ extra_stats={k: v for k, v in usage_dict.items()
+ if k not in ['input_tokens', 'output_tokens']}
+ )
+
+ async def chat(
+ self,
+ model: str,
+ messages: List[LLMMessage],
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+
+ # Anthropic requires system message to be separate
+ system_message = None
+ anthropic_messages = []
+
+ for msg in messages:
+ if msg.role == "system":
+ system_message = msg.content
+ else:
+ anthropic_messages.append({
+ "role": msg.role,
+ "content": msg.content
+ })
+
+ request_kwargs = {
+ "model": model,
+ "messages": anthropic_messages,
+ "max_tokens": kwargs.pop('max_tokens', 1000),
+ **kwargs
+ }
+
+ if system_message:
+ request_kwargs["system"] = system_message
+
+ if stream:
+ return self._stream_chat(**request_kwargs)
+ else:
+ response = await self.client.messages.create(
+ stream=False,
+ **request_kwargs
+ )
+
+ usage_stats = self._create_usage_stats(response.usage)
+
+ return ChatResponse(
+ content=response.content[0].text,
+ model=model,
+ finish_reason=response.stop_reason,
+ usage=usage_stats
+ )
+
+ async def _stream_chat(self, **kwargs):
+ model = kwargs['model']
+
+ async with self.client.messages.stream(**kwargs) as stream:
+ final_usage_stats = None
+
+ async for text in stream.text_stream:
+ yield ChatResponse(
+ content=text,
+ model=model,
+ usage=None # Usage stats not available until stream completion
+ )
+
+ # Collect usage stats after stream completion
+ try:
+ if hasattr(stream, 'get_final_message'):
+ final_message = await stream.get_final_message()
+ if final_message and hasattr(final_message, 'usage'):
+ final_usage_stats = self._create_usage_stats(final_message.usage)
+
+ # Yield a final empty response with usage stats
+ yield ChatResponse(
+ content="",
+ model=model,
+ usage=final_usage_stats,
+ finish_reason="stop"
+ )
+ except Exception as e:
+ logger.debug(f"Could not retrieve final usage stats: {e}")
+
+ async def generate(
+ self,
+ model: str,
+ prompt: str,
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+
+ messages = [LLMMessage(role="user", content=prompt)]
+ return await self.chat(model, messages, stream, **kwargs)
+
+ async def embeddings(
+ self,
+ model: str,
+ input_texts: Union[str, List[str]],
+ **kwargs
+ ) -> EmbeddingResponse:
+ """Anthropic doesn't provide embeddings API"""
+ raise NotImplementedError(
+ "Anthropic does not provide embeddings API. "
+ "Consider using OpenAI, Ollama, or Gemini for embeddings."
+ )
+
+ async def list_models(self) -> List[str]:
+ # Anthropic doesn't have a list models endpoint, return known models
+ return [
+ "claude-3-5-sonnet-20241022",
+ "claude-3-5-haiku-20241022",
+ "claude-3-opus-20240229"
+ ]
+
+class GeminiAdapter(BaseLLMAdapter):
+ """Adapter for Google Gemini with enhanced statistics"""
+
+ def __init__(self, **config):
+ super().__init__(**config)
+ import google.generativeai as genai # type: ignore
+ genai.configure(api_key=config.get('api_key', os.getenv('GEMINI_API_KEY')))
+ self.genai = genai
+
+ def _create_usage_stats(self, response: Any) -> UsageStats:
+ """Create usage statistics from Gemini response"""
+ usage_stats = UsageStats()
+
+ # Gemini usage metadata extraction
+ if hasattr(response, 'usage_metadata'):
+ usage = response.usage_metadata
+ usage_stats.prompt_tokens = getattr(usage, 'prompt_token_count', None)
+ usage_stats.completion_tokens = getattr(usage, 'candidates_token_count', None)
+ usage_stats.total_tokens = getattr(usage, 'total_token_count', None)
+
+ # Store additional Gemini-specific stats
+ extra_stats = {}
+ for attr in dir(usage):
+ if not attr.startswith('_') and attr not in ['prompt_token_count', 'candidates_token_count', 'total_token_count']:
+ extra_stats[attr] = getattr(usage, attr)
+
+ if extra_stats:
+ usage_stats.extra_stats = extra_stats
+
+ return usage_stats
+
+ async def chat(
+ self,
+ model: str,
+ messages: List[LLMMessage],
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+
+ model_instance = self.genai.GenerativeModel(model)
+
+ # Convert messages to Gemini format
+ chat_history = []
+ current_message = None
+
+ for i, msg in enumerate(messages[:-1]): # All but last message go to history
+ if msg.role == "user":
+ chat_history.append({"role": "user", "parts": [msg.content]})
+ elif msg.role == "assistant":
+ chat_history.append({"role": "model", "parts": [msg.content]})
+
+ # Last message is the current prompt
+ if messages:
+ current_message = messages[-1].content
+
+ chat = model_instance.start_chat(history=chat_history)
+ if not current_message:
+ raise ValueError("No current message provided for chat")
+
+ if stream:
+ return self._stream_chat(chat, current_message, **kwargs)
+ else:
+ response = await chat.send_message_async(current_message, **kwargs)
+ usage_stats = self._create_usage_stats(response)
+
+ return ChatResponse(
+ content=response.text,
+ model=model,
+ finish_reason=response.candidates[0].finish_reason.name if response.candidates else None,
+ usage=usage_stats
+ )
+
+ async def _stream_chat(self, chat, message: str, **kwargs):
+ # Await the stream creation first, then iterate
+ stream = await chat.send_message_async(message, stream=True, **kwargs)
+
+ final_chunk = None
+
+ async for chunk in stream:
+ if chunk.text:
+ final_chunk = chunk # Keep reference to last chunk
+
+ yield ChatResponse(
+ content=chunk.text,
+ model=chat.model.model_name,
+ usage=None # Don't include usage stats until final chunk
+ )
+
+ # After streaming is complete, yield final response with usage stats
+ if final_chunk:
+ usage_stats = self._create_usage_stats(final_chunk)
+ if usage_stats and any([usage_stats.prompt_tokens, usage_stats.completion_tokens, usage_stats.total_tokens]):
+ yield ChatResponse(
+ content="",
+ model=chat.model.model_name,
+ usage=usage_stats,
+ finish_reason=final_chunk.candidates[0].finish_reason.name if final_chunk.candidates else None
+ )
+
+ async def generate(
+ self,
+ model: str,
+ prompt: str,
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+
+ messages = [LLMMessage(role="user", content=prompt)]
+ return await self.chat(model, messages, stream, **kwargs)
+
+ async def embeddings(
+ self,
+ model: str,
+ input_texts: Union[str, List[str]],
+ **kwargs
+ ) -> EmbeddingResponse:
+ """Generate embeddings using Google Gemini"""
+ # Normalize input to list
+ if isinstance(input_texts, str):
+ texts = [input_texts]
+ else:
+ texts = input_texts
+
+ results = []
+
+ # Gemini embeddings - use embedding model
+ for i, text in enumerate(texts):
+ response = self.genai.embed_content(
+ model=model,
+ content=text,
+ **kwargs
+ )
+
+ results.append(EmbeddingData(
+ embedding=response['embedding'],
+ index=i
+ ))
+
+ return EmbeddingResponse(
+ data=results,
+ model=model,
+ usage=None # Gemini embeddings don't typically return usage stats
+ )
+
+ async def list_models(self) -> List[str]:
+ models = self.genai.list_models()
+ return [model.name for model in models if 'generateContent' in model.supported_generation_methods]
+
+class UnifiedLLMProxy:
+ """Main proxy class that provides unified interface to all LLM providers"""
+
+ def __init__(self, default_provider: LLMProvider = LLMProvider.OLLAMA):
+ self.adapters: Dict[LLMProvider, BaseLLMAdapter] = {}
+ self.default_provider = default_provider
+ self._initialized_providers = set()
+
+ def configure_provider(self, provider: LLMProvider, **config):
+ """Configure a specific provider with its settings"""
+ adapter_classes = {
+ LLMProvider.OLLAMA: OllamaAdapter,
+ LLMProvider.OPENAI: OpenAIAdapter,
+ LLMProvider.ANTHROPIC: AnthropicAdapter,
+ LLMProvider.GEMINI: GeminiAdapter,
+ # Add other providers as needed
+ }
+
+ if provider in adapter_classes:
+ self.adapters[provider] = adapter_classes[provider](**config)
+ self._initialized_providers.add(provider)
+ else:
+ raise ValueError(f"Unsupported provider: {provider}")
+
+ def set_default_provider(self, provider: LLMProvider):
+ """Set the default provider for requests"""
+ if provider not in self._initialized_providers:
+ raise ValueError(f"Provider {provider} not configured")
+ self.default_provider = provider
+
+ async def chat(
+ self,
+ model: str,
+ messages: Union[List[LLMMessage], List[Dict[str, str]]],
+ provider: Optional[LLMProvider] = None,
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+ """Send chat messages using specified or default provider"""
+
+ provider = provider or self.default_provider
+ adapter = self._get_adapter(provider)
+
+ # Normalize messages to LLMMessage objects
+ normalized_messages = []
+ for msg in messages:
+ if isinstance(msg, LLMMessage):
+ normalized_messages.append(msg)
+ elif isinstance(msg, dict):
+ normalized_messages.append(LLMMessage.from_dict(msg))
+ else:
+ raise ValueError(f"Invalid message type: {type(msg)}")
+
+ return await adapter.chat(model, normalized_messages, stream, **kwargs)
+
+ async def chat_stream(
+ self,
+ model: str,
+ messages: Union[List[LLMMessage], List[Dict[str, str]]],
+ provider: Optional[LLMProvider] = None,
+ stream: bool = True,
+ **kwargs
+ ) -> AsyncGenerator[ChatResponse, None]:
+ """Stream chat messages using specified or default provider"""
+ if stream is False:
+ raise ValueError("stream must be True for chat_stream")
+ result = await self.chat(model, messages, provider, stream=True, **kwargs)
+ # Type checker now knows this is an AsyncGenerator due to stream=True
+ async for chunk in result: # type: ignore
+ yield chunk
+
+ async def chat_single(
+ self,
+ model: str,
+ messages: Union[List[LLMMessage], List[Dict[str, str]]],
+ provider: Optional[LLMProvider] = None,
+ **kwargs
+ ) -> ChatResponse:
+ """Get single chat response using specified or default provider"""
+
+ result = await self.chat(model, messages, provider, stream=False, **kwargs)
+ # Type checker now knows this is a ChatResponse due to stream=False
+ return result # type: ignore
+
+ async def generate(
+ self,
+ model: str,
+ prompt: str,
+ provider: Optional[LLMProvider] = None,
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+ """Generate text using specified or default provider"""
+
+ provider = provider or self.default_provider
+ adapter = self._get_adapter(provider)
+
+ return await adapter.generate(model, prompt, stream, **kwargs)
+
+ async def generate_stream(
+ self,
+ model: str,
+ prompt: str,
+ provider: Optional[LLMProvider] = None,
+ **kwargs
+ ) -> AsyncGenerator[ChatResponse, None]:
+ """Stream text generation using specified or default provider"""
+
+ result = await self.generate(model, prompt, provider, stream=True, **kwargs)
+ async for chunk in result: # type: ignore
+ yield chunk
+
+ async def generate_single(
+ self,
+ model: str,
+ prompt: str,
+ provider: Optional[LLMProvider] = None,
+ **kwargs
+ ) -> ChatResponse:
+ """Get single generation response using specified or default provider"""
+
+ result = await self.generate(model, prompt, provider, stream=False, **kwargs)
+ return result # type: ignore
+
+ async def embeddings(
+ self,
+ model: str,
+ input_texts: Union[str, List[str]],
+ provider: Optional[LLMProvider] = None,
+ **kwargs
+ ) -> EmbeddingResponse:
+ """Generate embeddings using specified or default provider"""
+ provider = provider or self.default_provider
+ adapter = self._get_adapter(provider)
+
+ return await adapter.embeddings(model, input_texts, **kwargs)
+
+ async def list_models(self, provider: Optional[LLMProvider] = None) -> List[str]:
+ """List available models for specified or default provider"""
+
+ provider = provider or self.default_provider
+ adapter = self._get_adapter(provider)
+
+ return await adapter.list_models()
+
+ async def list_embedding_models(self, provider: Optional[LLMProvider] = None) -> List[str]:
+ """List available embedding models for specified or default provider"""
+ provider = provider or self.default_provider
+
+ # Provider-specific embedding models
+ embedding_models = {
+ LLMProvider.OLLAMA: [
+ "nomic-embed-text",
+ "mxbai-embed-large",
+ "all-minilm",
+ "snowflake-arctic-embed"
+ ],
+ LLMProvider.OPENAI: [
+ "text-embedding-3-small",
+ "text-embedding-3-large",
+ "text-embedding-ada-002"
+ ],
+ LLMProvider.ANTHROPIC: [], # No embeddings API
+ LLMProvider.GEMINI: [
+ "models/embedding-001",
+ "models/text-embedding-004"
+ ]
+ }
+
+ if provider == LLMProvider.ANTHROPIC:
+ raise NotImplementedError("Anthropic does not provide embeddings API")
+
+ # For Ollama, check which embedding models are actually available
+ if provider == LLMProvider.OLLAMA:
+ try:
+ all_models = await self.list_models(provider)
+ available_embedding_models = []
+ suggested_models = embedding_models[provider]
+
+ for model in all_models:
+ # Check if it's a known embedding model
+ if any(emb_model in model for emb_model in suggested_models):
+ available_embedding_models.append(model)
+
+ return available_embedding_models if available_embedding_models else suggested_models
+ except:
+ return embedding_models[provider]
+
+ return embedding_models.get(provider, [])
+
+ def _get_adapter(self, provider: LLMProvider) -> BaseLLMAdapter:
+ """Get adapter for specified provider"""
+ if provider not in self.adapters:
+ raise ValueError(f"Provider {provider} not configured")
+ return self.adapters[provider]
+
+# Example usage and configuration
+class LLMManager:
+ """Singleton manager for the unified LLM proxy"""
+
+ _instance = None
+ _proxy = None
+
+ @classmethod
+ def get_instance(cls):
+ if cls._instance is None:
+ cls._instance = cls()
+ return cls._instance
+
+ def __init__(self):
+ if LLMManager._proxy is None:
+ LLMManager._proxy = UnifiedLLMProxy()
+ self._configure_from_environment()
+
+ def _configure_from_environment(self):
+ """Configure providers based on environment variables"""
+ if not self._proxy:
+ raise RuntimeError("UnifiedLLMProxy instance not initialized")
+ # Configure Ollama if available
+ ollama_host = os.getenv('OLLAMA_HOST', defines.ollama_api_url)
+ self._proxy.configure_provider(LLMProvider.OLLAMA, host=ollama_host)
+
+ # Configure OpenAI if API key is available
+ if os.getenv('OPENAI_API_KEY'):
+ self._proxy.configure_provider(LLMProvider.OPENAI)
+
+ # Configure Anthropic if API key is available
+ if os.getenv('ANTHROPIC_API_KEY'):
+ self._proxy.configure_provider(LLMProvider.ANTHROPIC)
+
+ # Configure Gemini if API key is available
+ if os.getenv('GEMINI_API_KEY'):
+ self._proxy.configure_provider(LLMProvider.GEMINI)
+
+ # Set default provider from environment or use Ollama
+ default_provider = os.getenv('DEFAULT_LLM_PROVIDER', 'ollama')
+ try:
+ self._proxy.set_default_provider(LLMProvider(default_provider))
+ except ValueError:
+ # Fallback to Ollama if specified provider not available
+ self._proxy.set_default_provider(LLMProvider.OLLAMA)
+
+ def get_proxy(self) -> UnifiedLLMProxy:
+ """Get the unified LLM proxy instance"""
+ if not self._proxy:
+ raise RuntimeError("UnifiedLLMProxy instance not initialized")
+ return self._proxy
+
+# Convenience function for easy access
+def get_llm() -> UnifiedLLMProxy:
+ """Get the configured LLM proxy"""
+ return LLMManager.get_instance().get_proxy()
+
+# Example usage with detailed statistics
+async def example_usage():
+ """Example showing how to access detailed statistics"""
+ llm = get_llm()
+
+ # Configure providers
+ llm.configure_provider(LLMProvider.OLLAMA, host="http://localhost:11434")
+
+ # Simple chat
+ messages = [
+ LLMMessage(role="user", content="Explain quantum computing in one paragraph")
+ ]
+
+ response = await llm.chat_single("llama2", messages)
+
+ print(f"Content: {response.content}")
+ print(f"Model: {response.model}")
+
+ if response.usage:
+ print(f"Usage Statistics:")
+ print(f" Prompt tokens: {response.usage.prompt_tokens}")
+ print(f" Completion tokens: {response.usage.completion_tokens}")
+ print(f" Total tokens: {response.usage.total_tokens}")
+
+ # Ollama-specific stats
+ if response.usage.prompt_eval_count:
+ print(f" Prompt eval count: {response.usage.prompt_eval_count}")
+ if response.usage.eval_count:
+ print(f" Eval count: {response.usage.eval_count}")
+ if response.usage.tokens_per_second:
+ print(f" Generation speed: {response.usage.tokens_per_second:.2f} tokens/sec")
+ if response.usage.prompt_tokens_per_second:
+ print(f" Prompt processing speed: {response.usage.prompt_tokens_per_second:.2f} tokens/sec")
+
+ # Access as dictionary for backward compatibility
+ usage_dict = response.get_usage_dict()
+ print(f"Usage as dict: {usage_dict}")
+
+async def example_embeddings_usage():
+ """Example showing how to use the embeddings API"""
+ llm = get_llm()
+
+ # Configure providers
+ llm.configure_provider(LLMProvider.OLLAMA, host="http://localhost:11434")
+ if os.getenv('OPENAI_API_KEY'):
+ llm.configure_provider(LLMProvider.OPENAI)
+
+ # List available embedding models
+ print("=== Available Embedding Models ===")
+ try:
+ ollama_embedding_models = await llm.list_embedding_models(LLMProvider.OLLAMA)
+ print(f"Ollama embedding models: {ollama_embedding_models}")
+ except Exception as e:
+ print(f"Could not list Ollama embedding models: {e}")
+
+ if os.getenv('OPENAI_API_KEY'):
+ try:
+ openai_embedding_models = await llm.list_embedding_models(LLMProvider.OPENAI)
+ print(f"OpenAI embedding models: {openai_embedding_models}")
+ except Exception as e:
+ print(f"Could not list OpenAI embedding models: {e}")
+
+ # Single text embedding
+ print("\n=== Single Text Embedding ===")
+ single_text = "The quick brown fox jumps over the lazy dog"
+
+ try:
+ response = await llm.embeddings("nomic-embed-text", single_text, provider=LLMProvider.OLLAMA)
+ print(f"Model: {response.model}")
+ print(f"Embedding dimension: {len(response.get_single_embedding())}")
+ print(f"First 5 values: {response.get_single_embedding()[:5]}")
+
+ if response.usage:
+ print(f"Usage: {response.usage.model_dump(exclude_none=True)}")
+ except Exception as e:
+ print(f"Ollama embedding failed: {e}")
+
+ # Batch embeddings
+ print("\n=== Batch Text Embeddings ===")
+ texts = [
+ "Machine learning is fascinating",
+ "Natural language processing enables AI communication",
+ "Deep learning uses neural networks",
+ "Transformers revolutionized AI"
+ ]
+
+ try:
+ # Try OpenAI if available
+ if os.getenv('OPENAI_API_KEY'):
+ response = await llm.embeddings("text-embedding-3-small", texts, provider=LLMProvider.OPENAI)
+ print(f"OpenAI Model: {response.model}")
+ print(f"Number of embeddings: {len(response.data)}")
+ print(f"Embedding dimension: {len(response.data[0].embedding)}")
+
+ # Calculate similarity between first two texts (requires numpy)
+ try:
+ import numpy as np # type: ignore
+ emb1 = np.array(response.data[0].embedding)
+ emb2 = np.array(response.data[1].embedding)
+ similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
+ print(f"Cosine similarity between first two texts: {similarity:.4f}")
+ except ImportError:
+ print("Install numpy to calculate similarity: pip install numpy")
+
+ if response.usage:
+ print(f"Usage: {response.usage.model_dump(exclude_none=True)}")
+ else:
+ # Fallback to Ollama for batch
+ response = await llm.embeddings("nomic-embed-text", texts, provider=LLMProvider.OLLAMA)
+ print(f"Ollama Model: {response.model}")
+ print(f"Number of embeddings: {len(response.data)}")
+ print(f"Embedding dimension: {len(response.data[0].embedding)}")
+
+ except Exception as e:
+ print(f"Batch embedding failed: {e}")
+
+async def example_streaming_with_stats():
+ """Example showing how to collect usage stats from streaming responses"""
+ llm = get_llm()
+
+ messages = [
+ LLMMessage(role="user", content="Write a short story about AI")
+ ]
+
+ print("Streaming response:")
+ final_stats = None
+
+ async for chunk in llm.chat_stream("llama2", messages):
+ # Print content as it streams
+ if chunk.content:
+ print(chunk.content, end="", flush=True)
+
+ # Collect usage stats from final chunk
+ if chunk.usage:
+ final_stats = chunk.usage
+
+ print("\n\nFinal Usage Statistics:")
+ if final_stats:
+ print(f" Total tokens: {final_stats.total_tokens}")
+ print(f" Generation speed: {final_stats.tokens_per_second:.2f} tokens/sec" if final_stats.tokens_per_second else " Generation speed: N/A")
+ else:
+ print(" No usage statistics available")
+
+if __name__ == "__main__":
+ asyncio.run(example_usage())
+ print("\n" + "="*50 + "\n")
+ asyncio.run(example_streaming_with_stats())
+ print("\n" + "="*50 + "\n")
+ asyncio.run(example_embeddings_usage())
\ No newline at end of file
diff --git a/src/backend/main.py b/src/backend/main.py
index e6b1ce4..4af0067 100644
--- a/src/backend/main.py
+++ b/src/backend/main.py
@@ -13,6 +13,9 @@ import uuid
import defines
import pathlib
+from markitdown import MarkItDown, StreamInfo # type: ignore
+import io
+
import uvicorn # type: ignore
from typing import List, Optional, Dict, Any
from datetime import datetime, timedelta, UTC
@@ -53,7 +56,7 @@ import defines
from logger import logger
from database import RedisDatabase, redis_manager, DatabaseManager
from metrics import Metrics
-from llm_manager import llm_manager
+import llm_proxy as llm_manager
import entities
from email_service import VerificationEmailRateLimiter, email_service
from device_manager import DeviceManager
@@ -116,7 +119,8 @@ async def lifespan(app: FastAPI):
try:
# Initialize database
await db_manager.initialize()
-
+ entities.entity_manager.initialize(prometheus_collector, database=db_manager.get_database())
+
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
@@ -1827,7 +1831,7 @@ async def upload_candidate_document(
yield error_message
return
- converted = False;
+ converted = False
if document_type != DocumentType.MARKDOWN and document_type != DocumentType.TXT:
p = pathlib.Path(file_path)
p_as_md = p.with_suffix(".md")
@@ -1873,53 +1877,6 @@ async def upload_candidate_document(
content=file_content,
)
yield chat_message
-
- # If this is a job description, process it with the job requirements agent
- if not options.is_job_document:
- return
-
- status_message = ChatMessageStatus(
- session_id=MOCK_UUID, # No session ID for document uploads
- content=f"Initiating connection with {candidate.first_name}'s AI agent...",
- activity=ApiActivityType.INFO
- )
- yield status_message
- await asyncio.sleep(0)
-
- async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
- chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
- if not chat_agent:
- error_message = ChatMessageError(
- session_id=MOCK_UUID, # No session ID for document uploads
- content="No agent found for job requirements chat type"
- )
- yield error_message
- return
- message = None
- status_message = ChatMessageStatus(
- session_id=MOCK_UUID, # No session ID for document uploads
- content=f"Analyzing document for company and requirement details...",
- activity=ApiActivityType.SEARCHING
- )
- yield status_message
- await asyncio.sleep(0)
-
- async for message in chat_agent.generate(
- llm=llm_manager.get_llm(),
- model=defines.model,
- session_id=MOCK_UUID,
- prompt=file_content
- ):
- pass
- if not message or not isinstance(message, JobRequirementsMessage):
- error_message = ChatMessageError(
- session_id=MOCK_UUID, # No session ID for document uploads
- content="Failed to process job description file"
- )
- yield error_message
- return
- yield message
-
try:
async def to_json(method):
try:
@@ -1932,7 +1889,6 @@ async def upload_candidate_document(
logger.error(f"Error in to_json conversion: {e}")
return
-# return DebugStreamingResponse(
return StreamingResponse(
to_json(upload_stream_generator(file_content)),
media_type="text/event-stream",
@@ -1944,15 +1900,64 @@ async def upload_candidate_document(
"Access-Control-Allow-Origin": "*", # Adjust for your CORS needs
"Transfer-Encoding": "chunked",
},
- )
+ )
except Exception as e:
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Document upload error: {e}")
- return JSONResponse(
- status_code=500,
- content=create_error_response("UPLOAD_ERROR", "Failed to upload document")
+ return StreamingResponse(
+ iter([ChatMessageError(
+ session_id=MOCK_UUID, # No session ID for document uploads
+ content="Failed to upload document"
+ )]),
+ media_type="text/event-stream"
)
+async def create_job_from_content(database: RedisDatabase, current_user: Candidate, content: str):
+ status_message = ChatMessageStatus(
+ session_id=MOCK_UUID, # No session ID for document uploads
+ content=f"Initiating connection with {current_user.first_name}'s AI agent...",
+ activity=ApiActivityType.INFO
+ )
+ yield status_message
+ await asyncio.sleep(0) # Let the status message propagate
+
+ async with entities.get_candidate_entity(candidate=current_user) as candidate_entity:
+ chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
+ if not chat_agent:
+ error_message = ChatMessageError(
+ session_id=MOCK_UUID, # No session ID for document uploads
+ content="No agent found for job requirements chat type"
+ )
+ yield error_message
+ return
+ message = None
+ status_message = ChatMessageStatus(
+ session_id=MOCK_UUID, # No session ID for document uploads
+ content=f"Analyzing document for company and requirement details...",
+ activity=ApiActivityType.SEARCHING
+ )
+ yield status_message
+ await asyncio.sleep(0)
+
+ async for message in chat_agent.generate(
+ llm=llm_manager.get_llm(),
+ model=defines.model,
+ session_id=MOCK_UUID,
+ prompt=content
+ ):
+ pass
+ if not message or not isinstance(message, JobRequirementsMessage):
+ error_message = ChatMessageError(
+ session_id=MOCK_UUID, # No session ID for document uploads
+ content="Failed to process job description file"
+ )
+ yield error_message
+ return
+
+ logger.info(f"✅ Successfully saved job requirements job {message.id}")
+ yield message
+ return
+
@api_router.post("/candidates/profile/upload")
async def upload_candidate_profile(
file: UploadFile = File(...),
@@ -2573,6 +2578,7 @@ async def delete_candidate(
status_code=500,
content=create_error_response("DELETE_ERROR", "Failed to delete candidate")
)
+
@api_router.patch("/candidates/{candidate_id}")
async def update_candidate(
candidate_id: str = Path(...),
@@ -2816,6 +2822,139 @@ async def create_candidate_job(
content=create_error_response("CREATION_FAILED", str(e))
)
+
+@api_router.post("/jobs/upload")
+async def create_job_from_file(
+ file: UploadFile = File(...),
+ current_user = Depends(get_current_user),
+ database: RedisDatabase = Depends(get_database)
+):
+ """Upload a job document for the current candidate and create a Job"""
+ # Check file size (limit to 10MB)
+ max_size = 10 * 1024 * 1024 # 10MB
+ file_content = await file.read()
+ if len(file_content) > max_size:
+ logger.info(f"⚠️ File too large: {file.filename} ({len(file_content)} bytes)")
+ return StreamingResponse(
+ iter([ChatMessageError(
+ session_id=MOCK_UUID, # No session ID for document uploads
+ content="File size exceeds 10MB limit"
+ )]),
+ media_type="text/event-stream"
+ )
+ if len(file_content) == 0:
+ logger.info(f"⚠️ File is empty: {file.filename}")
+ return StreamingResponse(
+ iter([ChatMessageError(
+ session_id=MOCK_UUID, # No session ID for document uploads
+ content="File is empty"
+ )]),
+ media_type="text/event-stream"
+ )
+
+ """Upload a document for the current candidate"""
+ async def upload_stream_generator(file_content):
+ # Verify user is a candidate
+ if current_user.user_type != "candidate":
+ logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}")
+ error_message = ChatMessageError(
+ session_id=MOCK_UUID, # No session ID for document uploads
+ content="Only candidates can upload documents"
+ )
+ yield error_message
+ return
+
+ file.filename = re.sub(r'^.*/', '', file.filename) if file.filename else '' # Sanitize filename
+ if not file.filename or file.filename.strip() == "":
+ logger.warning("⚠️ File upload attempt with missing filename")
+ error_message = ChatMessageError(
+ session_id=MOCK_UUID, # No session ID for document uploads
+ content="File must have a valid filename"
+ )
+ yield error_message
+ return
+
+ logger.info(f"📁 Received file upload: filename='{file.filename}', content_type='{file.content_type}', size='{len(file_content)} bytes'")
+
+ # Validate file type
+ allowed_types = ['.txt', '.md', '.docx', '.pdf', '.png', '.jpg', '.jpeg', '.gif']
+ file_extension = pathlib.Path(file.filename).suffix.lower() if file.filename else ""
+
+ if file_extension not in allowed_types:
+ logger.warning(f"⚠️ Invalid file type: {file_extension} for file {file.filename}")
+ error_message = ChatMessageError(
+ session_id=MOCK_UUID, # No session ID for document uploads
+ content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}"
+ )
+ yield error_message
+ return
+
+ document_type = get_document_type_from_filename(file.filename or "unknown.txt")
+
+ if document_type != DocumentType.MARKDOWN and document_type != DocumentType.TXT:
+ status_message = ChatMessageStatus(
+ session_id=MOCK_UUID, # No session ID for document uploads
+ content=f"Converting content from {document_type}...",
+ activity=ApiActivityType.CONVERTING
+ )
+ yield status_message
+ try:
+ md = MarkItDown(enable_plugins=False) # Set to True to enable plugins
+ stream = io.BytesIO(file_content)
+ stream_info = StreamInfo(
+ extension=file_extension, # e.g., ".pdf"
+ url=file.filename # optional, helps with logging and guessing
+ )
+ result = md.convert_stream(stream, stream_info=stream_info, output_format="markdown")
+ file_content = result.text_content
+ logger.info(f"✅ Converted {file.filename} to Markdown format")
+ except Exception as e:
+ error_message = ChatMessageError(
+ session_id=MOCK_UUID, # No session ID for document uploads
+ content=f"Failed to convert {file.filename} to Markdown.",
+ )
+ yield error_message
+ logger.error(f"❌ Error converting {file.filename} to Markdown: {e}")
+ return
+ async for message in create_job_from_content(database=database, current_user=current_user, content=file_content):
+ yield message
+ return
+
+ try:
+ async def to_json(method):
+ try:
+ async for message in method:
+ json_data = message.model_dump(mode='json', by_alias=True)
+ json_str = json.dumps(json_data)
+ yield f"data: {json_str}\n\n".encode("utf-8")
+ except Exception as e:
+ logger.error(backstory_traceback.format_exc())
+ logger.error(f"Error in to_json conversion: {e}")
+ return
+
+ return StreamingResponse(
+ to_json(upload_stream_generator(file_content)),
+ media_type="text/event-stream",
+ headers={
+ "Cache-Control": "no-cache, no-store, must-revalidate",
+ "Connection": "keep-alive",
+ "X-Accel-Buffering": "no", # Nginx
+ "X-Content-Type-Options": "nosniff",
+ "Access-Control-Allow-Origin": "*", # Adjust for your CORS needs
+ "Transfer-Encoding": "chunked",
+ },
+ )
+ except Exception as e:
+ logger.error(backstory_traceback.format_exc())
+ logger.error(f"❌ Document upload error: {e}")
+ return StreamingResponse(
+ iter([ChatMessageError(
+ session_id=MOCK_UUID, # No session ID for document uploads
+ content="Failed to upload document"
+ )]),
+ media_type="text/event-stream"
+ )
+
@api_router.get("/jobs/{job_id}")
async def get_job(
job_id: str = Path(...),
@@ -2931,6 +3070,49 @@ async def search_jobs(
content=create_error_response("SEARCH_FAILED", str(e))
)
+
+@api_router.delete("/jobs/{job_id}")
+async def delete_job(
+ job_id: str = Path(...),
+ admin_user = Depends(get_current_admin),
+ database: RedisDatabase = Depends(get_database)
+):
+ """Delete a Job"""
+ try:
+ # Check if admin user
+ if not admin_user.is_admin:
+ logger.warning(f"⚠️ Unauthorized delete attempt by user {admin_user.id}")
+ return JSONResponse(
+ status_code=403,
+ content=create_error_response("FORBIDDEN", "Only admins can delete")
+ )
+
+ # Get candidate data
+ job_data = await database.get_job(job_id)
+ if not job_data:
+ logger.warning(f"⚠️ Candidate not found for deletion: {job_id}")
+ return JSONResponse(
+ status_code=404,
+ content=create_error_response("NOT_FOUND", "Job not found")
+ )
+
+ # Delete job from database
+ await database.delete_job(job_id)
+
+ logger.info(f"🗑️ Job deleted: {job_id} by admin {admin_user.id}")
+
+ return create_success_response({
+ "message": "Job deleted successfully",
+ "jobId": job_id
+ })
+
+ except Exception as e:
+ logger.error(f"❌ Delete job error: {e}")
+ return JSONResponse(
+ status_code=500,
+ content=create_error_response("DELETE_ERROR", "Failed to delete job")
+ )
+
# ============================
# Chat Endpoints
# ============================
@@ -3541,7 +3723,7 @@ async def get_candidate_skill_match(
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
):
- """Get skill match for a candidate against a requirement"""
+ """Get skill match for a candidate against a requirement with caching"""
try:
# Find candidate by ID
candidate_data = await database.get_candidate(candidate_id)
@@ -3553,34 +3735,89 @@ async def get_candidate_skill_match(
candidate = Candidate.model_validate(candidate_data)
- async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
- logger.info(f"🔍 Running skill match for candidate {candidate_entity.username} against requirement: {requirement}")
- agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.SKILL_MATCH)
- if not agent:
- return JSONResponse(
- status_code=400,
- content=create_error_response("AGENT_NOT_FOUND", "No skill match agent found for this candidate")
- )
- # Entity automatically released when done
- skill_match = await get_last_item(
- agent.generate(
- llm=llm_manager.get_llm(),
- model=defines.model,
- session_id=MOCK_UUID,
- prompt=requirement,
+ # Create cache key for this specific candidate + requirement combination
+ cache_key = f"skill_match:{candidate_id}:{hash(requirement)}"
+
+ # Get cached assessment if it exists
+ cached_assessment = await database.get_cached_skill_match(cache_key)
+
+ # Get the last update time for the candidate's skill information
+ candidate_skill_update_time = await database.get_candidate_skill_update_time(candidate_id)
+
+ # Get the latest RAG data update time for the current user
+ user_rag_update_time = await database.get_user_rag_update_time(current_user.id)
+
+ # Determine if we need to regenerate the assessment
+ should_regenerate = True
+ cached_date = None
+
+ if cached_assessment:
+ cached_date = cached_assessment.get('cached_at')
+ if cached_date:
+ # Check if cached result is still valid
+ # Regenerate if:
+ # 1. Candidate skills were updated after cache date
+ # 2. User's RAG data was updated after cache date
+ if (not candidate_skill_update_time or cached_date >= candidate_skill_update_time) and \
+ (not user_rag_update_time or cached_date >= user_rag_update_time):
+ should_regenerate = False
+ logger.info(f"🔄 Using cached skill match for candidate {candidate.id}")
+
+ if should_regenerate:
+ logger.info(f"🔍 Generating new skill match for candidate {candidate.id} against requirement: {requirement}")
+
+ async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
+ agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.SKILL_MATCH)
+ if not agent:
+ return JSONResponse(
+ status_code=400,
+ content=create_error_response("AGENT_NOT_FOUND", "No skill match agent found for this candidate")
+ )
+
+ # Generate new skill match
+ skill_match = await get_last_item(
+ agent.generate(
+ llm=llm_manager.get_llm(),
+ model=defines.model,
+ session_id=MOCK_UUID,
+ prompt=requirement,
),
)
- if skill_match is None:
+
+ if skill_match is None:
+ return JSONResponse(
+ status_code=500,
+ content=create_error_response("NO_MATCH", "No skill match found for the given requirement")
+ )
+
+ skill_match_data = json.loads(skill_match.content)
+
+ # Cache the new assessment with current timestamp
+ cached_assessment = {
+ "skill_match": skill_match_data,
+ "cached_at": datetime.utcnow().isoformat(),
+ "candidate_id": candidate_id,
+ "requirement": requirement
+ }
+
+ await database.cache_skill_match(cache_key, cached_assessment)
+ logger.info(f"💾 Cached new skill match for candidate {candidate.id}")
+ logger.info(f"✅ Skill match found for candidate {candidate.id}: {skill_match_data['evidence_strength']}")
+ else:
+ # Use cached result - we know cached_assessment is not None here
+ if cached_assessment is None:
return JSONResponse(
status_code=500,
- content=create_error_response("NO_MATCH", "No skill match found for the given requirement")
+ content=create_error_response("CACHE_ERROR", "Unexpected cache state")
)
- skill_match = json.loads(skill_match.content)
- logger.info(f"✅ Skill match found for candidate {candidate.id}: {skill_match['evidence_strength']}")
+ skill_match_data = cached_assessment["skill_match"]
+ logger.info(f"✅ Retrieved cached skill match for candidate {candidate.id}: {skill_match_data['evidence_strength']}")
return create_success_response({
"candidateId": candidate.id,
- "skillMatch": skill_match
+ "skillMatch": skill_match_data,
+ "cached": not should_regenerate,
+ "cacheTimestamp": cached_date
})
except Exception as e:
@@ -3589,8 +3826,8 @@ async def get_candidate_skill_match(
return JSONResponse(
status_code=500,
content=create_error_response("SKILL_MATCH_ERROR", str(e))
- )
-
+ )
+
@api_router.get("/candidates/{username}/chat-sessions")
async def get_candidate_chat_sessions(
username: str = Path(...),
@@ -3911,7 +4148,6 @@ async def track_requests(request, call_next):
# FastAPI Metrics
# ============================
prometheus_collector = CollectorRegistry()
-entities.entity_manager.initialize(prometheus_collector)
# Keep the Instrumentator instance alive
instrumentator = Instrumentator(
diff --git a/src/backend/models.py b/src/backend/models.py
index d5010eb..4dfa3e0 100644
--- a/src/backend/models.py
+++ b/src/backend/models.py
@@ -768,11 +768,7 @@ class ChatOptions(BaseModel):
"populate_by_name": True # Allow both field names and aliases
}
-
-class LLMMessage(BaseModel):
- role: str = Field(default="")
- content: str = Field(default="")
- tool_calls: Optional[List[Dict]] = Field(default=[], exclude=True)
+from llm_proxy import (LLMMessage)
class ApiMessage(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
diff --git a/src/backend/rag/rag.py b/src/backend/rag/rag.py
index 0c6ba48..b6646ed 100644
--- a/src/backend/rag/rag.py
+++ b/src/backend/rag/rag.py
@@ -13,7 +13,6 @@ import numpy as np # type: ignore
import traceback
import chromadb # type: ignore
-import ollama
from watchdog.observers import Observer # type: ignore
from watchdog.events import FileSystemEventHandler # type: ignore
import umap # type: ignore
@@ -27,6 +26,7 @@ from .markdown_chunker import (
# When imported as a module, use relative imports
import defines
+from database import RedisDatabase
from models import ChromaDBGetResponse
__all__ = ["ChromaDBFileWatcher", "start_file_watcher"]
@@ -47,11 +47,16 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
loop,
persist_directory,
collection_name,
+ database: RedisDatabase,
+ user_id: str,
chunk_size=DEFAULT_CHUNK_SIZE,
chunk_overlap=DEFAULT_CHUNK_OVERLAP,
recreate=False,
):
self.llm = llm
+ self.database = database
+ self.user_id = user_id
+ self.database = database
self.watch_directory = watch_directory
self.persist_directory = persist_directory or defines.persist_directory
self.collection_name = collection_name
@@ -284,6 +289,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
if results and "ids" in results and results["ids"]:
self.collection.delete(ids=results["ids"])
+ await self.database.update_user_rag_timestamp(self.user_id)
logging.info(
f"Removed {len(results['ids'])} chunks for deleted file: {file_path}"
)
@@ -372,14 +378,15 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
name=self.collection_name, metadata={"hnsw:space": "cosine"}
)
- def get_embedding(self, text: str) -> np.ndarray:
+ async def get_embedding(self, text: str) -> np.ndarray:
"""Generate and normalize an embedding for the given text."""
# Get embedding
try:
- response = self.llm.embeddings(model=defines.embedding_model, prompt=text)
- embedding = np.array(response["embedding"])
+ response = await self.llm.embeddings(model=defines.embedding_model, input_texts=text)
+ embedding = np.array(response.get_single_embedding())
except Exception as e:
+ logging.error(traceback.format_exc())
logging.error(f"Failed to get embedding: {e}")
raise
@@ -404,7 +411,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
return embedding
- def add_embeddings_to_collection(self, chunks: List[Chunk]):
+ async def _add_embeddings_to_collection(self, chunks: List[Chunk]):
"""Add embeddings for chunks to the collection."""
for i, chunk in enumerate(chunks):
@@ -420,7 +427,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
content_hash = hashlib.md5(text.encode()).hexdigest()[:8]
chunk_id = f"{path_hash}_{i}_{content_hash}"
- embedding = self.get_embedding(text)
+ embedding = await self.get_embedding(text)
try:
self.collection.add(
ids=[chunk_id],
@@ -458,11 +465,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# 0.5 - 0.7 0.65 - 0.75 Balanced precision/recall
# 0.7 - 0.9 0.55 - 0.65 Higher recall, more inclusive
# 0.9 - 1.2 0.40 - 0.55 Very inclusive, may include tangential content
- def find_similar(self, query, top_k=defines.default_rag_top_k, threshold=defines.default_rag_threshold):
+ async def find_similar(self, query, top_k=defines.default_rag_top_k, threshold=defines.default_rag_threshold):
"""Find similar documents to the query."""
# collection is configured with hnsw:space cosine
- query_embedding = self.get_embedding(query)
+ query_embedding = await self.get_embedding(query)
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
@@ -572,6 +579,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
and existing_results["ids"]
):
self.collection.delete(ids=existing_results["ids"])
+ await self.database.update_user_rag_timestamp(self.user_id)
extensions = (".docx", ".xlsx", ".xls", ".pdf")
if file_path.endswith(extensions):
@@ -606,7 +614,8 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# f.write(json.dumps(chunk, indent=2))
# Add chunks to collection
- self.add_embeddings_to_collection(chunks)
+ await self._add_embeddings_to_collection(chunks)
+ await self.database.update_user_rag_timestamp(self.user_id)
logging.info(f"Updated {len(chunks)} chunks for file: {file_path}")
@@ -640,9 +649,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# Function to start the file watcher
def start_file_watcher(
llm,
+ user_id,
watch_directory,
persist_directory,
collection_name,
+ database: RedisDatabase,
initialize=False,
recreate=False,
):
@@ -663,9 +674,11 @@ def start_file_watcher(
llm,
watch_directory=watch_directory,
loop=loop,
+ user_id=user_id,
persist_directory=persist_directory,
collection_name=collection_name,
recreate=recreate,
+ database=database
)
# Process all files if:
diff --git a/src/multi-llm/config.md b/src/multi-llm/config.md
new file mode 100644
index 0000000..3c11bcf
--- /dev/null
+++ b/src/multi-llm/config.md
@@ -0,0 +1,193 @@
+# Environment Configuration Examples
+
+# ==== Development (Ollama only) ====
+export OLLAMA_HOST="http://localhost:11434"
+export DEFAULT_LLM_PROVIDER="ollama"
+
+# ==== Production with OpenAI ====
+export OPENAI_API_KEY="sk-your-openai-key-here"
+export DEFAULT_LLM_PROVIDER="openai"
+
+# ==== Production with Anthropic ====
+export ANTHROPIC_API_KEY="sk-ant-your-anthropic-key-here"
+export DEFAULT_LLM_PROVIDER="anthropic"
+
+# ==== Production with multiple providers ====
+export OPENAI_API_KEY="sk-your-openai-key-here"
+export ANTHROPIC_API_KEY="sk-ant-your-anthropic-key-here"
+export GEMINI_API_KEY="your-gemini-key-here"
+export OLLAMA_HOST="http://ollama-server:11434"
+export DEFAULT_LLM_PROVIDER="openai"
+
+# ==== Docker Compose Example ====
+# docker-compose.yml
+version: '3.8'
+services:
+ api:
+ build: .
+ ports:
+ - "8000:8000"
+ environment:
+ - OPENAI_API_KEY=${OPENAI_API_KEY}
+ - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
+ - DEFAULT_LLM_PROVIDER=openai
+ depends_on:
+ - ollama
+
+ ollama:
+ image: ollama/ollama
+ ports:
+ - "11434:11434"
+ volumes:
+ - ollama_data:/root/.ollama
+ environment:
+ - OLLAMA_HOST=0.0.0.0
+
+volumes:
+ ollama_data:
+
+# ==== Kubernetes ConfigMap ====
+apiVersion: v1
+kind: ConfigMap
+metadata:
+ name: llm-config
+data:
+ DEFAULT_LLM_PROVIDER: "openai"
+ OLLAMA_HOST: "http://ollama-service:11434"
+
+---
+apiVersion: v1
+kind: Secret
+metadata:
+ name: llm-secrets
+type: Opaque
+stringData:
+ OPENAI_API_KEY: "sk-your-key-here"
+ ANTHROPIC_API_KEY: "sk-ant-your-key-here"
+
+# ==== Python configuration example ====
+# config.py
+import os
+from llm_proxy import get_llm, LLMProvider
+
+def configure_llm_for_environment():
+ """Configure LLM based on deployment environment"""
+ llm = get_llm()
+
+ # Development: Use Ollama
+ if os.getenv('ENVIRONMENT') == 'development':
+ llm.configure_provider(LLMProvider.OLLAMA, host='http://localhost:11434')
+ llm.set_default_provider(LLMProvider.OLLAMA)
+
+ # Staging: Use OpenAI with rate limiting
+ elif os.getenv('ENVIRONMENT') == 'staging':
+ llm.configure_provider(LLMProvider.OPENAI,
+ api_key=os.getenv('OPENAI_API_KEY'),
+ max_retries=3,
+ timeout=30)
+ llm.set_default_provider(LLMProvider.OPENAI)
+
+ # Production: Use multiple providers with fallback
+ elif os.getenv('ENVIRONMENT') == 'production':
+ # Primary: Anthropic
+ llm.configure_provider(LLMProvider.ANTHROPIC,
+ api_key=os.getenv('ANTHROPIC_API_KEY'))
+
+ # Fallback: OpenAI
+ llm.configure_provider(LLMProvider.OPENAI,
+ api_key=os.getenv('OPENAI_API_KEY'))
+
+ # Set primary provider
+ llm.set_default_provider(LLMProvider.ANTHROPIC)
+
+# ==== Usage Examples ====
+
+# Example 1: Basic usage with default provider
+async def basic_example():
+ llm = get_llm()
+
+ response = await llm.chat(
+ model="gpt-4",
+ messages=[{"role": "user", "content": "Hello!"}]
+ )
+ print(response.content)
+
+# Example 2: Specify provider explicitly
+async def provider_specific_example():
+ llm = get_llm()
+
+ # Use OpenAI specifically
+ response = await llm.chat(
+ model="gpt-4",
+ messages=[{"role": "user", "content": "Hello!"}],
+ provider=LLMProvider.OPENAI
+ )
+
+ # Use Anthropic specifically
+ response2 = await llm.chat(
+ model="claude-3-sonnet-20240229",
+ messages=[{"role": "user", "content": "Hello!"}],
+ provider=LLMProvider.ANTHROPIC
+ )
+
+# Example 3: Streaming with provider fallback
+async def streaming_with_fallback():
+ llm = get_llm()
+
+ try:
+ async for chunk in llm.chat(
+ model="claude-3-sonnet-20240229",
+ messages=[{"role": "user", "content": "Write a story"}],
+ provider=LLMProvider.ANTHROPIC,
+ stream=True
+ ):
+ print(chunk.content, end='', flush=True)
+ except Exception as e:
+ print(f"Primary provider failed: {e}")
+ # Fallback to OpenAI
+ async for chunk in llm.chat(
+ model="gpt-4",
+ messages=[{"role": "user", "content": "Write a story"}],
+ provider=LLMProvider.OPENAI,
+ stream=True
+ ):
+ print(chunk.content, end='', flush=True)
+
+# Example 4: Load balancing between providers
+import random
+
+async def load_balanced_request():
+ llm = get_llm()
+ available_providers = [LLMProvider.OPENAI, LLMProvider.ANTHROPIC]
+
+ # Simple random load balancing
+ provider = random.choice(available_providers)
+
+ # Adjust model based on provider
+ model_mapping = {
+ LLMProvider.OPENAI: "gpt-4",
+ LLMProvider.ANTHROPIC: "claude-3-sonnet-20240229"
+ }
+
+ response = await llm.chat(
+ model=model_mapping[provider],
+ messages=[{"role": "user", "content": "Hello!"}],
+ provider=provider
+ )
+
+ print(f"Response from {provider.value}: {response.content}")
+
+# ==== FastAPI Startup Configuration ====
+from fastapi import FastAPI
+
+app = FastAPI()
+
+@app.on_event("startup")
+async def startup_event():
+ """Configure LLM providers on application startup"""
+ configure_llm_for_environment()
+
+ # Verify configuration
+ llm = get_llm()
+ print(f"Configured providers: {[p.value for p in llm._initialized_providers]}")
+ print(f"Default provider: {llm.default_provider.value}")
\ No newline at end of file
diff --git a/src/multi-llm/example.py b/src/multi-llm/example.py
new file mode 100644
index 0000000..30f17cc
--- /dev/null
+++ b/src/multi-llm/example.py
@@ -0,0 +1,238 @@
+from fastapi import FastAPI, HTTPException # type: ignore
+from fastapi.responses import StreamingResponse # type: ignore
+from pydantic import BaseModel # type: ignore
+from typing import List, Optional, Dict, Any
+import json
+import asyncio
+from llm_proxy import get_llm, LLMProvider, ChatMessage
+
+app = FastAPI(title="Unified LLM API")
+
+# Pydantic models for request/response
+class ChatRequest(BaseModel):
+ model: str
+ messages: List[Dict[str, str]]
+ provider: Optional[str] = None
+ stream: bool = False
+ max_tokens: Optional[int] = None
+ temperature: Optional[float] = None
+
+class GenerateRequest(BaseModel):
+ model: str
+ prompt: str
+ provider: Optional[str] = None
+ stream: bool = False
+ max_tokens: Optional[int] = None
+ temperature: Optional[float] = None
+
+class ChatResponseModel(BaseModel):
+ content: str
+ model: str
+ provider: str
+ finish_reason: Optional[str] = None
+ usage: Optional[Dict[str, int]] = None
+
+@app.post("/chat")
+async def chat_endpoint(request: ChatRequest):
+ """Chat endpoint that works with any configured LLM provider"""
+ try:
+ llm = get_llm()
+
+ # Convert provider string to enum if provided
+ provider = None
+ if request.provider:
+ try:
+ provider = LLMProvider(request.provider.lower())
+ except ValueError:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Unsupported provider: {request.provider}"
+ )
+
+ # Prepare kwargs
+ kwargs = {}
+ if request.max_tokens:
+ kwargs['max_tokens'] = request.max_tokens
+ if request.temperature:
+ kwargs['temperature'] = request.temperature
+
+ if request.stream:
+ async def generate_response():
+ try:
+ # Use the type-safe streaming method
+ async for chunk in llm.chat_stream(
+ model=request.model,
+ messages=request.messages,
+ provider=provider,
+ **kwargs
+ ):
+ # Format as Server-Sent Events
+ data = {
+ "content": chunk.content,
+ "model": chunk.model,
+ "provider": provider.value if provider else llm.default_provider.value,
+ "finish_reason": chunk.finish_reason
+ }
+ yield f"data: {json.dumps(data)}\n\n"
+ except Exception as e:
+ error_data = {"error": str(e)}
+ yield f"data: {json.dumps(error_data)}\n\n"
+ finally:
+ yield "data: [DONE]\n\n"
+
+ return StreamingResponse(
+ generate_response(),
+ media_type="text/event-stream",
+ headers={
+ "Cache-Control": "no-cache",
+ "Connection": "keep-alive"
+ }
+ )
+ else:
+ # Use the type-safe single response method
+ response = await llm.chat_single(
+ model=request.model,
+ messages=request.messages,
+ provider=provider,
+ **kwargs
+ )
+
+ return ChatResponseModel(
+ content=response.content,
+ model=response.model,
+ provider=provider.value if provider else llm.default_provider.value,
+ finish_reason=response.finish_reason,
+ usage=response.usage
+ )
+
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+@app.post("/generate")
+async def generate_endpoint(request: GenerateRequest):
+ """Generate endpoint for simple text generation"""
+ try:
+ llm = get_llm()
+
+ provider = None
+ if request.provider:
+ try:
+ provider = LLMProvider(request.provider.lower())
+ except ValueError:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Unsupported provider: {request.provider}"
+ )
+
+ kwargs = {}
+ if request.max_tokens:
+ kwargs['max_tokens'] = request.max_tokens
+ if request.temperature:
+ kwargs['temperature'] = request.temperature
+
+ if request.stream:
+ async def generate_response():
+ try:
+ # Use the type-safe streaming method
+ async for chunk in llm.generate_stream(
+ model=request.model,
+ prompt=request.prompt,
+ provider=provider,
+ **kwargs
+ ):
+ data = {
+ "content": chunk.content,
+ "model": chunk.model,
+ "provider": provider.value if provider else llm.default_provider.value,
+ "finish_reason": chunk.finish_reason
+ }
+ yield f"data: {json.dumps(data)}\n\n"
+ except Exception as e:
+ error_data = {"error": str(e)}
+ yield f"data: {json.dumps(error_data)}\n\n"
+ finally:
+ yield "data: [DONE]\n\n"
+
+ return StreamingResponse(
+ generate_response(),
+ media_type="text/event-stream"
+ )
+ else:
+ # Use the type-safe single response method
+ response = await llm.generate_single(
+ model=request.model,
+ prompt=request.prompt,
+ provider=provider,
+ **kwargs
+ )
+
+ return ChatResponseModel(
+ content=response.content,
+ model=response.model,
+ provider=provider.value if provider else llm.default_provider.value,
+ finish_reason=response.finish_reason,
+ usage=response.usage
+ )
+
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+@app.get("/models")
+async def list_models(provider: Optional[str] = None):
+ """List available models for a provider"""
+ try:
+ llm = get_llm()
+
+ provider_enum = None
+ if provider:
+ try:
+ provider_enum = LLMProvider(provider.lower())
+ except ValueError:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Unsupported provider: {provider}"
+ )
+
+ models = await llm.list_models(provider_enum)
+ return {
+ "provider": provider_enum.value if provider_enum else llm.default_provider.value,
+ "models": models
+ }
+
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+@app.get("/providers")
+async def list_providers():
+ """List all configured providers"""
+ llm = get_llm()
+ return {
+ "providers": [provider.value for provider in llm._initialized_providers],
+ "default": llm.default_provider.value
+ }
+
+@app.post("/providers/{provider}/set-default")
+async def set_default_provider(provider: str):
+ """Set the default provider"""
+ try:
+ llm = get_llm()
+ provider_enum = LLMProvider(provider.lower())
+ llm.set_default_provider(provider_enum)
+ return {"message": f"Default provider set to {provider}", "default": provider}
+ except ValueError as e:
+ raise HTTPException(status_code=400, detail=str(e))
+
+# Health check endpoint
+@app.get("/health")
+async def health_check():
+ """Health check endpoint"""
+ llm = get_llm()
+ return {
+ "status": "healthy",
+ "providers_configured": len(llm._initialized_providers),
+ "default_provider": llm.default_provider.value
+ }
+
+if __name__ == "__main__":
+ import uvicorn # type: ignore
+ uvicorn.run(app, host="0.0.0.0", port=8000)
\ No newline at end of file
diff --git a/src/multi-llm/llm_proxy.py b/src/multi-llm/llm_proxy.py
new file mode 100644
index 0000000..2057760
--- /dev/null
+++ b/src/multi-llm/llm_proxy.py
@@ -0,0 +1,606 @@
+from abc import ABC, abstractmethod
+from typing import Dict, List, Any, AsyncGenerator, Optional, Union
+from enum import Enum
+import asyncio
+import json
+from dataclasses import dataclass
+import os
+
+# Standard message format for all providers
+@dataclass
+class ChatMessage:
+ role: str # "user", "assistant", "system"
+ content: str
+
+ def __getitem__(self, key: str):
+ """Allow dictionary-style access for backward compatibility"""
+ if key == 'role':
+ return self.role
+ elif key == 'content':
+ return self.content
+ else:
+ raise KeyError(f"'{key}' not found in ChatMessage")
+
+ def __setitem__(self, key: str, value: str):
+ """Allow dictionary-style assignment"""
+ if key == 'role':
+ self.role = value
+ elif key == 'content':
+ self.content = value
+ else:
+ raise KeyError(f"'{key}' not found in ChatMessage")
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, str]) -> 'ChatMessage':
+ """Create ChatMessage from dictionary"""
+ return cls(role=data['role'], content=data['content'])
+
+ def to_dict(self) -> Dict[str, str]:
+ """Convert ChatMessage to dictionary"""
+ return {'role': self.role, 'content': self.content}
+
+@dataclass
+class ChatResponse:
+ content: str
+ model: str
+ finish_reason: Optional[str] = None
+ usage: Optional[Dict[str, int]] = None
+
+class LLMProvider(Enum):
+ OLLAMA = "ollama"
+ OPENAI = "openai"
+ ANTHROPIC = "anthropic"
+ GEMINI = "gemini"
+ GROK = "grok"
+
+class BaseLLMAdapter(ABC):
+ """Abstract base class for all LLM adapters"""
+
+ def __init__(self, **config):
+ self.config = config
+
+ @abstractmethod
+ async def chat(
+ self,
+ model: str,
+ messages: List[ChatMessage],
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+ """Send chat messages and get response"""
+ pass
+
+ @abstractmethod
+ async def generate(
+ self,
+ model: str,
+ prompt: str,
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+ """Generate text from prompt"""
+ pass
+
+ @abstractmethod
+ async def list_models(self) -> List[str]:
+ """List available models"""
+ pass
+
+class OllamaAdapter(BaseLLMAdapter):
+ """Adapter for Ollama"""
+
+ def __init__(self, **config):
+ super().__init__(**config)
+ import ollama
+ self.client = ollama.AsyncClient( # type: ignore
+ host=config.get('host', 'http://localhost:11434')
+ )
+
+ async def chat(
+ self,
+ model: str,
+ messages: List[ChatMessage],
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+
+ # Convert ChatMessage objects to Ollama format
+ ollama_messages = []
+ for msg in messages:
+ ollama_messages.append({
+ "role": msg.role,
+ "content": msg.content
+ })
+
+ if stream:
+ return self._stream_chat(model, ollama_messages, **kwargs)
+ else:
+ response = await self.client.chat(
+ model=model,
+ messages=ollama_messages,
+ stream=False,
+ **kwargs
+ )
+ return ChatResponse(
+ content=response['message']['content'],
+ model=model,
+ finish_reason=response.get('done_reason')
+ )
+
+ async def _stream_chat(self, model: str, messages: List[Dict], **kwargs):
+ async for chunk in self.client.chat(
+ model=model,
+ messages=messages,
+ stream=True,
+ **kwargs
+ ):
+ if chunk.get('message', {}).get('content'):
+ yield ChatResponse(
+ content=chunk['message']['content'],
+ model=model,
+ finish_reason=chunk.get('done_reason')
+ )
+
+ async def generate(
+ self,
+ model: str,
+ prompt: str,
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+
+ if stream:
+ return self._stream_generate(model, prompt, **kwargs)
+ else:
+ response = await self.client.generate(
+ model=model,
+ prompt=prompt,
+ stream=False,
+ **kwargs
+ )
+ return ChatResponse(
+ content=response['response'],
+ model=model,
+ finish_reason=response.get('done_reason')
+ )
+
+ async def _stream_generate(self, model: str, prompt: str, **kwargs):
+ async for chunk in self.client.generate(
+ model=model,
+ prompt=prompt,
+ stream=True,
+ **kwargs
+ ):
+ if chunk.get('response'):
+ yield ChatResponse(
+ content=chunk['response'],
+ model=model,
+ finish_reason=chunk.get('done_reason')
+ )
+
+ async def list_models(self) -> List[str]:
+ models = await self.client.list()
+ return [model['name'] for model in models['models']]
+
+class OpenAIAdapter(BaseLLMAdapter):
+ """Adapter for OpenAI"""
+
+ def __init__(self, **config):
+ super().__init__(**config)
+ import openai # type: ignore
+ self.client = openai.AsyncOpenAI(
+ api_key=config.get('api_key', os.getenv('OPENAI_API_KEY'))
+ )
+
+ async def chat(
+ self,
+ model: str,
+ messages: List[ChatMessage],
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+
+ # Convert ChatMessage objects to OpenAI format
+ openai_messages = []
+ for msg in messages:
+ openai_messages.append({
+ "role": msg.role,
+ "content": msg.content
+ })
+
+ if stream:
+ return self._stream_chat(model, openai_messages, **kwargs)
+ else:
+ response = await self.client.chat.completions.create(
+ model=model,
+ messages=openai_messages,
+ stream=False,
+ **kwargs
+ )
+ return ChatResponse(
+ content=response.choices[0].message.content,
+ model=model,
+ finish_reason=response.choices[0].finish_reason,
+ usage=response.usage.model_dump() if response.usage else None
+ )
+
+ async def _stream_chat(self, model: str, messages: List[Dict], **kwargs):
+ async for chunk in await self.client.chat.completions.create(
+ model=model,
+ messages=messages,
+ stream=True,
+ **kwargs
+ ):
+ if chunk.choices[0].delta.content:
+ yield ChatResponse(
+ content=chunk.choices[0].delta.content,
+ model=model,
+ finish_reason=chunk.choices[0].finish_reason
+ )
+
+ async def generate(
+ self,
+ model: str,
+ prompt: str,
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+
+ # Convert to chat format for OpenAI
+ messages = [ChatMessage(role="user", content=prompt)]
+ return await self.chat(model, messages, stream, **kwargs)
+
+ async def list_models(self) -> List[str]:
+ models = await self.client.models.list()
+ return [model.id for model in models.data]
+
+class AnthropicAdapter(BaseLLMAdapter):
+ """Adapter for Anthropic Claude"""
+
+ def __init__(self, **config):
+ super().__init__(**config)
+ import anthropic # type: ignore
+ self.client = anthropic.AsyncAnthropic(
+ api_key=config.get('api_key', os.getenv('ANTHROPIC_API_KEY'))
+ )
+
+ async def chat(
+ self,
+ model: str,
+ messages: List[ChatMessage],
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+
+ # Anthropic requires system message to be separate
+ system_message = None
+ anthropic_messages = []
+
+ for msg in messages:
+ if msg.role == "system":
+ system_message = msg.content
+ else:
+ anthropic_messages.append({
+ "role": msg.role,
+ "content": msg.content
+ })
+
+ request_kwargs = {
+ "model": model,
+ "messages": anthropic_messages,
+ "max_tokens": kwargs.pop('max_tokens', 1000),
+ **kwargs
+ }
+
+ if system_message:
+ request_kwargs["system"] = system_message
+
+ if stream:
+ return self._stream_chat(**request_kwargs)
+ else:
+ response = await self.client.messages.create(
+ stream=False,
+ **request_kwargs
+ )
+ return ChatResponse(
+ content=response.content[0].text,
+ model=model,
+ finish_reason=response.stop_reason,
+ usage={
+ "input_tokens": response.usage.input_tokens,
+ "output_tokens": response.usage.output_tokens
+ }
+ )
+
+ async def _stream_chat(self, **kwargs):
+ async with self.client.messages.stream(**kwargs) as stream:
+ async for text in stream.text_stream:
+ yield ChatResponse(
+ content=text,
+ model=kwargs['model']
+ )
+
+ async def generate(
+ self,
+ model: str,
+ prompt: str,
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+
+ messages = [ChatMessage(role="user", content=prompt)]
+ return await self.chat(model, messages, stream, **kwargs)
+
+ async def list_models(self) -> List[str]:
+ # Anthropic doesn't have a list models endpoint, return known models
+ return [
+ "claude-3-5-sonnet-20241022",
+ "claude-3-5-haiku-20241022",
+ "claude-3-opus-20240229"
+ ]
+
+class GeminiAdapter(BaseLLMAdapter):
+ """Adapter for Google Gemini"""
+
+ def __init__(self, **config):
+ super().__init__(**config)
+ import google.generativeai as genai # type: ignore
+ genai.configure(api_key=config.get('api_key', os.getenv('GEMINI_API_KEY')))
+ self.genai = genai
+
+ async def chat(
+ self,
+ model: str,
+ messages: List[ChatMessage],
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+
+ model_instance = self.genai.GenerativeModel(model)
+
+ # Convert messages to Gemini format
+ chat_history = []
+ current_message = None
+
+ for i, msg in enumerate(messages[:-1]): # All but last message go to history
+ if msg.role == "user":
+ chat_history.append({"role": "user", "parts": [msg.content]})
+ elif msg.role == "assistant":
+ chat_history.append({"role": "model", "parts": [msg.content]})
+
+ # Last message is the current prompt
+ if messages:
+ current_message = messages[-1].content
+
+ chat = model_instance.start_chat(history=chat_history)
+
+ if not current_message:
+ raise ValueError("No current message provided for chat")
+
+ if stream:
+ return self._stream_chat(chat, current_message, **kwargs)
+ else:
+ response = await chat.send_message_async(current_message, **kwargs)
+ return ChatResponse(
+ content=response.text,
+ model=model,
+ finish_reason=response.candidates[0].finish_reason.name if response.candidates else None
+ )
+
+ async def _stream_chat(self, chat, message: str, **kwargs):
+ async for chunk in chat.send_message_async(message, stream=True, **kwargs):
+ if chunk.text:
+ yield ChatResponse(
+ content=chunk.text,
+ model=chat.model.model_name
+ )
+
+ async def generate(
+ self,
+ model: str,
+ prompt: str,
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+
+ messages = [ChatMessage(role="user", content=prompt)]
+ return await self.chat(model, messages, stream, **kwargs)
+
+ async def list_models(self) -> List[str]:
+ models = self.genai.list_models()
+ return [model.name for model in models if 'generateContent' in model.supported_generation_methods]
+
+class UnifiedLLMProxy:
+ """Main proxy class that provides unified interface to all LLM providers"""
+
+ def __init__(self, default_provider: LLMProvider = LLMProvider.OLLAMA):
+ self.adapters: Dict[LLMProvider, BaseLLMAdapter] = {}
+ self.default_provider = default_provider
+ self._initialized_providers = set()
+
+ def configure_provider(self, provider: LLMProvider, **config):
+ """Configure a specific provider with its settings"""
+ adapter_classes = {
+ LLMProvider.OLLAMA: OllamaAdapter,
+ LLMProvider.OPENAI: OpenAIAdapter,
+ LLMProvider.ANTHROPIC: AnthropicAdapter,
+ LLMProvider.GEMINI: GeminiAdapter,
+ # Add other providers as needed
+ }
+
+ if provider in adapter_classes:
+ self.adapters[provider] = adapter_classes[provider](**config)
+ self._initialized_providers.add(provider)
+ else:
+ raise ValueError(f"Unsupported provider: {provider}")
+
+ def set_default_provider(self, provider: LLMProvider):
+ """Set the default provider for requests"""
+ if provider not in self._initialized_providers:
+ raise ValueError(f"Provider {provider} not configured")
+ self.default_provider = provider
+
+ async def chat(
+ self,
+ model: str,
+ messages: Union[List[ChatMessage], List[Dict[str, str]]],
+ provider: Optional[LLMProvider] = None,
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+ """Send chat messages using specified or default provider"""
+
+ provider = provider or self.default_provider
+ adapter = self._get_adapter(provider)
+
+ # Normalize messages to ChatMessage objects
+ normalized_messages = []
+ for msg in messages:
+ if isinstance(msg, ChatMessage):
+ normalized_messages.append(msg)
+ elif isinstance(msg, dict):
+ normalized_messages.append(ChatMessage.from_dict(msg))
+ else:
+ raise ValueError(f"Invalid message type: {type(msg)}")
+
+ return await adapter.chat(model, normalized_messages, stream, **kwargs)
+
+ async def chat_stream(
+ self,
+ model: str,
+ messages: Union[List[ChatMessage], List[Dict[str, str]]],
+ provider: Optional[LLMProvider] = None,
+ **kwargs
+ ) -> AsyncGenerator[ChatResponse, None]:
+ """Stream chat messages using specified or default provider"""
+
+ result = await self.chat(model, messages, provider, stream=True, **kwargs)
+ # Type checker now knows this is an AsyncGenerator due to stream=True
+ async for chunk in result: # type: ignore
+ yield chunk
+
+ async def chat_single(
+ self,
+ model: str,
+ messages: Union[List[ChatMessage], List[Dict[str, str]]],
+ provider: Optional[LLMProvider] = None,
+ **kwargs
+ ) -> ChatResponse:
+ """Get single chat response using specified or default provider"""
+
+ result = await self.chat(model, messages, provider, stream=False, **kwargs)
+ # Type checker now knows this is a ChatResponse due to stream=False
+ return result # type: ignore
+
+ async def generate(
+ self,
+ model: str,
+ prompt: str,
+ provider: Optional[LLMProvider] = None,
+ stream: bool = False,
+ **kwargs
+ ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]:
+ """Generate text using specified or default provider"""
+
+ provider = provider or self.default_provider
+ adapter = self._get_adapter(provider)
+
+ return await adapter.generate(model, prompt, stream, **kwargs)
+
+ async def generate_stream(
+ self,
+ model: str,
+ prompt: str,
+ provider: Optional[LLMProvider] = None,
+ **kwargs
+ ) -> AsyncGenerator[ChatResponse, None]:
+ """Stream text generation using specified or default provider"""
+
+ result = await self.generate(model, prompt, provider, stream=True, **kwargs)
+ async for chunk in result: # type: ignore
+ yield chunk
+
+ async def generate_single(
+ self,
+ model: str,
+ prompt: str,
+ provider: Optional[LLMProvider] = None,
+ **kwargs
+ ) -> ChatResponse:
+ """Get single generation response using specified or default provider"""
+
+ result = await self.generate(model, prompt, provider, stream=False, **kwargs)
+ return result # type: ignore
+
+ async def list_models(self, provider: Optional[LLMProvider] = None) -> List[str]:
+ """List available models for specified or default provider"""
+
+ provider = provider or self.default_provider
+ adapter = self._get_adapter(provider)
+
+ return await adapter.list_models()
+
+ def _get_adapter(self, provider: LLMProvider) -> BaseLLMAdapter:
+ """Get adapter for specified provider"""
+ if provider not in self.adapters:
+ raise ValueError(f"Provider {provider} not configured")
+ return self.adapters[provider]
+
+# Example usage and configuration
+class LLMManager:
+ """Singleton manager for the unified LLM proxy"""
+
+ _instance = None
+ _proxy = None
+
+ @classmethod
+ def get_instance(cls):
+ if cls._instance is None:
+ cls._instance = cls()
+ return cls._instance
+
+ def __init__(self):
+ if LLMManager._proxy is None:
+ LLMManager._proxy = UnifiedLLMProxy()
+ self._configure_from_environment()
+
+ def _configure_from_environment(self):
+ """Configure providers based on environment variables"""
+
+ if not self._proxy:
+ raise RuntimeError("LLM proxy not initialized")
+
+ # Configure Ollama if available
+ ollama_host = os.getenv('OLLAMA_HOST', 'http://localhost:11434')
+ self._proxy.configure_provider(LLMProvider.OLLAMA, host=ollama_host)
+
+ # Configure OpenAI if API key is available
+ if os.getenv('OPENAI_API_KEY'):
+ self._proxy.configure_provider(LLMProvider.OPENAI)
+
+ # Configure Anthropic if API key is available
+ if os.getenv('ANTHROPIC_API_KEY'):
+ self._proxy.configure_provider(LLMProvider.ANTHROPIC)
+
+ # Configure Gemini if API key is available
+ if os.getenv('GEMINI_API_KEY'):
+ self._proxy.configure_provider(LLMProvider.GEMINI)
+
+ # Set default provider from environment or use Ollama
+ default_provider = os.getenv('DEFAULT_LLM_PROVIDER', 'ollama')
+ try:
+ self._proxy.set_default_provider(LLMProvider(default_provider))
+ except ValueError:
+ # Fallback to Ollama if specified provider not available
+ self._proxy.set_default_provider(LLMProvider.OLLAMA)
+
+ def get_proxy(self) -> UnifiedLLMProxy:
+ """Get the unified LLM proxy instance"""
+ if not self._proxy:
+ raise RuntimeError("LLM proxy not initialized")
+ return self._proxy
+
+# Convenience function for easy access
+def get_llm() -> UnifiedLLMProxy:
+ """Get the configured LLM proxy"""
+ return LLMManager.get_instance().get_proxy()
\ No newline at end of file