From 20f8d7bd32a83789ef86a1e4737f888ffc7c22ef Mon Sep 17 00:00:00 2001 From: James Ketrenos Date: Sun, 8 Jun 2025 20:42:23 -0700 Subject: [PATCH] Guest seems to work! --- Dockerfile | 2 +- docker-compose.yml | 4 +- frontend/src/BackstoryApp.tsx | 2 +- frontend/src/components/Conversation.tsx | 9 +- .../EmailVerificationComponents.tsx | 16 +- frontend/src/components/JobMatchAnalysis.tsx | 2 +- frontend/src/components/Message.tsx | 1 + frontend/src/components/VectorVisualizer.tsx | 3 +- .../src/components/layout/BackstoryLayout.tsx | 11 +- frontend/src/components/layout/Header.tsx | 2 +- frontend/src/components/ui/CandidateInfo.tsx | 2 +- frontend/src/config/navigationConfig.tsx | 12 +- frontend/src/hooks/AuthContext.tsx | 474 +++--- frontend/src/pages/CandidateChatPage.tsx | 17 +- frontend/src/pages/GenerateCandidate.tsx | 2 +- frontend/src/pages/HomePage.tsx | 10 + frontend/src/pages/LoadingPage.tsx | 3 +- frontend/src/pages/LoginPage.tsx | 3 +- frontend/src/pages/LoginRequired.tsx | 3 +- frontend/src/pages/OldChatPage.tsx | 3 +- frontend/src/pages/candidate/Dashboard.tsx | 2 +- frontend/src/pages/candidate/Settings.tsx | 353 ++--- frontend/src/routes/CandidateRoute.tsx | 11 +- frontend/src/services/api-client.ts | 234 ++- frontend/src/types/types.ts | 207 ++- frontend/src/utils/Global.tsx | 15 - src/backend/agents/__init__.py | 6 +- src/backend/agents/base.py | 8 +- src/backend/agents/candidate_chat.py | 2 +- src/backend/agents/general.py | 2 +- src/backend/agents/generate_image.py | 4 +- src/backend/agents/generate_persona.py | 8 +- src/backend/agents/job_requirements.py | 6 +- src/backend/agents/rag_search.py | 2 +- src/backend/agents/skill_match.py | 6 +- src/backend/auth_utils.py | 4 +- src/backend/background_tasks.py | 175 ++ src/backend/database.py | 270 +++- src/backend/device_manager.py | 4 +- src/backend/email_service.py | 4 +- src/backend/entities/candidate_entity.py | 6 +- src/backend/entities/entity_manager.py | 4 +- src/backend/generate_types.py | 2 +- src/backend/helpers/check_serializable.py | 2 +- .../image_generator/image_model_cache.py | 4 +- src/backend/image_generator/profile_image.py | 4 +- src/backend/llm_proxy.py | 20 +- src/backend/main.py | 1411 +++++++++++++++-- src/backend/model_cast.py | 41 +- src/backend/models.py | 196 ++- src/backend/rag/rag.py | 20 +- src/backend/rate_limiter.py | 287 ++++ src/backend/system_info.py | 25 +- src/backend/tools/basetools.py | 12 +- src/multi-llm/example.py | 8 +- src/multi-llm/llm_proxy.py | 16 +- src/server.py | 24 +- src/tests/test-chunker.py | 2 +- src/tests/test-context-routing.py | 10 +- src/tests/test-embedding.py | 4 +- src/tests/test-rag.py | 8 +- src/utils/__init__.py | 4 +- src/utils/agents/__init__.py | 2 +- src/utils/agents/base.py | 6 +- src/utils/agents/chat.py | 2 +- src/utils/agents/fact_check.py | 4 +- src/utils/agents/image_generator.py | 4 +- src/utils/agents/job_description.py | 6 +- src/utils/agents/persona_generator.py | 8 +- src/utils/agents/resume.py | 4 +- src/utils/check_serializable.py | 2 +- src/utils/context.py | 6 +- src/utils/conversation.py | 2 +- src/utils/image_model_cache.py | 4 +- src/utils/message.py | 4 +- src/utils/metrics.py | 2 +- src/utils/profile_image.py | 4 +- src/utils/rag.py | 20 +- src/utils/redis_client.py | 2 +- src/utils/tools/basetools.py | 12 +- src/utils/user.py | 6 +- 81 files changed, 3284 insertions(+), 830 deletions(-) delete mode 100644 frontend/src/utils/Global.tsx create mode 100644 src/backend/background_tasks.py create mode 100644 src/backend/rate_limiter.py diff --git a/Dockerfile b/Dockerfile index a77ca3b..e00a37e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -188,7 +188,7 @@ RUN pip install prometheus-client prometheus-fastapi-instrumentator RUN pip install "redis[hiredis]>=4.5.0" # New backend implementation -RUN pip install fastapi uvicorn "python-jose[cryptography]" bcrypt python-multipart +RUN pip install fastapi uvicorn "python-jose[cryptography]" bcrypt python-multipart schedule # Needed for email verification RUN pip install pyyaml user-agents cryptography diff --git a/docker-compose.yml b/docker-compose.yml index 5fbc23a..19307f9 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -91,11 +91,11 @@ services: # Optional: Redis Commander for GUI management redis-commander: image: rediscommander/redis-commander:latest - container_name: backstory-redis-commander + container_name: redis-commander ports: - "8081:8081" environment: - - REDIS_HOSTS=redis:6379 + - REDIS_HOSTS=redis:redis:6379 networks: - internal depends_on: diff --git a/frontend/src/BackstoryApp.tsx b/frontend/src/BackstoryApp.tsx index ab84fdd..af6fb17 100644 --- a/frontend/src/BackstoryApp.tsx +++ b/frontend/src/BackstoryApp.tsx @@ -20,8 +20,8 @@ import '@fontsource/roboto/700.css'; const BackstoryApp = () => { const navigate = useNavigate(); const location = useLocation(); - const snackRef = useRef(null); const chatRef = useRef(null); + const snackRef = useRef(null); const setSnack = useCallback((message: string, severity?: SeverityType) => { snackRef.current?.setSnack(message, severity); }, [snackRef]); diff --git a/frontend/src/components/Conversation.tsx b/frontend/src/components/Conversation.tsx index 7dd7d23..6ee2c43 100644 --- a/frontend/src/components/Conversation.tsx +++ b/frontend/src/components/Conversation.tsx @@ -12,7 +12,6 @@ import { Message } from './Message'; import { DeleteConfirmation } from 'components/DeleteConfirmation'; import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryTextField'; import { BackstoryElementProps } from './BackstoryTab'; -import { connectionBase } from 'utils/Global'; import { useAuth } from "hooks/AuthContext"; import { StreamingResponse } from 'services/api-client'; import { ChatMessage, ChatContext, ChatSession, ChatQuery, ChatMessageUser, ChatMessageError, ChatMessageStreaming, ChatMessageStatus } from 'types/types'; @@ -22,7 +21,7 @@ import './Conversation.css'; import { useAppState } from 'hooks/GlobalContext'; const defaultMessage: ChatMessage = { - status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "assistant" + status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "assistant", metadata: null as any }; const loadingMessage: ChatMessage = { ...defaultMessage, content: "Establishing connection with server..." }; @@ -325,16 +324,16 @@ const Conversation = forwardRef((props: C { filteredConversation.map((message, index) => - + ) } { processingMessage !== undefined && - + } { streamingMessage !== undefined && - + } { @@ -488,10 +488,11 @@ const RegistrationSuccessDialog = ({ // Enhanced Login Component with MFA Support const LoginForm = () => { - const { login, mfaResponse, isLoading, error } = useAuth(); + const { login, mfaResponse, isLoading, error, user } = useAuth(); const [email, setEmail] = useState(''); const [password, setPassword] = useState(''); const [errorMessage, setErrorMessage] = useState(null); + const navigate = useNavigate(); useEffect(() => { if (!error) { @@ -524,10 +525,13 @@ const LoginForm = () => { handleLoginSuccess(); }; - const handleLoginSuccess = () => { - // This could be handled by a router or parent component - // For now, just showing the pattern - console.log('Login successful - redirect to dashboard'); + const handleLoginSuccess = () => { + if (!user) { + navigate('/'); + } else { + navigate(`/${user.userType}/dashboard`); + } + console.log('Login successful - redirect to dashboard'); }; return ( diff --git a/frontend/src/components/JobMatchAnalysis.tsx b/frontend/src/components/JobMatchAnalysis.tsx index b39af5b..83b3507 100644 --- a/frontend/src/components/JobMatchAnalysis.tsx +++ b/frontend/src/components/JobMatchAnalysis.tsx @@ -39,7 +39,7 @@ interface JobAnalysisProps extends BackstoryPageProps { } const defaultMessage: ChatMessage = { - status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "assistant" + status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "assistant", metadata: null as any }; const JobMatchAnalysis: React.FC = (props: JobAnalysisProps) => { diff --git a/frontend/src/components/Message.tsx b/frontend/src/components/Message.tsx index f57ee50..6ec8d33 100644 --- a/frontend/src/components/Message.tsx +++ b/frontend/src/components/Message.tsx @@ -101,6 +101,7 @@ const getStyle = (theme: Theme, type: ApiActivityType | ChatSenderType | "error" color: theme.palette.text.primary, opacity: 0.95, }, + info: 'information', preparing: 'status', processing: 'status', qualifications: { diff --git a/frontend/src/components/VectorVisualizer.tsx b/frontend/src/components/VectorVisualizer.tsx index 8386d38..a826892 100644 --- a/frontend/src/components/VectorVisualizer.tsx +++ b/frontend/src/components/VectorVisualizer.tsx @@ -17,7 +17,6 @@ import TableContainer from '@mui/material/TableContainer'; import TableRow from '@mui/material/TableRow'; import { Scrollable } from './Scrollable'; -import { connectionBase } from '../utils/Global'; import './VectorVisualizer.css'; import { BackstoryPageProps } from './BackstoryTab'; @@ -195,7 +194,7 @@ const VectorVisualizer: React.FC = (props: VectorVisualiz const [plotDimensions, setPlotDimensions] = useState({ width: 0, height: 0 }); const navigate = useNavigate(); - const candidate: Types.Candidate | null = user?.userType === 'candidate' ? user : null; + const candidate: Types.Candidate | null = user?.userType === 'candidate' ? user as Types.Candidate : null; /* Force resize of Plotly as it tends to not be the correct size if it is initially rendered * off screen (eg., the VectorVisualizer is not on the tab the app loads to) */ diff --git a/frontend/src/components/layout/BackstoryLayout.tsx b/frontend/src/components/layout/BackstoryLayout.tsx index 5d84e67..9adafcc 100644 --- a/frontend/src/components/layout/BackstoryLayout.tsx +++ b/frontend/src/components/layout/BackstoryLayout.tsx @@ -84,7 +84,6 @@ const BackstoryLayout: React.FC = (props: BackstoryLayoutP const navigate = useNavigate(); const location = useLocation(); const { guest, user } = useAuth(); - const { selectedCandidate } = useSelectedCandidate(); const [navigationItems, setNavigationItems] = useState([]); useEffect(() => { @@ -92,9 +91,13 @@ const BackstoryLayout: React.FC = (props: BackstoryLayoutP setNavigationItems(getMainNavigationItems(userType, user?.isAdmin ? true : false)); }, [user]); + useEffect(() => { + console.log({ guest, user }); + }, [guest, user]); + // Generate dynamic routes from navigation config const generateRoutes = () => { - if (!guest) return null; + if (!guest && !user) return null; const userType = user?.userType || null; const isAdmin = user?.isAdmin ? true : false; @@ -161,7 +164,7 @@ const BackstoryLayout: React.FC = (props: BackstoryLayoutP }} > - {!guest && ( + {!guest && !user && ( = (props: BackstoryLayoutP /> )} - {guest && ( + {(guest || user) && ( <> diff --git a/frontend/src/components/layout/Header.tsx b/frontend/src/components/layout/Header.tsx index b464faf..20b94ae 100644 --- a/frontend/src/components/layout/Header.tsx +++ b/frontend/src/components/layout/Header.tsx @@ -156,7 +156,7 @@ const Header: React.FC = (props: HeaderProps) => { id: 'profile', label: 'Profile', icon: , - action: () => navigate(`/${user?.userType}/dashboard/profile`) + action: () => navigate(`/${user?.userType}/profile`) }, { id: 'dashboard', diff --git a/frontend/src/components/ui/CandidateInfo.tsx b/frontend/src/components/ui/CandidateInfo.tsx index 9b1e15c..e16faf7 100644 --- a/frontend/src/components/ui/CandidateInfo.tsx +++ b/frontend/src/components/ui/CandidateInfo.tsx @@ -74,7 +74,7 @@ const CandidateInfo: React.FC = (props: CandidateInfoProps) maxWidth: "80px" }}> (Search); @@ -58,20 +59,19 @@ const SettingsPage = () => (Settings, path: '/', component: , userTypes: ['guest', 'candidate', 'employer'], exact: true, }, - { id: 'chat', label: 'Chat', path: '/chat', icon: , component: , userTypes: ['guest', 'candidate', 'employer',], }, + { id: 'home', label: , path: '/', component: , userTypes: ['guest', 'candidate', 'employer'], exact: true, }, + { id: 'chat', label: 'Chat about a Candidate', path: '/chat', icon: , component: , userTypes: ['guest', 'candidate', 'employer',], }, { id: 'candidate-menu', label: 'Tools', icon: , userTypes: ['candidate'], children: [ { id: 'candidate-dashboard', label: 'Dashboard', path: '/candidate/dashboard', icon: , component: , userTypes: ['candidate'] }, { id: 'candidate-profile', label: 'Profile', icon: , path: '/candidate/profile', component: , userTypes: ['candidate'] }, { id: 'candidate-qa-setup', label: 'Q&A Setup', icon: , path: '/candidate/qa-setup', component: Candidate q&a setup page, userTypes: ['candidate'] }, - { id: 'candidate-analytics', label: 'Analytics', icon: , path: '/candidate/analytics', component: Candidate analytics page, userTypes: ['candidate'] }, - { id: 'candidate-jobs', label: 'Jobs', icon: , path: '/candidate/jobs', component: , userTypes: ['candidate'] }, + { id: 'candidate-analytics', label: 'Analytics', icon: , path: '/candidate/analytics', component: Candidate analytics page, userTypes: ['candidate'] }, { id: 'candidate-job-analysis', label: 'Job Analysis', path: '/candidate/job-analysis', icon: , component: , userTypes: ['candidate'], }, { id: 'candidate-resumes', label: 'Resumes', icon: , path: '/candidate/resumes', component: Candidate resumes page, userTypes: ['candidate'] }, { id: 'candidate-resume-builder', label: 'Resume Builder', path: '/candidate/resume-builder', icon: , component: , userTypes: ['candidate'], }, { id: 'candidate-content', label: 'Content', icon: , path: '/candidate/content', component: , userTypes: ['candidate'] }, - { id: 'candidate-settings', label: 'Settings', path: '/candidate/settings', icon: , component: , userTypes: ['candidate'], }, + { id: 'candidate-settings', label: 'Settings', path: '/candidate/settings', icon: , component: , userTypes: ['candidate'], }, ], }, { @@ -87,7 +87,7 @@ export const navigationConfig: NavigationConfig = { { id: 'employer-settings', label: 'Settings', path: '/employer/settings', icon: , component: , userTypes: ['employer'], }, ], }, - { id: 'find-candidate', label: 'Find a Candidate', path: '/find-a-candidate', icon: , component: , userTypes: ['guest', 'candidate', 'employer'], }, + // { id: 'find-candidate', label: 'Find a Candidate', path: '/find-a-candidate', icon: , component: , userTypes: ['guest', 'candidate', 'employer'], }, { id: 'docs', label: 'Docs', path: '/docs/*', icon: , component: , userTypes: ['guest', 'candidate', 'employer'], }, { id: 'admin-menu', diff --git a/frontend/src/hooks/AuthContext.tsx b/frontend/src/hooks/AuthContext.tsx index 7cbdb48..51e0568 100644 --- a/frontend/src/hooks/AuthContext.tsx +++ b/frontend/src/hooks/AuthContext.tsx @@ -1,17 +1,19 @@ +// Replace the existing AuthContext.tsx with these enhancements + import React, { createContext, useContext, useState, useCallback, useEffect, useRef } from 'react'; import * as Types from '../types/types'; -import { ApiClient, CreateCandidateRequest, CreateEmployerRequest } from '../services/api-client'; +import { ApiClient, CreateCandidateRequest, CreateEmployerRequest, GuestConversionRequest } from '../services/api-client'; import { formatApiRequest, toCamelCase } from '../types/conversion'; // ============================ -// Types and Interfaces +// Enhanced Types and Interfaces // ============================ - interface AuthState { user: Types.User | null; guest: Types.Guest | null; isAuthenticated: boolean; + isGuest: boolean; isLoading: boolean; isInitializing: boolean; error: string | null; @@ -36,7 +38,7 @@ interface PasswordResetRequest { } // ============================ -// Token Storage Constants +// Enhanced Token Storage Constants // ============================ const TOKEN_STORAGE = { @@ -44,7 +46,8 @@ const TOKEN_STORAGE = { REFRESH_TOKEN: 'refreshToken', USER_DATA: 'userData', TOKEN_EXPIRY: 'tokenExpiry', - GUEST_DATA: 'guestData', + USER_TYPE: 'userType', + IS_GUEST: 'isGuest', PENDING_VERIFICATION_EMAIL: 'pendingVerificationEmail' } as const; @@ -84,7 +87,7 @@ function isTokenExpired(token: string): boolean { } // ============================ -// Storage Utilities with Date Conversion +// Enhanced Storage Utilities // ============================ function clearStoredAuth(): void { @@ -92,6 +95,8 @@ function clearStoredAuth(): void { localStorage.removeItem(TOKEN_STORAGE.REFRESH_TOKEN); localStorage.removeItem(TOKEN_STORAGE.USER_DATA); localStorage.removeItem(TOKEN_STORAGE.TOKEN_EXPIRY); + localStorage.removeItem(TOKEN_STORAGE.USER_TYPE); + localStorage.removeItem(TOKEN_STORAGE.IS_GUEST); } function prepareUserDataForStorage(user: Types.User): string { @@ -119,11 +124,21 @@ function parseStoredUserData(userDataStr: string): Types.User | null { } } -function storeAuthData(authResponse: Types.AuthResponse): void { +function updateStoredUserData(user: Types.User): void { + try { + localStorage.setItem(TOKEN_STORAGE.USER_DATA, prepareUserDataForStorage(user)); + } catch (error) { + console.error('Failed to update stored user data:', error); + } +} + +function storeAuthData(authResponse: Types.AuthResponse, isGuest: boolean = false): void { localStorage.setItem(TOKEN_STORAGE.ACCESS_TOKEN, authResponse.accessToken); localStorage.setItem(TOKEN_STORAGE.REFRESH_TOKEN, authResponse.refreshToken); localStorage.setItem(TOKEN_STORAGE.USER_DATA, prepareUserDataForStorage(authResponse.user)); localStorage.setItem(TOKEN_STORAGE.TOKEN_EXPIRY, authResponse.expiresAt.toString()); + localStorage.setItem(TOKEN_STORAGE.USER_TYPE, authResponse.user.userType); + localStorage.setItem(TOKEN_STORAGE.IS_GUEST, isGuest.toString()); } function getStoredAuthData(): { @@ -131,11 +146,15 @@ function getStoredAuthData(): { refreshToken: string | null; userData: Types.User | null; expiresAt: number | null; + userType: string | null; + isGuest: boolean; } { const accessToken = localStorage.getItem(TOKEN_STORAGE.ACCESS_TOKEN); const refreshToken = localStorage.getItem(TOKEN_STORAGE.REFRESH_TOKEN); const userDataStr = localStorage.getItem(TOKEN_STORAGE.USER_DATA); const expiryStr = localStorage.getItem(TOKEN_STORAGE.TOKEN_EXPIRY); + const userType = localStorage.getItem(TOKEN_STORAGE.USER_TYPE); + const isGuestStr = localStorage.getItem(TOKEN_STORAGE.IS_GUEST); let userData: Types.User | null = null; let expiresAt: number | null = null; @@ -152,55 +171,18 @@ function getStoredAuthData(): { clearStoredAuth(); } - return { accessToken, refreshToken, userData, expiresAt }; -} - -function updateStoredUserData(user: Types.User): void { - try { - localStorage.setItem(TOKEN_STORAGE.USER_DATA, prepareUserDataForStorage(user)); - } catch (error) { - console.error('Failed to update stored user data:', error); - } -} - -// ============================ -// Guest Session Utilities -// ============================ - -function createGuestSession(): Types.Guest { - const sessionId = `guest_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`; - const guest: Types.Guest = { - sessionId, - createdAt: new Date(), - lastActivity: new Date(), - ipAddress: 'unknown', - userAgent: navigator.userAgent + return { + accessToken, + refreshToken, + userData, + expiresAt, + userType, + isGuest: isGuestStr === 'true' }; - - try { - localStorage.setItem(TOKEN_STORAGE.GUEST_DATA, JSON.stringify(formatApiRequest(guest))); - } catch (error) { - console.error('Failed to store guest session:', error); - } - - return guest; -} - -function getStoredGuestData(): Types.Guest | null { - try { - const guestDataStr = localStorage.getItem(TOKEN_STORAGE.GUEST_DATA); - if (guestDataStr) { - return toCamelCase(JSON.parse(guestDataStr)); - } - } catch (error) { - console.error('Failed to parse stored guest data:', error); - localStorage.removeItem(TOKEN_STORAGE.GUEST_DATA); - } - return null; } // ============================ -// Main Authentication Hook +// Enhanced Authentication Hook // ============================ function useAuthenticationLogic() { @@ -208,6 +190,7 @@ function useAuthenticationLogic() { user: null, guest: null, isAuthenticated: false, + isGuest: false, isLoading: false, isInitializing: true, error: null, @@ -216,6 +199,7 @@ function useAuthenticationLogic() { const [apiClient] = useState(() => new ApiClient()); const initializationCompleted = useRef(false); + const guestCreationAttempted = useRef(false); // Token refresh function const refreshAccessToken = useCallback(async (refreshToken: string): Promise => { @@ -228,6 +212,58 @@ function useAuthenticationLogic() { } }, [apiClient]); + // Create guest session + const createGuestSession = useCallback(async (): Promise => { + if (guestCreationAttempted.current) { + return false; + } + + guestCreationAttempted.current = true; + + try { + console.log('๐Ÿ”„ Creating guest session...'); + const guestAuth = await apiClient.createGuestSession(); + + if (guestAuth && guestAuth.user && guestAuth.user.userType === 'guest') { + storeAuthData(guestAuth, true); + apiClient.setAuthToken(guestAuth.accessToken); + + setAuthState({ + user: null, + guest: guestAuth.user as Types.Guest, + isAuthenticated: true, + isGuest: true, + isLoading: false, + isInitializing: false, + error: null, + mfaResponse: null, + }); + + console.log('๐Ÿ‘ค Guest session created successfully:', guestAuth.user); + return true; + } + + return false; + } catch (error) { + console.error('โŒ Failed to create guest session:', error); + guestCreationAttempted.current = false; + + // Set to unauthenticated state if guest creation fails + setAuthState(prev => ({ + ...prev, + user: null, + guest: null, + isAuthenticated: false, + isGuest: false, + isLoading: false, + isInitializing: false, + error: 'Failed to create guest session', + })); + + return false; + } + }, [apiClient]); + // Initialize authentication state const initializeAuth = useCallback(async () => { if (initializationCompleted.current) { @@ -235,99 +271,94 @@ function useAuthenticationLogic() { } try { - // Initialize guest session first - let guest = getStoredGuestData(); - if (!guest) { - guest = createGuestSession(); + const stored = getStoredAuthData(); + + // If no stored tokens, create guest session + if (!stored.accessToken || !stored.refreshToken || !stored.userData) { + console.log('๐Ÿ”„ No stored auth found, creating guest session...'); + await createGuestSession(); + return; } - const stored = getStoredAuthData(); - - // If no stored tokens, user is not authenticated but has guest session - if (!stored.accessToken || !stored.refreshToken || !stored.userData) { - setAuthState({ - user: null, - guest, - isAuthenticated: false, - isLoading: false, - isInitializing: false, - error: null, - mfaResponse: null, - }); - return; + // For guests, always verify the session exists on server + if (stored.userType === 'guest' && stored.userData) { + console.log(stored.userData); + try { + // Make a quick API call to verify guest still exists + const response = await fetch(`${apiClient.getBaseUrl()}/users/${stored.userData.id}`, { + headers: { 'Authorization': `Bearer ${stored.accessToken}` } + }); + + if (!response.ok) { + console.log('๐Ÿ”„ Guest session invalid, creating new guest session...'); + clearStoredAuth(); + await createGuestSession(); + return; + } + } catch (error) { + console.log('๐Ÿ”„ Guest verification failed, creating new guest session...'); + clearStoredAuth(); + await createGuestSession(); + return; + } } // Check if access token is expired if (isTokenExpired(stored.accessToken)) { - console.log('Access token expired, attempting refresh...'); + console.log('๐Ÿ”„ Access token expired, attempting refresh...'); const refreshResult = await refreshAccessToken(stored.refreshToken); if (refreshResult) { - storeAuthData(refreshResult); + const isGuest = stored.userType === 'guest'; + storeAuthData(refreshResult, isGuest); apiClient.setAuthToken(refreshResult.accessToken); setAuthState({ - user: refreshResult.user, - guest, + user: isGuest ? null : refreshResult.user, + guest: isGuest ? refreshResult.user as Types.Guest : null, isAuthenticated: true, + isGuest, isLoading: false, isInitializing: false, error: null, mfaResponse: null }); - console.log('Token refreshed successfully'); + console.log('โœ… Token refreshed successfully'); } else { - console.log('Token refresh failed, clearing stored auth'); + console.log('โŒ Token refresh failed, creating new guest session...'); clearStoredAuth(); apiClient.clearAuthToken(); - - setAuthState({ - user: null, - guest, - isAuthenticated: false, - isLoading: false, - isInitializing: false, - error: null, - mfaResponse: null - }); + await createGuestSession(); } } else { // Access token is still valid apiClient.setAuthToken(stored.accessToken); + const isGuest = stored.userType === 'guest'; setAuthState({ - user: stored.userData, - guest, + user: isGuest ? null : stored.userData, + guest: isGuest ? stored.userData as Types.Guest : null, isAuthenticated: true, + isGuest, isLoading: false, isInitializing: false, error: null, mfaResponse: null }); - console.log('Restored authentication from stored tokens'); + console.log('โœ… Restored authentication from stored tokens'); } } catch (error) { - console.error('Error initializing auth:', error); + console.error('โŒ Error initializing auth:', error); clearStoredAuth(); apiClient.clearAuthToken(); - - const guest = createGuestSession(); - setAuthState({ - user: null, - guest, - isAuthenticated: false, - isLoading: false, - isInitializing: false, - error: null, - mfaResponse: null - }); + await createGuestSession(); } finally { initializationCompleted.current = true; } - }, [apiClient, refreshAccessToken]); + }, [apiClient, refreshAccessToken, createGuestSession]); // Run initialization on mount useEffect(() => { @@ -355,7 +386,7 @@ function useAuthenticationLogic() { } const refreshTimer = setTimeout(() => { - console.log('Auto-refreshing token before expiry...'); + console.log('๐Ÿ”„ Auto-refreshing token before expiry...'); initializeAuth(); }, timeUntilExpiry); @@ -364,7 +395,7 @@ function useAuthenticationLogic() { // Enhanced login with MFA support const login = useCallback(async (loginData: LoginRequest): Promise => { - setAuthState(prev => ({ ...prev, isLoading: true, error: null, mfaResponse: null, mfaData: null })); + setAuthState(prev => ({ ...prev, isLoading: true, error: null, mfaResponse: null })); try { const result = await apiClient.login({ @@ -381,21 +412,23 @@ function useAuthenticationLogic() { })); return false; // Login not complete yet } else { - // Normal login success + // Normal login success - convert from guest to authenticated user const authResponse: Types.AuthResponse = result; - storeAuthData(authResponse); + storeAuthData(authResponse, false); apiClient.setAuthToken(authResponse.accessToken); setAuthState(prev => ({ ...prev, user: authResponse.user, + guest: null, isAuthenticated: true, + isGuest: false, isLoading: false, error: null, mfaResponse: null, })); - console.log('Login successful'); + console.log('โœ… Login successful, converted from guest to authenticated user'); return true; } } catch (error: any) { @@ -410,6 +443,44 @@ function useAuthenticationLogic() { } }, [apiClient]); + // Convert guest to permanent user + const convertGuestToUser = useCallback(async (registrationData: GuestConversionRequest): Promise => { + if (!authState.isGuest || !authState.guest) { + throw new Error('Not currently a guest user'); + } + + setAuthState(prev => ({ ...prev, isLoading: true, error: null })); + + try { + const result = await apiClient.convertGuestToUser(registrationData); + + // Store new authentication + storeAuthData(result.auth, false); + apiClient.setAuthToken(result.auth.accessToken); + + setAuthState(prev => ({ + ...prev, + user: result.auth.user, + guest: null, + isAuthenticated: true, + isGuest: false, + isLoading: false, + error: null, + })); + + console.log('โœ… Guest successfully converted to permanent user'); + return true; + } catch (error: any) { + const errorMessage = error instanceof Error ? error.message : 'Failed to convert guest account'; + setAuthState(prev => ({ + ...prev, + isLoading: false, + error: errorMessage, + })); + return false; + } + }, [apiClient, authState.isGuest, authState.guest]); + // MFA verification const verifyMFA = useCallback(async (mfaData: Types.MFAVerifyRequest): Promise => { setAuthState(prev => ({ ...prev, isLoading: true, error: null })); @@ -419,26 +490,27 @@ function useAuthenticationLogic() { if (result.accessToken) { const authResponse: Types.AuthResponse = result; - storeAuthData(authResponse); + storeAuthData(authResponse, false); apiClient.setAuthToken(authResponse.accessToken); setAuthState(prev => ({ ...prev, user: authResponse.user, + guest: null, isAuthenticated: true, + isGuest: false, isLoading: false, error: null, mfaResponse: null, })); - console.log('MFA verification successful'); + console.log('โœ… MFA verification successful, converted from guest'); return true; } return false; } catch (error) { const errorMessage = error instanceof Error ? error.message : 'MFA verification failed'; - console.log(errorMessage); setAuthState(prev => ({ ...prev, isLoading: false, @@ -448,42 +520,48 @@ function useAuthenticationLogic() { } }, [apiClient]); - // Resend MFA code - const resendMFACode = useCallback(async (email: string, deviceId: string, deviceName: string): Promise => { - setAuthState(prev => ({ ...prev, isLoading: true, error: null })); - + // Logout - returns to guest session + const logout = useCallback(async () => { try { - await apiClient.requestMFA({ - email, - password: '', // This would need to be stored securely or re-entered - deviceId, - deviceName, - }); - - setAuthState(prev => ({ ...prev, isLoading: false })); - return true; + // If authenticated, try to logout gracefully + if (authState.isAuthenticated && !authState.isGuest) { + const stored = getStoredAuthData(); + if (stored.accessToken && stored.refreshToken) { + try { + await apiClient.logout(stored.accessToken, stored.refreshToken); + } catch (error) { + console.warn('Logout request failed, proceeding with local cleanup'); + } + } + } } catch (error) { - const errorMessage = error instanceof Error ? error.message : 'Failed to resend MFA code'; - setAuthState(prev => ({ - ...prev, - isLoading: false, - error: errorMessage - })); - return false; - } - }, [apiClient]); + console.warn('Error during logout:', error); + } finally { + // Always clear stored auth and create new guest session + clearStoredAuth(); + apiClient.clearAuthToken(); + guestCreationAttempted.current = false; - // Clear MFA state - const clearMFA = useCallback(() => { + // Create new guest session + await createGuestSession(); + + console.log('๐Ÿ”„ Logged out, created new guest session'); + } + }, [apiClient, authState.isAuthenticated, authState.isGuest, createGuestSession]); + + // Update user data + const updateUserData = useCallback((updatedUser: Types.User) => { + updateStoredUserData(updatedUser); setAuthState(prev => ({ ...prev, - mfaResponse: null, - error: null + user: authState.isGuest ? null : updatedUser, + guest: authState.isGuest ? updatedUser as Types.Guest : prev.guest })); - }, []); + console.log('โœ… User data updated'); + }, [authState.isGuest]); - // Email verification - const verifyEmail = useCallback(async (verificationData: EmailVerificationRequest): Promise<{ message: string; userType: string } | null> => { + // Email verification functions (unchanged) + const verifyEmail = useCallback(async (verificationData: EmailVerificationRequest) => { setAuthState(prev => ({ ...prev, isLoading: true, error: null })); try { @@ -504,7 +582,7 @@ function useAuthenticationLogic() { } }, [apiClient]); - // Resend email verification + // Other existing methods remain the same... const resendEmailVerification = useCallback(async (email: string): Promise => { setAuthState(prev => ({ ...prev, isLoading: true, error: null })); @@ -523,53 +601,21 @@ function useAuthenticationLogic() { } }, [apiClient]); - // Store pending verification email const setPendingVerificationEmail = useCallback((email: string) => { localStorage.setItem(TOKEN_STORAGE.PENDING_VERIFICATION_EMAIL, email); }, []); - // Get pending verification email const getPendingVerificationEmail = useCallback((): string | null => { return localStorage.getItem(TOKEN_STORAGE.PENDING_VERIFICATION_EMAIL); }, []); - const logout = useCallback(() => { - clearStoredAuth(); - apiClient.clearAuthToken(); - - // Create new guest session after logout - const guest = createGuestSession(); - - setAuthState(prev => ({ - ...prev, - user: null, - guest, - isAuthenticated: false, - isLoading: false, - error: null, - mfaResponse: null, - })); - - console.log('User logged out'); - }, [apiClient]); - - const updateUserData = useCallback((updatedUser: Types.User) => { - updateStoredUserData(updatedUser); - setAuthState(prev => ({ - ...prev, - user: updatedUser - })); - console.log('User data updated', updatedUser); - }, []); - const createEmployerAccount = useCallback(async (employerData: CreateEmployerRequest): Promise => { setAuthState(prev => ({ ...prev, isLoading: true, error: null })); try { const employer = await apiClient.createEmployer(employerData); - console.log('Employer created:', employer); - - // Store email for potential verification resend + console.log('โœ… Employer created:', employer); + setPendingVerificationEmail(employerData.email); setAuthState(prev => ({ ...prev, isLoading: false })); @@ -614,24 +660,61 @@ function useAuthenticationLogic() { const refreshResult = await refreshAccessToken(stored.refreshToken); if (refreshResult) { - storeAuthData(refreshResult); + const isGuest = stored.userType === 'guest'; + storeAuthData(refreshResult, isGuest); apiClient.setAuthToken(refreshResult.accessToken); setAuthState(prev => ({ ...prev, - user: refreshResult.user, + user: isGuest ? null : refreshResult.user, + guest: isGuest ? refreshResult.user as Types.Guest : null, isAuthenticated: true, + isGuest, isLoading: false, error: null })); return true; } else { - logout(); + await logout(); return false; } }, [refreshAccessToken, logout]); + // Resend MFA code + const resendMFACode = useCallback(async (email: string, deviceId: string, deviceName: string): Promise => { + setAuthState(prev => ({ ...prev, isLoading: true, error: null })); + + try { + await apiClient.requestMFA({ + email, + password: '', // This would need to be stored securely or re-entered + deviceId, + deviceName, + }); + + setAuthState(prev => ({ ...prev, isLoading: false })); + return true; + } catch (error) { + const errorMessage = error instanceof Error ? error.message : 'Failed to resend MFA code'; + setAuthState(prev => ({ + ...prev, + isLoading: false, + error: errorMessage + })); + return false; + } + }, [apiClient]); + + // Clear MFA state + const clearMFA = useCallback(() => { + setAuthState(prev => ({ + ...prev, + mfaResponse: null, + error: null + })); + }, []); + return { ...authState, apiClient, @@ -647,12 +730,14 @@ function useAuthenticationLogic() { createEmployerAccount, requestPasswordReset, refreshAuth, - updateUserData + updateUserData, + convertGuestToUser, + createGuestSession }; } // ============================ -// Context Provider +// Enhanced Context Provider // ============================ const AuthContext = createContext | null>(null); @@ -676,34 +761,41 @@ function useAuth() { } // ============================ -// Protected Route Component +// Enhanced Protected Route Component // ============================ interface ProtectedRouteProps { children: React.ReactNode; fallback?: React.ReactNode; requiredUserType?: Types.UserType; + allowGuests?: boolean; } function ProtectedRoute({ children, fallback =
Please log in to access this page.
, - requiredUserType + requiredUserType, + allowGuests = false }: ProtectedRouteProps) { - const { isAuthenticated, isInitializing, user } = useAuth(); + const { isAuthenticated, isInitializing, user, isGuest } = useAuth(); // Show loading while checking stored tokens if (isInitializing) { return
Loading...
; } - // Not authenticated + // Not authenticated at all (shouldn't happen with guest sessions) if (!isAuthenticated) { return <>{fallback}; } - // Check user type if required - if (requiredUserType && user?.userType !== requiredUserType) { + // Guest access control + if (isGuest && !allowGuests) { + return
Please create an account or log in to access this page.
; + } + + // Check user type if required (only for non-guests) + if (requiredUserType && !isGuest && user?.userType !== requiredUserType) { return
Access denied. Required user type: {requiredUserType}
; } @@ -711,11 +803,19 @@ function ProtectedRoute({ } export type { - AuthState, LoginRequest, EmailVerificationRequest, ResendVerificationRequest, PasswordResetRequest + AuthState, + LoginRequest, + EmailVerificationRequest, + ResendVerificationRequest, + PasswordResetRequest, + GuestConversionRequest } export type { CreateCandidateRequest, CreateEmployerRequest } from '../services/api-client'; export { - useAuthenticationLogic, AuthProvider, useAuth, ProtectedRoute + useAuthenticationLogic, + AuthProvider, + useAuth, + ProtectedRoute } \ No newline at end of file diff --git a/frontend/src/pages/CandidateChatPage.tsx b/frontend/src/pages/CandidateChatPage.tsx index 346a69c..bebc583 100644 --- a/frontend/src/pages/CandidateChatPage.tsx +++ b/frontend/src/pages/CandidateChatPage.tsx @@ -23,15 +23,16 @@ import { useAppState, useSelectedCandidate } from 'hooks/GlobalContext'; import PropagateLoader from 'react-spinners/PropagateLoader'; import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryTextField'; import { BackstoryQuery } from 'components/BackstoryQuery'; +import { CandidatePicker } from 'components/ui/CandidatePicker'; const defaultMessage: ChatMessage = { - status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "user" + status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "user", metadata: null as any }; const CandidateChatPage = forwardRef((props: BackstoryPageProps, ref) => { const { apiClient } = useAuth(); const navigate = useNavigate(); - const { selectedCandidate } = useSelectedCandidate() + const { selectedCandidate, setSelectedCandidate } = useSelectedCandidate() const theme = useTheme(); const [processingMessage, setProcessingMessage] = useState(null); const [streamingMessage, setStreamingMessage] = useState(null); @@ -92,7 +93,7 @@ const CandidateChatPage = forwardRef((pr timestamp: new Date() }; - setProcessingMessage({ ...defaultMessage, status: 'status', content: `Establishing connection with ${selectedCandidate.firstName}'s chat session.` }); + setProcessingMessage({ ...defaultMessage, status: 'status', activity: "info", content: `Establishing connection with ${selectedCandidate.firstName}'s chat session.` }); setMessages(prev => { const filtered = prev.filter((m: any) => m.id !== chatMessage.id); @@ -123,7 +124,7 @@ const CandidateChatPage = forwardRef((pr }, onStreaming: (chunk: ChatMessageStreaming) => { // console.log("onStreaming:", chunk); - setStreamingMessage({ ...chunk, role: 'assistant' }); + setStreamingMessage({ ...chunk, role: 'assistant', metadata: null as any }); }, onStatus: (status: ChatMessageStatus) => { setProcessingMessage(status); @@ -171,8 +172,7 @@ const CandidateChatPage = forwardRef((pr }, [chatSession]); if (!selectedCandidate) { - navigate('/find-a-candidate'); - return (<>); + return ; } const welcomeMessage: ChatMessage = { @@ -181,7 +181,8 @@ const CandidateChatPage = forwardRef((pr type: "text", status: "done", timestamp: new Date(), - content: `Welcome to the Backstory Chat about ${selectedCandidate.fullName}. Ask any questions you have about ${selectedCandidate.firstName}.` + content: `Welcome to the Backstory Chat about ${selectedCandidate.fullName}. Ask any questions you have about ${selectedCandidate.firstName}.`, + metadata: null as any }; return ( @@ -192,12 +193,14 @@ const CandidateChatPage = forwardRef((pr gap: 1, }}> + {/* Chat Interface */} { diff --git a/frontend/src/pages/HomePage.tsx b/frontend/src/pages/HomePage.tsx index cae1f58..3ae1e21 100644 --- a/frontend/src/pages/HomePage.tsx +++ b/frontend/src/pages/HomePage.tsx @@ -20,6 +20,7 @@ import QuestionAnswerIcon from '@mui/icons-material/QuestionAnswer'; import DescriptionIcon from '@mui/icons-material/Description'; import professionalConversationPng from './Conversation.png'; import { ComingSoon } from 'components/ui/ComingSoon'; +import { useAuth } from 'hooks/AuthContext'; // Placeholder for Testimonials component const Testimonials = () => { @@ -135,6 +136,15 @@ const FeatureCard = ({ const HomePage = () => { const testimonials = false; + const { isGuest, guest, user } = useAuth(); + + if (isGuest) { + // Show guest-specific UI + console.log('Guest session:', guest?.sessionId || "No guest"); + } else { + // Show authenticated user UI + console.log('Authenticated user:', user?.email || "No user"); + } return ( {/* Hero Section */} diff --git a/frontend/src/pages/LoadingPage.tsx b/frontend/src/pages/LoadingPage.tsx index 4655c2a..c484360 100644 --- a/frontend/src/pages/LoadingPage.tsx +++ b/frontend/src/pages/LoadingPage.tsx @@ -10,7 +10,8 @@ const LoadingPage = (props: BackstoryPageProps) => { status: 'done', sessionId: '', content: 'Please wait while connecting to Backstory...', - timestamp: new Date() + timestamp: new Date(), + metadata: null as any } return diff --git a/frontend/src/pages/LoginPage.tsx b/frontend/src/pages/LoginPage.tsx index 6e11aad..ce0c802 100644 --- a/frontend/src/pages/LoginPage.tsx +++ b/frontend/src/pages/LoginPage.tsx @@ -26,6 +26,7 @@ import { LoginForm } from "components/EmailVerificationComponents"; import { CandidateRegistrationForm } from "components/RegistrationForms"; import { useNavigate } from 'react-router-dom'; import { useAppState } from 'hooks/GlobalContext'; +import * as Types from 'types/types'; const LoginPage: React.FC = (props: BackstoryPageProps) => { const navigate = useNavigate(); @@ -34,7 +35,7 @@ const LoginPage: React.FC = (props: BackstoryPageProps) => { const [loading, setLoading] = useState(false); const [success, setSuccess] = useState(null); const { guest, user, login, isLoading, error } = useAuth(); - const name = (user?.userType === 'candidate') ? user.username : user?.email || ''; + const name = (user?.userType === 'candidate') ? (user as Types.Candidate).username : user?.email || ''; const [errorMessage, setErrorMessage] = useState(null); const showGuest: boolean = false; diff --git a/frontend/src/pages/LoginRequired.tsx b/frontend/src/pages/LoginRequired.tsx index 460df04..e475015 100644 --- a/frontend/src/pages/LoginRequired.tsx +++ b/frontend/src/pages/LoginRequired.tsx @@ -10,7 +10,8 @@ const LoginRequired = (props: BackstoryPageProps) => { status: 'done', sessionId: '', content: 'You must be logged to view this feature.', - timestamp: new Date() + timestamp: new Date(), + metadata: null as any } return diff --git a/frontend/src/pages/OldChatPage.tsx b/frontend/src/pages/OldChatPage.tsx index 12f27b8..189d062 100644 --- a/frontend/src/pages/OldChatPage.tsx +++ b/frontend/src/pages/OldChatPage.tsx @@ -11,6 +11,7 @@ import { CandidateInfo } from 'components/ui/CandidateInfo'; import { useAuth } from 'hooks/AuthContext'; import { Candidate } from 'types/types'; import { useAppState } from 'hooks/GlobalContext'; +import * as Types from 'types/types'; const ChatPage = forwardRef((props: BackstoryPageProps, ref) => { const { setSnack } = useAppState(); @@ -18,7 +19,7 @@ const ChatPage = forwardRef((props: Back const theme = useTheme(); const isMobile = useMediaQuery(theme.breakpoints.down('md')); const [questions, setQuestions] = useState([]); - const candidate: Candidate | null = user?.userType === 'candidate' ? user : null; + const candidate: Candidate | null = user?.userType === 'candidate' ? user as Types.Candidate : null; // console.log("ChatPage candidate =>", candidate); useEffect(() => { diff --git a/frontend/src/pages/candidate/Dashboard.tsx b/frontend/src/pages/candidate/Dashboard.tsx index 85c0472..9a693f9 100644 --- a/frontend/src/pages/candidate/Dashboard.tsx +++ b/frontend/src/pages/candidate/Dashboard.tsx @@ -74,7 +74,7 @@ const CandidateDashboard = (props: CandidateDashboardProps) => { variant="contained" color="primary" sx={{ mt: 1 }} - onClick={(e) => {e.stopPropagation(); navigate('/candidate/dashboard/profile'); }} + onClick={(e) => { e.stopPropagation(); navigate('/candidate/profile'); }} > Complete Your Profile diff --git a/frontend/src/pages/candidate/Settings.tsx b/frontend/src/pages/candidate/Settings.tsx index 71c7a0b..17d2a47 100644 --- a/frontend/src/pages/candidate/Settings.tsx +++ b/frontend/src/pages/candidate/Settings.tsx @@ -14,9 +14,10 @@ import Typography from '@mui/material/Typography'; // import ResetIcon from '@mui/icons-material/History'; import ExpandMoreIcon from '@mui/icons-material/ExpandMore'; -import { connectionBase } from '../../utils/Global'; import { BackstoryPageProps } from '../../components/BackstoryTab'; import { useAppState } from 'hooks/GlobalContext'; +import { useAuth } from 'hooks/AuthContext'; +import * as Types from 'types/types'; interface ServerTunables { system_prompt: string, @@ -33,19 +34,7 @@ type Tool = { returns?: any }; -type GPUInfo = { - name: string, - memory: number, - discrete: boolean -} - -type SystemInfo = { - "Installed RAM": string, - "Graphics Card": GPUInfo[], - "CPU": string -}; - -const SystemInfoComponent: React.FC<{ systemInfo: SystemInfo | undefined }> = ({ systemInfo }) => { +const SystemInfoComponent: React.FC<{ systemInfo: Types.SystemInfo | undefined }> = ({ systemInfo }) => { const [systemElements, setSystemElements] = useState([]); const convertToSymbols = (text: string) => { @@ -86,91 +75,92 @@ const SystemInfoComponent: React.FC<{ systemInfo: SystemInfo | undefined }> = ({ }; const Settings = (props: BackstoryPageProps) => { + const { apiClient } = useAuth(); const { setSnack } = useAppState(); const [editSystemPrompt, setEditSystemPrompt] = useState(""); - const [systemInfo, setSystemInfo] = useState(undefined); + const [systemInfo, setSystemInfo] = useState(undefined); const [tools, setTools] = useState([]); const [rags, setRags] = useState([]); const [systemPrompt, setSystemPrompt] = useState(""); const [messageHistoryLength, setMessageHistoryLength] = useState(5); const [serverTunables, setServerTunables] = useState(undefined); - useEffect(() => { - if (serverTunables === undefined || systemPrompt === serverTunables.system_prompt || !systemPrompt.trim()) { - return; - } - const sendSystemPrompt = async (prompt: string) => { - try { - const response = await fetch(connectionBase + `/api/1.0/tunables`, { - method: 'PUT', - headers: { - 'Content-Type': 'application/json', - 'Accept': 'application/json', - }, - body: JSON.stringify({ "system_prompt": prompt }), - }); + // useEffect(() => { + // if (serverTunables === undefined || systemPrompt === serverTunables.system_prompt || !systemPrompt.trim()) { + // return; + // } + // const sendSystemPrompt = async (prompt: string) => { + // try { + // const response = await fetch(connectionBase + `/api/1.0/tunables`, { + // method: 'PUT', + // headers: { + // 'Content-Type': 'application/json', + // 'Accept': 'application/json', + // }, + // body: JSON.stringify({ "system_prompt": prompt }), + // }); - const tunables = await response.json(); - serverTunables.system_prompt = tunables.system_prompt; - console.log(tunables); - setSystemPrompt(tunables.system_prompt) - setSnack("System prompt updated", "success"); - } catch (error) { - console.error('Fetch error:', error); - setSnack("System prompt update failed", "error"); - } - }; + // const tunables = await response.json(); + // serverTunables.system_prompt = tunables.system_prompt; + // console.log(tunables); + // setSystemPrompt(tunables.system_prompt) + // setSnack("System prompt updated", "success"); + // } catch (error) { + // console.error('Fetch error:', error); + // setSnack("System prompt update failed", "error"); + // } + // }; - sendSystemPrompt(systemPrompt); + // sendSystemPrompt(systemPrompt); - }, [systemPrompt, setSnack, serverTunables]); + // }, [systemPrompt, setSnack, serverTunables]); - const reset = async (types: ("rags" | "tools" | "history" | "system_prompt")[], message: string = "Update successful.") => { - try { - const response = await fetch(connectionBase + `/api/1.0/reset/`, { - method: 'PUT', - headers: { - 'Content-Type': 'application/json', - 'Accept': 'application/json', - }, - body: JSON.stringify({ "reset": types }), - }); + // const reset = async (types: ("rags" | "tools" | "history" | "system_prompt")[], message: string = "Update successful.") => { + // try { + // const response = await fetch(connectionBase + `/api/1.0/reset/`, { + // method: 'PUT', + // headers: { + // 'Content-Type': 'application/json', + // 'Accept': 'application/json', + // }, + // body: JSON.stringify({ "reset": types }), + // }); - if (!response.ok) { - throw new Error(`Server responded with ${response.status}: ${response.statusText}`); - } + // if (!response.ok) { + // throw new Error(`Server responded with ${response.status}: ${response.statusText}`); + // } - if (!response.body) { - throw new Error('Response body is null'); - } + // if (!response.body) { + // throw new Error('Response body is null'); + // } - const data = await response.json(); - if (data.error) { - throw Error(data.error); - } + // const data = await response.json(); + // if (data.error) { + // throw Error(data.error); + // } - for (const [key, value] of Object.entries(data)) { - switch (key) { - case "rags": - setRags(value as Tool[]); - break; - case "tools": - setTools(value as Tool[]); - break; - case "system_prompt": - setSystemPrompt((value as ServerTunables)["system_prompt"].trim()); - break; - case "history": - console.log('TODO: handle history reset'); - break; - } - } - setSnack(message, "success"); - } catch (error) { - console.error('Fetch error:', error); - setSnack("Unable to restore defaults", "error"); - } - }; + // for (const [key, value] of Object.entries(data)) { + // switch (key) { + // case "rags": + // setRags(value as Tool[]); + // break; + // case "tools": + // setTools(value as Tool[]); + // break; + // case "system_prompt": + // setSystemPrompt((value as ServerTunables)["system_prompt"].trim()); + // break; + // case "history": + // console.log('TODO: handle history reset'); + // break; + // } + // } + // setSnack(message, "success"); + // } catch (error) { + // console.error('Fetch error:', error); + // setSnack("Unable to restore defaults", "error"); + // } + // }; // Get the system information useEffect(() => { @@ -179,27 +169,8 @@ const Settings = (props: BackstoryPageProps) => { } const fetchSystemInfo = async () => { try { - const response = await fetch(connectionBase + `/api/1.0/system-info`, { - method: 'GET', - headers: { - 'Content-Type': 'application/json', - }, - }) - - if (!response.ok) { - throw new Error(`Server responded with ${response.status}: ${response.statusText}`); - } - - if (!response.body) { - throw new Error('Response body is null'); - } - - const data = await response.json(); - if (data.error) { - throw Error(data.error); - } - - setSystemInfo(data); + const response: Types.SystemInfo = await apiClient.getSystemInfo(); + setSystemInfo(response); } catch (error) { console.error('Error obtaining system information:', error); setSnack("Unable to obtain system information.", "error"); @@ -217,101 +188,101 @@ const Settings = (props: BackstoryPageProps) => { setEditSystemPrompt(systemPrompt.trim()); }, [systemPrompt, setEditSystemPrompt]); - const toggleRag = async (tool: Tool) => { - tool.enabled = !tool.enabled - try { - const response = await fetch(connectionBase + `/api/1.0/tunables`, { - method: 'PUT', - headers: { - 'Content-Type': 'application/json', - 'Accept': 'application/json', - }, - body: JSON.stringify({ "rags": [{ "name": tool?.name, "enabled": tool.enabled }] }), - }); + // const toggleRag = async (tool: Tool) => { + // tool.enabled = !tool.enabled + // try { + // const response = await fetch(connectionBase + `/api/1.0/tunables`, { + // method: 'PUT', + // headers: { + // 'Content-Type': 'application/json', + // 'Accept': 'application/json', + // }, + // body: JSON.stringify({ "rags": [{ "name": tool?.name, "enabled": tool.enabled }] }), + // }); - const tunables: ServerTunables = await response.json(); - setRags(tunables.rags) - setSnack(`${tool?.name} ${tool.enabled ? "enabled" : "disabled"}`); - } catch (error) { - console.error('Fetch error:', error); - setSnack(`${tool?.name} ${tool.enabled ? "enabling" : "disabling"} failed.`, "error"); - tool.enabled = !tool.enabled - } - }; + // const tunables: ServerTunables = await response.json(); + // setRags(tunables.rags) + // setSnack(`${tool?.name} ${tool.enabled ? "enabled" : "disabled"}`); + // } catch (error) { + // console.error('Fetch error:', error); + // setSnack(`${tool?.name} ${tool.enabled ? "enabling" : "disabling"} failed.`, "error"); + // tool.enabled = !tool.enabled + // } + // }; - const toggleTool = async (tool: Tool) => { - tool.enabled = !tool.enabled - try { - const response = await fetch(connectionBase + `/api/1.0/tunables`, { - method: 'PUT', - headers: { - 'Content-Type': 'application/json', - 'Accept': 'application/json', - }, - body: JSON.stringify({ "tools": [{ "name": tool.name, "enabled": tool.enabled }] }), - }); + // const toggleTool = async (tool: Tool) => { + // tool.enabled = !tool.enabled + // try { + // const response = await fetch(connectionBase + `/api/1.0/tunables`, { + // method: 'PUT', + // headers: { + // 'Content-Type': 'application/json', + // 'Accept': 'application/json', + // }, + // body: JSON.stringify({ "tools": [{ "name": tool.name, "enabled": tool.enabled }] }), + // }); - const tunables: ServerTunables = await response.json(); - setTools(tunables.tools) - setSnack(`${tool.name} ${tool.enabled ? "enabled" : "disabled"}`); - } catch (error) { - console.error('Fetch error:', error); - setSnack(`${tool.name} ${tool.enabled ? "enabling" : "disabling"} failed.`, "error"); - tool.enabled = !tool.enabled - } - }; + // const tunables: ServerTunables = await response.json(); + // setTools(tunables.tools) + // setSnack(`${tool.name} ${tool.enabled ? "enabled" : "disabled"}`); + // } catch (error) { + // console.error('Fetch error:', error); + // setSnack(`${tool.name} ${tool.enabled ? "enabling" : "disabling"} failed.`, "error"); + // tool.enabled = !tool.enabled + // } + // }; // If the systemPrompt has not been set, fetch it from the server - useEffect(() => { - if (serverTunables !== undefined) { - return; - } - const fetchTunables = async () => { - try { - // Make the fetch request with proper headers - const response = await fetch(connectionBase + `/api/1.0/tunables`, { - method: 'GET', - headers: { - 'Content-Type': 'application/json', - 'Accept': 'application/json', - }, - }); - const data = await response.json(); - // console.log("Server tunables: ", data); - setServerTunables(data); - setSystemPrompt(data["system_prompt"]); - setTools(data["tools"]); - setRags(data["rags"]); - } catch (error) { - console.error('Fetch error:', error); - setSnack("System prompt update failed", "error"); - } - } + // useEffect(() => { + // if (serverTunables !== undefined) { + // return; + // } + // const fetchTunables = async () => { + // try { + // // Make the fetch request with proper headers + // const response = await fetch(connectionBase + `/api/1.0/tunables`, { + // method: 'GET', + // headers: { + // 'Content-Type': 'application/json', + // 'Accept': 'application/json', + // }, + // }); + // const data = await response.json(); + // // console.log("Server tunables: ", data); + // setServerTunables(data); + // setSystemPrompt(data["system_prompt"]); + // setTools(data["tools"]); + // setRags(data["rags"]); + // } catch (error) { + // console.error('Fetch error:', error); + // setSnack("System prompt update failed", "error"); + // } + // } - fetchTunables(); - }, [setServerTunables, setSystemPrompt, setMessageHistoryLength, serverTunables, setTools, setRags, setSnack]); + // fetchTunables(); + // }, [setServerTunables, setSystemPrompt, setMessageHistoryLength, serverTunables, setTools, setRags, setSnack]); - const toggle = async (type: string, index: number) => { - switch (type) { - case "rag": - if (rags === undefined) { - return; - } - toggleRag(rags[index]) - break; - case "tool": - if (tools === undefined) { - return; - } - toggleTool(tools[index]); - } - }; + // const toggle = async (type: string, index: number) => { + // switch (type) { + // case "rag": + // if (rags === undefined) { + // return; + // } + // toggleRag(rags[index]) + // break; + // case "tool": + // if (tools === undefined) { + // return; + // } + // toggleTool(tools[index]); + // } + // }; - const handleKeyPress = (event: any) => { - if (event.key === 'Enter' && event.ctrlKey) { - setSystemPrompt(editSystemPrompt); - } - }; + // const handleKeyPress = (event: any) => { + // if (event.key === 'Enter' && event.ctrlKey) { + // setSystemPrompt(editSystemPrompt); + // } + // }; return (
{/* diff --git a/frontend/src/routes/CandidateRoute.tsx b/frontend/src/routes/CandidateRoute.tsx index 745fd9d..c36f417 100644 --- a/frontend/src/routes/CandidateRoute.tsx +++ b/frontend/src/routes/CandidateRoute.tsx @@ -6,6 +6,7 @@ import { SetSnackType } from '../components/Snack'; import { LoadingComponent } from "../components/LoadingComponent"; import { User, Guest, Candidate } from 'types/types'; import { useAuth } from "hooks/AuthContext"; +import { useSelectedCandidate } from "hooks/GlobalContext"; interface CandidateRouteProps { guest?: Guest | null; @@ -15,19 +16,19 @@ interface CandidateRouteProps { const CandidateRoute: React.FC = (props: CandidateRouteProps) => { const { apiClient } = useAuth(); + const { selectedCandidate, setSelectedCandidate } = useSelectedCandidate(); const { setSnack } = props; const { username } = useParams<{ username: string }>(); - const [candidate, setCandidate] = useState(null); const navigate = useNavigate(); useEffect(() => { - if (candidate?.username === username || !username) { + if (selectedCandidate?.username === username || !username) { return; } const getCandidate = async (reference: string) => { try { const result: Candidate = await apiClient.getCandidate(reference); - setCandidate(result); + setSelectedCandidate(result); navigate('/chat'); } catch { setSnack(`Unable to obtain information for ${username}.`, "error"); @@ -36,9 +37,9 @@ const CandidateRoute: React.FC = (props: CandidateRouteProp } getCandidate(username); - }, [candidate, username, setCandidate, navigate, setSnack, apiClient]); + }, [selectedCandidate, username, selectedCandidate, navigate, setSnack, apiClient]); - if (candidate === null) { + if (selectedCandidate?.username !== username) { return ( + ) { + super(message); + this.name = 'RateLimitError'; + } +} interface StreamingOptions { method?: string, headers?: Record, @@ -777,7 +800,7 @@ class ApiClient { async getOrCreateChatSession(candidate: Types.Candidate, title: string, context_type: Types.ChatContextType) : Promise { const result = await this.getCandidateChatSessions(candidate.username); /* Find the 'candidate_chat' session if it exists, otherwise create it */ - let session = result.sessions.data.find(session => session.title === 'candidate_chat'); + let session = result.sessions.data.find(session => session.title === title); if (!session) { session = await this.createCandidateChatSession( candidate.username, @@ -788,6 +811,17 @@ class ApiClient { return session; } + async getSystemInfo() : Promise { + const response = await fetch(`${this.baseUrl}/system-info`, { + method: 'GET', + headers: this.defaultHeaders, + }); + + const result = await handleApiResponse(response); + + return result; + } + async getCandidateSimilarContent(query: string ): Promise { const response = await fetch(`${this.baseUrl}/candidates/rag-search`, { @@ -1003,6 +1037,202 @@ class ApiClient { return handleApiResponse<{ success: boolean; message: string }>(response); } + + // ============================ + // Guest Authentication Methods + // ============================ + + /** + * Create a guest session with authentication + */ + async createGuestSession(): Promise { + const response = await fetch(`${this.baseUrl}/auth/guest`, { + method: 'POST', + headers: this.defaultHeaders + }); + + const result = await handleApiResponse(response); + + // Convert guest data if needed + if (result.user && result.user.userType === 'guest') { + result.user = convertFromApi(result.user, "Guest"); + } + + return result; + } + + /** + * Convert guest account to permanent user account + */ + async convertGuestToUser( + registrationData: CreateCandidateRequest & { accountType: 'candidate' } + ): Promise<{ + message: string; + auth: Types.AuthResponse; + conversionType: string + }> { + const response = await fetch(`${this.baseUrl}/auth/guest/convert`, { + method: 'POST', + headers: this.defaultHeaders, + body: JSON.stringify(formatApiRequest(registrationData)) + }); + + const result = await handleApiResponse<{ + message: string; + auth: Types.AuthResponse; + conversionType: string; + }>(response); + + // Convert the auth user data + if (result.auth?.user) { + result.auth.user = convertFromApi(result.auth.user, "Candidate"); + } + + return result; + } + + /** + * Check if current session is a guest + */ + isGuestSession(): boolean { + try { + const userDataStr = localStorage.getItem(TOKEN_STORAGE.USER_DATA); + if (userDataStr) { + const userData = JSON.parse(userDataStr); + return userData.userType === 'guest'; + } + return false; + } catch { + return false; + } + } + + /** + * Get guest session info + */ + getGuestSessionInfo(): Types.Guest | null { + try { + const userDataStr = localStorage.getItem(TOKEN_STORAGE.USER_DATA); + if (userDataStr) { + const userData = JSON.parse(userDataStr); + if (userData.userType === 'guest') { + return convertFromApi(userData, "Guest"); + } + } + return null; + } catch { + return null; + } + } + + /** + * Get rate limit status for current user + */ + async getRateLimitStatus(): Promise<{ + user_id: string; + user_type: string; + is_admin: boolean; + current_usage: Record; + limits: Record; + remaining: Record; + reset_times: Record; + config: any; + }> { + const response = await fetch(`${this.baseUrl}/admin/rate-limits/info`, { + headers: this.defaultHeaders + }); + + return handleApiResponse(response); + } + + /** + * Get guest statistics (admin only) + */ + async getGuestStatistics(): Promise<{ + total_guests: number; + active_last_hour: number; + active_last_day: number; + converted_guests: number; + by_ip: Record; + creation_timeline: Record; + }> { + const response = await fetch(`${this.baseUrl}/admin/guests/statistics`, { + headers: this.defaultHeaders + }); + + return handleApiResponse(response); + } + + /** + * Cleanup inactive guests (admin only) + */ + async cleanupInactiveGuests(inactiveHours: number = 24): Promise<{ + message: string; + cleaned_count: number; + }> { + const response = await fetch(`${this.baseUrl}/admin/guests/cleanup`, { + method: 'POST', + headers: this.defaultHeaders, + body: JSON.stringify({ inactive_hours: inactiveHours }) + }); + + return handleApiResponse<{ + message: string; + cleaned_count: number; + }>(response); + } + + // ============================ + // Enhanced Error Handling for Rate Limits + // ============================ + + /** + * Enhanced API response handler with rate limit handling + */ + private async handleApiResponseWithRateLimit(response: Response): Promise { + if (response.status === 429) { + const rateLimitData = await response.json(); + const retryAfter = response.headers.get('Retry-After'); + + throw new RateLimitError( + rateLimitData.detail?.message || 'Rate limit exceeded', + parseInt(retryAfter || '60'), + rateLimitData.detail?.remaining || {} + ); + } + + return this.handleApiResponseWithConversion(response); + } + + /** + * Retry mechanism for rate-limited requests + */ + private async retryWithBackoff( + requestFn: () => Promise, + maxRetries: number = 3 + ): Promise { + let lastError: Error; + + for (let attempt = 0; attempt <= maxRetries; attempt++) { + try { + const response = await requestFn(); + return await this.handleApiResponseWithRateLimit(response); + } catch (error) { + lastError = error as Error; + + if (error instanceof RateLimitError && attempt < maxRetries) { + const delayMs = Math.min(error.retryAfterSeconds * 1000, 60000); // Max 1 minute + console.warn(`Rate limited, retrying in ${delayMs}ms (attempt ${attempt + 1}/${maxRetries + 1})`); + await new Promise(resolve => setTimeout(resolve, delayMs)); + continue; + } + + throw error; + } + } + + throw lastError!; + } async resetChatSession(id: string): Promise<{ success: boolean; message: string }> { const response = await fetch(`${this.baseUrl}/chat/sessions/${id}/reset`, { method: 'PATCH', @@ -1384,5 +1614,5 @@ export interface PendingVerification { attempts: number; } -export { ApiClient } +export { ApiClient, TOKEN_STORAGE } export type { StreamingOptions, StreamingResponse } \ No newline at end of file diff --git a/frontend/src/types/types.ts b/frontend/src/types/types.ts index a422f84..cd896b5 100644 --- a/frontend/src/types/types.ts +++ b/frontend/src/types/types.ts @@ -1,6 +1,6 @@ // Generated TypeScript types from Pydantic models // Source: src/backend/models.py -// Generated on: 2025-06-07T20:43:58.855207 +// Generated on: 2025-06-09T03:32:01.335483 // DO NOT EDIT MANUALLY - This file is auto-generated // ============================ @@ -128,8 +128,10 @@ export interface Attachment { export interface AuthResponse { accessToken: string; refreshToken: string; - user: any; + user: Candidate | Employer | Guest; expiresAt: number; + userType?: string; + isGuest?: boolean; } export interface Authentication { @@ -148,7 +150,9 @@ export interface Authentication { } export interface BaseUser { + userType: "candidate" | "employer" | "guest"; id?: string; + lastActivity?: Date; email: string; firstName: string; lastName: string; @@ -164,24 +168,15 @@ export interface BaseUser { } export interface BaseUserWithType { - id?: string; - email: string; - firstName: string; - lastName: string; - fullName: string; - phone?: string; - location?: Location; - createdAt: Date; - updatedAt: Date; - lastLogin?: Date; - profileImage?: string; - status: "active" | "inactive" | "pending" | "banned"; - isAdmin: boolean; userType: "candidate" | "employer" | "guest"; + id?: string; + lastActivity?: Date; } export interface Candidate { + userType: "candidate" | "employer" | "guest"; id?: string; + lastActivity?: Date; email: string; firstName: string; lastName: string; @@ -194,7 +189,6 @@ export interface Candidate { profileImage?: string; status: "active" | "inactive" | "pending" | "banned"; isAdmin: boolean; - userType: "candidate"; username: string; description?: string; resume?: string; @@ -214,7 +208,9 @@ export interface Candidate { } export interface CandidateAI { + userType: "candidate" | "employer" | "guest"; id?: string; + lastActivity?: Date; email: string; firstName: string; lastName: string; @@ -227,7 +223,6 @@ export interface CandidateAI { profileImage?: string; status: "active" | "inactive" | "pending" | "banned"; isAdmin: boolean; - userType: "candidate"; username: string; description?: string; resume?: string; @@ -301,7 +296,7 @@ export interface ChatMessage { role: "user" | "assistant" | "system" | "information" | "warning" | "error"; content: string; tunables?: Tunables; - metadata?: ChatMessageMetaData; + metadata: ChatMessageMetaData; } export interface ChatMessageError { @@ -319,9 +314,9 @@ export interface ChatMessageMetaData { temperature: number; maxTokens: number; topP: number; - frequencyPenalty?: number; - presencePenalty?: number; - stopSequences?: Array; + frequencyPenalty: number; + presencePenalty: number; + stopSequences: Array; ragResults?: Array; llmHistory?: Array; evalCount: number; @@ -392,7 +387,6 @@ export interface ChatQuery { export interface ChatSession { id?: string; userId?: string; - guestId?: string; createdAt?: Date; lastActivity?: Date; title?: string; @@ -507,7 +501,7 @@ export interface DocumentMessage { } export interface DocumentOptions { - includeInRAG?: boolean; + includeInRAG: boolean; isJobDocument?: boolean; overwrite?: boolean; } @@ -541,7 +535,9 @@ export interface EmailVerificationRequest { } export interface Employer { + userType: "candidate" | "employer" | "guest"; id?: string; + lastActivity?: Date; email: string; firstName: string; lastName: string; @@ -554,7 +550,6 @@ export interface Employer { profileImage?: string; status: "active" | "inactive" | "pending" | "banned"; isAdmin: boolean; - userType: "employer"; companyName: string; industry: string; description?: string; @@ -580,14 +575,71 @@ export interface ErrorDetail { details?: any; } +export interface GPUInfo { + name: string; + memory: number; + discrete: boolean; +} + export interface Guest { + userType: "candidate" | "employer" | "guest"; id?: string; - sessionId: string; + lastActivity?: Date; + email: string; + firstName: string; + lastName: string; + fullName: string; + phone?: string; + location?: Location; createdAt: Date; - lastActivity: Date; + updatedAt: Date; + lastLogin?: Date; + profileImage?: string; + status: "active" | "inactive" | "pending" | "banned"; + isAdmin: boolean; + sessionId: string; + username: string; convertedToUserId?: string; ipAddress?: string; userAgent?: string; + ragContentSize: number; +} + +export interface GuestCleanupRequest { + inactiveHours: number; +} + +export interface GuestConversionRequest { + accountType: "candidate" | "employer"; + email: string; + username: string; + password: string; + firstName: string; + lastName: string; + phone?: string; + companyName?: string; + industry?: string; + companySize?: string; + companyDescription?: string; + websiteUrl?: string; +} + +export interface GuestSessionResponse { + accessToken: string; + refreshToken: string; + user: Guest; + expiresAt: number; + userType: "guest"; + isGuest: boolean; +} + +export interface GuestStatistics { + totalGuests: number; + activeLastHour: number; + activeLastDay: number; + convertedGuests: number; + byIp: Record; + creationTimeline: Record; } export interface InterviewFeedback { @@ -746,6 +798,7 @@ export interface MFARequest { password: string; deviceId: string; deviceName: string; + email: string; } export interface MFARequestResponse { @@ -841,6 +894,33 @@ export interface RagEntry { enabled: boolean; } +export interface RateLimitConfig { + requestsPerMinute: number; + requestsPerHour: number; + requestsPerDay: number; + burstLimit: number; + burstWindowSeconds: number; +} + +export interface RateLimitResult { + allowed: boolean; + reason?: string; + retryAfterSeconds?: number; + remainingRequests?: Record; + resetTimes?: Record; +} + +export interface RateLimitStatus { + userId: string; + userType: string; + isAdmin: boolean; + currentUsage: Record; + limits: Record; + remaining: Record; + resetTimes: Record; + config: RateLimitConfig; +} + export interface RefreshToken { token: string; expiresAt: Date; @@ -915,6 +995,15 @@ export interface SocialLink { url: string; } +export interface SystemInfo { + installedRAM: string; + graphicsCards: Array; + CPU: string; + llmModel: string; + embeddingModel: string; + maxContextLength: number; +} + export interface Tunables { enableRAG: boolean; enableTools: boolean; @@ -1035,13 +1124,15 @@ export function convertAuthenticationFromApi(data: any): Authentication { } /** * Convert BaseUser from API response, parsing date fields - * Date fields: createdAt, updatedAt, lastLogin + * Date fields: lastActivity, createdAt, updatedAt, lastLogin */ export function convertBaseUserFromApi(data: any): BaseUser { if (!data) return data; return { ...data, + // Convert lastActivity from ISO string to Date + lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined, // Convert createdAt from ISO string to Date createdAt: new Date(data.createdAt), // Convert updatedAt from ISO string to Date @@ -1052,30 +1143,28 @@ export function convertBaseUserFromApi(data: any): BaseUser { } /** * Convert BaseUserWithType from API response, parsing date fields - * Date fields: createdAt, updatedAt, lastLogin + * Date fields: lastActivity */ export function convertBaseUserWithTypeFromApi(data: any): BaseUserWithType { if (!data) return data; return { ...data, - // Convert createdAt from ISO string to Date - createdAt: new Date(data.createdAt), - // Convert updatedAt from ISO string to Date - updatedAt: new Date(data.updatedAt), - // Convert lastLogin from ISO string to Date - lastLogin: data.lastLogin ? new Date(data.lastLogin) : undefined, + // Convert lastActivity from ISO string to Date + lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined, }; } /** * Convert Candidate from API response, parsing date fields - * Date fields: createdAt, updatedAt, lastLogin, availabilityDate + * Date fields: lastActivity, createdAt, updatedAt, lastLogin, availabilityDate */ export function convertCandidateFromApi(data: any): Candidate { if (!data) return data; return { ...data, + // Convert lastActivity from ISO string to Date + lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined, // Convert createdAt from ISO string to Date createdAt: new Date(data.createdAt), // Convert updatedAt from ISO string to Date @@ -1088,13 +1177,15 @@ export function convertCandidateFromApi(data: any): Candidate { } /** * Convert CandidateAI from API response, parsing date fields - * Date fields: createdAt, updatedAt, lastLogin, availabilityDate + * Date fields: lastActivity, createdAt, updatedAt, lastLogin, availabilityDate */ export function convertCandidateAIFromApi(data: any): CandidateAI { if (!data) return data; return { ...data, + // Convert lastActivity from ISO string to Date + lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined, // Convert createdAt from ISO string to Date createdAt: new Date(data.createdAt), // Convert updatedAt from ISO string to Date @@ -1282,13 +1373,15 @@ export function convertEducationFromApi(data: any): Education { } /** * Convert Employer from API response, parsing date fields - * Date fields: createdAt, updatedAt, lastLogin + * Date fields: lastActivity, createdAt, updatedAt, lastLogin */ export function convertEmployerFromApi(data: any): Employer { if (!data) return data; return { ...data, + // Convert lastActivity from ISO string to Date + lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined, // Convert createdAt from ISO string to Date createdAt: new Date(data.createdAt), // Convert updatedAt from ISO string to Date @@ -1299,17 +1392,21 @@ export function convertEmployerFromApi(data: any): Employer { } /** * Convert Guest from API response, parsing date fields - * Date fields: createdAt, lastActivity + * Date fields: lastActivity, createdAt, updatedAt, lastLogin */ export function convertGuestFromApi(data: any): Guest { if (!data) return data; return { ...data, + // Convert lastActivity from ISO string to Date + lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined, // Convert createdAt from ISO string to Date createdAt: new Date(data.createdAt), - // Convert lastActivity from ISO string to Date - lastActivity: new Date(data.lastActivity), + // Convert updatedAt from ISO string to Date + updatedAt: new Date(data.updatedAt), + // Convert lastLogin from ISO string to Date + lastLogin: data.lastLogin ? new Date(data.lastLogin) : undefined, }; } /** @@ -1415,6 +1512,32 @@ export function convertRAGConfigurationFromApi(data: any): RAGConfiguration { updatedAt: new Date(data.updatedAt), }; } +/** + * Convert RateLimitResult from API response, parsing date fields + * Date fields: resetTimes + */ +export function convertRateLimitResultFromApi(data: any): RateLimitResult { + if (!data) return data; + + return { + ...data, + // Convert resetTimes from ISO string to Date + resetTimes: data.resetTimes ? new Date(data.resetTimes) : undefined, + }; +} +/** + * Convert RateLimitStatus from API response, parsing date fields + * Date fields: resetTimes + */ +export function convertRateLimitStatusFromApi(data: any): RateLimitStatus { + if (!data) return data; + + return { + ...data, + // Convert resetTimes from ISO string to Date + resetTimes: new Date(data.resetTimes), + }; +} /** * Convert RefreshToken from API response, parsing date fields * Date fields: expiresAt @@ -1527,6 +1650,10 @@ export function convertFromApi(data: any, modelType: string): T { return convertMessageReactionFromApi(data) as T; case 'RAGConfiguration': return convertRAGConfigurationFromApi(data) as T; + case 'RateLimitResult': + return convertRateLimitResultFromApi(data) as T; + case 'RateLimitStatus': + return convertRateLimitStatusFromApi(data) as T; case 'RefreshToken': return convertRefreshTokenFromApi(data) as T; case 'UserActivity': diff --git a/frontend/src/utils/Global.tsx b/frontend/src/utils/Global.tsx deleted file mode 100644 index dc7faa4..0000000 --- a/frontend/src/utils/Global.tsx +++ /dev/null @@ -1,15 +0,0 @@ -const getConnectionBase = (loc: any): string => { - if (!loc.host.match(/.*battle-linux.*/) - // && !loc.host.match(/.*backstory-beta.*/) - ) { - return loc.protocol + "//" + loc.host; - } else { - return loc.protocol + "//battle-linux.ketrenos.com:8912"; - } -} - -const connectionBase = getConnectionBase(window.location); - -export { - connectionBase -}; \ No newline at end of file diff --git a/src/backend/agents/__init__.py b/src/backend/agents/__init__.py index 27d6f40..4a204af 100644 --- a/src/backend/agents/__init__.py +++ b/src/backend/agents/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations import traceback -from pydantic import BaseModel, Field # type: ignore +from pydantic import BaseModel, Field from typing import ( Literal, get_args, @@ -17,7 +17,7 @@ from typing import ( import importlib import pathlib import inspect -from prometheus_client import CollectorRegistry # type: ignore +from prometheus_client import CollectorRegistry from . base import Agent from logger import logger @@ -90,7 +90,7 @@ for path in package_dir.glob("*.py"): class_registry[name] = (full_module_name, name) globals()[name] = obj logger.info(f"Adding agent: {name}") - __all__.append(name) # type: ignore + __all__.append(name) except ImportError as e: logger.error(traceback.format_exc()) logger.error(f"Error importing {full_module_name}: {e}") diff --git a/src/backend/agents/base.py b/src/backend/agents/base.py index e74fed8..1031a2c 100644 --- a/src/backend/agents/base.py +++ b/src/backend/agents/base.py @@ -1,5 +1,5 @@ from __future__ import annotations -from pydantic import BaseModel, Field, model_validator # type: ignore +from pydantic import BaseModel, Field, model_validator from typing import ( Literal, get_args, @@ -20,8 +20,8 @@ import re from abc import ABC import asyncio from datetime import datetime, UTC -from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore -import numpy as np # type: ignore +from prometheus_client import Counter, Summary, CollectorRegistry +import numpy as np from models import ( ApiActivityType, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatQuery, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate) from logger import logger @@ -491,7 +491,7 @@ Content: {content} session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7 - ) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]: + ) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessageRagSearch, None]: if not self.user: error_message = ChatMessageError( session_id=session_id, diff --git a/src/backend/agents/candidate_chat.py b/src/backend/agents/candidate_chat.py index 1084b38..4c217e6 100644 --- a/src/backend/agents/candidate_chat.py +++ b/src/backend/agents/candidate_chat.py @@ -28,7 +28,7 @@ class CandidateChat(Agent): CandidateChat Agent """ - agent_type: Literal["candidate_chat"] = "candidate_chat" # type: ignore + agent_type: Literal["candidate_chat"] = "candidate_chat" _agent_type: ClassVar[str] = agent_type # Add this for registration system_prompt: str = system_message diff --git a/src/backend/agents/general.py b/src/backend/agents/general.py index 83b6a5d..445ab72 100644 --- a/src/backend/agents/general.py +++ b/src/backend/agents/general.py @@ -53,7 +53,7 @@ class Chat(Agent): Chat Agent """ - agent_type: Literal["general"] = "general" # type: ignore + agent_type: Literal["general"] = "general" _agent_type: ClassVar[str] = agent_type # Add this for registration system_prompt: str = system_message diff --git a/src/backend/agents/generate_image.py b/src/backend/agents/generate_image.py index 97fef5b..fa0b87e 100644 --- a/src/backend/agents/generate_image.py +++ b/src/backend/agents/generate_image.py @@ -1,6 +1,6 @@ from __future__ import annotations from datetime import UTC, datetime -from pydantic import model_validator, Field, BaseModel # type: ignore +from pydantic import model_validator, Field, BaseModel from typing import ( Dict, Literal, @@ -37,7 +37,7 @@ seed = int(time.time()) random.seed(seed) class ImageGenerator(Agent): - agent_type: Literal["generate_image"] = "generate_image" # type: ignore + agent_type: Literal["generate_image"] = "generate_image" _agent_type: ClassVar[str] = agent_type # Add this for registration agent_persist: bool = False diff --git a/src/backend/agents/generate_persona.py b/src/backend/agents/generate_persona.py index 67d4cf6..7fee083 100644 --- a/src/backend/agents/generate_persona.py +++ b/src/backend/agents/generate_persona.py @@ -1,6 +1,6 @@ from __future__ import annotations from datetime import UTC, datetime -from pydantic import model_validator, Field, BaseModel # type: ignore +from pydantic import model_validator, Field, BaseModel from typing import ( Dict, Literal, @@ -23,7 +23,7 @@ import asyncio import time import os import random -from names_dataset import NameDataset, NameWrapper # type: ignore +from names_dataset import NameDataset, NameWrapper from .base import Agent, agent_registry, LLMMessage from models import ApiActivityType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, Tunables @@ -128,7 +128,7 @@ logger = logging.getLogger(__name__) class EthnicNameGenerator: def __init__(self): try: - from names_dataset import NameDataset # type: ignore + from names_dataset import NameDataset self.nd = NameDataset() except ImportError: logger.error("NameDataset not available. Please install: pip install names-dataset") @@ -292,7 +292,7 @@ class EthnicNameGenerator: return names class GeneratePersona(Agent): - agent_type: Literal["generate_persona"] = "generate_persona" # type: ignore + agent_type: Literal["generate_persona"] = "generate_persona" _agent_type: ClassVar[str] = agent_type # Add this for registration agent_persist: bool = False diff --git a/src/backend/agents/job_requirements.py b/src/backend/agents/job_requirements.py index 1eb79ea..b25b48c 100644 --- a/src/backend/agents/job_requirements.py +++ b/src/backend/agents/job_requirements.py @@ -1,5 +1,5 @@ from __future__ import annotations -from pydantic import model_validator, Field # type: ignore +from pydantic import model_validator, Field from typing import ( Dict, Literal, @@ -16,7 +16,7 @@ import json import asyncio import time import asyncio -import numpy as np # type: ignore +import numpy as np from .base import Agent, agent_registry, LLMMessage from models import ApiActivityType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, JobRequirements, JobRequirementsMessage, Tunables @@ -26,7 +26,7 @@ import defines import backstory_traceback as traceback class JobRequirementsAgent(Agent): - agent_type: Literal["job_requirements"] = "job_requirements" # type: ignore + agent_type: Literal["job_requirements"] = "job_requirements" _agent_type: ClassVar[str] = agent_type # Add this for registration # Stage 1A: Job Analysis Implementation diff --git a/src/backend/agents/rag_search.py b/src/backend/agents/rag_search.py index 3f6abad..277d52b 100644 --- a/src/backend/agents/rag_search.py +++ b/src/backend/agents/rag_search.py @@ -15,7 +15,7 @@ class Chat(Agent): Chat Agent """ - agent_type: Literal["rag_search"] = "rag_search" # type: ignore + agent_type: Literal["rag_search"] = "rag_search" _agent_type: ClassVar[str] = agent_type # Add this for registration async def generate( diff --git a/src/backend/agents/skill_match.py b/src/backend/agents/skill_match.py index 8d617a6..13ad9db 100644 --- a/src/backend/agents/skill_match.py +++ b/src/backend/agents/skill_match.py @@ -1,5 +1,5 @@ from __future__ import annotations -from pydantic import model_validator, Field # type: ignore +from pydantic import model_validator, Field from typing import ( Dict, Literal, @@ -16,7 +16,7 @@ import json import asyncio import time import asyncio -import numpy as np # type: ignore +import numpy as np from .base import Agent, agent_registry, LLMMessage from models import Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, SkillMatch, Tunables @@ -25,7 +25,7 @@ from logger import logger import defines class SkillMatchAgent(Agent): - agent_type: Literal["skill_match"] = "skill_match" # type: ignore + agent_type: Literal["skill_match"] = "skill_match" _agent_type: ClassVar[str] = agent_type # Add this for registration def generate_skill_assessment_prompt(self, skill, rag_context): diff --git a/src/backend/auth_utils.py b/src/backend/auth_utils.py index 17a487a..0a31dac 100644 --- a/src/backend/auth_utils.py +++ b/src/backend/auth_utils.py @@ -5,12 +5,12 @@ Provides password hashing, verification, and security features """ import traceback -import bcrypt # type: ignore +import bcrypt import secrets import logging from datetime import datetime, timezone, timedelta from typing import Dict, Any, Optional, Tuple -from pydantic import BaseModel # type: ignore +from pydantic import BaseModel logger = logging.getLogger(__name__) diff --git a/src/backend/background_tasks.py b/src/backend/background_tasks.py new file mode 100644 index 0000000..a447a30 --- /dev/null +++ b/src/backend/background_tasks.py @@ -0,0 +1,175 @@ +""" +Background tasks for guest cleanup and system maintenance +""" + +import asyncio +import schedule # type: ignore +import threading +import time +from datetime import datetime, timedelta, UTC +from typing import Optional +from logger import logger +from database import DatabaseManager + +class BackgroundTaskManager: + """Manages background tasks for the application""" + + def __init__(self, database_manager: DatabaseManager): + self.database_manager = database_manager + self.running = False + self.tasks = [] + self.scheduler_thread: Optional[threading.Thread] = None + + async def cleanup_inactive_guests(self, inactive_hours: int = 24): + """Clean up inactive guest sessions""" + try: + database = self.database_manager.get_database() + cleaned_count = await database.cleanup_inactive_guests(inactive_hours) + + if cleaned_count > 0: + logger.info(f"๐Ÿงน Background cleanup: removed {cleaned_count} inactive guest sessions") + + return cleaned_count + except Exception as e: + logger.error(f"โŒ Error in guest cleanup: {e}") + return 0 + + async def cleanup_expired_verification_tokens(self): + """Clean up expired email verification tokens""" + try: + database = self.database_manager.get_database() + cleaned_count = await database.cleanup_expired_verification_tokens() + + if cleaned_count > 0: + logger.info(f"๐Ÿงน Background cleanup: removed {cleaned_count} expired verification tokens") + + return cleaned_count + except Exception as e: + logger.error(f"โŒ Error in verification token cleanup: {e}") + return 0 + + async def update_guest_statistics(self): + """Update guest usage statistics""" + try: + database = self.database_manager.get_database() + stats = await database.get_guest_statistics() + + # Log interesting statistics + if stats.get('total_guests', 0) > 0: + logger.info(f"๐Ÿ“Š Guest stats: {stats['total_guests']} total, " + f"{stats['active_last_hour']} active in last hour, " + f"{stats['converted_guests']} converted") + + return stats + except Exception as e: + logger.error(f"โŒ Error updating guest statistics: {e}") + return {} + + async def cleanup_old_rate_limit_data(self, days_old: int = 7): + """Clean up old rate limiting data""" + try: + database = self.database_manager.get_database() + redis = database.redis + + # Clean up rate limit keys older than specified days + cutoff_time = datetime.now(UTC) - timedelta(days=days_old) + pattern = "rate_limit:*" + + cursor = 0 + deleted_count = 0 + + while True: + cursor, keys = await redis.scan(cursor, match=pattern, count=100) + + for key in keys: + # Check if key is old enough to delete + try: + ttl = await redis.ttl(key) + if ttl == -1: # No expiration set, check creation time + # For simplicity, delete keys without TTL + await redis.delete(key) + deleted_count += 1 + except Exception: + continue + + if cursor == 0: + break + + if deleted_count > 0: + logger.info(f"๐Ÿงน Cleaned up {deleted_count} old rate limit keys") + + return deleted_count + except Exception as e: + logger.error(f"โŒ Error cleaning up rate limit data: {e}") + return 0 + + def schedule_periodic_tasks(self): + """Schedule periodic background tasks with safer intervals""" + + # Guest cleanup - every 6 hours instead of every hour (less aggressive) + schedule.every(6).hours.do(self._run_async_task, self.cleanup_inactive_guests, 48) # 48 hours instead of 24 + + # Verification token cleanup - every 12 hours + schedule.every(12).hours.do(self._run_async_task, self.cleanup_expired_verification_tokens) + + # Guest statistics update - every hour + schedule.every().hour.do(self._run_async_task, self.update_guest_statistics) + + # Rate limit data cleanup - daily at 3 AM + schedule.every().day.at("03:00").do(self._run_async_task, self.cleanup_old_rate_limit_data, 7) + + logger.info("๐Ÿ“… Background tasks scheduled with safer intervals") + + def _run_async_task(self, coro_func, *args, **kwargs): + """Run an async task in the background""" + try: + # Create new event loop for this thread if needed + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Run the coroutine + loop.run_until_complete(coro_func(*args, **kwargs)) + except Exception as e: + logger.error(f"โŒ Error running background task {coro_func.__name__}: {e}") + + def _scheduler_worker(self): + """Worker thread for running scheduled tasks""" + while self.running: + try: + schedule.run_pending() + time.sleep(60) # Check every minute + except Exception as e: + logger.error(f"โŒ Error in scheduler worker: {e}") + time.sleep(60) + + def start(self): + """Start the background task manager""" + if self.running: + logger.warning("โš ๏ธ Background task manager already running") + return + + self.running = True + self.schedule_periodic_tasks() + + # Start scheduler thread + self.scheduler_thread = threading.Thread(target=self._scheduler_worker, daemon=True) + self.scheduler_thread.start() + + logger.info("๐Ÿš€ Background task manager started") + + def stop(self): + """Stop the background task manager""" + self.running = False + + if self.scheduler_thread and self.scheduler_thread.is_alive(): + self.scheduler_thread.join(timeout=5) + + # Clear scheduled tasks + schedule.clear() + + logger.info("๐Ÿ›‘ Background task manager stopped") + + diff --git a/src/backend/database.py b/src/backend/database.py index 739e6b0..198ee72 100644 --- a/src/backend/database.py +++ b/src/backend/database.py @@ -9,6 +9,7 @@ from models import ( # User models Candidate, Employer, BaseUser, Guest, Authentication, AuthResponse, ) +import backstory_traceback as traceback logger = logging.getLogger(__name__) @@ -198,6 +199,7 @@ class RedisDatabase: try: return json.loads(data) except json.JSONDecodeError: + logger.error(traceback.format_exc()) logger.error(f"Failed to deserialize data: {data}") return None @@ -254,43 +256,44 @@ class RedisDatabase: raise async def get_cached_skill_match(self, cache_key: str) -> Optional[Dict[str, Any]]: - """Retrieve cached skill match assessment""" + """Get cached skill match assessment""" try: - cached_data = await self.redis.get(cache_key) - if cached_data: - return json.loads(cached_data) + data = await self.redis.get(cache_key) + if data: + return json.loads(data) return None except Exception as e: - logger.error(f"Error retrieving cached skill match: {e}") + logger.error(f"โŒ Error getting cached skill match: {e}") return None - - async def cache_skill_match(self, cache_key: str, assessment_data: Dict[str, Any], ttl: int = 86400 * 30) -> bool: - """Cache skill match assessment with TTL (default 30 days)""" + + async def cache_skill_match(self, cache_key: str, assessment_data: Dict[str, Any]) -> None: + """Cache skill match assessment""" try: + # Cache for 1 hour by default await self.redis.setex( cache_key, - ttl, - json.dumps(assessment_data, default=str) + 3600, + json.dumps(assessment_data) ) - return True + logger.debug(f"๐Ÿ’พ Skill match cached: {cache_key}") except Exception as e: - logger.error(f"Error caching skill match: {e}") - return False - + logger.error(f"โŒ Error caching skill match: {e}") + async def get_candidate_skill_update_time(self, candidate_id: str) -> Optional[datetime]: - """Get the last time candidate's skill information was updated""" + """Get the last time candidate skills were updated""" try: - # This assumes you track skill update timestamps in your candidate data candidate_data = await self.get_candidate(candidate_id) - if candidate_data and 'skills_updated_at' in candidate_data: - return datetime.fromisoformat(candidate_data['skills_updated_at']) + if candidate_data: + updated_at_str = candidate_data.get("updated_at") + if updated_at_str: + return datetime.fromisoformat(updated_at_str.replace('Z', '+00:00')) return None except Exception as e: - logger.error(f"Error getting candidate skill update time: {e}") + logger.error(f"โŒ Error getting candidate skill update time: {e}") return None - + async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]: - """Get the timestamp of the latest RAG data update for a specific user""" + """Get the last time user's RAG data was updated""" try: rag_update_key = f"user:{user_id}:rag_last_update" timestamp_str = await self.redis.get(rag_update_key) @@ -298,8 +301,8 @@ class RedisDatabase: return datetime.fromisoformat(timestamp_str.decode('utf-8')) return None except Exception as e: - logger.error(f"Error getting user RAG update time for user {user_id}: {e}") - return None + logger.error(f"โŒ Error getting user RAG update time: {e}") + return None async def update_user_rag_timestamp(self, user_id: str) -> bool: """Update the RAG data timestamp for a specific user (call this when user's RAG data is updated)""" @@ -359,7 +362,7 @@ class RedisDatabase: async def get_candidate_documents(self, candidate_id: str) -> List[Dict]: """Get all documents for a specific candidate""" key = f"{self.KEY_PREFIXES['candidate_documents']}{candidate_id}" - document_ids = await self.redis.lrange(key, 0, -1) + document_ids = await self.redis.lrange(key, 0, -1) if not document_ids: return [] @@ -1707,7 +1710,11 @@ class RedisDatabase: result = {} for key, value in zip(keys, values): email = key.replace(self.KEY_PREFIXES['users'], '') - result[email] = self._deserialize(value) + logger.info(f"๐Ÿ” Found user key: {key}, type: {type(value)}") + if type(value) == str: + result[email] = value + else: + result[email] = self._deserialize(value) return result @@ -1850,17 +1857,16 @@ class RedisDatabase: return False async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: - """Retrieve user data by ID""" + """Get user lookup data by user ID""" try: - key = f"user_by_id:{user_id}" - data = await self.redis.get(key) + data = await self.redis.hget("user_lookup_by_id", user_id) if data: return json.loads(data) return None except Exception as e: - logger.error(f"โŒ Error retrieving user by ID {user_id}: {e}") - return None - + logger.error(f"โŒ Error getting user by ID {user_id}: {e}") + return None + async def user_exists_by_email(self, email: str) -> bool: """Check if a user exists with the given email""" try: @@ -2104,6 +2110,208 @@ class RedisDatabase: logger.error(f"โŒ Error retrieving security log for {user_id}: {e}") return [] + # ============================ + # Guest Management Methods + # ============================ + + async def set_guest(self, guest_id: str, guest_data: Dict[str, Any]) -> None: + """Store guest data with enhanced persistence""" + try: + # Ensure last_activity is always set + guest_data["last_activity"] = datetime.now(UTC).isoformat() + + # Store in Redis with both hash and individual key for redundancy + await self.redis.hset("guests", guest_id, json.dumps(guest_data)) + + # Also store with a longer TTL as backup + await self.redis.setex( + f"guest_backup:{guest_id}", + 86400 * 7, # 7 days TTL + json.dumps(guest_data) + ) + + logger.debug(f"๐Ÿ’พ Guest stored with backup: {guest_id}") + except Exception as e: + logger.error(f"โŒ Error storing guest {guest_id}: {e}") + raise + + async def get_guest(self, guest_id: str) -> Optional[Dict[str, Any]]: + """Get guest data with fallback to backup""" + try: + # Try primary storage first + data = await self.redis.hget("guests", guest_id) + if data: + guest_data = json.loads(data) + # Update last activity when accessed + guest_data["last_activity"] = datetime.now(UTC).isoformat() + await self.set_guest(guest_id, guest_data) + logger.debug(f"๐Ÿ” Guest found in primary storage: {guest_id}") + return guest_data + + # Fallback to backup storage + backup_data = await self.redis.get(f"guest_backup:{guest_id}") + if backup_data: + guest_data = json.loads(backup_data) + guest_data["last_activity"] = datetime.now(UTC).isoformat() + + # Restore to primary storage + await self.set_guest(guest_id, guest_data) + logger.info(f"๐Ÿ”„ Guest restored from backup: {guest_id}") + return guest_data + + logger.warning(f"โš ๏ธ Guest not found: {guest_id}") + return None + except Exception as e: + logger.error(f"โŒ Error getting guest {guest_id}: {e}") + return None + + async def get_guest_by_session_id(self, session_id: str) -> Optional[Dict[str, Any]]: + """Get guest data by session ID""" + try: + all_guests = await self.get_all_guests() + for guest_data in all_guests.values(): + if guest_data.get("session_id") == session_id: + return guest_data + return None + except Exception as e: + logger.error(f"โŒ Error getting guest by session ID {session_id}: {e}") + return None + + async def get_all_guests(self) -> Dict[str, Dict[str, Any]]: + """Get all guests""" + try: + data = await self.redis.hgetall("guests") + return { + guest_id: json.loads(guest_json) + for guest_id, guest_json in data.items() + } + except Exception as e: + logger.error(f"โŒ Error getting all guests: {e}") + return {} + + async def delete_guest(self, guest_id: str) -> bool: + """Delete a guest""" + try: + result = await self.redis.hdel("guests", guest_id) + if result: + logger.info(f"๐Ÿ—‘๏ธ Guest deleted: {guest_id}") + return True + return False + except Exception as e: + logger.error(f"โŒ Error deleting guest {guest_id}: {e}") + return False + + async def cleanup_inactive_guests(self, inactive_hours: int = 24) -> int: + """Clean up inactive guest sessions with safety checks""" + try: + all_guests = await self.get_all_guests() + current_time = datetime.now(UTC) + cutoff_time = current_time - timedelta(hours=inactive_hours) + + deleted_count = 0 + preserved_count = 0 + + for guest_id, guest_data in all_guests.items(): + try: + last_activity_str = guest_data.get("last_activity") + created_at_str = guest_data.get("created_at") + + # Skip cleanup if guest is very new (less than 1 hour old) + if created_at_str: + created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00')) + if current_time - created_at < timedelta(hours=1): + preserved_count += 1 + logger.debug(f"๐Ÿ›ก๏ธ Preserving new guest: {guest_id}") + continue + + # Check last activity + should_delete = False + if last_activity_str: + try: + last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00')) + if last_activity < cutoff_time: + should_delete = True + except ValueError: + # Invalid date format, but don't delete if guest is new + if not created_at_str: + should_delete = True + else: + # No last activity, but don't delete if guest is new + if not created_at_str: + should_delete = True + + if should_delete: + await self.delete_guest(guest_id) + deleted_count += 1 + else: + preserved_count += 1 + + except Exception as e: + logger.error(f"โŒ Error processing guest {guest_id} for cleanup: {e}") + preserved_count += 1 # Preserve on error + + if deleted_count > 0: + logger.info(f"๐Ÿงน Guest cleanup: removed {deleted_count}, preserved {preserved_count}") + + return deleted_count + except Exception as e: + logger.error(f"โŒ Error in guest cleanup: {e}") + return 0 + + async def get_guest_statistics(self) -> Dict[str, Any]: + """Get guest usage statistics""" + try: + all_guests = await self.get_all_guests() + current_time = datetime.now(UTC) + + stats = { + "total_guests": len(all_guests), + "active_last_hour": 0, + "active_last_day": 0, + "converted_guests": 0, + "by_ip": {}, + "creation_timeline": {} + } + + hour_ago = current_time - timedelta(hours=1) + day_ago = current_time - timedelta(days=1) + + for guest_data in all_guests.values(): + # Check activity + last_activity_str = guest_data.get("last_activity") + if last_activity_str: + try: + last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00')) + if last_activity > hour_ago: + stats["active_last_hour"] += 1 + if last_activity > day_ago: + stats["active_last_day"] += 1 + except ValueError: + pass + + # Check conversions + if guest_data.get("converted_to_user_id"): + stats["converted_guests"] += 1 + + # IP tracking + ip = guest_data.get("ip_address", "unknown") + stats["by_ip"][ip] = stats["by_ip"].get(ip, 0) + 1 + + # Creation timeline + created_at_str = guest_data.get("created_at") + if created_at_str: + try: + created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00')) + date_key = created_at.strftime('%Y-%m-%d') + stats["creation_timeline"][date_key] = stats["creation_timeline"].get(date_key, 0) + 1 + except ValueError: + pass + + return stats + except Exception as e: + logger.error(f"โŒ Error getting guest statistics: {e}") + return {} + # Global Redis manager instance redis_manager = _RedisManager() diff --git a/src/backend/device_manager.py b/src/backend/device_manager.py index 8965ade..0dd4ea7 100644 --- a/src/backend/device_manager.py +++ b/src/backend/device_manager.py @@ -1,9 +1,9 @@ -from fastapi import FastAPI, HTTPException, Depends, Query, Path, Body, status, APIRouter, Request, BackgroundTasks # type: ignore +from fastapi import FastAPI, HTTPException, Depends, Query, Path, Body, status, APIRouter, Request, BackgroundTasks from database import RedisDatabase import hashlib from logger import logger from datetime import datetime, timezone -from user_agents import parse # type: ignore +from user_agents import parse import json class DeviceManager: diff --git a/src/backend/email_service.py b/src/backend/email_service.py index 46e1e22..dce48a6 100644 --- a/src/backend/email_service.py +++ b/src/backend/email_service.py @@ -1,8 +1,8 @@ import os from typing import Tuple from logger import logger -from email.mime.text import MIMEText # type: ignore -from email.mime.multipart import MIMEMultipart # type: ignore +from email.mime.text import MIMEText +from email.mime.multipart import MIMEMultipart import smtplib import asyncio from email_templates import EMAIL_TEMPLATES diff --git a/src/backend/entities/candidate_entity.py b/src/backend/entities/candidate_entity.py index a35da0a..691b3d2 100644 --- a/src/backend/entities/candidate_entity.py +++ b/src/backend/entities/candidate_entity.py @@ -1,13 +1,13 @@ from __future__ import annotations -from pydantic import BaseModel, Field, model_validator # type: ignore +from pydantic import BaseModel, Field, model_validator from uuid import uuid4 from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING, Literal from typing_extensions import Annotated, Union -import numpy as np # type: ignore +import numpy as np from uuid import uuid4 -from prometheus_client import CollectorRegistry, Counter # type: ignore +from prometheus_client import CollectorRegistry, Counter import traceback import os import json diff --git a/src/backend/entities/entity_manager.py b/src/backend/entities/entity_manager.py index 4b535b4..ba64517 100644 --- a/src/backend/entities/entity_manager.py +++ b/src/backend/entities/entity_manager.py @@ -3,12 +3,12 @@ import weakref from datetime import datetime, timedelta from typing import Dict, Optional, Any from contextlib import asynccontextmanager -from pydantic import BaseModel, Field # type: ignore +from pydantic import BaseModel, Field from models import ( Candidate ) from .candidate_entity import CandidateEntity from database import RedisDatabase -from prometheus_client import CollectorRegistry # type: ignore +from prometheus_client import CollectorRegistry class EntityManager: """Manages lifecycle of CandidateEntity instances""" diff --git a/src/backend/generate_types.py b/src/backend/generate_types.py index a23c579..2de1543 100644 --- a/src/backend/generate_types.py +++ b/src/backend/generate_types.py @@ -64,7 +64,7 @@ current_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, current_dir) try: - from pydantic import BaseModel # type: ignore + from pydantic import BaseModel except ImportError as e: print(f"Error importing pydantic: {e}") print("Make sure pydantic is installed: pip install pydantic") diff --git a/src/backend/helpers/check_serializable.py b/src/backend/helpers/check_serializable.py index 646066d..0cad6a1 100644 --- a/src/backend/helpers/check_serializable.py +++ b/src/backend/helpers/check_serializable.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field # type: ignore +from pydantic import BaseModel, Field import json from typing import Any, List, Set diff --git a/src/backend/image_generator/image_model_cache.py b/src/backend/image_generator/image_model_cache.py index cca6317..d13905c 100644 --- a/src/backend/image_generator/image_model_cache.py +++ b/src/backend/image_generator/image_model_cache.py @@ -4,8 +4,8 @@ import re import time from typing import Any -import torch # type: ignore -from diffusers import StableDiffusionPipeline, FluxPipeline # type: ignore +import torch +from diffusers import StableDiffusionPipeline, FluxPipeline class ImageModelCache: # Stay loaded for 3 hours def __init__(self, timeout_seconds: float = 3 * 60 * 60): diff --git a/src/backend/image_generator/profile_image.py b/src/backend/image_generator/profile_image.py index 20423af..2d2aae7 100644 --- a/src/backend/image_generator/profile_image.py +++ b/src/backend/image_generator/profile_image.py @@ -1,6 +1,6 @@ from __future__ import annotations from datetime import UTC, datetime -from pydantic import BaseModel, Field # type: ignore +from pydantic import BaseModel, Field from typing import Dict, Literal, Any, AsyncGenerator, Optional import inspect import random @@ -13,7 +13,7 @@ import os import gc import tempfile import uuid -import torch # type: ignore +import torch import asyncio import time import json diff --git a/src/backend/llm_proxy.py b/src/backend/llm_proxy.py index 3e0f77c..b756c0b 100644 --- a/src/backend/llm_proxy.py +++ b/src/backend/llm_proxy.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from typing import Dict, List, Any, AsyncGenerator, Optional, Union -from pydantic import BaseModel, Field # type: ignore +from pydantic import BaseModel, Field from enum import Enum import asyncio import json @@ -179,7 +179,7 @@ class OllamaAdapter(BaseLLMAdapter): def __init__(self, **config): super().__init__(**config) import ollama - self.client = ollama.AsyncClient( # type: ignore + self.client = ollama.AsyncClient( host=config.get('host', defines.ollama_api_url) ) @@ -386,7 +386,7 @@ class OpenAIAdapter(BaseLLMAdapter): def __init__(self, **config): super().__init__(**config) - import openai # type: ignore + import openai self.client = openai.AsyncOpenAI( api_key=config.get('api_key', os.getenv('OPENAI_API_KEY')) ) @@ -522,7 +522,7 @@ class AnthropicAdapter(BaseLLMAdapter): def __init__(self, **config): super().__init__(**config) - import anthropic # type: ignore + import anthropic self.client = anthropic.AsyncAnthropic( api_key=config.get('api_key', os.getenv('ANTHROPIC_API_KEY')) ) @@ -656,7 +656,7 @@ class GeminiAdapter(BaseLLMAdapter): def __init__(self, **config): super().__init__(**config) - import google.generativeai as genai # type: ignore + import google.generativeai as genai genai.configure(api_key=config.get('api_key', os.getenv('GEMINI_API_KEY'))) self.genai = genai @@ -867,7 +867,7 @@ class UnifiedLLMProxy: raise ValueError("stream must be True for chat_stream") result = await self.chat(model, messages, provider, stream=True, **kwargs) # Type checker now knows this is an AsyncGenerator due to stream=True - async for chunk in result: # type: ignore + async for chunk in result: yield chunk async def chat_single( @@ -881,7 +881,7 @@ class UnifiedLLMProxy: result = await self.chat(model, messages, provider, stream=False, **kwargs) # Type checker now knows this is a ChatResponse due to stream=False - return result # type: ignore + return result async def generate( self, @@ -908,7 +908,7 @@ class UnifiedLLMProxy: """Stream text generation using specified or default provider""" result = await self.generate(model, prompt, provider, stream=True, **kwargs) - async for chunk in result: # type: ignore + async for chunk in result: yield chunk async def generate_single( @@ -921,7 +921,7 @@ class UnifiedLLMProxy: """Get single generation response using specified or default provider""" result = await self.generate(model, prompt, provider, stream=False, **kwargs) - return result # type: ignore + return result async def embeddings( self, @@ -1148,7 +1148,7 @@ async def example_embeddings_usage(): # Calculate similarity between first two texts (requires numpy) try: - import numpy as np # type: ignore + import numpy as np emb1 = np.array(response.data[0].embedding) emb2 = np.array(response.data[1].embedding) similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2)) diff --git a/src/backend/main.py b/src/backend/main.py index 2ae6041..b4b9a8e 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -1,10 +1,16 @@ -from fastapi import FastAPI, HTTPException, Depends, Query, Path, Body, status, APIRouter, Request, BackgroundTasks, File, UploadFile, Form # type: ignore +import time +from fastapi import FastAPI, HTTPException, Depends, Query, Path, Body, status, APIRouter, Request, BackgroundTasks, File, UploadFile, Form# type: ignore from fastapi.middleware.cors import CORSMiddleware # type: ignore -from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials # type: ignore +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials# type: ignore from fastapi.exceptions import RequestValidationError # type: ignore from fastapi.responses import JSONResponse, StreamingResponse, FileResponse # type: ignore from fastapi.staticfiles import StaticFiles # type: ignore from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY # type: ignore +from functools import wraps +from typing import Callable, Any, Optional +from rate_limiter import RateLimiter, RateLimitResult + +import schedule # type: ignore import os import shutil @@ -41,6 +47,8 @@ from prometheus_client import CollectorRegistry, Counter # type: ignore import secrets import os import backstory_traceback +from rate_limiter import RateLimiter, RateLimitResult, RateLimitConfig +from background_tasks import BackgroundTaskManager # ============================= # Import custom modules @@ -83,7 +91,7 @@ from models import ( Document, DocumentType, DocumentListResponse, DocumentUpdateRequest, DocumentContentResponse, # Supporting models - Location, MFARequest, MFAData, MFARequestResponse, MFAVerifyRequest, RagContentResponse, ResendVerificationRequest, Skill, WorkExperience, Education, + Location, MFARequest, MFAData, MFARequestResponse, MFAVerifyRequest, RagContentMetadata, RagContentResponse, ResendVerificationRequest, Skill, SystemInfo, WorkExperience, Education, # Email EmailVerificationRequest @@ -137,6 +145,9 @@ async def lifespan(app: FastAPI): logger.info("Application shutdown requested") await db_manager.graceful_shutdown() +# Global background task manager +background_task_manager: Optional[BackgroundTaskManager] = None + app = FastAPI( lifespan=lifespan, title="Backstory API", @@ -185,6 +196,7 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE content=json.dumps({"detail": str(exc)}), ) + # ============================ # Authentication Utilities # ============================ @@ -203,11 +215,13 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): return encoded_jwt async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials = Depends(security)): - """Verify token and check if it's blacklisted""" + """Enhanced token verification with guest session recovery""" try: # First decode the token payload = jwt.decode(credentials.credentials, JWT_SECRET_KEY, algorithms=[ALGORITHM]) user_id: str = payload.get("sub") + token_type: str = payload.get("type", "access") + if user_id is None: raise HTTPException(status_code=401, detail="Invalid authentication credentials") @@ -220,19 +234,27 @@ async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials logger.warning(f"๐Ÿšซ Attempt to use blacklisted token for user {user_id}") raise HTTPException(status_code=401, detail="Token has been revoked") - # Optional: Check if all user tokens are revoked (for "logout from all devices") - # user_revoked_key = f"user_tokens_revoked:{user_id}" - # user_tokens_revoked_at = await redis.get(user_revoked_key) - # if user_tokens_revoked_at: - # revoked_timestamp = datetime.fromisoformat(user_tokens_revoked_at.decode()) - # token_issued_at = datetime.fromtimestamp(payload.get("iat", 0), UTC) - # if token_issued_at < revoked_timestamp: - # raise HTTPException(status_code=401, detail="All user tokens have been revoked") + # For guest tokens, verify guest still exists and update activity + if token_type == "guest" or payload.get("type") == "guest": + database = db_manager.get_database() + guest_data = await database.get_guest(user_id) + + if not guest_data: + logger.warning(f"๐Ÿšซ Guest session not found for token: {user_id}") + raise HTTPException(status_code=401, detail="Guest session expired") + + # Update guest activity + guest_data["last_activity"] = datetime.now(UTC).isoformat() + await database.set_guest(user_id, guest_data) + logger.debug(f"๐Ÿ”„ Guest activity updated: {user_id}") return user_id - except jwt.PyJWTError: + except jwt.PyJWTError as e: + logger.warning(f"โš ๏ธ JWT decode error: {e}") raise HTTPException(status_code=401, detail="Invalid authentication credentials") + except HTTPException: + raise except Exception as e: logger.error(f"โŒ Token verification error: {e}") raise HTTPException(status_code=401, detail="Token verification failed") @@ -247,8 +269,15 @@ async def get_current_user( candidate_data = await database.get_candidate(user_id) if candidate_data: # logger.info(f"๐Ÿ”‘ Current user is candidate: {candidate['id']}") - return Candidate.model_validate(candidate_data) if not candidate_data.get("is_AI") else CandidateAI.model_validate(candidate_data) - + return Candidate.model_validate(candidate_data) if not candidate_data.get("is_AI") else CandidateAI.model_validate(candidate_data) # type: ignore[return-value] + # Check candidates + candidate_data = await database.get_candidate(user_id) + if candidate_data: + # logger.info(f"๐Ÿ”‘ Current user is candidate: {candidate['id']}") + if candidate_data.get("is_AI"): + return model_cast.cast_to_base_user_with_type(CandidateAI.model_validate(candidate_data)) + else: + return model_cast.cast_to_base_user_with_type(Candidate.model_validate(candidate_data)) # Check employers employer = await database.get_employer(user_id) if employer: @@ -262,6 +291,34 @@ async def get_current_user( logger.error(f"โŒ Error getting current user: {e}") raise HTTPException(status_code=404, detail="User not found") +async def get_current_user_or_guest( + user_id: str = Depends(verify_token_with_blacklist), + database: RedisDatabase = Depends(lambda: db_manager.get_database()) +) -> BaseUserWithType: + """Get current user (including guests) from database""" + try: + # Check candidates first + candidate_data = await database.get_candidate(user_id) + if candidate_data: + return Candidate.model_validate(candidate_data) if not candidate_data.get("is_AI") else CandidateAI.model_validate(candidate_data) + + # Check employers + employer_data = await database.get_employer(user_id) + if employer_data: + return Employer.model_validate(employer_data) + + # Check guests + guest_data = await database.get_guest(user_id) + if guest_data: + return Guest.model_validate(guest_data) + + logger.warning(f"โš ๏ธ User {user_id} not found in database") + raise HTTPException(status_code=404, detail="User not found") + + except Exception as e: + logger.error(f"โŒ Error getting current user: {e}") + raise HTTPException(status_code=404, detail="User not found") + async def get_current_admin( user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(lambda: db_manager.get_database()) @@ -456,6 +513,108 @@ def get_document_type_from_filename(filename: str) -> DocumentType: return type_mapping.get(extension, DocumentType.TXT) +# ============================ +# Rate Limiting Dependencies +# ============================ + +async def get_rate_limiter(database: RedisDatabase = Depends(get_database)) -> RateLimiter: + """Dependency to get rate limiter instance""" + return RateLimiter(database) + +async def apply_rate_limiting( + request: Request, + rate_limiter: RateLimiter = Depends(get_rate_limiter), + current_user: Optional[BaseUserWithType] = None +) -> RateLimitResult: + """ + Apply rate limiting based on user type + Can be used as a dependency in endpoints + """ + try: + # Determine user info for rate limiting + if current_user: + user_id = current_user.id + user_type = current_user.user_type + is_admin = getattr(current_user, 'is_admin', False) + else: + # For unauthenticated requests, use IP address as identifier + user_id = request.client.host if request.client else "unknown" + user_type = "anonymous" + is_admin = False + + # Extract endpoint for specific rate limiting if needed + endpoint = request.url.path + + # Check rate limits + result = await rate_limiter.check_rate_limit( + user_id=user_id, + user_type=user_type, + is_admin=is_admin, + endpoint=endpoint + ) + + if not result.allowed: + logger.warning(f"๐Ÿšซ Rate limit exceeded for {user_type} {user_id}: {result.reason}") + raise HTTPException( + status_code=429, + detail={ + "error": "Rate limit exceeded", + "message": result.reason, + "retryAfter": result.retry_after_seconds, + "remaining": result.remaining_requests + }, + headers={"Retry-After": str(result.retry_after_seconds or 60)} + ) + + return result + + except HTTPException: + raise + except Exception as e: + logger.error(f"โŒ Rate limiting error: {e}") + # Fail open - allow request if rate limiting fails + return RateLimitResult(allowed=True, reason="Rate limiting system error") + +async def rate_limit_dependency( + request: Request, + rate_limiter: RateLimiter = Depends(get_rate_limiter) +): + """ + Rate limiting dependency that can be applied to any endpoint + Usage: dependencies=[Depends(rate_limit_dependency)] + """ + try: + # Try to get current user from token if present + current_user = None + if "authorization" in request.headers: + try: + auth_header = request.headers["authorization"] + if auth_header.startswith("Bearer "): + token = auth_header[7:] + payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[ALGORITHM]) + user_id = payload.get("sub") + if user_id: + database = db_manager.get_database() + # Quick user lookup for rate limiting + candidate_data = await database.get_candidate(user_id) + if candidate_data: + current_user = Candidate.model_validate(candidate_data) + else: + employer_data = await database.get_employer(user_id) + if employer_data: + current_user = Employer.model_validate(employer_data) + except: + # Ignore auth errors for rate limiting - treat as anonymous + pass + + await apply_rate_limiting(request, rate_limiter, current_user) + + except HTTPException: + raise + except Exception as e: + logger.error(f"โŒ Rate limit dependency error: {e}") + # Fail open + # ============================ # API Router Setup # ============================ @@ -466,6 +625,262 @@ api_router = APIRouter(prefix="/api/1.0") # ============================ # Authentication Endpoints # ============================ +@api_router.post("/auth/guest") +async def create_guest_session_enhanced( + request: Request, + database: RedisDatabase = Depends(get_database), + rate_limiter: RateLimiter = Depends(get_rate_limiter) +): + """Create a guest session with enhanced validation and persistence""" + try: + # Apply rate limiting for guest creation + ip_address = request.client.host if request.client else "unknown" + + # Check rate limits for guest session creation + rate_result = await rate_limiter.check_rate_limit( + user_id=ip_address, + user_type="guest_creation", + is_admin=False, + endpoint="/auth/guest" + ) + + if not rate_result.allowed: + logger.warning(f"๐Ÿšซ Guest creation rate limit exceeded for IP {ip_address}") + return JSONResponse( + status_code=429, + content=create_error_response( + "RATE_LIMITED", + rate_result.reason or "Too many guest sessions created" + ), + headers={"Retry-After": str(rate_result.retry_after_seconds or 300)} + ) + + # Generate unique guest identifier with timestamp for uniqueness + current_time = datetime.now(UTC) + guest_id = str(uuid.uuid4()) + session_id = f"guest_{int(current_time.timestamp())}_{secrets.token_hex(8)}" + guest_username = f"guest-{session_id[-12:]}" + + # Verify username is unique (unlikely but possible collision) + while True: + existing_user = await database.get_user(guest_username) + if existing_user: + # Regenerate if collision + session_id = f"guest_{int(current_time.timestamp())}_{secrets.token_hex(12)}" + guest_username = f"guest-{session_id[-16:]}" + else: + break + + # Create guest user data with comprehensive info + guest_data = { + "id": guest_id, + "session_id": session_id, + "username": guest_username, + "email": f"{guest_username}@guest.backstory.ketrenos.com", + "first_name": "Guest", + "last_name": "User", + "full_name": "Guest User", + "user_type": "guest", + "created_at": current_time.isoformat(), + "updated_at": current_time.isoformat(), + "last_activity": current_time.isoformat(), + "last_login": current_time.isoformat(), + "status": "active", + "is_admin": False, + "ip_address": ip_address, + "user_agent": request.headers.get("user-agent", "Unknown"), + "converted_to_user_id": None, + "browser_session": True, # Mark as browser session + "persistent": True, # Mark as persistent + } + + # Store guest with enhanced persistence + await database.set_guest(guest_id, guest_data) + + # Create user lookup records + user_auth_data = { + "id": guest_id, + "type": "guest", + "email": guest_data["email"], + "username": guest_username, + "session_id": session_id, + "created_at": current_time.isoformat() + } + + await database.set_user(guest_data["email"], user_auth_data) + await database.set_user(guest_username, user_auth_data) + await database.set_user_by_id(guest_id, user_auth_data) + + # Create authentication tokens with longer expiry for guests + access_token = create_access_token( + data={"sub": guest_id, "type": "guest"}, + expires_delta=timedelta(hours=48) # Longer expiry for guests + ) + refresh_token = create_access_token( + data={"sub": guest_id, "type": "refresh_guest"}, + expires_delta=timedelta(days=14) # 2 weeks refresh for guests + ) + + # Verify guest was stored correctly + verification = await database.get_guest(guest_id) + if not verification: + logger.error(f"โŒ Failed to verify guest storage: {guest_id}") + return JSONResponse( + status_code=500, + content=create_error_response("STORAGE_ERROR", "Failed to create guest session") + ) + + # Create guest object for response + guest = Guest.model_validate(guest_data) + + # Log successful creation + logger.info(f"๐Ÿ‘ค Guest session created and verified: {guest_username} (ID: {guest_id}) from IP: {ip_address}") + + # Create auth response + auth_response = { + "accessToken": access_token, + "refreshToken": refresh_token, + "user": guest.model_dump(by_alias=True), + "expiresAt": int((current_time + timedelta(hours=48)).timestamp()), + "userType": "guest", + "isGuest": True + } + + return create_success_response(auth_response) + + except Exception as e: + logger.error(f"โŒ Guest session creation error: {e}") + import traceback + logger.error(traceback.format_exc()) + return JSONResponse( + status_code=500, + content=create_error_response("GUEST_CREATION_FAILED", "Failed to create guest session") + ) + +@api_router.post("/auth/guest/convert") +async def convert_guest_to_user( + registration_data: Dict[str, Any] = Body(...), + current_user = Depends(get_current_user), + database: RedisDatabase = Depends(get_database) +): + """Convert a guest session to a permanent user account""" + try: + # Verify current user is a guest + if current_user.user_type != "guest": + return JSONResponse( + status_code=400, + content=create_error_response("NOT_GUEST", "Only guest users can be converted") + ) + + guest: Guest = current_user + account_type = registration_data.get("accountType", "candidate") + + if account_type == "candidate": + # Validate candidate registration data + try: + candidate_request = CreateCandidateRequest.model_validate(registration_data) + except ValidationError as e: + return JSONResponse( + status_code=400, + content=create_error_response("VALIDATION_ERROR", str(e)) + ) + + # Check if email/username already exists + auth_manager = AuthenticationManager(database) + user_exists, conflict_field = await auth_manager.check_user_exists( + candidate_request.email, + candidate_request.username + ) + + if user_exists: + return JSONResponse( + status_code=409, + content=create_error_response( + "USER_EXISTS", + f"A user with this {conflict_field} already exists" + ) + ) + + # Create candidate + candidate_id = str(uuid.uuid4()) + current_time = datetime.now(timezone.utc) + + candidate_data = { + "id": candidate_id, + "user_type": "candidate", + "email": candidate_request.email, + "username": candidate_request.username, + "first_name": candidate_request.first_name, + "last_name": candidate_request.last_name, + "full_name": f"{candidate_request.first_name} {candidate_request.last_name}", + "phone": candidate_request.phone, + "created_at": current_time.isoformat(), + "updated_at": current_time.isoformat(), + "status": "active", + "is_admin": False, + "converted_from_guest": guest.id + } + + candidate = Candidate.model_validate(candidate_data) + + # Create authentication + await auth_manager.create_user_authentication(candidate_id, candidate_request.password) + + # Store candidate + await database.set_candidate(candidate_id, candidate.model_dump()) + + # Update user lookup records + user_auth_data = { + "id": candidate_id, + "type": "candidate", + "email": candidate.email, + "username": candidate.username + } + + await database.set_user(candidate.email, user_auth_data) + await database.set_user(candidate.username, user_auth_data) + await database.set_user_by_id(candidate_id, user_auth_data) + + # Mark guest as converted + guest_data = guest.model_dump() + guest_data["converted_to_user_id"] = candidate_id + guest_data["updated_at"] = current_time.isoformat() + await database.set_guest(guest.id, guest_data) + + # Create new tokens for the candidate + access_token = create_access_token(data={"sub": candidate_id}) + refresh_token = create_access_token( + data={"sub": candidate_id, "type": "refresh"}, + expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS) + ) + + auth_response = AuthResponse( + accessToken=access_token, + refreshToken=refresh_token, + user=candidate, + expiresAt=int((current_time + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()) + ) + + logger.info(f"โœ… Guest {guest.session_id} converted to candidate {candidate.username}") + + return create_success_response({ + "message": "Guest account successfully converted to candidate", + "auth": auth_response.model_dump(by_alias=True), + "conversionType": "candidate" + }) + + else: + return JSONResponse( + status_code=400, + content=create_error_response("INVALID_TYPE", "Only candidate conversion is currently supported") + ) + + except Exception as e: + logger.error(f"โŒ Guest conversion error: {e}") + return JSONResponse( + status_code=500, + content=create_error_response("CONVERSION_FAILED", "Failed to convert guest account") + ) @api_router.post("/auth/logout") async def logout( @@ -576,7 +991,7 @@ async def logout( @api_router.post("/auth/logout-all") async def logout_all_devices( - current_user = Depends(get_current_user), + current_user = Depends(get_current_admin), database: RedisDatabase = Depends(get_database) ): """Logout from all devices by revoking all tokens for the user""" @@ -692,18 +1107,23 @@ async def create_candidate_ai( session_id=user_message.session_id, prompt=user_message.content, ): - if generated_message.status == ApiStatusType.ERROR: - logger.error(f"โŒ AI generation error: {generated_message.content}") + if isinstance(generated_message, ChatMessageError): + error_message : ChatMessageError = generated_message + logger.error(f"โŒ AI generation error: {error_message.content}") return JSONResponse( status_code=500, - content=create_error_response("AI_GENERATION_ERROR", generated_message.content) + content=create_error_response("AI_GENERATION_ERROR", error_message.content) ) + if isinstance(generated_message, ChatMessageRagSearch): + raise ValueError("AI generation returned a RAG search message instead of a persona") + if generated_message.status == ApiStatusType.DONE and state == 0: persona_message = generated_message state = 1 # Switch to resume generation elif generated_message.status == ApiStatusType.DONE and state == 1: resume_message = generated_message + if not persona_message: logger.error(f"โŒ AI generation failed: {persona_message.content if persona_message else 'No message generated'}") return JSONResponse( @@ -773,9 +1193,8 @@ async def create_candidate_ai( originalName=document_filename, type=document_type, size=len(document_content), - upload_date=datetime.now(UTC), - include_in_RAG=True, - owner_id=candidate.id + uploadDate=datetime.now(UTC), + ownerId=candidate.id ) file_path = os.path.join(defines.user_dir, candidate.username, "rag-content", document_filename) # Ensure the directory exists @@ -934,11 +1353,11 @@ async def create_employer_with_verification( employer_data = { "id": employer_id, "email": request.email, - "companyName": request.companyName, + "companyName": request.company_name, "industry": request.industry, - "companySize": request.companySize, - "companyDescription": request.companyDescription, - "websiteUrl": request.websiteUrl, + "companySize": request.company_size, + "companyDescription": request.company_description, + "websiteUrl": request.website_url, "phone": request.phone, "createdAt": current_time.isoformat(), "updatedAt": current_time.isoformat(), @@ -969,7 +1388,7 @@ async def create_employer_with_verification( email_service.send_verification_email, request.email, verification_token, - request.companyName + request.company_name ) logger.info(f"โœ… Employer registration initiated for: {request.email}") @@ -1317,15 +1736,15 @@ async def request_mfa( logger.info(f"๐Ÿ” MFA requested for {request.email} from new device {request.device_name}") mfa_data = MFAData( + message="New device detected. We've sent a security code to your email address.", + codeSent=mfa_code, email=request.email, - device_id=request.device_id, - device_name=request.device_name, - mfaCode=mfa_code + deviceId=request.device_id, + deviceName=request.device_name, ) mfa_response = MFARequestResponse( - mfa_required=True, - message="MFA code sent to your email address", - mfa_data=mfa_data + mfaRequired=True, + mfaData=mfa_data ) return create_success_response(mfa_response) @@ -1429,13 +1848,13 @@ async def login( logger.info(f"๐Ÿ” MFA code automatically sent to {request.login} for device {device_info['device_name']}") mfa_response = MFARequestResponse( - mfa_required=True, - mfa_data=MFAData( + mfaRequired=True, + mfaData=MFAData( message="New device detected. We've sent a security code to your email address.", email=email, - device_id=device_id, - device_name=device_info["device_name"], - code_sent=mfa_code + deviceId=device_id, + deviceName=device_info["device_name"], + codeSent=mfa_code ) ) return create_success_response(mfa_response.model_dump(by_alias=True)) @@ -1636,10 +2055,10 @@ async def verify_mfa( # Create response auth_response = AuthResponse( - access_token=access_token, - refresh_token=refresh_token, + accessToken=access_token, + refreshToken=refresh_token, user=user, - expires_at=int((datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()) + expiresAt=int((datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()) ) logger.info(f"โœ… MFA verified and login completed for {request.email}") @@ -1694,20 +2113,20 @@ class DebugStreamingResponse(StreamingResponse): @api_router.post("/candidates/documents/upload") async def upload_candidate_document( file: UploadFile = File(...), - options: str = Form(...), + options_data: str = Form(..., alias="options"), current_user = Depends(get_current_user), database: RedisDatabase = Depends(get_database) ): try: # Parse the JSON string and create DocumentOptions object - options_dict = json.loads(options) - options = DocumentOptions(**options_dict) + options_dict = json.loads(options_data) + options : DocumentOptions = DocumentOptions.model_validate(**options_dict) except (json.JSONDecodeError, ValidationError) as e: return StreamingResponse( - iter([ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads + iter([json.dumps(ChatMessageError( + sessionId=MOCK_UUID, # No session ID for document uploads content="Invalid options format. Please provide valid JSON." - )]), + ).model_dump(mode='json', by_alias=True))]), media_type="text/event-stream" ) @@ -1717,19 +2136,19 @@ async def upload_candidate_document( if len(file_content) > max_size: logger.info(f"โš ๏ธ File too large: {file.filename} ({len(file_content)} bytes)") return StreamingResponse( - iter([ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads + iter([json.dumps(ChatMessageError( + sessionId=MOCK_UUID, # No session ID for document uploads content="File size exceeds 10MB limit" - )]), + ).model_dump(mode='json', by_alias=True))]), media_type="text/event-stream" ) if len(file_content) == 0: logger.info(f"โš ๏ธ File is empty: {file.filename}") return StreamingResponse( - iter([ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads + iter([json.dumps(ChatMessageError( + sessionId=MOCK_UUID, # No session ID for document uploads content="File is empty" - )]), + ).model_dump(mode='json', by_alias=True))]), media_type="text/event-stream" ) @@ -1739,7 +2158,7 @@ async def upload_candidate_document( if current_user.user_type != "candidate": logger.warning(f"โš ๏ธ Unauthorized upload attempt by user type: {current_user.user_type}") error_message = ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads + sessionId=MOCK_UUID, # No session ID for document uploads content="Only candidates can upload documents" ) yield error_message @@ -1750,7 +2169,7 @@ async def upload_candidate_document( if not file.filename or file.filename.strip() == "": logger.warning("โš ๏ธ File upload attempt with missing filename") error_message = ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads + sessionId=MOCK_UUID, # No session ID for document uploads content="File must have a valid filename" ) yield error_message @@ -1770,7 +2189,7 @@ async def upload_candidate_document( if not options.overwrite: logger.warning(f"โš ๏ธ File already exists: {file_path}") error_message = ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads + sessionId=MOCK_UUID, # No session ID for document uploads content=f"File with this name already exists in the '{directory}' directory" ) yield error_message @@ -1778,7 +2197,7 @@ async def upload_candidate_document( else: logger.info(f"๐Ÿ”„ Overwriting existing file: {file_path}") status_message = ChatMessageStatus( - session_id=MOCK_UUID, # No session ID for document uploads + sessionId=MOCK_UUID, # No session ID for document uploads content=f"Overwriting existing file: {file.filename}", activity=ApiActivityType.INFO ) @@ -1791,7 +2210,7 @@ async def upload_candidate_document( if file_extension not in allowed_types: logger.warning(f"โš ๏ธ Invalid file type: {file_extension} for file {file.filename}") error_message = ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads + sessionId=MOCK_UUID, # No session ID for document uploads content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}" ) yield error_message @@ -1807,9 +2226,9 @@ async def upload_candidate_document( originalName=file.filename or f"document_{document_id}", type=document_type, size=len(file_content), - upload_date=datetime.now(UTC), + uploadDate=datetime.now(UTC), options=options, - owner_id=candidate.id + ownerId=candidate.id ) # Save file to disk @@ -1825,7 +2244,7 @@ async def upload_candidate_document( except Exception as e: logger.error(f"โŒ Failed to save file to disk: {e}") error_message = ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads + sessionId=MOCK_UUID, # No session ID for document uploads content="Failed to save file to disk", ) yield error_message @@ -1841,13 +2260,13 @@ async def upload_candidate_document( p.stat().st_mtime > p_as_md.stat().st_mtime ): status_message = ChatMessageStatus( - session_id=MOCK_UUID, # No session ID for document uploads + sessionId=MOCK_UUID, # No session ID for document uploads content=f"Converting content from {document_type}...", activity=ApiActivityType.CONVERTING ) yield status_message try: - from markitdown import MarkItDown# type: ignore + from markitdown import MarkItDown # type: ignore md = MarkItDown(enable_plugins=False) # Set to True to enable plugins result = md.convert(file_path, output_format="markdown") p_as_md.write_text(result.text_content) @@ -1857,7 +2276,7 @@ async def upload_candidate_document( file_path = p_as_md except Exception as e: error_message = ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads + sessionId=MOCK_UUID, # No session ID for document uploads content=f"Failed to convert {file.filename} to Markdown.", ) yield error_message @@ -1869,7 +2288,7 @@ async def upload_candidate_document( await database.add_document_to_candidate(candidate.id, document_id) logger.info(f"๐Ÿ“„ Document uploaded: {file.filename} for candidate {candidate.username}") chat_message = DocumentMessage( - session_id=MOCK_UUID, # No session ID for document uploads + sessionId=MOCK_UUID, # No session ID for document uploads type=ApiMessageType.JSON, status=ApiStatusType.DONE, document=document_data, @@ -1905,16 +2324,16 @@ async def upload_candidate_document( logger.error(backstory_traceback.format_exc()) logger.error(f"โŒ Document upload error: {e}") return StreamingResponse( - iter([ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads + iter([json.dumps(ChatMessageError( + sessionId=MOCK_UUID, # No session ID for document uploads content="Failed to upload document" - )]), + ).model_dump(mode='json', by_alias=True))]), media_type="text/event-stream" ) async def create_job_from_content(database: RedisDatabase, current_user: Candidate, content: str): status_message = ChatMessageStatus( - session_id=MOCK_UUID, # No session ID for document uploads + sessionId=MOCK_UUID, # No session ID for document uploads content=f"Initiating connection with {current_user.first_name}'s AI agent...", activity=ApiActivityType.INFO ) @@ -1925,14 +2344,14 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS) if not chat_agent: error_message = ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads + sessionId=MOCK_UUID, # No session ID for document uploads content="No agent found for job requirements chat type" ) yield error_message return message = None status_message = ChatMessageStatus( - session_id=MOCK_UUID, # No session ID for document uploads + sessionId=MOCK_UUID, # No session ID for document uploads content=f"Analyzing document for company and requirement details...", activity=ApiActivityType.SEARCHING ) @@ -1948,7 +2367,7 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida pass if not message or not isinstance(message, JobRequirementsMessage): error_message = ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads + sessionId=MOCK_UUID, # No session ID for document uploads content="Failed to process job description file" ) yield error_message @@ -2000,7 +2419,7 @@ async def upload_candidate_profile( ) # Save file to disk as "profile." - _, extension = os.path.splitext(file.filename) + _, extension = os.path.splitext(file.filename or "") file_path = os.path.join(defines.user_dir, candidate.username, f"profile{extension}") try: @@ -2164,7 +2583,7 @@ async def get_document_content( content=create_error_response("FORBIDDEN", "Cannot access another candidate's document") ) - file_path = os.path.join(defines.user_dir, candidate.username, "rag-content" if document.include_in_RAG else "files", document.originalName) + file_path = os.path.join(defines.user_dir, candidate.username, "rag-content" if document.options.include_in_RAG else "files", document.originalName) file_path = pathlib.Path(file_path) if not document.type in [DocumentType.TXT, DocumentType.MARKDOWN]: file_path = file_path.with_suffix('.md') @@ -2183,7 +2602,7 @@ async def get_document_content( response = DocumentContentResponse( documentId=document_id, filename=document.filename, - type=document.type.value, + type=document.type, content=content, size=document.size ) @@ -2238,8 +2657,8 @@ async def update_document( status_code=403, content=create_error_response("FORBIDDEN", "Cannot update another candidate's document") ) - - if document.include_in_RAG != updates.include_in_RAG: + update_options = updates.options if updates.options else DocumentOptions() + if document.options.include_in_RAG != update_options.include_in_RAG: # If RAG status is changing, we need to handle file movement rag_dir = os.path.join(defines.user_dir, candidate.username, "rag-content") file_dir = os.path.join(defines.user_dir, candidate.username, "files") @@ -2248,7 +2667,7 @@ async def update_document( rag_path = os.path.join(rag_dir, document.originalName) file_path = os.path.join(file_dir, document.originalName) - if updates.include_in_RAG: + if update_options.include_in_RAG: src = pathlib.Path(file_path) dst = pathlib.Path(rag_path) # Move to RAG directory @@ -2276,8 +2695,8 @@ async def update_document( update_dict = {} if updates.filename is not None: update_dict["filename"] = updates.filename.strip() - if updates.include_in_RAG is not None: - update_dict["include_in_RAG"] = updates.include_in_RAG + if update_options.include_in_RAG is not None: + update_dict["include_in_RAG"] = update_options.include_in_RAG if not update_dict: return JSONResponse( @@ -2449,15 +2868,20 @@ async def post_candidate_vector_content( {"error": "No UMAP collection found"}, status_code=404 ) - if not collection.get("metadatas", None): + if not collection.metadatas or collection.ids: return JSONResponse(f"Document id {rag_document.id} not found.", 404) - for index, id in enumerate(collection.get("ids", [])): + for index, id in enumerate(collection.ids): if id == rag_document.id: - metadata = collection.get("metadatas", [])[index].copy() + metadata = collection.metadatas[index].copy() + rag_metadata = RagContentMetadata.model_validate(metadata) content = candidate_entity.file_watcher.prepare_metadata(metadata) - rag_response = RagContentResponse(id=id, content=content, metadata=metadata) - logger.info(f"โœ… Fetched RAG content for document id {id} for candidate {candidate.username}") + if content: + rag_response = RagContentResponse(id=id, content=content, metadata=rag_metadata) + logger.info(f"โœ… Fetched RAG content for document id {id} for candidate {candidate.username}") + else: + logger.warning(f"โš ๏ธ No content found for document id {id} for candidate {candidate.username}") + return JSONResponse(f"No content found for document id {rag_document.id}.", 404) return create_success_response(rag_response.model_dump(by_alias=True)) return JSONResponse(f"Document id {rag_document.id} not found.", 404) @@ -2508,9 +2932,9 @@ async def post_candidate_vectors( return create_success_response(results) result = { - "ids": collection.get("ids", []), - "metadatas": collection.get("metadatas", []), - "documents": collection.get("documents", []), + "ids": collection.ids, + "metadatas": collection.metadatas, + "documents": collection.documents, "embeddings": umap_embedding.tolist(), "size": candidate_entity.file_watcher.collection.count() } @@ -2690,7 +3114,7 @@ async def search_candidates( query_lower in c.last_name.lower() or query_lower in c.email.lower() or query_lower in c.username.lower() or - any(query_lower in skill.name.lower() for skill in c.skills)) + any(query_lower in skill.name.lower() for skill in c.skills or [])) ] paginated_candidates, total = filter_and_paginate( @@ -2835,7 +3259,7 @@ async def create_job_from_description( if current_user.user_type != "candidate": logger.warning(f"โš ๏ธ Unauthorized upload attempt by user type: {current_user.user_type}") error_message = ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads + sessionId=MOCK_UUID, # No session ID for document uploads content="Only candidates can upload documents" ) yield error_message @@ -2875,10 +3299,10 @@ async def create_job_from_description( logger.error(backstory_traceback.format_exc()) logger.error(f"โŒ Document upload error: {e}") return StreamingResponse( - iter([ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads + iter([json.dumps(ChatMessageError( + sessionId=MOCK_UUID, # No session ID for document uploads content="Failed to upload document" - )]), + ).model_dump(by_alias=True)).encode("utf-8")]), media_type="text/event-stream" ) @@ -2895,19 +3319,19 @@ async def create_job_from_file( if len(file_content) > max_size: logger.info(f"โš ๏ธ File too large: {file.filename} ({len(file_content)} bytes)") return StreamingResponse( - iter([ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads + iter([json.dumps(ChatMessageError( + sessionId=MOCK_UUID, # No session ID for document uploads content="File size exceeds 10MB limit" - )]), + ).model_dump(by_alias=True)).encode("utf-8")]), media_type="text/event-stream" ) if len(file_content) == 0: logger.info(f"โš ๏ธ File is empty: {file.filename}") return StreamingResponse( - iter([ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads + iter([json.dumps(ChatMessageError( + sessionId=MOCK_UUID, # No session ID for document uploads content="File is empty" - )]), + ).model_dump(by_alias=True)).encode("utf-8")]), media_type="text/event-stream" ) @@ -2917,7 +3341,7 @@ async def create_job_from_file( if current_user.user_type != "candidate": logger.warning(f"โš ๏ธ Unauthorized upload attempt by user type: {current_user.user_type}") error_message = ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads + sessionId=MOCK_UUID, # No session ID for document uploads content="Only candidates can upload documents" ) yield error_message @@ -2927,7 +3351,7 @@ async def create_job_from_file( if not file.filename or file.filename.strip() == "": logger.warning("โš ๏ธ File upload attempt with missing filename") error_message = ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads + sessionId=MOCK_UUID, # No session ID for document uploads content="File must have a valid filename" ) yield error_message @@ -2942,7 +3366,7 @@ async def create_job_from_file( if file_extension not in allowed_types: logger.warning(f"โš ๏ธ Invalid file type: {file_extension} for file {file.filename}") error_message = ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads + sessionId=MOCK_UUID, # No session ID for document uploads content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}" ) yield error_message @@ -2952,7 +3376,7 @@ async def create_job_from_file( if document_type != DocumentType.MARKDOWN and document_type != DocumentType.TXT: status_message = ChatMessageStatus( - session_id=MOCK_UUID, # No session ID for document uploads + sessionId=MOCK_UUID, # No session ID for document uploads content=f"Converting content from {document_type}...", activity=ApiActivityType.CONVERTING ) @@ -2969,7 +3393,7 @@ async def create_job_from_file( logger.info(f"โœ… Converted {file.filename} to Markdown format") except Exception as e: error_message = ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads + sessionId=MOCK_UUID, # No session ID for document uploads content=f"Failed to convert {file.filename} to Markdown.", ) yield error_message @@ -3007,10 +3431,10 @@ async def create_job_from_file( logger.error(backstory_traceback.format_exc()) logger.error(f"โŒ Document upload error: {e}") return StreamingResponse( - iter([ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads + iter([json.dumps(ChatMessageError( + sessionId=MOCK_UUID, # No session ID for document uploads content="Failed to upload document" - )]), + ).model_dump(mode='json', by_alias=True)).encode("utf-8")]), media_type="text/event-stream" ) @@ -3106,9 +3530,9 @@ async def search_jobs( query_lower = query.lower() jobs_list = [ j for j in jobs_list - if (query_lower in j.title.lower() or - query_lower in j.description.lower() or - any(query_lower in skill.lower() for skill in (j.preferred_skills or []))) + if ((j.title and query_lower in j.title.lower()) or + (j.description and query_lower in j.description.lower()) or + any(query_lower in skill.lower() for skill in getattr(j, "skills", []) or [])) ] paginated_jobs, total = filter_and_paginate( @@ -3218,13 +3642,13 @@ async def post_candidate_rag_search( content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type") ) - user_message = ChatMessageUser(sender_id=candidate.id, session_id=MOCK_UUID, content=query, timestamp=datetime.now(UTC)) - rag_message = None + user_message = ChatMessageUser(senderId=candidate.id, sessionId=MOCK_UUID, content=query, timestamp=datetime.now(UTC)) + rag_message : Any = None async for generated_message in chat_agent.generate( llm=llm_manager.get_llm(), model=defines.model, session_id=user_message.session_id, - prompt=user_message.prompt, + prompt=user_message.content, ): rag_message = generated_message @@ -3233,7 +3657,8 @@ async def post_candidate_rag_search( status_code=500, content=create_error_response("NO_RESPONSE", "No response generated for the RAG search") ) - return create_success_response(rag_message.metadata.rag_results[0].model_dump(by_alias=True)) + final_message : ChatMessageRagSearch = rag_message + return create_success_response(final_message.content[0].model_dump(by_alias=True)) except Exception as e: logger.error(f"โŒ Get candidate chat summary error: {e}") @@ -3242,6 +3667,66 @@ async def post_candidate_rag_search( content=create_error_response("SUMMARY_ERROR", str(e)) ) +# reference can be candidateId, username, or email +@api_router.get("/users/{reference}") +async def get_user( + reference: str = Path(...), + database: RedisDatabase = Depends(get_database) +): + """Get a candidate by username""" + try: + # Normalize reference to lowercase for case-insensitive search + query_lower = reference.lower() + + all_candidate_data = await database.get_all_candidates() + if not all_candidate_data: + logger.warning(f"โš ๏ธ No users found in database") + return JSONResponse( + status_code=404, + content=create_error_response("NOT_FOUND", "No users found") + ) + + user_data = None + for user in all_candidate_data.values(): + if (user.get("id", "").lower() == query_lower or + user.get("username", "").lower() == query_lower or + user.get("email", "").lower() == query_lower): + user_data = user + break + + if not user_data: + all_guest_data = await database.get_all_guests() + if not all_guest_data: + logger.warning(f"โš ๏ธ No guests found in database") + return JSONResponse( + status_code=404, + content=create_error_response("NOT_FOUND", "No users found") + ) + for user in all_guest_data.values(): + if (user.get("id", "").lower() == query_lower or + user.get("username", "").lower() == query_lower or + user.get("email", "").lower() == query_lower): + user_data = user + break + + if not user_data: + logger.warning(f"โš ๏ธ User nor Guest found for reference: {reference}") + return JSONResponse( + status_code=404, + content=create_error_response("NOT_FOUND", "User not found") + ) + + user = BaseUserWithType.model_validate(user_data) + + return create_success_response(user.model_dump(by_alias=True)) + + except Exception as e: + logger.error(f"โŒ Get user error: {e}") + return JSONResponse( + status_code=500, + content=create_error_response("FETCH_ERROR", str(e)) + ) + # reference can be candidateId, username, or email @api_router.get("/candidates/{reference}") async def get_candidate( @@ -3362,7 +3847,7 @@ async def archive_chat_session( @api_router.post("/chat/sessions") async def create_chat_session( session_data: Dict[str, Any] = Body(...), - current_user: BaseUserWithType = Depends(get_current_user), + current_user: BaseUserWithType = Depends(get_current_user_or_guest), database: RedisDatabase = Depends(get_database) ): """Create a new chat session with optional candidate username association""" @@ -3449,7 +3934,7 @@ async def create_chat_session( @api_router.post("/chat/sessions/{session_id}/messages/stream") async def post_chat_session_message_stream( user_message: ChatMessageUser = Body(...), - current_user = Depends(get_current_user), + current_user = Depends(get_current_user_or_guest), database: RedisDatabase = Depends(get_database) ): """Post a message to a chat session and stream the response with persistence""" @@ -3463,7 +3948,7 @@ async def post_chat_session_message_stream( ) chat_session = ChatSession.model_validate(chat_session_data) chat_type = chat_session.context.type - candidate_info = chat_session.context.additional_context.get("candidateInfo", {}) + candidate_info = chat_session.context.additional_context.get("candidateInfo", {}) if chat_session.context and chat_session.context.additional_context else None # Get candidate info if this chat is about a specific candidate if candidate_info: @@ -3521,7 +4006,7 @@ async def post_chat_session_message_stream( @api_router.get("/chat/sessions/{session_id}/messages") async def get_chat_session_messages( session_id: str = Path(...), - current_user = Depends(get_current_user), + current_user = Depends(get_current_user_or_guest), page: int = Query(1, ge=1), limit: int = Query(50, ge=1, le=100), # Increased default for chat messages database: RedisDatabase = Depends(get_database) @@ -3575,7 +4060,7 @@ async def get_chat_session_messages( async def update_chat_session( session_id: str = Path(...), updates: Dict[str, Any] = Body(...), - current_user = Depends(get_current_user), + current_user = Depends(get_current_user_or_guest), database: RedisDatabase = Depends(get_database) ): """Update a chat session's properties""" @@ -3672,7 +4157,7 @@ async def update_chat_session( @api_router.delete("/chat/sessions/{session_id}") async def delete_chat_session( session_id: str = Path(...), - current_user = Depends(get_current_user), + current_user = Depends(get_current_user_or_guest), database: RedisDatabase = Depends(get_database) ): """Delete a chat session and all its messages""" @@ -3726,7 +4211,7 @@ async def delete_chat_session( @api_router.patch("/chat/sessions/{session_id}/reset") async def reset_chat_session( session_id: str = Path(...), - current_user = Depends(get_current_user), + current_user = Depends(get_current_user_or_guest), database: RedisDatabase = Depends(get_database) ): """Delete a chat session and all its messages""" @@ -3775,11 +4260,244 @@ async def reset_chat_session( content=create_error_response("RESET_ERROR", str(e)) ) + +# ============================ +# Rate Limited Decorator +# ============================ + +def rate_limited( + guest_per_minute: int = 10, + user_per_minute: int = 60, + admin_per_minute: int = 120, + endpoint_specific: bool = True +): + """ + Decorator to easily apply rate limiting to endpoints + + Args: + guest_per_minute: Rate limit for guest users + user_per_minute: Rate limit for authenticated users + admin_per_minute: Rate limit for admin users + endpoint_specific: Whether to apply endpoint-specific limits + + Usage: + @rate_limited(guest_per_minute=5, user_per_minute=30) + @api_router.post("/my-endpoint") + async def my_endpoint( + request: Request, + current_user = Depends(get_current_user_or_guest), + database: RedisDatabase = Depends(get_database) + ): + return {"message": "Rate limited endpoint"} + """ + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args, **kwargs): + # Extract dependencies from function signature + import inspect + sig = inspect.signature(func) + + # Get request, current_user, and rate_limiter from kwargs or args + request = None + current_user = None + rate_limiter = None + + # Try to find dependencies in kwargs first + for param_name, param_value in kwargs.items(): + if isinstance(param_value, Request): + request = param_value + elif hasattr(param_value, 'user_type'): # User-like object + current_user = param_value + elif isinstance(param_value, RateLimiter): + rate_limiter = param_value + + # If not found in kwargs, check if they're provided via Depends + if not rate_limiter: + # Create rate limiter instance (this should ideally come from DI) + database = db_manager.get_database() + rate_limiter = RateLimiter(database) + + # Apply rate limiting if we have the required components + if request and current_user and rate_limiter: + await apply_custom_rate_limiting( + request, current_user, rate_limiter, + guest_per_minute, user_per_minute, admin_per_minute + ) + + # Call the original function + return await func(*args, **kwargs) + + return wrapper + return decorator + +async def apply_custom_rate_limiting( + request: Request, + current_user, + rate_limiter: RateLimiter, + guest_per_minute: int, + user_per_minute: int, + admin_per_minute: int +): + """Apply custom rate limiting with specified limits""" + try: + # Determine user info + user_id = current_user.id + user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type) + is_admin = getattr(current_user, 'is_admin', False) + + # Determine appropriate limit + if is_admin: + requests_per_minute = admin_per_minute + elif user_type == "guest": + requests_per_minute = guest_per_minute + else: + requests_per_minute = user_per_minute + + # Create custom rate limit key + current_time = datetime.now(UTC) + custom_key = f"custom_rate_limit:{request.url.path}:{user_type}:{user_id}:minute:{current_time.strftime('%Y%m%d%H%M')}" + + # Check current usage + current_count = int(await rate_limiter.redis.get(custom_key) or 0) + + if current_count >= requests_per_minute: + logger.warning(f"๐Ÿšซ Custom rate limit exceeded for {user_type} {user_id}: {current_count}/{requests_per_minute}") + raise HTTPException( + status_code=429, + detail={ + "error": "Rate limit exceeded", + "message": f"Custom rate limit exceeded: {current_count}/{requests_per_minute} requests per minute", + "retryAfter": 60 - current_time.second, + "userType": user_type, + "endpoint": request.url.path + }, + headers={"Retry-After": str(60 - current_time.second)} + ) + + # Increment counter + pipe = rate_limiter.redis.pipeline() + pipe.incr(custom_key) + pipe.expire(custom_key, 120) # 2 minutes TTL + await pipe.execute() + + logger.debug(f"โœ… Custom rate limit check passed for {user_type} {user_id}: {current_count + 1}/{requests_per_minute}") + + except HTTPException: + raise + except Exception as e: + logger.error(f"โŒ Custom rate limiting error: {e}") + # Fail open + +# ============================ +# Alternative: FastAPI Dependency-Based Rate Limiting +# ============================ + +def create_rate_limit_dependency( + guest_per_minute: int = 10, + user_per_minute: int = 60, + admin_per_minute: int = 120 +): + """ + Create a FastAPI dependency for rate limiting + + Usage: + rate_limit_5_30 = create_rate_limit_dependency(guest_per_minute=5, user_per_minute=30) + + @api_router.post("/my-endpoint") + async def my_endpoint( + rate_check = Depends(rate_limit_5_30), + current_user = Depends(get_current_user_or_guest), + database: RedisDatabase = Depends(get_database) + ): + return {"message": "Rate limited endpoint"} + """ + async def rate_limit_dependency( + request: Request, + current_user = Depends(get_current_user_or_guest), + rate_limiter: RateLimiter = Depends(get_rate_limiter) + ): + await apply_custom_rate_limiting( + request, current_user, rate_limiter, + guest_per_minute, user_per_minute, admin_per_minute + ) + return True + + return rate_limit_dependency + +# ============================ +# Rate Limiting Utilities +# ============================ + +class EndpointRateLimiter: + """Utility class for endpoint-specific rate limiting""" + + def __init__(self, rate_limiter: RateLimiter): + self.rate_limiter = rate_limiter + self.custom_limits = {} + + def set_endpoint_limits(self, endpoint: str, limits: dict): + """Set custom limits for an endpoint""" + self.custom_limits[endpoint] = limits + + async def check_endpoint_limit(self, request: Request, current_user) -> bool: + """Check if request exceeds endpoint-specific limits""" + endpoint = request.url.path + + if endpoint not in self.custom_limits: + return True # No custom limits set + + limits = self.custom_limits[endpoint] + user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type) + + if getattr(current_user, 'is_admin', False): + user_type = "admin" + + limit = limits.get(user_type, limits.get("default", 60)) + + current_time = datetime.now(UTC) + key = f"endpoint_limit:{endpoint}:{user_type}:{current_user.id}:minute:{current_time.strftime('%Y%m%d%H%M')}" + + current_count = int(await self.rate_limiter.redis.get(key) or 0) + + if current_count >= limit: + raise HTTPException( + status_code=429, + detail=f"Endpoint rate limit exceeded: {current_count}/{limit} for {endpoint}" + ) + + # Increment counter + await self.rate_limiter.redis.incr(key) + await self.rate_limiter.redis.expire(key, 120) + + return True + +# Global endpoint rate limiter instance +endpoint_rate_limiter = None + +def get_endpoint_rate_limiter(rate_limiter: RateLimiter = Depends(get_rate_limiter)) -> EndpointRateLimiter: + """Get endpoint rate limiter instance""" + global endpoint_rate_limiter + if endpoint_rate_limiter is None: + endpoint_rate_limiter = EndpointRateLimiter(rate_limiter) + + # Configure endpoint-specific limits + endpoint_rate_limiter.set_endpoint_limits("/api/1.0/chat/sessions/*/messages/stream", { + "guest": 5, "candidate": 30, "employer": 30, "admin": 100 + }) + endpoint_rate_limiter.set_endpoint_limits("/api/1.0/candidates/documents/upload", { + "guest": 2, "candidate": 10, "employer": 10, "admin": 50 + }) + endpoint_rate_limiter.set_endpoint_limits("/api/1.0/jobs", { + "guest": 1, "candidate": 5, "employer": 20, "admin": 50 + }) + + return endpoint_rate_limiter + @api_router.post("/candidates/{candidate_id}/skill-match") async def get_candidate_skill_match( candidate_id: str = Path(...), requirement: str = Body(...), - current_user = Depends(get_current_user), + current_user = Depends(get_current_user_or_guest), database: RedisDatabase = Depends(get_database) ): """Get skill match for a candidate against a requirement with caching""" @@ -3887,16 +4605,18 @@ async def get_candidate_skill_match( content=create_error_response("SKILL_MATCH_ERROR", str(e)) ) +@rate_limited(guest_per_minute=5, user_per_minute=30, admin_per_minute=100) @api_router.get("/candidates/{username}/chat-sessions") async def get_candidate_chat_sessions( username: str = Path(...), - current_user = Depends(get_current_user), + current_user = Depends(get_current_user_or_guest), page: int = Query(1, ge=1), limit: int = Query(20, ge=1, le=100), database: RedisDatabase = Depends(get_database) ): """Get all chat sessions related to a specific candidate""" try: + logger.info(f"๐Ÿ” Fetching chat sessions for candidate with username: {username}") # Find candidate by username all_candidates_data = await database.get_all_candidates() candidates_list = [Candidate.model_validate(data) for data in all_candidates_data.values()] @@ -3921,6 +4641,10 @@ async def get_candidate_chat_sessions( for index, session_data in enumerate(all_sessions_data.values()): try: session = ChatSession.model_validate(session_data) + if session.user_id != current_user.id: + # User can only access their own sessions + logger.info(f"๐Ÿ”— Skipping session {session.id} - not owned by user {current_user.id} (created by {session.user_id})") + continue # Check if this session is related to the candidate context = session.context if (context and @@ -3969,7 +4693,7 @@ async def get_candidate_chat_sessions( # ============================ # @api_router.get("/admin/verification-stats") async def get_verification_statistics( - current_user = Depends(get_current_user), + current_user = Depends(get_current_admin), database: RedisDatabase = Depends(get_database) ): """Get verification statistics (admin only)""" @@ -3993,7 +4717,7 @@ async def get_verification_statistics( @api_router.post("/admin/cleanup-verifications") async def cleanup_verification_tokens( - current_user = Depends(get_current_user), + current_user = Depends(get_current_admin), database: RedisDatabase = Depends(get_database) ): """Manually trigger cleanup of expired verification tokens (admin only)""" @@ -4019,7 +4743,7 @@ async def cleanup_verification_tokens( @api_router.get("/admin/pending-verifications") async def get_pending_verifications( - current_user = Depends(get_current_user), + current_user = Depends(get_current_admin), page: int = Query(1, ge=1), limit: int = Query(20, ge=1, le=100), database: RedisDatabase = Depends(get_database) @@ -4078,7 +4802,175 @@ async def get_pending_verifications( status_code=500, content=create_error_response("FETCH_ERROR", str(e)) ) + +@api_router.get("/admin/rate-limits/info") +async def get_user_rate_limit_status( + current_user = Depends(get_current_user_or_guest), + rate_limiter: RateLimiter = Depends(get_rate_limiter), + database: RedisDatabase = Depends(get_database) +): + """Get rate limit status for a user (admin only)""" + try: + # Get user to determine type + user_data = await database.get_user_by_id(current_user.id) + if not user_data: + return JSONResponse( + status_code=404, + content=create_error_response("USER_NOT_FOUND", "User not found") + ) + user_type = user_data.get("type", "unknown") + is_admin = False + + if user_type == "candidate": + candidate_data = await database.get_candidate(current_user.id) + if candidate_data: + is_admin = candidate_data.get("is_admin", False) + elif user_type == "employer": + employer_data = await database.get_employer(current_user.id) + if employer_data: + is_admin = employer_data.get("is_admin", False) + + status = await rate_limiter.get_user_rate_limit_status(current_user.id, user_type, is_admin) + + return create_success_response(status) + + except Exception as e: + logger.error(f"โŒ Get rate limit status error: {e}") + return JSONResponse( + status_code=500, + content=create_error_response("STATUS_ERROR", str(e)) + ) + +@api_router.get("/admin/rate-limits/{user_id}") +async def get_anyone_rate_limit_status( + user_id: str = Path(...), + admin_user = Depends(get_current_admin), + rate_limiter: RateLimiter = Depends(get_rate_limiter), + database: RedisDatabase = Depends(get_database) +): + """Get rate limit status for a user (admin only)""" + try: + # Get user to determine type + user_data = await database.get_user_by_id(user_id) + if not user_data: + return JSONResponse( + status_code=404, + content=create_error_response("USER_NOT_FOUND", "User not found") + ) + + user_type = user_data.get("type", "unknown") + is_admin = False + + if user_type == "candidate": + candidate_data = await database.get_candidate(user_id) + if candidate_data: + is_admin = candidate_data.get("is_admin", False) + elif user_type == "employer": + employer_data = await database.get_employer(user_id) + if employer_data: + is_admin = employer_data.get("is_admin", False) + + status = await rate_limiter.get_user_rate_limit_status(user_id, user_type, is_admin) + + return create_success_response(status) + + except Exception as e: + logger.error(f"โŒ Get rate limit status error: {e}") + return JSONResponse( + status_code=500, + content=create_error_response("STATUS_ERROR", str(e)) + ) + +@api_router.post("/admin/rate-limits/{user_id}/reset") +async def reset_user_rate_limits( + user_id: str = Path(...), + admin_user = Depends(get_current_admin), + rate_limiter: RateLimiter = Depends(get_rate_limiter), + database: RedisDatabase = Depends(get_database) +): + """Reset rate limits for a user (admin only)""" + try: + # Get user to determine type + user_data = await database.get_user_by_id(user_id) + if not user_data: + return JSONResponse( + status_code=404, + content=create_error_response("USER_NOT_FOUND", "User not found") + ) + + user_type = user_data.get("type", "unknown") + success = await rate_limiter.reset_user_rate_limits(user_id, user_type) + + if success: + logger.info(f"๐Ÿ”„ Rate limits reset for {user_type} {user_id} by admin {admin_user.id}") + return create_success_response({ + "message": f"Rate limits reset for {user_type} {user_id}", + "resetBy": admin_user.id + }) + else: + return JSONResponse( + status_code=500, + content=create_error_response("RESET_FAILED", "Failed to reset rate limits") + ) + + except Exception as e: + logger.error(f"โŒ Reset rate limits error: {e}") + return JSONResponse( + status_code=500, + content=create_error_response("RESET_ERROR", str(e)) + ) + +# ============================ +# Debugging Endpoints +# ============================ +@api_router.get("/debug/guest/{guest_id}") +async def debug_guest_session( + guest_id: str = Path(...), + admin_user = Depends(get_current_admin), + database: RedisDatabase = Depends(get_database) +): + """Debug guest session issues (admin only)""" + try: + # Check primary storage + primary_data = await database.redis.hget("guests", guest_id) + primary_exists = primary_data is not None + + # Check backup storage + backup_data = await database.redis.get(f"guest_backup:{guest_id}") + backup_exists = backup_data is not None + + # Check user lookup + user_lookup = await database.get_user_by_id(guest_id) + + # Get TTL info + primary_ttl = await database.redis.ttl(f"guests") + backup_ttl = await database.redis.ttl(f"guest_backup:{guest_id}") + + debug_info = { + "guest_id": guest_id, + "primary_storage": { + "exists": primary_exists, + "data": json.loads(primary_data) if primary_data else None, + "ttl": primary_ttl + }, + "backup_storage": { + "exists": backup_exists, + "data": json.loads(backup_data) if backup_data else None, + "ttl": backup_ttl + }, + "user_lookup": user_lookup, + "timestamp": datetime.now(UTC).isoformat() + } + + return create_success_response(debug_info) + + except Exception as e: + logger.error(f"โŒ Debug guest session error: {e}") + return JSONResponse( + status_code=500, + content=create_error_response("DEBUG_ERROR", str(e)) + ) # ============================ # Health Check and Info Endpoints # ============================ @@ -4144,8 +5036,11 @@ async def redis_stats(redis: redis.Redis = Depends(get_redis)): @api_router.get("/system-info") async def get_system_info(request: Request): + """Get system information""" from system_info import system_info # Import system_info function from system_info module - return JSONResponse(system_info()) + system = system_info() + + return create_success_response(system.model_dump(mode='json')) @api_router.get("/") async def api_info(): @@ -4158,6 +5053,266 @@ async def api_info(): "health": f"{defines.api_prefix}/health" } +# ============================ +# Manual Task Execution Endpoints (Admin Only) +# ============================ +# Global background task manager +background_task_manager: Optional[BackgroundTaskManager] = None + +@asynccontextmanager +async def enhanced_lifespan(app: FastAPI): + # Startup + global background_task_manager + + logger.info("๐Ÿš€ Starting Backstory API with enhanced background tasks") + logger.info(f"๐Ÿ“ API Documentation available at: http://{defines.host}:{defines.port}{defines.api_prefix}/docs") + logger.info("๐Ÿ”— API endpoints prefixed with: /api/1.0") + if os.path.exists(defines.static_content): + logger.info(f"๐Ÿ“ Serving static files from: {defines.static_content}") + + try: + # Initialize database + await db_manager.initialize() + entities.entity_manager.initialize(prometheus_collector, database=db_manager.get_database()) + + # Initialize background task manager + background_task_manager = BackgroundTaskManager(db_manager) + background_task_manager.start() + + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + + logger.info("๐Ÿš€ Application startup completed with background tasks") + + yield # Application is running + + except Exception as e: + logger.error(f"โŒ Failed to start application: {e}") + raise + + finally: + # Shutdown + logger.info("Application shutdown requested") + + # Stop background tasks first + if background_task_manager: + background_task_manager.stop() + + await db_manager.graceful_shutdown() + +# ============================ +# Manual Task Execution Endpoints (Admin Only) +# ============================ + + +# ============================ +# Task Monitoring and Metrics +# ============================ + +@api_router.post("/admin/tasks/cleanup-guests") +async def manual_guest_cleanup( + inactive_hours: int = Body(24, embed=True), + current_user = Depends(get_current_admin), + admin_user = Depends(get_current_admin) +): + """Manually trigger guest cleanup (admin only)""" + try: + global background_task_manager + + if not background_task_manager: + return JSONResponse( + status_code=500, + content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available") + ) + + cleaned_count = await background_task_manager.cleanup_inactive_guests(inactive_hours) + + logger.info(f"๐Ÿงน Manual guest cleanup triggered by admin {admin_user.id}: {cleaned_count} guests cleaned") + + return create_success_response({ + "message": f"Guest cleanup completed. Removed {cleaned_count} inactive sessions.", + "cleaned_count": cleaned_count, + "triggered_by": admin_user.id + }) + + except Exception as e: + logger.error(f"โŒ Manual guest cleanup error: {e}") + return JSONResponse( + status_code=500, + content=create_error_response("CLEANUP_ERROR", str(e)) + ) + +@api_router.post("/admin/tasks/cleanup-tokens") +async def manual_token_cleanup( + admin_user = Depends(get_current_admin) +): + """Manually trigger verification token cleanup (admin only)""" + try: + global background_task_manager + + if not background_task_manager: + return JSONResponse( + status_code=500, + content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available") + ) + + cleaned_count = await background_task_manager.cleanup_expired_verification_tokens() + + logger.info(f"๐Ÿงน Manual token cleanup triggered by admin {admin_user.id}: {cleaned_count} tokens cleaned") + + return create_success_response({ + "message": f"Token cleanup completed. Removed {cleaned_count} expired tokens.", + "cleaned_count": cleaned_count, + "triggered_by": admin_user.id + }) + + except Exception as e: + logger.error(f"โŒ Manual token cleanup error: {e}") + return JSONResponse( + status_code=500, + content=create_error_response("CLEANUP_ERROR", str(e)) + ) + +@api_router.post("/admin/tasks/cleanup-rate-limits") +async def manual_rate_limit_cleanup( + days_old: int = Body(7, embed=True), + admin_user = Depends(get_current_admin) +): + """Manually trigger rate limit data cleanup (admin only)""" + try: + global background_task_manager + + if not background_task_manager: + return JSONResponse( + status_code=500, + content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available") + ) + + cleaned_count = await background_task_manager.cleanup_old_rate_limit_data(days_old) + + logger.info(f"๐Ÿงน Manual rate limit cleanup triggered by admin {admin_user.id}: {cleaned_count} keys cleaned") + + return create_success_response({ + "message": f"Rate limit cleanup completed. Removed {cleaned_count} old keys.", + "cleaned_count": cleaned_count, + "triggered_by": admin_user.id + }) + + except Exception as e: + logger.error(f"โŒ Manual rate limit cleanup error: {e}") + return JSONResponse( + status_code=500, + content=create_error_response("CLEANUP_ERROR", str(e)) + ) + +@api_router.get("/admin/tasks/status") +async def get_background_task_status( + admin_user = Depends(get_current_admin) +): + """Get background task manager status (admin only)""" + try: + global background_task_manager + + if not background_task_manager: + return create_success_response({ + "running": False, + "message": "Background task manager not initialized" + }) + + # Get next scheduled run times + next_runs = [] + for job in schedule.jobs: + next_runs.append({ + "job": str(job.job_func), + "next_run": job.next_run.isoformat() if job.next_run else None + }) + + return create_success_response({ + "running": background_task_manager.running, + "scheduler_thread_alive": background_task_manager.scheduler_thread.is_alive() if background_task_manager.scheduler_thread else False, + "scheduled_jobs": len(schedule.jobs), + "next_runs": next_runs + }) + + except Exception as e: + logger.error(f"โŒ Get task status error: {e}") + return JSONResponse( + status_code=500, + content=create_error_response("STATUS_ERROR", str(e)) + ) + + +# ============================ +# Task Monitoring and Metrics +# ============================ + +class TaskMetrics: + """Collect metrics for background tasks""" + + def __init__(self): + self.task_runs = {} + self.task_durations = {} + self.task_errors = {} + + def record_task_run(self, task_name: str, duration: float, success: bool = True): + """Record a task execution""" + if task_name not in self.task_runs: + self.task_runs[task_name] = 0 + self.task_durations[task_name] = [] + self.task_errors[task_name] = 0 + + self.task_runs[task_name] += 1 + self.task_durations[task_name].append(duration) + + if not success: + self.task_errors[task_name] += 1 + + # Keep only last 100 durations to prevent memory growth + if len(self.task_durations[task_name]) > 100: + self.task_durations[task_name] = self.task_durations[task_name][-100:] + + def get_metrics(self) -> dict: + """Get task metrics summary""" + metrics = {} + + for task_name in self.task_runs: + durations = self.task_durations[task_name] + avg_duration = sum(durations) / len(durations) if durations else 0 + + metrics[task_name] = { + "total_runs": self.task_runs[task_name], + "total_errors": self.task_errors[task_name], + "success_rate": (self.task_runs[task_name] - self.task_errors[task_name]) / self.task_runs[task_name] if self.task_runs[task_name] > 0 else 0, + "average_duration": avg_duration, + "last_runs": durations[-10:] if durations else [] + } + + return metrics + +# Global task metrics +task_metrics = TaskMetrics() + +@api_router.get("/admin/tasks/metrics") +async def get_task_metrics( + admin_user = Depends(get_current_admin) +): + """Get background task metrics (admin only)""" + try: + global task_metrics + metrics = task_metrics.get_metrics() + + return create_success_response({ + "metrics": metrics, + "timestamp": datetime.now(UTC).isoformat() + }) + + except Exception as e: + logger.error(f"โŒ Get task metrics error: {e}") + return JSONResponse( + status_code=500, + content=create_error_response("METRICS_ERROR", str(e)) + ) + # ============================ # Include Router in App # ============================ @@ -4174,7 +5329,7 @@ logger.info(f"Debug mode is {'enabled' if defines.debug else 'disabled'}") async def log_requests(request: Request, call_next): try: if defines.debug and not re.match(rf"{defines.api_prefix}/metrics", request.url.path): - logger.info(f"๐Ÿ“ Request {request.method}: {request.url.path}, Remote: {request.client.host}") + logger.info(f"๐Ÿ“ Request {request.method}: {request.url.path}, Remote: {request.client.host if request.client else ''}") response = await call_next(request) if defines.debug and not re.match(rf"{defines.api_prefix}/metrics", request.url.path): if response.status_code < 200 or response.status_code >= 300: diff --git a/src/backend/model_cast.py b/src/backend/model_cast.py index 19d0ea4..9a31226 100644 --- a/src/backend/model_cast.py +++ b/src/backend/model_cast.py @@ -1,7 +1,14 @@ from typing import Type, TypeVar -from pydantic import BaseModel # type: ignore +from pydantic import BaseModel import copy +from models import Candidate, CandidateAI, Employer, Guest, BaseUserWithType + +# Ensure all user models inherit from BaseUserWithType +assert issubclass(Candidate, BaseUserWithType), "Candidate must inherit from BaseUserWithType" +assert issubclass(CandidateAI, BaseUserWithType), "CandidateAI must inherit from BaseUserWithType" +assert issubclass(Employer, BaseUserWithType), "Employer must inherit from BaseUserWithType" +assert issubclass(Guest, BaseUserWithType), "Guest must inherit from BaseUserWithType" T = TypeVar('T', bound=BaseModel) @@ -12,3 +19,35 @@ def cast_to_model(model_cls: Type[T], source: BaseModel) -> T: def cast_to_model_safe(model_cls: Type[T], source: BaseModel) -> T: data = {field: copy.deepcopy(getattr(source, field)) for field in model_cls.__fields__} return model_cls(**data) + +def cast_to_base_user_with_type(user) -> BaseUserWithType: + """ + Casts a Candidate, CandidateAI, Employer, or Guest to BaseUserWithType. + This is useful for FastAPI dependencies that expect a common user type. + """ + if isinstance(user, BaseUserWithType): + return user + # If it's a dict, try to detect type + if isinstance(user, dict): + user_type = user.get("user_type") or user.get("type") + if user_type == "candidate": + if user.get("is_AI"): + return CandidateAI.model_validate(user) + return Candidate.model_validate(user) + elif user_type == "employer": + return Employer.model_validate(user) + elif user_type == "guest": + return Guest.model_validate(user) + else: + raise ValueError(f"Unknown user_type: {user_type}") + # If it's a model, check its type + if hasattr(user, "user_type"): + if getattr(user, "user_type", None) == "candidate": + if getattr(user, "is_AI", False): + return CandidateAI.model_validate(user.model_dump()) + return Candidate.model_validate(user.model_dump()) + elif getattr(user, "user_type", None) == "employer": + return Employer.model_validate(user.model_dump()) + elif getattr(user, "user_type", None) == "guest": + return Guest.model_validate(user.model_dump()) + raise TypeError(f"Cannot cast object of type {type(user)} to BaseUserWithType") diff --git a/src/backend/models.py b/src/backend/models.py index 4dfa3e0..ee055f7 100644 --- a/src/backend/models.py +++ b/src/backend/models.py @@ -10,6 +10,7 @@ from auth_utils import ( sanitize_login_input, SecurityConfig ) +import defines # Generic type variable T = TypeVar('T') @@ -256,6 +257,7 @@ class MFARequest(BaseModel): password: str device_id: str = Field(..., alias="deviceId") device_name: str = Field(..., alias="deviceName") + email: str = Field(..., alias="email") model_config = { "populate_by_name": True, # Allow both field names and aliases } @@ -464,9 +466,18 @@ class ErrorDetail(BaseModel): # Main Models # ============================ -# Base user model without user_type field -class BaseUser(BaseModel): +# Generic base user with user_type for API responses +class BaseUserWithType(BaseModel): + user_type: UserType = Field(..., alias="userType") id: str = Field(default_factory=lambda: str(uuid.uuid4())) + last_activity: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="lastActivity") + model_config = { + "populate_by_name": True, # Allow both field names and aliases + "use_enum_values": True # Use enum values instead of names + } + +# Base user model without user_type field +class BaseUser(BaseUserWithType): email: EmailStr first_name: str = Field(..., alias="firstName") last_name: str = Field(..., alias="lastName") @@ -485,9 +496,6 @@ class BaseUser(BaseModel): "use_enum_values": True # Use enum values instead of names } -# Generic base user with user_type for API responses -class BaseUserWithType(BaseUser): - user_type: UserType = Field(..., alias="userType") class RagEntry(BaseModel): name: str @@ -519,9 +527,9 @@ class DocumentType(str, Enum): IMAGE = "image" class DocumentOptions(BaseModel): - include_in_RAG: Optional[bool] = Field(True, alias="includeInRAG") - is_job_document: Optional[bool] = Field(False, alias="isJobDocument") - overwrite: Optional[bool] = Field(False, alias="overwrite") + include_in_RAG: bool = Field(default=True, alias="includeInRAG") + is_job_document: Optional[bool] = Field(default=False, alias="isJobDocument") + overwrite: Optional[bool] = Field(default=False, alias="overwrite") model_config = { "populate_by_name": True # Allow both field names and aliases } @@ -534,7 +542,7 @@ class Document(BaseModel): type: DocumentType size: int upload_date: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="uploadDate") - options: DocumentOptions = Field(default_factory=DocumentOptions, alias="options") + options: DocumentOptions = Field(default_factory=lambda: DocumentOptions(), alias="options") rag_chunks: Optional[int] = Field(default=0, alias="ragChunks") model_config = { "populate_by_name": True # Allow both field names and aliases @@ -565,7 +573,7 @@ class DocumentUpdateRequest(BaseModel): } class Candidate(BaseUser): - user_type: Literal[UserType.CANDIDATE] = Field(UserType.CANDIDATE, alias="userType") + user_type: UserType = Field(UserType.CANDIDATE, alias="userType") username: str description: Optional[str] = None resume: Optional[str] = None @@ -584,14 +592,14 @@ class Candidate(BaseUser): rag_content_size : int = 0 class CandidateAI(Candidate): - user_type: Literal[UserType.CANDIDATE] = Field(UserType.CANDIDATE, alias="userType") + user_type: UserType = Field(UserType.CANDIDATE, alias="userType") is_AI: bool = Field(True, alias="isAI") age: Optional[int] = None gender: Optional[UserGender] = None ethnicity: Optional[str] = None class Employer(BaseUser): - user_type: Literal[UserType.EMPLOYER] = Field(UserType.EMPLOYER, alias="userType") + user_type: UserType = Field(UserType.EMPLOYER, alias="userType") company_name: str = Field(..., alias="companyName") industry: str description: Optional[str] = None @@ -603,16 +611,18 @@ class Employer(BaseUser): social_links: Optional[List[SocialLink]] = Field(None, alias="socialLinks") poc: Optional[PointOfContact] = None -class Guest(BaseModel): - id: str = Field(default_factory=lambda: str(uuid.uuid4())) +class Guest(BaseUser): + user_type: UserType = Field(UserType.GUEST, alias="userType") session_id: str = Field(..., alias="sessionId") - created_at: datetime = Field(..., alias="createdAt") - last_activity: datetime = Field(..., alias="lastActivity") + username: str # Add username for consistency with other user types converted_to_user_id: Optional[str] = Field(None, alias="convertedToUserId") ip_address: Optional[str] = Field(None, alias="ipAddress") + created_at: datetime = Field(..., alias="createdAt") user_agent: Optional[str] = Field(None, alias="userAgent") + rag_content_size: int = 0 model_config = { - "populate_by_name": True # Allow both field names and aliases + "populate_by_name": True, # Allow both field names and aliases + "use_enum_values": True # Use enum values instead of names } class Authentication(BaseModel): @@ -635,10 +645,21 @@ class Authentication(BaseModel): class AuthResponse(BaseModel): access_token: str = Field(..., alias="accessToken") refresh_token: str = Field(..., alias="refreshToken") - user: Candidate | Employer + user: Union[Candidate, Employer, Guest] # Add Guest support expires_at: int = Field(..., alias="expiresAt") + user_type: Optional[str] = Field(default=UserType.GUEST, alias="userType") # Explicit user type + is_guest: Optional[bool] = Field(default=True, alias="isGuest") # Guest indicator + model_config = { - "populate_by_name": True # Allow both field names and aliases + "populate_by_name": True + } + +class GuestCleanupRequest(BaseModel): + """Request to cleanup inactive guests""" + inactive_hours: int = Field(24, alias="inactiveHours") + + model_config = { + "populate_by_name": True } class JobRequirements(BaseModel): @@ -751,6 +772,19 @@ class ChromaDBGetResponse(BaseModel): umap_embedding_2d: Optional[List[float]] = Field(default=None, alias="umapEmbedding2D") umap_embedding_3d: Optional[List[float]] = Field(default=None, alias="umapEmbedding3D") +class GuestSessionResponse(BaseModel): + """Response for guest session creation""" + access_token: str = Field(..., alias="accessToken") + refresh_token: str = Field(..., alias="refreshToken") + user: Guest + expires_at: int = Field(..., alias="expiresAt") + user_type: Literal["guest"] = Field("guest", alias="userType") + is_guest: bool = Field(True, alias="isGuest") + + model_config = { + "populate_by_name": True + } + class ChatContext(BaseModel): type: ChatContextType related_entity_id: Optional[str] = Field(None, alias="relatedEntityId") @@ -768,6 +802,98 @@ class ChatOptions(BaseModel): "populate_by_name": True # Allow both field names and aliases } +# Add rate limiting configuration models +class RateLimitConfig(BaseModel): + """Rate limit configuration""" + requests_per_minute: int = Field(..., alias="requestsPerMinute") + requests_per_hour: int = Field(..., alias="requestsPerHour") + requests_per_day: int = Field(..., alias="requestsPerDay") + burst_limit: int = Field(..., alias="burstLimit") + burst_window_seconds: int = Field(60, alias="burstWindowSeconds") + + model_config = { + "populate_by_name": True + } + +class RateLimitResult(BaseModel): + """Result of rate limit check""" + allowed: bool + reason: Optional[str] = None + retry_after_seconds: Optional[int] = Field(None, alias="retryAfterSeconds") + remaining_requests: Dict[str, int] = Field(default_factory=dict, alias="remainingRequests") + reset_times: Dict[str, datetime] = Field(default_factory=dict, alias="resetTimes") + + model_config = { + "populate_by_name": True + } + +class RateLimitStatus(BaseModel): + """Rate limit status for a user""" + user_id: str = Field(..., alias="userId") + user_type: str = Field(..., alias="userType") + is_admin: bool = Field(..., alias="isAdmin") + current_usage: Dict[str, int] = Field(..., alias="currentUsage") + limits: Dict[str, int] = Field(..., alias="limits") + remaining: Dict[str, int] = Field(..., alias="remaining") + reset_times: Dict[str, datetime] = Field(..., alias="resetTimes") + config: RateLimitConfig + + model_config = { + "populate_by_name": True + } + + +# Add guest conversion request models +class GuestConversionRequest(BaseModel): + """Request to convert guest to permanent user""" + account_type: Literal["candidate", "employer"] = Field(..., alias="accountType") + email: EmailStr + username: str + password: str + first_name: str = Field(..., alias="firstName") + last_name: str = Field(..., alias="lastName") + phone: Optional[str] = None + + # Employer-specific fields (optional) + company_name: Optional[str] = Field(None, alias="companyName") + industry: Optional[str] = None + company_size: Optional[str] = Field(None, alias="companySize") + company_description: Optional[str] = Field(None, alias="companyDescription") + website_url: Optional[HttpUrl] = Field(None, alias="websiteUrl") + + model_config = { + "populate_by_name": True + } + + @field_validator('username') + def validate_username(cls, v): + if not v or len(v.strip()) < 3: + raise ValueError('Username must be at least 3 characters long') + return v.strip().lower() + + @field_validator('password') + def validate_password_strength(cls, v): + # Import here to avoid circular imports + from auth_utils import validate_password_strength + is_valid, issues = validate_password_strength(v) + if not is_valid: + raise ValueError('; '.join(issues)) + return v + +# Add guest statistics response model +class GuestStatistics(BaseModel): + """Guest usage statistics""" + total_guests: int = Field(..., alias="totalGuests") + active_last_hour: int = Field(..., alias="activeLastHour") + active_last_day: int = Field(..., alias="activeLastDay") + converted_guests: int = Field(..., alias="convertedGuests") + by_ip: Dict[str, int] = Field(..., alias="byIp") + creation_timeline: Dict[str, int] = Field(..., alias="creationTimeline") + + model_config = { + "populate_by_name": True + } + from llm_proxy import (LLMMessage) class ApiMessage(BaseModel): @@ -800,12 +926,14 @@ class ApiActivityType(str, Enum): HEARTBEAT = "heartbeat" # Used for periodic updates class ChatMessageStatus(ApiMessage): + sender_id: Optional[str] = Field(default=MOCK_UUID, alias="senderId") status: ApiStatusType = ApiStatusType.STATUS type: ApiMessageType = ApiMessageType.TEXT activity: ApiActivityType content: Any class ChatMessageError(ApiMessage): + sender_id: Optional[str] = Field(default=MOCK_UUID, alias="senderId") status: ApiStatusType = ApiStatusType.ERROR type: ApiMessageType = ApiMessageType.TEXT content: str @@ -825,6 +953,7 @@ class JobRequirementsMessage(ApiMessage): class DocumentMessage(ApiMessage): type: ApiMessageType = ApiMessageType.JSON + sender_id: Optional[str] = Field(default=MOCK_UUID, alias="senderId") document: Document = Field(..., alias="document") content: Optional[str] = "" converted: bool = Field(False, alias="converted") @@ -837,9 +966,9 @@ class ChatMessageMetaData(BaseModel): temperature: float = 0.7 max_tokens: int = Field(default=8092, alias="maxTokens") top_p: float = Field(default=1, alias="topP") - frequency_penalty: Optional[float] = Field(None, alias="frequencyPenalty") - presence_penalty: Optional[float] = Field(None, alias="presencePenalty") - stop_sequences: Optional[List[str]] = Field(None, alias="stopSequences") + frequency_penalty: float = Field(default=0, alias="frequencyPenalty") + presence_penalty: float = Field(default=0, alias="presencePenalty") + stop_sequences: List[str] = Field(default=[], alias="stopSequences") rag_results: List[ChromaDBGetResponse] = Field(default_factory=list, alias="ragResults") llm_history: List[LLMMessage] = Field(default_factory=list, alias="llmHistory") eval_count: int = 0 @@ -862,16 +991,31 @@ class ChatMessageUser(ApiMessage): class ChatMessage(ChatMessageUser): role: ChatSenderType = ChatSenderType.ASSISTANT - metadata: ChatMessageMetaData = Field(default_factory=ChatMessageMetaData) + metadata: ChatMessageMetaData = Field(default=ChatMessageMetaData()) #attachments: Optional[List[Attachment]] = None #reactions: Optional[List[MessageReaction]] = None #is_edited: bool = Field(False, alias="isEdited") #edit_history: Optional[List[EditHistory]] = Field(None, alias="editHistory") +class GPUInfo(BaseModel): + name: str + memory: int + discrete: bool + +class SystemInfo(BaseModel): + installed_RAM: str = Field(..., alias="installedRAM") + graphics_cards: List[GPUInfo] = Field(..., alias="graphicsCards") + CPU: str + llm_model: str = Field(default=defines.model, alias="llmModel") + embedding_model: str = Field(default=defines.embedding_model, alias="embeddingModel") + max_context_length: int = Field(default=defines.max_context, alias="maxContextLength") + model_config = { + "populate_by_name": True # Allow both field names and aliases + } + class ChatSession(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) user_id: Optional[str] = Field(None, alias="userId") - guest_id: Optional[str] = Field(None, alias="guestId") created_at: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="createdAt") last_activity: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="lastActivity") title: Optional[str] = None @@ -937,7 +1081,7 @@ class UserActivity(BaseModel): } @model_validator(mode="after") - def check_user_or_guest(self) -> "ChatSession": + def check_user_or_guest(self): if not self.user_id and not self.guest_id: raise ValueError("Either user_id or guest_id must be provided") return self @@ -1102,6 +1246,8 @@ class JobListResponse(BaseModel): error: Optional[ErrorDetail] = None meta: Optional[Dict[str, Any]] = None +User = Union[Candidate, CandidateAI, Employer, Guest] + # Forward references resolution Candidate.update_forward_refs() Employer.update_forward_refs() diff --git a/src/backend/rag/rag.py b/src/backend/rag/rag.py index b6646ed..2907cb6 100644 --- a/src/backend/rag/rag.py +++ b/src/backend/rag/rag.py @@ -1,5 +1,5 @@ from __future__ import annotations -from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field # type: ignore +from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field from typing import List, Optional, Dict, Any, Union import os import glob @@ -9,15 +9,15 @@ import hashlib import asyncio import logging import json -import numpy as np # type: ignore +import numpy as np import traceback -import chromadb # type: ignore -from watchdog.observers import Observer # type: ignore -from watchdog.events import FileSystemEventHandler # type: ignore -import umap # type: ignore -from markitdown import MarkItDown # type: ignore -from chromadb.api.models.Collection import Collection # type: ignore +import chromadb +from watchdog.observers import Observer +from watchdog.events import FileSystemEventHandler +import umap +from markitdown import MarkItDown +from chromadb.api.models.Collection import Collection from .markdown_chunker import ( MarkdownChunker, @@ -351,9 +351,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler): os.makedirs(self.persist_directory) # Initialize ChromaDB client - chroma_client = chromadb.PersistentClient( # type: ignore + chroma_client = chromadb.PersistentClient( path=self.persist_directory, - settings=chromadb.Settings(anonymized_telemetry=False), # type: ignore + settings=chromadb.Settings(anonymized_telemetry=False), ) # Check if the collection exists diff --git a/src/backend/rate_limiter.py b/src/backend/rate_limiter.py new file mode 100644 index 0000000..23e1f0a --- /dev/null +++ b/src/backend/rate_limiter.py @@ -0,0 +1,287 @@ +""" +Rate limiting utilities for guest and authenticated users +""" + +import json +import time +from datetime import datetime, timedelta, UTC +from typing import Dict, Optional, Tuple, Any +from pydantic import BaseModel # type: ignore +from database import RedisDatabase +from logger import logger + +class RateLimitConfig(BaseModel): + """Rate limit configuration""" + requests_per_minute: int + requests_per_hour: int + requests_per_day: int + burst_limit: int # Maximum requests in a short burst + burst_window_seconds: int = 60 # Window for burst detection + +class GuestRateLimitConfig(RateLimitConfig): + """Rate limits for guest users - more restrictive""" + requests_per_minute: int = 10 + requests_per_hour: int = 100 + requests_per_day: int = 500 + burst_limit: int = 15 + burst_window_seconds: int = 60 + +class AuthenticatedUserRateLimitConfig(RateLimitConfig): + """Rate limits for authenticated users - more generous""" + requests_per_minute: int = 60 + requests_per_hour: int = 1000 + requests_per_day: int = 10000 + burst_limit: int = 100 + burst_window_seconds: int = 60 + +class PremiumUserRateLimitConfig(RateLimitConfig): + """Rate limits for premium/admin users - most generous""" + requests_per_minute: int = 120 + requests_per_hour: int = 5000 + requests_per_day: int = 50000 + burst_limit: int = 200 + burst_window_seconds: int = 60 + +class RateLimitResult(BaseModel): + """Result of rate limit check""" + allowed: bool + reason: Optional[str] = None + retry_after_seconds: Optional[int] = None + remaining_requests: Dict[str, int] = {} + reset_times: Dict[str, datetime] = {} + +class RateLimiter: + """Rate limiter using Redis for distributed rate limiting""" + + def __init__(self, database: RedisDatabase): + self.database = database + self.redis = database.redis + + # Rate limit configurations + self.guest_config = GuestRateLimitConfig() + self.user_config = AuthenticatedUserRateLimitConfig() + self.premium_config = PremiumUserRateLimitConfig() + + def get_config_for_user(self, user_type: str, is_admin: bool = False) -> RateLimitConfig: + """Get rate limit configuration based on user type""" + if user_type == "guest": + return self.guest_config + elif is_admin: + return self.premium_config + else: + return self.user_config + + async def check_rate_limit( + self, + user_id: str, + user_type: str, + is_admin: bool = False, + endpoint: Optional[str] = None + ) -> RateLimitResult: + """ + Check if user has exceeded rate limits + + Args: + user_id: Unique identifier for the user (guest session ID or user ID) + user_type: "guest", "candidate", or "employer" + is_admin: Whether user has admin privileges + endpoint: Optional endpoint-specific rate limiting + + Returns: + RateLimitResult indicating if request is allowed + """ + config = self.get_config_for_user(user_type, is_admin) + current_time = datetime.now(UTC) + + # Create Redis keys for different time windows + base_key = f"rate_limit:{user_type}:{user_id}" + keys = { + "minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}", + "hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}", + "day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}", + "burst": f"{base_key}:burst" + } + + # Add endpoint-specific limiting if provided + if endpoint: + keys = {k: f"{v}:{endpoint}" for k, v in keys.items()} + + try: + # Use Redis pipeline for atomic operations + pipe = self.redis.pipeline() + + # Get current counts + for key in keys.values(): + pipe.get(key) + + results = await pipe.execute() + current_counts = { + "minute": int(results[0] or 0), + "hour": int(results[1] or 0), + "day": int(results[2] or 0), + "burst": int(results[3] or 0) + } + + # Check limits + limits = { + "minute": config.requests_per_minute, + "hour": config.requests_per_hour, + "day": config.requests_per_day, + "burst": config.burst_limit + } + + # Check each limit + for window, current_count in current_counts.items(): + limit = limits[window] + if current_count >= limit: + # Calculate retry after time + if window == "minute": + retry_after = 60 - current_time.second + elif window == "hour": + retry_after = 3600 - (current_time.minute * 60 + current_time.second) + elif window == "day": + retry_after = 86400 - (current_time.hour * 3600 + current_time.minute * 60 + current_time.second) + else: # burst + retry_after = config.burst_window_seconds + + logger.warning(f"๐Ÿšซ Rate limit exceeded for {user_type} {user_id}: {current_count}/{limit} {window}") + + return RateLimitResult( + allowed=False, + reason=f"Rate limit exceeded: {current_count}/{limit} requests per {window}", + retry_after_seconds=retry_after, + remaining_requests={k: max(0, limits[k] - v) for k, v in current_counts.items()}, + reset_times=self._calculate_reset_times(current_time) + ) + + # If we get here, request is allowed - increment counters + pipe = self.redis.pipeline() + + # Increment minute counter (expires after 2 minutes) + pipe.incr(keys["minute"]) + pipe.expire(keys["minute"], 120) + + # Increment hour counter (expires after 2 hours) + pipe.incr(keys["hour"]) + pipe.expire(keys["hour"], 7200) + + # Increment day counter (expires after 2 days) + pipe.incr(keys["day"]) + pipe.expire(keys["day"], 172800) + + # Increment burst counter (expires after burst window) + pipe.incr(keys["burst"]) + pipe.expire(keys["burst"], config.burst_window_seconds) + + await pipe.execute() + + # Calculate remaining requests + remaining = { + k: max(0, limits[k] - (current_counts[k] + 1)) + for k in current_counts.keys() + } + + logger.debug(f"โœ… Rate limit check passed for {user_type} {user_id}") + + return RateLimitResult( + allowed=True, + remaining_requests=remaining, + reset_times=self._calculate_reset_times(current_time) + ) + + except Exception as e: + logger.error(f"โŒ Rate limit check failed for {user_id}: {e}") + # Fail open - allow request if rate limiting system fails + return RateLimitResult(allowed=True, reason="Rate limit check failed - allowing request") + + def _calculate_reset_times(self, current_time: datetime) -> Dict[str, datetime]: + """Calculate when each rate limit window resets""" + next_minute = current_time.replace(second=0, microsecond=0) + timedelta(minutes=1) + next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1) + next_day = current_time.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1) + + return { + "minute": next_minute, + "hour": next_hour, + "day": next_day + } + + async def get_user_rate_limit_status( + self, + user_id: str, + user_type: str, + is_admin: bool = False + ) -> Dict[str, Any]: + """Get current rate limit status for a user""" + config = self.get_config_for_user(user_type, is_admin) + current_time = datetime.now(UTC) + + base_key = f"rate_limit:{user_type}:{user_id}" + keys = { + "minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}", + "hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}", + "day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}", + "burst": f"{base_key}:burst" + } + + try: + pipe = self.redis.pipeline() + for key in keys.values(): + pipe.get(key) + + results = await pipe.execute() + current_counts = { + "minute": int(results[0] or 0), + "hour": int(results[1] or 0), + "day": int(results[2] or 0), + "burst": int(results[3] or 0) + } + + limits = { + "minute": config.requests_per_minute, + "hour": config.requests_per_hour, + "day": config.requests_per_day, + "burst": config.burst_limit + } + + return { + "user_id": user_id, + "user_type": user_type, + "is_admin": is_admin, + "current_usage": current_counts, + "limits": limits, + "remaining": {k: max(0, limits[k] - current_counts[k]) for k in limits.keys()}, + "reset_times": self._calculate_reset_times(current_time), + "config": config.model_dump() + } + + except Exception as e: + logger.error(f"โŒ Failed to get rate limit status for {user_id}: {e}") + return {"error": str(e)} + + async def reset_user_rate_limits(self, user_id: str, user_type: str) -> bool: + """Reset all rate limits for a user (admin function)""" + try: + base_key = f"rate_limit:{user_type}:{user_id}" + pattern = f"{base_key}:*" + + cursor = 0 + deleted_count = 0 + + while True: + cursor, keys = await self.redis.scan(cursor, match=pattern, count=100) + if keys: + await self.redis.delete(*keys) + deleted_count += len(keys) + + if cursor == 0: + break + + logger.info(f"๐Ÿ”„ Reset {deleted_count} rate limit keys for {user_type} {user_id}") + return True + + except Exception as e: + logger.error(f"โŒ Failed to reset rate limits for {user_id}: {e}") + return False + + \ No newline at end of file diff --git a/src/backend/system_info.py b/src/backend/system_info.py index 53d481b..46e5dc3 100644 --- a/src/backend/system_info.py +++ b/src/backend/system_info.py @@ -2,6 +2,7 @@ import defines import re import subprocess import math +from models import SystemInfo def get_installed_ram(): try: @@ -70,12 +71,18 @@ def get_cpu_info(): except Exception as e: return f"Error retrieving CPU info: {e}" -def system_info(): - return { - "System RAM": get_installed_ram(), - "Graphics Card": get_graphics_cards(), - "CPU": get_cpu_info(), - "LLM Model": defines.model, - "Embedding Model": defines.embedding_model, - "Context length": defines.max_context, - } \ No newline at end of file +def system_info() -> SystemInfo: + """ + Collects system information including RAM, GPU, CPU, LLM model, embedding model, and context length. + Returns: + SystemInfo: An object containing the collected system information. + """ + system = SystemInfo( + installed_RAM=get_installed_ram(), + graphics_cards=get_graphics_cards(), + CPU=get_cpu_info(), + llm_model=defines.model, + embedding_model=defines.embedding_model, + max_context_length=defines.max_context, + ) + return system \ No newline at end of file diff --git a/src/backend/tools/basetools.py b/src/backend/tools/basetools.py index 6598bf5..d90360c 100644 --- a/src/backend/tools/basetools.py +++ b/src/backend/tools/basetools.py @@ -1,5 +1,5 @@ import os -from pydantic import BaseModel, Field, model_validator # type: ignore +from pydantic import BaseModel, Field, model_validator from typing import List, Optional, Generator, ClassVar, Any, Dict from datetime import datetime from typing import ( @@ -7,12 +7,12 @@ from typing import ( ) from typing_extensions import Annotated -from bs4 import BeautifulSoup # type: ignore +from bs4 import BeautifulSoup -from geopy.geocoders import Nominatim # type: ignore -import pytz # type: ignore +from geopy.geocoders import Nominatim +import pytz import requests -import yfinance as yf # type: ignore +import yfinance as yf import logging @@ -523,4 +523,4 @@ def enabled_tools(tools: List[ToolEntry]) -> List[ToolEntry]: tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite", "GenerateImage"] __all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"] -# __all__.extend(__tool_functions__) # type: ignore + diff --git a/src/multi-llm/example.py b/src/multi-llm/example.py index 30f17cc..99c8ac1 100644 --- a/src/multi-llm/example.py +++ b/src/multi-llm/example.py @@ -1,6 +1,6 @@ -from fastapi import FastAPI, HTTPException # type: ignore -from fastapi.responses import StreamingResponse # type: ignore -from pydantic import BaseModel # type: ignore +from fastapi import FastAPI, HTTPException +from fastapi.responses import StreamingResponse +from pydantic import BaseModel from typing import List, Optional, Dict, Any import json import asyncio @@ -234,5 +234,5 @@ async def health_check(): } if __name__ == "__main__": - import uvicorn # type: ignore + import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/src/multi-llm/llm_proxy.py b/src/multi-llm/llm_proxy.py index 2057760..a130a9f 100644 --- a/src/multi-llm/llm_proxy.py +++ b/src/multi-llm/llm_proxy.py @@ -92,7 +92,7 @@ class OllamaAdapter(BaseLLMAdapter): def __init__(self, **config): super().__init__(**config) import ollama - self.client = ollama.AsyncClient( # type: ignore + self.client = ollama.AsyncClient( host=config.get('host', 'http://localhost:11434') ) @@ -187,7 +187,7 @@ class OpenAIAdapter(BaseLLMAdapter): def __init__(self, **config): super().__init__(**config) - import openai # type: ignore + import openai self.client = openai.AsyncOpenAI( api_key=config.get('api_key', os.getenv('OPENAI_API_KEY')) ) @@ -259,7 +259,7 @@ class AnthropicAdapter(BaseLLMAdapter): def __init__(self, **config): super().__init__(**config) - import anthropic # type: ignore + import anthropic self.client = anthropic.AsyncAnthropic( api_key=config.get('api_key', os.getenv('ANTHROPIC_API_KEY')) ) @@ -344,7 +344,7 @@ class GeminiAdapter(BaseLLMAdapter): def __init__(self, **config): super().__init__(**config) - import google.generativeai as genai # type: ignore + import google.generativeai as genai genai.configure(api_key=config.get('api_key', os.getenv('GEMINI_API_KEY'))) self.genai = genai @@ -476,7 +476,7 @@ class UnifiedLLMProxy: result = await self.chat(model, messages, provider, stream=True, **kwargs) # Type checker now knows this is an AsyncGenerator due to stream=True - async for chunk in result: # type: ignore + async for chunk in result: yield chunk async def chat_single( @@ -490,7 +490,7 @@ class UnifiedLLMProxy: result = await self.chat(model, messages, provider, stream=False, **kwargs) # Type checker now knows this is a ChatResponse due to stream=False - return result # type: ignore + return result async def generate( self, @@ -517,7 +517,7 @@ class UnifiedLLMProxy: """Stream text generation using specified or default provider""" result = await self.generate(model, prompt, provider, stream=True, **kwargs) - async for chunk in result: # type: ignore + async for chunk in result: yield chunk async def generate_single( @@ -530,7 +530,7 @@ class UnifiedLLMProxy: """Get single generation response using specified or default provider""" result = await self.generate(model, prompt, provider, stream=False, **kwargs) - return result # type: ignore + return result async def list_models(self, provider: Optional[LLMProvider] = None) -> List[str]: """List available models for specified or default provider""" diff --git a/src/server.py b/src/server.py index 1ea088e..046d70d 100644 --- a/src/server.py +++ b/src/server.py @@ -1,8 +1,8 @@ LLM_TIMEOUT = 600 from utils import logger -from pydantic import BaseModel, Field, ValidationError # type: ignore -from pydantic_core import PydanticSerializationError # type: ignore +from pydantic import BaseModel, Field, ValidationError +from pydantic_core import PydanticSerializationError from typing import List from typing import AsyncGenerator, Dict, Optional @@ -48,18 +48,18 @@ try_import("prometheus_fastapi_instrumentator") import ollama from contextlib import asynccontextmanager -from fastapi import FastAPI, Request, HTTPException, Depends # type: ignore -from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse # type: ignore -from fastapi.middleware.cors import CORSMiddleware # type: ignore -import uvicorn # type: ignore -import numpy as np # type: ignore +from fastapi import FastAPI, Request, HTTPException, Depends +from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse +from fastapi.middleware.cors import CORSMiddleware +import uvicorn +import numpy as np from utils import redis_manager -import redis.asyncio as redis # type: ignore +import redis.asyncio as redis # Prometheus -from prometheus_client import Summary # type: ignore -from prometheus_fastapi_instrumentator import Instrumentator # type: ignore -from prometheus_client import CollectorRegistry, Counter # type: ignore +from prometheus_client import Summary +from prometheus_fastapi_instrumentator import Instrumentator +from prometheus_client import CollectorRegistry, Counter from utils import ( rag as Rag, @@ -1308,7 +1308,7 @@ def main(): warnings.filterwarnings("ignore", category=UserWarning, module="umap.*") - llm = ollama.Client(host=args.ollama_server) # type: ignore + llm = ollama.Client(host=args.ollama_server) web_server = WebServer(llm, args.ollama_model) web_server.run(host=args.web_host, port=args.web_port, use_reloader=False) diff --git a/src/tests/test-chunker.py b/src/tests/test-chunker.py index 6bad6f3..4d9c447 100644 --- a/src/tests/test-chunker.py +++ b/src/tests/test-chunker.py @@ -5,7 +5,7 @@ import sys import os sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../utils"))) -from markdown_chunker import MarkdownChunker # type: ignore +from markdown_chunker import MarkdownChunker chunker = MarkdownChunker() chunks = chunker.process_file("./src/tests/test.md") # docs/resume/resume.md") diff --git a/src/tests/test-context-routing.py b/src/tests/test-context-routing.py index 7ed01e2..0b14f1c 100644 --- a/src/tests/test-context-routing.py +++ b/src/tests/test-context-routing.py @@ -1,10 +1,10 @@ -from fastapi import FastAPI, Request, Depends, Query # type: ignore -from fastapi.responses import RedirectResponse, JSONResponse # type: ignore +from fastapi import FastAPI, Request, Depends, Query +from fastapi.responses import RedirectResponse, JSONResponse from uuid import UUID, uuid4 import logging import traceback from typing import Callable, Optional -from anyio.to_thread import run_sync # type: ignore +from anyio.to_thread import run_sync logger = logging.getLogger(__name__) @@ -69,7 +69,7 @@ class ContextRouteManager: logger.info(f"Invalid UUID, redirecting to {redirect_url}") raise RedirectToContext(redirect_url) - return _ensure_context_dependency # type: ignore + return _ensure_context_dependency def route_pattern(self, path: str, *dependencies, **kwargs): logger.info(f"Registering route: {path}") @@ -134,6 +134,6 @@ async def redirect_history( if __name__ == "__main__": - import uvicorn # type: ignore + import uvicorn uvicorn.run(app, host="0.0.0.0", port=8900) diff --git a/src/tests/test-embedding.py b/src/tests/test-embedding.py index 1b84544..c1a54d9 100644 --- a/src/tests/test-embedding.py +++ b/src/tests/test-embedding.py @@ -1,9 +1,9 @@ # From /opt/backstory run: # python -m src.tests.test-embedding -import numpy as np # type: ignore +import numpy as np import logging import argparse -from ollama import Client # type: ignore +from ollama import Client from ..utils import defines # Configure logging diff --git a/src/tests/test-rag.py b/src/tests/test-rag.py index 99911e8..a20fe38 100644 --- a/src/tests/test-rag.py +++ b/src/tests/test-rag.py @@ -1,11 +1,11 @@ # From /opt/backstory run: # python -m src.tests.test-rag from ..utils import logger -from pydantic import BaseModel, field_validator # type: ignore -from prometheus_client import CollectorRegistry # type: ignore +from pydantic import BaseModel, field_validator +from prometheus_client import CollectorRegistry from typing import List, Dict, Any, Optional import ollama -import numpy as np # type: ignore +import numpy as np from ..utils import (rag as Rag, ChromaDBGetResponse) from ..utils import Context from ..utils import defines @@ -44,7 +44,7 @@ rag.embeddings = np.array([[1.0, 2.0], [3.0, 4.0]]) json_str = rag.model_dump(mode="json") logger.info(json_str) rag = ChromaDBGetResponse.model_validate(json_str) -llm = ollama.Client(host=defines.ollama_api_url) # type: ignore +llm = ollama.Client(host=defines.ollama_api_url) prometheus_collector = CollectorRegistry() observer, file_watcher = Rag.start_file_watcher( llm=llm, diff --git a/src/utils/__init__.py b/src/utils/__init__.py index fb32f5b..fb6dc7c 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,5 +1,5 @@ from __future__ import annotations -from pydantic import BaseModel # type: ignore +from pydantic import BaseModel from typing import ( Any, Set @@ -39,7 +39,7 @@ __all__ = [ "generate_image", "ImageRequest" ] -__all__.extend(agents_all) # type: ignore +__all__.extend(agents_all) logger = setup_logging(level=defines.logging_level) diff --git a/src/utils/agents/__init__.py b/src/utils/agents/__init__.py index 8a3225c..0592660 100644 --- a/src/utils/agents/__init__.py +++ b/src/utils/agents/__init__.py @@ -44,7 +44,7 @@ for path in package_dir.glob("*.py"): class_registry[name] = (full_module_name, name) globals()[name] = obj logger.info(f"Adding agent: {name}") - __all__.append(name) # type: ignore + __all__.append(name) except ImportError as e: logger.error(f"Error importing {full_module_name}: {e}") raise e diff --git a/src/utils/agents/base.py b/src/utils/agents/base.py index 2a3673e..c5e617d 100644 --- a/src/utils/agents/base.py +++ b/src/utils/agents/base.py @@ -1,5 +1,5 @@ from __future__ import annotations -from pydantic import BaseModel, Field # type: ignore +from pydantic import BaseModel, Field from typing import ( Literal, get_args, @@ -19,7 +19,7 @@ import inspect from abc import ABC import asyncio -from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore +from prometheus_client import Counter, Summary, CollectorRegistry from ..setup_logging import setup_logging @@ -33,7 +33,7 @@ from .types import agent_registry from .. import defines from ..message import Message, Tunables from ..metrics import Metrics -from ..tools import TickerValue, WeatherForecast, AnalyzeSite, GenerateImage, DateTime, llm_tools # type: ignore -- dynamically added to __all__ +from ..tools import TickerValue, WeatherForecast, AnalyzeSite, GenerateImage, DateTime, llm_tools from ..conversation import Conversation class LLMMessage(BaseModel): diff --git a/src/utils/agents/chat.py b/src/utils/agents/chat.py index 5529407..840f969 100644 --- a/src/utils/agents/chat.py +++ b/src/utils/agents/chat.py @@ -53,7 +53,7 @@ class Chat(Agent): Chat Agent """ - agent_type: Literal["chat"] = "chat" # type: ignore + agent_type: Literal["chat"] = "chat" _agent_type: ClassVar[str] = agent_type # Add this for registration system_prompt: str = system_message diff --git a/src/utils/agents/fact_check.py b/src/utils/agents/fact_check.py index 5131aa9..8f8d33c 100644 --- a/src/utils/agents/fact_check.py +++ b/src/utils/agents/fact_check.py @@ -1,5 +1,5 @@ from __future__ import annotations -from pydantic import model_validator # type: ignore +from pydantic import model_validator from typing import ( Literal, ClassVar, @@ -31,7 +31,7 @@ When answering queries, follow these steps: class FactCheck(Agent): - agent_type: Literal["fact_check"] = "fact_check" # type: ignore + agent_type: Literal["fact_check"] = "fact_check" _agent_type: ClassVar[str] = agent_type # Add this for registration system_prompt: str = system_fact_check diff --git a/src/utils/agents/image_generator.py b/src/utils/agents/image_generator.py index ed5efbc..568c8bf 100644 --- a/src/utils/agents/image_generator.py +++ b/src/utils/agents/image_generator.py @@ -1,5 +1,5 @@ from __future__ import annotations -from pydantic import model_validator, Field, BaseModel # type: ignore +from pydantic import model_validator, Field, BaseModel from typing import ( Dict, Literal, @@ -37,7 +37,7 @@ seed = int(time.time()) random.seed(seed) class ImageGenerator(Agent): - agent_type: Literal["image"] = "image" # type: ignore + agent_type: Literal["image"] = "image" _agent_type: ClassVar[str] = agent_type # Add this for registration agent_persist: bool = False diff --git a/src/utils/agents/job_description.py b/src/utils/agents/job_description.py index 3ea1a5d..bd5ab74 100644 --- a/src/utils/agents/job_description.py +++ b/src/utils/agents/job_description.py @@ -1,5 +1,5 @@ from __future__ import annotations -from pydantic import model_validator, Field # type: ignore +from pydantic import model_validator, Field from typing import ( Dict, Literal, @@ -17,7 +17,7 @@ import traceback import asyncio import time import asyncio -import numpy as np # type: ignore +import numpy as np from . base import Agent, agent_registry, LLMMessage from .. message import Message @@ -36,7 +36,7 @@ Answer questions about the job description. class JobDescription(Agent): - agent_type: Literal["job_description"] = "job_description" # type: ignore + agent_type: Literal["job_description"] = "job_description" _agent_type: ClassVar[str] = agent_type # Add this for registration system_prompt: str = system_generate_resume diff --git a/src/utils/agents/persona_generator.py b/src/utils/agents/persona_generator.py index 3e697b3..e6e5e72 100644 --- a/src/utils/agents/persona_generator.py +++ b/src/utils/agents/persona_generator.py @@ -1,5 +1,5 @@ from __future__ import annotations -from pydantic import model_validator, Field, BaseModel # type: ignore +from pydantic import model_validator, Field, BaseModel from typing import ( Dict, Literal, @@ -23,7 +23,7 @@ import asyncio import time import os import random -from names_dataset import NameDataset, NameWrapper # type: ignore +from names_dataset import NameDataset, NameWrapper from .base import Agent, agent_registry, LLMMessage from ..message import Message @@ -128,7 +128,7 @@ logger = logging.getLogger(__name__) class EthnicNameGenerator: def __init__(self): try: - from names_dataset import NameDataset # type: ignore + from names_dataset import NameDataset self.nd = NameDataset() except ImportError: logger.error("NameDataset not available. Please install: pip install names-dataset") @@ -292,7 +292,7 @@ class EthnicNameGenerator: return names class PersonaGenerator(Agent): - agent_type: Literal["persona"] = "persona" # type: ignore + agent_type: Literal["persona"] = "persona" _agent_type: ClassVar[str] = agent_type # Add this for registration agent_persist: bool = False diff --git a/src/utils/agents/resume.py b/src/utils/agents/resume.py index dd7118b..b35ce13 100644 --- a/src/utils/agents/resume.py +++ b/src/utils/agents/resume.py @@ -1,5 +1,5 @@ from __future__ import annotations -from pydantic import model_validator # type: ignore +from pydantic import model_validator from typing import ( Literal, ClassVar, @@ -46,7 +46,7 @@ When answering queries, follow these steps: class Resume(Agent): - agent_type: Literal["resume"] = "resume" # type: ignore + agent_type: Literal["resume"] = "resume" _agent_type: ClassVar[str] = agent_type # Add this for registration system_prompt: str = system_fact_check diff --git a/src/utils/check_serializable.py b/src/utils/check_serializable.py index 646066d..0cad6a1 100644 --- a/src/utils/check_serializable.py +++ b/src/utils/check_serializable.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field # type: ignore +from pydantic import BaseModel, Field import json from typing import Any, List, Set diff --git a/src/utils/context.py b/src/utils/context.py index 0ff1434..8769837 100644 --- a/src/utils/context.py +++ b/src/utils/context.py @@ -1,9 +1,9 @@ from __future__ import annotations -from pydantic import BaseModel, Field, model_validator # type: ignore +from pydantic import BaseModel, Field, model_validator from uuid import uuid4 from typing import List, Optional, Generator, ClassVar, Any, TYPE_CHECKING from typing_extensions import Annotated, Union -import numpy as np # type: ignore +import numpy as np import logging from uuid import uuid4 import traceback @@ -44,7 +44,7 @@ class Context(BaseModel): user_facts: Optional[str] = None # Class managed fields - agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( # type: ignore + agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( default_factory=list ) diff --git a/src/utils/conversation.py b/src/utils/conversation.py index 20798b4..0a4e4c6 100644 --- a/src/utils/conversation.py +++ b/src/utils/conversation.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field # type: ignore +from pydantic import BaseModel, Field from typing import List from .message import Message diff --git a/src/utils/image_model_cache.py b/src/utils/image_model_cache.py index cca6317..d13905c 100644 --- a/src/utils/image_model_cache.py +++ b/src/utils/image_model_cache.py @@ -4,8 +4,8 @@ import re import time from typing import Any -import torch # type: ignore -from diffusers import StableDiffusionPipeline, FluxPipeline # type: ignore +import torch +from diffusers import StableDiffusionPipeline, FluxPipeline class ImageModelCache: # Stay loaded for 3 hours def __init__(self, timeout_seconds: float = 3 * 60 * 60): diff --git a/src/utils/message.py b/src/utils/message.py index ed5e3a9..b16a751 100644 --- a/src/utils/message.py +++ b/src/utils/message.py @@ -1,8 +1,8 @@ -from pydantic import BaseModel, Field # type: ignore +from pydantic import BaseModel, Field from typing import Dict, List, Optional, Any, Union, Mapping from datetime import datetime, timezone from . rag import ChromaDBGetResponse -from ollama._types import Options # type: ignore +from ollama._types import Options class Tunables(BaseModel): enable_rag: bool = True # Enable RAG collection chromadb matching diff --git a/src/utils/metrics.py b/src/utils/metrics.py index 6f30521..b2e25d1 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -1,4 +1,4 @@ -from prometheus_client import Counter, Histogram # type: ignore +from prometheus_client import Counter, Histogram from threading import Lock def singleton(cls): diff --git a/src/utils/profile_image.py b/src/utils/profile_image.py index 300a913..8047b46 100644 --- a/src/utils/profile_image.py +++ b/src/utils/profile_image.py @@ -1,5 +1,5 @@ from __future__ import annotations -from pydantic import BaseModel, Field # type: ignore +from pydantic import BaseModel, Field from typing import Dict, Literal, Any, AsyncGenerator, Optional import inspect import random @@ -12,7 +12,7 @@ import os import gc import tempfile import uuid -import torch # type: ignore +import torch import asyncio import time import json diff --git a/src/utils/rag.py b/src/utils/rag.py index 4121133..ba76dec 100644 --- a/src/utils/rag.py +++ b/src/utils/rag.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field # type: ignore +from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field from typing import List, Optional, Dict, Any, Union import os import glob @@ -8,16 +8,16 @@ import hashlib import asyncio import logging import json -import numpy as np # type: ignore +import numpy as np import traceback -import chromadb # type: ignore +import chromadb import ollama -from watchdog.observers import Observer # type: ignore -from watchdog.events import FileSystemEventHandler # type: ignore -import umap # type: ignore -from markitdown import MarkItDown # type: ignore -from chromadb.api.models.Collection import Collection # type: ignore +from watchdog.observers import Observer +from watchdog.events import FileSystemEventHandler +import umap +from markitdown import MarkItDown +from chromadb.api.models.Collection import Collection from .markdown_chunker import ( MarkdownChunker, @@ -388,9 +388,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler): os.makedirs(self.persist_directory) # Initialize ChromaDB client - chroma_client = chromadb.PersistentClient( # type: ignore + chroma_client = chromadb.PersistentClient( path=self.persist_directory, - settings=chromadb.Settings(anonymized_telemetry=False), # type: ignore + settings=chromadb.Settings(anonymized_telemetry=False), ) # Check if the collection exists diff --git a/src/utils/redis_client.py b/src/utils/redis_client.py index 8be6c43..88b0d55 100644 --- a/src/utils/redis_client.py +++ b/src/utils/redis_client.py @@ -1,4 +1,4 @@ -import redis.asyncio as redis # type: ignore +import redis.asyncio as redis from typing import Optional import os import logging diff --git a/src/utils/tools/basetools.py b/src/utils/tools/basetools.py index 6598bf5..d90360c 100644 --- a/src/utils/tools/basetools.py +++ b/src/utils/tools/basetools.py @@ -1,5 +1,5 @@ import os -from pydantic import BaseModel, Field, model_validator # type: ignore +from pydantic import BaseModel, Field, model_validator from typing import List, Optional, Generator, ClassVar, Any, Dict from datetime import datetime from typing import ( @@ -7,12 +7,12 @@ from typing import ( ) from typing_extensions import Annotated -from bs4 import BeautifulSoup # type: ignore +from bs4 import BeautifulSoup -from geopy.geocoders import Nominatim # type: ignore -import pytz # type: ignore +from geopy.geocoders import Nominatim +import pytz import requests -import yfinance as yf # type: ignore +import yfinance as yf import logging @@ -523,4 +523,4 @@ def enabled_tools(tools: List[ToolEntry]) -> List[ToolEntry]: tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite", "GenerateImage"] __all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"] -# __all__.extend(__tool_functions__) # type: ignore + diff --git a/src/utils/user.py b/src/utils/user.py index 927b28c..3a1abe7 100644 --- a/src/utils/user.py +++ b/src/utils/user.py @@ -1,13 +1,13 @@ from __future__ import annotations -from pydantic import BaseModel, Field, model_validator # type: ignore +from pydantic import BaseModel, Field, model_validator from uuid import uuid4 from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING, Literal from typing_extensions import Annotated, Union -import numpy as np # type: ignore +import numpy as np import logging from uuid import uuid4 -from prometheus_client import CollectorRegistry, Counter # type: ignore +from prometheus_client import CollectorRegistry, Counter import traceback import os import json