Guest seems to work!

This commit is contained in:
James Ketr 2025-06-08 20:42:23 -07:00
parent 35ef9898f1
commit 20f8d7bd32
81 changed files with 3284 additions and 830 deletions

View File

@ -188,7 +188,7 @@ RUN pip install prometheus-client prometheus-fastapi-instrumentator
RUN pip install "redis[hiredis]>=4.5.0" RUN pip install "redis[hiredis]>=4.5.0"
# New backend implementation # New backend implementation
RUN pip install fastapi uvicorn "python-jose[cryptography]" bcrypt python-multipart RUN pip install fastapi uvicorn "python-jose[cryptography]" bcrypt python-multipart schedule
# Needed for email verification # Needed for email verification
RUN pip install pyyaml user-agents cryptography RUN pip install pyyaml user-agents cryptography

View File

@ -91,11 +91,11 @@ services:
# Optional: Redis Commander for GUI management # Optional: Redis Commander for GUI management
redis-commander: redis-commander:
image: rediscommander/redis-commander:latest image: rediscommander/redis-commander:latest
container_name: backstory-redis-commander container_name: redis-commander
ports: ports:
- "8081:8081" - "8081:8081"
environment: environment:
- REDIS_HOSTS=redis:6379 - REDIS_HOSTS=redis:redis:6379
networks: networks:
- internal - internal
depends_on: depends_on:

View File

@ -20,8 +20,8 @@ import '@fontsource/roboto/700.css';
const BackstoryApp = () => { const BackstoryApp = () => {
const navigate = useNavigate(); const navigate = useNavigate();
const location = useLocation(); const location = useLocation();
const snackRef = useRef<any>(null);
const chatRef = useRef<ConversationHandle>(null); const chatRef = useRef<ConversationHandle>(null);
const snackRef = useRef<any>(null);
const setSnack = useCallback((message: string, severity?: SeverityType) => { const setSnack = useCallback((message: string, severity?: SeverityType) => {
snackRef.current?.setSnack(message, severity); snackRef.current?.setSnack(message, severity);
}, [snackRef]); }, [snackRef]);

View File

@ -12,7 +12,6 @@ import { Message } from './Message';
import { DeleteConfirmation } from 'components/DeleteConfirmation'; import { DeleteConfirmation } from 'components/DeleteConfirmation';
import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryTextField'; import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryTextField';
import { BackstoryElementProps } from './BackstoryTab'; import { BackstoryElementProps } from './BackstoryTab';
import { connectionBase } from 'utils/Global';
import { useAuth } from "hooks/AuthContext"; import { useAuth } from "hooks/AuthContext";
import { StreamingResponse } from 'services/api-client'; import { StreamingResponse } from 'services/api-client';
import { ChatMessage, ChatContext, ChatSession, ChatQuery, ChatMessageUser, ChatMessageError, ChatMessageStreaming, ChatMessageStatus } from 'types/types'; import { ChatMessage, ChatContext, ChatSession, ChatQuery, ChatMessageUser, ChatMessageError, ChatMessageStreaming, ChatMessageStatus } from 'types/types';
@ -22,7 +21,7 @@ import './Conversation.css';
import { useAppState } from 'hooks/GlobalContext'; import { useAppState } from 'hooks/GlobalContext';
const defaultMessage: ChatMessage = { const defaultMessage: ChatMessage = {
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "assistant" status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "assistant", metadata: null as any
}; };
const loadingMessage: ChatMessage = { ...defaultMessage, content: "Establishing connection with server..." }; const loadingMessage: ChatMessage = { ...defaultMessage, content: "Establishing connection with server..." };
@ -325,16 +324,16 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
<Box sx={{ p: 1, mt: 0, ...sx }}> <Box sx={{ p: 1, mt: 0, ...sx }}>
{ {
filteredConversation.map((message, index) => filteredConversation.map((message, index) =>
<Message key={index} {...{ chatSession, sendQuery: processQuery, message, connectionBase, }} /> <Message key={index} {...{ chatSession, sendQuery: processQuery, message, }} />
) )
} }
{ {
processingMessage !== undefined && processingMessage !== undefined &&
<Message {...{ chatSession, sendQuery: processQuery, connectionBase, message: processingMessage, }} /> <Message {...{ chatSession, sendQuery: processQuery, message: processingMessage, }} />
} }
{ {
streamingMessage !== undefined && streamingMessage !== undefined &&
<Message {...{ chatSession, sendQuery: processQuery, connectionBase, message: streamingMessage }} /> <Message {...{ chatSession, sendQuery: processQuery, message: streamingMessage }} />
} }
<Box sx={{ <Box sx={{
display: "flex", display: "flex",

View File

@ -29,7 +29,7 @@ import {
} from '@mui/icons-material'; } from '@mui/icons-material';
import { useAuth } from 'hooks/AuthContext'; import { useAuth } from 'hooks/AuthContext';
import { BackstoryPageProps } from './BackstoryTab'; import { BackstoryPageProps } from './BackstoryTab';
import { useNavigate } from 'react-router-dom'; import { Navigate, useNavigate } from 'react-router-dom';
// Email Verification Component // Email Verification Component
const EmailVerificationPage = (props: BackstoryPageProps) => { const EmailVerificationPage = (props: BackstoryPageProps) => {
@ -488,10 +488,11 @@ const RegistrationSuccessDialog = ({
// Enhanced Login Component with MFA Support // Enhanced Login Component with MFA Support
const LoginForm = () => { const LoginForm = () => {
const { login, mfaResponse, isLoading, error } = useAuth(); const { login, mfaResponse, isLoading, error, user } = useAuth();
const [email, setEmail] = useState(''); const [email, setEmail] = useState('');
const [password, setPassword] = useState(''); const [password, setPassword] = useState('');
const [errorMessage, setErrorMessage] = useState<string | null>(null); const [errorMessage, setErrorMessage] = useState<string | null>(null);
const navigate = useNavigate();
useEffect(() => { useEffect(() => {
if (!error) { if (!error) {
@ -524,10 +525,13 @@ const LoginForm = () => {
handleLoginSuccess(); handleLoginSuccess();
}; };
const handleLoginSuccess = () => { const handleLoginSuccess = () => {
// This could be handled by a router or parent component if (!user) {
// For now, just showing the pattern navigate('/');
console.log('Login successful - redirect to dashboard'); } else {
navigate(`/${user.userType}/dashboard`);
}
console.log('Login successful - redirect to dashboard');
}; };
return ( return (

View File

@ -39,7 +39,7 @@ interface JobAnalysisProps extends BackstoryPageProps {
} }
const defaultMessage: ChatMessage = { const defaultMessage: ChatMessage = {
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "assistant" status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "assistant", metadata: null as any
}; };
const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) => { const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) => {

View File

@ -101,6 +101,7 @@ const getStyle = (theme: Theme, type: ApiActivityType | ChatSenderType | "error"
color: theme.palette.text.primary, color: theme.palette.text.primary,
opacity: 0.95, opacity: 0.95,
}, },
info: 'information',
preparing: 'status', preparing: 'status',
processing: 'status', processing: 'status',
qualifications: { qualifications: {

View File

@ -17,7 +17,6 @@ import TableContainer from '@mui/material/TableContainer';
import TableRow from '@mui/material/TableRow'; import TableRow from '@mui/material/TableRow';
import { Scrollable } from './Scrollable'; import { Scrollable } from './Scrollable';
import { connectionBase } from '../utils/Global';
import './VectorVisualizer.css'; import './VectorVisualizer.css';
import { BackstoryPageProps } from './BackstoryTab'; import { BackstoryPageProps } from './BackstoryTab';
@ -195,7 +194,7 @@ const VectorVisualizer: React.FC<VectorVisualizerProps> = (props: VectorVisualiz
const [plotDimensions, setPlotDimensions] = useState({ width: 0, height: 0 }); const [plotDimensions, setPlotDimensions] = useState({ width: 0, height: 0 });
const navigate = useNavigate(); const navigate = useNavigate();
const candidate: Types.Candidate | null = user?.userType === 'candidate' ? user : null; const candidate: Types.Candidate | null = user?.userType === 'candidate' ? user as Types.Candidate : null;
/* Force resize of Plotly as it tends to not be the correct size if it is initially rendered /* Force resize of Plotly as it tends to not be the correct size if it is initially rendered
* off screen (eg., the VectorVisualizer is not on the tab the app loads to) */ * off screen (eg., the VectorVisualizer is not on the tab the app loads to) */

View File

@ -84,7 +84,6 @@ const BackstoryLayout: React.FC<BackstoryLayoutProps> = (props: BackstoryLayoutP
const navigate = useNavigate(); const navigate = useNavigate();
const location = useLocation(); const location = useLocation();
const { guest, user } = useAuth(); const { guest, user } = useAuth();
const { selectedCandidate } = useSelectedCandidate();
const [navigationItems, setNavigationItems] = useState<NavigationItem[]>([]); const [navigationItems, setNavigationItems] = useState<NavigationItem[]>([]);
useEffect(() => { useEffect(() => {
@ -92,9 +91,13 @@ const BackstoryLayout: React.FC<BackstoryLayoutProps> = (props: BackstoryLayoutP
setNavigationItems(getMainNavigationItems(userType, user?.isAdmin ? true : false)); setNavigationItems(getMainNavigationItems(userType, user?.isAdmin ? true : false));
}, [user]); }, [user]);
useEffect(() => {
console.log({ guest, user });
}, [guest, user]);
// Generate dynamic routes from navigation config // Generate dynamic routes from navigation config
const generateRoutes = () => { const generateRoutes = () => {
if (!guest) return null; if (!guest && !user) return null;
const userType = user?.userType || null; const userType = user?.userType || null;
const isAdmin = user?.isAdmin ? true : false; const isAdmin = user?.isAdmin ? true : false;
@ -161,7 +164,7 @@ const BackstoryLayout: React.FC<BackstoryLayoutProps> = (props: BackstoryLayoutP
}} }}
> >
<BackstoryPageContainer> <BackstoryPageContainer>
{!guest && ( {!guest && !user && (
<Box> <Box>
<LoadingComponent <LoadingComponent
loadingText="Creating session..." loadingText="Creating session..."
@ -171,7 +174,7 @@ const BackstoryLayout: React.FC<BackstoryLayoutProps> = (props: BackstoryLayoutP
/> />
</Box> </Box>
)} )}
{guest && ( {(guest || user) && (
<> <>
<Outlet /> <Outlet />
<Routes> <Routes>

View File

@ -156,7 +156,7 @@ const Header: React.FC<HeaderProps> = (props: HeaderProps) => {
id: 'profile', id: 'profile',
label: 'Profile', label: 'Profile',
icon: <Person fontSize="small" />, icon: <Person fontSize="small" />,
action: () => navigate(`/${user?.userType}/dashboard/profile`) action: () => navigate(`/${user?.userType}/profile`)
}, },
{ {
id: 'dashboard', id: 'dashboard',

View File

@ -74,7 +74,7 @@ const CandidateInfo: React.FC<CandidateInfoProps> = (props: CandidateInfoProps)
maxWidth: "80px" maxWidth: "80px"
}}> }}>
<Avatar <Avatar
src={candidate.profileImage ? `/api/1.0/candidates/profile/${candidate.username}?timestamp=${Date.now()}` : ''} src={candidate.profileImage ? `/api/1.0/candidates/profile/${candidate.username}` : ''}
alt={`${candidate.fullName}'s profile`} alt={`${candidate.fullName}'s profile`}
sx={{ sx={{
alignSelf: "flex-start", alignSelf: "flex-start",

View File

@ -46,6 +46,7 @@ import { JobPicker } from 'components/ui/JobPicker';
import { DocumentManager } from 'components/DocumentManager'; import { DocumentManager } from 'components/DocumentManager';
import { VectorVisualizer } from 'components/VectorVisualizer'; import { VectorVisualizer } from 'components/VectorVisualizer';
import { ComingSoon } from 'components/ui/ComingSoon'; import { ComingSoon } from 'components/ui/ComingSoon';
import { Beta } from 'components/ui/Beta';
// Beta page components for placeholder routes // Beta page components for placeholder routes
const SearchPage = () => (<BetaPage><Typography variant="h4">Search</Typography></BetaPage>); const SearchPage = () => (<BetaPage><Typography variant="h4">Search</Typography></BetaPage>);
@ -58,20 +59,19 @@ const SettingsPage = () => (<BetaPage><Typography variant="h4">Settings</Typogra
export const navigationConfig: NavigationConfig = { export const navigationConfig: NavigationConfig = {
items: [ items: [
{ id: 'home', label: <BackstoryLogo />, path: '/', component: <HomePage />, userTypes: ['guest', 'candidate', 'employer'], exact: true, }, { id: 'home', label: <BackstoryLogo />, path: '/', component: <HomePage />, userTypes: ['guest', 'candidate', 'employer'], exact: true, },
{ id: 'chat', label: 'Chat', path: '/chat', icon: <ChatIcon />, component: <CandidateChatPage />, userTypes: ['guest', 'candidate', 'employer',], }, { id: 'chat', label: 'Chat about a Candidate', path: '/chat', icon: <ChatIcon />, component: <CandidateChatPage />, userTypes: ['guest', 'candidate', 'employer',], },
{ {
id: 'candidate-menu', label: 'Tools', icon: <PersonIcon />, userTypes: ['candidate'], children: [ id: 'candidate-menu', label: 'Tools', icon: <PersonIcon />, userTypes: ['candidate'], children: [
{ id: 'candidate-dashboard', label: 'Dashboard', path: '/candidate/dashboard', icon: <DashboardIcon />, component: <CandidateDashboard />, userTypes: ['candidate'] }, { id: 'candidate-dashboard', label: 'Dashboard', path: '/candidate/dashboard', icon: <DashboardIcon />, component: <CandidateDashboard />, userTypes: ['candidate'] },
{ id: 'candidate-profile', label: 'Profile', icon: <PersonIcon />, path: '/candidate/profile', component: <CandidateProfile />, userTypes: ['candidate'] }, { id: 'candidate-profile', label: 'Profile', icon: <PersonIcon />, path: '/candidate/profile', component: <CandidateProfile />, userTypes: ['candidate'] },
{ id: 'candidate-qa-setup', label: 'Q&A Setup', icon: <QuizIcon />, path: '/candidate/qa-setup', component: <BetaPage><Box>Candidate q&a setup page</Box></BetaPage>, userTypes: ['candidate'] }, { id: 'candidate-qa-setup', label: 'Q&A Setup', icon: <QuizIcon />, path: '/candidate/qa-setup', component: <BetaPage><Box>Candidate q&a setup page</Box></BetaPage>, userTypes: ['candidate'] },
{ id: 'candidate-analytics', label: 'Analytics', icon: <AnalyticsIcon />, path: '/candidate/analytics', component: <BetaPage><Box>Candidate analytics page</Box></BetaPage>, userTypes: ['candidate'] }, { id: 'candidate-analytics', label: 'Analytics', icon: <AnalyticsIcon />, path: '/candidate/analytics', component: <BetaPage><Box>Candidate analytics page</Box></BetaPage>, userTypes: ['candidate'] },
{ id: 'candidate-jobs', label: 'Jobs', icon: <WorkIcon />, path: '/candidate/jobs', component: <JobPicker />, userTypes: ['candidate'] },
{ id: 'candidate-job-analysis', label: 'Job Analysis', path: '/candidate/job-analysis', icon: <WorkIcon />, component: <JobAnalysisPage />, userTypes: ['candidate'], }, { id: 'candidate-job-analysis', label: 'Job Analysis', path: '/candidate/job-analysis', icon: <WorkIcon />, component: <JobAnalysisPage />, userTypes: ['candidate'], },
{ id: 'candidate-resumes', label: 'Resumes', icon: <DescriptionIcon />, path: '/candidate/resumes', component: <BetaPage><Box>Candidate resumes page</Box></BetaPage>, userTypes: ['candidate'] }, { id: 'candidate-resumes', label: 'Resumes', icon: <DescriptionIcon />, path: '/candidate/resumes', component: <BetaPage><Box>Candidate resumes page</Box></BetaPage>, userTypes: ['candidate'] },
{ id: 'candidate-resume-builder', label: 'Resume Builder', path: '/candidate/resume-builder', icon: <DescriptionIcon />, component: <ResumeBuilderPage />, userTypes: ['candidate'], }, { id: 'candidate-resume-builder', label: 'Resume Builder', path: '/candidate/resume-builder', icon: <DescriptionIcon />, component: <ResumeBuilderPage />, userTypes: ['candidate'], },
{ id: 'candidate-content', label: 'Content', icon: <BubbleChart />, path: '/candidate/content', component: <Box sx={{ display: "flex", width: "100%", flexDirection: "column" }}><VectorVisualizer /><DocumentManager /></Box>, userTypes: ['candidate'] }, { id: 'candidate-content', label: 'Content', icon: <BubbleChart />, path: '/candidate/content', component: <Box sx={{ display: "flex", width: "100%", flexDirection: "column" }}><VectorVisualizer /><DocumentManager /></Box>, userTypes: ['candidate'] },
{ id: 'candidate-settings', label: 'Settings', path: '/candidate/settings', icon: <SettingsIcon />, component: <ComingSoon><Settings /></ComingSoon>, userTypes: ['candidate'], }, { id: 'candidate-settings', label: 'Settings', path: '/candidate/settings', icon: <SettingsIcon />, component: <Settings />, userTypes: ['candidate'], },
], ],
}, },
{ {
@ -87,7 +87,7 @@ export const navigationConfig: NavigationConfig = {
{ id: 'employer-settings', label: 'Settings', path: '/employer/settings', icon: <SettingsIcon />, component: <SettingsPage />, userTypes: ['employer'], }, { id: 'employer-settings', label: 'Settings', path: '/employer/settings', icon: <SettingsIcon />, component: <SettingsPage />, userTypes: ['employer'], },
], ],
}, },
{ id: 'find-candidate', label: 'Find a Candidate', path: '/find-a-candidate', icon: <PersonSearchIcon />, component: <CandidateListingPage />, userTypes: ['guest', 'candidate', 'employer'], }, // { id: 'find-candidate', label: 'Find a Candidate', path: '/find-a-candidate', icon: <PersonSearchIcon />, component: <CandidateListingPage />, userTypes: ['guest', 'candidate', 'employer'], },
{ id: 'docs', label: 'Docs', path: '/docs/*', icon: <LibraryBooksIcon />, component: <DocsPage />, userTypes: ['guest', 'candidate', 'employer'], }, { id: 'docs', label: 'Docs', path: '/docs/*', icon: <LibraryBooksIcon />, component: <DocsPage />, userTypes: ['guest', 'candidate', 'employer'], },
{ {
id: 'admin-menu', id: 'admin-menu',

View File

@ -1,17 +1,19 @@
// Replace the existing AuthContext.tsx with these enhancements
import React, { createContext, useContext, useState, useCallback, useEffect, useRef } from 'react'; import React, { createContext, useContext, useState, useCallback, useEffect, useRef } from 'react';
import * as Types from '../types/types'; import * as Types from '../types/types';
import { ApiClient, CreateCandidateRequest, CreateEmployerRequest } from '../services/api-client'; import { ApiClient, CreateCandidateRequest, CreateEmployerRequest, GuestConversionRequest } from '../services/api-client';
import { formatApiRequest, toCamelCase } from '../types/conversion'; import { formatApiRequest, toCamelCase } from '../types/conversion';
// ============================ // ============================
// Types and Interfaces // Enhanced Types and Interfaces
// ============================ // ============================
interface AuthState { interface AuthState {
user: Types.User | null; user: Types.User | null;
guest: Types.Guest | null; guest: Types.Guest | null;
isAuthenticated: boolean; isAuthenticated: boolean;
isGuest: boolean;
isLoading: boolean; isLoading: boolean;
isInitializing: boolean; isInitializing: boolean;
error: string | null; error: string | null;
@ -36,7 +38,7 @@ interface PasswordResetRequest {
} }
// ============================ // ============================
// Token Storage Constants // Enhanced Token Storage Constants
// ============================ // ============================
const TOKEN_STORAGE = { const TOKEN_STORAGE = {
@ -44,7 +46,8 @@ const TOKEN_STORAGE = {
REFRESH_TOKEN: 'refreshToken', REFRESH_TOKEN: 'refreshToken',
USER_DATA: 'userData', USER_DATA: 'userData',
TOKEN_EXPIRY: 'tokenExpiry', TOKEN_EXPIRY: 'tokenExpiry',
GUEST_DATA: 'guestData', USER_TYPE: 'userType',
IS_GUEST: 'isGuest',
PENDING_VERIFICATION_EMAIL: 'pendingVerificationEmail' PENDING_VERIFICATION_EMAIL: 'pendingVerificationEmail'
} as const; } as const;
@ -84,7 +87,7 @@ function isTokenExpired(token: string): boolean {
} }
// ============================ // ============================
// Storage Utilities with Date Conversion // Enhanced Storage Utilities
// ============================ // ============================
function clearStoredAuth(): void { function clearStoredAuth(): void {
@ -92,6 +95,8 @@ function clearStoredAuth(): void {
localStorage.removeItem(TOKEN_STORAGE.REFRESH_TOKEN); localStorage.removeItem(TOKEN_STORAGE.REFRESH_TOKEN);
localStorage.removeItem(TOKEN_STORAGE.USER_DATA); localStorage.removeItem(TOKEN_STORAGE.USER_DATA);
localStorage.removeItem(TOKEN_STORAGE.TOKEN_EXPIRY); localStorage.removeItem(TOKEN_STORAGE.TOKEN_EXPIRY);
localStorage.removeItem(TOKEN_STORAGE.USER_TYPE);
localStorage.removeItem(TOKEN_STORAGE.IS_GUEST);
} }
function prepareUserDataForStorage(user: Types.User): string { function prepareUserDataForStorage(user: Types.User): string {
@ -119,11 +124,21 @@ function parseStoredUserData(userDataStr: string): Types.User | null {
} }
} }
function storeAuthData(authResponse: Types.AuthResponse): void { function updateStoredUserData(user: Types.User): void {
try {
localStorage.setItem(TOKEN_STORAGE.USER_DATA, prepareUserDataForStorage(user));
} catch (error) {
console.error('Failed to update stored user data:', error);
}
}
function storeAuthData(authResponse: Types.AuthResponse, isGuest: boolean = false): void {
localStorage.setItem(TOKEN_STORAGE.ACCESS_TOKEN, authResponse.accessToken); localStorage.setItem(TOKEN_STORAGE.ACCESS_TOKEN, authResponse.accessToken);
localStorage.setItem(TOKEN_STORAGE.REFRESH_TOKEN, authResponse.refreshToken); localStorage.setItem(TOKEN_STORAGE.REFRESH_TOKEN, authResponse.refreshToken);
localStorage.setItem(TOKEN_STORAGE.USER_DATA, prepareUserDataForStorage(authResponse.user)); localStorage.setItem(TOKEN_STORAGE.USER_DATA, prepareUserDataForStorage(authResponse.user));
localStorage.setItem(TOKEN_STORAGE.TOKEN_EXPIRY, authResponse.expiresAt.toString()); localStorage.setItem(TOKEN_STORAGE.TOKEN_EXPIRY, authResponse.expiresAt.toString());
localStorage.setItem(TOKEN_STORAGE.USER_TYPE, authResponse.user.userType);
localStorage.setItem(TOKEN_STORAGE.IS_GUEST, isGuest.toString());
} }
function getStoredAuthData(): { function getStoredAuthData(): {
@ -131,11 +146,15 @@ function getStoredAuthData(): {
refreshToken: string | null; refreshToken: string | null;
userData: Types.User | null; userData: Types.User | null;
expiresAt: number | null; expiresAt: number | null;
userType: string | null;
isGuest: boolean;
} { } {
const accessToken = localStorage.getItem(TOKEN_STORAGE.ACCESS_TOKEN); const accessToken = localStorage.getItem(TOKEN_STORAGE.ACCESS_TOKEN);
const refreshToken = localStorage.getItem(TOKEN_STORAGE.REFRESH_TOKEN); const refreshToken = localStorage.getItem(TOKEN_STORAGE.REFRESH_TOKEN);
const userDataStr = localStorage.getItem(TOKEN_STORAGE.USER_DATA); const userDataStr = localStorage.getItem(TOKEN_STORAGE.USER_DATA);
const expiryStr = localStorage.getItem(TOKEN_STORAGE.TOKEN_EXPIRY); const expiryStr = localStorage.getItem(TOKEN_STORAGE.TOKEN_EXPIRY);
const userType = localStorage.getItem(TOKEN_STORAGE.USER_TYPE);
const isGuestStr = localStorage.getItem(TOKEN_STORAGE.IS_GUEST);
let userData: Types.User | null = null; let userData: Types.User | null = null;
let expiresAt: number | null = null; let expiresAt: number | null = null;
@ -152,55 +171,18 @@ function getStoredAuthData(): {
clearStoredAuth(); clearStoredAuth();
} }
return { accessToken, refreshToken, userData, expiresAt }; return {
} accessToken,
refreshToken,
function updateStoredUserData(user: Types.User): void { userData,
try { expiresAt,
localStorage.setItem(TOKEN_STORAGE.USER_DATA, prepareUserDataForStorage(user)); userType,
} catch (error) { isGuest: isGuestStr === 'true'
console.error('Failed to update stored user data:', error);
}
}
// ============================
// Guest Session Utilities
// ============================
function createGuestSession(): Types.Guest {
const sessionId = `guest_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
const guest: Types.Guest = {
sessionId,
createdAt: new Date(),
lastActivity: new Date(),
ipAddress: 'unknown',
userAgent: navigator.userAgent
}; };
try {
localStorage.setItem(TOKEN_STORAGE.GUEST_DATA, JSON.stringify(formatApiRequest(guest)));
} catch (error) {
console.error('Failed to store guest session:', error);
}
return guest;
}
function getStoredGuestData(): Types.Guest | null {
try {
const guestDataStr = localStorage.getItem(TOKEN_STORAGE.GUEST_DATA);
if (guestDataStr) {
return toCamelCase<Types.Guest>(JSON.parse(guestDataStr));
}
} catch (error) {
console.error('Failed to parse stored guest data:', error);
localStorage.removeItem(TOKEN_STORAGE.GUEST_DATA);
}
return null;
} }
// ============================ // ============================
// Main Authentication Hook // Enhanced Authentication Hook
// ============================ // ============================
function useAuthenticationLogic() { function useAuthenticationLogic() {
@ -208,6 +190,7 @@ function useAuthenticationLogic() {
user: null, user: null,
guest: null, guest: null,
isAuthenticated: false, isAuthenticated: false,
isGuest: false,
isLoading: false, isLoading: false,
isInitializing: true, isInitializing: true,
error: null, error: null,
@ -216,6 +199,7 @@ function useAuthenticationLogic() {
const [apiClient] = useState(() => new ApiClient()); const [apiClient] = useState(() => new ApiClient());
const initializationCompleted = useRef(false); const initializationCompleted = useRef(false);
const guestCreationAttempted = useRef(false);
// Token refresh function // Token refresh function
const refreshAccessToken = useCallback(async (refreshToken: string): Promise<Types.AuthResponse | null> => { const refreshAccessToken = useCallback(async (refreshToken: string): Promise<Types.AuthResponse | null> => {
@ -228,6 +212,58 @@ function useAuthenticationLogic() {
} }
}, [apiClient]); }, [apiClient]);
// Create guest session
const createGuestSession = useCallback(async (): Promise<boolean> => {
if (guestCreationAttempted.current) {
return false;
}
guestCreationAttempted.current = true;
try {
console.log('🔄 Creating guest session...');
const guestAuth = await apiClient.createGuestSession();
if (guestAuth && guestAuth.user && guestAuth.user.userType === 'guest') {
storeAuthData(guestAuth, true);
apiClient.setAuthToken(guestAuth.accessToken);
setAuthState({
user: null,
guest: guestAuth.user as Types.Guest,
isAuthenticated: true,
isGuest: true,
isLoading: false,
isInitializing: false,
error: null,
mfaResponse: null,
});
console.log('👤 Guest session created successfully:', guestAuth.user);
return true;
}
return false;
} catch (error) {
console.error('❌ Failed to create guest session:', error);
guestCreationAttempted.current = false;
// Set to unauthenticated state if guest creation fails
setAuthState(prev => ({
...prev,
user: null,
guest: null,
isAuthenticated: false,
isGuest: false,
isLoading: false,
isInitializing: false,
error: 'Failed to create guest session',
}));
return false;
}
}, [apiClient]);
// Initialize authentication state // Initialize authentication state
const initializeAuth = useCallback(async () => { const initializeAuth = useCallback(async () => {
if (initializationCompleted.current) { if (initializationCompleted.current) {
@ -235,99 +271,94 @@ function useAuthenticationLogic() {
} }
try { try {
// Initialize guest session first
let guest = getStoredGuestData();
if (!guest) {
guest = createGuestSession();
}
const stored = getStoredAuthData(); const stored = getStoredAuthData();
// If no stored tokens, user is not authenticated but has guest session // If no stored tokens, create guest session
if (!stored.accessToken || !stored.refreshToken || !stored.userData) { if (!stored.accessToken || !stored.refreshToken || !stored.userData) {
setAuthState({ console.log('🔄 No stored auth found, creating guest session...');
user: null, await createGuestSession();
guest,
isAuthenticated: false,
isLoading: false,
isInitializing: false,
error: null,
mfaResponse: null,
});
return; return;
} }
// For guests, always verify the session exists on server
if (stored.userType === 'guest' && stored.userData) {
console.log(stored.userData);
try {
// Make a quick API call to verify guest still exists
const response = await fetch(`${apiClient.getBaseUrl()}/users/${stored.userData.id}`, {
headers: { 'Authorization': `Bearer ${stored.accessToken}` }
});
if (!response.ok) {
console.log('🔄 Guest session invalid, creating new guest session...');
clearStoredAuth();
await createGuestSession();
return;
}
} catch (error) {
console.log('🔄 Guest verification failed, creating new guest session...');
clearStoredAuth();
await createGuestSession();
return;
}
}
// Check if access token is expired // Check if access token is expired
if (isTokenExpired(stored.accessToken)) { if (isTokenExpired(stored.accessToken)) {
console.log('Access token expired, attempting refresh...'); console.log('🔄 Access token expired, attempting refresh...');
const refreshResult = await refreshAccessToken(stored.refreshToken); const refreshResult = await refreshAccessToken(stored.refreshToken);
if (refreshResult) { if (refreshResult) {
storeAuthData(refreshResult); const isGuest = stored.userType === 'guest';
storeAuthData(refreshResult, isGuest);
apiClient.setAuthToken(refreshResult.accessToken); apiClient.setAuthToken(refreshResult.accessToken);
setAuthState({ setAuthState({
user: refreshResult.user, user: isGuest ? null : refreshResult.user,
guest, guest: isGuest ? refreshResult.user as Types.Guest : null,
isAuthenticated: true, isAuthenticated: true,
isGuest,
isLoading: false, isLoading: false,
isInitializing: false, isInitializing: false,
error: null, error: null,
mfaResponse: null mfaResponse: null
}); });
console.log('Token refreshed successfully'); console.log('Token refreshed successfully');
} else { } else {
console.log('Token refresh failed, clearing stored auth'); console.log('❌ Token refresh failed, creating new guest session...');
clearStoredAuth(); clearStoredAuth();
apiClient.clearAuthToken(); apiClient.clearAuthToken();
await createGuestSession();
setAuthState({
user: null,
guest,
isAuthenticated: false,
isLoading: false,
isInitializing: false,
error: null,
mfaResponse: null
});
} }
} else { } else {
// Access token is still valid // Access token is still valid
apiClient.setAuthToken(stored.accessToken); apiClient.setAuthToken(stored.accessToken);
const isGuest = stored.userType === 'guest';
setAuthState({ setAuthState({
user: stored.userData, user: isGuest ? null : stored.userData,
guest, guest: isGuest ? stored.userData as Types.Guest : null,
isAuthenticated: true, isAuthenticated: true,
isGuest,
isLoading: false, isLoading: false,
isInitializing: false, isInitializing: false,
error: null, error: null,
mfaResponse: null mfaResponse: null
}); });
console.log('Restored authentication from stored tokens'); console.log('Restored authentication from stored tokens');
} }
} catch (error) { } catch (error) {
console.error('Error initializing auth:', error); console.error('Error initializing auth:', error);
clearStoredAuth(); clearStoredAuth();
apiClient.clearAuthToken(); apiClient.clearAuthToken();
await createGuestSession();
const guest = createGuestSession();
setAuthState({
user: null,
guest,
isAuthenticated: false,
isLoading: false,
isInitializing: false,
error: null,
mfaResponse: null
});
} finally { } finally {
initializationCompleted.current = true; initializationCompleted.current = true;
} }
}, [apiClient, refreshAccessToken]); }, [apiClient, refreshAccessToken, createGuestSession]);
// Run initialization on mount // Run initialization on mount
useEffect(() => { useEffect(() => {
@ -355,7 +386,7 @@ function useAuthenticationLogic() {
} }
const refreshTimer = setTimeout(() => { const refreshTimer = setTimeout(() => {
console.log('Auto-refreshing token before expiry...'); console.log('🔄 Auto-refreshing token before expiry...');
initializeAuth(); initializeAuth();
}, timeUntilExpiry); }, timeUntilExpiry);
@ -364,7 +395,7 @@ function useAuthenticationLogic() {
// Enhanced login with MFA support // Enhanced login with MFA support
const login = useCallback(async (loginData: LoginRequest): Promise<boolean> => { const login = useCallback(async (loginData: LoginRequest): Promise<boolean> => {
setAuthState(prev => ({ ...prev, isLoading: true, error: null, mfaResponse: null, mfaData: null })); setAuthState(prev => ({ ...prev, isLoading: true, error: null, mfaResponse: null }));
try { try {
const result = await apiClient.login({ const result = await apiClient.login({
@ -381,21 +412,23 @@ function useAuthenticationLogic() {
})); }));
return false; // Login not complete yet return false; // Login not complete yet
} else { } else {
// Normal login success // Normal login success - convert from guest to authenticated user
const authResponse: Types.AuthResponse = result; const authResponse: Types.AuthResponse = result;
storeAuthData(authResponse); storeAuthData(authResponse, false);
apiClient.setAuthToken(authResponse.accessToken); apiClient.setAuthToken(authResponse.accessToken);
setAuthState(prev => ({ setAuthState(prev => ({
...prev, ...prev,
user: authResponse.user, user: authResponse.user,
guest: null,
isAuthenticated: true, isAuthenticated: true,
isGuest: false,
isLoading: false, isLoading: false,
error: null, error: null,
mfaResponse: null, mfaResponse: null,
})); }));
console.log('Login successful'); console.log('Login successful, converted from guest to authenticated user');
return true; return true;
} }
} catch (error: any) { } catch (error: any) {
@ -410,6 +443,44 @@ function useAuthenticationLogic() {
} }
}, [apiClient]); }, [apiClient]);
// Convert guest to permanent user
const convertGuestToUser = useCallback(async (registrationData: GuestConversionRequest): Promise<boolean> => {
if (!authState.isGuest || !authState.guest) {
throw new Error('Not currently a guest user');
}
setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
try {
const result = await apiClient.convertGuestToUser(registrationData);
// Store new authentication
storeAuthData(result.auth, false);
apiClient.setAuthToken(result.auth.accessToken);
setAuthState(prev => ({
...prev,
user: result.auth.user,
guest: null,
isAuthenticated: true,
isGuest: false,
isLoading: false,
error: null,
}));
console.log('✅ Guest successfully converted to permanent user');
return true;
} catch (error: any) {
const errorMessage = error instanceof Error ? error.message : 'Failed to convert guest account';
setAuthState(prev => ({
...prev,
isLoading: false,
error: errorMessage,
}));
return false;
}
}, [apiClient, authState.isGuest, authState.guest]);
// MFA verification // MFA verification
const verifyMFA = useCallback(async (mfaData: Types.MFAVerifyRequest): Promise<boolean> => { const verifyMFA = useCallback(async (mfaData: Types.MFAVerifyRequest): Promise<boolean> => {
setAuthState(prev => ({ ...prev, isLoading: true, error: null })); setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
@ -419,26 +490,27 @@ function useAuthenticationLogic() {
if (result.accessToken) { if (result.accessToken) {
const authResponse: Types.AuthResponse = result; const authResponse: Types.AuthResponse = result;
storeAuthData(authResponse); storeAuthData(authResponse, false);
apiClient.setAuthToken(authResponse.accessToken); apiClient.setAuthToken(authResponse.accessToken);
setAuthState(prev => ({ setAuthState(prev => ({
...prev, ...prev,
user: authResponse.user, user: authResponse.user,
guest: null,
isAuthenticated: true, isAuthenticated: true,
isGuest: false,
isLoading: false, isLoading: false,
error: null, error: null,
mfaResponse: null, mfaResponse: null,
})); }));
console.log('MFA verification successful'); console.log('MFA verification successful, converted from guest');
return true; return true;
} }
return false; return false;
} catch (error) { } catch (error) {
const errorMessage = error instanceof Error ? error.message : 'MFA verification failed'; const errorMessage = error instanceof Error ? error.message : 'MFA verification failed';
console.log(errorMessage);
setAuthState(prev => ({ setAuthState(prev => ({
...prev, ...prev,
isLoading: false, isLoading: false,
@ -448,42 +520,48 @@ function useAuthenticationLogic() {
} }
}, [apiClient]); }, [apiClient]);
// Resend MFA code // Logout - returns to guest session
const resendMFACode = useCallback(async (email: string, deviceId: string, deviceName: string): Promise<boolean> => { const logout = useCallback(async () => {
setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
try { try {
await apiClient.requestMFA({ // If authenticated, try to logout gracefully
email, if (authState.isAuthenticated && !authState.isGuest) {
password: '', // This would need to be stored securely or re-entered const stored = getStoredAuthData();
deviceId, if (stored.accessToken && stored.refreshToken) {
deviceName, try {
}); await apiClient.logout(stored.accessToken, stored.refreshToken);
} catch (error) {
setAuthState(prev => ({ ...prev, isLoading: false })); console.warn('Logout request failed, proceeding with local cleanup');
return true; }
}
}
} catch (error) { } catch (error) {
const errorMessage = error instanceof Error ? error.message : 'Failed to resend MFA code'; console.warn('Error during logout:', error);
setAuthState(prev => ({ } finally {
...prev, // Always clear stored auth and create new guest session
isLoading: false, clearStoredAuth();
error: errorMessage apiClient.clearAuthToken();
})); guestCreationAttempted.current = false;
return false;
}
}, [apiClient]);
// Clear MFA state // Create new guest session
const clearMFA = useCallback(() => { await createGuestSession();
console.log('🔄 Logged out, created new guest session');
}
}, [apiClient, authState.isAuthenticated, authState.isGuest, createGuestSession]);
// Update user data
const updateUserData = useCallback((updatedUser: Types.User) => {
updateStoredUserData(updatedUser);
setAuthState(prev => ({ setAuthState(prev => ({
...prev, ...prev,
mfaResponse: null, user: authState.isGuest ? null : updatedUser,
error: null guest: authState.isGuest ? updatedUser as Types.Guest : prev.guest
})); }));
}, []); console.log('✅ User data updated');
}, [authState.isGuest]);
// Email verification // Email verification functions (unchanged)
const verifyEmail = useCallback(async (verificationData: EmailVerificationRequest): Promise<{ message: string; userType: string } | null> => { const verifyEmail = useCallback(async (verificationData: EmailVerificationRequest) => {
setAuthState(prev => ({ ...prev, isLoading: true, error: null })); setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
try { try {
@ -504,7 +582,7 @@ function useAuthenticationLogic() {
} }
}, [apiClient]); }, [apiClient]);
// Resend email verification // Other existing methods remain the same...
const resendEmailVerification = useCallback(async (email: string): Promise<boolean> => { const resendEmailVerification = useCallback(async (email: string): Promise<boolean> => {
setAuthState(prev => ({ ...prev, isLoading: true, error: null })); setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
@ -523,53 +601,21 @@ function useAuthenticationLogic() {
} }
}, [apiClient]); }, [apiClient]);
// Store pending verification email
const setPendingVerificationEmail = useCallback((email: string) => { const setPendingVerificationEmail = useCallback((email: string) => {
localStorage.setItem(TOKEN_STORAGE.PENDING_VERIFICATION_EMAIL, email); localStorage.setItem(TOKEN_STORAGE.PENDING_VERIFICATION_EMAIL, email);
}, []); }, []);
// Get pending verification email
const getPendingVerificationEmail = useCallback((): string | null => { const getPendingVerificationEmail = useCallback((): string | null => {
return localStorage.getItem(TOKEN_STORAGE.PENDING_VERIFICATION_EMAIL); return localStorage.getItem(TOKEN_STORAGE.PENDING_VERIFICATION_EMAIL);
}, []); }, []);
const logout = useCallback(() => {
clearStoredAuth();
apiClient.clearAuthToken();
// Create new guest session after logout
const guest = createGuestSession();
setAuthState(prev => ({
...prev,
user: null,
guest,
isAuthenticated: false,
isLoading: false,
error: null,
mfaResponse: null,
}));
console.log('User logged out');
}, [apiClient]);
const updateUserData = useCallback((updatedUser: Types.User) => {
updateStoredUserData(updatedUser);
setAuthState(prev => ({
...prev,
user: updatedUser
}));
console.log('User data updated', updatedUser);
}, []);
const createEmployerAccount = useCallback(async (employerData: CreateEmployerRequest): Promise<boolean> => { const createEmployerAccount = useCallback(async (employerData: CreateEmployerRequest): Promise<boolean> => {
setAuthState(prev => ({ ...prev, isLoading: true, error: null })); setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
try { try {
const employer = await apiClient.createEmployer(employerData); const employer = await apiClient.createEmployer(employerData);
console.log('Employer created:', employer); console.log('✅ Employer created:', employer);
// Store email for potential verification resend
setPendingVerificationEmail(employerData.email); setPendingVerificationEmail(employerData.email);
setAuthState(prev => ({ ...prev, isLoading: false })); setAuthState(prev => ({ ...prev, isLoading: false }));
@ -614,24 +660,61 @@ function useAuthenticationLogic() {
const refreshResult = await refreshAccessToken(stored.refreshToken); const refreshResult = await refreshAccessToken(stored.refreshToken);
if (refreshResult) { if (refreshResult) {
storeAuthData(refreshResult); const isGuest = stored.userType === 'guest';
storeAuthData(refreshResult, isGuest);
apiClient.setAuthToken(refreshResult.accessToken); apiClient.setAuthToken(refreshResult.accessToken);
setAuthState(prev => ({ setAuthState(prev => ({
...prev, ...prev,
user: refreshResult.user, user: isGuest ? null : refreshResult.user,
guest: isGuest ? refreshResult.user as Types.Guest : null,
isAuthenticated: true, isAuthenticated: true,
isGuest,
isLoading: false, isLoading: false,
error: null error: null
})); }));
return true; return true;
} else { } else {
logout(); await logout();
return false; return false;
} }
}, [refreshAccessToken, logout]); }, [refreshAccessToken, logout]);
// Resend MFA code
const resendMFACode = useCallback(async (email: string, deviceId: string, deviceName: string): Promise<boolean> => {
setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
try {
await apiClient.requestMFA({
email,
password: '', // This would need to be stored securely or re-entered
deviceId,
deviceName,
});
setAuthState(prev => ({ ...prev, isLoading: false }));
return true;
} catch (error) {
const errorMessage = error instanceof Error ? error.message : 'Failed to resend MFA code';
setAuthState(prev => ({
...prev,
isLoading: false,
error: errorMessage
}));
return false;
}
}, [apiClient]);
// Clear MFA state
const clearMFA = useCallback(() => {
setAuthState(prev => ({
...prev,
mfaResponse: null,
error: null
}));
}, []);
return { return {
...authState, ...authState,
apiClient, apiClient,
@ -647,12 +730,14 @@ function useAuthenticationLogic() {
createEmployerAccount, createEmployerAccount,
requestPasswordReset, requestPasswordReset,
refreshAuth, refreshAuth,
updateUserData updateUserData,
convertGuestToUser,
createGuestSession
}; };
} }
// ============================ // ============================
// Context Provider // Enhanced Context Provider
// ============================ // ============================
const AuthContext = createContext<ReturnType<typeof useAuthenticationLogic> | null>(null); const AuthContext = createContext<ReturnType<typeof useAuthenticationLogic> | null>(null);
@ -676,34 +761,41 @@ function useAuth() {
} }
// ============================ // ============================
// Protected Route Component // Enhanced Protected Route Component
// ============================ // ============================
interface ProtectedRouteProps { interface ProtectedRouteProps {
children: React.ReactNode; children: React.ReactNode;
fallback?: React.ReactNode; fallback?: React.ReactNode;
requiredUserType?: Types.UserType; requiredUserType?: Types.UserType;
allowGuests?: boolean;
} }
function ProtectedRoute({ function ProtectedRoute({
children, children,
fallback = <div>Please log in to access this page.</div>, fallback = <div>Please log in to access this page.</div>,
requiredUserType requiredUserType,
allowGuests = false
}: ProtectedRouteProps) { }: ProtectedRouteProps) {
const { isAuthenticated, isInitializing, user } = useAuth(); const { isAuthenticated, isInitializing, user, isGuest } = useAuth();
// Show loading while checking stored tokens // Show loading while checking stored tokens
if (isInitializing) { if (isInitializing) {
return <div>Loading...</div>; return <div>Loading...</div>;
} }
// Not authenticated // Not authenticated at all (shouldn't happen with guest sessions)
if (!isAuthenticated) { if (!isAuthenticated) {
return <>{fallback}</>; return <>{fallback}</>;
} }
// Check user type if required // Guest access control
if (requiredUserType && user?.userType !== requiredUserType) { if (isGuest && !allowGuests) {
return <div>Please create an account or log in to access this page.</div>;
}
// Check user type if required (only for non-guests)
if (requiredUserType && !isGuest && user?.userType !== requiredUserType) {
return <div>Access denied. Required user type: {requiredUserType}</div>; return <div>Access denied. Required user type: {requiredUserType}</div>;
} }
@ -711,11 +803,19 @@ function ProtectedRoute({
} }
export type { export type {
AuthState, LoginRequest, EmailVerificationRequest, ResendVerificationRequest, PasswordResetRequest AuthState,
LoginRequest,
EmailVerificationRequest,
ResendVerificationRequest,
PasswordResetRequest,
GuestConversionRequest
} }
export type { CreateCandidateRequest, CreateEmployerRequest } from '../services/api-client'; export type { CreateCandidateRequest, CreateEmployerRequest } from '../services/api-client';
export { export {
useAuthenticationLogic, AuthProvider, useAuth, ProtectedRoute useAuthenticationLogic,
AuthProvider,
useAuth,
ProtectedRoute
} }

View File

@ -23,15 +23,16 @@ import { useAppState, useSelectedCandidate } from 'hooks/GlobalContext';
import PropagateLoader from 'react-spinners/PropagateLoader'; import PropagateLoader from 'react-spinners/PropagateLoader';
import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryTextField'; import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryTextField';
import { BackstoryQuery } from 'components/BackstoryQuery'; import { BackstoryQuery } from 'components/BackstoryQuery';
import { CandidatePicker } from 'components/ui/CandidatePicker';
const defaultMessage: ChatMessage = { const defaultMessage: ChatMessage = {
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "user" status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "user", metadata: null as any
}; };
const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: BackstoryPageProps, ref) => { const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: BackstoryPageProps, ref) => {
const { apiClient } = useAuth(); const { apiClient } = useAuth();
const navigate = useNavigate(); const navigate = useNavigate();
const { selectedCandidate } = useSelectedCandidate() const { selectedCandidate, setSelectedCandidate } = useSelectedCandidate()
const theme = useTheme(); const theme = useTheme();
const [processingMessage, setProcessingMessage] = useState<ChatMessageStatus | ChatMessageError | null>(null); const [processingMessage, setProcessingMessage] = useState<ChatMessageStatus | ChatMessageError | null>(null);
const [streamingMessage, setStreamingMessage] = useState<ChatMessage | null>(null); const [streamingMessage, setStreamingMessage] = useState<ChatMessage | null>(null);
@ -92,7 +93,7 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
timestamp: new Date() timestamp: new Date()
}; };
setProcessingMessage({ ...defaultMessage, status: 'status', content: `Establishing connection with ${selectedCandidate.firstName}'s chat session.` }); setProcessingMessage({ ...defaultMessage, status: 'status', activity: "info", content: `Establishing connection with ${selectedCandidate.firstName}'s chat session.` });
setMessages(prev => { setMessages(prev => {
const filtered = prev.filter((m: any) => m.id !== chatMessage.id); const filtered = prev.filter((m: any) => m.id !== chatMessage.id);
@ -123,7 +124,7 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
}, },
onStreaming: (chunk: ChatMessageStreaming) => { onStreaming: (chunk: ChatMessageStreaming) => {
// console.log("onStreaming:", chunk); // console.log("onStreaming:", chunk);
setStreamingMessage({ ...chunk, role: 'assistant' }); setStreamingMessage({ ...chunk, role: 'assistant', metadata: null as any });
}, },
onStatus: (status: ChatMessageStatus) => { onStatus: (status: ChatMessageStatus) => {
setProcessingMessage(status); setProcessingMessage(status);
@ -171,8 +172,7 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
}, [chatSession]); }, [chatSession]);
if (!selectedCandidate) { if (!selectedCandidate) {
navigate('/find-a-candidate'); return <CandidatePicker />;
return (<></>);
} }
const welcomeMessage: ChatMessage = { const welcomeMessage: ChatMessage = {
@ -181,7 +181,8 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
type: "text", type: "text",
status: "done", status: "done",
timestamp: new Date(), timestamp: new Date(),
content: `Welcome to the Backstory Chat about ${selectedCandidate.fullName}. Ask any questions you have about ${selectedCandidate.firstName}.` content: `Welcome to the Backstory Chat about ${selectedCandidate.fullName}. Ask any questions you have about ${selectedCandidate.firstName}.`,
metadata: null as any
}; };
return ( return (
@ -192,12 +193,14 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
gap: 1, gap: 1,
}}> }}>
<CandidateInfo <CandidateInfo
key={selectedCandidate.username}
action={`Chat with Backstory about ${selectedCandidate.firstName}`} action={`Chat with Backstory about ${selectedCandidate.firstName}`}
elevation={4} elevation={4}
candidate={selectedCandidate} candidate={selectedCandidate}
variant="small" variant="small"
sx={{ flexShrink: 0 }} // Prevent header from shrinking sx={{ flexShrink: 0 }} // Prevent header from shrinking
/> />
<Button onClick={() => { setSelectedCandidate(null); }} variant="contained">Change Candidates</Button>
{/* Chat Interface */} {/* Chat Interface */}
<Paper <Paper

View File

@ -23,7 +23,7 @@ import { Message } from 'components/Message';
import { useAppState } from 'hooks/GlobalContext'; import { useAppState } from 'hooks/GlobalContext';
const defaultMessage: ChatMessage = { const defaultMessage: ChatMessage = {
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "user" status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "user", metadata: null as any
}; };
const GenerateCandidate = (props: BackstoryElementProps) => { const GenerateCandidate = (props: BackstoryElementProps) => {

View File

@ -20,6 +20,7 @@ import QuestionAnswerIcon from '@mui/icons-material/QuestionAnswer';
import DescriptionIcon from '@mui/icons-material/Description'; import DescriptionIcon from '@mui/icons-material/Description';
import professionalConversationPng from './Conversation.png'; import professionalConversationPng from './Conversation.png';
import { ComingSoon } from 'components/ui/ComingSoon'; import { ComingSoon } from 'components/ui/ComingSoon';
import { useAuth } from 'hooks/AuthContext';
// Placeholder for Testimonials component // Placeholder for Testimonials component
const Testimonials = () => { const Testimonials = () => {
@ -135,6 +136,15 @@ const FeatureCard = ({
const HomePage = () => { const HomePage = () => {
const testimonials = false; const testimonials = false;
const { isGuest, guest, user } = useAuth();
if (isGuest) {
// Show guest-specific UI
console.log('Guest session:', guest?.sessionId || "No guest");
} else {
// Show authenticated user UI
console.log('Authenticated user:', user?.email || "No user");
}
return (<Box sx={{display: "flex", flexDirection: "column"}}> return (<Box sx={{display: "flex", flexDirection: "column"}}>
{/* Hero Section */} {/* Hero Section */}

View File

@ -10,7 +10,8 @@ const LoadingPage = (props: BackstoryPageProps) => {
status: 'done', status: 'done',
sessionId: '', sessionId: '',
content: 'Please wait while connecting to Backstory...', content: 'Please wait while connecting to Backstory...',
timestamp: new Date() timestamp: new Date(),
metadata: null as any
} }
return <Box sx={{display: "flex", flexGrow: 1, maxWidth: "1024px", margin: "0 auto"}}> return <Box sx={{display: "flex", flexGrow: 1, maxWidth: "1024px", margin: "0 auto"}}>

View File

@ -26,6 +26,7 @@ import { LoginForm } from "components/EmailVerificationComponents";
import { CandidateRegistrationForm } from "components/RegistrationForms"; import { CandidateRegistrationForm } from "components/RegistrationForms";
import { useNavigate } from 'react-router-dom'; import { useNavigate } from 'react-router-dom';
import { useAppState } from 'hooks/GlobalContext'; import { useAppState } from 'hooks/GlobalContext';
import * as Types from 'types/types';
const LoginPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps) => { const LoginPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps) => {
const navigate = useNavigate(); const navigate = useNavigate();
@ -34,7 +35,7 @@ const LoginPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps) => {
const [loading, setLoading] = useState(false); const [loading, setLoading] = useState(false);
const [success, setSuccess] = useState<string | null>(null); const [success, setSuccess] = useState<string | null>(null);
const { guest, user, login, isLoading, error } = useAuth(); const { guest, user, login, isLoading, error } = useAuth();
const name = (user?.userType === 'candidate') ? user.username : user?.email || ''; const name = (user?.userType === 'candidate') ? (user as Types.Candidate).username : user?.email || '';
const [errorMessage, setErrorMessage] = useState<string | null>(null); const [errorMessage, setErrorMessage] = useState<string | null>(null);
const showGuest: boolean = false; const showGuest: boolean = false;

View File

@ -10,7 +10,8 @@ const LoginRequired = (props: BackstoryPageProps) => {
status: 'done', status: 'done',
sessionId: '', sessionId: '',
content: 'You must be logged to view this feature.', content: 'You must be logged to view this feature.',
timestamp: new Date() timestamp: new Date(),
metadata: null as any
} }
return <Box sx={{display: "flex", flexGrow: 1, maxWidth: "1024px", margin: "0 auto"}}> return <Box sx={{display: "flex", flexGrow: 1, maxWidth: "1024px", margin: "0 auto"}}>

View File

@ -11,6 +11,7 @@ import { CandidateInfo } from 'components/ui/CandidateInfo';
import { useAuth } from 'hooks/AuthContext'; import { useAuth } from 'hooks/AuthContext';
import { Candidate } from 'types/types'; import { Candidate } from 'types/types';
import { useAppState } from 'hooks/GlobalContext'; import { useAppState } from 'hooks/GlobalContext';
import * as Types from 'types/types';
const ChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: BackstoryPageProps, ref) => { const ChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: BackstoryPageProps, ref) => {
const { setSnack } = useAppState(); const { setSnack } = useAppState();
@ -18,7 +19,7 @@ const ChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: Back
const theme = useTheme(); const theme = useTheme();
const isMobile = useMediaQuery(theme.breakpoints.down('md')); const isMobile = useMediaQuery(theme.breakpoints.down('md'));
const [questions, setQuestions] = useState<React.ReactElement[]>([]); const [questions, setQuestions] = useState<React.ReactElement[]>([]);
const candidate: Candidate | null = user?.userType === 'candidate' ? user : null; const candidate: Candidate | null = user?.userType === 'candidate' ? user as Types.Candidate : null;
// console.log("ChatPage candidate =>", candidate); // console.log("ChatPage candidate =>", candidate);
useEffect(() => { useEffect(() => {

View File

@ -74,7 +74,7 @@ const CandidateDashboard = (props: CandidateDashboardProps) => {
variant="contained" variant="contained"
color="primary" color="primary"
sx={{ mt: 1 }} sx={{ mt: 1 }}
onClick={(e) => {e.stopPropagation(); navigate('/candidate/dashboard/profile'); }} onClick={(e) => { e.stopPropagation(); navigate('/candidate/profile'); }}
> >
Complete Your Profile Complete Your Profile
</Button> </Button>

View File

@ -14,9 +14,10 @@ import Typography from '@mui/material/Typography';
// import ResetIcon from '@mui/icons-material/History'; // import ResetIcon from '@mui/icons-material/History';
import ExpandMoreIcon from '@mui/icons-material/ExpandMore'; import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
import { connectionBase } from '../../utils/Global';
import { BackstoryPageProps } from '../../components/BackstoryTab'; import { BackstoryPageProps } from '../../components/BackstoryTab';
import { useAppState } from 'hooks/GlobalContext'; import { useAppState } from 'hooks/GlobalContext';
import { useAuth } from 'hooks/AuthContext';
import * as Types from 'types/types';
interface ServerTunables { interface ServerTunables {
system_prompt: string, system_prompt: string,
@ -33,19 +34,7 @@ type Tool = {
returns?: any returns?: any
}; };
type GPUInfo = { const SystemInfoComponent: React.FC<{ systemInfo: Types.SystemInfo | undefined }> = ({ systemInfo }) => {
name: string,
memory: number,
discrete: boolean
}
type SystemInfo = {
"Installed RAM": string,
"Graphics Card": GPUInfo[],
"CPU": string
};
const SystemInfoComponent: React.FC<{ systemInfo: SystemInfo | undefined }> = ({ systemInfo }) => {
const [systemElements, setSystemElements] = useState<ReactElement[]>([]); const [systemElements, setSystemElements] = useState<ReactElement[]>([]);
const convertToSymbols = (text: string) => { const convertToSymbols = (text: string) => {
@ -86,91 +75,92 @@ const SystemInfoComponent: React.FC<{ systemInfo: SystemInfo | undefined }> = ({
}; };
const Settings = (props: BackstoryPageProps) => { const Settings = (props: BackstoryPageProps) => {
const { apiClient } = useAuth();
const { setSnack } = useAppState(); const { setSnack } = useAppState();
const [editSystemPrompt, setEditSystemPrompt] = useState<string>(""); const [editSystemPrompt, setEditSystemPrompt] = useState<string>("");
const [systemInfo, setSystemInfo] = useState<SystemInfo | undefined>(undefined); const [systemInfo, setSystemInfo] = useState<Types.SystemInfo | undefined>(undefined);
const [tools, setTools] = useState<Tool[]>([]); const [tools, setTools] = useState<Tool[]>([]);
const [rags, setRags] = useState<Tool[]>([]); const [rags, setRags] = useState<Tool[]>([]);
const [systemPrompt, setSystemPrompt] = useState<string>(""); const [systemPrompt, setSystemPrompt] = useState<string>("");
const [messageHistoryLength, setMessageHistoryLength] = useState<number>(5); const [messageHistoryLength, setMessageHistoryLength] = useState<number>(5);
const [serverTunables, setServerTunables] = useState<ServerTunables | undefined>(undefined); const [serverTunables, setServerTunables] = useState<ServerTunables | undefined>(undefined);
useEffect(() => { // useEffect(() => {
if (serverTunables === undefined || systemPrompt === serverTunables.system_prompt || !systemPrompt.trim()) { // if (serverTunables === undefined || systemPrompt === serverTunables.system_prompt || !systemPrompt.trim()) {
return; // return;
} // }
const sendSystemPrompt = async (prompt: string) => { // const sendSystemPrompt = async (prompt: string) => {
try { // try {
const response = await fetch(connectionBase + `/api/1.0/tunables`, { // const response = await fetch(connectionBase + `/api/1.0/tunables`, {
method: 'PUT', // method: 'PUT',
headers: { // headers: {
'Content-Type': 'application/json', // 'Content-Type': 'application/json',
'Accept': 'application/json', // 'Accept': 'application/json',
}, // },
body: JSON.stringify({ "system_prompt": prompt }), // body: JSON.stringify({ "system_prompt": prompt }),
}); // });
const tunables = await response.json(); // const tunables = await response.json();
serverTunables.system_prompt = tunables.system_prompt; // serverTunables.system_prompt = tunables.system_prompt;
console.log(tunables); // console.log(tunables);
setSystemPrompt(tunables.system_prompt) // setSystemPrompt(tunables.system_prompt)
setSnack("System prompt updated", "success"); // setSnack("System prompt updated", "success");
} catch (error) { // } catch (error) {
console.error('Fetch error:', error); // console.error('Fetch error:', error);
setSnack("System prompt update failed", "error"); // setSnack("System prompt update failed", "error");
} // }
}; // };
sendSystemPrompt(systemPrompt); // sendSystemPrompt(systemPrompt);
}, [systemPrompt, setSnack, serverTunables]); // }, [systemPrompt, setSnack, serverTunables]);
const reset = async (types: ("rags" | "tools" | "history" | "system_prompt")[], message: string = "Update successful.") => { // const reset = async (types: ("rags" | "tools" | "history" | "system_prompt")[], message: string = "Update successful.") => {
try { // try {
const response = await fetch(connectionBase + `/api/1.0/reset/`, { // const response = await fetch(connectionBase + `/api/1.0/reset/`, {
method: 'PUT', // method: 'PUT',
headers: { // headers: {
'Content-Type': 'application/json', // 'Content-Type': 'application/json',
'Accept': 'application/json', // 'Accept': 'application/json',
}, // },
body: JSON.stringify({ "reset": types }), // body: JSON.stringify({ "reset": types }),
}); // });
if (!response.ok) { // if (!response.ok) {
throw new Error(`Server responded with ${response.status}: ${response.statusText}`); // throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
} // }
if (!response.body) { // if (!response.body) {
throw new Error('Response body is null'); // throw new Error('Response body is null');
} // }
const data = await response.json(); // const data = await response.json();
if (data.error) { // if (data.error) {
throw Error(data.error); // throw Error(data.error);
} // }
for (const [key, value] of Object.entries(data)) { // for (const [key, value] of Object.entries(data)) {
switch (key) { // switch (key) {
case "rags": // case "rags":
setRags(value as Tool[]); // setRags(value as Tool[]);
break; // break;
case "tools": // case "tools":
setTools(value as Tool[]); // setTools(value as Tool[]);
break; // break;
case "system_prompt": // case "system_prompt":
setSystemPrompt((value as ServerTunables)["system_prompt"].trim()); // setSystemPrompt((value as ServerTunables)["system_prompt"].trim());
break; // break;
case "history": // case "history":
console.log('TODO: handle history reset'); // console.log('TODO: handle history reset');
break; // break;
} // }
} // }
setSnack(message, "success"); // setSnack(message, "success");
} catch (error) { // } catch (error) {
console.error('Fetch error:', error); // console.error('Fetch error:', error);
setSnack("Unable to restore defaults", "error"); // setSnack("Unable to restore defaults", "error");
} // }
}; // };
// Get the system information // Get the system information
useEffect(() => { useEffect(() => {
@ -179,27 +169,8 @@ const Settings = (props: BackstoryPageProps) => {
} }
const fetchSystemInfo = async () => { const fetchSystemInfo = async () => {
try { try {
const response = await fetch(connectionBase + `/api/1.0/system-info`, { const response: Types.SystemInfo = await apiClient.getSystemInfo();
method: 'GET', setSystemInfo(response);
headers: {
'Content-Type': 'application/json',
},
})
if (!response.ok) {
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
}
if (!response.body) {
throw new Error('Response body is null');
}
const data = await response.json();
if (data.error) {
throw Error(data.error);
}
setSystemInfo(data);
} catch (error) { } catch (error) {
console.error('Error obtaining system information:', error); console.error('Error obtaining system information:', error);
setSnack("Unable to obtain system information.", "error"); setSnack("Unable to obtain system information.", "error");
@ -217,101 +188,101 @@ const Settings = (props: BackstoryPageProps) => {
setEditSystemPrompt(systemPrompt.trim()); setEditSystemPrompt(systemPrompt.trim());
}, [systemPrompt, setEditSystemPrompt]); }, [systemPrompt, setEditSystemPrompt]);
const toggleRag = async (tool: Tool) => { // const toggleRag = async (tool: Tool) => {
tool.enabled = !tool.enabled // tool.enabled = !tool.enabled
try { // try {
const response = await fetch(connectionBase + `/api/1.0/tunables`, { // const response = await fetch(connectionBase + `/api/1.0/tunables`, {
method: 'PUT', // method: 'PUT',
headers: { // headers: {
'Content-Type': 'application/json', // 'Content-Type': 'application/json',
'Accept': 'application/json', // 'Accept': 'application/json',
}, // },
body: JSON.stringify({ "rags": [{ "name": tool?.name, "enabled": tool.enabled }] }), // body: JSON.stringify({ "rags": [{ "name": tool?.name, "enabled": tool.enabled }] }),
}); // });
const tunables: ServerTunables = await response.json(); // const tunables: ServerTunables = await response.json();
setRags(tunables.rags) // setRags(tunables.rags)
setSnack(`${tool?.name} ${tool.enabled ? "enabled" : "disabled"}`); // setSnack(`${tool?.name} ${tool.enabled ? "enabled" : "disabled"}`);
} catch (error) { // } catch (error) {
console.error('Fetch error:', error); // console.error('Fetch error:', error);
setSnack(`${tool?.name} ${tool.enabled ? "enabling" : "disabling"} failed.`, "error"); // setSnack(`${tool?.name} ${tool.enabled ? "enabling" : "disabling"} failed.`, "error");
tool.enabled = !tool.enabled // tool.enabled = !tool.enabled
} // }
}; // };
const toggleTool = async (tool: Tool) => { // const toggleTool = async (tool: Tool) => {
tool.enabled = !tool.enabled // tool.enabled = !tool.enabled
try { // try {
const response = await fetch(connectionBase + `/api/1.0/tunables`, { // const response = await fetch(connectionBase + `/api/1.0/tunables`, {
method: 'PUT', // method: 'PUT',
headers: { // headers: {
'Content-Type': 'application/json', // 'Content-Type': 'application/json',
'Accept': 'application/json', // 'Accept': 'application/json',
}, // },
body: JSON.stringify({ "tools": [{ "name": tool.name, "enabled": tool.enabled }] }), // body: JSON.stringify({ "tools": [{ "name": tool.name, "enabled": tool.enabled }] }),
}); // });
const tunables: ServerTunables = await response.json(); // const tunables: ServerTunables = await response.json();
setTools(tunables.tools) // setTools(tunables.tools)
setSnack(`${tool.name} ${tool.enabled ? "enabled" : "disabled"}`); // setSnack(`${tool.name} ${tool.enabled ? "enabled" : "disabled"}`);
} catch (error) { // } catch (error) {
console.error('Fetch error:', error); // console.error('Fetch error:', error);
setSnack(`${tool.name} ${tool.enabled ? "enabling" : "disabling"} failed.`, "error"); // setSnack(`${tool.name} ${tool.enabled ? "enabling" : "disabling"} failed.`, "error");
tool.enabled = !tool.enabled // tool.enabled = !tool.enabled
} // }
}; // };
// If the systemPrompt has not been set, fetch it from the server // If the systemPrompt has not been set, fetch it from the server
useEffect(() => { // useEffect(() => {
if (serverTunables !== undefined) { // if (serverTunables !== undefined) {
return; // return;
} // }
const fetchTunables = async () => { // const fetchTunables = async () => {
try { // try {
// Make the fetch request with proper headers // // Make the fetch request with proper headers
const response = await fetch(connectionBase + `/api/1.0/tunables`, { // const response = await fetch(connectionBase + `/api/1.0/tunables`, {
method: 'GET', // method: 'GET',
headers: { // headers: {
'Content-Type': 'application/json', // 'Content-Type': 'application/json',
'Accept': 'application/json', // 'Accept': 'application/json',
}, // },
}); // });
const data = await response.json(); // const data = await response.json();
// console.log("Server tunables: ", data); // // console.log("Server tunables: ", data);
setServerTunables(data); // setServerTunables(data);
setSystemPrompt(data["system_prompt"]); // setSystemPrompt(data["system_prompt"]);
setTools(data["tools"]); // setTools(data["tools"]);
setRags(data["rags"]); // setRags(data["rags"]);
} catch (error) { // } catch (error) {
console.error('Fetch error:', error); // console.error('Fetch error:', error);
setSnack("System prompt update failed", "error"); // setSnack("System prompt update failed", "error");
} // }
} // }
fetchTunables(); // fetchTunables();
}, [setServerTunables, setSystemPrompt, setMessageHistoryLength, serverTunables, setTools, setRags, setSnack]); // }, [setServerTunables, setSystemPrompt, setMessageHistoryLength, serverTunables, setTools, setRags, setSnack]);
const toggle = async (type: string, index: number) => { // const toggle = async (type: string, index: number) => {
switch (type) { // switch (type) {
case "rag": // case "rag":
if (rags === undefined) { // if (rags === undefined) {
return; // return;
} // }
toggleRag(rags[index]) // toggleRag(rags[index])
break; // break;
case "tool": // case "tool":
if (tools === undefined) { // if (tools === undefined) {
return; // return;
} // }
toggleTool(tools[index]); // toggleTool(tools[index]);
} // }
}; // };
const handleKeyPress = (event: any) => { // const handleKeyPress = (event: any) => {
if (event.key === 'Enter' && event.ctrlKey) { // if (event.key === 'Enter' && event.ctrlKey) {
setSystemPrompt(editSystemPrompt); // setSystemPrompt(editSystemPrompt);
} // }
}; // };
return (<div className="Controls"> return (<div className="Controls">
{/* <Typography component="span" sx={{ mb: 1 }}> {/* <Typography component="span" sx={{ mb: 1 }}>

View File

@ -6,6 +6,7 @@ import { SetSnackType } from '../components/Snack';
import { LoadingComponent } from "../components/LoadingComponent"; import { LoadingComponent } from "../components/LoadingComponent";
import { User, Guest, Candidate } from 'types/types'; import { User, Guest, Candidate } from 'types/types';
import { useAuth } from "hooks/AuthContext"; import { useAuth } from "hooks/AuthContext";
import { useSelectedCandidate } from "hooks/GlobalContext";
interface CandidateRouteProps { interface CandidateRouteProps {
guest?: Guest | null; guest?: Guest | null;
@ -15,19 +16,19 @@ interface CandidateRouteProps {
const CandidateRoute: React.FC<CandidateRouteProps> = (props: CandidateRouteProps) => { const CandidateRoute: React.FC<CandidateRouteProps> = (props: CandidateRouteProps) => {
const { apiClient } = useAuth(); const { apiClient } = useAuth();
const { selectedCandidate, setSelectedCandidate } = useSelectedCandidate();
const { setSnack } = props; const { setSnack } = props;
const { username } = useParams<{ username: string }>(); const { username } = useParams<{ username: string }>();
const [candidate, setCandidate] = useState<Candidate|null>(null);
const navigate = useNavigate(); const navigate = useNavigate();
useEffect(() => { useEffect(() => {
if (candidate?.username === username || !username) { if (selectedCandidate?.username === username || !username) {
return; return;
} }
const getCandidate = async (reference: string) => { const getCandidate = async (reference: string) => {
try { try {
const result: Candidate = await apiClient.getCandidate(reference); const result: Candidate = await apiClient.getCandidate(reference);
setCandidate(result); setSelectedCandidate(result);
navigate('/chat'); navigate('/chat');
} catch { } catch {
setSnack(`Unable to obtain information for ${username}.`, "error"); setSnack(`Unable to obtain information for ${username}.`, "error");
@ -36,9 +37,9 @@ const CandidateRoute: React.FC<CandidateRouteProps> = (props: CandidateRouteProp
} }
getCandidate(username); getCandidate(username);
}, [candidate, username, setCandidate, navigate, setSnack, apiClient]); }, [selectedCandidate, username, selectedCandidate, navigate, setSnack, apiClient]);
if (candidate === null) { if (selectedCandidate?.username !== username) {
return (<Box> return (<Box>
<LoadingComponent <LoadingComponent
loadingText="Fetching candidate information..." loadingText="Fetching candidate information..."

View File

@ -29,10 +29,33 @@ import {
convertArrayFromApi convertArrayFromApi
} from 'types/types'; } from 'types/types';
const TOKEN_STORAGE = {
ACCESS_TOKEN: 'accessToken',
REFRESH_TOKEN: 'refreshToken',
USER_DATA: 'userData',
TOKEN_EXPIRY: 'tokenExpiry',
USER_TYPE: 'userType',
IS_GUEST: 'isGuest',
PENDING_VERIFICATION_EMAIL: 'pendingVerificationEmail'
} as const;
// ============================ // ============================
// Streaming Types and Interfaces // Streaming Types and Interfaces
// ============================ // ============================
export interface GuestConversionRequest extends CreateCandidateRequest {
accountType: 'candidate';
}
export class RateLimitError extends Error {
constructor(
message: string,
public retryAfterSeconds: number,
public remainingRequests: Record<string, number>
) {
super(message);
this.name = 'RateLimitError';
}
}
interface StreamingOptions<T = Types.ChatMessage> { interface StreamingOptions<T = Types.ChatMessage> {
method?: string, method?: string,
headers?: Record<string, any>, headers?: Record<string, any>,
@ -777,7 +800,7 @@ class ApiClient {
async getOrCreateChatSession(candidate: Types.Candidate, title: string, context_type: Types.ChatContextType) : Promise<Types.ChatSession> { async getOrCreateChatSession(candidate: Types.Candidate, title: string, context_type: Types.ChatContextType) : Promise<Types.ChatSession> {
const result = await this.getCandidateChatSessions(candidate.username); const result = await this.getCandidateChatSessions(candidate.username);
/* Find the 'candidate_chat' session if it exists, otherwise create it */ /* Find the 'candidate_chat' session if it exists, otherwise create it */
let session = result.sessions.data.find(session => session.title === 'candidate_chat'); let session = result.sessions.data.find(session => session.title === title);
if (!session) { if (!session) {
session = await this.createCandidateChatSession( session = await this.createCandidateChatSession(
candidate.username, candidate.username,
@ -788,6 +811,17 @@ class ApiClient {
return session; return session;
} }
async getSystemInfo() : Promise<Types.SystemInfo> {
const response = await fetch(`${this.baseUrl}/system-info`, {
method: 'GET',
headers: this.defaultHeaders,
});
const result = await handleApiResponse<Types.SystemInfo>(response);
return result;
}
async getCandidateSimilarContent(query: string async getCandidateSimilarContent(query: string
): Promise<Types.ChromaDBGetResponse> { ): Promise<Types.ChromaDBGetResponse> {
const response = await fetch(`${this.baseUrl}/candidates/rag-search`, { const response = await fetch(`${this.baseUrl}/candidates/rag-search`, {
@ -1003,6 +1037,202 @@ class ApiClient {
return handleApiResponse<{ success: boolean; message: string }>(response); return handleApiResponse<{ success: boolean; message: string }>(response);
} }
// ============================
// Guest Authentication Methods
// ============================
/**
* Create a guest session with authentication
*/
async createGuestSession(): Promise<Types.AuthResponse> {
const response = await fetch(`${this.baseUrl}/auth/guest`, {
method: 'POST',
headers: this.defaultHeaders
});
const result = await handleApiResponse<Types.AuthResponse>(response);
// Convert guest data if needed
if (result.user && result.user.userType === 'guest') {
result.user = convertFromApi<Types.Guest>(result.user, "Guest");
}
return result;
}
/**
* Convert guest account to permanent user account
*/
async convertGuestToUser(
registrationData: CreateCandidateRequest & { accountType: 'candidate' }
): Promise<{
message: string;
auth: Types.AuthResponse;
conversionType: string
}> {
const response = await fetch(`${this.baseUrl}/auth/guest/convert`, {
method: 'POST',
headers: this.defaultHeaders,
body: JSON.stringify(formatApiRequest(registrationData))
});
const result = await handleApiResponse<{
message: string;
auth: Types.AuthResponse;
conversionType: string;
}>(response);
// Convert the auth user data
if (result.auth?.user) {
result.auth.user = convertFromApi<Types.Candidate>(result.auth.user, "Candidate");
}
return result;
}
/**
* Check if current session is a guest
*/
isGuestSession(): boolean {
try {
const userDataStr = localStorage.getItem(TOKEN_STORAGE.USER_DATA);
if (userDataStr) {
const userData = JSON.parse(userDataStr);
return userData.userType === 'guest';
}
return false;
} catch {
return false;
}
}
/**
* Get guest session info
*/
getGuestSessionInfo(): Types.Guest | null {
try {
const userDataStr = localStorage.getItem(TOKEN_STORAGE.USER_DATA);
if (userDataStr) {
const userData = JSON.parse(userDataStr);
if (userData.userType === 'guest') {
return convertFromApi<Types.Guest>(userData, "Guest");
}
}
return null;
} catch {
return null;
}
}
/**
* Get rate limit status for current user
*/
async getRateLimitStatus(): Promise<{
user_id: string;
user_type: string;
is_admin: boolean;
current_usage: Record<string, number>;
limits: Record<string, number>;
remaining: Record<string, number>;
reset_times: Record<string, string>;
config: any;
}> {
const response = await fetch(`${this.baseUrl}/admin/rate-limits/info`, {
headers: this.defaultHeaders
});
return handleApiResponse<any>(response);
}
/**
* Get guest statistics (admin only)
*/
async getGuestStatistics(): Promise<{
total_guests: number;
active_last_hour: number;
active_last_day: number;
converted_guests: number;
by_ip: Record<string, number>;
creation_timeline: Record<string, number>;
}> {
const response = await fetch(`${this.baseUrl}/admin/guests/statistics`, {
headers: this.defaultHeaders
});
return handleApiResponse<any>(response);
}
/**
* Cleanup inactive guests (admin only)
*/
async cleanupInactiveGuests(inactiveHours: number = 24): Promise<{
message: string;
cleaned_count: number;
}> {
const response = await fetch(`${this.baseUrl}/admin/guests/cleanup`, {
method: 'POST',
headers: this.defaultHeaders,
body: JSON.stringify({ inactive_hours: inactiveHours })
});
return handleApiResponse<{
message: string;
cleaned_count: number;
}>(response);
}
// ============================
// Enhanced Error Handling for Rate Limits
// ============================
/**
* Enhanced API response handler with rate limit handling
*/
private async handleApiResponseWithRateLimit<T>(response: Response): Promise<T> {
if (response.status === 429) {
const rateLimitData = await response.json();
const retryAfter = response.headers.get('Retry-After');
throw new RateLimitError(
rateLimitData.detail?.message || 'Rate limit exceeded',
parseInt(retryAfter || '60'),
rateLimitData.detail?.remaining || {}
);
}
return this.handleApiResponseWithConversion<T>(response);
}
/**
* Retry mechanism for rate-limited requests
*/
private async retryWithBackoff<T>(
requestFn: () => Promise<Response>,
maxRetries: number = 3
): Promise<T> {
let lastError: Error;
for (let attempt = 0; attempt <= maxRetries; attempt++) {
try {
const response = await requestFn();
return await this.handleApiResponseWithRateLimit<T>(response);
} catch (error) {
lastError = error as Error;
if (error instanceof RateLimitError && attempt < maxRetries) {
const delayMs = Math.min(error.retryAfterSeconds * 1000, 60000); // Max 1 minute
console.warn(`Rate limited, retrying in ${delayMs}ms (attempt ${attempt + 1}/${maxRetries + 1})`);
await new Promise(resolve => setTimeout(resolve, delayMs));
continue;
}
throw error;
}
}
throw lastError!;
}
async resetChatSession(id: string): Promise<{ success: boolean; message: string }> { async resetChatSession(id: string): Promise<{ success: boolean; message: string }> {
const response = await fetch(`${this.baseUrl}/chat/sessions/${id}/reset`, { const response = await fetch(`${this.baseUrl}/chat/sessions/${id}/reset`, {
method: 'PATCH', method: 'PATCH',
@ -1384,5 +1614,5 @@ export interface PendingVerification {
attempts: number; attempts: number;
} }
export { ApiClient } export { ApiClient, TOKEN_STORAGE }
export type { StreamingOptions, StreamingResponse } export type { StreamingOptions, StreamingResponse }

View File

@ -1,6 +1,6 @@
// Generated TypeScript types from Pydantic models // Generated TypeScript types from Pydantic models
// Source: src/backend/models.py // Source: src/backend/models.py
// Generated on: 2025-06-07T20:43:58.855207 // Generated on: 2025-06-09T03:32:01.335483
// DO NOT EDIT MANUALLY - This file is auto-generated // DO NOT EDIT MANUALLY - This file is auto-generated
// ============================ // ============================
@ -128,8 +128,10 @@ export interface Attachment {
export interface AuthResponse { export interface AuthResponse {
accessToken: string; accessToken: string;
refreshToken: string; refreshToken: string;
user: any; user: Candidate | Employer | Guest;
expiresAt: number; expiresAt: number;
userType?: string;
isGuest?: boolean;
} }
export interface Authentication { export interface Authentication {
@ -148,7 +150,9 @@ export interface Authentication {
} }
export interface BaseUser { export interface BaseUser {
userType: "candidate" | "employer" | "guest";
id?: string; id?: string;
lastActivity?: Date;
email: string; email: string;
firstName: string; firstName: string;
lastName: string; lastName: string;
@ -164,24 +168,15 @@ export interface BaseUser {
} }
export interface BaseUserWithType { export interface BaseUserWithType {
id?: string;
email: string;
firstName: string;
lastName: string;
fullName: string;
phone?: string;
location?: Location;
createdAt: Date;
updatedAt: Date;
lastLogin?: Date;
profileImage?: string;
status: "active" | "inactive" | "pending" | "banned";
isAdmin: boolean;
userType: "candidate" | "employer" | "guest"; userType: "candidate" | "employer" | "guest";
id?: string;
lastActivity?: Date;
} }
export interface Candidate { export interface Candidate {
userType: "candidate" | "employer" | "guest";
id?: string; id?: string;
lastActivity?: Date;
email: string; email: string;
firstName: string; firstName: string;
lastName: string; lastName: string;
@ -194,7 +189,6 @@ export interface Candidate {
profileImage?: string; profileImage?: string;
status: "active" | "inactive" | "pending" | "banned"; status: "active" | "inactive" | "pending" | "banned";
isAdmin: boolean; isAdmin: boolean;
userType: "candidate";
username: string; username: string;
description?: string; description?: string;
resume?: string; resume?: string;
@ -214,7 +208,9 @@ export interface Candidate {
} }
export interface CandidateAI { export interface CandidateAI {
userType: "candidate" | "employer" | "guest";
id?: string; id?: string;
lastActivity?: Date;
email: string; email: string;
firstName: string; firstName: string;
lastName: string; lastName: string;
@ -227,7 +223,6 @@ export interface CandidateAI {
profileImage?: string; profileImage?: string;
status: "active" | "inactive" | "pending" | "banned"; status: "active" | "inactive" | "pending" | "banned";
isAdmin: boolean; isAdmin: boolean;
userType: "candidate";
username: string; username: string;
description?: string; description?: string;
resume?: string; resume?: string;
@ -301,7 +296,7 @@ export interface ChatMessage {
role: "user" | "assistant" | "system" | "information" | "warning" | "error"; role: "user" | "assistant" | "system" | "information" | "warning" | "error";
content: string; content: string;
tunables?: Tunables; tunables?: Tunables;
metadata?: ChatMessageMetaData; metadata: ChatMessageMetaData;
} }
export interface ChatMessageError { export interface ChatMessageError {
@ -319,9 +314,9 @@ export interface ChatMessageMetaData {
temperature: number; temperature: number;
maxTokens: number; maxTokens: number;
topP: number; topP: number;
frequencyPenalty?: number; frequencyPenalty: number;
presencePenalty?: number; presencePenalty: number;
stopSequences?: Array<string>; stopSequences: Array<string>;
ragResults?: Array<ChromaDBGetResponse>; ragResults?: Array<ChromaDBGetResponse>;
llmHistory?: Array<any>; llmHistory?: Array<any>;
evalCount: number; evalCount: number;
@ -392,7 +387,6 @@ export interface ChatQuery {
export interface ChatSession { export interface ChatSession {
id?: string; id?: string;
userId?: string; userId?: string;
guestId?: string;
createdAt?: Date; createdAt?: Date;
lastActivity?: Date; lastActivity?: Date;
title?: string; title?: string;
@ -507,7 +501,7 @@ export interface DocumentMessage {
} }
export interface DocumentOptions { export interface DocumentOptions {
includeInRAG?: boolean; includeInRAG: boolean;
isJobDocument?: boolean; isJobDocument?: boolean;
overwrite?: boolean; overwrite?: boolean;
} }
@ -541,7 +535,9 @@ export interface EmailVerificationRequest {
} }
export interface Employer { export interface Employer {
userType: "candidate" | "employer" | "guest";
id?: string; id?: string;
lastActivity?: Date;
email: string; email: string;
firstName: string; firstName: string;
lastName: string; lastName: string;
@ -554,7 +550,6 @@ export interface Employer {
profileImage?: string; profileImage?: string;
status: "active" | "inactive" | "pending" | "banned"; status: "active" | "inactive" | "pending" | "banned";
isAdmin: boolean; isAdmin: boolean;
userType: "employer";
companyName: string; companyName: string;
industry: string; industry: string;
description?: string; description?: string;
@ -580,14 +575,71 @@ export interface ErrorDetail {
details?: any; details?: any;
} }
export interface GPUInfo {
name: string;
memory: number;
discrete: boolean;
}
export interface Guest { export interface Guest {
userType: "candidate" | "employer" | "guest";
id?: string; id?: string;
sessionId: string; lastActivity?: Date;
email: string;
firstName: string;
lastName: string;
fullName: string;
phone?: string;
location?: Location;
createdAt: Date; createdAt: Date;
lastActivity: Date; updatedAt: Date;
lastLogin?: Date;
profileImage?: string;
status: "active" | "inactive" | "pending" | "banned";
isAdmin: boolean;
sessionId: string;
username: string;
convertedToUserId?: string; convertedToUserId?: string;
ipAddress?: string; ipAddress?: string;
userAgent?: string; userAgent?: string;
ragContentSize: number;
}
export interface GuestCleanupRequest {
inactiveHours: number;
}
export interface GuestConversionRequest {
accountType: "candidate" | "employer";
email: string;
username: string;
password: string;
firstName: string;
lastName: string;
phone?: string;
companyName?: string;
industry?: string;
companySize?: string;
companyDescription?: string;
websiteUrl?: string;
}
export interface GuestSessionResponse {
accessToken: string;
refreshToken: string;
user: Guest;
expiresAt: number;
userType: "guest";
isGuest: boolean;
}
export interface GuestStatistics {
totalGuests: number;
activeLastHour: number;
activeLastDay: number;
convertedGuests: number;
byIp: Record<string, number>;
creationTimeline: Record<string, number>;
} }
export interface InterviewFeedback { export interface InterviewFeedback {
@ -746,6 +798,7 @@ export interface MFARequest {
password: string; password: string;
deviceId: string; deviceId: string;
deviceName: string; deviceName: string;
email: string;
} }
export interface MFARequestResponse { export interface MFARequestResponse {
@ -841,6 +894,33 @@ export interface RagEntry {
enabled: boolean; enabled: boolean;
} }
export interface RateLimitConfig {
requestsPerMinute: number;
requestsPerHour: number;
requestsPerDay: number;
burstLimit: number;
burstWindowSeconds: number;
}
export interface RateLimitResult {
allowed: boolean;
reason?: string;
retryAfterSeconds?: number;
remainingRequests?: Record<string, number>;
resetTimes?: Record<string, Date>;
}
export interface RateLimitStatus {
userId: string;
userType: string;
isAdmin: boolean;
currentUsage: Record<string, number>;
limits: Record<string, number>;
remaining: Record<string, number>;
resetTimes: Record<string, Date>;
config: RateLimitConfig;
}
export interface RefreshToken { export interface RefreshToken {
token: string; token: string;
expiresAt: Date; expiresAt: Date;
@ -915,6 +995,15 @@ export interface SocialLink {
url: string; url: string;
} }
export interface SystemInfo {
installedRAM: string;
graphicsCards: Array<GPUInfo>;
CPU: string;
llmModel: string;
embeddingModel: string;
maxContextLength: number;
}
export interface Tunables { export interface Tunables {
enableRAG: boolean; enableRAG: boolean;
enableTools: boolean; enableTools: boolean;
@ -1035,13 +1124,15 @@ export function convertAuthenticationFromApi(data: any): Authentication {
} }
/** /**
* Convert BaseUser from API response, parsing date fields * Convert BaseUser from API response, parsing date fields
* Date fields: createdAt, updatedAt, lastLogin * Date fields: lastActivity, createdAt, updatedAt, lastLogin
*/ */
export function convertBaseUserFromApi(data: any): BaseUser { export function convertBaseUserFromApi(data: any): BaseUser {
if (!data) return data; if (!data) return data;
return { return {
...data, ...data,
// Convert lastActivity from ISO string to Date
lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined,
// Convert createdAt from ISO string to Date // Convert createdAt from ISO string to Date
createdAt: new Date(data.createdAt), createdAt: new Date(data.createdAt),
// Convert updatedAt from ISO string to Date // Convert updatedAt from ISO string to Date
@ -1052,30 +1143,28 @@ export function convertBaseUserFromApi(data: any): BaseUser {
} }
/** /**
* Convert BaseUserWithType from API response, parsing date fields * Convert BaseUserWithType from API response, parsing date fields
* Date fields: createdAt, updatedAt, lastLogin * Date fields: lastActivity
*/ */
export function convertBaseUserWithTypeFromApi(data: any): BaseUserWithType { export function convertBaseUserWithTypeFromApi(data: any): BaseUserWithType {
if (!data) return data; if (!data) return data;
return { return {
...data, ...data,
// Convert createdAt from ISO string to Date // Convert lastActivity from ISO string to Date
createdAt: new Date(data.createdAt), lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined,
// Convert updatedAt from ISO string to Date
updatedAt: new Date(data.updatedAt),
// Convert lastLogin from ISO string to Date
lastLogin: data.lastLogin ? new Date(data.lastLogin) : undefined,
}; };
} }
/** /**
* Convert Candidate from API response, parsing date fields * Convert Candidate from API response, parsing date fields
* Date fields: createdAt, updatedAt, lastLogin, availabilityDate * Date fields: lastActivity, createdAt, updatedAt, lastLogin, availabilityDate
*/ */
export function convertCandidateFromApi(data: any): Candidate { export function convertCandidateFromApi(data: any): Candidate {
if (!data) return data; if (!data) return data;
return { return {
...data, ...data,
// Convert lastActivity from ISO string to Date
lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined,
// Convert createdAt from ISO string to Date // Convert createdAt from ISO string to Date
createdAt: new Date(data.createdAt), createdAt: new Date(data.createdAt),
// Convert updatedAt from ISO string to Date // Convert updatedAt from ISO string to Date
@ -1088,13 +1177,15 @@ export function convertCandidateFromApi(data: any): Candidate {
} }
/** /**
* Convert CandidateAI from API response, parsing date fields * Convert CandidateAI from API response, parsing date fields
* Date fields: createdAt, updatedAt, lastLogin, availabilityDate * Date fields: lastActivity, createdAt, updatedAt, lastLogin, availabilityDate
*/ */
export function convertCandidateAIFromApi(data: any): CandidateAI { export function convertCandidateAIFromApi(data: any): CandidateAI {
if (!data) return data; if (!data) return data;
return { return {
...data, ...data,
// Convert lastActivity from ISO string to Date
lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined,
// Convert createdAt from ISO string to Date // Convert createdAt from ISO string to Date
createdAt: new Date(data.createdAt), createdAt: new Date(data.createdAt),
// Convert updatedAt from ISO string to Date // Convert updatedAt from ISO string to Date
@ -1282,13 +1373,15 @@ export function convertEducationFromApi(data: any): Education {
} }
/** /**
* Convert Employer from API response, parsing date fields * Convert Employer from API response, parsing date fields
* Date fields: createdAt, updatedAt, lastLogin * Date fields: lastActivity, createdAt, updatedAt, lastLogin
*/ */
export function convertEmployerFromApi(data: any): Employer { export function convertEmployerFromApi(data: any): Employer {
if (!data) return data; if (!data) return data;
return { return {
...data, ...data,
// Convert lastActivity from ISO string to Date
lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined,
// Convert createdAt from ISO string to Date // Convert createdAt from ISO string to Date
createdAt: new Date(data.createdAt), createdAt: new Date(data.createdAt),
// Convert updatedAt from ISO string to Date // Convert updatedAt from ISO string to Date
@ -1299,17 +1392,21 @@ export function convertEmployerFromApi(data: any): Employer {
} }
/** /**
* Convert Guest from API response, parsing date fields * Convert Guest from API response, parsing date fields
* Date fields: createdAt, lastActivity * Date fields: lastActivity, createdAt, updatedAt, lastLogin
*/ */
export function convertGuestFromApi(data: any): Guest { export function convertGuestFromApi(data: any): Guest {
if (!data) return data; if (!data) return data;
return { return {
...data, ...data,
// Convert lastActivity from ISO string to Date
lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined,
// Convert createdAt from ISO string to Date // Convert createdAt from ISO string to Date
createdAt: new Date(data.createdAt), createdAt: new Date(data.createdAt),
// Convert lastActivity from ISO string to Date // Convert updatedAt from ISO string to Date
lastActivity: new Date(data.lastActivity), updatedAt: new Date(data.updatedAt),
// Convert lastLogin from ISO string to Date
lastLogin: data.lastLogin ? new Date(data.lastLogin) : undefined,
}; };
} }
/** /**
@ -1415,6 +1512,32 @@ export function convertRAGConfigurationFromApi(data: any): RAGConfiguration {
updatedAt: new Date(data.updatedAt), updatedAt: new Date(data.updatedAt),
}; };
} }
/**
* Convert RateLimitResult from API response, parsing date fields
* Date fields: resetTimes
*/
export function convertRateLimitResultFromApi(data: any): RateLimitResult {
if (!data) return data;
return {
...data,
// Convert resetTimes from ISO string to Date
resetTimes: data.resetTimes ? new Date(data.resetTimes) : undefined,
};
}
/**
* Convert RateLimitStatus from API response, parsing date fields
* Date fields: resetTimes
*/
export function convertRateLimitStatusFromApi(data: any): RateLimitStatus {
if (!data) return data;
return {
...data,
// Convert resetTimes from ISO string to Date
resetTimes: new Date(data.resetTimes),
};
}
/** /**
* Convert RefreshToken from API response, parsing date fields * Convert RefreshToken from API response, parsing date fields
* Date fields: expiresAt * Date fields: expiresAt
@ -1527,6 +1650,10 @@ export function convertFromApi<T>(data: any, modelType: string): T {
return convertMessageReactionFromApi(data) as T; return convertMessageReactionFromApi(data) as T;
case 'RAGConfiguration': case 'RAGConfiguration':
return convertRAGConfigurationFromApi(data) as T; return convertRAGConfigurationFromApi(data) as T;
case 'RateLimitResult':
return convertRateLimitResultFromApi(data) as T;
case 'RateLimitStatus':
return convertRateLimitStatusFromApi(data) as T;
case 'RefreshToken': case 'RefreshToken':
return convertRefreshTokenFromApi(data) as T; return convertRefreshTokenFromApi(data) as T;
case 'UserActivity': case 'UserActivity':

View File

@ -1,15 +0,0 @@
const getConnectionBase = (loc: any): string => {
if (!loc.host.match(/.*battle-linux.*/)
// && !loc.host.match(/.*backstory-beta.*/)
) {
return loc.protocol + "//" + loc.host;
} else {
return loc.protocol + "//battle-linux.ketrenos.com:8912";
}
}
const connectionBase = getConnectionBase(window.location);
export {
connectionBase
};

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
import traceback import traceback
from pydantic import BaseModel, Field # type: ignore from pydantic import BaseModel, Field
from typing import ( from typing import (
Literal, Literal,
get_args, get_args,
@ -17,7 +17,7 @@ from typing import (
import importlib import importlib
import pathlib import pathlib
import inspect import inspect
from prometheus_client import CollectorRegistry # type: ignore from prometheus_client import CollectorRegistry
from . base import Agent from . base import Agent
from logger import logger from logger import logger
@ -90,7 +90,7 @@ for path in package_dir.glob("*.py"):
class_registry[name] = (full_module_name, name) class_registry[name] = (full_module_name, name)
globals()[name] = obj globals()[name] = obj
logger.info(f"Adding agent: {name}") logger.info(f"Adding agent: {name}")
__all__.append(name) # type: ignore __all__.append(name)
except ImportError as e: except ImportError as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
logger.error(f"Error importing {full_module_name}: {e}") logger.error(f"Error importing {full_module_name}: {e}")

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
from pydantic import BaseModel, Field, model_validator # type: ignore from pydantic import BaseModel, Field, model_validator
from typing import ( from typing import (
Literal, Literal,
get_args, get_args,
@ -20,8 +20,8 @@ import re
from abc import ABC from abc import ABC
import asyncio import asyncio
from datetime import datetime, UTC from datetime import datetime, UTC
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore from prometheus_client import Counter, Summary, CollectorRegistry
import numpy as np # type: ignore import numpy as np
from models import ( ApiActivityType, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatQuery, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate) from models import ( ApiActivityType, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatQuery, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate)
from logger import logger from logger import logger
@ -491,7 +491,7 @@ Content: {content}
session_id: str, prompt: str, session_id: str, prompt: str,
tunables: Optional[Tunables] = None, tunables: Optional[Tunables] = None,
temperature=0.7 temperature=0.7
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]: ) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessageRagSearch, None]:
if not self.user: if not self.user:
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=session_id, session_id=session_id,

View File

@ -28,7 +28,7 @@ class CandidateChat(Agent):
CandidateChat Agent CandidateChat Agent
""" """
agent_type: Literal["candidate_chat"] = "candidate_chat" # type: ignore agent_type: Literal["candidate_chat"] = "candidate_chat"
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
system_prompt: str = system_message system_prompt: str = system_message

View File

@ -53,7 +53,7 @@ class Chat(Agent):
Chat Agent Chat Agent
""" """
agent_type: Literal["general"] = "general" # type: ignore agent_type: Literal["general"] = "general"
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
system_prompt: str = system_message system_prompt: str = system_message

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from datetime import UTC, datetime from datetime import UTC, datetime
from pydantic import model_validator, Field, BaseModel # type: ignore from pydantic import model_validator, Field, BaseModel
from typing import ( from typing import (
Dict, Dict,
Literal, Literal,
@ -37,7 +37,7 @@ seed = int(time.time())
random.seed(seed) random.seed(seed)
class ImageGenerator(Agent): class ImageGenerator(Agent):
agent_type: Literal["generate_image"] = "generate_image" # type: ignore agent_type: Literal["generate_image"] = "generate_image"
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
agent_persist: bool = False agent_persist: bool = False

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from datetime import UTC, datetime from datetime import UTC, datetime
from pydantic import model_validator, Field, BaseModel # type: ignore from pydantic import model_validator, Field, BaseModel
from typing import ( from typing import (
Dict, Dict,
Literal, Literal,
@ -23,7 +23,7 @@ import asyncio
import time import time
import os import os
import random import random
from names_dataset import NameDataset, NameWrapper # type: ignore from names_dataset import NameDataset, NameWrapper
from .base import Agent, agent_registry, LLMMessage from .base import Agent, agent_registry, LLMMessage
from models import ApiActivityType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, Tunables from models import ApiActivityType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, Tunables
@ -128,7 +128,7 @@ logger = logging.getLogger(__name__)
class EthnicNameGenerator: class EthnicNameGenerator:
def __init__(self): def __init__(self):
try: try:
from names_dataset import NameDataset # type: ignore from names_dataset import NameDataset
self.nd = NameDataset() self.nd = NameDataset()
except ImportError: except ImportError:
logger.error("NameDataset not available. Please install: pip install names-dataset") logger.error("NameDataset not available. Please install: pip install names-dataset")
@ -292,7 +292,7 @@ class EthnicNameGenerator:
return names return names
class GeneratePersona(Agent): class GeneratePersona(Agent):
agent_type: Literal["generate_persona"] = "generate_persona" # type: ignore agent_type: Literal["generate_persona"] = "generate_persona"
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
agent_persist: bool = False agent_persist: bool = False

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
from pydantic import model_validator, Field # type: ignore from pydantic import model_validator, Field
from typing import ( from typing import (
Dict, Dict,
Literal, Literal,
@ -16,7 +16,7 @@ import json
import asyncio import asyncio
import time import time
import asyncio import asyncio
import numpy as np # type: ignore import numpy as np
from .base import Agent, agent_registry, LLMMessage from .base import Agent, agent_registry, LLMMessage
from models import ApiActivityType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, JobRequirements, JobRequirementsMessage, Tunables from models import ApiActivityType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, JobRequirements, JobRequirementsMessage, Tunables
@ -26,7 +26,7 @@ import defines
import backstory_traceback as traceback import backstory_traceback as traceback
class JobRequirementsAgent(Agent): class JobRequirementsAgent(Agent):
agent_type: Literal["job_requirements"] = "job_requirements" # type: ignore agent_type: Literal["job_requirements"] = "job_requirements"
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
# Stage 1A: Job Analysis Implementation # Stage 1A: Job Analysis Implementation

View File

@ -15,7 +15,7 @@ class Chat(Agent):
Chat Agent Chat Agent
""" """
agent_type: Literal["rag_search"] = "rag_search" # type: ignore agent_type: Literal["rag_search"] = "rag_search"
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
async def generate( async def generate(

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
from pydantic import model_validator, Field # type: ignore from pydantic import model_validator, Field
from typing import ( from typing import (
Dict, Dict,
Literal, Literal,
@ -16,7 +16,7 @@ import json
import asyncio import asyncio
import time import time
import asyncio import asyncio
import numpy as np # type: ignore import numpy as np
from .base import Agent, agent_registry, LLMMessage from .base import Agent, agent_registry, LLMMessage
from models import Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, SkillMatch, Tunables from models import Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, SkillMatch, Tunables
@ -25,7 +25,7 @@ from logger import logger
import defines import defines
class SkillMatchAgent(Agent): class SkillMatchAgent(Agent):
agent_type: Literal["skill_match"] = "skill_match" # type: ignore agent_type: Literal["skill_match"] = "skill_match"
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
def generate_skill_assessment_prompt(self, skill, rag_context): def generate_skill_assessment_prompt(self, skill, rag_context):

View File

@ -5,12 +5,12 @@ Provides password hashing, verification, and security features
""" """
import traceback import traceback
import bcrypt # type: ignore import bcrypt
import secrets import secrets
import logging import logging
from datetime import datetime, timezone, timedelta from datetime import datetime, timezone, timedelta
from typing import Dict, Any, Optional, Tuple from typing import Dict, Any, Optional, Tuple
from pydantic import BaseModel # type: ignore from pydantic import BaseModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -0,0 +1,175 @@
"""
Background tasks for guest cleanup and system maintenance
"""
import asyncio
import schedule # type: ignore
import threading
import time
from datetime import datetime, timedelta, UTC
from typing import Optional
from logger import logger
from database import DatabaseManager
class BackgroundTaskManager:
"""Manages background tasks for the application"""
def __init__(self, database_manager: DatabaseManager):
self.database_manager = database_manager
self.running = False
self.tasks = []
self.scheduler_thread: Optional[threading.Thread] = None
async def cleanup_inactive_guests(self, inactive_hours: int = 24):
"""Clean up inactive guest sessions"""
try:
database = self.database_manager.get_database()
cleaned_count = await database.cleanup_inactive_guests(inactive_hours)
if cleaned_count > 0:
logger.info(f"🧹 Background cleanup: removed {cleaned_count} inactive guest sessions")
return cleaned_count
except Exception as e:
logger.error(f"❌ Error in guest cleanup: {e}")
return 0
async def cleanup_expired_verification_tokens(self):
"""Clean up expired email verification tokens"""
try:
database = self.database_manager.get_database()
cleaned_count = await database.cleanup_expired_verification_tokens()
if cleaned_count > 0:
logger.info(f"🧹 Background cleanup: removed {cleaned_count} expired verification tokens")
return cleaned_count
except Exception as e:
logger.error(f"❌ Error in verification token cleanup: {e}")
return 0
async def update_guest_statistics(self):
"""Update guest usage statistics"""
try:
database = self.database_manager.get_database()
stats = await database.get_guest_statistics()
# Log interesting statistics
if stats.get('total_guests', 0) > 0:
logger.info(f"📊 Guest stats: {stats['total_guests']} total, "
f"{stats['active_last_hour']} active in last hour, "
f"{stats['converted_guests']} converted")
return stats
except Exception as e:
logger.error(f"❌ Error updating guest statistics: {e}")
return {}
async def cleanup_old_rate_limit_data(self, days_old: int = 7):
"""Clean up old rate limiting data"""
try:
database = self.database_manager.get_database()
redis = database.redis
# Clean up rate limit keys older than specified days
cutoff_time = datetime.now(UTC) - timedelta(days=days_old)
pattern = "rate_limit:*"
cursor = 0
deleted_count = 0
while True:
cursor, keys = await redis.scan(cursor, match=pattern, count=100)
for key in keys:
# Check if key is old enough to delete
try:
ttl = await redis.ttl(key)
if ttl == -1: # No expiration set, check creation time
# For simplicity, delete keys without TTL
await redis.delete(key)
deleted_count += 1
except Exception:
continue
if cursor == 0:
break
if deleted_count > 0:
logger.info(f"🧹 Cleaned up {deleted_count} old rate limit keys")
return deleted_count
except Exception as e:
logger.error(f"❌ Error cleaning up rate limit data: {e}")
return 0
def schedule_periodic_tasks(self):
"""Schedule periodic background tasks with safer intervals"""
# Guest cleanup - every 6 hours instead of every hour (less aggressive)
schedule.every(6).hours.do(self._run_async_task, self.cleanup_inactive_guests, 48) # 48 hours instead of 24
# Verification token cleanup - every 12 hours
schedule.every(12).hours.do(self._run_async_task, self.cleanup_expired_verification_tokens)
# Guest statistics update - every hour
schedule.every().hour.do(self._run_async_task, self.update_guest_statistics)
# Rate limit data cleanup - daily at 3 AM
schedule.every().day.at("03:00").do(self._run_async_task, self.cleanup_old_rate_limit_data, 7)
logger.info("📅 Background tasks scheduled with safer intervals")
def _run_async_task(self, coro_func, *args, **kwargs):
"""Run an async task in the background"""
try:
# Create new event loop for this thread if needed
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Run the coroutine
loop.run_until_complete(coro_func(*args, **kwargs))
except Exception as e:
logger.error(f"❌ Error running background task {coro_func.__name__}: {e}")
def _scheduler_worker(self):
"""Worker thread for running scheduled tasks"""
while self.running:
try:
schedule.run_pending()
time.sleep(60) # Check every minute
except Exception as e:
logger.error(f"❌ Error in scheduler worker: {e}")
time.sleep(60)
def start(self):
"""Start the background task manager"""
if self.running:
logger.warning("⚠️ Background task manager already running")
return
self.running = True
self.schedule_periodic_tasks()
# Start scheduler thread
self.scheduler_thread = threading.Thread(target=self._scheduler_worker, daemon=True)
self.scheduler_thread.start()
logger.info("🚀 Background task manager started")
def stop(self):
"""Stop the background task manager"""
self.running = False
if self.scheduler_thread and self.scheduler_thread.is_alive():
self.scheduler_thread.join(timeout=5)
# Clear scheduled tasks
schedule.clear()
logger.info("🛑 Background task manager stopped")

View File

@ -9,6 +9,7 @@ from models import (
# User models # User models
Candidate, Employer, BaseUser, Guest, Authentication, AuthResponse, Candidate, Employer, BaseUser, Guest, Authentication, AuthResponse,
) )
import backstory_traceback as traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -198,6 +199,7 @@ class RedisDatabase:
try: try:
return json.loads(data) return json.loads(data)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.error(traceback.format_exc())
logger.error(f"Failed to deserialize data: {data}") logger.error(f"Failed to deserialize data: {data}")
return None return None
@ -254,43 +256,44 @@ class RedisDatabase:
raise raise
async def get_cached_skill_match(self, cache_key: str) -> Optional[Dict[str, Any]]: async def get_cached_skill_match(self, cache_key: str) -> Optional[Dict[str, Any]]:
"""Retrieve cached skill match assessment""" """Get cached skill match assessment"""
try: try:
cached_data = await self.redis.get(cache_key) data = await self.redis.get(cache_key)
if cached_data: if data:
return json.loads(cached_data) return json.loads(data)
return None return None
except Exception as e: except Exception as e:
logger.error(f"Error retrieving cached skill match: {e}") logger.error(f"❌ Error getting cached skill match: {e}")
return None return None
async def cache_skill_match(self, cache_key: str, assessment_data: Dict[str, Any], ttl: int = 86400 * 30) -> bool: async def cache_skill_match(self, cache_key: str, assessment_data: Dict[str, Any]) -> None:
"""Cache skill match assessment with TTL (default 30 days)""" """Cache skill match assessment"""
try: try:
# Cache for 1 hour by default
await self.redis.setex( await self.redis.setex(
cache_key, cache_key,
ttl, 3600,
json.dumps(assessment_data, default=str) json.dumps(assessment_data)
) )
return True logger.debug(f"💾 Skill match cached: {cache_key}")
except Exception as e: except Exception as e:
logger.error(f"Error caching skill match: {e}") logger.error(f"❌ Error caching skill match: {e}")
return False
async def get_candidate_skill_update_time(self, candidate_id: str) -> Optional[datetime]: async def get_candidate_skill_update_time(self, candidate_id: str) -> Optional[datetime]:
"""Get the last time candidate's skill information was updated""" """Get the last time candidate skills were updated"""
try: try:
# This assumes you track skill update timestamps in your candidate data
candidate_data = await self.get_candidate(candidate_id) candidate_data = await self.get_candidate(candidate_id)
if candidate_data and 'skills_updated_at' in candidate_data: if candidate_data:
return datetime.fromisoformat(candidate_data['skills_updated_at']) updated_at_str = candidate_data.get("updated_at")
if updated_at_str:
return datetime.fromisoformat(updated_at_str.replace('Z', '+00:00'))
return None return None
except Exception as e: except Exception as e:
logger.error(f"Error getting candidate skill update time: {e}") logger.error(f"Error getting candidate skill update time: {e}")
return None return None
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]: async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]:
"""Get the timestamp of the latest RAG data update for a specific user""" """Get the last time user's RAG data was updated"""
try: try:
rag_update_key = f"user:{user_id}:rag_last_update" rag_update_key = f"user:{user_id}:rag_last_update"
timestamp_str = await self.redis.get(rag_update_key) timestamp_str = await self.redis.get(rag_update_key)
@ -298,7 +301,7 @@ class RedisDatabase:
return datetime.fromisoformat(timestamp_str.decode('utf-8')) return datetime.fromisoformat(timestamp_str.decode('utf-8'))
return None return None
except Exception as e: except Exception as e:
logger.error(f"Error getting user RAG update time for user {user_id}: {e}") logger.error(f"Error getting user RAG update time: {e}")
return None return None
async def update_user_rag_timestamp(self, user_id: str) -> bool: async def update_user_rag_timestamp(self, user_id: str) -> bool:
@ -1707,7 +1710,11 @@ class RedisDatabase:
result = {} result = {}
for key, value in zip(keys, values): for key, value in zip(keys, values):
email = key.replace(self.KEY_PREFIXES['users'], '') email = key.replace(self.KEY_PREFIXES['users'], '')
result[email] = self._deserialize(value) logger.info(f"🔍 Found user key: {key}, type: {type(value)}")
if type(value) == str:
result[email] = value
else:
result[email] = self._deserialize(value)
return result return result
@ -1850,15 +1857,14 @@ class RedisDatabase:
return False return False
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Retrieve user data by ID""" """Get user lookup data by user ID"""
try: try:
key = f"user_by_id:{user_id}" data = await self.redis.hget("user_lookup_by_id", user_id)
data = await self.redis.get(key)
if data: if data:
return json.loads(data) return json.loads(data)
return None return None
except Exception as e: except Exception as e:
logger.error(f"❌ Error retrieving user by ID {user_id}: {e}") logger.error(f"❌ Error getting user by ID {user_id}: {e}")
return None return None
async def user_exists_by_email(self, email: str) -> bool: async def user_exists_by_email(self, email: str) -> bool:
@ -2104,6 +2110,208 @@ class RedisDatabase:
logger.error(f"❌ Error retrieving security log for {user_id}: {e}") logger.error(f"❌ Error retrieving security log for {user_id}: {e}")
return [] return []
# ============================
# Guest Management Methods
# ============================
async def set_guest(self, guest_id: str, guest_data: Dict[str, Any]) -> None:
"""Store guest data with enhanced persistence"""
try:
# Ensure last_activity is always set
guest_data["last_activity"] = datetime.now(UTC).isoformat()
# Store in Redis with both hash and individual key for redundancy
await self.redis.hset("guests", guest_id, json.dumps(guest_data))
# Also store with a longer TTL as backup
await self.redis.setex(
f"guest_backup:{guest_id}",
86400 * 7, # 7 days TTL
json.dumps(guest_data)
)
logger.debug(f"💾 Guest stored with backup: {guest_id}")
except Exception as e:
logger.error(f"❌ Error storing guest {guest_id}: {e}")
raise
async def get_guest(self, guest_id: str) -> Optional[Dict[str, Any]]:
"""Get guest data with fallback to backup"""
try:
# Try primary storage first
data = await self.redis.hget("guests", guest_id)
if data:
guest_data = json.loads(data)
# Update last activity when accessed
guest_data["last_activity"] = datetime.now(UTC).isoformat()
await self.set_guest(guest_id, guest_data)
logger.debug(f"🔍 Guest found in primary storage: {guest_id}")
return guest_data
# Fallback to backup storage
backup_data = await self.redis.get(f"guest_backup:{guest_id}")
if backup_data:
guest_data = json.loads(backup_data)
guest_data["last_activity"] = datetime.now(UTC).isoformat()
# Restore to primary storage
await self.set_guest(guest_id, guest_data)
logger.info(f"🔄 Guest restored from backup: {guest_id}")
return guest_data
logger.warning(f"⚠️ Guest not found: {guest_id}")
return None
except Exception as e:
logger.error(f"❌ Error getting guest {guest_id}: {e}")
return None
async def get_guest_by_session_id(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Get guest data by session ID"""
try:
all_guests = await self.get_all_guests()
for guest_data in all_guests.values():
if guest_data.get("session_id") == session_id:
return guest_data
return None
except Exception as e:
logger.error(f"❌ Error getting guest by session ID {session_id}: {e}")
return None
async def get_all_guests(self) -> Dict[str, Dict[str, Any]]:
"""Get all guests"""
try:
data = await self.redis.hgetall("guests")
return {
guest_id: json.loads(guest_json)
for guest_id, guest_json in data.items()
}
except Exception as e:
logger.error(f"❌ Error getting all guests: {e}")
return {}
async def delete_guest(self, guest_id: str) -> bool:
"""Delete a guest"""
try:
result = await self.redis.hdel("guests", guest_id)
if result:
logger.info(f"🗑️ Guest deleted: {guest_id}")
return True
return False
except Exception as e:
logger.error(f"❌ Error deleting guest {guest_id}: {e}")
return False
async def cleanup_inactive_guests(self, inactive_hours: int = 24) -> int:
"""Clean up inactive guest sessions with safety checks"""
try:
all_guests = await self.get_all_guests()
current_time = datetime.now(UTC)
cutoff_time = current_time - timedelta(hours=inactive_hours)
deleted_count = 0
preserved_count = 0
for guest_id, guest_data in all_guests.items():
try:
last_activity_str = guest_data.get("last_activity")
created_at_str = guest_data.get("created_at")
# Skip cleanup if guest is very new (less than 1 hour old)
if created_at_str:
created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
if current_time - created_at < timedelta(hours=1):
preserved_count += 1
logger.debug(f"🛡️ Preserving new guest: {guest_id}")
continue
# Check last activity
should_delete = False
if last_activity_str:
try:
last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00'))
if last_activity < cutoff_time:
should_delete = True
except ValueError:
# Invalid date format, but don't delete if guest is new
if not created_at_str:
should_delete = True
else:
# No last activity, but don't delete if guest is new
if not created_at_str:
should_delete = True
if should_delete:
await self.delete_guest(guest_id)
deleted_count += 1
else:
preserved_count += 1
except Exception as e:
logger.error(f"❌ Error processing guest {guest_id} for cleanup: {e}")
preserved_count += 1 # Preserve on error
if deleted_count > 0:
logger.info(f"🧹 Guest cleanup: removed {deleted_count}, preserved {preserved_count}")
return deleted_count
except Exception as e:
logger.error(f"❌ Error in guest cleanup: {e}")
return 0
async def get_guest_statistics(self) -> Dict[str, Any]:
"""Get guest usage statistics"""
try:
all_guests = await self.get_all_guests()
current_time = datetime.now(UTC)
stats = {
"total_guests": len(all_guests),
"active_last_hour": 0,
"active_last_day": 0,
"converted_guests": 0,
"by_ip": {},
"creation_timeline": {}
}
hour_ago = current_time - timedelta(hours=1)
day_ago = current_time - timedelta(days=1)
for guest_data in all_guests.values():
# Check activity
last_activity_str = guest_data.get("last_activity")
if last_activity_str:
try:
last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00'))
if last_activity > hour_ago:
stats["active_last_hour"] += 1
if last_activity > day_ago:
stats["active_last_day"] += 1
except ValueError:
pass
# Check conversions
if guest_data.get("converted_to_user_id"):
stats["converted_guests"] += 1
# IP tracking
ip = guest_data.get("ip_address", "unknown")
stats["by_ip"][ip] = stats["by_ip"].get(ip, 0) + 1
# Creation timeline
created_at_str = guest_data.get("created_at")
if created_at_str:
try:
created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
date_key = created_at.strftime('%Y-%m-%d')
stats["creation_timeline"][date_key] = stats["creation_timeline"].get(date_key, 0) + 1
except ValueError:
pass
return stats
except Exception as e:
logger.error(f"❌ Error getting guest statistics: {e}")
return {}
# Global Redis manager instance # Global Redis manager instance
redis_manager = _RedisManager() redis_manager = _RedisManager()

View File

@ -1,9 +1,9 @@
from fastapi import FastAPI, HTTPException, Depends, Query, Path, Body, status, APIRouter, Request, BackgroundTasks # type: ignore from fastapi import FastAPI, HTTPException, Depends, Query, Path, Body, status, APIRouter, Request, BackgroundTasks
from database import RedisDatabase from database import RedisDatabase
import hashlib import hashlib
from logger import logger from logger import logger
from datetime import datetime, timezone from datetime import datetime, timezone
from user_agents import parse # type: ignore from user_agents import parse
import json import json
class DeviceManager: class DeviceManager:

View File

@ -1,8 +1,8 @@
import os import os
from typing import Tuple from typing import Tuple
from logger import logger from logger import logger
from email.mime.text import MIMEText # type: ignore from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart # type: ignore from email.mime.multipart import MIMEMultipart
import smtplib import smtplib
import asyncio import asyncio
from email_templates import EMAIL_TEMPLATES from email_templates import EMAIL_TEMPLATES

View File

@ -1,13 +1,13 @@
from __future__ import annotations from __future__ import annotations
from pydantic import BaseModel, Field, model_validator # type: ignore from pydantic import BaseModel, Field, model_validator
from uuid import uuid4 from uuid import uuid4
from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING, Literal from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING, Literal
from typing_extensions import Annotated, Union from typing_extensions import Annotated, Union
import numpy as np # type: ignore import numpy as np
from uuid import uuid4 from uuid import uuid4
from prometheus_client import CollectorRegistry, Counter # type: ignore from prometheus_client import CollectorRegistry, Counter
import traceback import traceback
import os import os
import json import json

View File

@ -3,12 +3,12 @@ import weakref
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Dict, Optional, Any from typing import Dict, Optional, Any
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pydantic import BaseModel, Field # type: ignore from pydantic import BaseModel, Field
from models import ( Candidate ) from models import ( Candidate )
from .candidate_entity import CandidateEntity from .candidate_entity import CandidateEntity
from database import RedisDatabase from database import RedisDatabase
from prometheus_client import CollectorRegistry # type: ignore from prometheus_client import CollectorRegistry
class EntityManager: class EntityManager:
"""Manages lifecycle of CandidateEntity instances""" """Manages lifecycle of CandidateEntity instances"""

View File

@ -64,7 +64,7 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, current_dir) sys.path.insert(0, current_dir)
try: try:
from pydantic import BaseModel # type: ignore from pydantic import BaseModel
except ImportError as e: except ImportError as e:
print(f"Error importing pydantic: {e}") print(f"Error importing pydantic: {e}")
print("Make sure pydantic is installed: pip install pydantic") print("Make sure pydantic is installed: pip install pydantic")

View File

@ -1,4 +1,4 @@
from pydantic import BaseModel, Field # type: ignore from pydantic import BaseModel, Field
import json import json
from typing import Any, List, Set from typing import Any, List, Set

View File

@ -4,8 +4,8 @@ import re
import time import time
from typing import Any from typing import Any
import torch # type: ignore import torch
from diffusers import StableDiffusionPipeline, FluxPipeline # type: ignore from diffusers import StableDiffusionPipeline, FluxPipeline
class ImageModelCache: # Stay loaded for 3 hours class ImageModelCache: # Stay loaded for 3 hours
def __init__(self, timeout_seconds: float = 3 * 60 * 60): def __init__(self, timeout_seconds: float = 3 * 60 * 60):

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from datetime import UTC, datetime from datetime import UTC, datetime
from pydantic import BaseModel, Field # type: ignore from pydantic import BaseModel, Field
from typing import Dict, Literal, Any, AsyncGenerator, Optional from typing import Dict, Literal, Any, AsyncGenerator, Optional
import inspect import inspect
import random import random
@ -13,7 +13,7 @@ import os
import gc import gc
import tempfile import tempfile
import uuid import uuid
import torch # type: ignore import torch
import asyncio import asyncio
import time import time
import json import json

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Any, AsyncGenerator, Optional, Union from typing import Dict, List, Any, AsyncGenerator, Optional, Union
from pydantic import BaseModel, Field # type: ignore from pydantic import BaseModel, Field
from enum import Enum from enum import Enum
import asyncio import asyncio
import json import json
@ -179,7 +179,7 @@ class OllamaAdapter(BaseLLMAdapter):
def __init__(self, **config): def __init__(self, **config):
super().__init__(**config) super().__init__(**config)
import ollama import ollama
self.client = ollama.AsyncClient( # type: ignore self.client = ollama.AsyncClient(
host=config.get('host', defines.ollama_api_url) host=config.get('host', defines.ollama_api_url)
) )
@ -386,7 +386,7 @@ class OpenAIAdapter(BaseLLMAdapter):
def __init__(self, **config): def __init__(self, **config):
super().__init__(**config) super().__init__(**config)
import openai # type: ignore import openai
self.client = openai.AsyncOpenAI( self.client = openai.AsyncOpenAI(
api_key=config.get('api_key', os.getenv('OPENAI_API_KEY')) api_key=config.get('api_key', os.getenv('OPENAI_API_KEY'))
) )
@ -522,7 +522,7 @@ class AnthropicAdapter(BaseLLMAdapter):
def __init__(self, **config): def __init__(self, **config):
super().__init__(**config) super().__init__(**config)
import anthropic # type: ignore import anthropic
self.client = anthropic.AsyncAnthropic( self.client = anthropic.AsyncAnthropic(
api_key=config.get('api_key', os.getenv('ANTHROPIC_API_KEY')) api_key=config.get('api_key', os.getenv('ANTHROPIC_API_KEY'))
) )
@ -656,7 +656,7 @@ class GeminiAdapter(BaseLLMAdapter):
def __init__(self, **config): def __init__(self, **config):
super().__init__(**config) super().__init__(**config)
import google.generativeai as genai # type: ignore import google.generativeai as genai
genai.configure(api_key=config.get('api_key', os.getenv('GEMINI_API_KEY'))) genai.configure(api_key=config.get('api_key', os.getenv('GEMINI_API_KEY')))
self.genai = genai self.genai = genai
@ -867,7 +867,7 @@ class UnifiedLLMProxy:
raise ValueError("stream must be True for chat_stream") raise ValueError("stream must be True for chat_stream")
result = await self.chat(model, messages, provider, stream=True, **kwargs) result = await self.chat(model, messages, provider, stream=True, **kwargs)
# Type checker now knows this is an AsyncGenerator due to stream=True # Type checker now knows this is an AsyncGenerator due to stream=True
async for chunk in result: # type: ignore async for chunk in result:
yield chunk yield chunk
async def chat_single( async def chat_single(
@ -881,7 +881,7 @@ class UnifiedLLMProxy:
result = await self.chat(model, messages, provider, stream=False, **kwargs) result = await self.chat(model, messages, provider, stream=False, **kwargs)
# Type checker now knows this is a ChatResponse due to stream=False # Type checker now knows this is a ChatResponse due to stream=False
return result # type: ignore return result
async def generate( async def generate(
self, self,
@ -908,7 +908,7 @@ class UnifiedLLMProxy:
"""Stream text generation using specified or default provider""" """Stream text generation using specified or default provider"""
result = await self.generate(model, prompt, provider, stream=True, **kwargs) result = await self.generate(model, prompt, provider, stream=True, **kwargs)
async for chunk in result: # type: ignore async for chunk in result:
yield chunk yield chunk
async def generate_single( async def generate_single(
@ -921,7 +921,7 @@ class UnifiedLLMProxy:
"""Get single generation response using specified or default provider""" """Get single generation response using specified or default provider"""
result = await self.generate(model, prompt, provider, stream=False, **kwargs) result = await self.generate(model, prompt, provider, stream=False, **kwargs)
return result # type: ignore return result
async def embeddings( async def embeddings(
self, self,
@ -1148,7 +1148,7 @@ async def example_embeddings_usage():
# Calculate similarity between first two texts (requires numpy) # Calculate similarity between first two texts (requires numpy)
try: try:
import numpy as np # type: ignore import numpy as np
emb1 = np.array(response.data[0].embedding) emb1 = np.array(response.data[0].embedding)
emb2 = np.array(response.data[1].embedding) emb2 = np.array(response.data[1].embedding)
similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2)) similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,14 @@
from typing import Type, TypeVar from typing import Type, TypeVar
from pydantic import BaseModel # type: ignore from pydantic import BaseModel
import copy import copy
from models import Candidate, CandidateAI, Employer, Guest, BaseUserWithType
# Ensure all user models inherit from BaseUserWithType
assert issubclass(Candidate, BaseUserWithType), "Candidate must inherit from BaseUserWithType"
assert issubclass(CandidateAI, BaseUserWithType), "CandidateAI must inherit from BaseUserWithType"
assert issubclass(Employer, BaseUserWithType), "Employer must inherit from BaseUserWithType"
assert issubclass(Guest, BaseUserWithType), "Guest must inherit from BaseUserWithType"
T = TypeVar('T', bound=BaseModel) T = TypeVar('T', bound=BaseModel)
@ -12,3 +19,35 @@ def cast_to_model(model_cls: Type[T], source: BaseModel) -> T:
def cast_to_model_safe(model_cls: Type[T], source: BaseModel) -> T: def cast_to_model_safe(model_cls: Type[T], source: BaseModel) -> T:
data = {field: copy.deepcopy(getattr(source, field)) for field in model_cls.__fields__} data = {field: copy.deepcopy(getattr(source, field)) for field in model_cls.__fields__}
return model_cls(**data) return model_cls(**data)
def cast_to_base_user_with_type(user) -> BaseUserWithType:
"""
Casts a Candidate, CandidateAI, Employer, or Guest to BaseUserWithType.
This is useful for FastAPI dependencies that expect a common user type.
"""
if isinstance(user, BaseUserWithType):
return user
# If it's a dict, try to detect type
if isinstance(user, dict):
user_type = user.get("user_type") or user.get("type")
if user_type == "candidate":
if user.get("is_AI"):
return CandidateAI.model_validate(user)
return Candidate.model_validate(user)
elif user_type == "employer":
return Employer.model_validate(user)
elif user_type == "guest":
return Guest.model_validate(user)
else:
raise ValueError(f"Unknown user_type: {user_type}")
# If it's a model, check its type
if hasattr(user, "user_type"):
if getattr(user, "user_type", None) == "candidate":
if getattr(user, "is_AI", False):
return CandidateAI.model_validate(user.model_dump())
return Candidate.model_validate(user.model_dump())
elif getattr(user, "user_type", None) == "employer":
return Employer.model_validate(user.model_dump())
elif getattr(user, "user_type", None) == "guest":
return Guest.model_validate(user.model_dump())
raise TypeError(f"Cannot cast object of type {type(user)} to BaseUserWithType")

View File

@ -10,6 +10,7 @@ from auth_utils import (
sanitize_login_input, sanitize_login_input,
SecurityConfig SecurityConfig
) )
import defines
# Generic type variable # Generic type variable
T = TypeVar('T') T = TypeVar('T')
@ -256,6 +257,7 @@ class MFARequest(BaseModel):
password: str password: str
device_id: str = Field(..., alias="deviceId") device_id: str = Field(..., alias="deviceId")
device_name: str = Field(..., alias="deviceName") device_name: str = Field(..., alias="deviceName")
email: str = Field(..., alias="email")
model_config = { model_config = {
"populate_by_name": True, # Allow both field names and aliases "populate_by_name": True, # Allow both field names and aliases
} }
@ -464,9 +466,18 @@ class ErrorDetail(BaseModel):
# Main Models # Main Models
# ============================ # ============================
# Base user model without user_type field # Generic base user with user_type for API responses
class BaseUser(BaseModel): class BaseUserWithType(BaseModel):
user_type: UserType = Field(..., alias="userType")
id: str = Field(default_factory=lambda: str(uuid.uuid4())) id: str = Field(default_factory=lambda: str(uuid.uuid4()))
last_activity: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="lastActivity")
model_config = {
"populate_by_name": True, # Allow both field names and aliases
"use_enum_values": True # Use enum values instead of names
}
# Base user model without user_type field
class BaseUser(BaseUserWithType):
email: EmailStr email: EmailStr
first_name: str = Field(..., alias="firstName") first_name: str = Field(..., alias="firstName")
last_name: str = Field(..., alias="lastName") last_name: str = Field(..., alias="lastName")
@ -485,9 +496,6 @@ class BaseUser(BaseModel):
"use_enum_values": True # Use enum values instead of names "use_enum_values": True # Use enum values instead of names
} }
# Generic base user with user_type for API responses
class BaseUserWithType(BaseUser):
user_type: UserType = Field(..., alias="userType")
class RagEntry(BaseModel): class RagEntry(BaseModel):
name: str name: str
@ -519,9 +527,9 @@ class DocumentType(str, Enum):
IMAGE = "image" IMAGE = "image"
class DocumentOptions(BaseModel): class DocumentOptions(BaseModel):
include_in_RAG: Optional[bool] = Field(True, alias="includeInRAG") include_in_RAG: bool = Field(default=True, alias="includeInRAG")
is_job_document: Optional[bool] = Field(False, alias="isJobDocument") is_job_document: Optional[bool] = Field(default=False, alias="isJobDocument")
overwrite: Optional[bool] = Field(False, alias="overwrite") overwrite: Optional[bool] = Field(default=False, alias="overwrite")
model_config = { model_config = {
"populate_by_name": True # Allow both field names and aliases "populate_by_name": True # Allow both field names and aliases
} }
@ -534,7 +542,7 @@ class Document(BaseModel):
type: DocumentType type: DocumentType
size: int size: int
upload_date: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="uploadDate") upload_date: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="uploadDate")
options: DocumentOptions = Field(default_factory=DocumentOptions, alias="options") options: DocumentOptions = Field(default_factory=lambda: DocumentOptions(), alias="options")
rag_chunks: Optional[int] = Field(default=0, alias="ragChunks") rag_chunks: Optional[int] = Field(default=0, alias="ragChunks")
model_config = { model_config = {
"populate_by_name": True # Allow both field names and aliases "populate_by_name": True # Allow both field names and aliases
@ -565,7 +573,7 @@ class DocumentUpdateRequest(BaseModel):
} }
class Candidate(BaseUser): class Candidate(BaseUser):
user_type: Literal[UserType.CANDIDATE] = Field(UserType.CANDIDATE, alias="userType") user_type: UserType = Field(UserType.CANDIDATE, alias="userType")
username: str username: str
description: Optional[str] = None description: Optional[str] = None
resume: Optional[str] = None resume: Optional[str] = None
@ -584,14 +592,14 @@ class Candidate(BaseUser):
rag_content_size : int = 0 rag_content_size : int = 0
class CandidateAI(Candidate): class CandidateAI(Candidate):
user_type: Literal[UserType.CANDIDATE] = Field(UserType.CANDIDATE, alias="userType") user_type: UserType = Field(UserType.CANDIDATE, alias="userType")
is_AI: bool = Field(True, alias="isAI") is_AI: bool = Field(True, alias="isAI")
age: Optional[int] = None age: Optional[int] = None
gender: Optional[UserGender] = None gender: Optional[UserGender] = None
ethnicity: Optional[str] = None ethnicity: Optional[str] = None
class Employer(BaseUser): class Employer(BaseUser):
user_type: Literal[UserType.EMPLOYER] = Field(UserType.EMPLOYER, alias="userType") user_type: UserType = Field(UserType.EMPLOYER, alias="userType")
company_name: str = Field(..., alias="companyName") company_name: str = Field(..., alias="companyName")
industry: str industry: str
description: Optional[str] = None description: Optional[str] = None
@ -603,16 +611,18 @@ class Employer(BaseUser):
social_links: Optional[List[SocialLink]] = Field(None, alias="socialLinks") social_links: Optional[List[SocialLink]] = Field(None, alias="socialLinks")
poc: Optional[PointOfContact] = None poc: Optional[PointOfContact] = None
class Guest(BaseModel): class Guest(BaseUser):
id: str = Field(default_factory=lambda: str(uuid.uuid4())) user_type: UserType = Field(UserType.GUEST, alias="userType")
session_id: str = Field(..., alias="sessionId") session_id: str = Field(..., alias="sessionId")
created_at: datetime = Field(..., alias="createdAt") username: str # Add username for consistency with other user types
last_activity: datetime = Field(..., alias="lastActivity")
converted_to_user_id: Optional[str] = Field(None, alias="convertedToUserId") converted_to_user_id: Optional[str] = Field(None, alias="convertedToUserId")
ip_address: Optional[str] = Field(None, alias="ipAddress") ip_address: Optional[str] = Field(None, alias="ipAddress")
created_at: datetime = Field(..., alias="createdAt")
user_agent: Optional[str] = Field(None, alias="userAgent") user_agent: Optional[str] = Field(None, alias="userAgent")
rag_content_size: int = 0
model_config = { model_config = {
"populate_by_name": True # Allow both field names and aliases "populate_by_name": True, # Allow both field names and aliases
"use_enum_values": True # Use enum values instead of names
} }
class Authentication(BaseModel): class Authentication(BaseModel):
@ -635,10 +645,21 @@ class Authentication(BaseModel):
class AuthResponse(BaseModel): class AuthResponse(BaseModel):
access_token: str = Field(..., alias="accessToken") access_token: str = Field(..., alias="accessToken")
refresh_token: str = Field(..., alias="refreshToken") refresh_token: str = Field(..., alias="refreshToken")
user: Candidate | Employer user: Union[Candidate, Employer, Guest] # Add Guest support
expires_at: int = Field(..., alias="expiresAt") expires_at: int = Field(..., alias="expiresAt")
user_type: Optional[str] = Field(default=UserType.GUEST, alias="userType") # Explicit user type
is_guest: Optional[bool] = Field(default=True, alias="isGuest") # Guest indicator
model_config = { model_config = {
"populate_by_name": True # Allow both field names and aliases "populate_by_name": True
}
class GuestCleanupRequest(BaseModel):
"""Request to cleanup inactive guests"""
inactive_hours: int = Field(24, alias="inactiveHours")
model_config = {
"populate_by_name": True
} }
class JobRequirements(BaseModel): class JobRequirements(BaseModel):
@ -751,6 +772,19 @@ class ChromaDBGetResponse(BaseModel):
umap_embedding_2d: Optional[List[float]] = Field(default=None, alias="umapEmbedding2D") umap_embedding_2d: Optional[List[float]] = Field(default=None, alias="umapEmbedding2D")
umap_embedding_3d: Optional[List[float]] = Field(default=None, alias="umapEmbedding3D") umap_embedding_3d: Optional[List[float]] = Field(default=None, alias="umapEmbedding3D")
class GuestSessionResponse(BaseModel):
"""Response for guest session creation"""
access_token: str = Field(..., alias="accessToken")
refresh_token: str = Field(..., alias="refreshToken")
user: Guest
expires_at: int = Field(..., alias="expiresAt")
user_type: Literal["guest"] = Field("guest", alias="userType")
is_guest: bool = Field(True, alias="isGuest")
model_config = {
"populate_by_name": True
}
class ChatContext(BaseModel): class ChatContext(BaseModel):
type: ChatContextType type: ChatContextType
related_entity_id: Optional[str] = Field(None, alias="relatedEntityId") related_entity_id: Optional[str] = Field(None, alias="relatedEntityId")
@ -768,6 +802,98 @@ class ChatOptions(BaseModel):
"populate_by_name": True # Allow both field names and aliases "populate_by_name": True # Allow both field names and aliases
} }
# Add rate limiting configuration models
class RateLimitConfig(BaseModel):
"""Rate limit configuration"""
requests_per_minute: int = Field(..., alias="requestsPerMinute")
requests_per_hour: int = Field(..., alias="requestsPerHour")
requests_per_day: int = Field(..., alias="requestsPerDay")
burst_limit: int = Field(..., alias="burstLimit")
burst_window_seconds: int = Field(60, alias="burstWindowSeconds")
model_config = {
"populate_by_name": True
}
class RateLimitResult(BaseModel):
"""Result of rate limit check"""
allowed: bool
reason: Optional[str] = None
retry_after_seconds: Optional[int] = Field(None, alias="retryAfterSeconds")
remaining_requests: Dict[str, int] = Field(default_factory=dict, alias="remainingRequests")
reset_times: Dict[str, datetime] = Field(default_factory=dict, alias="resetTimes")
model_config = {
"populate_by_name": True
}
class RateLimitStatus(BaseModel):
"""Rate limit status for a user"""
user_id: str = Field(..., alias="userId")
user_type: str = Field(..., alias="userType")
is_admin: bool = Field(..., alias="isAdmin")
current_usage: Dict[str, int] = Field(..., alias="currentUsage")
limits: Dict[str, int] = Field(..., alias="limits")
remaining: Dict[str, int] = Field(..., alias="remaining")
reset_times: Dict[str, datetime] = Field(..., alias="resetTimes")
config: RateLimitConfig
model_config = {
"populate_by_name": True
}
# Add guest conversion request models
class GuestConversionRequest(BaseModel):
"""Request to convert guest to permanent user"""
account_type: Literal["candidate", "employer"] = Field(..., alias="accountType")
email: EmailStr
username: str
password: str
first_name: str = Field(..., alias="firstName")
last_name: str = Field(..., alias="lastName")
phone: Optional[str] = None
# Employer-specific fields (optional)
company_name: Optional[str] = Field(None, alias="companyName")
industry: Optional[str] = None
company_size: Optional[str] = Field(None, alias="companySize")
company_description: Optional[str] = Field(None, alias="companyDescription")
website_url: Optional[HttpUrl] = Field(None, alias="websiteUrl")
model_config = {
"populate_by_name": True
}
@field_validator('username')
def validate_username(cls, v):
if not v or len(v.strip()) < 3:
raise ValueError('Username must be at least 3 characters long')
return v.strip().lower()
@field_validator('password')
def validate_password_strength(cls, v):
# Import here to avoid circular imports
from auth_utils import validate_password_strength
is_valid, issues = validate_password_strength(v)
if not is_valid:
raise ValueError('; '.join(issues))
return v
# Add guest statistics response model
class GuestStatistics(BaseModel):
"""Guest usage statistics"""
total_guests: int = Field(..., alias="totalGuests")
active_last_hour: int = Field(..., alias="activeLastHour")
active_last_day: int = Field(..., alias="activeLastDay")
converted_guests: int = Field(..., alias="convertedGuests")
by_ip: Dict[str, int] = Field(..., alias="byIp")
creation_timeline: Dict[str, int] = Field(..., alias="creationTimeline")
model_config = {
"populate_by_name": True
}
from llm_proxy import (LLMMessage) from llm_proxy import (LLMMessage)
class ApiMessage(BaseModel): class ApiMessage(BaseModel):
@ -800,12 +926,14 @@ class ApiActivityType(str, Enum):
HEARTBEAT = "heartbeat" # Used for periodic updates HEARTBEAT = "heartbeat" # Used for periodic updates
class ChatMessageStatus(ApiMessage): class ChatMessageStatus(ApiMessage):
sender_id: Optional[str] = Field(default=MOCK_UUID, alias="senderId")
status: ApiStatusType = ApiStatusType.STATUS status: ApiStatusType = ApiStatusType.STATUS
type: ApiMessageType = ApiMessageType.TEXT type: ApiMessageType = ApiMessageType.TEXT
activity: ApiActivityType activity: ApiActivityType
content: Any content: Any
class ChatMessageError(ApiMessage): class ChatMessageError(ApiMessage):
sender_id: Optional[str] = Field(default=MOCK_UUID, alias="senderId")
status: ApiStatusType = ApiStatusType.ERROR status: ApiStatusType = ApiStatusType.ERROR
type: ApiMessageType = ApiMessageType.TEXT type: ApiMessageType = ApiMessageType.TEXT
content: str content: str
@ -825,6 +953,7 @@ class JobRequirementsMessage(ApiMessage):
class DocumentMessage(ApiMessage): class DocumentMessage(ApiMessage):
type: ApiMessageType = ApiMessageType.JSON type: ApiMessageType = ApiMessageType.JSON
sender_id: Optional[str] = Field(default=MOCK_UUID, alias="senderId")
document: Document = Field(..., alias="document") document: Document = Field(..., alias="document")
content: Optional[str] = "" content: Optional[str] = ""
converted: bool = Field(False, alias="converted") converted: bool = Field(False, alias="converted")
@ -837,9 +966,9 @@ class ChatMessageMetaData(BaseModel):
temperature: float = 0.7 temperature: float = 0.7
max_tokens: int = Field(default=8092, alias="maxTokens") max_tokens: int = Field(default=8092, alias="maxTokens")
top_p: float = Field(default=1, alias="topP") top_p: float = Field(default=1, alias="topP")
frequency_penalty: Optional[float] = Field(None, alias="frequencyPenalty") frequency_penalty: float = Field(default=0, alias="frequencyPenalty")
presence_penalty: Optional[float] = Field(None, alias="presencePenalty") presence_penalty: float = Field(default=0, alias="presencePenalty")
stop_sequences: Optional[List[str]] = Field(None, alias="stopSequences") stop_sequences: List[str] = Field(default=[], alias="stopSequences")
rag_results: List[ChromaDBGetResponse] = Field(default_factory=list, alias="ragResults") rag_results: List[ChromaDBGetResponse] = Field(default_factory=list, alias="ragResults")
llm_history: List[LLMMessage] = Field(default_factory=list, alias="llmHistory") llm_history: List[LLMMessage] = Field(default_factory=list, alias="llmHistory")
eval_count: int = 0 eval_count: int = 0
@ -862,16 +991,31 @@ class ChatMessageUser(ApiMessage):
class ChatMessage(ChatMessageUser): class ChatMessage(ChatMessageUser):
role: ChatSenderType = ChatSenderType.ASSISTANT role: ChatSenderType = ChatSenderType.ASSISTANT
metadata: ChatMessageMetaData = Field(default_factory=ChatMessageMetaData) metadata: ChatMessageMetaData = Field(default=ChatMessageMetaData())
#attachments: Optional[List[Attachment]] = None #attachments: Optional[List[Attachment]] = None
#reactions: Optional[List[MessageReaction]] = None #reactions: Optional[List[MessageReaction]] = None
#is_edited: bool = Field(False, alias="isEdited") #is_edited: bool = Field(False, alias="isEdited")
#edit_history: Optional[List[EditHistory]] = Field(None, alias="editHistory") #edit_history: Optional[List[EditHistory]] = Field(None, alias="editHistory")
class GPUInfo(BaseModel):
name: str
memory: int
discrete: bool
class SystemInfo(BaseModel):
installed_RAM: str = Field(..., alias="installedRAM")
graphics_cards: List[GPUInfo] = Field(..., alias="graphicsCards")
CPU: str
llm_model: str = Field(default=defines.model, alias="llmModel")
embedding_model: str = Field(default=defines.embedding_model, alias="embeddingModel")
max_context_length: int = Field(default=defines.max_context, alias="maxContextLength")
model_config = {
"populate_by_name": True # Allow both field names and aliases
}
class ChatSession(BaseModel): class ChatSession(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4())) id: str = Field(default_factory=lambda: str(uuid.uuid4()))
user_id: Optional[str] = Field(None, alias="userId") user_id: Optional[str] = Field(None, alias="userId")
guest_id: Optional[str] = Field(None, alias="guestId")
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="createdAt") created_at: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="createdAt")
last_activity: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="lastActivity") last_activity: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="lastActivity")
title: Optional[str] = None title: Optional[str] = None
@ -937,7 +1081,7 @@ class UserActivity(BaseModel):
} }
@model_validator(mode="after") @model_validator(mode="after")
def check_user_or_guest(self) -> "ChatSession": def check_user_or_guest(self):
if not self.user_id and not self.guest_id: if not self.user_id and not self.guest_id:
raise ValueError("Either user_id or guest_id must be provided") raise ValueError("Either user_id or guest_id must be provided")
return self return self
@ -1102,6 +1246,8 @@ class JobListResponse(BaseModel):
error: Optional[ErrorDetail] = None error: Optional[ErrorDetail] = None
meta: Optional[Dict[str, Any]] = None meta: Optional[Dict[str, Any]] = None
User = Union[Candidate, CandidateAI, Employer, Guest]
# Forward references resolution # Forward references resolution
Candidate.update_forward_refs() Candidate.update_forward_refs()
Employer.update_forward_refs() Employer.update_forward_refs()

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field # type: ignore from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field
from typing import List, Optional, Dict, Any, Union from typing import List, Optional, Dict, Any, Union
import os import os
import glob import glob
@ -9,15 +9,15 @@ import hashlib
import asyncio import asyncio
import logging import logging
import json import json
import numpy as np # type: ignore import numpy as np
import traceback import traceback
import chromadb # type: ignore import chromadb
from watchdog.observers import Observer # type: ignore from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler # type: ignore from watchdog.events import FileSystemEventHandler
import umap # type: ignore import umap
from markitdown import MarkItDown # type: ignore from markitdown import MarkItDown
from chromadb.api.models.Collection import Collection # type: ignore from chromadb.api.models.Collection import Collection
from .markdown_chunker import ( from .markdown_chunker import (
MarkdownChunker, MarkdownChunker,
@ -351,9 +351,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
os.makedirs(self.persist_directory) os.makedirs(self.persist_directory)
# Initialize ChromaDB client # Initialize ChromaDB client
chroma_client = chromadb.PersistentClient( # type: ignore chroma_client = chromadb.PersistentClient(
path=self.persist_directory, path=self.persist_directory,
settings=chromadb.Settings(anonymized_telemetry=False), # type: ignore settings=chromadb.Settings(anonymized_telemetry=False),
) )
# Check if the collection exists # Check if the collection exists

287
src/backend/rate_limiter.py Normal file
View File

@ -0,0 +1,287 @@
"""
Rate limiting utilities for guest and authenticated users
"""
import json
import time
from datetime import datetime, timedelta, UTC
from typing import Dict, Optional, Tuple, Any
from pydantic import BaseModel # type: ignore
from database import RedisDatabase
from logger import logger
class RateLimitConfig(BaseModel):
"""Rate limit configuration"""
requests_per_minute: int
requests_per_hour: int
requests_per_day: int
burst_limit: int # Maximum requests in a short burst
burst_window_seconds: int = 60 # Window for burst detection
class GuestRateLimitConfig(RateLimitConfig):
"""Rate limits for guest users - more restrictive"""
requests_per_minute: int = 10
requests_per_hour: int = 100
requests_per_day: int = 500
burst_limit: int = 15
burst_window_seconds: int = 60
class AuthenticatedUserRateLimitConfig(RateLimitConfig):
"""Rate limits for authenticated users - more generous"""
requests_per_minute: int = 60
requests_per_hour: int = 1000
requests_per_day: int = 10000
burst_limit: int = 100
burst_window_seconds: int = 60
class PremiumUserRateLimitConfig(RateLimitConfig):
"""Rate limits for premium/admin users - most generous"""
requests_per_minute: int = 120
requests_per_hour: int = 5000
requests_per_day: int = 50000
burst_limit: int = 200
burst_window_seconds: int = 60
class RateLimitResult(BaseModel):
"""Result of rate limit check"""
allowed: bool
reason: Optional[str] = None
retry_after_seconds: Optional[int] = None
remaining_requests: Dict[str, int] = {}
reset_times: Dict[str, datetime] = {}
class RateLimiter:
"""Rate limiter using Redis for distributed rate limiting"""
def __init__(self, database: RedisDatabase):
self.database = database
self.redis = database.redis
# Rate limit configurations
self.guest_config = GuestRateLimitConfig()
self.user_config = AuthenticatedUserRateLimitConfig()
self.premium_config = PremiumUserRateLimitConfig()
def get_config_for_user(self, user_type: str, is_admin: bool = False) -> RateLimitConfig:
"""Get rate limit configuration based on user type"""
if user_type == "guest":
return self.guest_config
elif is_admin:
return self.premium_config
else:
return self.user_config
async def check_rate_limit(
self,
user_id: str,
user_type: str,
is_admin: bool = False,
endpoint: Optional[str] = None
) -> RateLimitResult:
"""
Check if user has exceeded rate limits
Args:
user_id: Unique identifier for the user (guest session ID or user ID)
user_type: "guest", "candidate", or "employer"
is_admin: Whether user has admin privileges
endpoint: Optional endpoint-specific rate limiting
Returns:
RateLimitResult indicating if request is allowed
"""
config = self.get_config_for_user(user_type, is_admin)
current_time = datetime.now(UTC)
# Create Redis keys for different time windows
base_key = f"rate_limit:{user_type}:{user_id}"
keys = {
"minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}",
"hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}",
"day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}",
"burst": f"{base_key}:burst"
}
# Add endpoint-specific limiting if provided
if endpoint:
keys = {k: f"{v}:{endpoint}" for k, v in keys.items()}
try:
# Use Redis pipeline for atomic operations
pipe = self.redis.pipeline()
# Get current counts
for key in keys.values():
pipe.get(key)
results = await pipe.execute()
current_counts = {
"minute": int(results[0] or 0),
"hour": int(results[1] or 0),
"day": int(results[2] or 0),
"burst": int(results[3] or 0)
}
# Check limits
limits = {
"minute": config.requests_per_minute,
"hour": config.requests_per_hour,
"day": config.requests_per_day,
"burst": config.burst_limit
}
# Check each limit
for window, current_count in current_counts.items():
limit = limits[window]
if current_count >= limit:
# Calculate retry after time
if window == "minute":
retry_after = 60 - current_time.second
elif window == "hour":
retry_after = 3600 - (current_time.minute * 60 + current_time.second)
elif window == "day":
retry_after = 86400 - (current_time.hour * 3600 + current_time.minute * 60 + current_time.second)
else: # burst
retry_after = config.burst_window_seconds
logger.warning(f"🚫 Rate limit exceeded for {user_type} {user_id}: {current_count}/{limit} {window}")
return RateLimitResult(
allowed=False,
reason=f"Rate limit exceeded: {current_count}/{limit} requests per {window}",
retry_after_seconds=retry_after,
remaining_requests={k: max(0, limits[k] - v) for k, v in current_counts.items()},
reset_times=self._calculate_reset_times(current_time)
)
# If we get here, request is allowed - increment counters
pipe = self.redis.pipeline()
# Increment minute counter (expires after 2 minutes)
pipe.incr(keys["minute"])
pipe.expire(keys["minute"], 120)
# Increment hour counter (expires after 2 hours)
pipe.incr(keys["hour"])
pipe.expire(keys["hour"], 7200)
# Increment day counter (expires after 2 days)
pipe.incr(keys["day"])
pipe.expire(keys["day"], 172800)
# Increment burst counter (expires after burst window)
pipe.incr(keys["burst"])
pipe.expire(keys["burst"], config.burst_window_seconds)
await pipe.execute()
# Calculate remaining requests
remaining = {
k: max(0, limits[k] - (current_counts[k] + 1))
for k in current_counts.keys()
}
logger.debug(f"✅ Rate limit check passed for {user_type} {user_id}")
return RateLimitResult(
allowed=True,
remaining_requests=remaining,
reset_times=self._calculate_reset_times(current_time)
)
except Exception as e:
logger.error(f"❌ Rate limit check failed for {user_id}: {e}")
# Fail open - allow request if rate limiting system fails
return RateLimitResult(allowed=True, reason="Rate limit check failed - allowing request")
def _calculate_reset_times(self, current_time: datetime) -> Dict[str, datetime]:
"""Calculate when each rate limit window resets"""
next_minute = current_time.replace(second=0, microsecond=0) + timedelta(minutes=1)
next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
next_day = current_time.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
return {
"minute": next_minute,
"hour": next_hour,
"day": next_day
}
async def get_user_rate_limit_status(
self,
user_id: str,
user_type: str,
is_admin: bool = False
) -> Dict[str, Any]:
"""Get current rate limit status for a user"""
config = self.get_config_for_user(user_type, is_admin)
current_time = datetime.now(UTC)
base_key = f"rate_limit:{user_type}:{user_id}"
keys = {
"minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}",
"hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}",
"day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}",
"burst": f"{base_key}:burst"
}
try:
pipe = self.redis.pipeline()
for key in keys.values():
pipe.get(key)
results = await pipe.execute()
current_counts = {
"minute": int(results[0] or 0),
"hour": int(results[1] or 0),
"day": int(results[2] or 0),
"burst": int(results[3] or 0)
}
limits = {
"minute": config.requests_per_minute,
"hour": config.requests_per_hour,
"day": config.requests_per_day,
"burst": config.burst_limit
}
return {
"user_id": user_id,
"user_type": user_type,
"is_admin": is_admin,
"current_usage": current_counts,
"limits": limits,
"remaining": {k: max(0, limits[k] - current_counts[k]) for k in limits.keys()},
"reset_times": self._calculate_reset_times(current_time),
"config": config.model_dump()
}
except Exception as e:
logger.error(f"❌ Failed to get rate limit status for {user_id}: {e}")
return {"error": str(e)}
async def reset_user_rate_limits(self, user_id: str, user_type: str) -> bool:
"""Reset all rate limits for a user (admin function)"""
try:
base_key = f"rate_limit:{user_type}:{user_id}"
pattern = f"{base_key}:*"
cursor = 0
deleted_count = 0
while True:
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
if keys:
await self.redis.delete(*keys)
deleted_count += len(keys)
if cursor == 0:
break
logger.info(f"🔄 Reset {deleted_count} rate limit keys for {user_type} {user_id}")
return True
except Exception as e:
logger.error(f"❌ Failed to reset rate limits for {user_id}: {e}")
return False

View File

@ -2,6 +2,7 @@ import defines
import re import re
import subprocess import subprocess
import math import math
from models import SystemInfo
def get_installed_ram(): def get_installed_ram():
try: try:
@ -70,12 +71,18 @@ def get_cpu_info():
except Exception as e: except Exception as e:
return f"Error retrieving CPU info: {e}" return f"Error retrieving CPU info: {e}"
def system_info(): def system_info() -> SystemInfo:
return { """
"System RAM": get_installed_ram(), Collects system information including RAM, GPU, CPU, LLM model, embedding model, and context length.
"Graphics Card": get_graphics_cards(), Returns:
"CPU": get_cpu_info(), SystemInfo: An object containing the collected system information.
"LLM Model": defines.model, """
"Embedding Model": defines.embedding_model, system = SystemInfo(
"Context length": defines.max_context, installed_RAM=get_installed_ram(),
} graphics_cards=get_graphics_cards(),
CPU=get_cpu_info(),
llm_model=defines.model,
embedding_model=defines.embedding_model,
max_context_length=defines.max_context,
)
return system

View File

@ -1,5 +1,5 @@
import os import os
from pydantic import BaseModel, Field, model_validator # type: ignore from pydantic import BaseModel, Field, model_validator
from typing import List, Optional, Generator, ClassVar, Any, Dict from typing import List, Optional, Generator, ClassVar, Any, Dict
from datetime import datetime from datetime import datetime
from typing import ( from typing import (
@ -7,12 +7,12 @@ from typing import (
) )
from typing_extensions import Annotated from typing_extensions import Annotated
from bs4 import BeautifulSoup # type: ignore from bs4 import BeautifulSoup
from geopy.geocoders import Nominatim # type: ignore from geopy.geocoders import Nominatim
import pytz # type: ignore import pytz
import requests import requests
import yfinance as yf # type: ignore import yfinance as yf
import logging import logging
@ -523,4 +523,4 @@ def enabled_tools(tools: List[ToolEntry]) -> List[ToolEntry]:
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite", "GenerateImage"] tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite", "GenerateImage"]
__all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"] __all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"]
# __all__.extend(__tool_functions__) # type: ignore

View File

@ -1,6 +1,6 @@
from fastapi import FastAPI, HTTPException # type: ignore from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse # type: ignore from fastapi.responses import StreamingResponse
from pydantic import BaseModel # type: ignore from pydantic import BaseModel
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any
import json import json
import asyncio import asyncio
@ -234,5 +234,5 @@ async def health_check():
} }
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn # type: ignore import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@ -92,7 +92,7 @@ class OllamaAdapter(BaseLLMAdapter):
def __init__(self, **config): def __init__(self, **config):
super().__init__(**config) super().__init__(**config)
import ollama import ollama
self.client = ollama.AsyncClient( # type: ignore self.client = ollama.AsyncClient(
host=config.get('host', 'http://localhost:11434') host=config.get('host', 'http://localhost:11434')
) )
@ -187,7 +187,7 @@ class OpenAIAdapter(BaseLLMAdapter):
def __init__(self, **config): def __init__(self, **config):
super().__init__(**config) super().__init__(**config)
import openai # type: ignore import openai
self.client = openai.AsyncOpenAI( self.client = openai.AsyncOpenAI(
api_key=config.get('api_key', os.getenv('OPENAI_API_KEY')) api_key=config.get('api_key', os.getenv('OPENAI_API_KEY'))
) )
@ -259,7 +259,7 @@ class AnthropicAdapter(BaseLLMAdapter):
def __init__(self, **config): def __init__(self, **config):
super().__init__(**config) super().__init__(**config)
import anthropic # type: ignore import anthropic
self.client = anthropic.AsyncAnthropic( self.client = anthropic.AsyncAnthropic(
api_key=config.get('api_key', os.getenv('ANTHROPIC_API_KEY')) api_key=config.get('api_key', os.getenv('ANTHROPIC_API_KEY'))
) )
@ -344,7 +344,7 @@ class GeminiAdapter(BaseLLMAdapter):
def __init__(self, **config): def __init__(self, **config):
super().__init__(**config) super().__init__(**config)
import google.generativeai as genai # type: ignore import google.generativeai as genai
genai.configure(api_key=config.get('api_key', os.getenv('GEMINI_API_KEY'))) genai.configure(api_key=config.get('api_key', os.getenv('GEMINI_API_KEY')))
self.genai = genai self.genai = genai
@ -476,7 +476,7 @@ class UnifiedLLMProxy:
result = await self.chat(model, messages, provider, stream=True, **kwargs) result = await self.chat(model, messages, provider, stream=True, **kwargs)
# Type checker now knows this is an AsyncGenerator due to stream=True # Type checker now knows this is an AsyncGenerator due to stream=True
async for chunk in result: # type: ignore async for chunk in result:
yield chunk yield chunk
async def chat_single( async def chat_single(
@ -490,7 +490,7 @@ class UnifiedLLMProxy:
result = await self.chat(model, messages, provider, stream=False, **kwargs) result = await self.chat(model, messages, provider, stream=False, **kwargs)
# Type checker now knows this is a ChatResponse due to stream=False # Type checker now knows this is a ChatResponse due to stream=False
return result # type: ignore return result
async def generate( async def generate(
self, self,
@ -517,7 +517,7 @@ class UnifiedLLMProxy:
"""Stream text generation using specified or default provider""" """Stream text generation using specified or default provider"""
result = await self.generate(model, prompt, provider, stream=True, **kwargs) result = await self.generate(model, prompt, provider, stream=True, **kwargs)
async for chunk in result: # type: ignore async for chunk in result:
yield chunk yield chunk
async def generate_single( async def generate_single(
@ -530,7 +530,7 @@ class UnifiedLLMProxy:
"""Get single generation response using specified or default provider""" """Get single generation response using specified or default provider"""
result = await self.generate(model, prompt, provider, stream=False, **kwargs) result = await self.generate(model, prompt, provider, stream=False, **kwargs)
return result # type: ignore return result
async def list_models(self, provider: Optional[LLMProvider] = None) -> List[str]: async def list_models(self, provider: Optional[LLMProvider] = None) -> List[str]:
"""List available models for specified or default provider""" """List available models for specified or default provider"""

View File

@ -1,8 +1,8 @@
LLM_TIMEOUT = 600 LLM_TIMEOUT = 600
from utils import logger from utils import logger
from pydantic import BaseModel, Field, ValidationError # type: ignore from pydantic import BaseModel, Field, ValidationError
from pydantic_core import PydanticSerializationError # type: ignore from pydantic_core import PydanticSerializationError
from typing import List from typing import List
from typing import AsyncGenerator, Dict, Optional from typing import AsyncGenerator, Dict, Optional
@ -48,18 +48,18 @@ try_import("prometheus_fastapi_instrumentator")
import ollama import ollama
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, HTTPException, Depends # type: ignore from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse # type: ignore from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse
from fastapi.middleware.cors import CORSMiddleware # type: ignore from fastapi.middleware.cors import CORSMiddleware
import uvicorn # type: ignore import uvicorn
import numpy as np # type: ignore import numpy as np
from utils import redis_manager from utils import redis_manager
import redis.asyncio as redis # type: ignore import redis.asyncio as redis
# Prometheus # Prometheus
from prometheus_client import Summary # type: ignore from prometheus_client import Summary
from prometheus_fastapi_instrumentator import Instrumentator # type: ignore from prometheus_fastapi_instrumentator import Instrumentator
from prometheus_client import CollectorRegistry, Counter # type: ignore from prometheus_client import CollectorRegistry, Counter
from utils import ( from utils import (
rag as Rag, rag as Rag,
@ -1308,7 +1308,7 @@ def main():
warnings.filterwarnings("ignore", category=UserWarning, module="umap.*") warnings.filterwarnings("ignore", category=UserWarning, module="umap.*")
llm = ollama.Client(host=args.ollama_server) # type: ignore llm = ollama.Client(host=args.ollama_server)
web_server = WebServer(llm, args.ollama_model) web_server = WebServer(llm, args.ollama_model)
web_server.run(host=args.web_host, port=args.web_port, use_reloader=False) web_server.run(host=args.web_host, port=args.web_port, use_reloader=False)

View File

@ -5,7 +5,7 @@ import sys
import os import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../utils"))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../utils")))
from markdown_chunker import MarkdownChunker # type: ignore from markdown_chunker import MarkdownChunker
chunker = MarkdownChunker() chunker = MarkdownChunker()
chunks = chunker.process_file("./src/tests/test.md") # docs/resume/resume.md") chunks = chunker.process_file("./src/tests/test.md") # docs/resume/resume.md")

View File

@ -1,10 +1,10 @@
from fastapi import FastAPI, Request, Depends, Query # type: ignore from fastapi import FastAPI, Request, Depends, Query
from fastapi.responses import RedirectResponse, JSONResponse # type: ignore from fastapi.responses import RedirectResponse, JSONResponse
from uuid import UUID, uuid4 from uuid import UUID, uuid4
import logging import logging
import traceback import traceback
from typing import Callable, Optional from typing import Callable, Optional
from anyio.to_thread import run_sync # type: ignore from anyio.to_thread import run_sync
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -69,7 +69,7 @@ class ContextRouteManager:
logger.info(f"Invalid UUID, redirecting to {redirect_url}") logger.info(f"Invalid UUID, redirecting to {redirect_url}")
raise RedirectToContext(redirect_url) raise RedirectToContext(redirect_url)
return _ensure_context_dependency # type: ignore return _ensure_context_dependency
def route_pattern(self, path: str, *dependencies, **kwargs): def route_pattern(self, path: str, *dependencies, **kwargs):
logger.info(f"Registering route: {path}") logger.info(f"Registering route: {path}")
@ -134,6 +134,6 @@ async def redirect_history(
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn # type: ignore import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8900) uvicorn.run(app, host="0.0.0.0", port=8900)

View File

@ -1,9 +1,9 @@
# From /opt/backstory run: # From /opt/backstory run:
# python -m src.tests.test-embedding # python -m src.tests.test-embedding
import numpy as np # type: ignore import numpy as np
import logging import logging
import argparse import argparse
from ollama import Client # type: ignore from ollama import Client
from ..utils import defines from ..utils import defines
# Configure logging # Configure logging

View File

@ -1,11 +1,11 @@
# From /opt/backstory run: # From /opt/backstory run:
# python -m src.tests.test-rag # python -m src.tests.test-rag
from ..utils import logger from ..utils import logger
from pydantic import BaseModel, field_validator # type: ignore from pydantic import BaseModel, field_validator
from prometheus_client import CollectorRegistry # type: ignore from prometheus_client import CollectorRegistry
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
import ollama import ollama
import numpy as np # type: ignore import numpy as np
from ..utils import (rag as Rag, ChromaDBGetResponse) from ..utils import (rag as Rag, ChromaDBGetResponse)
from ..utils import Context from ..utils import Context
from ..utils import defines from ..utils import defines
@ -44,7 +44,7 @@ rag.embeddings = np.array([[1.0, 2.0], [3.0, 4.0]])
json_str = rag.model_dump(mode="json") json_str = rag.model_dump(mode="json")
logger.info(json_str) logger.info(json_str)
rag = ChromaDBGetResponse.model_validate(json_str) rag = ChromaDBGetResponse.model_validate(json_str)
llm = ollama.Client(host=defines.ollama_api_url) # type: ignore llm = ollama.Client(host=defines.ollama_api_url)
prometheus_collector = CollectorRegistry() prometheus_collector = CollectorRegistry()
observer, file_watcher = Rag.start_file_watcher( observer, file_watcher = Rag.start_file_watcher(
llm=llm, llm=llm,

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
from pydantic import BaseModel # type: ignore from pydantic import BaseModel
from typing import ( from typing import (
Any, Any,
Set Set
@ -39,7 +39,7 @@ __all__ = [
"generate_image", "ImageRequest" "generate_image", "ImageRequest"
] ]
__all__.extend(agents_all) # type: ignore __all__.extend(agents_all)
logger = setup_logging(level=defines.logging_level) logger = setup_logging(level=defines.logging_level)

View File

@ -44,7 +44,7 @@ for path in package_dir.glob("*.py"):
class_registry[name] = (full_module_name, name) class_registry[name] = (full_module_name, name)
globals()[name] = obj globals()[name] = obj
logger.info(f"Adding agent: {name}") logger.info(f"Adding agent: {name}")
__all__.append(name) # type: ignore __all__.append(name)
except ImportError as e: except ImportError as e:
logger.error(f"Error importing {full_module_name}: {e}") logger.error(f"Error importing {full_module_name}: {e}")
raise e raise e

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
from pydantic import BaseModel, Field # type: ignore from pydantic import BaseModel, Field
from typing import ( from typing import (
Literal, Literal,
get_args, get_args,
@ -19,7 +19,7 @@ import inspect
from abc import ABC from abc import ABC
import asyncio import asyncio
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore from prometheus_client import Counter, Summary, CollectorRegistry
from ..setup_logging import setup_logging from ..setup_logging import setup_logging
@ -33,7 +33,7 @@ from .types import agent_registry
from .. import defines from .. import defines
from ..message import Message, Tunables from ..message import Message, Tunables
from ..metrics import Metrics from ..metrics import Metrics
from ..tools import TickerValue, WeatherForecast, AnalyzeSite, GenerateImage, DateTime, llm_tools # type: ignore -- dynamically added to __all__ from ..tools import TickerValue, WeatherForecast, AnalyzeSite, GenerateImage, DateTime, llm_tools
from ..conversation import Conversation from ..conversation import Conversation
class LLMMessage(BaseModel): class LLMMessage(BaseModel):

View File

@ -53,7 +53,7 @@ class Chat(Agent):
Chat Agent Chat Agent
""" """
agent_type: Literal["chat"] = "chat" # type: ignore agent_type: Literal["chat"] = "chat"
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
system_prompt: str = system_message system_prompt: str = system_message

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
from pydantic import model_validator # type: ignore from pydantic import model_validator
from typing import ( from typing import (
Literal, Literal,
ClassVar, ClassVar,
@ -31,7 +31,7 @@ When answering queries, follow these steps:
class FactCheck(Agent): class FactCheck(Agent):
agent_type: Literal["fact_check"] = "fact_check" # type: ignore agent_type: Literal["fact_check"] = "fact_check"
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
system_prompt: str = system_fact_check system_prompt: str = system_fact_check

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
from pydantic import model_validator, Field, BaseModel # type: ignore from pydantic import model_validator, Field, BaseModel
from typing import ( from typing import (
Dict, Dict,
Literal, Literal,
@ -37,7 +37,7 @@ seed = int(time.time())
random.seed(seed) random.seed(seed)
class ImageGenerator(Agent): class ImageGenerator(Agent):
agent_type: Literal["image"] = "image" # type: ignore agent_type: Literal["image"] = "image"
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
agent_persist: bool = False agent_persist: bool = False

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
from pydantic import model_validator, Field # type: ignore from pydantic import model_validator, Field
from typing import ( from typing import (
Dict, Dict,
Literal, Literal,
@ -17,7 +17,7 @@ import traceback
import asyncio import asyncio
import time import time
import asyncio import asyncio
import numpy as np # type: ignore import numpy as np
from . base import Agent, agent_registry, LLMMessage from . base import Agent, agent_registry, LLMMessage
from .. message import Message from .. message import Message
@ -36,7 +36,7 @@ Answer questions about the job description.
class JobDescription(Agent): class JobDescription(Agent):
agent_type: Literal["job_description"] = "job_description" # type: ignore agent_type: Literal["job_description"] = "job_description"
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
system_prompt: str = system_generate_resume system_prompt: str = system_generate_resume

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
from pydantic import model_validator, Field, BaseModel # type: ignore from pydantic import model_validator, Field, BaseModel
from typing import ( from typing import (
Dict, Dict,
Literal, Literal,
@ -23,7 +23,7 @@ import asyncio
import time import time
import os import os
import random import random
from names_dataset import NameDataset, NameWrapper # type: ignore from names_dataset import NameDataset, NameWrapper
from .base import Agent, agent_registry, LLMMessage from .base import Agent, agent_registry, LLMMessage
from ..message import Message from ..message import Message
@ -128,7 +128,7 @@ logger = logging.getLogger(__name__)
class EthnicNameGenerator: class EthnicNameGenerator:
def __init__(self): def __init__(self):
try: try:
from names_dataset import NameDataset # type: ignore from names_dataset import NameDataset
self.nd = NameDataset() self.nd = NameDataset()
except ImportError: except ImportError:
logger.error("NameDataset not available. Please install: pip install names-dataset") logger.error("NameDataset not available. Please install: pip install names-dataset")
@ -292,7 +292,7 @@ class EthnicNameGenerator:
return names return names
class PersonaGenerator(Agent): class PersonaGenerator(Agent):
agent_type: Literal["persona"] = "persona" # type: ignore agent_type: Literal["persona"] = "persona"
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
agent_persist: bool = False agent_persist: bool = False

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
from pydantic import model_validator # type: ignore from pydantic import model_validator
from typing import ( from typing import (
Literal, Literal,
ClassVar, ClassVar,
@ -46,7 +46,7 @@ When answering queries, follow these steps:
class Resume(Agent): class Resume(Agent):
agent_type: Literal["resume"] = "resume" # type: ignore agent_type: Literal["resume"] = "resume"
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
system_prompt: str = system_fact_check system_prompt: str = system_fact_check

View File

@ -1,4 +1,4 @@
from pydantic import BaseModel, Field # type: ignore from pydantic import BaseModel, Field
import json import json
from typing import Any, List, Set from typing import Any, List, Set

View File

@ -1,9 +1,9 @@
from __future__ import annotations from __future__ import annotations
from pydantic import BaseModel, Field, model_validator # type: ignore from pydantic import BaseModel, Field, model_validator
from uuid import uuid4 from uuid import uuid4
from typing import List, Optional, Generator, ClassVar, Any, TYPE_CHECKING from typing import List, Optional, Generator, ClassVar, Any, TYPE_CHECKING
from typing_extensions import Annotated, Union from typing_extensions import Annotated, Union
import numpy as np # type: ignore import numpy as np
import logging import logging
from uuid import uuid4 from uuid import uuid4
import traceback import traceback
@ -44,7 +44,7 @@ class Context(BaseModel):
user_facts: Optional[str] = None user_facts: Optional[str] = None
# Class managed fields # Class managed fields
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( # type: ignore agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field(
default_factory=list default_factory=list
) )

View File

@ -1,4 +1,4 @@
from pydantic import BaseModel, Field # type: ignore from pydantic import BaseModel, Field
from typing import List from typing import List
from .message import Message from .message import Message

View File

@ -4,8 +4,8 @@ import re
import time import time
from typing import Any from typing import Any
import torch # type: ignore import torch
from diffusers import StableDiffusionPipeline, FluxPipeline # type: ignore from diffusers import StableDiffusionPipeline, FluxPipeline
class ImageModelCache: # Stay loaded for 3 hours class ImageModelCache: # Stay loaded for 3 hours
def __init__(self, timeout_seconds: float = 3 * 60 * 60): def __init__(self, timeout_seconds: float = 3 * 60 * 60):

View File

@ -1,8 +1,8 @@
from pydantic import BaseModel, Field # type: ignore from pydantic import BaseModel, Field
from typing import Dict, List, Optional, Any, Union, Mapping from typing import Dict, List, Optional, Any, Union, Mapping
from datetime import datetime, timezone from datetime import datetime, timezone
from . rag import ChromaDBGetResponse from . rag import ChromaDBGetResponse
from ollama._types import Options # type: ignore from ollama._types import Options
class Tunables(BaseModel): class Tunables(BaseModel):
enable_rag: bool = True # Enable RAG collection chromadb matching enable_rag: bool = True # Enable RAG collection chromadb matching

View File

@ -1,4 +1,4 @@
from prometheus_client import Counter, Histogram # type: ignore from prometheus_client import Counter, Histogram
from threading import Lock from threading import Lock
def singleton(cls): def singleton(cls):

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
from pydantic import BaseModel, Field # type: ignore from pydantic import BaseModel, Field
from typing import Dict, Literal, Any, AsyncGenerator, Optional from typing import Dict, Literal, Any, AsyncGenerator, Optional
import inspect import inspect
import random import random
@ -12,7 +12,7 @@ import os
import gc import gc
import tempfile import tempfile
import uuid import uuid
import torch # type: ignore import torch
import asyncio import asyncio
import time import time
import json import json

View File

@ -1,4 +1,4 @@
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field # type: ignore from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field
from typing import List, Optional, Dict, Any, Union from typing import List, Optional, Dict, Any, Union
import os import os
import glob import glob
@ -8,16 +8,16 @@ import hashlib
import asyncio import asyncio
import logging import logging
import json import json
import numpy as np # type: ignore import numpy as np
import traceback import traceback
import chromadb # type: ignore import chromadb
import ollama import ollama
from watchdog.observers import Observer # type: ignore from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler # type: ignore from watchdog.events import FileSystemEventHandler
import umap # type: ignore import umap
from markitdown import MarkItDown # type: ignore from markitdown import MarkItDown
from chromadb.api.models.Collection import Collection # type: ignore from chromadb.api.models.Collection import Collection
from .markdown_chunker import ( from .markdown_chunker import (
MarkdownChunker, MarkdownChunker,
@ -388,9 +388,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
os.makedirs(self.persist_directory) os.makedirs(self.persist_directory)
# Initialize ChromaDB client # Initialize ChromaDB client
chroma_client = chromadb.PersistentClient( # type: ignore chroma_client = chromadb.PersistentClient(
path=self.persist_directory, path=self.persist_directory,
settings=chromadb.Settings(anonymized_telemetry=False), # type: ignore settings=chromadb.Settings(anonymized_telemetry=False),
) )
# Check if the collection exists # Check if the collection exists

View File

@ -1,4 +1,4 @@
import redis.asyncio as redis # type: ignore import redis.asyncio as redis
from typing import Optional from typing import Optional
import os import os
import logging import logging

View File

@ -1,5 +1,5 @@
import os import os
from pydantic import BaseModel, Field, model_validator # type: ignore from pydantic import BaseModel, Field, model_validator
from typing import List, Optional, Generator, ClassVar, Any, Dict from typing import List, Optional, Generator, ClassVar, Any, Dict
from datetime import datetime from datetime import datetime
from typing import ( from typing import (
@ -7,12 +7,12 @@ from typing import (
) )
from typing_extensions import Annotated from typing_extensions import Annotated
from bs4 import BeautifulSoup # type: ignore from bs4 import BeautifulSoup
from geopy.geocoders import Nominatim # type: ignore from geopy.geocoders import Nominatim
import pytz # type: ignore import pytz
import requests import requests
import yfinance as yf # type: ignore import yfinance as yf
import logging import logging
@ -523,4 +523,4 @@ def enabled_tools(tools: List[ToolEntry]) -> List[ToolEntry]:
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite", "GenerateImage"] tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite", "GenerateImage"]
__all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"] __all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"]
# __all__.extend(__tool_functions__) # type: ignore

View File

@ -1,13 +1,13 @@
from __future__ import annotations from __future__ import annotations
from pydantic import BaseModel, Field, model_validator # type: ignore from pydantic import BaseModel, Field, model_validator
from uuid import uuid4 from uuid import uuid4
from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING, Literal from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING, Literal
from typing_extensions import Annotated, Union from typing_extensions import Annotated, Union
import numpy as np # type: ignore import numpy as np
import logging import logging
from uuid import uuid4 from uuid import uuid4
from prometheus_client import CollectorRegistry, Counter # type: ignore from prometheus_client import CollectorRegistry, Counter
import traceback import traceback
import os import os
import json import json