diff --git a/frontend/src/components/JobCreator.tsx b/frontend/src/components/JobCreator.tsx new file mode 100644 index 0000000..dfe936e --- /dev/null +++ b/frontend/src/components/JobCreator.tsx @@ -0,0 +1,541 @@ +import React, { useState, useEffect, useRef, JSX } from 'react'; +import { + Box, + Button, + Typography, + Paper, + TextField, + Grid, + Dialog, + DialogTitle, + DialogContent, + DialogContentText, + DialogActions, + IconButton, + useTheme, + useMediaQuery, + Chip, + Divider, + Card, + CardContent, + CardHeader, + LinearProgress, + Stack, + Alert +} from '@mui/material'; +import { + SyncAlt, + Favorite, + Settings, + Info, + Search, + AutoFixHigh, + Image, + Psychology, + Build, + CloudUpload, + Description, + Business, + LocationOn, + Work, + CheckCircle, + Star +} from '@mui/icons-material'; +import { styled } from '@mui/material/styles'; +import DescriptionIcon from '@mui/icons-material/Description'; +import FileUploadIcon from '@mui/icons-material/FileUpload'; + +import { useAuth } from 'hooks/AuthContext'; +import { useSelectedCandidate, useSelectedJob } from 'hooks/GlobalContext'; +import { BackstoryElementProps } from './BackstoryTab'; +import { LoginRequired } from 'components/ui/LoginRequired'; + +import * as Types from 'types/types'; + +const VisuallyHiddenInput = styled('input')({ + clip: 'rect(0 0 0 0)', + clipPath: 'inset(50%)', + height: 1, + overflow: 'hidden', + position: 'absolute', + bottom: 0, + left: 0, + whiteSpace: 'nowrap', + width: 1, +}); + +const UploadBox = styled(Box)(({ theme }) => ({ + border: `2px dashed ${theme.palette.primary.main}`, + borderRadius: theme.shape.borderRadius * 2, + padding: theme.spacing(4), + textAlign: 'center', + backgroundColor: theme.palette.action.hover, + transition: 'all 0.3s ease', + cursor: 'pointer', + '&:hover': { + backgroundColor: theme.palette.action.selected, + borderColor: theme.palette.primary.dark, + }, +})); + +const StatusBox = styled(Box)(({ theme }) => ({ + display: 'flex', + alignItems: 'center', + gap: theme.spacing(1), + padding: theme.spacing(1, 2), + backgroundColor: theme.palette.background.paper, + borderRadius: theme.shape.borderRadius, + border: `1px solid ${theme.palette.divider}`, + minHeight: 48, +})); + +const getIcon = (type: Types.ApiActivityType) => { + switch (type) { + case 'converting': + return ; + case 'heartbeat': + return ; + case 'system': + return ; + case 'info': + return ; + case 'searching': + return ; + case 'generating': + return ; + case 'generating_image': + return ; + case 'thinking': + return ; + case 'tooling': + return ; + default: + return ; + } +}; + +interface JobCreator extends BackstoryElementProps { + onSave?: (job: Types.Job) => void; +} +const JobCreator = (props: JobCreator) => { + const { user, apiClient } = useAuth(); + const { onSave } = props; + const { selectedCandidate } = useSelectedCandidate(); + const { selectedJob, setSelectedJob } = useSelectedJob(); + const { setSnack, submitQuery } = props; + const backstoryProps = { setSnack, submitQuery }; + const theme = useTheme(); + const isMobile = useMediaQuery(theme.breakpoints.down('sm')); + const isTablet = useMediaQuery(theme.breakpoints.down('md')); + + const [openUploadDialog, setOpenUploadDialog] = useState(false); + const [jobDescription, setJobDescription] = useState(''); + const [jobRequirements, setJobRequirements] = useState(null); + const [jobTitle, setJobTitle] = useState(''); + const [company, setCompany] = useState(''); + const [summary, setSummary] = useState(''); + const [jobLocation, setJobLocation] = useState(''); + const [jobId, setJobId] = useState(''); + const [jobStatus, setJobStatus] = useState(''); + const [jobStatusIcon, setJobStatusIcon] = useState(<>); + const [isProcessing, setIsProcessing] = useState(false); + + useEffect(() => { + + }, [jobTitle, jobDescription, company]); + + const fileInputRef = useRef(null); + + + if (!user?.id) { + return ( + + ); + } + + const jobStatusHandlers = { + onStatus: (status: Types.ChatMessageStatus) => { + console.log('status:', status.content); + setJobStatusIcon(getIcon(status.activity)); + setJobStatus(status.content); + }, + onMessage: (job: Types.Job) => { + console.log('onMessage - job', job); + setCompany(job.company || ''); + setJobDescription(job.description); + setSummary(job.summary || ''); + setJobTitle(job.title || ''); + setJobRequirements(job.requirements || null); + setJobStatusIcon(<>); + setJobStatus(''); + }, + onError: (error: Types.ChatMessageError) => { + console.log('onError', error); + setSnack(error.content, "error"); + setIsProcessing(false); + }, + onComplete: () => { + setJobStatusIcon(<>); + setJobStatus(''); + setIsProcessing(false); + } + }; + + const handleJobUpload = async (e: React.ChangeEvent) => { + if (e.target.files && e.target.files[0]) { + const file = e.target.files[0]; + const fileExtension = '.' + file.name.split('.').pop()?.toLowerCase(); + let docType: Types.DocumentType | null = null; + switch (fileExtension.substring(1)) { + case "pdf": + docType = "pdf"; + break; + case "docx": + docType = "docx"; + break; + case "md": + docType = "markdown"; + break; + case "txt": + docType = "txt"; + break; + } + + if (!docType) { + setSnack('Invalid file type. Please upload .txt, .md, .docx, or .pdf files only.', 'error'); + return; + } + + try { + setIsProcessing(true); + setJobDescription(''); + setJobTitle(''); + setJobRequirements(null); + setSummary(''); + const controller = apiClient.createJobFromFile(file, jobStatusHandlers); + const job = await controller.promise; + if (!job) { + return; + } + console.log(`Job id: ${job.id}`); + e.target.value = ''; + } catch (error) { + console.error(error); + setSnack('Failed to upload document', 'error'); + setIsProcessing(false); + } + } + }; + + const handleUploadClick = () => { + fileInputRef.current?.click(); + }; + + const renderRequirementSection = (title: string, items: string[] | undefined, icon: JSX.Element, required = false) => { + if (!items || items.length === 0) return null; + + return ( + + + {icon} + + {title} + + {required && } + + + {items.map((item, index) => ( + + ))} + + + ); + }; + + const renderJobRequirements = () => { + if (!jobRequirements) return null; + + return ( + + } + sx={{ pb: 1 }} + /> + + {renderRequirementSection( + "Technical Skills (Required)", + jobRequirements.technicalSkills.required, + , + true + )} + {renderRequirementSection( + "Technical Skills (Preferred)", + jobRequirements.technicalSkills.preferred, + + )} + {renderRequirementSection( + "Experience Requirements (Required)", + jobRequirements.experienceRequirements.required, + , + true + )} + {renderRequirementSection( + "Experience Requirements (Preferred)", + jobRequirements.experienceRequirements.preferred, + + )} + {renderRequirementSection( + "Soft Skills", + jobRequirements.softSkills, + + )} + {renderRequirementSection( + "Experience", + jobRequirements.experience, + + )} + {renderRequirementSection( + "Education", + jobRequirements.education, + + )} + {renderRequirementSection( + "Certifications", + jobRequirements.certifications, + + )} + {renderRequirementSection( + "Preferred Attributes", + jobRequirements.preferredAttributes, + + )} + + + ); + }; + + const handleSave = async () => { + const newJob: Types.Job = { + ownerId: user?.id || '', + ownerType: 'candidate', + description: jobDescription, + company: company, + summary: summary, + title: jobTitle, + requirements: jobRequirements || undefined + }; + setIsProcessing(true); + const job = await apiClient.createJob(newJob); + setIsProcessing(false); + onSave ? onSave(job) : setSelectedJob(job); + }; + + const handleExtractRequirements = () => { + // Implement requirements extraction logic here + setIsProcessing(true); + // This would call your API to extract requirements from the job description + }; + + const renderJobCreation = () => { + if (!user) { + return You must be logged in; + } + + return ( + + {/* Upload Section */} + + } + /> + + + + + + Upload Job Description + + + + + Drop your job description here + + + Supported formats: PDF, DOCX, TXT, MD + + + + + + + + + + Or Enter Manually + + setJobDescription(e.target.value)} + disabled={isProcessing} + sx={{ mb: 2 }} + /> + {jobRequirements === null && jobDescription && ( + + )} + + + + {(jobStatus || isProcessing) && ( + + + {jobStatusIcon} + + {jobStatus || 'Processing...'} + + + {isProcessing && } + + )} + + + + {/* Job Details Section */} + + } + /> + + + + setJobTitle(e.target.value)} + required + disabled={isProcessing} + InputProps={{ + startAdornment: + }} + /> + + + + setCompany(e.target.value)} + required + disabled={isProcessing} + InputProps={{ + startAdornment: + }} + /> + + + {/* + setJobLocation(e.target.value)} + disabled={isProcessing} + InputProps={{ + startAdornment: + }} + /> + */} + + + + + + + + + + + {/* Job Summary */} + {summary !== '' && + + } + sx={{ pb: 1 }} + /> + + {summary} + + + } + + {/* Requirements Display */} + {renderJobRequirements()} + + + ); + }; + + return ( + + {selectedJob === null && renderJobCreation()} + + ); +}; + +export { JobCreator }; \ No newline at end of file diff --git a/frontend/src/components/JobManagement.tsx b/frontend/src/components/JobManagement.tsx index 87a7ad5..59834b5 100644 --- a/frontend/src/components/JobManagement.tsx +++ b/frontend/src/components/JobManagement.tsx @@ -115,34 +115,23 @@ const getIcon = (type: Types.ApiActivityType) => { }; const JobManagement = (props: BackstoryElementProps) => { - const { user, apiClient } = useAuth(); - const { selectedCandidate } = useSelectedCandidate(); + const { user, apiClient } = useAuth(); const { selectedJob, setSelectedJob } = useSelectedJob(); - const { setSnack, submitQuery } = props; - const backstoryProps = { setSnack, submitQuery }; + const { setSnack, submitQuery } = props; const theme = useTheme(); - const isMobile = useMediaQuery(theme.breakpoints.down('sm')); - const isTablet = useMediaQuery(theme.breakpoints.down('md')); + const isMobile = useMediaQuery(theme.breakpoints.down('sm')); - const [openUploadDialog, setOpenUploadDialog] = useState(false); const [jobDescription, setJobDescription] = useState(''); const [jobRequirements, setJobRequirements] = useState(null); const [jobTitle, setJobTitle] = useState(''); const [company, setCompany] = useState(''); const [summary, setSummary] = useState(''); - const [jobLocation, setJobLocation] = useState(''); - const [jobId, setJobId] = useState(''); const [jobStatus, setJobStatus] = useState(''); const [jobStatusIcon, setJobStatusIcon] = useState(<>); const [isProcessing, setIsProcessing] = useState(false); - useEffect(() => { - - }, [jobTitle, jobDescription, company]); - const fileInputRef = useRef(null); - if (!user?.id) { return ( @@ -223,12 +212,12 @@ const JobManagement = (props: BackstoryElementProps) => { setJobTitle(''); setJobRequirements(null); setSummary(''); - const controller = apiClient.uploadCandidateDocument(file, { isJobDocument: true, overwrite: true }, documentStatusHandlers); - const document = await controller.promise; - if (!document) { + const controller = apiClient.createJobFromFile(file, jobStatusHandlers); + const job = await controller.promise; + if (!job) { return; } - console.log(`Document id: ${document.id}`); + console.log(`Job id: ${job.id}`); e.target.value = ''; } catch (error) { console.error(error); @@ -354,10 +343,6 @@ const JobManagement = (props: BackstoryElementProps) => { // This would call your API to extract requirements from the job description }; - const loadJob = async () => { - const job = await apiClient.getJob("7594e989-a926-45a2-9b07-ae553d2e0d0d"); - setSelectedJob(job); - } const renderJobCreation = () => { if (!user) { return You must be logged in; @@ -367,7 +352,6 @@ const JobManagement = (props: BackstoryElementProps) => { - {/* Upload Section */} = (props: JobAnalysisProps) = const [overallScore, setOverallScore] = useState(0); const [requirementsSession, setRequirementsSession] = useState(null); const [statusMessage, setStatusMessage] = useState(null); + const [startAnalysis, setStartAnalysis] = useState(false); + const [analyzing, setAnalyzing] = useState(false); + const isMobile = useMediaQuery(theme.breakpoints.down('sm')); // Handle accordion expansion @@ -63,7 +69,7 @@ const JobMatchAnalysis: React.FC = (props: JobAnalysisProps) = setExpanded(isExpanded ? panel : false); }; - useEffect(() => { + const initializeRequirements = (job: Job) => { if (!job || !job.requirements) { return; } @@ -106,116 +112,19 @@ const JobMatchAnalysis: React.FC = (props: JobAnalysisProps) = setSkillMatches(initialSkillMatches); setStatusMessage(null); setLoadingRequirements(false); - - }, [job, setRequirements]); + setOverallScore(0); + } useEffect(() => { - if (requirementsSession || creatingSession) { - return; - } - - try { - setCreatingSession(true); - apiClient.getOrCreateChatSession(candidate, `Generate requirements for ${candidate.fullName}`, 'job_requirements') - .then(session => { - setRequirementsSession(session); - setCreatingSession(false); - }); - } catch (error) { - setSnack('Unable to load chat session', 'error'); - } finally { - setCreatingSession(false); - } - - }, [requirementsSession, apiClient, candidate]); - - // Fetch initial requirements - useEffect(() => { - if (!job.description || !requirementsSession || loadingRequirements) { - return; - } - - const getRequirements = async () => { - setLoadingRequirements(true); - try { - const chatMessage: ChatMessageUser = { ...defaultMessage, sessionId: requirementsSession.id || '', content: job.description }; - apiClient.sendMessageStream(chatMessage, { - onMessage: (msg: ChatMessage) => { - console.log(`onMessage: ${msg.type}`, msg); - const job: Job = toCamelCase(JSON.parse(msg.content || '')); - const requirements: { requirement: string, domain: string }[] = []; - if (job.requirements?.technicalSkills) { - job.requirements.technicalSkills.required?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (required)' })); - job.requirements.technicalSkills.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Technical Skills (preferred)' })); - } - if (job.requirements?.experienceRequirements) { - job.requirements.experienceRequirements.required?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (required)' })); - job.requirements.experienceRequirements.preferred?.forEach(req => requirements.push({ requirement: req, domain: 'Experience (preferred)' })); - } - if (job.requirements?.softSkills) { - job.requirements.softSkills.forEach(req => requirements.push({ requirement: req, domain: 'Soft Skills' })); - } - if (job.requirements?.experience) { - job.requirements.experience.forEach(req => requirements.push({ requirement: req, domain: 'Experience' })); - } - if (job.requirements?.education) { - job.requirements.education.forEach(req => requirements.push({ requirement: req, domain: 'Education' })); - } - if (job.requirements?.certifications) { - job.requirements.certifications.forEach(req => requirements.push({ requirement: req, domain: 'Certifications' })); - } - if (job.requirements?.preferredAttributes) { - job.requirements.preferredAttributes.forEach(req => requirements.push({ requirement: req, domain: 'Preferred Attributes' })); - } - - const initialSkillMatches = requirements.map(req => ({ - requirement: req.requirement, - domain: req.domain, - status: 'waiting' as const, - matchScore: 0, - assessment: '', - description: '', - citations: [] - })); - - setRequirements(requirements); - setSkillMatches(initialSkillMatches); - setStatusMessage(null); - setLoadingRequirements(false); - }, - onError: (error: string | ChatMessageError) => { - console.log("onError:", error); - // Type-guard to determine if this is a ChatMessageBase or a string - if (typeof error === "object" && error !== null && "content" in error) { - setSnack(error.content || 'Error obtaining requirements from job description.', "error"); - } else { - setSnack(error as string, "error"); - } - setLoadingRequirements(false); - }, - onStreaming: (chunk: ChatMessageStreaming) => { - // console.log("onStreaming:", chunk); - }, - onStatus: (status: ChatMessageStatus) => { - console.log(`onStatus: ${status}`); - }, - onComplete: () => { - console.log("onComplete"); - setStatusMessage(null); - setLoadingRequirements(false); - } - }); - } catch (error) { - console.error('Failed to send message:', error); - setLoadingRequirements(false); - } - }; - - getRequirements(); - }, [job, requirementsSession]); + initializeRequirements(job); + }, [job]); // Fetch match data for each requirement useEffect(() => { + if (!startAnalysis || analyzing || !job.requirements) { + return; + } + const fetchMatchData = async () => { if (requirements.length === 0) return; @@ -279,10 +188,9 @@ const JobMatchAnalysis: React.FC = (props: JobAnalysisProps) = } }; - if (!loadingRequirements) { - fetchMatchData(); - } - }, [requirements, loadingRequirements]); + setAnalyzing(true); + fetchMatchData().then(() => { setAnalyzing(false); setStartAnalysis(false) }); + }, [job, startAnalysis, analyzing, requirements, loadingRequirements]); // Get color based on match score const getMatchColor = (score: number): string => { @@ -301,6 +209,11 @@ const JobMatchAnalysis: React.FC = (props: JobAnalysisProps) = return ; }; + const beginAnalysis = () => { + initializeRequirements(job); + setStartAnalysis(true); + }; + return ( @@ -344,7 +257,9 @@ const JobMatchAnalysis: React.FC = (props: JobAnalysisProps) = - + + {} + {overallScore !== 0 && <> Overall Match: @@ -391,6 +306,7 @@ const JobMatchAnalysis: React.FC = (props: JobAnalysisProps) = fontWeight: 'bold' }} /> + } diff --git a/frontend/src/components/CandidateInfo.tsx b/frontend/src/components/ui/CandidateInfo.tsx similarity index 98% rename from frontend/src/components/CandidateInfo.tsx rename to frontend/src/components/ui/CandidateInfo.tsx index 95b866e..9b1e15c 100644 --- a/frontend/src/components/CandidateInfo.tsx +++ b/frontend/src/components/ui/CandidateInfo.tsx @@ -13,7 +13,7 @@ import { CopyBubble } from "components/CopyBubble"; import { rest } from 'lodash'; import { AIBanner } from 'components/ui/AIBanner'; import { useAuth } from 'hooks/AuthContext'; -import { DeleteConfirmation } from './DeleteConfirmation'; +import { DeleteConfirmation } from '../DeleteConfirmation'; interface CandidateInfoProps { candidate: Candidate; diff --git a/frontend/src/components/ui/CandidatePicker.tsx b/frontend/src/components/ui/CandidatePicker.tsx index 1a895f0..decf152 100644 --- a/frontend/src/components/ui/CandidatePicker.tsx +++ b/frontend/src/components/ui/CandidatePicker.tsx @@ -4,7 +4,7 @@ import Button from '@mui/material/Button'; import Box from '@mui/material/Box'; import { BackstoryElementProps } from 'components/BackstoryTab'; -import { CandidateInfo } from 'components/CandidateInfo'; +import { CandidateInfo } from 'components/ui/CandidateInfo'; import { Candidate } from "types/types"; import { useAuth } from 'hooks/AuthContext'; import { useSelectedCandidate } from 'hooks/GlobalContext'; diff --git a/frontend/src/components/ui/JobInfo.tsx b/frontend/src/components/ui/JobInfo.tsx new file mode 100644 index 0000000..c850c01 --- /dev/null +++ b/frontend/src/components/ui/JobInfo.tsx @@ -0,0 +1,97 @@ +import React from 'react'; +import { Box, Link, Typography, Avatar, Grid, SxProps, CardActions } from '@mui/material'; +import { + Card, + CardContent, + Divider, + useTheme, +} from '@mui/material'; +import DeleteIcon from '@mui/icons-material/Delete'; +import { useMediaQuery } from '@mui/material'; +import { JobFull } from 'types/types'; +import { CopyBubble } from "components/CopyBubble"; +import { rest } from 'lodash'; +import { AIBanner } from 'components/ui/AIBanner'; +import { useAuth } from 'hooks/AuthContext'; +import { DeleteConfirmation } from '../DeleteConfirmation'; + +interface JobInfoProps { + job: JobFull; + sx?: SxProps; + action?: string; + elevation?: number; + variant?: "small" | "normal" | null +}; + +const JobInfo: React.FC = (props: JobInfoProps) => { + const { job } = props; + const { user, apiClient } = useAuth(); + const { + sx, + action = '', + elevation = 1, + variant = "normal" + } = props; + const theme = useTheme(); + const isMobile = useMediaQuery(theme.breakpoints.down('md')); + const isAdmin = user?.isAdmin; + + const deleteJob = async (jobId: string | undefined) => { + if (jobId) { + await apiClient.deleteJob(jobId); + } + } + + if (!job) { + return No user loaded.; + } + + return ( + + + {variant !== "small" && <> + {job.location && + + Location: {job.location.city}, {job.location.state || job.location.country} + + } + {job.company && + + Company: {job.company} + + } + {job.summary && + Summary: {job.summary} + + } + } + + + {isAdmin && + { deleteJob(job.id); }} + sx={{ minWidth: 'auto', px: 2, maxHeight: "min-content", color: "red" }} + action="delete" + label="job" + title="Delete job" + icon= + message={`Are you sure you want to delete ${job.id}? This action cannot be undone.`} + />} + + + ); +}; + +export { JobInfo }; diff --git a/frontend/src/components/ui/JobPicker.tsx b/frontend/src/components/ui/JobPicker.tsx new file mode 100644 index 0000000..9ae5057 --- /dev/null +++ b/frontend/src/components/ui/JobPicker.tsx @@ -0,0 +1,69 @@ +import React, { useEffect, useState } from 'react'; +import { useNavigate } from "react-router-dom"; +import Button from '@mui/material/Button'; +import Box from '@mui/material/Box'; + +import { BackstoryElementProps } from 'components/BackstoryTab'; +import { JobInfo } from 'components/ui/JobInfo'; +import { Job, JobFull } from "types/types"; +import { useAuth } from 'hooks/AuthContext'; +import { useSelectedJob } from 'hooks/GlobalContext'; + +interface JobPickerProps extends BackstoryElementProps { + onSelect?: (job: JobFull) => void +}; + +const JobPicker = (props: JobPickerProps) => { + const { onSelect } = props; + const { apiClient } = useAuth(); + const { selectedJob, setSelectedJob } = useSelectedJob(); + const { setSnack } = props; + const [jobs, setJobs] = useState(null); + + useEffect(() => { + if (jobs !== null) { + return; + } + const getJobs = async () => { + try { + const results = await apiClient.getJobs(); + const jobs: JobFull[] = results.data; + jobs.sort((a, b) => { + let result = a.company?.localeCompare(b.company || ''); + if (result === 0) { + result = a.title?.localeCompare(b.title || ''); + } + return result || 0; + }); + setJobs(jobs); + } catch (err) { + setSnack("" + err); + } + }; + + getJobs(); + }, [jobs, setSnack]); + + return ( + + + {jobs?.map((j, i) => + { onSelect ? onSelect(j) : setSelectedJob(j); }} + sx={{ cursor: "pointer" }}> + {selectedJob?.id === j.id && + + } + {selectedJob?.id !== j.id && + + } + + )} + + + ); +}; + +export { + JobPicker +}; \ No newline at end of file diff --git a/frontend/src/pages/CandidateChatPage.tsx b/frontend/src/pages/CandidateChatPage.tsx index 132a235..95f8eda 100644 --- a/frontend/src/pages/CandidateChatPage.tsx +++ b/frontend/src/pages/CandidateChatPage.tsx @@ -17,7 +17,7 @@ import { ConversationHandle } from 'components/Conversation'; import { BackstoryPageProps } from 'components/BackstoryTab'; import { Message } from 'components/Message'; import { DeleteConfirmation } from 'components/DeleteConfirmation'; -import { CandidateInfo } from 'components/CandidateInfo'; +import { CandidateInfo } from 'components/ui/CandidateInfo'; import { useNavigate } from 'react-router-dom'; import { useSelectedCandidate } from 'hooks/GlobalContext'; import PropagateLoader from 'react-spinners/PropagateLoader'; diff --git a/frontend/src/pages/GenerateCandidate.tsx b/frontend/src/pages/GenerateCandidate.tsx index a38630b..2b8efbf 100644 --- a/frontend/src/pages/GenerateCandidate.tsx +++ b/frontend/src/pages/GenerateCandidate.tsx @@ -9,7 +9,7 @@ import CancelIcon from '@mui/icons-material/Cancel'; import SendIcon from '@mui/icons-material/Send'; import PropagateLoader from 'react-spinners/PropagateLoader'; -import { CandidateInfo } from '../components/CandidateInfo'; +import { CandidateInfo } from '../components/ui/CandidateInfo'; import { Quote } from 'components/Quote'; import { BackstoryElementProps } from 'components/BackstoryTab'; import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryTextField'; diff --git a/frontend/src/pages/JobAnalysisPage.tsx b/frontend/src/pages/JobAnalysisPage.tsx index dcfdfac..a7b7123 100644 --- a/frontend/src/pages/JobAnalysisPage.tsx +++ b/frontend/src/pages/JobAnalysisPage.tsx @@ -7,31 +7,74 @@ import { Button, Typography, Paper, - Avatar, useTheme, Snackbar, + Container, + Grid, Alert, + Tabs, + Tab, + Card, + CardContent, + Divider, + Avatar, + Badge, } from '@mui/material'; +import { + Person, + PersonAdd, + AccountCircle, + Add, + WorkOutline, + AddCircle, +} from '@mui/icons-material'; import PersonIcon from '@mui/icons-material/Person'; import WorkIcon from '@mui/icons-material/Work'; import AssessmentIcon from '@mui/icons-material/Assessment'; import { JobMatchAnalysis } from 'components/JobMatchAnalysis'; -import { Candidate } from "types/types"; +import { Candidate, Job, JobFull } from "types/types"; import { useNavigate } from 'react-router-dom'; import { BackstoryPageProps } from 'components/BackstoryTab'; import { useAuth } from 'hooks/AuthContext'; import { useSelectedCandidate, useSelectedJob } from 'hooks/GlobalContext'; -import { CandidateInfo } from 'components/CandidateInfo'; +import { CandidateInfo } from 'components/ui/CandidateInfo'; import { ComingSoon } from 'components/ui/ComingSoon'; import { JobManagement } from 'components/JobManagement'; import { LoginRequired } from 'components/ui/LoginRequired'; import { Scrollable } from 'components/Scrollable'; import { CandidatePicker } from 'components/ui/CandidatePicker'; +import { JobPicker } from 'components/ui/JobPicker'; +import { JobCreator } from 'components/JobCreator'; + +function WorkAddIcon() { + return ( + + + + + ); +} // Main component const JobAnalysisPage: React.FC = (props: BackstoryPageProps) => { const theme = useTheme(); - const { user } = useAuth(); + const { user, apiClient } = useAuth(); const navigate = useNavigate(); const { selectedCandidate, setSelectedCandidate } = useSelectedCandidate() const { selectedJob, setSelectedJob } = useSelectedJob() @@ -41,12 +84,21 @@ const JobAnalysisPage: React.FC = (props: BackstoryPageProps const [activeStep, setActiveStep] = useState(0); const [analysisStarted, setAnalysisStarted] = useState(false); const [error, setError] = useState(null); + const [jobTab, setJobTab] = useState('load'); useEffect(() => { - if (selectedJob && activeStep === 1) { - setActiveStep(2); + console.log({ activeStep, selectedCandidate, selectedJob }); + + if (!selectedCandidate) { + if (activeStep !== 0) { + setActiveStep(0); + } + } else if (!selectedJob) { + if (activeStep !== 1) { + setActiveStep(1); + } } - }, [selectedJob, activeStep]); + }, [selectedCandidate, selectedJob, activeStep]) // Steps in our process const steps = [ @@ -93,7 +145,6 @@ const JobAnalysisPage: React.FC = (props: BackstoryPageProps setSelectedJob(null); break; case 1: /* Select Job */ - setSelectedCandidate(null); setSelectedJob(null); break; case 2: /* Job Analysis */ @@ -109,6 +160,11 @@ const JobAnalysisPage: React.FC = (props: BackstoryPageProps setActiveStep(1); } + const onJobSelect = (job: Job) => { + setSelectedJob(job) + setActiveStep(2); + } + // Render function for the candidate selection step const renderCandidateSelection = () => ( @@ -120,16 +176,35 @@ const JobAnalysisPage: React.FC = (props: BackstoryPageProps ); + const handleTabChange = (event: React.SyntheticEvent, value: string) => { + setJobTab(value); + }; + // Render function for the job description step - const renderJobDescription = () => ( - - {selectedCandidate && ( - { + if (!selectedCandidate) { + return; + } + + return ( + + + } label="Load" /> + } label="Create" /> + + + + {jobTab === 'load' && + + } + {jobTab === 'create' && + - )} + onSave={onJobSelect} + />} - ); + ); + } // Render function for the analysis step const renderAnalysis = () => ( @@ -232,7 +307,7 @@ const JobAnalysisPage: React.FC = (props: BackstoryPageProps ) : ( )} diff --git a/frontend/src/pages/OldChatPage.tsx b/frontend/src/pages/OldChatPage.tsx index c58f549..a4ec2cc 100644 --- a/frontend/src/pages/OldChatPage.tsx +++ b/frontend/src/pages/OldChatPage.tsx @@ -7,7 +7,7 @@ import MuiMarkdown from 'mui-markdown'; import { BackstoryPageProps } from '../components/BackstoryTab'; import { Conversation, ConversationHandle } from '../components/Conversation'; import { BackstoryQuery } from '../components/BackstoryQuery'; -import { CandidateInfo } from 'components/CandidateInfo'; +import { CandidateInfo } from 'components/ui/CandidateInfo'; import { useAuth } from 'hooks/AuthContext'; import { Candidate } from 'types/types'; diff --git a/frontend/src/services/api-client.ts b/frontend/src/services/api-client.ts index 1aef0fe..36b7b06 100644 --- a/frontend/src/services/api-client.ts +++ b/frontend/src/services/api-client.ts @@ -52,7 +52,7 @@ interface StreamingOptions { signal?: AbortSignal; } -interface DeleteCandidateResponse { +interface DeleteResponse { success: boolean; message: string; } @@ -530,14 +530,24 @@ class ApiClient { return this.handleApiResponseWithConversion(response, 'Candidate'); } - async deleteCandidate(id: string): Promise { + async deleteCandidate(id: string): Promise { const response = await fetch(`${this.baseUrl}/candidates/${id}`, { method: 'DELETE', headers: this.defaultHeaders, body: JSON.stringify({ id }) }); - return handleApiResponse(response); + return handleApiResponse(response); + } + + async deleteJob(id: string): Promise { + const response = await fetch(`${this.baseUrl}/jobs/${id}`, { + method: 'DELETE', + headers: this.defaultHeaders, + body: JSON.stringify({ id }) + }); + + return handleApiResponse(response); } async uploadCandidateProfile(file: File): Promise { @@ -641,7 +651,7 @@ class ApiClient { return this.handleApiResponseWithConversion(response, 'Job'); } - async getJobs(request: Partial = {}): Promise> { + async getJobs(request: Partial = {}): Promise> { const paginatedRequest = createPaginatedRequest(request); const params = toUrlParams(formatApiRequest(paginatedRequest)); @@ -649,7 +659,7 @@ class ApiClient { headers: this.defaultHeaders }); - return this.handlePaginatedApiResponseWithConversion(response, 'Job'); + return this.handlePaginatedApiResponseWithConversion(response, 'JobFull'); } async getJobsByEmployer(employerId: string, request: Partial = {}): Promise> { @@ -844,18 +854,28 @@ class ApiClient { } }; return this.streamify('/candidates/documents/upload', formData, streamingOptions); - // { - // method: 'POST', - // headers: { - // // Don't set Content-Type - browser will set it automatically with boundary - // 'Authorization': this.defaultHeaders['Authorization'] - // }, - // body: formData - // }); + } - // const result = await handleApiResponse(response); - - // return result; + createJobFromFile(file: File, streamingOptions?: StreamingOptions): StreamingResponse { + const formData = new FormData() + formData.append('file', file); + formData.append('filename', file.name); + streamingOptions = { + ...streamingOptions, + headers: { + // Don't set Content-Type - browser will set it automatically with boundary + 'Authorization': this.defaultHeaders['Authorization'] + } + }; + return this.streamify('/jobs/upload', formData, streamingOptions); + } + + getJobRequirements(jobId: string, streamingOptions?: StreamingOptions): StreamingResponse { + streamingOptions = { + ...streamingOptions, + headers: this.defaultHeaders, + }; + return this.streamify(`/jobs/requirements/${jobId}`, null, streamingOptions); } async candidateMatchForRequirement(candidate_id: string, requirement: string) : Promise { @@ -1001,7 +1021,7 @@ class ApiClient { * @param options callbacks, headers, and method * @returns */ - streamify(api: string, data: BodyInit, options: StreamingOptions = {}) : StreamingResponse { + streamify(api: string, data: BodyInit | null, options: StreamingOptions = {}) : StreamingResponse { const abortController = new AbortController(); const signal = options.signal || abortController.signal; const headers = options.headers || null; diff --git a/frontend/src/types/types.ts b/frontend/src/types/types.ts index dfc452e..a422f84 100644 --- a/frontend/src/types/types.ts +++ b/frontend/src/types/types.ts @@ -1,6 +1,6 @@ // Generated TypeScript types from Pydantic models // Source: src/backend/models.py -// Generated on: 2025-06-05T22:02:22.004513 +// Generated on: 2025-06-07T20:43:58.855207 // DO NOT EDIT MANUALLY - This file is auto-generated // ============================ @@ -323,7 +323,7 @@ export interface ChatMessageMetaData { presencePenalty?: number; stopSequences?: Array; ragResults?: Array; - llmHistory?: Array; + llmHistory?: Array; evalCount: number; evalDuration: number; promptEvalCount: number; @@ -711,12 +711,6 @@ export interface JobResponse { meta?: Record; } -export interface LLMMessage { - role: string; - content: string; - toolCalls?: Array>; -} - export interface Language { language: string; proficiency: "basic" | "conversational" | "fluent" | "native"; diff --git a/src/backend/agents/base.py b/src/backend/agents/base.py index 12a7af8..e74fed8 100644 --- a/src/backend/agents/base.py +++ b/src/backend/agents/base.py @@ -137,7 +137,7 @@ class Agent(BaseModel, ABC): # llm: Any, # model: str, # message: ChatMessage, - # tool_message: Any, # llama response message + # tool_message: Any, # llama response # messages: List[LLMMessage], # ) -> AsyncGenerator[ChatMessage, None]: # logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") @@ -270,15 +270,15 @@ class Agent(BaseModel, ABC): # }, # stream=True, # ): - # # logger.info(f"LLM::Tools: {'done' if response.done else 'processing'} - {response.message}") + # # logger.info(f"LLM::Tools: {'done' if response.finish_reason else 'processing'} - {response}") # message.status = "streaming" - # message.chunk = response.message.content + # message.chunk = response.content # message.content += message.chunk - # if not response.done: + # if not response.finish_reason: # yield message - # if response.done: + # if response.finish_reason: # self.collect_metrics(response) # message.metadata.eval_count += response.eval_count # message.metadata.eval_duration += response.eval_duration @@ -296,9 +296,9 @@ class Agent(BaseModel, ABC): def collect_metrics(self, response): self.metrics.tokens_prompt.labels(agent=self.agent_type).inc( - response.prompt_eval_count + response.usage.prompt_eval_count ) - self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count) + self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.usage.eval_count) def get_rag_context(self, rag_message: ChatMessageRagSearch) -> str: """ @@ -361,7 +361,7 @@ Content: {content} yield status_message try: - chroma_results = user.file_watcher.find_similar( + chroma_results = await user.file_watcher.find_similar( query=prompt, top_k=top_k, threshold=threshold ) if not chroma_results: @@ -430,7 +430,7 @@ Content: {content} logger.info(f"Message options: {options.model_dump(exclude_unset=True)}") response = None content = "" - for response in llm.chat( + async for response in llm.chat_stream( model=model, messages=messages, options={ @@ -446,12 +446,12 @@ Content: {content} yield error_message return - content += response.message.content + content += response.content - if not response.done: + if not response.finish_reason: streaming_message = ChatMessageStreaming( session_id=session_id, - content=response.message.content, + content=response.content, status=ApiStatusType.STREAMING, ) yield streaming_message @@ -466,7 +466,7 @@ Content: {content} self.collect_metrics(response) self.context_tokens = ( - response.prompt_eval_count + response.eval_count + response.usage.prompt_eval_count + response.usage.eval_count ) chat_message = ChatMessage( @@ -476,10 +476,10 @@ Content: {content} content=content, metadata = ChatMessageMetaData( options=options, - eval_count=response.eval_count, - eval_duration=response.eval_duration, - prompt_eval_count=response.prompt_eval_count, - prompt_eval_duration=response.prompt_eval_duration, + eval_count=response.usage.eval_count, + eval_duration=response.usage.eval_duration, + prompt_eval_count=response.usage.prompt_eval_count, + prompt_eval_duration=response.usage.prompt_eval_duration, ) ) @@ -588,12 +588,12 @@ Content: {content} # end_time = time.perf_counter() # message.metadata.timers["tool_check"] = end_time - start_time - # if not response.message.tool_calls: + # if not response.tool_calls: # logger.info("LLM indicates tools will not be used") # # The LLM will not use tools, so disable use_tools so we can stream the full response # use_tools = False # else: - # tool_metadata["attempted"] = response.message.tool_calls + # tool_metadata["attempted"] = response.tool_calls # if use_tools: # logger.info("LLM indicates tools will be used") @@ -626,15 +626,15 @@ Content: {content} # yield message # return - # if response.message.tool_calls: - # tool_metadata["used"] = response.message.tool_calls + # if response.tool_calls: + # tool_metadata["used"] = response.tool_calls # # Process all yielded items from the handler # start_time = time.perf_counter() # async for message in self.process_tool_calls( # llm=llm, # model=model, # message=message, - # tool_message=response.message, + # tool_message=response, # messages=messages, # ): # if message.status == "error": @@ -647,7 +647,7 @@ Content: {content} # return # logger.info("LLM indicated tools will be used, and then they weren't") - # message.content = response.message.content + # message.content = response.content # message.status = "done" # yield message # return @@ -674,7 +674,7 @@ Content: {content} content = "" start_time = time.perf_counter() response = None - for response in llm.chat( + async for response in llm.chat_stream( model=model, messages=messages, options={ @@ -690,12 +690,12 @@ Content: {content} yield error_message return - content += response.message.content + content += response.content - if not response.done: + if not response.finish_reason: streaming_message = ChatMessageStreaming( session_id=session_id, - content=response.message.content, + content=response.content, ) yield streaming_message @@ -709,7 +709,7 @@ Content: {content} self.collect_metrics(response) self.context_tokens = ( - response.prompt_eval_count + response.eval_count + response.usage.prompt_eval_count + response.usage.eval_count ) end_time = time.perf_counter() @@ -720,10 +720,10 @@ Content: {content} content=content, metadata = ChatMessageMetaData( options=options, - eval_count=response.eval_count, - eval_duration=response.eval_duration, - prompt_eval_count=response.prompt_eval_count, - prompt_eval_duration=response.prompt_eval_duration, + eval_count=response.usage.eval_count, + eval_duration=response.usage.eval_duration, + prompt_eval_count=response.usage.prompt_eval_count, + prompt_eval_duration=response.usage.prompt_eval_duration, timers={ "llm_streamed": end_time - start_time, "llm_with_tools": 0, # Placeholder for tool processing time diff --git a/src/backend/database.py b/src/backend/database.py index e42fa82..739e6b0 100644 --- a/src/backend/database.py +++ b/src/backend/database.py @@ -178,12 +178,13 @@ class RedisDatabase: 'jobs': 'job:', 'job_applications': 'job_application:', 'chat_sessions': 'chat_session:', - 'chat_messages': 'chat_messages:', # This will store lists + 'chat_messages': 'chat_messages:', 'ai_parameters': 'ai_parameters:', 'users': 'user:', - 'candidate_documents': 'candidate_documents:', + 'candidate_documents': 'candidate_documents:', + 'job_requirements': 'job_requirements:', # Add this line } - + def _serialize(self, data: Any) -> str: """Serialize data to JSON string for Redis storage""" if data is None: @@ -236,8 +237,9 @@ class RedisDatabase: # Delete each document's metadata for doc_id in document_ids: pipe.delete(f"document:{doc_id}") + pipe.delete(f"{self.KEY_PREFIXES['job_requirements']}{doc_id}") deleted_count += 1 - + # Delete the candidate's document list pipe.delete(key) @@ -250,7 +252,110 @@ class RedisDatabase: except Exception as e: logger.error(f"Error deleting all documents for candidate {candidate_id}: {e}") raise - + + async def get_cached_skill_match(self, cache_key: str) -> Optional[Dict[str, Any]]: + """Retrieve cached skill match assessment""" + try: + cached_data = await self.redis.get(cache_key) + if cached_data: + return json.loads(cached_data) + return None + except Exception as e: + logger.error(f"Error retrieving cached skill match: {e}") + return None + + async def cache_skill_match(self, cache_key: str, assessment_data: Dict[str, Any], ttl: int = 86400 * 30) -> bool: + """Cache skill match assessment with TTL (default 30 days)""" + try: + await self.redis.setex( + cache_key, + ttl, + json.dumps(assessment_data, default=str) + ) + return True + except Exception as e: + logger.error(f"Error caching skill match: {e}") + return False + + async def get_candidate_skill_update_time(self, candidate_id: str) -> Optional[datetime]: + """Get the last time candidate's skill information was updated""" + try: + # This assumes you track skill update timestamps in your candidate data + candidate_data = await self.get_candidate(candidate_id) + if candidate_data and 'skills_updated_at' in candidate_data: + return datetime.fromisoformat(candidate_data['skills_updated_at']) + return None + except Exception as e: + logger.error(f"Error getting candidate skill update time: {e}") + return None + + async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]: + """Get the timestamp of the latest RAG data update for a specific user""" + try: + rag_update_key = f"user:{user_id}:rag_last_update" + timestamp_str = await self.redis.get(rag_update_key) + if timestamp_str: + return datetime.fromisoformat(timestamp_str.decode('utf-8')) + return None + except Exception as e: + logger.error(f"Error getting user RAG update time for user {user_id}: {e}") + return None + + async def update_user_rag_timestamp(self, user_id: str) -> bool: + """Update the RAG data timestamp for a specific user (call this when user's RAG data is updated)""" + try: + rag_update_key = f"user:{user_id}:rag_last_update" + current_time = datetime.utcnow().isoformat() + await self.redis.set(rag_update_key, current_time) + return True + except Exception as e: + logger.error(f"Error updating RAG timestamp for user {user_id}: {e}") + return False + + async def invalidate_candidate_skill_cache(self, candidate_id: str) -> int: + """Invalidate all cached skill matches for a specific candidate""" + try: + pattern = f"skill_match:{candidate_id}:*" + keys = await self.redis.keys(pattern) + if keys: + return await self.redis.delete(*keys) + return 0 + except Exception as e: + logger.error(f"Error invalidating candidate skill cache: {e}") + return 0 + + async def clear_all_skill_match_cache(self) -> int: + """Clear all skill match cache (useful after major system updates)""" + try: + pattern = "skill_match:*" + keys = await self.redis.keys(pattern) + if keys: + return await self.redis.delete(*keys) + return 0 + except Exception as e: + logger.error(f"Error clearing skill match cache: {e}") + return 0 + + async def invalidate_user_skill_cache(self, user_id: str) -> int: + """Invalidate all cached skill matches when a user's RAG data is updated""" + try: + # This assumes all candidates belonging to this user need cache invalidation + # You might need to adjust the pattern based on how you associate candidates with users + pattern = f"skill_match:*" + keys = await self.redis.keys(pattern) + + # Filter keys that belong to candidates owned by this user + # This would require additional logic to determine candidate ownership + # For now, you might want to clear all cache when any user's RAG data updates + # or implement a more sophisticated mapping + + if keys: + return await self.redis.delete(*keys) + return 0 + except Exception as e: + logger.error(f"Error invalidating user skill cache for user {user_id}: {e}") + return 0 + async def get_candidate_documents(self, candidate_id: str) -> List[Dict]: """Get all documents for a specific candidate""" key = f"{self.KEY_PREFIXES['candidate_documents']}{candidate_id}" @@ -330,7 +435,134 @@ class RedisDatabase: if (query_lower in doc.get("filename", "").lower() or query_lower in doc.get("originalName", "").lower()) ] - + + async def get_job_requirements(self, document_id: str) -> Optional[Dict]: + """Get cached job requirements analysis for a document""" + try: + key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}" + data = await self.redis.get(key) + if data: + requirements_data = self._deserialize(data) + logger.debug(f"📋 Retrieved cached job requirements for document {document_id}") + return requirements_data + logger.debug(f"📋 No cached job requirements found for document {document_id}") + return None + except Exception as e: + logger.error(f"❌ Error retrieving job requirements for document {document_id}: {e}") + return None + + async def save_job_requirements(self, document_id: str, requirements: Dict) -> bool: + """Save job requirements analysis results for a document""" + try: + key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}" + + # Add metadata to the requirements + requirements_with_meta = { + **requirements, + "cached_at": datetime.now(UTC).isoformat(), + "document_id": document_id + } + + await self.redis.set(key, self._serialize(requirements_with_meta)) + + # Optional: Set expiration (e.g., 30 days) to prevent indefinite storage + # await self.redis.expire(key, 30 * 24 * 60 * 60) # 30 days + + logger.debug(f"📋 Saved job requirements for document {document_id}") + return True + except Exception as e: + logger.error(f"❌ Error saving job requirements for document {document_id}: {e}") + return False + + async def delete_job_requirements(self, document_id: str) -> bool: + """Delete cached job requirements for a document""" + try: + key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}" + result = await self.redis.delete(key) + if result > 0: + logger.debug(f"📋 Deleted job requirements for document {document_id}") + return True + return False + except Exception as e: + logger.error(f"❌ Error deleting job requirements for document {document_id}: {e}") + return False + + async def get_all_job_requirements(self) -> Dict[str, Any]: + """Get all cached job requirements""" + try: + pattern = f"{self.KEY_PREFIXES['job_requirements']}*" + keys = await self.redis.keys(pattern) + + if not keys: + return {} + + pipe = self.redis.pipeline() + for key in keys: + pipe.get(key) + values = await pipe.execute() + + result = {} + for key, value in zip(keys, values): + document_id = key.replace(self.KEY_PREFIXES['job_requirements'], '') + if value: + result[document_id] = self._deserialize(value) + + return result + except Exception as e: + logger.error(f"❌ Error retrieving all job requirements: {e}") + return {} + + async def get_job_requirements_by_candidate(self, candidate_id: str) -> List[Dict]: + """Get all job requirements analysis for documents belonging to a candidate""" + try: + # Get all documents for the candidate + candidate_documents = await self.get_candidate_documents(candidate_id) + + if not candidate_documents: + return [] + + # Get job requirements for each document + job_requirements = [] + for doc in candidate_documents: + doc_id = doc.get("id") + if doc_id: + requirements = await self.get_job_requirements(doc_id) + if requirements: + # Add document metadata to requirements + requirements["document_filename"] = doc.get("filename") + requirements["document_original_name"] = doc.get("originalName") + job_requirements.append(requirements) + + return job_requirements + except Exception as e: + logger.error(f"❌ Error retrieving job requirements for candidate {candidate_id}: {e}") + return [] + + async def invalidate_job_requirements_cache(self, document_id: str) -> bool: + """Invalidate (delete) cached job requirements for a document""" + # This is an alias for delete_job_requirements for semantic clarity + return await self.delete_job_requirements(document_id) + + async def bulk_delete_job_requirements(self, document_ids: List[str]) -> int: + """Delete job requirements for multiple documents and return count of deleted items""" + try: + deleted_count = 0 + pipe = self.redis.pipeline() + + for doc_id in document_ids: + key = f"{self.KEY_PREFIXES['job_requirements']}{doc_id}" + pipe.delete(key) + deleted_count += 1 + + results = await pipe.execute() + actual_deleted = sum(1 for result in results if result > 0) + + logger.info(f"📋 Bulk deleted job requirements for {actual_deleted}/{len(document_ids)} documents") + return actual_deleted + except Exception as e: + logger.error(f"❌ Error bulk deleting job requirements: {e}") + return 0 + # Viewer operations async def get_viewer(self, viewer_id: str) -> Optional[Dict]: """Get viewer by ID""" @@ -1484,6 +1716,74 @@ class RedisDatabase: key = f"{self.KEY_PREFIXES['users']}{email}" await self.redis.delete(key) + async def get_job_requirements_stats(self) -> Dict[str, Any]: + """Get statistics about cached job requirements""" + try: + pattern = f"{self.KEY_PREFIXES['job_requirements']}*" + keys = await self.redis.keys(pattern) + + stats = { + "total_cached_requirements": len(keys), + "cache_dates": {}, + "documents_with_requirements": [] + } + + if keys: + # Get cache dates for analysis + pipe = self.redis.pipeline() + for key in keys: + pipe.get(key) + values = await pipe.execute() + + for key, value in zip(keys, values): + if value: + requirements_data = self._deserialize(value) + if requirements_data: + document_id = key.replace(self.KEY_PREFIXES['job_requirements'], '') + stats["documents_with_requirements"].append(document_id) + + # Track cache dates + cached_at = requirements_data.get("cached_at") + if cached_at: + cache_date = cached_at[:10] # Extract date part + stats["cache_dates"][cache_date] = stats["cache_dates"].get(cache_date, 0) + 1 + + return stats + except Exception as e: + logger.error(f"❌ Error getting job requirements stats: {e}") + return {"total_cached_requirements": 0, "cache_dates": {}, "documents_with_requirements": []} + + async def cleanup_orphaned_job_requirements(self) -> int: + """Clean up job requirements for documents that no longer exist""" + try: + # Get all job requirements + all_requirements = await self.get_all_job_requirements() + + if not all_requirements: + return 0 + + orphaned_count = 0 + pipe = self.redis.pipeline() + + for document_id in all_requirements.keys(): + # Check if the document still exists + document_exists = await self.get_document(document_id) + if not document_exists: + # Document no longer exists, delete its job requirements + key = f"{self.KEY_PREFIXES['job_requirements']}{document_id}" + pipe.delete(key) + orphaned_count += 1 + logger.debug(f"📋 Queued orphaned job requirements for deletion: {document_id}") + + if orphaned_count > 0: + await pipe.execute() + logger.info(f"🧹 Cleaned up {orphaned_count} orphaned job requirements") + + return orphaned_count + except Exception as e: + logger.error(f"❌ Error cleaning up orphaned job requirements: {e}") + return 0 + # Utility methods async def clear_all_data(self): """Clear all data from Redis (use with caution!)""" diff --git a/src/backend/entities/candidate_entity.py b/src/backend/entities/candidate_entity.py index 91a83a1..a35da0a 100644 --- a/src/backend/entities/candidate_entity.py +++ b/src/backend/entities/candidate_entity.py @@ -19,8 +19,9 @@ import defines from logger import logger import agents as agents from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, RagEntry, ChatMessageMetaData, ApiStatusType, Candidate, ChatContextType) -from llm_manager import llm_manager +import llm_proxy as llm_manager from agents.base import Agent +from database import RedisDatabase class CandidateEntity(Candidate): model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc @@ -115,7 +116,7 @@ class CandidateEntity(Candidate): raise ValueError("initialize() has not been called.") return self.CandidateEntity__observer - async def initialize(self, prometheus_collector: CollectorRegistry): + async def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase): if self.CandidateEntity__initialized: # Initialization can only be attempted once; if there are multiple attempts, it means # a subsystem is failing or there is a logic bug in the code. @@ -140,9 +141,11 @@ class CandidateEntity(Candidate): self.CandidateEntity__observer, self.CandidateEntity__file_watcher = start_file_watcher( llm=llm_manager.get_llm(), + user_id=self.id, collection_name=self.username, persist_directory=vector_db_dir, watch_directory=rag_content_dir, + database=database, recreate=False, # Don't recreate if exists ) has_username_rag = any(item["name"] == self.username for item in self.rags) diff --git a/src/backend/entities/entity_manager.py b/src/backend/entities/entity_manager.py index cebb79d..4b535b4 100644 --- a/src/backend/entities/entity_manager.py +++ b/src/backend/entities/entity_manager.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, Field # type: ignore from models import ( Candidate ) from .candidate_entity import CandidateEntity +from database import RedisDatabase from prometheus_client import CollectorRegistry # type: ignore class EntityManager: @@ -34,9 +35,10 @@ class EntityManager: pass self._cleanup_task = None - def initialize(self, prometheus_collector: CollectorRegistry): + def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase): """Initialize the EntityManager with Prometheus collector""" self._prometheus_collector = prometheus_collector + self._database = database async def get_entity(self, candidate: Candidate) -> CandidateEntity: """Get or create CandidateEntity with proper reference tracking""" @@ -49,7 +51,7 @@ class EntityManager: return entity entity = CandidateEntity(candidate=candidate) - await entity.initialize(prometheus_collector=self._prometheus_collector) + await entity.initialize(prometheus_collector=self._prometheus_collector, database=self._database) # Store with reference tracking self._entities[candidate.id] = entity diff --git a/src/backend/llm_manager.py b/src/backend/llm_manager.py deleted file mode 100644 index 43281c8..0000000 --- a/src/backend/llm_manager.py +++ /dev/null @@ -1,41 +0,0 @@ -import ollama -import defines - -_llm = ollama.Client(host=defines.ollama_api_url) # type: ignore - -class llm_manager: - """ - A class to manage LLM operations using the Ollama client. - """ - @staticmethod - def get_llm() -> ollama.Client: # type: ignore - """ - Get the Ollama client instance. - - Returns: - An instance of the Ollama client. - """ - return _llm - - @staticmethod - def get_models() -> list[str]: - """ - Get a list of available models from the Ollama client. - - Returns: - List of model names. - """ - return _llm.models() - - @staticmethod - def get_model_info(model_name: str) -> dict: - """ - Get information about a specific model. - - Args: - model_name: The name of the model to retrieve information for. - - Returns: - A dictionary containing model information. - """ - return _llm.model(model_name) \ No newline at end of file diff --git a/src/backend/llm_proxy.py b/src/backend/llm_proxy.py new file mode 100644 index 0000000..3e0f77c --- /dev/null +++ b/src/backend/llm_proxy.py @@ -0,0 +1,1203 @@ +from abc import ABC, abstractmethod +from typing import Dict, List, Any, AsyncGenerator, Optional, Union +from pydantic import BaseModel, Field # type: ignore +from enum import Enum +import asyncio +import json +from dataclasses import dataclass +import os +import defines +from logger import logger + +# Standard message format for all providers +@dataclass +class LLMMessage: + role: str # "user", "assistant", "system" + content: str + + def __getitem__(self, key: str): + """Allow dictionary-style access for backward compatibility""" + if key == 'role': + return self.role + elif key == 'content': + return self.content + else: + raise KeyError(f"'{key}' not found in LLMMessage") + + def __setitem__(self, key: str, value: str): + """Allow dictionary-style assignment""" + if key == 'role': + self.role = value + elif key == 'content': + self.content = value + else: + raise KeyError(f"'{key}' not found in LLMMessage") + + @classmethod + def from_dict(cls, data: Dict[str, str]) -> 'LLMMessage': + """Create LLMMessage from dictionary""" + return cls(role=data['role'], content=data['content']) + + def to_dict(self) -> Dict[str, str]: + """Convert LLMMessage to dictionary""" + return {'role': self.role, 'content': self.content} + +# Enhanced usage statistics model +class UsageStats(BaseModel): + """Comprehensive usage statistics across all providers""" + # Token counts (standardized across providers) + prompt_tokens: Optional[int] = Field(default=None, description="Number of tokens in the prompt") + completion_tokens: Optional[int] = Field(default=None, description="Number of tokens in the completion") + total_tokens: Optional[int] = Field(default=None, description="Total number of tokens used") + + # Ollama-specific detailed stats + prompt_eval_count: Optional[int] = Field(default=None, description="Number of tokens evaluated in prompt") + prompt_eval_duration: Optional[int] = Field(default=None, description="Time spent evaluating prompt (nanoseconds)") + eval_count: Optional[int] = Field(default=None, description="Number of tokens generated") + eval_duration: Optional[int] = Field(default=None, description="Time spent generating tokens (nanoseconds)") + total_duration: Optional[int] = Field(default=None, description="Total request duration (nanoseconds)") + + # Performance metrics + tokens_per_second: Optional[float] = Field(default=None, description="Generation speed in tokens/second") + prompt_tokens_per_second: Optional[float] = Field(default=None, description="Prompt processing speed") + + # Additional provider-specific stats + extra_stats: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Provider-specific additional statistics") + + def calculate_derived_stats(self) -> None: + """Calculate derived statistics where possible""" + # Calculate tokens per second for Ollama + if self.eval_count and self.eval_duration and self.eval_duration > 0: + # Convert nanoseconds to seconds and calculate tokens/sec + duration_seconds = self.eval_duration / 1_000_000_000 + self.tokens_per_second = self.eval_count / duration_seconds + + # Calculate prompt processing speed for Ollama + if self.prompt_eval_count and self.prompt_eval_duration and self.prompt_eval_duration > 0: + duration_seconds = self.prompt_eval_duration / 1_000_000_000 + self.prompt_tokens_per_second = self.prompt_eval_count / duration_seconds + + # Standardize token counts across providers + if not self.total_tokens and self.prompt_tokens and self.completion_tokens: + self.total_tokens = self.prompt_tokens + self.completion_tokens + + # Map Ollama counts to standard format if not already set + if not self.prompt_tokens and self.prompt_eval_count: + self.prompt_tokens = self.prompt_eval_count + if not self.completion_tokens and self.eval_count: + self.completion_tokens = self.eval_count + +# Embedding response models +class EmbeddingData(BaseModel): + """Single embedding result""" + embedding: List[float] = Field(description="The embedding vector") + index: int = Field(description="Index in the input list") + +class EmbeddingResponse(BaseModel): + """Response from embeddings API""" + data: List[EmbeddingData] = Field(description="List of embedding results") + model: str = Field(description="Model used for embeddings") + usage: Optional[UsageStats] = Field(default=None, description="Usage statistics") + + def get_single_embedding(self) -> List[float]: + """Get the first embedding for single-text requests""" + if not self.data: + raise ValueError("No embeddings in response") + return self.data[0].embedding + + def get_embeddings(self) -> List[List[float]]: + """Get all embeddings as a list of vectors""" + return [item.embedding for item in self.data] + +class ChatResponse(BaseModel): + content: str + model: str + finish_reason: Optional[str] = Field(default="") + usage: Optional[UsageStats] = Field(default=None) + # Keep legacy usage field for backward compatibility + usage_legacy: Optional[Dict[str, int]] = Field(default=None, alias="usage_dict") + + def get_usage_dict(self) -> Dict[str, Any]: + """Get usage statistics as dictionary for backward compatibility""" + if self.usage: + return self.usage.model_dump(exclude_none=True) + return self.usage_legacy or {} + +class LLMProvider(str, Enum): + OLLAMA = "ollama" + OPENAI = "openai" + ANTHROPIC = "anthropic" + GEMINI = "gemini" + GROK = "grok" + +class BaseLLMAdapter(ABC): + """Abstract base class for all LLM adapters""" + + def __init__(self, **config): + self.config = config + + @abstractmethod + async def chat( + self, + model: str, + messages: List[LLMMessage], + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + """Send chat messages and get response""" + pass + + @abstractmethod + async def generate( + self, + model: str, + prompt: str, + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + """Generate text from prompt""" + pass + + @abstractmethod + async def embeddings( + self, + model: str, + input_texts: Union[str, List[str]], + **kwargs + ) -> EmbeddingResponse: + """Generate embeddings for input text(s)""" + pass + + @abstractmethod + async def list_models(self) -> List[str]: + """List available models""" + pass + +class OllamaAdapter(BaseLLMAdapter): + """Adapter for Ollama with enhanced statistics""" + + def __init__(self, **config): + super().__init__(**config) + import ollama + self.client = ollama.AsyncClient( # type: ignore + host=config.get('host', defines.ollama_api_url) + ) + + def _create_usage_stats(self, response_data: Dict[str, Any]) -> UsageStats: + """Create comprehensive usage statistics from Ollama response""" + usage_stats = UsageStats() + + # Extract Ollama-specific stats + if 'prompt_eval_count' in response_data: + usage_stats.prompt_eval_count = response_data['prompt_eval_count'] + if 'prompt_eval_duration' in response_data: + usage_stats.prompt_eval_duration = response_data['prompt_eval_duration'] + if 'eval_count' in response_data: + usage_stats.eval_count = response_data['eval_count'] + if 'eval_duration' in response_data: + usage_stats.eval_duration = response_data['eval_duration'] + if 'total_duration' in response_data: + usage_stats.total_duration = response_data['total_duration'] + + # Store any additional stats + extra_stats = {} + if type(response_data) is dict: + for key, value in response_data.items(): + if key not in ['prompt_eval_count', 'prompt_eval_duration', 'eval_count', + 'eval_duration', 'total_duration', 'message', 'done_reason', 'response']: + extra_stats[key] = value + + if extra_stats: + usage_stats.extra_stats = extra_stats + + # Calculate derived statistics + usage_stats.calculate_derived_stats() + + return usage_stats + + async def chat( + self, + model: str, + messages: List[LLMMessage], + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + + # Convert LLMMessage objects to Ollama format + ollama_messages = [] + for msg in messages: + ollama_messages.append({ + "role": msg.role, + "content": msg.content + }) + + if stream: + return self._stream_chat(model, ollama_messages, **kwargs) + else: + response = await self.client.chat( + model=model, + messages=ollama_messages, + stream=False, + **kwargs + ) + + usage_stats = self._create_usage_stats(response) + + return ChatResponse( + content=response['message']['content'], + model=model, + finish_reason=response.get('done_reason'), + usage=usage_stats + ) + + async def _stream_chat(self, model: str, messages: List[Dict], **kwargs): + # Await the chat call first, then iterate over the result + stream = await self.client.chat( + model=model, + messages=messages, + stream=True, + **kwargs + ) + + # Accumulate stats for final chunk + accumulated_stats = {} + + async for chunk in stream: + # Update accumulated stats + accumulated_stats.update(chunk) + content = chunk.get('message', {}).get('content', '') + usage_stats = None + if chunk.done: + usage_stats = self._create_usage_stats(accumulated_stats) + yield ChatResponse( + content=content, + model=model, + finish_reason=chunk.get('done_reason'), + usage=usage_stats + ) + else: + yield ChatResponse( + content=content, + model=model, + finish_reason=None, + usage=None + ) + + async def generate( + self, + model: str, + prompt: str, + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + + if stream: + return self._stream_generate(model, prompt, **kwargs) + else: + response = await self.client.generate( + model=model, + prompt=prompt, + stream=False, + **kwargs + ) + + usage_stats = self._create_usage_stats(response) + + return ChatResponse( + content=response['response'], + model=model, + finish_reason=response.get('done_reason'), + usage=usage_stats + ) + + async def _stream_generate(self, model: str, prompt: str, **kwargs): + # Await the generate call first, then iterate over the result + stream = await self.client.generate( + model=model, + prompt=prompt, + stream=True, + **kwargs + ) + + accumulated_stats = {} + + async for chunk in stream: + if chunk.get('response'): + accumulated_stats.update(chunk) + + usage_stats = None + if chunk.get('done', False): # Only create stats for final chunk + usage_stats = self._create_usage_stats(accumulated_stats) + + yield ChatResponse( + content=chunk['response'], + model=model, + finish_reason=chunk.get('done_reason'), + usage=usage_stats + ) + + async def embeddings( + self, + model: str, + input_texts: Union[str, List[str]], + **kwargs + ) -> EmbeddingResponse: + """Generate embeddings using Ollama""" + # Normalize input to list + if isinstance(input_texts, str): + texts = [input_texts] + else: + texts = input_texts + + # Ollama embeddings API typically handles one text at a time + results = [] + final_response = None + + for i, text in enumerate(texts): + response = await self.client.embeddings( + model=model, + prompt=text, + **kwargs + ) + + results.append(EmbeddingData( + embedding=response['embedding'], + index=i + )) + final_response = response + + # Create usage stats if available from the last response + usage_stats = None + if final_response and len(results) == 1: + usage_stats = self._create_usage_stats(final_response) + + return EmbeddingResponse( + data=results, + model=model, + usage=usage_stats + ) + + async def list_models(self) -> List[str]: + models = await self.client.list() + return [model['name'] for model in models['models']] + +class OpenAIAdapter(BaseLLMAdapter): + """Adapter for OpenAI with enhanced statistics""" + + def __init__(self, **config): + super().__init__(**config) + import openai # type: ignore + self.client = openai.AsyncOpenAI( + api_key=config.get('api_key', os.getenv('OPENAI_API_KEY')) + ) + + def _create_usage_stats(self, usage_data: Any) -> UsageStats: + """Create usage statistics from OpenAI response""" + if not usage_data: + return UsageStats() + + usage_dict = usage_data.model_dump() if hasattr(usage_data, 'model_dump') else usage_data + + return UsageStats( + prompt_tokens=usage_dict.get('prompt_tokens'), + completion_tokens=usage_dict.get('completion_tokens'), + total_tokens=usage_dict.get('total_tokens'), + extra_stats={k: v for k, v in usage_dict.items() + if k not in ['prompt_tokens', 'completion_tokens', 'total_tokens']} + ) + + async def chat( + self, + model: str, + messages: List[LLMMessage], + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + + # Convert LLMMessage objects to OpenAI format + openai_messages = [] + for msg in messages: + openai_messages.append({ + "role": msg.role, + "content": msg.content + }) + + if stream: + return self._stream_chat(model, openai_messages, **kwargs) + else: + response = await self.client.chat.completions.create( + model=model, + messages=openai_messages, + stream=False, + **kwargs + ) + + usage_stats = self._create_usage_stats(response.usage) + + return ChatResponse( + content=response.choices[0].message.content, + model=model, + finish_reason=response.choices[0].finish_reason, + usage=usage_stats + ) + + async def _stream_chat(self, model: str, messages: List[Dict], **kwargs): + # Await the stream creation first, then iterate + stream = await self.client.chat.completions.create( + model=model, + messages=messages, + stream=True, + **kwargs + ) + + async for chunk in stream: + if chunk.choices[0].delta.content: + # Usage stats are only available in the final chunk for OpenAI streaming + usage_stats = None + if (chunk.choices[0].finish_reason is not None and + hasattr(chunk, 'usage') and chunk.usage): + usage_stats = self._create_usage_stats(chunk.usage) + + yield ChatResponse( + content=chunk.choices[0].delta.content, + model=model, + finish_reason=chunk.choices[0].finish_reason, + usage=usage_stats + ) + + async def generate( + self, + model: str, + prompt: str, + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + + # Convert to chat format for OpenAI + messages = [LLMMessage(role="user", content=prompt)] + return await self.chat(model, messages, stream, **kwargs) + + async def embeddings( + self, + model: str, + input_texts: Union[str, List[str]], + **kwargs + ) -> EmbeddingResponse: + """Generate embeddings using OpenAI""" + # Normalize input to list + if isinstance(input_texts, str): + texts = [input_texts] + else: + texts = input_texts + + response = await self.client.embeddings.create( + model=model, + input=texts, + **kwargs + ) + + # Convert OpenAI response to our format + results = [] + for item in response.data: + results.append(EmbeddingData( + embedding=item.embedding, + index=item.index + )) + + # Create usage stats + usage_stats = self._create_usage_stats(response.usage) + + return EmbeddingResponse( + data=results, + model=model, + usage=usage_stats + ) + + async def list_models(self) -> List[str]: + models = await self.client.models.list() + return [model.id for model in models.data] + +class AnthropicAdapter(BaseLLMAdapter): + """Adapter for Anthropic Claude with enhanced statistics""" + + def __init__(self, **config): + super().__init__(**config) + import anthropic # type: ignore + self.client = anthropic.AsyncAnthropic( + api_key=config.get('api_key', os.getenv('ANTHROPIC_API_KEY')) + ) + + def _create_usage_stats(self, usage_data: Any) -> UsageStats: + """Create usage statistics from Anthropic response""" + if not usage_data: + return UsageStats() + + usage_dict = usage_data.model_dump() if hasattr(usage_data, 'model_dump') else usage_data + + return UsageStats( + prompt_tokens=usage_dict.get('input_tokens'), + completion_tokens=usage_dict.get('output_tokens'), + total_tokens=(usage_dict.get('input_tokens', 0) + usage_dict.get('output_tokens', 0)) or None, + extra_stats={k: v for k, v in usage_dict.items() + if k not in ['input_tokens', 'output_tokens']} + ) + + async def chat( + self, + model: str, + messages: List[LLMMessage], + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + + # Anthropic requires system message to be separate + system_message = None + anthropic_messages = [] + + for msg in messages: + if msg.role == "system": + system_message = msg.content + else: + anthropic_messages.append({ + "role": msg.role, + "content": msg.content + }) + + request_kwargs = { + "model": model, + "messages": anthropic_messages, + "max_tokens": kwargs.pop('max_tokens', 1000), + **kwargs + } + + if system_message: + request_kwargs["system"] = system_message + + if stream: + return self._stream_chat(**request_kwargs) + else: + response = await self.client.messages.create( + stream=False, + **request_kwargs + ) + + usage_stats = self._create_usage_stats(response.usage) + + return ChatResponse( + content=response.content[0].text, + model=model, + finish_reason=response.stop_reason, + usage=usage_stats + ) + + async def _stream_chat(self, **kwargs): + model = kwargs['model'] + + async with self.client.messages.stream(**kwargs) as stream: + final_usage_stats = None + + async for text in stream.text_stream: + yield ChatResponse( + content=text, + model=model, + usage=None # Usage stats not available until stream completion + ) + + # Collect usage stats after stream completion + try: + if hasattr(stream, 'get_final_message'): + final_message = await stream.get_final_message() + if final_message and hasattr(final_message, 'usage'): + final_usage_stats = self._create_usage_stats(final_message.usage) + + # Yield a final empty response with usage stats + yield ChatResponse( + content="", + model=model, + usage=final_usage_stats, + finish_reason="stop" + ) + except Exception as e: + logger.debug(f"Could not retrieve final usage stats: {e}") + + async def generate( + self, + model: str, + prompt: str, + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + + messages = [LLMMessage(role="user", content=prompt)] + return await self.chat(model, messages, stream, **kwargs) + + async def embeddings( + self, + model: str, + input_texts: Union[str, List[str]], + **kwargs + ) -> EmbeddingResponse: + """Anthropic doesn't provide embeddings API""" + raise NotImplementedError( + "Anthropic does not provide embeddings API. " + "Consider using OpenAI, Ollama, or Gemini for embeddings." + ) + + async def list_models(self) -> List[str]: + # Anthropic doesn't have a list models endpoint, return known models + return [ + "claude-3-5-sonnet-20241022", + "claude-3-5-haiku-20241022", + "claude-3-opus-20240229" + ] + +class GeminiAdapter(BaseLLMAdapter): + """Adapter for Google Gemini with enhanced statistics""" + + def __init__(self, **config): + super().__init__(**config) + import google.generativeai as genai # type: ignore + genai.configure(api_key=config.get('api_key', os.getenv('GEMINI_API_KEY'))) + self.genai = genai + + def _create_usage_stats(self, response: Any) -> UsageStats: + """Create usage statistics from Gemini response""" + usage_stats = UsageStats() + + # Gemini usage metadata extraction + if hasattr(response, 'usage_metadata'): + usage = response.usage_metadata + usage_stats.prompt_tokens = getattr(usage, 'prompt_token_count', None) + usage_stats.completion_tokens = getattr(usage, 'candidates_token_count', None) + usage_stats.total_tokens = getattr(usage, 'total_token_count', None) + + # Store additional Gemini-specific stats + extra_stats = {} + for attr in dir(usage): + if not attr.startswith('_') and attr not in ['prompt_token_count', 'candidates_token_count', 'total_token_count']: + extra_stats[attr] = getattr(usage, attr) + + if extra_stats: + usage_stats.extra_stats = extra_stats + + return usage_stats + + async def chat( + self, + model: str, + messages: List[LLMMessage], + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + + model_instance = self.genai.GenerativeModel(model) + + # Convert messages to Gemini format + chat_history = [] + current_message = None + + for i, msg in enumerate(messages[:-1]): # All but last message go to history + if msg.role == "user": + chat_history.append({"role": "user", "parts": [msg.content]}) + elif msg.role == "assistant": + chat_history.append({"role": "model", "parts": [msg.content]}) + + # Last message is the current prompt + if messages: + current_message = messages[-1].content + + chat = model_instance.start_chat(history=chat_history) + if not current_message: + raise ValueError("No current message provided for chat") + + if stream: + return self._stream_chat(chat, current_message, **kwargs) + else: + response = await chat.send_message_async(current_message, **kwargs) + usage_stats = self._create_usage_stats(response) + + return ChatResponse( + content=response.text, + model=model, + finish_reason=response.candidates[0].finish_reason.name if response.candidates else None, + usage=usage_stats + ) + + async def _stream_chat(self, chat, message: str, **kwargs): + # Await the stream creation first, then iterate + stream = await chat.send_message_async(message, stream=True, **kwargs) + + final_chunk = None + + async for chunk in stream: + if chunk.text: + final_chunk = chunk # Keep reference to last chunk + + yield ChatResponse( + content=chunk.text, + model=chat.model.model_name, + usage=None # Don't include usage stats until final chunk + ) + + # After streaming is complete, yield final response with usage stats + if final_chunk: + usage_stats = self._create_usage_stats(final_chunk) + if usage_stats and any([usage_stats.prompt_tokens, usage_stats.completion_tokens, usage_stats.total_tokens]): + yield ChatResponse( + content="", + model=chat.model.model_name, + usage=usage_stats, + finish_reason=final_chunk.candidates[0].finish_reason.name if final_chunk.candidates else None + ) + + async def generate( + self, + model: str, + prompt: str, + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + + messages = [LLMMessage(role="user", content=prompt)] + return await self.chat(model, messages, stream, **kwargs) + + async def embeddings( + self, + model: str, + input_texts: Union[str, List[str]], + **kwargs + ) -> EmbeddingResponse: + """Generate embeddings using Google Gemini""" + # Normalize input to list + if isinstance(input_texts, str): + texts = [input_texts] + else: + texts = input_texts + + results = [] + + # Gemini embeddings - use embedding model + for i, text in enumerate(texts): + response = self.genai.embed_content( + model=model, + content=text, + **kwargs + ) + + results.append(EmbeddingData( + embedding=response['embedding'], + index=i + )) + + return EmbeddingResponse( + data=results, + model=model, + usage=None # Gemini embeddings don't typically return usage stats + ) + + async def list_models(self) -> List[str]: + models = self.genai.list_models() + return [model.name for model in models if 'generateContent' in model.supported_generation_methods] + +class UnifiedLLMProxy: + """Main proxy class that provides unified interface to all LLM providers""" + + def __init__(self, default_provider: LLMProvider = LLMProvider.OLLAMA): + self.adapters: Dict[LLMProvider, BaseLLMAdapter] = {} + self.default_provider = default_provider + self._initialized_providers = set() + + def configure_provider(self, provider: LLMProvider, **config): + """Configure a specific provider with its settings""" + adapter_classes = { + LLMProvider.OLLAMA: OllamaAdapter, + LLMProvider.OPENAI: OpenAIAdapter, + LLMProvider.ANTHROPIC: AnthropicAdapter, + LLMProvider.GEMINI: GeminiAdapter, + # Add other providers as needed + } + + if provider in adapter_classes: + self.adapters[provider] = adapter_classes[provider](**config) + self._initialized_providers.add(provider) + else: + raise ValueError(f"Unsupported provider: {provider}") + + def set_default_provider(self, provider: LLMProvider): + """Set the default provider for requests""" + if provider not in self._initialized_providers: + raise ValueError(f"Provider {provider} not configured") + self.default_provider = provider + + async def chat( + self, + model: str, + messages: Union[List[LLMMessage], List[Dict[str, str]]], + provider: Optional[LLMProvider] = None, + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + """Send chat messages using specified or default provider""" + + provider = provider or self.default_provider + adapter = self._get_adapter(provider) + + # Normalize messages to LLMMessage objects + normalized_messages = [] + for msg in messages: + if isinstance(msg, LLMMessage): + normalized_messages.append(msg) + elif isinstance(msg, dict): + normalized_messages.append(LLMMessage.from_dict(msg)) + else: + raise ValueError(f"Invalid message type: {type(msg)}") + + return await adapter.chat(model, normalized_messages, stream, **kwargs) + + async def chat_stream( + self, + model: str, + messages: Union[List[LLMMessage], List[Dict[str, str]]], + provider: Optional[LLMProvider] = None, + stream: bool = True, + **kwargs + ) -> AsyncGenerator[ChatResponse, None]: + """Stream chat messages using specified or default provider""" + if stream is False: + raise ValueError("stream must be True for chat_stream") + result = await self.chat(model, messages, provider, stream=True, **kwargs) + # Type checker now knows this is an AsyncGenerator due to stream=True + async for chunk in result: # type: ignore + yield chunk + + async def chat_single( + self, + model: str, + messages: Union[List[LLMMessage], List[Dict[str, str]]], + provider: Optional[LLMProvider] = None, + **kwargs + ) -> ChatResponse: + """Get single chat response using specified or default provider""" + + result = await self.chat(model, messages, provider, stream=False, **kwargs) + # Type checker now knows this is a ChatResponse due to stream=False + return result # type: ignore + + async def generate( + self, + model: str, + prompt: str, + provider: Optional[LLMProvider] = None, + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + """Generate text using specified or default provider""" + + provider = provider or self.default_provider + adapter = self._get_adapter(provider) + + return await adapter.generate(model, prompt, stream, **kwargs) + + async def generate_stream( + self, + model: str, + prompt: str, + provider: Optional[LLMProvider] = None, + **kwargs + ) -> AsyncGenerator[ChatResponse, None]: + """Stream text generation using specified or default provider""" + + result = await self.generate(model, prompt, provider, stream=True, **kwargs) + async for chunk in result: # type: ignore + yield chunk + + async def generate_single( + self, + model: str, + prompt: str, + provider: Optional[LLMProvider] = None, + **kwargs + ) -> ChatResponse: + """Get single generation response using specified or default provider""" + + result = await self.generate(model, prompt, provider, stream=False, **kwargs) + return result # type: ignore + + async def embeddings( + self, + model: str, + input_texts: Union[str, List[str]], + provider: Optional[LLMProvider] = None, + **kwargs + ) -> EmbeddingResponse: + """Generate embeddings using specified or default provider""" + provider = provider or self.default_provider + adapter = self._get_adapter(provider) + + return await adapter.embeddings(model, input_texts, **kwargs) + + async def list_models(self, provider: Optional[LLMProvider] = None) -> List[str]: + """List available models for specified or default provider""" + + provider = provider or self.default_provider + adapter = self._get_adapter(provider) + + return await adapter.list_models() + + async def list_embedding_models(self, provider: Optional[LLMProvider] = None) -> List[str]: + """List available embedding models for specified or default provider""" + provider = provider or self.default_provider + + # Provider-specific embedding models + embedding_models = { + LLMProvider.OLLAMA: [ + "nomic-embed-text", + "mxbai-embed-large", + "all-minilm", + "snowflake-arctic-embed" + ], + LLMProvider.OPENAI: [ + "text-embedding-3-small", + "text-embedding-3-large", + "text-embedding-ada-002" + ], + LLMProvider.ANTHROPIC: [], # No embeddings API + LLMProvider.GEMINI: [ + "models/embedding-001", + "models/text-embedding-004" + ] + } + + if provider == LLMProvider.ANTHROPIC: + raise NotImplementedError("Anthropic does not provide embeddings API") + + # For Ollama, check which embedding models are actually available + if provider == LLMProvider.OLLAMA: + try: + all_models = await self.list_models(provider) + available_embedding_models = [] + suggested_models = embedding_models[provider] + + for model in all_models: + # Check if it's a known embedding model + if any(emb_model in model for emb_model in suggested_models): + available_embedding_models.append(model) + + return available_embedding_models if available_embedding_models else suggested_models + except: + return embedding_models[provider] + + return embedding_models.get(provider, []) + + def _get_adapter(self, provider: LLMProvider) -> BaseLLMAdapter: + """Get adapter for specified provider""" + if provider not in self.adapters: + raise ValueError(f"Provider {provider} not configured") + return self.adapters[provider] + +# Example usage and configuration +class LLMManager: + """Singleton manager for the unified LLM proxy""" + + _instance = None + _proxy = None + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def __init__(self): + if LLMManager._proxy is None: + LLMManager._proxy = UnifiedLLMProxy() + self._configure_from_environment() + + def _configure_from_environment(self): + """Configure providers based on environment variables""" + if not self._proxy: + raise RuntimeError("UnifiedLLMProxy instance not initialized") + # Configure Ollama if available + ollama_host = os.getenv('OLLAMA_HOST', defines.ollama_api_url) + self._proxy.configure_provider(LLMProvider.OLLAMA, host=ollama_host) + + # Configure OpenAI if API key is available + if os.getenv('OPENAI_API_KEY'): + self._proxy.configure_provider(LLMProvider.OPENAI) + + # Configure Anthropic if API key is available + if os.getenv('ANTHROPIC_API_KEY'): + self._proxy.configure_provider(LLMProvider.ANTHROPIC) + + # Configure Gemini if API key is available + if os.getenv('GEMINI_API_KEY'): + self._proxy.configure_provider(LLMProvider.GEMINI) + + # Set default provider from environment or use Ollama + default_provider = os.getenv('DEFAULT_LLM_PROVIDER', 'ollama') + try: + self._proxy.set_default_provider(LLMProvider(default_provider)) + except ValueError: + # Fallback to Ollama if specified provider not available + self._proxy.set_default_provider(LLMProvider.OLLAMA) + + def get_proxy(self) -> UnifiedLLMProxy: + """Get the unified LLM proxy instance""" + if not self._proxy: + raise RuntimeError("UnifiedLLMProxy instance not initialized") + return self._proxy + +# Convenience function for easy access +def get_llm() -> UnifiedLLMProxy: + """Get the configured LLM proxy""" + return LLMManager.get_instance().get_proxy() + +# Example usage with detailed statistics +async def example_usage(): + """Example showing how to access detailed statistics""" + llm = get_llm() + + # Configure providers + llm.configure_provider(LLMProvider.OLLAMA, host="http://localhost:11434") + + # Simple chat + messages = [ + LLMMessage(role="user", content="Explain quantum computing in one paragraph") + ] + + response = await llm.chat_single("llama2", messages) + + print(f"Content: {response.content}") + print(f"Model: {response.model}") + + if response.usage: + print(f"Usage Statistics:") + print(f" Prompt tokens: {response.usage.prompt_tokens}") + print(f" Completion tokens: {response.usage.completion_tokens}") + print(f" Total tokens: {response.usage.total_tokens}") + + # Ollama-specific stats + if response.usage.prompt_eval_count: + print(f" Prompt eval count: {response.usage.prompt_eval_count}") + if response.usage.eval_count: + print(f" Eval count: {response.usage.eval_count}") + if response.usage.tokens_per_second: + print(f" Generation speed: {response.usage.tokens_per_second:.2f} tokens/sec") + if response.usage.prompt_tokens_per_second: + print(f" Prompt processing speed: {response.usage.prompt_tokens_per_second:.2f} tokens/sec") + + # Access as dictionary for backward compatibility + usage_dict = response.get_usage_dict() + print(f"Usage as dict: {usage_dict}") + +async def example_embeddings_usage(): + """Example showing how to use the embeddings API""" + llm = get_llm() + + # Configure providers + llm.configure_provider(LLMProvider.OLLAMA, host="http://localhost:11434") + if os.getenv('OPENAI_API_KEY'): + llm.configure_provider(LLMProvider.OPENAI) + + # List available embedding models + print("=== Available Embedding Models ===") + try: + ollama_embedding_models = await llm.list_embedding_models(LLMProvider.OLLAMA) + print(f"Ollama embedding models: {ollama_embedding_models}") + except Exception as e: + print(f"Could not list Ollama embedding models: {e}") + + if os.getenv('OPENAI_API_KEY'): + try: + openai_embedding_models = await llm.list_embedding_models(LLMProvider.OPENAI) + print(f"OpenAI embedding models: {openai_embedding_models}") + except Exception as e: + print(f"Could not list OpenAI embedding models: {e}") + + # Single text embedding + print("\n=== Single Text Embedding ===") + single_text = "The quick brown fox jumps over the lazy dog" + + try: + response = await llm.embeddings("nomic-embed-text", single_text, provider=LLMProvider.OLLAMA) + print(f"Model: {response.model}") + print(f"Embedding dimension: {len(response.get_single_embedding())}") + print(f"First 5 values: {response.get_single_embedding()[:5]}") + + if response.usage: + print(f"Usage: {response.usage.model_dump(exclude_none=True)}") + except Exception as e: + print(f"Ollama embedding failed: {e}") + + # Batch embeddings + print("\n=== Batch Text Embeddings ===") + texts = [ + "Machine learning is fascinating", + "Natural language processing enables AI communication", + "Deep learning uses neural networks", + "Transformers revolutionized AI" + ] + + try: + # Try OpenAI if available + if os.getenv('OPENAI_API_KEY'): + response = await llm.embeddings("text-embedding-3-small", texts, provider=LLMProvider.OPENAI) + print(f"OpenAI Model: {response.model}") + print(f"Number of embeddings: {len(response.data)}") + print(f"Embedding dimension: {len(response.data[0].embedding)}") + + # Calculate similarity between first two texts (requires numpy) + try: + import numpy as np # type: ignore + emb1 = np.array(response.data[0].embedding) + emb2 = np.array(response.data[1].embedding) + similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2)) + print(f"Cosine similarity between first two texts: {similarity:.4f}") + except ImportError: + print("Install numpy to calculate similarity: pip install numpy") + + if response.usage: + print(f"Usage: {response.usage.model_dump(exclude_none=True)}") + else: + # Fallback to Ollama for batch + response = await llm.embeddings("nomic-embed-text", texts, provider=LLMProvider.OLLAMA) + print(f"Ollama Model: {response.model}") + print(f"Number of embeddings: {len(response.data)}") + print(f"Embedding dimension: {len(response.data[0].embedding)}") + + except Exception as e: + print(f"Batch embedding failed: {e}") + +async def example_streaming_with_stats(): + """Example showing how to collect usage stats from streaming responses""" + llm = get_llm() + + messages = [ + LLMMessage(role="user", content="Write a short story about AI") + ] + + print("Streaming response:") + final_stats = None + + async for chunk in llm.chat_stream("llama2", messages): + # Print content as it streams + if chunk.content: + print(chunk.content, end="", flush=True) + + # Collect usage stats from final chunk + if chunk.usage: + final_stats = chunk.usage + + print("\n\nFinal Usage Statistics:") + if final_stats: + print(f" Total tokens: {final_stats.total_tokens}") + print(f" Generation speed: {final_stats.tokens_per_second:.2f} tokens/sec" if final_stats.tokens_per_second else " Generation speed: N/A") + else: + print(" No usage statistics available") + +if __name__ == "__main__": + asyncio.run(example_usage()) + print("\n" + "="*50 + "\n") + asyncio.run(example_streaming_with_stats()) + print("\n" + "="*50 + "\n") + asyncio.run(example_embeddings_usage()) \ No newline at end of file diff --git a/src/backend/main.py b/src/backend/main.py index e6b1ce4..4af0067 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -13,6 +13,9 @@ import uuid import defines import pathlib +from markitdown import MarkItDown, StreamInfo # type: ignore +import io + import uvicorn # type: ignore from typing import List, Optional, Dict, Any from datetime import datetime, timedelta, UTC @@ -53,7 +56,7 @@ import defines from logger import logger from database import RedisDatabase, redis_manager, DatabaseManager from metrics import Metrics -from llm_manager import llm_manager +import llm_proxy as llm_manager import entities from email_service import VerificationEmailRateLimiter, email_service from device_manager import DeviceManager @@ -116,7 +119,8 @@ async def lifespan(app: FastAPI): try: # Initialize database await db_manager.initialize() - + entities.entity_manager.initialize(prometheus_collector, database=db_manager.get_database()) + signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) @@ -1827,7 +1831,7 @@ async def upload_candidate_document( yield error_message return - converted = False; + converted = False if document_type != DocumentType.MARKDOWN and document_type != DocumentType.TXT: p = pathlib.Path(file_path) p_as_md = p.with_suffix(".md") @@ -1873,53 +1877,6 @@ async def upload_candidate_document( content=file_content, ) yield chat_message - - # If this is a job description, process it with the job requirements agent - if not options.is_job_document: - return - - status_message = ChatMessageStatus( - session_id=MOCK_UUID, # No session ID for document uploads - content=f"Initiating connection with {candidate.first_name}'s AI agent...", - activity=ApiActivityType.INFO - ) - yield status_message - await asyncio.sleep(0) - - async with entities.get_candidate_entity(candidate=candidate) as candidate_entity: - chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS) - if not chat_agent: - error_message = ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads - content="No agent found for job requirements chat type" - ) - yield error_message - return - message = None - status_message = ChatMessageStatus( - session_id=MOCK_UUID, # No session ID for document uploads - content=f"Analyzing document for company and requirement details...", - activity=ApiActivityType.SEARCHING - ) - yield status_message - await asyncio.sleep(0) - - async for message in chat_agent.generate( - llm=llm_manager.get_llm(), - model=defines.model, - session_id=MOCK_UUID, - prompt=file_content - ): - pass - if not message or not isinstance(message, JobRequirementsMessage): - error_message = ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads - content="Failed to process job description file" - ) - yield error_message - return - yield message - try: async def to_json(method): try: @@ -1932,7 +1889,6 @@ async def upload_candidate_document( logger.error(f"Error in to_json conversion: {e}") return -# return DebugStreamingResponse( return StreamingResponse( to_json(upload_stream_generator(file_content)), media_type="text/event-stream", @@ -1944,15 +1900,64 @@ async def upload_candidate_document( "Access-Control-Allow-Origin": "*", # Adjust for your CORS needs "Transfer-Encoding": "chunked", }, - ) + ) except Exception as e: logger.error(backstory_traceback.format_exc()) logger.error(f"❌ Document upload error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("UPLOAD_ERROR", "Failed to upload document") + return StreamingResponse( + iter([ChatMessageError( + session_id=MOCK_UUID, # No session ID for document uploads + content="Failed to upload document" + )]), + media_type="text/event-stream" ) +async def create_job_from_content(database: RedisDatabase, current_user: Candidate, content: str): + status_message = ChatMessageStatus( + session_id=MOCK_UUID, # No session ID for document uploads + content=f"Initiating connection with {current_user.first_name}'s AI agent...", + activity=ApiActivityType.INFO + ) + yield status_message + await asyncio.sleep(0) # Let the status message propagate + + async with entities.get_candidate_entity(candidate=current_user) as candidate_entity: + chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS) + if not chat_agent: + error_message = ChatMessageError( + session_id=MOCK_UUID, # No session ID for document uploads + content="No agent found for job requirements chat type" + ) + yield error_message + return + message = None + status_message = ChatMessageStatus( + session_id=MOCK_UUID, # No session ID for document uploads + content=f"Analyzing document for company and requirement details...", + activity=ApiActivityType.SEARCHING + ) + yield status_message + await asyncio.sleep(0) + + async for message in chat_agent.generate( + llm=llm_manager.get_llm(), + model=defines.model, + session_id=MOCK_UUID, + prompt=content + ): + pass + if not message or not isinstance(message, JobRequirementsMessage): + error_message = ChatMessageError( + session_id=MOCK_UUID, # No session ID for document uploads + content="Failed to process job description file" + ) + yield error_message + return + + logger.info(f"✅ Successfully saved job requirements job {message.id}") + yield message + return + @api_router.post("/candidates/profile/upload") async def upload_candidate_profile( file: UploadFile = File(...), @@ -2573,6 +2578,7 @@ async def delete_candidate( status_code=500, content=create_error_response("DELETE_ERROR", "Failed to delete candidate") ) + @api_router.patch("/candidates/{candidate_id}") async def update_candidate( candidate_id: str = Path(...), @@ -2816,6 +2822,139 @@ async def create_candidate_job( content=create_error_response("CREATION_FAILED", str(e)) ) + +@api_router.post("/jobs/upload") +async def create_job_from_file( + file: UploadFile = File(...), + current_user = Depends(get_current_user), + database: RedisDatabase = Depends(get_database) +): + """Upload a job document for the current candidate and create a Job""" + # Check file size (limit to 10MB) + max_size = 10 * 1024 * 1024 # 10MB + file_content = await file.read() + if len(file_content) > max_size: + logger.info(f"⚠️ File too large: {file.filename} ({len(file_content)} bytes)") + return StreamingResponse( + iter([ChatMessageError( + session_id=MOCK_UUID, # No session ID for document uploads + content="File size exceeds 10MB limit" + )]), + media_type="text/event-stream" + ) + if len(file_content) == 0: + logger.info(f"⚠️ File is empty: {file.filename}") + return StreamingResponse( + iter([ChatMessageError( + session_id=MOCK_UUID, # No session ID for document uploads + content="File is empty" + )]), + media_type="text/event-stream" + ) + + """Upload a document for the current candidate""" + async def upload_stream_generator(file_content): + # Verify user is a candidate + if current_user.user_type != "candidate": + logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}") + error_message = ChatMessageError( + session_id=MOCK_UUID, # No session ID for document uploads + content="Only candidates can upload documents" + ) + yield error_message + return + + file.filename = re.sub(r'^.*/', '', file.filename) if file.filename else '' # Sanitize filename + if not file.filename or file.filename.strip() == "": + logger.warning("⚠️ File upload attempt with missing filename") + error_message = ChatMessageError( + session_id=MOCK_UUID, # No session ID for document uploads + content="File must have a valid filename" + ) + yield error_message + return + + logger.info(f"📁 Received file upload: filename='{file.filename}', content_type='{file.content_type}', size='{len(file_content)} bytes'") + + # Validate file type + allowed_types = ['.txt', '.md', '.docx', '.pdf', '.png', '.jpg', '.jpeg', '.gif'] + file_extension = pathlib.Path(file.filename).suffix.lower() if file.filename else "" + + if file_extension not in allowed_types: + logger.warning(f"⚠️ Invalid file type: {file_extension} for file {file.filename}") + error_message = ChatMessageError( + session_id=MOCK_UUID, # No session ID for document uploads + content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}" + ) + yield error_message + return + + document_type = get_document_type_from_filename(file.filename or "unknown.txt") + + if document_type != DocumentType.MARKDOWN and document_type != DocumentType.TXT: + status_message = ChatMessageStatus( + session_id=MOCK_UUID, # No session ID for document uploads + content=f"Converting content from {document_type}...", + activity=ApiActivityType.CONVERTING + ) + yield status_message + try: + md = MarkItDown(enable_plugins=False) # Set to True to enable plugins + stream = io.BytesIO(file_content) + stream_info = StreamInfo( + extension=file_extension, # e.g., ".pdf" + url=file.filename # optional, helps with logging and guessing + ) + result = md.convert_stream(stream, stream_info=stream_info, output_format="markdown") + file_content = result.text_content + logger.info(f"✅ Converted {file.filename} to Markdown format") + except Exception as e: + error_message = ChatMessageError( + session_id=MOCK_UUID, # No session ID for document uploads + content=f"Failed to convert {file.filename} to Markdown.", + ) + yield error_message + logger.error(f"❌ Error converting {file.filename} to Markdown: {e}") + return + async for message in create_job_from_content(database=database, current_user=current_user, content=file_content): + yield message + return + + try: + async def to_json(method): + try: + async for message in method: + json_data = message.model_dump(mode='json', by_alias=True) + json_str = json.dumps(json_data) + yield f"data: {json_str}\n\n".encode("utf-8") + except Exception as e: + logger.error(backstory_traceback.format_exc()) + logger.error(f"Error in to_json conversion: {e}") + return + + return StreamingResponse( + to_json(upload_stream_generator(file_content)), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache, no-store, must-revalidate", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", # Nginx + "X-Content-Type-Options": "nosniff", + "Access-Control-Allow-Origin": "*", # Adjust for your CORS needs + "Transfer-Encoding": "chunked", + }, + ) + except Exception as e: + logger.error(backstory_traceback.format_exc()) + logger.error(f"❌ Document upload error: {e}") + return StreamingResponse( + iter([ChatMessageError( + session_id=MOCK_UUID, # No session ID for document uploads + content="Failed to upload document" + )]), + media_type="text/event-stream" + ) + @api_router.get("/jobs/{job_id}") async def get_job( job_id: str = Path(...), @@ -2931,6 +3070,49 @@ async def search_jobs( content=create_error_response("SEARCH_FAILED", str(e)) ) + +@api_router.delete("/jobs/{job_id}") +async def delete_job( + job_id: str = Path(...), + admin_user = Depends(get_current_admin), + database: RedisDatabase = Depends(get_database) +): + """Delete a Job""" + try: + # Check if admin user + if not admin_user.is_admin: + logger.warning(f"⚠️ Unauthorized delete attempt by user {admin_user.id}") + return JSONResponse( + status_code=403, + content=create_error_response("FORBIDDEN", "Only admins can delete") + ) + + # Get candidate data + job_data = await database.get_job(job_id) + if not job_data: + logger.warning(f"⚠️ Candidate not found for deletion: {job_id}") + return JSONResponse( + status_code=404, + content=create_error_response("NOT_FOUND", "Job not found") + ) + + # Delete job from database + await database.delete_job(job_id) + + logger.info(f"🗑️ Job deleted: {job_id} by admin {admin_user.id}") + + return create_success_response({ + "message": "Job deleted successfully", + "jobId": job_id + }) + + except Exception as e: + logger.error(f"❌ Delete job error: {e}") + return JSONResponse( + status_code=500, + content=create_error_response("DELETE_ERROR", "Failed to delete job") + ) + # ============================ # Chat Endpoints # ============================ @@ -3541,7 +3723,7 @@ async def get_candidate_skill_match( current_user = Depends(get_current_user), database: RedisDatabase = Depends(get_database) ): - """Get skill match for a candidate against a requirement""" + """Get skill match for a candidate against a requirement with caching""" try: # Find candidate by ID candidate_data = await database.get_candidate(candidate_id) @@ -3553,34 +3735,89 @@ async def get_candidate_skill_match( candidate = Candidate.model_validate(candidate_data) - async with entities.get_candidate_entity(candidate=candidate) as candidate_entity: - logger.info(f"🔍 Running skill match for candidate {candidate_entity.username} against requirement: {requirement}") - agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.SKILL_MATCH) - if not agent: - return JSONResponse( - status_code=400, - content=create_error_response("AGENT_NOT_FOUND", "No skill match agent found for this candidate") - ) - # Entity automatically released when done - skill_match = await get_last_item( - agent.generate( - llm=llm_manager.get_llm(), - model=defines.model, - session_id=MOCK_UUID, - prompt=requirement, + # Create cache key for this specific candidate + requirement combination + cache_key = f"skill_match:{candidate_id}:{hash(requirement)}" + + # Get cached assessment if it exists + cached_assessment = await database.get_cached_skill_match(cache_key) + + # Get the last update time for the candidate's skill information + candidate_skill_update_time = await database.get_candidate_skill_update_time(candidate_id) + + # Get the latest RAG data update time for the current user + user_rag_update_time = await database.get_user_rag_update_time(current_user.id) + + # Determine if we need to regenerate the assessment + should_regenerate = True + cached_date = None + + if cached_assessment: + cached_date = cached_assessment.get('cached_at') + if cached_date: + # Check if cached result is still valid + # Regenerate if: + # 1. Candidate skills were updated after cache date + # 2. User's RAG data was updated after cache date + if (not candidate_skill_update_time or cached_date >= candidate_skill_update_time) and \ + (not user_rag_update_time or cached_date >= user_rag_update_time): + should_regenerate = False + logger.info(f"🔄 Using cached skill match for candidate {candidate.id}") + + if should_regenerate: + logger.info(f"🔍 Generating new skill match for candidate {candidate.id} against requirement: {requirement}") + + async with entities.get_candidate_entity(candidate=candidate) as candidate_entity: + agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.SKILL_MATCH) + if not agent: + return JSONResponse( + status_code=400, + content=create_error_response("AGENT_NOT_FOUND", "No skill match agent found for this candidate") + ) + + # Generate new skill match + skill_match = await get_last_item( + agent.generate( + llm=llm_manager.get_llm(), + model=defines.model, + session_id=MOCK_UUID, + prompt=requirement, ), ) - if skill_match is None: + + if skill_match is None: + return JSONResponse( + status_code=500, + content=create_error_response("NO_MATCH", "No skill match found for the given requirement") + ) + + skill_match_data = json.loads(skill_match.content) + + # Cache the new assessment with current timestamp + cached_assessment = { + "skill_match": skill_match_data, + "cached_at": datetime.utcnow().isoformat(), + "candidate_id": candidate_id, + "requirement": requirement + } + + await database.cache_skill_match(cache_key, cached_assessment) + logger.info(f"💾 Cached new skill match for candidate {candidate.id}") + logger.info(f"✅ Skill match found for candidate {candidate.id}: {skill_match_data['evidence_strength']}") + else: + # Use cached result - we know cached_assessment is not None here + if cached_assessment is None: return JSONResponse( status_code=500, - content=create_error_response("NO_MATCH", "No skill match found for the given requirement") + content=create_error_response("CACHE_ERROR", "Unexpected cache state") ) - skill_match = json.loads(skill_match.content) - logger.info(f"✅ Skill match found for candidate {candidate.id}: {skill_match['evidence_strength']}") + skill_match_data = cached_assessment["skill_match"] + logger.info(f"✅ Retrieved cached skill match for candidate {candidate.id}: {skill_match_data['evidence_strength']}") return create_success_response({ "candidateId": candidate.id, - "skillMatch": skill_match + "skillMatch": skill_match_data, + "cached": not should_regenerate, + "cacheTimestamp": cached_date }) except Exception as e: @@ -3589,8 +3826,8 @@ async def get_candidate_skill_match( return JSONResponse( status_code=500, content=create_error_response("SKILL_MATCH_ERROR", str(e)) - ) - + ) + @api_router.get("/candidates/{username}/chat-sessions") async def get_candidate_chat_sessions( username: str = Path(...), @@ -3911,7 +4148,6 @@ async def track_requests(request, call_next): # FastAPI Metrics # ============================ prometheus_collector = CollectorRegistry() -entities.entity_manager.initialize(prometheus_collector) # Keep the Instrumentator instance alive instrumentator = Instrumentator( diff --git a/src/backend/models.py b/src/backend/models.py index d5010eb..4dfa3e0 100644 --- a/src/backend/models.py +++ b/src/backend/models.py @@ -768,11 +768,7 @@ class ChatOptions(BaseModel): "populate_by_name": True # Allow both field names and aliases } - -class LLMMessage(BaseModel): - role: str = Field(default="") - content: str = Field(default="") - tool_calls: Optional[List[Dict]] = Field(default=[], exclude=True) +from llm_proxy import (LLMMessage) class ApiMessage(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) diff --git a/src/backend/rag/rag.py b/src/backend/rag/rag.py index 0c6ba48..b6646ed 100644 --- a/src/backend/rag/rag.py +++ b/src/backend/rag/rag.py @@ -13,7 +13,6 @@ import numpy as np # type: ignore import traceback import chromadb # type: ignore -import ollama from watchdog.observers import Observer # type: ignore from watchdog.events import FileSystemEventHandler # type: ignore import umap # type: ignore @@ -27,6 +26,7 @@ from .markdown_chunker import ( # When imported as a module, use relative imports import defines +from database import RedisDatabase from models import ChromaDBGetResponse __all__ = ["ChromaDBFileWatcher", "start_file_watcher"] @@ -47,11 +47,16 @@ class ChromaDBFileWatcher(FileSystemEventHandler): loop, persist_directory, collection_name, + database: RedisDatabase, + user_id: str, chunk_size=DEFAULT_CHUNK_SIZE, chunk_overlap=DEFAULT_CHUNK_OVERLAP, recreate=False, ): self.llm = llm + self.database = database + self.user_id = user_id + self.database = database self.watch_directory = watch_directory self.persist_directory = persist_directory or defines.persist_directory self.collection_name = collection_name @@ -284,6 +289,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler): if results and "ids" in results and results["ids"]: self.collection.delete(ids=results["ids"]) + await self.database.update_user_rag_timestamp(self.user_id) logging.info( f"Removed {len(results['ids'])} chunks for deleted file: {file_path}" ) @@ -372,14 +378,15 @@ class ChromaDBFileWatcher(FileSystemEventHandler): name=self.collection_name, metadata={"hnsw:space": "cosine"} ) - def get_embedding(self, text: str) -> np.ndarray: + async def get_embedding(self, text: str) -> np.ndarray: """Generate and normalize an embedding for the given text.""" # Get embedding try: - response = self.llm.embeddings(model=defines.embedding_model, prompt=text) - embedding = np.array(response["embedding"]) + response = await self.llm.embeddings(model=defines.embedding_model, input_texts=text) + embedding = np.array(response.get_single_embedding()) except Exception as e: + logging.error(traceback.format_exc()) logging.error(f"Failed to get embedding: {e}") raise @@ -404,7 +411,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler): return embedding - def add_embeddings_to_collection(self, chunks: List[Chunk]): + async def _add_embeddings_to_collection(self, chunks: List[Chunk]): """Add embeddings for chunks to the collection.""" for i, chunk in enumerate(chunks): @@ -420,7 +427,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler): content_hash = hashlib.md5(text.encode()).hexdigest()[:8] chunk_id = f"{path_hash}_{i}_{content_hash}" - embedding = self.get_embedding(text) + embedding = await self.get_embedding(text) try: self.collection.add( ids=[chunk_id], @@ -458,11 +465,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler): # 0.5 - 0.7 0.65 - 0.75 Balanced precision/recall # 0.7 - 0.9 0.55 - 0.65 Higher recall, more inclusive # 0.9 - 1.2 0.40 - 0.55 Very inclusive, may include tangential content - def find_similar(self, query, top_k=defines.default_rag_top_k, threshold=defines.default_rag_threshold): + async def find_similar(self, query, top_k=defines.default_rag_top_k, threshold=defines.default_rag_threshold): """Find similar documents to the query.""" # collection is configured with hnsw:space cosine - query_embedding = self.get_embedding(query) + query_embedding = await self.get_embedding(query) results = self.collection.query( query_embeddings=[query_embedding], n_results=top_k, @@ -572,6 +579,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler): and existing_results["ids"] ): self.collection.delete(ids=existing_results["ids"]) + await self.database.update_user_rag_timestamp(self.user_id) extensions = (".docx", ".xlsx", ".xls", ".pdf") if file_path.endswith(extensions): @@ -606,7 +614,8 @@ class ChromaDBFileWatcher(FileSystemEventHandler): # f.write(json.dumps(chunk, indent=2)) # Add chunks to collection - self.add_embeddings_to_collection(chunks) + await self._add_embeddings_to_collection(chunks) + await self.database.update_user_rag_timestamp(self.user_id) logging.info(f"Updated {len(chunks)} chunks for file: {file_path}") @@ -640,9 +649,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler): # Function to start the file watcher def start_file_watcher( llm, + user_id, watch_directory, persist_directory, collection_name, + database: RedisDatabase, initialize=False, recreate=False, ): @@ -663,9 +674,11 @@ def start_file_watcher( llm, watch_directory=watch_directory, loop=loop, + user_id=user_id, persist_directory=persist_directory, collection_name=collection_name, recreate=recreate, + database=database ) # Process all files if: diff --git a/src/multi-llm/config.md b/src/multi-llm/config.md new file mode 100644 index 0000000..3c11bcf --- /dev/null +++ b/src/multi-llm/config.md @@ -0,0 +1,193 @@ +# Environment Configuration Examples + +# ==== Development (Ollama only) ==== +export OLLAMA_HOST="http://localhost:11434" +export DEFAULT_LLM_PROVIDER="ollama" + +# ==== Production with OpenAI ==== +export OPENAI_API_KEY="sk-your-openai-key-here" +export DEFAULT_LLM_PROVIDER="openai" + +# ==== Production with Anthropic ==== +export ANTHROPIC_API_KEY="sk-ant-your-anthropic-key-here" +export DEFAULT_LLM_PROVIDER="anthropic" + +# ==== Production with multiple providers ==== +export OPENAI_API_KEY="sk-your-openai-key-here" +export ANTHROPIC_API_KEY="sk-ant-your-anthropic-key-here" +export GEMINI_API_KEY="your-gemini-key-here" +export OLLAMA_HOST="http://ollama-server:11434" +export DEFAULT_LLM_PROVIDER="openai" + +# ==== Docker Compose Example ==== +# docker-compose.yml +version: '3.8' +services: + api: + build: . + ports: + - "8000:8000" + environment: + - OPENAI_API_KEY=${OPENAI_API_KEY} + - ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY} + - DEFAULT_LLM_PROVIDER=openai + depends_on: + - ollama + + ollama: + image: ollama/ollama + ports: + - "11434:11434" + volumes: + - ollama_data:/root/.ollama + environment: + - OLLAMA_HOST=0.0.0.0 + +volumes: + ollama_data: + +# ==== Kubernetes ConfigMap ==== +apiVersion: v1 +kind: ConfigMap +metadata: + name: llm-config +data: + DEFAULT_LLM_PROVIDER: "openai" + OLLAMA_HOST: "http://ollama-service:11434" + +--- +apiVersion: v1 +kind: Secret +metadata: + name: llm-secrets +type: Opaque +stringData: + OPENAI_API_KEY: "sk-your-key-here" + ANTHROPIC_API_KEY: "sk-ant-your-key-here" + +# ==== Python configuration example ==== +# config.py +import os +from llm_proxy import get_llm, LLMProvider + +def configure_llm_for_environment(): + """Configure LLM based on deployment environment""" + llm = get_llm() + + # Development: Use Ollama + if os.getenv('ENVIRONMENT') == 'development': + llm.configure_provider(LLMProvider.OLLAMA, host='http://localhost:11434') + llm.set_default_provider(LLMProvider.OLLAMA) + + # Staging: Use OpenAI with rate limiting + elif os.getenv('ENVIRONMENT') == 'staging': + llm.configure_provider(LLMProvider.OPENAI, + api_key=os.getenv('OPENAI_API_KEY'), + max_retries=3, + timeout=30) + llm.set_default_provider(LLMProvider.OPENAI) + + # Production: Use multiple providers with fallback + elif os.getenv('ENVIRONMENT') == 'production': + # Primary: Anthropic + llm.configure_provider(LLMProvider.ANTHROPIC, + api_key=os.getenv('ANTHROPIC_API_KEY')) + + # Fallback: OpenAI + llm.configure_provider(LLMProvider.OPENAI, + api_key=os.getenv('OPENAI_API_KEY')) + + # Set primary provider + llm.set_default_provider(LLMProvider.ANTHROPIC) + +# ==== Usage Examples ==== + +# Example 1: Basic usage with default provider +async def basic_example(): + llm = get_llm() + + response = await llm.chat( + model="gpt-4", + messages=[{"role": "user", "content": "Hello!"}] + ) + print(response.content) + +# Example 2: Specify provider explicitly +async def provider_specific_example(): + llm = get_llm() + + # Use OpenAI specifically + response = await llm.chat( + model="gpt-4", + messages=[{"role": "user", "content": "Hello!"}], + provider=LLMProvider.OPENAI + ) + + # Use Anthropic specifically + response2 = await llm.chat( + model="claude-3-sonnet-20240229", + messages=[{"role": "user", "content": "Hello!"}], + provider=LLMProvider.ANTHROPIC + ) + +# Example 3: Streaming with provider fallback +async def streaming_with_fallback(): + llm = get_llm() + + try: + async for chunk in llm.chat( + model="claude-3-sonnet-20240229", + messages=[{"role": "user", "content": "Write a story"}], + provider=LLMProvider.ANTHROPIC, + stream=True + ): + print(chunk.content, end='', flush=True) + except Exception as e: + print(f"Primary provider failed: {e}") + # Fallback to OpenAI + async for chunk in llm.chat( + model="gpt-4", + messages=[{"role": "user", "content": "Write a story"}], + provider=LLMProvider.OPENAI, + stream=True + ): + print(chunk.content, end='', flush=True) + +# Example 4: Load balancing between providers +import random + +async def load_balanced_request(): + llm = get_llm() + available_providers = [LLMProvider.OPENAI, LLMProvider.ANTHROPIC] + + # Simple random load balancing + provider = random.choice(available_providers) + + # Adjust model based on provider + model_mapping = { + LLMProvider.OPENAI: "gpt-4", + LLMProvider.ANTHROPIC: "claude-3-sonnet-20240229" + } + + response = await llm.chat( + model=model_mapping[provider], + messages=[{"role": "user", "content": "Hello!"}], + provider=provider + ) + + print(f"Response from {provider.value}: {response.content}") + +# ==== FastAPI Startup Configuration ==== +from fastapi import FastAPI + +app = FastAPI() + +@app.on_event("startup") +async def startup_event(): + """Configure LLM providers on application startup""" + configure_llm_for_environment() + + # Verify configuration + llm = get_llm() + print(f"Configured providers: {[p.value for p in llm._initialized_providers]}") + print(f"Default provider: {llm.default_provider.value}") \ No newline at end of file diff --git a/src/multi-llm/example.py b/src/multi-llm/example.py new file mode 100644 index 0000000..30f17cc --- /dev/null +++ b/src/multi-llm/example.py @@ -0,0 +1,238 @@ +from fastapi import FastAPI, HTTPException # type: ignore +from fastapi.responses import StreamingResponse # type: ignore +from pydantic import BaseModel # type: ignore +from typing import List, Optional, Dict, Any +import json +import asyncio +from llm_proxy import get_llm, LLMProvider, ChatMessage + +app = FastAPI(title="Unified LLM API") + +# Pydantic models for request/response +class ChatRequest(BaseModel): + model: str + messages: List[Dict[str, str]] + provider: Optional[str] = None + stream: bool = False + max_tokens: Optional[int] = None + temperature: Optional[float] = None + +class GenerateRequest(BaseModel): + model: str + prompt: str + provider: Optional[str] = None + stream: bool = False + max_tokens: Optional[int] = None + temperature: Optional[float] = None + +class ChatResponseModel(BaseModel): + content: str + model: str + provider: str + finish_reason: Optional[str] = None + usage: Optional[Dict[str, int]] = None + +@app.post("/chat") +async def chat_endpoint(request: ChatRequest): + """Chat endpoint that works with any configured LLM provider""" + try: + llm = get_llm() + + # Convert provider string to enum if provided + provider = None + if request.provider: + try: + provider = LLMProvider(request.provider.lower()) + except ValueError: + raise HTTPException( + status_code=400, + detail=f"Unsupported provider: {request.provider}" + ) + + # Prepare kwargs + kwargs = {} + if request.max_tokens: + kwargs['max_tokens'] = request.max_tokens + if request.temperature: + kwargs['temperature'] = request.temperature + + if request.stream: + async def generate_response(): + try: + # Use the type-safe streaming method + async for chunk in llm.chat_stream( + model=request.model, + messages=request.messages, + provider=provider, + **kwargs + ): + # Format as Server-Sent Events + data = { + "content": chunk.content, + "model": chunk.model, + "provider": provider.value if provider else llm.default_provider.value, + "finish_reason": chunk.finish_reason + } + yield f"data: {json.dumps(data)}\n\n" + except Exception as e: + error_data = {"error": str(e)} + yield f"data: {json.dumps(error_data)}\n\n" + finally: + yield "data: [DONE]\n\n" + + return StreamingResponse( + generate_response(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive" + } + ) + else: + # Use the type-safe single response method + response = await llm.chat_single( + model=request.model, + messages=request.messages, + provider=provider, + **kwargs + ) + + return ChatResponseModel( + content=response.content, + model=response.model, + provider=provider.value if provider else llm.default_provider.value, + finish_reason=response.finish_reason, + usage=response.usage + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@app.post("/generate") +async def generate_endpoint(request: GenerateRequest): + """Generate endpoint for simple text generation""" + try: + llm = get_llm() + + provider = None + if request.provider: + try: + provider = LLMProvider(request.provider.lower()) + except ValueError: + raise HTTPException( + status_code=400, + detail=f"Unsupported provider: {request.provider}" + ) + + kwargs = {} + if request.max_tokens: + kwargs['max_tokens'] = request.max_tokens + if request.temperature: + kwargs['temperature'] = request.temperature + + if request.stream: + async def generate_response(): + try: + # Use the type-safe streaming method + async for chunk in llm.generate_stream( + model=request.model, + prompt=request.prompt, + provider=provider, + **kwargs + ): + data = { + "content": chunk.content, + "model": chunk.model, + "provider": provider.value if provider else llm.default_provider.value, + "finish_reason": chunk.finish_reason + } + yield f"data: {json.dumps(data)}\n\n" + except Exception as e: + error_data = {"error": str(e)} + yield f"data: {json.dumps(error_data)}\n\n" + finally: + yield "data: [DONE]\n\n" + + return StreamingResponse( + generate_response(), + media_type="text/event-stream" + ) + else: + # Use the type-safe single response method + response = await llm.generate_single( + model=request.model, + prompt=request.prompt, + provider=provider, + **kwargs + ) + + return ChatResponseModel( + content=response.content, + model=response.model, + provider=provider.value if provider else llm.default_provider.value, + finish_reason=response.finish_reason, + usage=response.usage + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@app.get("/models") +async def list_models(provider: Optional[str] = None): + """List available models for a provider""" + try: + llm = get_llm() + + provider_enum = None + if provider: + try: + provider_enum = LLMProvider(provider.lower()) + except ValueError: + raise HTTPException( + status_code=400, + detail=f"Unsupported provider: {provider}" + ) + + models = await llm.list_models(provider_enum) + return { + "provider": provider_enum.value if provider_enum else llm.default_provider.value, + "models": models + } + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@app.get("/providers") +async def list_providers(): + """List all configured providers""" + llm = get_llm() + return { + "providers": [provider.value for provider in llm._initialized_providers], + "default": llm.default_provider.value + } + +@app.post("/providers/{provider}/set-default") +async def set_default_provider(provider: str): + """Set the default provider""" + try: + llm = get_llm() + provider_enum = LLMProvider(provider.lower()) + llm.set_default_provider(provider_enum) + return {"message": f"Default provider set to {provider}", "default": provider} + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + +# Health check endpoint +@app.get("/health") +async def health_check(): + """Health check endpoint""" + llm = get_llm() + return { + "status": "healthy", + "providers_configured": len(llm._initialized_providers), + "default_provider": llm.default_provider.value + } + +if __name__ == "__main__": + import uvicorn # type: ignore + uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/src/multi-llm/llm_proxy.py b/src/multi-llm/llm_proxy.py new file mode 100644 index 0000000..2057760 --- /dev/null +++ b/src/multi-llm/llm_proxy.py @@ -0,0 +1,606 @@ +from abc import ABC, abstractmethod +from typing import Dict, List, Any, AsyncGenerator, Optional, Union +from enum import Enum +import asyncio +import json +from dataclasses import dataclass +import os + +# Standard message format for all providers +@dataclass +class ChatMessage: + role: str # "user", "assistant", "system" + content: str + + def __getitem__(self, key: str): + """Allow dictionary-style access for backward compatibility""" + if key == 'role': + return self.role + elif key == 'content': + return self.content + else: + raise KeyError(f"'{key}' not found in ChatMessage") + + def __setitem__(self, key: str, value: str): + """Allow dictionary-style assignment""" + if key == 'role': + self.role = value + elif key == 'content': + self.content = value + else: + raise KeyError(f"'{key}' not found in ChatMessage") + + @classmethod + def from_dict(cls, data: Dict[str, str]) -> 'ChatMessage': + """Create ChatMessage from dictionary""" + return cls(role=data['role'], content=data['content']) + + def to_dict(self) -> Dict[str, str]: + """Convert ChatMessage to dictionary""" + return {'role': self.role, 'content': self.content} + +@dataclass +class ChatResponse: + content: str + model: str + finish_reason: Optional[str] = None + usage: Optional[Dict[str, int]] = None + +class LLMProvider(Enum): + OLLAMA = "ollama" + OPENAI = "openai" + ANTHROPIC = "anthropic" + GEMINI = "gemini" + GROK = "grok" + +class BaseLLMAdapter(ABC): + """Abstract base class for all LLM adapters""" + + def __init__(self, **config): + self.config = config + + @abstractmethod + async def chat( + self, + model: str, + messages: List[ChatMessage], + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + """Send chat messages and get response""" + pass + + @abstractmethod + async def generate( + self, + model: str, + prompt: str, + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + """Generate text from prompt""" + pass + + @abstractmethod + async def list_models(self) -> List[str]: + """List available models""" + pass + +class OllamaAdapter(BaseLLMAdapter): + """Adapter for Ollama""" + + def __init__(self, **config): + super().__init__(**config) + import ollama + self.client = ollama.AsyncClient( # type: ignore + host=config.get('host', 'http://localhost:11434') + ) + + async def chat( + self, + model: str, + messages: List[ChatMessage], + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + + # Convert ChatMessage objects to Ollama format + ollama_messages = [] + for msg in messages: + ollama_messages.append({ + "role": msg.role, + "content": msg.content + }) + + if stream: + return self._stream_chat(model, ollama_messages, **kwargs) + else: + response = await self.client.chat( + model=model, + messages=ollama_messages, + stream=False, + **kwargs + ) + return ChatResponse( + content=response['message']['content'], + model=model, + finish_reason=response.get('done_reason') + ) + + async def _stream_chat(self, model: str, messages: List[Dict], **kwargs): + async for chunk in self.client.chat( + model=model, + messages=messages, + stream=True, + **kwargs + ): + if chunk.get('message', {}).get('content'): + yield ChatResponse( + content=chunk['message']['content'], + model=model, + finish_reason=chunk.get('done_reason') + ) + + async def generate( + self, + model: str, + prompt: str, + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + + if stream: + return self._stream_generate(model, prompt, **kwargs) + else: + response = await self.client.generate( + model=model, + prompt=prompt, + stream=False, + **kwargs + ) + return ChatResponse( + content=response['response'], + model=model, + finish_reason=response.get('done_reason') + ) + + async def _stream_generate(self, model: str, prompt: str, **kwargs): + async for chunk in self.client.generate( + model=model, + prompt=prompt, + stream=True, + **kwargs + ): + if chunk.get('response'): + yield ChatResponse( + content=chunk['response'], + model=model, + finish_reason=chunk.get('done_reason') + ) + + async def list_models(self) -> List[str]: + models = await self.client.list() + return [model['name'] for model in models['models']] + +class OpenAIAdapter(BaseLLMAdapter): + """Adapter for OpenAI""" + + def __init__(self, **config): + super().__init__(**config) + import openai # type: ignore + self.client = openai.AsyncOpenAI( + api_key=config.get('api_key', os.getenv('OPENAI_API_KEY')) + ) + + async def chat( + self, + model: str, + messages: List[ChatMessage], + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + + # Convert ChatMessage objects to OpenAI format + openai_messages = [] + for msg in messages: + openai_messages.append({ + "role": msg.role, + "content": msg.content + }) + + if stream: + return self._stream_chat(model, openai_messages, **kwargs) + else: + response = await self.client.chat.completions.create( + model=model, + messages=openai_messages, + stream=False, + **kwargs + ) + return ChatResponse( + content=response.choices[0].message.content, + model=model, + finish_reason=response.choices[0].finish_reason, + usage=response.usage.model_dump() if response.usage else None + ) + + async def _stream_chat(self, model: str, messages: List[Dict], **kwargs): + async for chunk in await self.client.chat.completions.create( + model=model, + messages=messages, + stream=True, + **kwargs + ): + if chunk.choices[0].delta.content: + yield ChatResponse( + content=chunk.choices[0].delta.content, + model=model, + finish_reason=chunk.choices[0].finish_reason + ) + + async def generate( + self, + model: str, + prompt: str, + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + + # Convert to chat format for OpenAI + messages = [ChatMessage(role="user", content=prompt)] + return await self.chat(model, messages, stream, **kwargs) + + async def list_models(self) -> List[str]: + models = await self.client.models.list() + return [model.id for model in models.data] + +class AnthropicAdapter(BaseLLMAdapter): + """Adapter for Anthropic Claude""" + + def __init__(self, **config): + super().__init__(**config) + import anthropic # type: ignore + self.client = anthropic.AsyncAnthropic( + api_key=config.get('api_key', os.getenv('ANTHROPIC_API_KEY')) + ) + + async def chat( + self, + model: str, + messages: List[ChatMessage], + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + + # Anthropic requires system message to be separate + system_message = None + anthropic_messages = [] + + for msg in messages: + if msg.role == "system": + system_message = msg.content + else: + anthropic_messages.append({ + "role": msg.role, + "content": msg.content + }) + + request_kwargs = { + "model": model, + "messages": anthropic_messages, + "max_tokens": kwargs.pop('max_tokens', 1000), + **kwargs + } + + if system_message: + request_kwargs["system"] = system_message + + if stream: + return self._stream_chat(**request_kwargs) + else: + response = await self.client.messages.create( + stream=False, + **request_kwargs + ) + return ChatResponse( + content=response.content[0].text, + model=model, + finish_reason=response.stop_reason, + usage={ + "input_tokens": response.usage.input_tokens, + "output_tokens": response.usage.output_tokens + } + ) + + async def _stream_chat(self, **kwargs): + async with self.client.messages.stream(**kwargs) as stream: + async for text in stream.text_stream: + yield ChatResponse( + content=text, + model=kwargs['model'] + ) + + async def generate( + self, + model: str, + prompt: str, + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + + messages = [ChatMessage(role="user", content=prompt)] + return await self.chat(model, messages, stream, **kwargs) + + async def list_models(self) -> List[str]: + # Anthropic doesn't have a list models endpoint, return known models + return [ + "claude-3-5-sonnet-20241022", + "claude-3-5-haiku-20241022", + "claude-3-opus-20240229" + ] + +class GeminiAdapter(BaseLLMAdapter): + """Adapter for Google Gemini""" + + def __init__(self, **config): + super().__init__(**config) + import google.generativeai as genai # type: ignore + genai.configure(api_key=config.get('api_key', os.getenv('GEMINI_API_KEY'))) + self.genai = genai + + async def chat( + self, + model: str, + messages: List[ChatMessage], + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + + model_instance = self.genai.GenerativeModel(model) + + # Convert messages to Gemini format + chat_history = [] + current_message = None + + for i, msg in enumerate(messages[:-1]): # All but last message go to history + if msg.role == "user": + chat_history.append({"role": "user", "parts": [msg.content]}) + elif msg.role == "assistant": + chat_history.append({"role": "model", "parts": [msg.content]}) + + # Last message is the current prompt + if messages: + current_message = messages[-1].content + + chat = model_instance.start_chat(history=chat_history) + + if not current_message: + raise ValueError("No current message provided for chat") + + if stream: + return self._stream_chat(chat, current_message, **kwargs) + else: + response = await chat.send_message_async(current_message, **kwargs) + return ChatResponse( + content=response.text, + model=model, + finish_reason=response.candidates[0].finish_reason.name if response.candidates else None + ) + + async def _stream_chat(self, chat, message: str, **kwargs): + async for chunk in chat.send_message_async(message, stream=True, **kwargs): + if chunk.text: + yield ChatResponse( + content=chunk.text, + model=chat.model.model_name + ) + + async def generate( + self, + model: str, + prompt: str, + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + + messages = [ChatMessage(role="user", content=prompt)] + return await self.chat(model, messages, stream, **kwargs) + + async def list_models(self) -> List[str]: + models = self.genai.list_models() + return [model.name for model in models if 'generateContent' in model.supported_generation_methods] + +class UnifiedLLMProxy: + """Main proxy class that provides unified interface to all LLM providers""" + + def __init__(self, default_provider: LLMProvider = LLMProvider.OLLAMA): + self.adapters: Dict[LLMProvider, BaseLLMAdapter] = {} + self.default_provider = default_provider + self._initialized_providers = set() + + def configure_provider(self, provider: LLMProvider, **config): + """Configure a specific provider with its settings""" + adapter_classes = { + LLMProvider.OLLAMA: OllamaAdapter, + LLMProvider.OPENAI: OpenAIAdapter, + LLMProvider.ANTHROPIC: AnthropicAdapter, + LLMProvider.GEMINI: GeminiAdapter, + # Add other providers as needed + } + + if provider in adapter_classes: + self.adapters[provider] = adapter_classes[provider](**config) + self._initialized_providers.add(provider) + else: + raise ValueError(f"Unsupported provider: {provider}") + + def set_default_provider(self, provider: LLMProvider): + """Set the default provider for requests""" + if provider not in self._initialized_providers: + raise ValueError(f"Provider {provider} not configured") + self.default_provider = provider + + async def chat( + self, + model: str, + messages: Union[List[ChatMessage], List[Dict[str, str]]], + provider: Optional[LLMProvider] = None, + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + """Send chat messages using specified or default provider""" + + provider = provider or self.default_provider + adapter = self._get_adapter(provider) + + # Normalize messages to ChatMessage objects + normalized_messages = [] + for msg in messages: + if isinstance(msg, ChatMessage): + normalized_messages.append(msg) + elif isinstance(msg, dict): + normalized_messages.append(ChatMessage.from_dict(msg)) + else: + raise ValueError(f"Invalid message type: {type(msg)}") + + return await adapter.chat(model, normalized_messages, stream, **kwargs) + + async def chat_stream( + self, + model: str, + messages: Union[List[ChatMessage], List[Dict[str, str]]], + provider: Optional[LLMProvider] = None, + **kwargs + ) -> AsyncGenerator[ChatResponse, None]: + """Stream chat messages using specified or default provider""" + + result = await self.chat(model, messages, provider, stream=True, **kwargs) + # Type checker now knows this is an AsyncGenerator due to stream=True + async for chunk in result: # type: ignore + yield chunk + + async def chat_single( + self, + model: str, + messages: Union[List[ChatMessage], List[Dict[str, str]]], + provider: Optional[LLMProvider] = None, + **kwargs + ) -> ChatResponse: + """Get single chat response using specified or default provider""" + + result = await self.chat(model, messages, provider, stream=False, **kwargs) + # Type checker now knows this is a ChatResponse due to stream=False + return result # type: ignore + + async def generate( + self, + model: str, + prompt: str, + provider: Optional[LLMProvider] = None, + stream: bool = False, + **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + """Generate text using specified or default provider""" + + provider = provider or self.default_provider + adapter = self._get_adapter(provider) + + return await adapter.generate(model, prompt, stream, **kwargs) + + async def generate_stream( + self, + model: str, + prompt: str, + provider: Optional[LLMProvider] = None, + **kwargs + ) -> AsyncGenerator[ChatResponse, None]: + """Stream text generation using specified or default provider""" + + result = await self.generate(model, prompt, provider, stream=True, **kwargs) + async for chunk in result: # type: ignore + yield chunk + + async def generate_single( + self, + model: str, + prompt: str, + provider: Optional[LLMProvider] = None, + **kwargs + ) -> ChatResponse: + """Get single generation response using specified or default provider""" + + result = await self.generate(model, prompt, provider, stream=False, **kwargs) + return result # type: ignore + + async def list_models(self, provider: Optional[LLMProvider] = None) -> List[str]: + """List available models for specified or default provider""" + + provider = provider or self.default_provider + adapter = self._get_adapter(provider) + + return await adapter.list_models() + + def _get_adapter(self, provider: LLMProvider) -> BaseLLMAdapter: + """Get adapter for specified provider""" + if provider not in self.adapters: + raise ValueError(f"Provider {provider} not configured") + return self.adapters[provider] + +# Example usage and configuration +class LLMManager: + """Singleton manager for the unified LLM proxy""" + + _instance = None + _proxy = None + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def __init__(self): + if LLMManager._proxy is None: + LLMManager._proxy = UnifiedLLMProxy() + self._configure_from_environment() + + def _configure_from_environment(self): + """Configure providers based on environment variables""" + + if not self._proxy: + raise RuntimeError("LLM proxy not initialized") + + # Configure Ollama if available + ollama_host = os.getenv('OLLAMA_HOST', 'http://localhost:11434') + self._proxy.configure_provider(LLMProvider.OLLAMA, host=ollama_host) + + # Configure OpenAI if API key is available + if os.getenv('OPENAI_API_KEY'): + self._proxy.configure_provider(LLMProvider.OPENAI) + + # Configure Anthropic if API key is available + if os.getenv('ANTHROPIC_API_KEY'): + self._proxy.configure_provider(LLMProvider.ANTHROPIC) + + # Configure Gemini if API key is available + if os.getenv('GEMINI_API_KEY'): + self._proxy.configure_provider(LLMProvider.GEMINI) + + # Set default provider from environment or use Ollama + default_provider = os.getenv('DEFAULT_LLM_PROVIDER', 'ollama') + try: + self._proxy.set_default_provider(LLMProvider(default_provider)) + except ValueError: + # Fallback to Ollama if specified provider not available + self._proxy.set_default_provider(LLMProvider.OLLAMA) + + def get_proxy(self) -> UnifiedLLMProxy: + """Get the unified LLM proxy instance""" + if not self._proxy: + raise RuntimeError("LLM proxy not initialized") + return self._proxy + +# Convenience function for easy access +def get_llm() -> UnifiedLLMProxy: + """Get the configured LLM proxy""" + return LLMManager.get_instance().get_proxy() \ No newline at end of file