Guest seems to work!
This commit is contained in:
parent
35ef9898f1
commit
20f8d7bd32
@ -188,7 +188,7 @@ RUN pip install prometheus-client prometheus-fastapi-instrumentator
|
||||
RUN pip install "redis[hiredis]>=4.5.0"
|
||||
|
||||
# New backend implementation
|
||||
RUN pip install fastapi uvicorn "python-jose[cryptography]" bcrypt python-multipart
|
||||
RUN pip install fastapi uvicorn "python-jose[cryptography]" bcrypt python-multipart schedule
|
||||
|
||||
# Needed for email verification
|
||||
RUN pip install pyyaml user-agents cryptography
|
||||
|
@ -91,11 +91,11 @@ services:
|
||||
# Optional: Redis Commander for GUI management
|
||||
redis-commander:
|
||||
image: rediscommander/redis-commander:latest
|
||||
container_name: backstory-redis-commander
|
||||
container_name: redis-commander
|
||||
ports:
|
||||
- "8081:8081"
|
||||
environment:
|
||||
- REDIS_HOSTS=redis:6379
|
||||
- REDIS_HOSTS=redis:redis:6379
|
||||
networks:
|
||||
- internal
|
||||
depends_on:
|
||||
|
@ -20,8 +20,8 @@ import '@fontsource/roboto/700.css';
|
||||
const BackstoryApp = () => {
|
||||
const navigate = useNavigate();
|
||||
const location = useLocation();
|
||||
const snackRef = useRef<any>(null);
|
||||
const chatRef = useRef<ConversationHandle>(null);
|
||||
const snackRef = useRef<any>(null);
|
||||
const setSnack = useCallback((message: string, severity?: SeverityType) => {
|
||||
snackRef.current?.setSnack(message, severity);
|
||||
}, [snackRef]);
|
||||
|
@ -12,7 +12,6 @@ import { Message } from './Message';
|
||||
import { DeleteConfirmation } from 'components/DeleteConfirmation';
|
||||
import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryTextField';
|
||||
import { BackstoryElementProps } from './BackstoryTab';
|
||||
import { connectionBase } from 'utils/Global';
|
||||
import { useAuth } from "hooks/AuthContext";
|
||||
import { StreamingResponse } from 'services/api-client';
|
||||
import { ChatMessage, ChatContext, ChatSession, ChatQuery, ChatMessageUser, ChatMessageError, ChatMessageStreaming, ChatMessageStatus } from 'types/types';
|
||||
@ -22,7 +21,7 @@ import './Conversation.css';
|
||||
import { useAppState } from 'hooks/GlobalContext';
|
||||
|
||||
const defaultMessage: ChatMessage = {
|
||||
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "assistant"
|
||||
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "assistant", metadata: null as any
|
||||
};
|
||||
|
||||
const loadingMessage: ChatMessage = { ...defaultMessage, content: "Establishing connection with server..." };
|
||||
@ -325,16 +324,16 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
|
||||
<Box sx={{ p: 1, mt: 0, ...sx }}>
|
||||
{
|
||||
filteredConversation.map((message, index) =>
|
||||
<Message key={index} {...{ chatSession, sendQuery: processQuery, message, connectionBase, }} />
|
||||
<Message key={index} {...{ chatSession, sendQuery: processQuery, message, }} />
|
||||
)
|
||||
}
|
||||
{
|
||||
processingMessage !== undefined &&
|
||||
<Message {...{ chatSession, sendQuery: processQuery, connectionBase, message: processingMessage, }} />
|
||||
<Message {...{ chatSession, sendQuery: processQuery, message: processingMessage, }} />
|
||||
}
|
||||
{
|
||||
streamingMessage !== undefined &&
|
||||
<Message {...{ chatSession, sendQuery: processQuery, connectionBase, message: streamingMessage }} />
|
||||
<Message {...{ chatSession, sendQuery: processQuery, message: streamingMessage }} />
|
||||
}
|
||||
<Box sx={{
|
||||
display: "flex",
|
||||
|
@ -29,7 +29,7 @@ import {
|
||||
} from '@mui/icons-material';
|
||||
import { useAuth } from 'hooks/AuthContext';
|
||||
import { BackstoryPageProps } from './BackstoryTab';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import { Navigate, useNavigate } from 'react-router-dom';
|
||||
|
||||
// Email Verification Component
|
||||
const EmailVerificationPage = (props: BackstoryPageProps) => {
|
||||
@ -488,10 +488,11 @@ const RegistrationSuccessDialog = ({
|
||||
|
||||
// Enhanced Login Component with MFA Support
|
||||
const LoginForm = () => {
|
||||
const { login, mfaResponse, isLoading, error } = useAuth();
|
||||
const { login, mfaResponse, isLoading, error, user } = useAuth();
|
||||
const [email, setEmail] = useState('');
|
||||
const [password, setPassword] = useState('');
|
||||
const [errorMessage, setErrorMessage] = useState<string | null>(null);
|
||||
const navigate = useNavigate();
|
||||
|
||||
useEffect(() => {
|
||||
if (!error) {
|
||||
@ -524,10 +525,13 @@ const LoginForm = () => {
|
||||
handleLoginSuccess();
|
||||
};
|
||||
|
||||
const handleLoginSuccess = () => {
|
||||
// This could be handled by a router or parent component
|
||||
// For now, just showing the pattern
|
||||
console.log('Login successful - redirect to dashboard');
|
||||
const handleLoginSuccess = () => {
|
||||
if (!user) {
|
||||
navigate('/');
|
||||
} else {
|
||||
navigate(`/${user.userType}/dashboard`);
|
||||
}
|
||||
console.log('Login successful - redirect to dashboard');
|
||||
};
|
||||
|
||||
return (
|
||||
|
@ -39,7 +39,7 @@ interface JobAnalysisProps extends BackstoryPageProps {
|
||||
}
|
||||
|
||||
const defaultMessage: ChatMessage = {
|
||||
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "assistant"
|
||||
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "assistant", metadata: null as any
|
||||
};
|
||||
|
||||
const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) => {
|
||||
|
@ -101,6 +101,7 @@ const getStyle = (theme: Theme, type: ApiActivityType | ChatSenderType | "error"
|
||||
color: theme.palette.text.primary,
|
||||
opacity: 0.95,
|
||||
},
|
||||
info: 'information',
|
||||
preparing: 'status',
|
||||
processing: 'status',
|
||||
qualifications: {
|
||||
|
@ -17,7 +17,6 @@ import TableContainer from '@mui/material/TableContainer';
|
||||
import TableRow from '@mui/material/TableRow';
|
||||
|
||||
import { Scrollable } from './Scrollable';
|
||||
import { connectionBase } from '../utils/Global';
|
||||
|
||||
import './VectorVisualizer.css';
|
||||
import { BackstoryPageProps } from './BackstoryTab';
|
||||
@ -195,7 +194,7 @@ const VectorVisualizer: React.FC<VectorVisualizerProps> = (props: VectorVisualiz
|
||||
const [plotDimensions, setPlotDimensions] = useState({ width: 0, height: 0 });
|
||||
const navigate = useNavigate();
|
||||
|
||||
const candidate: Types.Candidate | null = user?.userType === 'candidate' ? user : null;
|
||||
const candidate: Types.Candidate | null = user?.userType === 'candidate' ? user as Types.Candidate : null;
|
||||
|
||||
/* Force resize of Plotly as it tends to not be the correct size if it is initially rendered
|
||||
* off screen (eg., the VectorVisualizer is not on the tab the app loads to) */
|
||||
|
@ -84,7 +84,6 @@ const BackstoryLayout: React.FC<BackstoryLayoutProps> = (props: BackstoryLayoutP
|
||||
const navigate = useNavigate();
|
||||
const location = useLocation();
|
||||
const { guest, user } = useAuth();
|
||||
const { selectedCandidate } = useSelectedCandidate();
|
||||
const [navigationItems, setNavigationItems] = useState<NavigationItem[]>([]);
|
||||
|
||||
useEffect(() => {
|
||||
@ -92,9 +91,13 @@ const BackstoryLayout: React.FC<BackstoryLayoutProps> = (props: BackstoryLayoutP
|
||||
setNavigationItems(getMainNavigationItems(userType, user?.isAdmin ? true : false));
|
||||
}, [user]);
|
||||
|
||||
useEffect(() => {
|
||||
console.log({ guest, user });
|
||||
}, [guest, user]);
|
||||
|
||||
// Generate dynamic routes from navigation config
|
||||
const generateRoutes = () => {
|
||||
if (!guest) return null;
|
||||
if (!guest && !user) return null;
|
||||
|
||||
const userType = user?.userType || null;
|
||||
const isAdmin = user?.isAdmin ? true : false;
|
||||
@ -161,7 +164,7 @@ const BackstoryLayout: React.FC<BackstoryLayoutProps> = (props: BackstoryLayoutP
|
||||
}}
|
||||
>
|
||||
<BackstoryPageContainer>
|
||||
{!guest && (
|
||||
{!guest && !user && (
|
||||
<Box>
|
||||
<LoadingComponent
|
||||
loadingText="Creating session..."
|
||||
@ -171,7 +174,7 @@ const BackstoryLayout: React.FC<BackstoryLayoutProps> = (props: BackstoryLayoutP
|
||||
/>
|
||||
</Box>
|
||||
)}
|
||||
{guest && (
|
||||
{(guest || user) && (
|
||||
<>
|
||||
<Outlet />
|
||||
<Routes>
|
||||
|
@ -156,7 +156,7 @@ const Header: React.FC<HeaderProps> = (props: HeaderProps) => {
|
||||
id: 'profile',
|
||||
label: 'Profile',
|
||||
icon: <Person fontSize="small" />,
|
||||
action: () => navigate(`/${user?.userType}/dashboard/profile`)
|
||||
action: () => navigate(`/${user?.userType}/profile`)
|
||||
},
|
||||
{
|
||||
id: 'dashboard',
|
||||
|
@ -74,7 +74,7 @@ const CandidateInfo: React.FC<CandidateInfoProps> = (props: CandidateInfoProps)
|
||||
maxWidth: "80px"
|
||||
}}>
|
||||
<Avatar
|
||||
src={candidate.profileImage ? `/api/1.0/candidates/profile/${candidate.username}?timestamp=${Date.now()}` : ''}
|
||||
src={candidate.profileImage ? `/api/1.0/candidates/profile/${candidate.username}` : ''}
|
||||
alt={`${candidate.fullName}'s profile`}
|
||||
sx={{
|
||||
alignSelf: "flex-start",
|
||||
|
@ -46,6 +46,7 @@ import { JobPicker } from 'components/ui/JobPicker';
|
||||
import { DocumentManager } from 'components/DocumentManager';
|
||||
import { VectorVisualizer } from 'components/VectorVisualizer';
|
||||
import { ComingSoon } from 'components/ui/ComingSoon';
|
||||
import { Beta } from 'components/ui/Beta';
|
||||
|
||||
// Beta page components for placeholder routes
|
||||
const SearchPage = () => (<BetaPage><Typography variant="h4">Search</Typography></BetaPage>);
|
||||
@ -58,20 +59,19 @@ const SettingsPage = () => (<BetaPage><Typography variant="h4">Settings</Typogra
|
||||
|
||||
export const navigationConfig: NavigationConfig = {
|
||||
items: [
|
||||
{ id: 'home', label: <BackstoryLogo />, path: '/', component: <HomePage />, userTypes: ['guest', 'candidate', 'employer'], exact: true, },
|
||||
{ id: 'chat', label: 'Chat', path: '/chat', icon: <ChatIcon />, component: <CandidateChatPage />, userTypes: ['guest', 'candidate', 'employer',], },
|
||||
{ id: 'home', label: <BackstoryLogo />, path: '/', component: <HomePage />, userTypes: ['guest', 'candidate', 'employer'], exact: true, },
|
||||
{ id: 'chat', label: 'Chat about a Candidate', path: '/chat', icon: <ChatIcon />, component: <CandidateChatPage />, userTypes: ['guest', 'candidate', 'employer',], },
|
||||
{
|
||||
id: 'candidate-menu', label: 'Tools', icon: <PersonIcon />, userTypes: ['candidate'], children: [
|
||||
{ id: 'candidate-dashboard', label: 'Dashboard', path: '/candidate/dashboard', icon: <DashboardIcon />, component: <CandidateDashboard />, userTypes: ['candidate'] },
|
||||
{ id: 'candidate-profile', label: 'Profile', icon: <PersonIcon />, path: '/candidate/profile', component: <CandidateProfile />, userTypes: ['candidate'] },
|
||||
{ id: 'candidate-qa-setup', label: 'Q&A Setup', icon: <QuizIcon />, path: '/candidate/qa-setup', component: <BetaPage><Box>Candidate q&a setup page</Box></BetaPage>, userTypes: ['candidate'] },
|
||||
{ id: 'candidate-analytics', label: 'Analytics', icon: <AnalyticsIcon />, path: '/candidate/analytics', component: <BetaPage><Box>Candidate analytics page</Box></BetaPage>, userTypes: ['candidate'] },
|
||||
{ id: 'candidate-jobs', label: 'Jobs', icon: <WorkIcon />, path: '/candidate/jobs', component: <JobPicker />, userTypes: ['candidate'] },
|
||||
{ id: 'candidate-analytics', label: 'Analytics', icon: <AnalyticsIcon />, path: '/candidate/analytics', component: <BetaPage><Box>Candidate analytics page</Box></BetaPage>, userTypes: ['candidate'] },
|
||||
{ id: 'candidate-job-analysis', label: 'Job Analysis', path: '/candidate/job-analysis', icon: <WorkIcon />, component: <JobAnalysisPage />, userTypes: ['candidate'], },
|
||||
{ id: 'candidate-resumes', label: 'Resumes', icon: <DescriptionIcon />, path: '/candidate/resumes', component: <BetaPage><Box>Candidate resumes page</Box></BetaPage>, userTypes: ['candidate'] },
|
||||
{ id: 'candidate-resume-builder', label: 'Resume Builder', path: '/candidate/resume-builder', icon: <DescriptionIcon />, component: <ResumeBuilderPage />, userTypes: ['candidate'], },
|
||||
{ id: 'candidate-content', label: 'Content', icon: <BubbleChart />, path: '/candidate/content', component: <Box sx={{ display: "flex", width: "100%", flexDirection: "column" }}><VectorVisualizer /><DocumentManager /></Box>, userTypes: ['candidate'] },
|
||||
{ id: 'candidate-settings', label: 'Settings', path: '/candidate/settings', icon: <SettingsIcon />, component: <ComingSoon><Settings /></ComingSoon>, userTypes: ['candidate'], },
|
||||
{ id: 'candidate-settings', label: 'Settings', path: '/candidate/settings', icon: <SettingsIcon />, component: <Settings />, userTypes: ['candidate'], },
|
||||
],
|
||||
},
|
||||
{
|
||||
@ -87,7 +87,7 @@ export const navigationConfig: NavigationConfig = {
|
||||
{ id: 'employer-settings', label: 'Settings', path: '/employer/settings', icon: <SettingsIcon />, component: <SettingsPage />, userTypes: ['employer'], },
|
||||
],
|
||||
},
|
||||
{ id: 'find-candidate', label: 'Find a Candidate', path: '/find-a-candidate', icon: <PersonSearchIcon />, component: <CandidateListingPage />, userTypes: ['guest', 'candidate', 'employer'], },
|
||||
// { id: 'find-candidate', label: 'Find a Candidate', path: '/find-a-candidate', icon: <PersonSearchIcon />, component: <CandidateListingPage />, userTypes: ['guest', 'candidate', 'employer'], },
|
||||
{ id: 'docs', label: 'Docs', path: '/docs/*', icon: <LibraryBooksIcon />, component: <DocsPage />, userTypes: ['guest', 'candidate', 'employer'], },
|
||||
{
|
||||
id: 'admin-menu',
|
||||
|
@ -1,17 +1,19 @@
|
||||
// Replace the existing AuthContext.tsx with these enhancements
|
||||
|
||||
import React, { createContext, useContext, useState, useCallback, useEffect, useRef } from 'react';
|
||||
import * as Types from '../types/types';
|
||||
import { ApiClient, CreateCandidateRequest, CreateEmployerRequest } from '../services/api-client';
|
||||
import { ApiClient, CreateCandidateRequest, CreateEmployerRequest, GuestConversionRequest } from '../services/api-client';
|
||||
import { formatApiRequest, toCamelCase } from '../types/conversion';
|
||||
|
||||
// ============================
|
||||
// Types and Interfaces
|
||||
// Enhanced Types and Interfaces
|
||||
// ============================
|
||||
|
||||
|
||||
interface AuthState {
|
||||
user: Types.User | null;
|
||||
guest: Types.Guest | null;
|
||||
isAuthenticated: boolean;
|
||||
isGuest: boolean;
|
||||
isLoading: boolean;
|
||||
isInitializing: boolean;
|
||||
error: string | null;
|
||||
@ -36,7 +38,7 @@ interface PasswordResetRequest {
|
||||
}
|
||||
|
||||
// ============================
|
||||
// Token Storage Constants
|
||||
// Enhanced Token Storage Constants
|
||||
// ============================
|
||||
|
||||
const TOKEN_STORAGE = {
|
||||
@ -44,7 +46,8 @@ const TOKEN_STORAGE = {
|
||||
REFRESH_TOKEN: 'refreshToken',
|
||||
USER_DATA: 'userData',
|
||||
TOKEN_EXPIRY: 'tokenExpiry',
|
||||
GUEST_DATA: 'guestData',
|
||||
USER_TYPE: 'userType',
|
||||
IS_GUEST: 'isGuest',
|
||||
PENDING_VERIFICATION_EMAIL: 'pendingVerificationEmail'
|
||||
} as const;
|
||||
|
||||
@ -84,7 +87,7 @@ function isTokenExpired(token: string): boolean {
|
||||
}
|
||||
|
||||
// ============================
|
||||
// Storage Utilities with Date Conversion
|
||||
// Enhanced Storage Utilities
|
||||
// ============================
|
||||
|
||||
function clearStoredAuth(): void {
|
||||
@ -92,6 +95,8 @@ function clearStoredAuth(): void {
|
||||
localStorage.removeItem(TOKEN_STORAGE.REFRESH_TOKEN);
|
||||
localStorage.removeItem(TOKEN_STORAGE.USER_DATA);
|
||||
localStorage.removeItem(TOKEN_STORAGE.TOKEN_EXPIRY);
|
||||
localStorage.removeItem(TOKEN_STORAGE.USER_TYPE);
|
||||
localStorage.removeItem(TOKEN_STORAGE.IS_GUEST);
|
||||
}
|
||||
|
||||
function prepareUserDataForStorage(user: Types.User): string {
|
||||
@ -119,11 +124,21 @@ function parseStoredUserData(userDataStr: string): Types.User | null {
|
||||
}
|
||||
}
|
||||
|
||||
function storeAuthData(authResponse: Types.AuthResponse): void {
|
||||
function updateStoredUserData(user: Types.User): void {
|
||||
try {
|
||||
localStorage.setItem(TOKEN_STORAGE.USER_DATA, prepareUserDataForStorage(user));
|
||||
} catch (error) {
|
||||
console.error('Failed to update stored user data:', error);
|
||||
}
|
||||
}
|
||||
|
||||
function storeAuthData(authResponse: Types.AuthResponse, isGuest: boolean = false): void {
|
||||
localStorage.setItem(TOKEN_STORAGE.ACCESS_TOKEN, authResponse.accessToken);
|
||||
localStorage.setItem(TOKEN_STORAGE.REFRESH_TOKEN, authResponse.refreshToken);
|
||||
localStorage.setItem(TOKEN_STORAGE.USER_DATA, prepareUserDataForStorage(authResponse.user));
|
||||
localStorage.setItem(TOKEN_STORAGE.TOKEN_EXPIRY, authResponse.expiresAt.toString());
|
||||
localStorage.setItem(TOKEN_STORAGE.USER_TYPE, authResponse.user.userType);
|
||||
localStorage.setItem(TOKEN_STORAGE.IS_GUEST, isGuest.toString());
|
||||
}
|
||||
|
||||
function getStoredAuthData(): {
|
||||
@ -131,11 +146,15 @@ function getStoredAuthData(): {
|
||||
refreshToken: string | null;
|
||||
userData: Types.User | null;
|
||||
expiresAt: number | null;
|
||||
userType: string | null;
|
||||
isGuest: boolean;
|
||||
} {
|
||||
const accessToken = localStorage.getItem(TOKEN_STORAGE.ACCESS_TOKEN);
|
||||
const refreshToken = localStorage.getItem(TOKEN_STORAGE.REFRESH_TOKEN);
|
||||
const userDataStr = localStorage.getItem(TOKEN_STORAGE.USER_DATA);
|
||||
const expiryStr = localStorage.getItem(TOKEN_STORAGE.TOKEN_EXPIRY);
|
||||
const userType = localStorage.getItem(TOKEN_STORAGE.USER_TYPE);
|
||||
const isGuestStr = localStorage.getItem(TOKEN_STORAGE.IS_GUEST);
|
||||
|
||||
let userData: Types.User | null = null;
|
||||
let expiresAt: number | null = null;
|
||||
@ -152,55 +171,18 @@ function getStoredAuthData(): {
|
||||
clearStoredAuth();
|
||||
}
|
||||
|
||||
return { accessToken, refreshToken, userData, expiresAt };
|
||||
}
|
||||
|
||||
function updateStoredUserData(user: Types.User): void {
|
||||
try {
|
||||
localStorage.setItem(TOKEN_STORAGE.USER_DATA, prepareUserDataForStorage(user));
|
||||
} catch (error) {
|
||||
console.error('Failed to update stored user data:', error);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================
|
||||
// Guest Session Utilities
|
||||
// ============================
|
||||
|
||||
function createGuestSession(): Types.Guest {
|
||||
const sessionId = `guest_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
|
||||
const guest: Types.Guest = {
|
||||
sessionId,
|
||||
createdAt: new Date(),
|
||||
lastActivity: new Date(),
|
||||
ipAddress: 'unknown',
|
||||
userAgent: navigator.userAgent
|
||||
return {
|
||||
accessToken,
|
||||
refreshToken,
|
||||
userData,
|
||||
expiresAt,
|
||||
userType,
|
||||
isGuest: isGuestStr === 'true'
|
||||
};
|
||||
|
||||
try {
|
||||
localStorage.setItem(TOKEN_STORAGE.GUEST_DATA, JSON.stringify(formatApiRequest(guest)));
|
||||
} catch (error) {
|
||||
console.error('Failed to store guest session:', error);
|
||||
}
|
||||
|
||||
return guest;
|
||||
}
|
||||
|
||||
function getStoredGuestData(): Types.Guest | null {
|
||||
try {
|
||||
const guestDataStr = localStorage.getItem(TOKEN_STORAGE.GUEST_DATA);
|
||||
if (guestDataStr) {
|
||||
return toCamelCase<Types.Guest>(JSON.parse(guestDataStr));
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to parse stored guest data:', error);
|
||||
localStorage.removeItem(TOKEN_STORAGE.GUEST_DATA);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
// ============================
|
||||
// Main Authentication Hook
|
||||
// Enhanced Authentication Hook
|
||||
// ============================
|
||||
|
||||
function useAuthenticationLogic() {
|
||||
@ -208,6 +190,7 @@ function useAuthenticationLogic() {
|
||||
user: null,
|
||||
guest: null,
|
||||
isAuthenticated: false,
|
||||
isGuest: false,
|
||||
isLoading: false,
|
||||
isInitializing: true,
|
||||
error: null,
|
||||
@ -216,6 +199,7 @@ function useAuthenticationLogic() {
|
||||
|
||||
const [apiClient] = useState(() => new ApiClient());
|
||||
const initializationCompleted = useRef(false);
|
||||
const guestCreationAttempted = useRef(false);
|
||||
|
||||
// Token refresh function
|
||||
const refreshAccessToken = useCallback(async (refreshToken: string): Promise<Types.AuthResponse | null> => {
|
||||
@ -228,6 +212,58 @@ function useAuthenticationLogic() {
|
||||
}
|
||||
}, [apiClient]);
|
||||
|
||||
// Create guest session
|
||||
const createGuestSession = useCallback(async (): Promise<boolean> => {
|
||||
if (guestCreationAttempted.current) {
|
||||
return false;
|
||||
}
|
||||
|
||||
guestCreationAttempted.current = true;
|
||||
|
||||
try {
|
||||
console.log('🔄 Creating guest session...');
|
||||
const guestAuth = await apiClient.createGuestSession();
|
||||
|
||||
if (guestAuth && guestAuth.user && guestAuth.user.userType === 'guest') {
|
||||
storeAuthData(guestAuth, true);
|
||||
apiClient.setAuthToken(guestAuth.accessToken);
|
||||
|
||||
setAuthState({
|
||||
user: null,
|
||||
guest: guestAuth.user as Types.Guest,
|
||||
isAuthenticated: true,
|
||||
isGuest: true,
|
||||
isLoading: false,
|
||||
isInitializing: false,
|
||||
error: null,
|
||||
mfaResponse: null,
|
||||
});
|
||||
|
||||
console.log('👤 Guest session created successfully:', guestAuth.user);
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
} catch (error) {
|
||||
console.error('❌ Failed to create guest session:', error);
|
||||
guestCreationAttempted.current = false;
|
||||
|
||||
// Set to unauthenticated state if guest creation fails
|
||||
setAuthState(prev => ({
|
||||
...prev,
|
||||
user: null,
|
||||
guest: null,
|
||||
isAuthenticated: false,
|
||||
isGuest: false,
|
||||
isLoading: false,
|
||||
isInitializing: false,
|
||||
error: 'Failed to create guest session',
|
||||
}));
|
||||
|
||||
return false;
|
||||
}
|
||||
}, [apiClient]);
|
||||
|
||||
// Initialize authentication state
|
||||
const initializeAuth = useCallback(async () => {
|
||||
if (initializationCompleted.current) {
|
||||
@ -235,99 +271,94 @@ function useAuthenticationLogic() {
|
||||
}
|
||||
|
||||
try {
|
||||
// Initialize guest session first
|
||||
let guest = getStoredGuestData();
|
||||
if (!guest) {
|
||||
guest = createGuestSession();
|
||||
const stored = getStoredAuthData();
|
||||
|
||||
// If no stored tokens, create guest session
|
||||
if (!stored.accessToken || !stored.refreshToken || !stored.userData) {
|
||||
console.log('🔄 No stored auth found, creating guest session...');
|
||||
await createGuestSession();
|
||||
return;
|
||||
}
|
||||
|
||||
const stored = getStoredAuthData();
|
||||
|
||||
// If no stored tokens, user is not authenticated but has guest session
|
||||
if (!stored.accessToken || !stored.refreshToken || !stored.userData) {
|
||||
setAuthState({
|
||||
user: null,
|
||||
guest,
|
||||
isAuthenticated: false,
|
||||
isLoading: false,
|
||||
isInitializing: false,
|
||||
error: null,
|
||||
mfaResponse: null,
|
||||
});
|
||||
return;
|
||||
// For guests, always verify the session exists on server
|
||||
if (stored.userType === 'guest' && stored.userData) {
|
||||
console.log(stored.userData);
|
||||
try {
|
||||
// Make a quick API call to verify guest still exists
|
||||
const response = await fetch(`${apiClient.getBaseUrl()}/users/${stored.userData.id}`, {
|
||||
headers: { 'Authorization': `Bearer ${stored.accessToken}` }
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
console.log('🔄 Guest session invalid, creating new guest session...');
|
||||
clearStoredAuth();
|
||||
await createGuestSession();
|
||||
return;
|
||||
}
|
||||
} catch (error) {
|
||||
console.log('🔄 Guest verification failed, creating new guest session...');
|
||||
clearStoredAuth();
|
||||
await createGuestSession();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if access token is expired
|
||||
if (isTokenExpired(stored.accessToken)) {
|
||||
console.log('Access token expired, attempting refresh...');
|
||||
console.log('🔄 Access token expired, attempting refresh...');
|
||||
|
||||
const refreshResult = await refreshAccessToken(stored.refreshToken);
|
||||
|
||||
if (refreshResult) {
|
||||
storeAuthData(refreshResult);
|
||||
const isGuest = stored.userType === 'guest';
|
||||
storeAuthData(refreshResult, isGuest);
|
||||
apiClient.setAuthToken(refreshResult.accessToken);
|
||||
|
||||
setAuthState({
|
||||
user: refreshResult.user,
|
||||
guest,
|
||||
user: isGuest ? null : refreshResult.user,
|
||||
guest: isGuest ? refreshResult.user as Types.Guest : null,
|
||||
isAuthenticated: true,
|
||||
isGuest,
|
||||
isLoading: false,
|
||||
isInitializing: false,
|
||||
error: null,
|
||||
mfaResponse: null
|
||||
});
|
||||
|
||||
console.log('Token refreshed successfully');
|
||||
console.log('✅ Token refreshed successfully');
|
||||
} else {
|
||||
console.log('Token refresh failed, clearing stored auth');
|
||||
console.log('❌ Token refresh failed, creating new guest session...');
|
||||
clearStoredAuth();
|
||||
apiClient.clearAuthToken();
|
||||
|
||||
setAuthState({
|
||||
user: null,
|
||||
guest,
|
||||
isAuthenticated: false,
|
||||
isLoading: false,
|
||||
isInitializing: false,
|
||||
error: null,
|
||||
mfaResponse: null
|
||||
});
|
||||
await createGuestSession();
|
||||
}
|
||||
} else {
|
||||
// Access token is still valid
|
||||
apiClient.setAuthToken(stored.accessToken);
|
||||
const isGuest = stored.userType === 'guest';
|
||||
|
||||
setAuthState({
|
||||
user: stored.userData,
|
||||
guest,
|
||||
user: isGuest ? null : stored.userData,
|
||||
guest: isGuest ? stored.userData as Types.Guest : null,
|
||||
isAuthenticated: true,
|
||||
isGuest,
|
||||
isLoading: false,
|
||||
isInitializing: false,
|
||||
error: null,
|
||||
mfaResponse: null
|
||||
});
|
||||
|
||||
console.log('Restored authentication from stored tokens');
|
||||
console.log('✅ Restored authentication from stored tokens');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error initializing auth:', error);
|
||||
console.error('❌ Error initializing auth:', error);
|
||||
clearStoredAuth();
|
||||
apiClient.clearAuthToken();
|
||||
|
||||
const guest = createGuestSession();
|
||||
setAuthState({
|
||||
user: null,
|
||||
guest,
|
||||
isAuthenticated: false,
|
||||
isLoading: false,
|
||||
isInitializing: false,
|
||||
error: null,
|
||||
mfaResponse: null
|
||||
});
|
||||
await createGuestSession();
|
||||
} finally {
|
||||
initializationCompleted.current = true;
|
||||
}
|
||||
}, [apiClient, refreshAccessToken]);
|
||||
}, [apiClient, refreshAccessToken, createGuestSession]);
|
||||
|
||||
// Run initialization on mount
|
||||
useEffect(() => {
|
||||
@ -355,7 +386,7 @@ function useAuthenticationLogic() {
|
||||
}
|
||||
|
||||
const refreshTimer = setTimeout(() => {
|
||||
console.log('Auto-refreshing token before expiry...');
|
||||
console.log('🔄 Auto-refreshing token before expiry...');
|
||||
initializeAuth();
|
||||
}, timeUntilExpiry);
|
||||
|
||||
@ -364,7 +395,7 @@ function useAuthenticationLogic() {
|
||||
|
||||
// Enhanced login with MFA support
|
||||
const login = useCallback(async (loginData: LoginRequest): Promise<boolean> => {
|
||||
setAuthState(prev => ({ ...prev, isLoading: true, error: null, mfaResponse: null, mfaData: null }));
|
||||
setAuthState(prev => ({ ...prev, isLoading: true, error: null, mfaResponse: null }));
|
||||
|
||||
try {
|
||||
const result = await apiClient.login({
|
||||
@ -381,21 +412,23 @@ function useAuthenticationLogic() {
|
||||
}));
|
||||
return false; // Login not complete yet
|
||||
} else {
|
||||
// Normal login success
|
||||
// Normal login success - convert from guest to authenticated user
|
||||
const authResponse: Types.AuthResponse = result;
|
||||
storeAuthData(authResponse);
|
||||
storeAuthData(authResponse, false);
|
||||
apiClient.setAuthToken(authResponse.accessToken);
|
||||
|
||||
setAuthState(prev => ({
|
||||
...prev,
|
||||
user: authResponse.user,
|
||||
guest: null,
|
||||
isAuthenticated: true,
|
||||
isGuest: false,
|
||||
isLoading: false,
|
||||
error: null,
|
||||
mfaResponse: null,
|
||||
}));
|
||||
|
||||
console.log('Login successful');
|
||||
console.log('✅ Login successful, converted from guest to authenticated user');
|
||||
return true;
|
||||
}
|
||||
} catch (error: any) {
|
||||
@ -410,6 +443,44 @@ function useAuthenticationLogic() {
|
||||
}
|
||||
}, [apiClient]);
|
||||
|
||||
// Convert guest to permanent user
|
||||
const convertGuestToUser = useCallback(async (registrationData: GuestConversionRequest): Promise<boolean> => {
|
||||
if (!authState.isGuest || !authState.guest) {
|
||||
throw new Error('Not currently a guest user');
|
||||
}
|
||||
|
||||
setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
|
||||
|
||||
try {
|
||||
const result = await apiClient.convertGuestToUser(registrationData);
|
||||
|
||||
// Store new authentication
|
||||
storeAuthData(result.auth, false);
|
||||
apiClient.setAuthToken(result.auth.accessToken);
|
||||
|
||||
setAuthState(prev => ({
|
||||
...prev,
|
||||
user: result.auth.user,
|
||||
guest: null,
|
||||
isAuthenticated: true,
|
||||
isGuest: false,
|
||||
isLoading: false,
|
||||
error: null,
|
||||
}));
|
||||
|
||||
console.log('✅ Guest successfully converted to permanent user');
|
||||
return true;
|
||||
} catch (error: any) {
|
||||
const errorMessage = error instanceof Error ? error.message : 'Failed to convert guest account';
|
||||
setAuthState(prev => ({
|
||||
...prev,
|
||||
isLoading: false,
|
||||
error: errorMessage,
|
||||
}));
|
||||
return false;
|
||||
}
|
||||
}, [apiClient, authState.isGuest, authState.guest]);
|
||||
|
||||
// MFA verification
|
||||
const verifyMFA = useCallback(async (mfaData: Types.MFAVerifyRequest): Promise<boolean> => {
|
||||
setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
|
||||
@ -419,26 +490,27 @@ function useAuthenticationLogic() {
|
||||
|
||||
if (result.accessToken) {
|
||||
const authResponse: Types.AuthResponse = result;
|
||||
storeAuthData(authResponse);
|
||||
storeAuthData(authResponse, false);
|
||||
apiClient.setAuthToken(authResponse.accessToken);
|
||||
|
||||
setAuthState(prev => ({
|
||||
...prev,
|
||||
user: authResponse.user,
|
||||
guest: null,
|
||||
isAuthenticated: true,
|
||||
isGuest: false,
|
||||
isLoading: false,
|
||||
error: null,
|
||||
mfaResponse: null,
|
||||
}));
|
||||
|
||||
console.log('MFA verification successful');
|
||||
console.log('✅ MFA verification successful, converted from guest');
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : 'MFA verification failed';
|
||||
console.log(errorMessage);
|
||||
setAuthState(prev => ({
|
||||
...prev,
|
||||
isLoading: false,
|
||||
@ -448,42 +520,48 @@ function useAuthenticationLogic() {
|
||||
}
|
||||
}, [apiClient]);
|
||||
|
||||
// Resend MFA code
|
||||
const resendMFACode = useCallback(async (email: string, deviceId: string, deviceName: string): Promise<boolean> => {
|
||||
setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
|
||||
|
||||
// Logout - returns to guest session
|
||||
const logout = useCallback(async () => {
|
||||
try {
|
||||
await apiClient.requestMFA({
|
||||
email,
|
||||
password: '', // This would need to be stored securely or re-entered
|
||||
deviceId,
|
||||
deviceName,
|
||||
});
|
||||
|
||||
setAuthState(prev => ({ ...prev, isLoading: false }));
|
||||
return true;
|
||||
// If authenticated, try to logout gracefully
|
||||
if (authState.isAuthenticated && !authState.isGuest) {
|
||||
const stored = getStoredAuthData();
|
||||
if (stored.accessToken && stored.refreshToken) {
|
||||
try {
|
||||
await apiClient.logout(stored.accessToken, stored.refreshToken);
|
||||
} catch (error) {
|
||||
console.warn('Logout request failed, proceeding with local cleanup');
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : 'Failed to resend MFA code';
|
||||
setAuthState(prev => ({
|
||||
...prev,
|
||||
isLoading: false,
|
||||
error: errorMessage
|
||||
}));
|
||||
return false;
|
||||
}
|
||||
}, [apiClient]);
|
||||
console.warn('Error during logout:', error);
|
||||
} finally {
|
||||
// Always clear stored auth and create new guest session
|
||||
clearStoredAuth();
|
||||
apiClient.clearAuthToken();
|
||||
guestCreationAttempted.current = false;
|
||||
|
||||
// Clear MFA state
|
||||
const clearMFA = useCallback(() => {
|
||||
// Create new guest session
|
||||
await createGuestSession();
|
||||
|
||||
console.log('🔄 Logged out, created new guest session');
|
||||
}
|
||||
}, [apiClient, authState.isAuthenticated, authState.isGuest, createGuestSession]);
|
||||
|
||||
// Update user data
|
||||
const updateUserData = useCallback((updatedUser: Types.User) => {
|
||||
updateStoredUserData(updatedUser);
|
||||
setAuthState(prev => ({
|
||||
...prev,
|
||||
mfaResponse: null,
|
||||
error: null
|
||||
user: authState.isGuest ? null : updatedUser,
|
||||
guest: authState.isGuest ? updatedUser as Types.Guest : prev.guest
|
||||
}));
|
||||
}, []);
|
||||
console.log('✅ User data updated');
|
||||
}, [authState.isGuest]);
|
||||
|
||||
// Email verification
|
||||
const verifyEmail = useCallback(async (verificationData: EmailVerificationRequest): Promise<{ message: string; userType: string } | null> => {
|
||||
// Email verification functions (unchanged)
|
||||
const verifyEmail = useCallback(async (verificationData: EmailVerificationRequest) => {
|
||||
setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
|
||||
|
||||
try {
|
||||
@ -504,7 +582,7 @@ function useAuthenticationLogic() {
|
||||
}
|
||||
}, [apiClient]);
|
||||
|
||||
// Resend email verification
|
||||
// Other existing methods remain the same...
|
||||
const resendEmailVerification = useCallback(async (email: string): Promise<boolean> => {
|
||||
setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
|
||||
|
||||
@ -523,53 +601,21 @@ function useAuthenticationLogic() {
|
||||
}
|
||||
}, [apiClient]);
|
||||
|
||||
// Store pending verification email
|
||||
const setPendingVerificationEmail = useCallback((email: string) => {
|
||||
localStorage.setItem(TOKEN_STORAGE.PENDING_VERIFICATION_EMAIL, email);
|
||||
}, []);
|
||||
|
||||
// Get pending verification email
|
||||
const getPendingVerificationEmail = useCallback((): string | null => {
|
||||
return localStorage.getItem(TOKEN_STORAGE.PENDING_VERIFICATION_EMAIL);
|
||||
}, []);
|
||||
|
||||
const logout = useCallback(() => {
|
||||
clearStoredAuth();
|
||||
apiClient.clearAuthToken();
|
||||
|
||||
// Create new guest session after logout
|
||||
const guest = createGuestSession();
|
||||
|
||||
setAuthState(prev => ({
|
||||
...prev,
|
||||
user: null,
|
||||
guest,
|
||||
isAuthenticated: false,
|
||||
isLoading: false,
|
||||
error: null,
|
||||
mfaResponse: null,
|
||||
}));
|
||||
|
||||
console.log('User logged out');
|
||||
}, [apiClient]);
|
||||
|
||||
const updateUserData = useCallback((updatedUser: Types.User) => {
|
||||
updateStoredUserData(updatedUser);
|
||||
setAuthState(prev => ({
|
||||
...prev,
|
||||
user: updatedUser
|
||||
}));
|
||||
console.log('User data updated', updatedUser);
|
||||
}, []);
|
||||
|
||||
const createEmployerAccount = useCallback(async (employerData: CreateEmployerRequest): Promise<boolean> => {
|
||||
setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
|
||||
|
||||
try {
|
||||
const employer = await apiClient.createEmployer(employerData);
|
||||
console.log('Employer created:', employer);
|
||||
|
||||
// Store email for potential verification resend
|
||||
console.log('✅ Employer created:', employer);
|
||||
|
||||
setPendingVerificationEmail(employerData.email);
|
||||
|
||||
setAuthState(prev => ({ ...prev, isLoading: false }));
|
||||
@ -614,24 +660,61 @@ function useAuthenticationLogic() {
|
||||
const refreshResult = await refreshAccessToken(stored.refreshToken);
|
||||
|
||||
if (refreshResult) {
|
||||
storeAuthData(refreshResult);
|
||||
const isGuest = stored.userType === 'guest';
|
||||
storeAuthData(refreshResult, isGuest);
|
||||
apiClient.setAuthToken(refreshResult.accessToken);
|
||||
|
||||
setAuthState(prev => ({
|
||||
...prev,
|
||||
user: refreshResult.user,
|
||||
user: isGuest ? null : refreshResult.user,
|
||||
guest: isGuest ? refreshResult.user as Types.Guest : null,
|
||||
isAuthenticated: true,
|
||||
isGuest,
|
||||
isLoading: false,
|
||||
error: null
|
||||
}));
|
||||
|
||||
return true;
|
||||
} else {
|
||||
logout();
|
||||
await logout();
|
||||
return false;
|
||||
}
|
||||
}, [refreshAccessToken, logout]);
|
||||
|
||||
// Resend MFA code
|
||||
const resendMFACode = useCallback(async (email: string, deviceId: string, deviceName: string): Promise<boolean> => {
|
||||
setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
|
||||
|
||||
try {
|
||||
await apiClient.requestMFA({
|
||||
email,
|
||||
password: '', // This would need to be stored securely or re-entered
|
||||
deviceId,
|
||||
deviceName,
|
||||
});
|
||||
|
||||
setAuthState(prev => ({ ...prev, isLoading: false }));
|
||||
return true;
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : 'Failed to resend MFA code';
|
||||
setAuthState(prev => ({
|
||||
...prev,
|
||||
isLoading: false,
|
||||
error: errorMessage
|
||||
}));
|
||||
return false;
|
||||
}
|
||||
}, [apiClient]);
|
||||
|
||||
// Clear MFA state
|
||||
const clearMFA = useCallback(() => {
|
||||
setAuthState(prev => ({
|
||||
...prev,
|
||||
mfaResponse: null,
|
||||
error: null
|
||||
}));
|
||||
}, []);
|
||||
|
||||
return {
|
||||
...authState,
|
||||
apiClient,
|
||||
@ -647,12 +730,14 @@ function useAuthenticationLogic() {
|
||||
createEmployerAccount,
|
||||
requestPasswordReset,
|
||||
refreshAuth,
|
||||
updateUserData
|
||||
updateUserData,
|
||||
convertGuestToUser,
|
||||
createGuestSession
|
||||
};
|
||||
}
|
||||
|
||||
// ============================
|
||||
// Context Provider
|
||||
// Enhanced Context Provider
|
||||
// ============================
|
||||
|
||||
const AuthContext = createContext<ReturnType<typeof useAuthenticationLogic> | null>(null);
|
||||
@ -676,34 +761,41 @@ function useAuth() {
|
||||
}
|
||||
|
||||
// ============================
|
||||
// Protected Route Component
|
||||
// Enhanced Protected Route Component
|
||||
// ============================
|
||||
|
||||
interface ProtectedRouteProps {
|
||||
children: React.ReactNode;
|
||||
fallback?: React.ReactNode;
|
||||
requiredUserType?: Types.UserType;
|
||||
allowGuests?: boolean;
|
||||
}
|
||||
|
||||
function ProtectedRoute({
|
||||
children,
|
||||
fallback = <div>Please log in to access this page.</div>,
|
||||
requiredUserType
|
||||
requiredUserType,
|
||||
allowGuests = false
|
||||
}: ProtectedRouteProps) {
|
||||
const { isAuthenticated, isInitializing, user } = useAuth();
|
||||
const { isAuthenticated, isInitializing, user, isGuest } = useAuth();
|
||||
|
||||
// Show loading while checking stored tokens
|
||||
if (isInitializing) {
|
||||
return <div>Loading...</div>;
|
||||
}
|
||||
|
||||
// Not authenticated
|
||||
// Not authenticated at all (shouldn't happen with guest sessions)
|
||||
if (!isAuthenticated) {
|
||||
return <>{fallback}</>;
|
||||
}
|
||||
|
||||
// Check user type if required
|
||||
if (requiredUserType && user?.userType !== requiredUserType) {
|
||||
// Guest access control
|
||||
if (isGuest && !allowGuests) {
|
||||
return <div>Please create an account or log in to access this page.</div>;
|
||||
}
|
||||
|
||||
// Check user type if required (only for non-guests)
|
||||
if (requiredUserType && !isGuest && user?.userType !== requiredUserType) {
|
||||
return <div>Access denied. Required user type: {requiredUserType}</div>;
|
||||
}
|
||||
|
||||
@ -711,11 +803,19 @@ function ProtectedRoute({
|
||||
}
|
||||
|
||||
export type {
|
||||
AuthState, LoginRequest, EmailVerificationRequest, ResendVerificationRequest, PasswordResetRequest
|
||||
AuthState,
|
||||
LoginRequest,
|
||||
EmailVerificationRequest,
|
||||
ResendVerificationRequest,
|
||||
PasswordResetRequest,
|
||||
GuestConversionRequest
|
||||
}
|
||||
|
||||
export type { CreateCandidateRequest, CreateEmployerRequest } from '../services/api-client';
|
||||
|
||||
export {
|
||||
useAuthenticationLogic, AuthProvider, useAuth, ProtectedRoute
|
||||
useAuthenticationLogic,
|
||||
AuthProvider,
|
||||
useAuth,
|
||||
ProtectedRoute
|
||||
}
|
@ -23,15 +23,16 @@ import { useAppState, useSelectedCandidate } from 'hooks/GlobalContext';
|
||||
import PropagateLoader from 'react-spinners/PropagateLoader';
|
||||
import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryTextField';
|
||||
import { BackstoryQuery } from 'components/BackstoryQuery';
|
||||
import { CandidatePicker } from 'components/ui/CandidatePicker';
|
||||
|
||||
const defaultMessage: ChatMessage = {
|
||||
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "user"
|
||||
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "user", metadata: null as any
|
||||
};
|
||||
|
||||
const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: BackstoryPageProps, ref) => {
|
||||
const { apiClient } = useAuth();
|
||||
const navigate = useNavigate();
|
||||
const { selectedCandidate } = useSelectedCandidate()
|
||||
const { selectedCandidate, setSelectedCandidate } = useSelectedCandidate()
|
||||
const theme = useTheme();
|
||||
const [processingMessage, setProcessingMessage] = useState<ChatMessageStatus | ChatMessageError | null>(null);
|
||||
const [streamingMessage, setStreamingMessage] = useState<ChatMessage | null>(null);
|
||||
@ -92,7 +93,7 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
||||
timestamp: new Date()
|
||||
};
|
||||
|
||||
setProcessingMessage({ ...defaultMessage, status: 'status', content: `Establishing connection with ${selectedCandidate.firstName}'s chat session.` });
|
||||
setProcessingMessage({ ...defaultMessage, status: 'status', activity: "info", content: `Establishing connection with ${selectedCandidate.firstName}'s chat session.` });
|
||||
|
||||
setMessages(prev => {
|
||||
const filtered = prev.filter((m: any) => m.id !== chatMessage.id);
|
||||
@ -123,7 +124,7 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
||||
},
|
||||
onStreaming: (chunk: ChatMessageStreaming) => {
|
||||
// console.log("onStreaming:", chunk);
|
||||
setStreamingMessage({ ...chunk, role: 'assistant' });
|
||||
setStreamingMessage({ ...chunk, role: 'assistant', metadata: null as any });
|
||||
},
|
||||
onStatus: (status: ChatMessageStatus) => {
|
||||
setProcessingMessage(status);
|
||||
@ -171,8 +172,7 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
||||
}, [chatSession]);
|
||||
|
||||
if (!selectedCandidate) {
|
||||
navigate('/find-a-candidate');
|
||||
return (<></>);
|
||||
return <CandidatePicker />;
|
||||
}
|
||||
|
||||
const welcomeMessage: ChatMessage = {
|
||||
@ -181,7 +181,8 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
||||
type: "text",
|
||||
status: "done",
|
||||
timestamp: new Date(),
|
||||
content: `Welcome to the Backstory Chat about ${selectedCandidate.fullName}. Ask any questions you have about ${selectedCandidate.firstName}.`
|
||||
content: `Welcome to the Backstory Chat about ${selectedCandidate.fullName}. Ask any questions you have about ${selectedCandidate.firstName}.`,
|
||||
metadata: null as any
|
||||
};
|
||||
|
||||
return (
|
||||
@ -192,12 +193,14 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
||||
gap: 1,
|
||||
}}>
|
||||
<CandidateInfo
|
||||
key={selectedCandidate.username}
|
||||
action={`Chat with Backstory about ${selectedCandidate.firstName}`}
|
||||
elevation={4}
|
||||
candidate={selectedCandidate}
|
||||
variant="small"
|
||||
sx={{ flexShrink: 0 }} // Prevent header from shrinking
|
||||
/>
|
||||
<Button onClick={() => { setSelectedCandidate(null); }} variant="contained">Change Candidates</Button>
|
||||
|
||||
{/* Chat Interface */}
|
||||
<Paper
|
||||
|
@ -23,7 +23,7 @@ import { Message } from 'components/Message';
|
||||
import { useAppState } from 'hooks/GlobalContext';
|
||||
|
||||
const defaultMessage: ChatMessage = {
|
||||
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "user"
|
||||
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "user", metadata: null as any
|
||||
};
|
||||
|
||||
const GenerateCandidate = (props: BackstoryElementProps) => {
|
||||
|
@ -20,6 +20,7 @@ import QuestionAnswerIcon from '@mui/icons-material/QuestionAnswer';
|
||||
import DescriptionIcon from '@mui/icons-material/Description';
|
||||
import professionalConversationPng from './Conversation.png';
|
||||
import { ComingSoon } from 'components/ui/ComingSoon';
|
||||
import { useAuth } from 'hooks/AuthContext';
|
||||
|
||||
// Placeholder for Testimonials component
|
||||
const Testimonials = () => {
|
||||
@ -135,6 +136,15 @@ const FeatureCard = ({
|
||||
|
||||
const HomePage = () => {
|
||||
const testimonials = false;
|
||||
const { isGuest, guest, user } = useAuth();
|
||||
|
||||
if (isGuest) {
|
||||
// Show guest-specific UI
|
||||
console.log('Guest session:', guest?.sessionId || "No guest");
|
||||
} else {
|
||||
// Show authenticated user UI
|
||||
console.log('Authenticated user:', user?.email || "No user");
|
||||
}
|
||||
|
||||
return (<Box sx={{display: "flex", flexDirection: "column"}}>
|
||||
{/* Hero Section */}
|
||||
|
@ -10,7 +10,8 @@ const LoadingPage = (props: BackstoryPageProps) => {
|
||||
status: 'done',
|
||||
sessionId: '',
|
||||
content: 'Please wait while connecting to Backstory...',
|
||||
timestamp: new Date()
|
||||
timestamp: new Date(),
|
||||
metadata: null as any
|
||||
}
|
||||
|
||||
return <Box sx={{display: "flex", flexGrow: 1, maxWidth: "1024px", margin: "0 auto"}}>
|
||||
|
@ -26,6 +26,7 @@ import { LoginForm } from "components/EmailVerificationComponents";
|
||||
import { CandidateRegistrationForm } from "components/RegistrationForms";
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import { useAppState } from 'hooks/GlobalContext';
|
||||
import * as Types from 'types/types';
|
||||
|
||||
const LoginPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps) => {
|
||||
const navigate = useNavigate();
|
||||
@ -34,7 +35,7 @@ const LoginPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps) => {
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [success, setSuccess] = useState<string | null>(null);
|
||||
const { guest, user, login, isLoading, error } = useAuth();
|
||||
const name = (user?.userType === 'candidate') ? user.username : user?.email || '';
|
||||
const name = (user?.userType === 'candidate') ? (user as Types.Candidate).username : user?.email || '';
|
||||
const [errorMessage, setErrorMessage] = useState<string | null>(null);
|
||||
|
||||
const showGuest: boolean = false;
|
||||
|
@ -10,7 +10,8 @@ const LoginRequired = (props: BackstoryPageProps) => {
|
||||
status: 'done',
|
||||
sessionId: '',
|
||||
content: 'You must be logged to view this feature.',
|
||||
timestamp: new Date()
|
||||
timestamp: new Date(),
|
||||
metadata: null as any
|
||||
}
|
||||
|
||||
return <Box sx={{display: "flex", flexGrow: 1, maxWidth: "1024px", margin: "0 auto"}}>
|
||||
|
@ -11,6 +11,7 @@ import { CandidateInfo } from 'components/ui/CandidateInfo';
|
||||
import { useAuth } from 'hooks/AuthContext';
|
||||
import { Candidate } from 'types/types';
|
||||
import { useAppState } from 'hooks/GlobalContext';
|
||||
import * as Types from 'types/types';
|
||||
|
||||
const ChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: BackstoryPageProps, ref) => {
|
||||
const { setSnack } = useAppState();
|
||||
@ -18,7 +19,7 @@ const ChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: Back
|
||||
const theme = useTheme();
|
||||
const isMobile = useMediaQuery(theme.breakpoints.down('md'));
|
||||
const [questions, setQuestions] = useState<React.ReactElement[]>([]);
|
||||
const candidate: Candidate | null = user?.userType === 'candidate' ? user : null;
|
||||
const candidate: Candidate | null = user?.userType === 'candidate' ? user as Types.Candidate : null;
|
||||
|
||||
// console.log("ChatPage candidate =>", candidate);
|
||||
useEffect(() => {
|
||||
|
@ -74,7 +74,7 @@ const CandidateDashboard = (props: CandidateDashboardProps) => {
|
||||
variant="contained"
|
||||
color="primary"
|
||||
sx={{ mt: 1 }}
|
||||
onClick={(e) => {e.stopPropagation(); navigate('/candidate/dashboard/profile'); }}
|
||||
onClick={(e) => { e.stopPropagation(); navigate('/candidate/profile'); }}
|
||||
>
|
||||
Complete Your Profile
|
||||
</Button>
|
||||
|
@ -14,9 +14,10 @@ import Typography from '@mui/material/Typography';
|
||||
// import ResetIcon from '@mui/icons-material/History';
|
||||
import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
|
||||
|
||||
import { connectionBase } from '../../utils/Global';
|
||||
import { BackstoryPageProps } from '../../components/BackstoryTab';
|
||||
import { useAppState } from 'hooks/GlobalContext';
|
||||
import { useAuth } from 'hooks/AuthContext';
|
||||
import * as Types from 'types/types';
|
||||
|
||||
interface ServerTunables {
|
||||
system_prompt: string,
|
||||
@ -33,19 +34,7 @@ type Tool = {
|
||||
returns?: any
|
||||
};
|
||||
|
||||
type GPUInfo = {
|
||||
name: string,
|
||||
memory: number,
|
||||
discrete: boolean
|
||||
}
|
||||
|
||||
type SystemInfo = {
|
||||
"Installed RAM": string,
|
||||
"Graphics Card": GPUInfo[],
|
||||
"CPU": string
|
||||
};
|
||||
|
||||
const SystemInfoComponent: React.FC<{ systemInfo: SystemInfo | undefined }> = ({ systemInfo }) => {
|
||||
const SystemInfoComponent: React.FC<{ systemInfo: Types.SystemInfo | undefined }> = ({ systemInfo }) => {
|
||||
const [systemElements, setSystemElements] = useState<ReactElement[]>([]);
|
||||
|
||||
const convertToSymbols = (text: string) => {
|
||||
@ -86,91 +75,92 @@ const SystemInfoComponent: React.FC<{ systemInfo: SystemInfo | undefined }> = ({
|
||||
};
|
||||
|
||||
const Settings = (props: BackstoryPageProps) => {
|
||||
const { apiClient } = useAuth();
|
||||
const { setSnack } = useAppState();
|
||||
const [editSystemPrompt, setEditSystemPrompt] = useState<string>("");
|
||||
const [systemInfo, setSystemInfo] = useState<SystemInfo | undefined>(undefined);
|
||||
const [systemInfo, setSystemInfo] = useState<Types.SystemInfo | undefined>(undefined);
|
||||
const [tools, setTools] = useState<Tool[]>([]);
|
||||
const [rags, setRags] = useState<Tool[]>([]);
|
||||
const [systemPrompt, setSystemPrompt] = useState<string>("");
|
||||
const [messageHistoryLength, setMessageHistoryLength] = useState<number>(5);
|
||||
const [serverTunables, setServerTunables] = useState<ServerTunables | undefined>(undefined);
|
||||
|
||||
useEffect(() => {
|
||||
if (serverTunables === undefined || systemPrompt === serverTunables.system_prompt || !systemPrompt.trim()) {
|
||||
return;
|
||||
}
|
||||
const sendSystemPrompt = async (prompt: string) => {
|
||||
try {
|
||||
const response = await fetch(connectionBase + `/api/1.0/tunables`, {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ "system_prompt": prompt }),
|
||||
});
|
||||
// useEffect(() => {
|
||||
// if (serverTunables === undefined || systemPrompt === serverTunables.system_prompt || !systemPrompt.trim()) {
|
||||
// return;
|
||||
// }
|
||||
// const sendSystemPrompt = async (prompt: string) => {
|
||||
// try {
|
||||
// const response = await fetch(connectionBase + `/api/1.0/tunables`, {
|
||||
// method: 'PUT',
|
||||
// headers: {
|
||||
// 'Content-Type': 'application/json',
|
||||
// 'Accept': 'application/json',
|
||||
// },
|
||||
// body: JSON.stringify({ "system_prompt": prompt }),
|
||||
// });
|
||||
|
||||
const tunables = await response.json();
|
||||
serverTunables.system_prompt = tunables.system_prompt;
|
||||
console.log(tunables);
|
||||
setSystemPrompt(tunables.system_prompt)
|
||||
setSnack("System prompt updated", "success");
|
||||
} catch (error) {
|
||||
console.error('Fetch error:', error);
|
||||
setSnack("System prompt update failed", "error");
|
||||
}
|
||||
};
|
||||
// const tunables = await response.json();
|
||||
// serverTunables.system_prompt = tunables.system_prompt;
|
||||
// console.log(tunables);
|
||||
// setSystemPrompt(tunables.system_prompt)
|
||||
// setSnack("System prompt updated", "success");
|
||||
// } catch (error) {
|
||||
// console.error('Fetch error:', error);
|
||||
// setSnack("System prompt update failed", "error");
|
||||
// }
|
||||
// };
|
||||
|
||||
sendSystemPrompt(systemPrompt);
|
||||
// sendSystemPrompt(systemPrompt);
|
||||
|
||||
}, [systemPrompt, setSnack, serverTunables]);
|
||||
// }, [systemPrompt, setSnack, serverTunables]);
|
||||
|
||||
const reset = async (types: ("rags" | "tools" | "history" | "system_prompt")[], message: string = "Update successful.") => {
|
||||
try {
|
||||
const response = await fetch(connectionBase + `/api/1.0/reset/`, {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ "reset": types }),
|
||||
});
|
||||
// const reset = async (types: ("rags" | "tools" | "history" | "system_prompt")[], message: string = "Update successful.") => {
|
||||
// try {
|
||||
// const response = await fetch(connectionBase + `/api/1.0/reset/`, {
|
||||
// method: 'PUT',
|
||||
// headers: {
|
||||
// 'Content-Type': 'application/json',
|
||||
// 'Accept': 'application/json',
|
||||
// },
|
||||
// body: JSON.stringify({ "reset": types }),
|
||||
// });
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
|
||||
}
|
||||
// if (!response.ok) {
|
||||
// throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
|
||||
// }
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error('Response body is null');
|
||||
}
|
||||
// if (!response.body) {
|
||||
// throw new Error('Response body is null');
|
||||
// }
|
||||
|
||||
const data = await response.json();
|
||||
if (data.error) {
|
||||
throw Error(data.error);
|
||||
}
|
||||
// const data = await response.json();
|
||||
// if (data.error) {
|
||||
// throw Error(data.error);
|
||||
// }
|
||||
|
||||
for (const [key, value] of Object.entries(data)) {
|
||||
switch (key) {
|
||||
case "rags":
|
||||
setRags(value as Tool[]);
|
||||
break;
|
||||
case "tools":
|
||||
setTools(value as Tool[]);
|
||||
break;
|
||||
case "system_prompt":
|
||||
setSystemPrompt((value as ServerTunables)["system_prompt"].trim());
|
||||
break;
|
||||
case "history":
|
||||
console.log('TODO: handle history reset');
|
||||
break;
|
||||
}
|
||||
}
|
||||
setSnack(message, "success");
|
||||
} catch (error) {
|
||||
console.error('Fetch error:', error);
|
||||
setSnack("Unable to restore defaults", "error");
|
||||
}
|
||||
};
|
||||
// for (const [key, value] of Object.entries(data)) {
|
||||
// switch (key) {
|
||||
// case "rags":
|
||||
// setRags(value as Tool[]);
|
||||
// break;
|
||||
// case "tools":
|
||||
// setTools(value as Tool[]);
|
||||
// break;
|
||||
// case "system_prompt":
|
||||
// setSystemPrompt((value as ServerTunables)["system_prompt"].trim());
|
||||
// break;
|
||||
// case "history":
|
||||
// console.log('TODO: handle history reset');
|
||||
// break;
|
||||
// }
|
||||
// }
|
||||
// setSnack(message, "success");
|
||||
// } catch (error) {
|
||||
// console.error('Fetch error:', error);
|
||||
// setSnack("Unable to restore defaults", "error");
|
||||
// }
|
||||
// };
|
||||
|
||||
// Get the system information
|
||||
useEffect(() => {
|
||||
@ -179,27 +169,8 @@ const Settings = (props: BackstoryPageProps) => {
|
||||
}
|
||||
const fetchSystemInfo = async () => {
|
||||
try {
|
||||
const response = await fetch(connectionBase + `/api/1.0/system-info`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error('Response body is null');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
if (data.error) {
|
||||
throw Error(data.error);
|
||||
}
|
||||
|
||||
setSystemInfo(data);
|
||||
const response: Types.SystemInfo = await apiClient.getSystemInfo();
|
||||
setSystemInfo(response);
|
||||
} catch (error) {
|
||||
console.error('Error obtaining system information:', error);
|
||||
setSnack("Unable to obtain system information.", "error");
|
||||
@ -217,101 +188,101 @@ const Settings = (props: BackstoryPageProps) => {
|
||||
setEditSystemPrompt(systemPrompt.trim());
|
||||
}, [systemPrompt, setEditSystemPrompt]);
|
||||
|
||||
const toggleRag = async (tool: Tool) => {
|
||||
tool.enabled = !tool.enabled
|
||||
try {
|
||||
const response = await fetch(connectionBase + `/api/1.0/tunables`, {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ "rags": [{ "name": tool?.name, "enabled": tool.enabled }] }),
|
||||
});
|
||||
// const toggleRag = async (tool: Tool) => {
|
||||
// tool.enabled = !tool.enabled
|
||||
// try {
|
||||
// const response = await fetch(connectionBase + `/api/1.0/tunables`, {
|
||||
// method: 'PUT',
|
||||
// headers: {
|
||||
// 'Content-Type': 'application/json',
|
||||
// 'Accept': 'application/json',
|
||||
// },
|
||||
// body: JSON.stringify({ "rags": [{ "name": tool?.name, "enabled": tool.enabled }] }),
|
||||
// });
|
||||
|
||||
const tunables: ServerTunables = await response.json();
|
||||
setRags(tunables.rags)
|
||||
setSnack(`${tool?.name} ${tool.enabled ? "enabled" : "disabled"}`);
|
||||
} catch (error) {
|
||||
console.error('Fetch error:', error);
|
||||
setSnack(`${tool?.name} ${tool.enabled ? "enabling" : "disabling"} failed.`, "error");
|
||||
tool.enabled = !tool.enabled
|
||||
}
|
||||
};
|
||||
// const tunables: ServerTunables = await response.json();
|
||||
// setRags(tunables.rags)
|
||||
// setSnack(`${tool?.name} ${tool.enabled ? "enabled" : "disabled"}`);
|
||||
// } catch (error) {
|
||||
// console.error('Fetch error:', error);
|
||||
// setSnack(`${tool?.name} ${tool.enabled ? "enabling" : "disabling"} failed.`, "error");
|
||||
// tool.enabled = !tool.enabled
|
||||
// }
|
||||
// };
|
||||
|
||||
const toggleTool = async (tool: Tool) => {
|
||||
tool.enabled = !tool.enabled
|
||||
try {
|
||||
const response = await fetch(connectionBase + `/api/1.0/tunables`, {
|
||||
method: 'PUT',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ "tools": [{ "name": tool.name, "enabled": tool.enabled }] }),
|
||||
});
|
||||
// const toggleTool = async (tool: Tool) => {
|
||||
// tool.enabled = !tool.enabled
|
||||
// try {
|
||||
// const response = await fetch(connectionBase + `/api/1.0/tunables`, {
|
||||
// method: 'PUT',
|
||||
// headers: {
|
||||
// 'Content-Type': 'application/json',
|
||||
// 'Accept': 'application/json',
|
||||
// },
|
||||
// body: JSON.stringify({ "tools": [{ "name": tool.name, "enabled": tool.enabled }] }),
|
||||
// });
|
||||
|
||||
const tunables: ServerTunables = await response.json();
|
||||
setTools(tunables.tools)
|
||||
setSnack(`${tool.name} ${tool.enabled ? "enabled" : "disabled"}`);
|
||||
} catch (error) {
|
||||
console.error('Fetch error:', error);
|
||||
setSnack(`${tool.name} ${tool.enabled ? "enabling" : "disabling"} failed.`, "error");
|
||||
tool.enabled = !tool.enabled
|
||||
}
|
||||
};
|
||||
// const tunables: ServerTunables = await response.json();
|
||||
// setTools(tunables.tools)
|
||||
// setSnack(`${tool.name} ${tool.enabled ? "enabled" : "disabled"}`);
|
||||
// } catch (error) {
|
||||
// console.error('Fetch error:', error);
|
||||
// setSnack(`${tool.name} ${tool.enabled ? "enabling" : "disabling"} failed.`, "error");
|
||||
// tool.enabled = !tool.enabled
|
||||
// }
|
||||
// };
|
||||
|
||||
// If the systemPrompt has not been set, fetch it from the server
|
||||
useEffect(() => {
|
||||
if (serverTunables !== undefined) {
|
||||
return;
|
||||
}
|
||||
const fetchTunables = async () => {
|
||||
try {
|
||||
// Make the fetch request with proper headers
|
||||
const response = await fetch(connectionBase + `/api/1.0/tunables`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
},
|
||||
});
|
||||
const data = await response.json();
|
||||
// console.log("Server tunables: ", data);
|
||||
setServerTunables(data);
|
||||
setSystemPrompt(data["system_prompt"]);
|
||||
setTools(data["tools"]);
|
||||
setRags(data["rags"]);
|
||||
} catch (error) {
|
||||
console.error('Fetch error:', error);
|
||||
setSnack("System prompt update failed", "error");
|
||||
}
|
||||
}
|
||||
// useEffect(() => {
|
||||
// if (serverTunables !== undefined) {
|
||||
// return;
|
||||
// }
|
||||
// const fetchTunables = async () => {
|
||||
// try {
|
||||
// // Make the fetch request with proper headers
|
||||
// const response = await fetch(connectionBase + `/api/1.0/tunables`, {
|
||||
// method: 'GET',
|
||||
// headers: {
|
||||
// 'Content-Type': 'application/json',
|
||||
// 'Accept': 'application/json',
|
||||
// },
|
||||
// });
|
||||
// const data = await response.json();
|
||||
// // console.log("Server tunables: ", data);
|
||||
// setServerTunables(data);
|
||||
// setSystemPrompt(data["system_prompt"]);
|
||||
// setTools(data["tools"]);
|
||||
// setRags(data["rags"]);
|
||||
// } catch (error) {
|
||||
// console.error('Fetch error:', error);
|
||||
// setSnack("System prompt update failed", "error");
|
||||
// }
|
||||
// }
|
||||
|
||||
fetchTunables();
|
||||
}, [setServerTunables, setSystemPrompt, setMessageHistoryLength, serverTunables, setTools, setRags, setSnack]);
|
||||
// fetchTunables();
|
||||
// }, [setServerTunables, setSystemPrompt, setMessageHistoryLength, serverTunables, setTools, setRags, setSnack]);
|
||||
|
||||
const toggle = async (type: string, index: number) => {
|
||||
switch (type) {
|
||||
case "rag":
|
||||
if (rags === undefined) {
|
||||
return;
|
||||
}
|
||||
toggleRag(rags[index])
|
||||
break;
|
||||
case "tool":
|
||||
if (tools === undefined) {
|
||||
return;
|
||||
}
|
||||
toggleTool(tools[index]);
|
||||
}
|
||||
};
|
||||
// const toggle = async (type: string, index: number) => {
|
||||
// switch (type) {
|
||||
// case "rag":
|
||||
// if (rags === undefined) {
|
||||
// return;
|
||||
// }
|
||||
// toggleRag(rags[index])
|
||||
// break;
|
||||
// case "tool":
|
||||
// if (tools === undefined) {
|
||||
// return;
|
||||
// }
|
||||
// toggleTool(tools[index]);
|
||||
// }
|
||||
// };
|
||||
|
||||
const handleKeyPress = (event: any) => {
|
||||
if (event.key === 'Enter' && event.ctrlKey) {
|
||||
setSystemPrompt(editSystemPrompt);
|
||||
}
|
||||
};
|
||||
// const handleKeyPress = (event: any) => {
|
||||
// if (event.key === 'Enter' && event.ctrlKey) {
|
||||
// setSystemPrompt(editSystemPrompt);
|
||||
// }
|
||||
// };
|
||||
|
||||
return (<div className="Controls">
|
||||
{/* <Typography component="span" sx={{ mb: 1 }}>
|
||||
|
@ -6,6 +6,7 @@ import { SetSnackType } from '../components/Snack';
|
||||
import { LoadingComponent } from "../components/LoadingComponent";
|
||||
import { User, Guest, Candidate } from 'types/types';
|
||||
import { useAuth } from "hooks/AuthContext";
|
||||
import { useSelectedCandidate } from "hooks/GlobalContext";
|
||||
|
||||
interface CandidateRouteProps {
|
||||
guest?: Guest | null;
|
||||
@ -15,19 +16,19 @@ interface CandidateRouteProps {
|
||||
|
||||
const CandidateRoute: React.FC<CandidateRouteProps> = (props: CandidateRouteProps) => {
|
||||
const { apiClient } = useAuth();
|
||||
const { selectedCandidate, setSelectedCandidate } = useSelectedCandidate();
|
||||
const { setSnack } = props;
|
||||
const { username } = useParams<{ username: string }>();
|
||||
const [candidate, setCandidate] = useState<Candidate|null>(null);
|
||||
const navigate = useNavigate();
|
||||
|
||||
useEffect(() => {
|
||||
if (candidate?.username === username || !username) {
|
||||
if (selectedCandidate?.username === username || !username) {
|
||||
return;
|
||||
}
|
||||
const getCandidate = async (reference: string) => {
|
||||
try {
|
||||
const result: Candidate = await apiClient.getCandidate(reference);
|
||||
setCandidate(result);
|
||||
setSelectedCandidate(result);
|
||||
navigate('/chat');
|
||||
} catch {
|
||||
setSnack(`Unable to obtain information for ${username}.`, "error");
|
||||
@ -36,9 +37,9 @@ const CandidateRoute: React.FC<CandidateRouteProps> = (props: CandidateRouteProp
|
||||
}
|
||||
|
||||
getCandidate(username);
|
||||
}, [candidate, username, setCandidate, navigate, setSnack, apiClient]);
|
||||
}, [selectedCandidate, username, selectedCandidate, navigate, setSnack, apiClient]);
|
||||
|
||||
if (candidate === null) {
|
||||
if (selectedCandidate?.username !== username) {
|
||||
return (<Box>
|
||||
<LoadingComponent
|
||||
loadingText="Fetching candidate information..."
|
||||
|
@ -29,10 +29,33 @@ import {
|
||||
convertArrayFromApi
|
||||
} from 'types/types';
|
||||
|
||||
const TOKEN_STORAGE = {
|
||||
ACCESS_TOKEN: 'accessToken',
|
||||
REFRESH_TOKEN: 'refreshToken',
|
||||
USER_DATA: 'userData',
|
||||
TOKEN_EXPIRY: 'tokenExpiry',
|
||||
USER_TYPE: 'userType',
|
||||
IS_GUEST: 'isGuest',
|
||||
PENDING_VERIFICATION_EMAIL: 'pendingVerificationEmail'
|
||||
} as const;
|
||||
|
||||
// ============================
|
||||
// Streaming Types and Interfaces
|
||||
// ============================
|
||||
export interface GuestConversionRequest extends CreateCandidateRequest {
|
||||
accountType: 'candidate';
|
||||
}
|
||||
|
||||
export class RateLimitError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public retryAfterSeconds: number,
|
||||
public remainingRequests: Record<string, number>
|
||||
) {
|
||||
super(message);
|
||||
this.name = 'RateLimitError';
|
||||
}
|
||||
}
|
||||
interface StreamingOptions<T = Types.ChatMessage> {
|
||||
method?: string,
|
||||
headers?: Record<string, any>,
|
||||
@ -777,7 +800,7 @@ class ApiClient {
|
||||
async getOrCreateChatSession(candidate: Types.Candidate, title: string, context_type: Types.ChatContextType) : Promise<Types.ChatSession> {
|
||||
const result = await this.getCandidateChatSessions(candidate.username);
|
||||
/* Find the 'candidate_chat' session if it exists, otherwise create it */
|
||||
let session = result.sessions.data.find(session => session.title === 'candidate_chat');
|
||||
let session = result.sessions.data.find(session => session.title === title);
|
||||
if (!session) {
|
||||
session = await this.createCandidateChatSession(
|
||||
candidate.username,
|
||||
@ -788,6 +811,17 @@ class ApiClient {
|
||||
return session;
|
||||
}
|
||||
|
||||
async getSystemInfo() : Promise<Types.SystemInfo> {
|
||||
const response = await fetch(`${this.baseUrl}/system-info`, {
|
||||
method: 'GET',
|
||||
headers: this.defaultHeaders,
|
||||
});
|
||||
|
||||
const result = await handleApiResponse<Types.SystemInfo>(response);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
async getCandidateSimilarContent(query: string
|
||||
): Promise<Types.ChromaDBGetResponse> {
|
||||
const response = await fetch(`${this.baseUrl}/candidates/rag-search`, {
|
||||
@ -1003,6 +1037,202 @@ class ApiClient {
|
||||
return handleApiResponse<{ success: boolean; message: string }>(response);
|
||||
}
|
||||
|
||||
|
||||
// ============================
|
||||
// Guest Authentication Methods
|
||||
// ============================
|
||||
|
||||
/**
|
||||
* Create a guest session with authentication
|
||||
*/
|
||||
async createGuestSession(): Promise<Types.AuthResponse> {
|
||||
const response = await fetch(`${this.baseUrl}/auth/guest`, {
|
||||
method: 'POST',
|
||||
headers: this.defaultHeaders
|
||||
});
|
||||
|
||||
const result = await handleApiResponse<Types.AuthResponse>(response);
|
||||
|
||||
// Convert guest data if needed
|
||||
if (result.user && result.user.userType === 'guest') {
|
||||
result.user = convertFromApi<Types.Guest>(result.user, "Guest");
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert guest account to permanent user account
|
||||
*/
|
||||
async convertGuestToUser(
|
||||
registrationData: CreateCandidateRequest & { accountType: 'candidate' }
|
||||
): Promise<{
|
||||
message: string;
|
||||
auth: Types.AuthResponse;
|
||||
conversionType: string
|
||||
}> {
|
||||
const response = await fetch(`${this.baseUrl}/auth/guest/convert`, {
|
||||
method: 'POST',
|
||||
headers: this.defaultHeaders,
|
||||
body: JSON.stringify(formatApiRequest(registrationData))
|
||||
});
|
||||
|
||||
const result = await handleApiResponse<{
|
||||
message: string;
|
||||
auth: Types.AuthResponse;
|
||||
conversionType: string;
|
||||
}>(response);
|
||||
|
||||
// Convert the auth user data
|
||||
if (result.auth?.user) {
|
||||
result.auth.user = convertFromApi<Types.Candidate>(result.auth.user, "Candidate");
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if current session is a guest
|
||||
*/
|
||||
isGuestSession(): boolean {
|
||||
try {
|
||||
const userDataStr = localStorage.getItem(TOKEN_STORAGE.USER_DATA);
|
||||
if (userDataStr) {
|
||||
const userData = JSON.parse(userDataStr);
|
||||
return userData.userType === 'guest';
|
||||
}
|
||||
return false;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get guest session info
|
||||
*/
|
||||
getGuestSessionInfo(): Types.Guest | null {
|
||||
try {
|
||||
const userDataStr = localStorage.getItem(TOKEN_STORAGE.USER_DATA);
|
||||
if (userDataStr) {
|
||||
const userData = JSON.parse(userDataStr);
|
||||
if (userData.userType === 'guest') {
|
||||
return convertFromApi<Types.Guest>(userData, "Guest");
|
||||
}
|
||||
}
|
||||
return null;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get rate limit status for current user
|
||||
*/
|
||||
async getRateLimitStatus(): Promise<{
|
||||
user_id: string;
|
||||
user_type: string;
|
||||
is_admin: boolean;
|
||||
current_usage: Record<string, number>;
|
||||
limits: Record<string, number>;
|
||||
remaining: Record<string, number>;
|
||||
reset_times: Record<string, string>;
|
||||
config: any;
|
||||
}> {
|
||||
const response = await fetch(`${this.baseUrl}/admin/rate-limits/info`, {
|
||||
headers: this.defaultHeaders
|
||||
});
|
||||
|
||||
return handleApiResponse<any>(response);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get guest statistics (admin only)
|
||||
*/
|
||||
async getGuestStatistics(): Promise<{
|
||||
total_guests: number;
|
||||
active_last_hour: number;
|
||||
active_last_day: number;
|
||||
converted_guests: number;
|
||||
by_ip: Record<string, number>;
|
||||
creation_timeline: Record<string, number>;
|
||||
}> {
|
||||
const response = await fetch(`${this.baseUrl}/admin/guests/statistics`, {
|
||||
headers: this.defaultHeaders
|
||||
});
|
||||
|
||||
return handleApiResponse<any>(response);
|
||||
}
|
||||
|
||||
/**
|
||||
* Cleanup inactive guests (admin only)
|
||||
*/
|
||||
async cleanupInactiveGuests(inactiveHours: number = 24): Promise<{
|
||||
message: string;
|
||||
cleaned_count: number;
|
||||
}> {
|
||||
const response = await fetch(`${this.baseUrl}/admin/guests/cleanup`, {
|
||||
method: 'POST',
|
||||
headers: this.defaultHeaders,
|
||||
body: JSON.stringify({ inactive_hours: inactiveHours })
|
||||
});
|
||||
|
||||
return handleApiResponse<{
|
||||
message: string;
|
||||
cleaned_count: number;
|
||||
}>(response);
|
||||
}
|
||||
|
||||
// ============================
|
||||
// Enhanced Error Handling for Rate Limits
|
||||
// ============================
|
||||
|
||||
/**
|
||||
* Enhanced API response handler with rate limit handling
|
||||
*/
|
||||
private async handleApiResponseWithRateLimit<T>(response: Response): Promise<T> {
|
||||
if (response.status === 429) {
|
||||
const rateLimitData = await response.json();
|
||||
const retryAfter = response.headers.get('Retry-After');
|
||||
|
||||
throw new RateLimitError(
|
||||
rateLimitData.detail?.message || 'Rate limit exceeded',
|
||||
parseInt(retryAfter || '60'),
|
||||
rateLimitData.detail?.remaining || {}
|
||||
);
|
||||
}
|
||||
|
||||
return this.handleApiResponseWithConversion<T>(response);
|
||||
}
|
||||
|
||||
/**
|
||||
* Retry mechanism for rate-limited requests
|
||||
*/
|
||||
private async retryWithBackoff<T>(
|
||||
requestFn: () => Promise<Response>,
|
||||
maxRetries: number = 3
|
||||
): Promise<T> {
|
||||
let lastError: Error;
|
||||
|
||||
for (let attempt = 0; attempt <= maxRetries; attempt++) {
|
||||
try {
|
||||
const response = await requestFn();
|
||||
return await this.handleApiResponseWithRateLimit<T>(response);
|
||||
} catch (error) {
|
||||
lastError = error as Error;
|
||||
|
||||
if (error instanceof RateLimitError && attempt < maxRetries) {
|
||||
const delayMs = Math.min(error.retryAfterSeconds * 1000, 60000); // Max 1 minute
|
||||
console.warn(`Rate limited, retrying in ${delayMs}ms (attempt ${attempt + 1}/${maxRetries + 1})`);
|
||||
await new Promise(resolve => setTimeout(resolve, delayMs));
|
||||
continue;
|
||||
}
|
||||
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
throw lastError!;
|
||||
}
|
||||
async resetChatSession(id: string): Promise<{ success: boolean; message: string }> {
|
||||
const response = await fetch(`${this.baseUrl}/chat/sessions/${id}/reset`, {
|
||||
method: 'PATCH',
|
||||
@ -1384,5 +1614,5 @@ export interface PendingVerification {
|
||||
attempts: number;
|
||||
}
|
||||
|
||||
export { ApiClient }
|
||||
export { ApiClient, TOKEN_STORAGE }
|
||||
export type { StreamingOptions, StreamingResponse }
|
@ -1,6 +1,6 @@
|
||||
// Generated TypeScript types from Pydantic models
|
||||
// Source: src/backend/models.py
|
||||
// Generated on: 2025-06-07T20:43:58.855207
|
||||
// Generated on: 2025-06-09T03:32:01.335483
|
||||
// DO NOT EDIT MANUALLY - This file is auto-generated
|
||||
|
||||
// ============================
|
||||
@ -128,8 +128,10 @@ export interface Attachment {
|
||||
export interface AuthResponse {
|
||||
accessToken: string;
|
||||
refreshToken: string;
|
||||
user: any;
|
||||
user: Candidate | Employer | Guest;
|
||||
expiresAt: number;
|
||||
userType?: string;
|
||||
isGuest?: boolean;
|
||||
}
|
||||
|
||||
export interface Authentication {
|
||||
@ -148,7 +150,9 @@ export interface Authentication {
|
||||
}
|
||||
|
||||
export interface BaseUser {
|
||||
userType: "candidate" | "employer" | "guest";
|
||||
id?: string;
|
||||
lastActivity?: Date;
|
||||
email: string;
|
||||
firstName: string;
|
||||
lastName: string;
|
||||
@ -164,24 +168,15 @@ export interface BaseUser {
|
||||
}
|
||||
|
||||
export interface BaseUserWithType {
|
||||
id?: string;
|
||||
email: string;
|
||||
firstName: string;
|
||||
lastName: string;
|
||||
fullName: string;
|
||||
phone?: string;
|
||||
location?: Location;
|
||||
createdAt: Date;
|
||||
updatedAt: Date;
|
||||
lastLogin?: Date;
|
||||
profileImage?: string;
|
||||
status: "active" | "inactive" | "pending" | "banned";
|
||||
isAdmin: boolean;
|
||||
userType: "candidate" | "employer" | "guest";
|
||||
id?: string;
|
||||
lastActivity?: Date;
|
||||
}
|
||||
|
||||
export interface Candidate {
|
||||
userType: "candidate" | "employer" | "guest";
|
||||
id?: string;
|
||||
lastActivity?: Date;
|
||||
email: string;
|
||||
firstName: string;
|
||||
lastName: string;
|
||||
@ -194,7 +189,6 @@ export interface Candidate {
|
||||
profileImage?: string;
|
||||
status: "active" | "inactive" | "pending" | "banned";
|
||||
isAdmin: boolean;
|
||||
userType: "candidate";
|
||||
username: string;
|
||||
description?: string;
|
||||
resume?: string;
|
||||
@ -214,7 +208,9 @@ export interface Candidate {
|
||||
}
|
||||
|
||||
export interface CandidateAI {
|
||||
userType: "candidate" | "employer" | "guest";
|
||||
id?: string;
|
||||
lastActivity?: Date;
|
||||
email: string;
|
||||
firstName: string;
|
||||
lastName: string;
|
||||
@ -227,7 +223,6 @@ export interface CandidateAI {
|
||||
profileImage?: string;
|
||||
status: "active" | "inactive" | "pending" | "banned";
|
||||
isAdmin: boolean;
|
||||
userType: "candidate";
|
||||
username: string;
|
||||
description?: string;
|
||||
resume?: string;
|
||||
@ -301,7 +296,7 @@ export interface ChatMessage {
|
||||
role: "user" | "assistant" | "system" | "information" | "warning" | "error";
|
||||
content: string;
|
||||
tunables?: Tunables;
|
||||
metadata?: ChatMessageMetaData;
|
||||
metadata: ChatMessageMetaData;
|
||||
}
|
||||
|
||||
export interface ChatMessageError {
|
||||
@ -319,9 +314,9 @@ export interface ChatMessageMetaData {
|
||||
temperature: number;
|
||||
maxTokens: number;
|
||||
topP: number;
|
||||
frequencyPenalty?: number;
|
||||
presencePenalty?: number;
|
||||
stopSequences?: Array<string>;
|
||||
frequencyPenalty: number;
|
||||
presencePenalty: number;
|
||||
stopSequences: Array<string>;
|
||||
ragResults?: Array<ChromaDBGetResponse>;
|
||||
llmHistory?: Array<any>;
|
||||
evalCount: number;
|
||||
@ -392,7 +387,6 @@ export interface ChatQuery {
|
||||
export interface ChatSession {
|
||||
id?: string;
|
||||
userId?: string;
|
||||
guestId?: string;
|
||||
createdAt?: Date;
|
||||
lastActivity?: Date;
|
||||
title?: string;
|
||||
@ -507,7 +501,7 @@ export interface DocumentMessage {
|
||||
}
|
||||
|
||||
export interface DocumentOptions {
|
||||
includeInRAG?: boolean;
|
||||
includeInRAG: boolean;
|
||||
isJobDocument?: boolean;
|
||||
overwrite?: boolean;
|
||||
}
|
||||
@ -541,7 +535,9 @@ export interface EmailVerificationRequest {
|
||||
}
|
||||
|
||||
export interface Employer {
|
||||
userType: "candidate" | "employer" | "guest";
|
||||
id?: string;
|
||||
lastActivity?: Date;
|
||||
email: string;
|
||||
firstName: string;
|
||||
lastName: string;
|
||||
@ -554,7 +550,6 @@ export interface Employer {
|
||||
profileImage?: string;
|
||||
status: "active" | "inactive" | "pending" | "banned";
|
||||
isAdmin: boolean;
|
||||
userType: "employer";
|
||||
companyName: string;
|
||||
industry: string;
|
||||
description?: string;
|
||||
@ -580,14 +575,71 @@ export interface ErrorDetail {
|
||||
details?: any;
|
||||
}
|
||||
|
||||
export interface GPUInfo {
|
||||
name: string;
|
||||
memory: number;
|
||||
discrete: boolean;
|
||||
}
|
||||
|
||||
export interface Guest {
|
||||
userType: "candidate" | "employer" | "guest";
|
||||
id?: string;
|
||||
sessionId: string;
|
||||
lastActivity?: Date;
|
||||
email: string;
|
||||
firstName: string;
|
||||
lastName: string;
|
||||
fullName: string;
|
||||
phone?: string;
|
||||
location?: Location;
|
||||
createdAt: Date;
|
||||
lastActivity: Date;
|
||||
updatedAt: Date;
|
||||
lastLogin?: Date;
|
||||
profileImage?: string;
|
||||
status: "active" | "inactive" | "pending" | "banned";
|
||||
isAdmin: boolean;
|
||||
sessionId: string;
|
||||
username: string;
|
||||
convertedToUserId?: string;
|
||||
ipAddress?: string;
|
||||
userAgent?: string;
|
||||
ragContentSize: number;
|
||||
}
|
||||
|
||||
export interface GuestCleanupRequest {
|
||||
inactiveHours: number;
|
||||
}
|
||||
|
||||
export interface GuestConversionRequest {
|
||||
accountType: "candidate" | "employer";
|
||||
email: string;
|
||||
username: string;
|
||||
password: string;
|
||||
firstName: string;
|
||||
lastName: string;
|
||||
phone?: string;
|
||||
companyName?: string;
|
||||
industry?: string;
|
||||
companySize?: string;
|
||||
companyDescription?: string;
|
||||
websiteUrl?: string;
|
||||
}
|
||||
|
||||
export interface GuestSessionResponse {
|
||||
accessToken: string;
|
||||
refreshToken: string;
|
||||
user: Guest;
|
||||
expiresAt: number;
|
||||
userType: "guest";
|
||||
isGuest: boolean;
|
||||
}
|
||||
|
||||
export interface GuestStatistics {
|
||||
totalGuests: number;
|
||||
activeLastHour: number;
|
||||
activeLastDay: number;
|
||||
convertedGuests: number;
|
||||
byIp: Record<string, number>;
|
||||
creationTimeline: Record<string, number>;
|
||||
}
|
||||
|
||||
export interface InterviewFeedback {
|
||||
@ -746,6 +798,7 @@ export interface MFARequest {
|
||||
password: string;
|
||||
deviceId: string;
|
||||
deviceName: string;
|
||||
email: string;
|
||||
}
|
||||
|
||||
export interface MFARequestResponse {
|
||||
@ -841,6 +894,33 @@ export interface RagEntry {
|
||||
enabled: boolean;
|
||||
}
|
||||
|
||||
export interface RateLimitConfig {
|
||||
requestsPerMinute: number;
|
||||
requestsPerHour: number;
|
||||
requestsPerDay: number;
|
||||
burstLimit: number;
|
||||
burstWindowSeconds: number;
|
||||
}
|
||||
|
||||
export interface RateLimitResult {
|
||||
allowed: boolean;
|
||||
reason?: string;
|
||||
retryAfterSeconds?: number;
|
||||
remainingRequests?: Record<string, number>;
|
||||
resetTimes?: Record<string, Date>;
|
||||
}
|
||||
|
||||
export interface RateLimitStatus {
|
||||
userId: string;
|
||||
userType: string;
|
||||
isAdmin: boolean;
|
||||
currentUsage: Record<string, number>;
|
||||
limits: Record<string, number>;
|
||||
remaining: Record<string, number>;
|
||||
resetTimes: Record<string, Date>;
|
||||
config: RateLimitConfig;
|
||||
}
|
||||
|
||||
export interface RefreshToken {
|
||||
token: string;
|
||||
expiresAt: Date;
|
||||
@ -915,6 +995,15 @@ export interface SocialLink {
|
||||
url: string;
|
||||
}
|
||||
|
||||
export interface SystemInfo {
|
||||
installedRAM: string;
|
||||
graphicsCards: Array<GPUInfo>;
|
||||
CPU: string;
|
||||
llmModel: string;
|
||||
embeddingModel: string;
|
||||
maxContextLength: number;
|
||||
}
|
||||
|
||||
export interface Tunables {
|
||||
enableRAG: boolean;
|
||||
enableTools: boolean;
|
||||
@ -1035,13 +1124,15 @@ export function convertAuthenticationFromApi(data: any): Authentication {
|
||||
}
|
||||
/**
|
||||
* Convert BaseUser from API response, parsing date fields
|
||||
* Date fields: createdAt, updatedAt, lastLogin
|
||||
* Date fields: lastActivity, createdAt, updatedAt, lastLogin
|
||||
*/
|
||||
export function convertBaseUserFromApi(data: any): BaseUser {
|
||||
if (!data) return data;
|
||||
|
||||
return {
|
||||
...data,
|
||||
// Convert lastActivity from ISO string to Date
|
||||
lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined,
|
||||
// Convert createdAt from ISO string to Date
|
||||
createdAt: new Date(data.createdAt),
|
||||
// Convert updatedAt from ISO string to Date
|
||||
@ -1052,30 +1143,28 @@ export function convertBaseUserFromApi(data: any): BaseUser {
|
||||
}
|
||||
/**
|
||||
* Convert BaseUserWithType from API response, parsing date fields
|
||||
* Date fields: createdAt, updatedAt, lastLogin
|
||||
* Date fields: lastActivity
|
||||
*/
|
||||
export function convertBaseUserWithTypeFromApi(data: any): BaseUserWithType {
|
||||
if (!data) return data;
|
||||
|
||||
return {
|
||||
...data,
|
||||
// Convert createdAt from ISO string to Date
|
||||
createdAt: new Date(data.createdAt),
|
||||
// Convert updatedAt from ISO string to Date
|
||||
updatedAt: new Date(data.updatedAt),
|
||||
// Convert lastLogin from ISO string to Date
|
||||
lastLogin: data.lastLogin ? new Date(data.lastLogin) : undefined,
|
||||
// Convert lastActivity from ISO string to Date
|
||||
lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined,
|
||||
};
|
||||
}
|
||||
/**
|
||||
* Convert Candidate from API response, parsing date fields
|
||||
* Date fields: createdAt, updatedAt, lastLogin, availabilityDate
|
||||
* Date fields: lastActivity, createdAt, updatedAt, lastLogin, availabilityDate
|
||||
*/
|
||||
export function convertCandidateFromApi(data: any): Candidate {
|
||||
if (!data) return data;
|
||||
|
||||
return {
|
||||
...data,
|
||||
// Convert lastActivity from ISO string to Date
|
||||
lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined,
|
||||
// Convert createdAt from ISO string to Date
|
||||
createdAt: new Date(data.createdAt),
|
||||
// Convert updatedAt from ISO string to Date
|
||||
@ -1088,13 +1177,15 @@ export function convertCandidateFromApi(data: any): Candidate {
|
||||
}
|
||||
/**
|
||||
* Convert CandidateAI from API response, parsing date fields
|
||||
* Date fields: createdAt, updatedAt, lastLogin, availabilityDate
|
||||
* Date fields: lastActivity, createdAt, updatedAt, lastLogin, availabilityDate
|
||||
*/
|
||||
export function convertCandidateAIFromApi(data: any): CandidateAI {
|
||||
if (!data) return data;
|
||||
|
||||
return {
|
||||
...data,
|
||||
// Convert lastActivity from ISO string to Date
|
||||
lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined,
|
||||
// Convert createdAt from ISO string to Date
|
||||
createdAt: new Date(data.createdAt),
|
||||
// Convert updatedAt from ISO string to Date
|
||||
@ -1282,13 +1373,15 @@ export function convertEducationFromApi(data: any): Education {
|
||||
}
|
||||
/**
|
||||
* Convert Employer from API response, parsing date fields
|
||||
* Date fields: createdAt, updatedAt, lastLogin
|
||||
* Date fields: lastActivity, createdAt, updatedAt, lastLogin
|
||||
*/
|
||||
export function convertEmployerFromApi(data: any): Employer {
|
||||
if (!data) return data;
|
||||
|
||||
return {
|
||||
...data,
|
||||
// Convert lastActivity from ISO string to Date
|
||||
lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined,
|
||||
// Convert createdAt from ISO string to Date
|
||||
createdAt: new Date(data.createdAt),
|
||||
// Convert updatedAt from ISO string to Date
|
||||
@ -1299,17 +1392,21 @@ export function convertEmployerFromApi(data: any): Employer {
|
||||
}
|
||||
/**
|
||||
* Convert Guest from API response, parsing date fields
|
||||
* Date fields: createdAt, lastActivity
|
||||
* Date fields: lastActivity, createdAt, updatedAt, lastLogin
|
||||
*/
|
||||
export function convertGuestFromApi(data: any): Guest {
|
||||
if (!data) return data;
|
||||
|
||||
return {
|
||||
...data,
|
||||
// Convert lastActivity from ISO string to Date
|
||||
lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined,
|
||||
// Convert createdAt from ISO string to Date
|
||||
createdAt: new Date(data.createdAt),
|
||||
// Convert lastActivity from ISO string to Date
|
||||
lastActivity: new Date(data.lastActivity),
|
||||
// Convert updatedAt from ISO string to Date
|
||||
updatedAt: new Date(data.updatedAt),
|
||||
// Convert lastLogin from ISO string to Date
|
||||
lastLogin: data.lastLogin ? new Date(data.lastLogin) : undefined,
|
||||
};
|
||||
}
|
||||
/**
|
||||
@ -1415,6 +1512,32 @@ export function convertRAGConfigurationFromApi(data: any): RAGConfiguration {
|
||||
updatedAt: new Date(data.updatedAt),
|
||||
};
|
||||
}
|
||||
/**
|
||||
* Convert RateLimitResult from API response, parsing date fields
|
||||
* Date fields: resetTimes
|
||||
*/
|
||||
export function convertRateLimitResultFromApi(data: any): RateLimitResult {
|
||||
if (!data) return data;
|
||||
|
||||
return {
|
||||
...data,
|
||||
// Convert resetTimes from ISO string to Date
|
||||
resetTimes: data.resetTimes ? new Date(data.resetTimes) : undefined,
|
||||
};
|
||||
}
|
||||
/**
|
||||
* Convert RateLimitStatus from API response, parsing date fields
|
||||
* Date fields: resetTimes
|
||||
*/
|
||||
export function convertRateLimitStatusFromApi(data: any): RateLimitStatus {
|
||||
if (!data) return data;
|
||||
|
||||
return {
|
||||
...data,
|
||||
// Convert resetTimes from ISO string to Date
|
||||
resetTimes: new Date(data.resetTimes),
|
||||
};
|
||||
}
|
||||
/**
|
||||
* Convert RefreshToken from API response, parsing date fields
|
||||
* Date fields: expiresAt
|
||||
@ -1527,6 +1650,10 @@ export function convertFromApi<T>(data: any, modelType: string): T {
|
||||
return convertMessageReactionFromApi(data) as T;
|
||||
case 'RAGConfiguration':
|
||||
return convertRAGConfigurationFromApi(data) as T;
|
||||
case 'RateLimitResult':
|
||||
return convertRateLimitResultFromApi(data) as T;
|
||||
case 'RateLimitStatus':
|
||||
return convertRateLimitStatusFromApi(data) as T;
|
||||
case 'RefreshToken':
|
||||
return convertRefreshTokenFromApi(data) as T;
|
||||
case 'UserActivity':
|
||||
|
@ -1,15 +0,0 @@
|
||||
const getConnectionBase = (loc: any): string => {
|
||||
if (!loc.host.match(/.*battle-linux.*/)
|
||||
// && !loc.host.match(/.*backstory-beta.*/)
|
||||
) {
|
||||
return loc.protocol + "//" + loc.host;
|
||||
} else {
|
||||
return loc.protocol + "//battle-linux.ketrenos.com:8912";
|
||||
}
|
||||
}
|
||||
|
||||
const connectionBase = getConnectionBase(window.location);
|
||||
|
||||
export {
|
||||
connectionBase
|
||||
};
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import traceback
|
||||
from pydantic import BaseModel, Field # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import (
|
||||
Literal,
|
||||
get_args,
|
||||
@ -17,7 +17,7 @@ from typing import (
|
||||
import importlib
|
||||
import pathlib
|
||||
import inspect
|
||||
from prometheus_client import CollectorRegistry # type: ignore
|
||||
from prometheus_client import CollectorRegistry
|
||||
|
||||
from . base import Agent
|
||||
from logger import logger
|
||||
@ -90,7 +90,7 @@ for path in package_dir.glob("*.py"):
|
||||
class_registry[name] = (full_module_name, name)
|
||||
globals()[name] = obj
|
||||
logger.info(f"Adding agent: {name}")
|
||||
__all__.append(name) # type: ignore
|
||||
__all__.append(name)
|
||||
except ImportError as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"Error importing {full_module_name}: {e}")
|
||||
|
@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing import (
|
||||
Literal,
|
||||
get_args,
|
||||
@ -20,8 +20,8 @@ import re
|
||||
from abc import ABC
|
||||
import asyncio
|
||||
from datetime import datetime, UTC
|
||||
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
|
||||
import numpy as np # type: ignore
|
||||
from prometheus_client import Counter, Summary, CollectorRegistry
|
||||
import numpy as np
|
||||
|
||||
from models import ( ApiActivityType, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatQuery, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate)
|
||||
from logger import logger
|
||||
@ -491,7 +491,7 @@ Content: {content}
|
||||
session_id: str, prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7
|
||||
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
|
||||
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessageRagSearch, None]:
|
||||
if not self.user:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
|
@ -28,7 +28,7 @@ class CandidateChat(Agent):
|
||||
CandidateChat Agent
|
||||
"""
|
||||
|
||||
agent_type: Literal["candidate_chat"] = "candidate_chat" # type: ignore
|
||||
agent_type: Literal["candidate_chat"] = "candidate_chat"
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
system_prompt: str = system_message
|
||||
|
@ -53,7 +53,7 @@ class Chat(Agent):
|
||||
Chat Agent
|
||||
"""
|
||||
|
||||
agent_type: Literal["general"] = "general" # type: ignore
|
||||
agent_type: Literal["general"] = "general"
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
system_prompt: str = system_message
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from datetime import UTC, datetime
|
||||
from pydantic import model_validator, Field, BaseModel # type: ignore
|
||||
from pydantic import model_validator, Field, BaseModel
|
||||
from typing import (
|
||||
Dict,
|
||||
Literal,
|
||||
@ -37,7 +37,7 @@ seed = int(time.time())
|
||||
random.seed(seed)
|
||||
|
||||
class ImageGenerator(Agent):
|
||||
agent_type: Literal["generate_image"] = "generate_image" # type: ignore
|
||||
agent_type: Literal["generate_image"] = "generate_image"
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
agent_persist: bool = False
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from datetime import UTC, datetime
|
||||
from pydantic import model_validator, Field, BaseModel # type: ignore
|
||||
from pydantic import model_validator, Field, BaseModel
|
||||
from typing import (
|
||||
Dict,
|
||||
Literal,
|
||||
@ -23,7 +23,7 @@ import asyncio
|
||||
import time
|
||||
import os
|
||||
import random
|
||||
from names_dataset import NameDataset, NameWrapper # type: ignore
|
||||
from names_dataset import NameDataset, NameWrapper
|
||||
|
||||
from .base import Agent, agent_registry, LLMMessage
|
||||
from models import ApiActivityType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, Tunables
|
||||
@ -128,7 +128,7 @@ logger = logging.getLogger(__name__)
|
||||
class EthnicNameGenerator:
|
||||
def __init__(self):
|
||||
try:
|
||||
from names_dataset import NameDataset # type: ignore
|
||||
from names_dataset import NameDataset
|
||||
self.nd = NameDataset()
|
||||
except ImportError:
|
||||
logger.error("NameDataset not available. Please install: pip install names-dataset")
|
||||
@ -292,7 +292,7 @@ class EthnicNameGenerator:
|
||||
return names
|
||||
|
||||
class GeneratePersona(Agent):
|
||||
agent_type: Literal["generate_persona"] = "generate_persona" # type: ignore
|
||||
agent_type: Literal["generate_persona"] = "generate_persona"
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
agent_persist: bool = False
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import model_validator, Field # type: ignore
|
||||
from pydantic import model_validator, Field
|
||||
from typing import (
|
||||
Dict,
|
||||
Literal,
|
||||
@ -16,7 +16,7 @@ import json
|
||||
import asyncio
|
||||
import time
|
||||
import asyncio
|
||||
import numpy as np # type: ignore
|
||||
import numpy as np
|
||||
|
||||
from .base import Agent, agent_registry, LLMMessage
|
||||
from models import ApiActivityType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, JobRequirements, JobRequirementsMessage, Tunables
|
||||
@ -26,7 +26,7 @@ import defines
|
||||
import backstory_traceback as traceback
|
||||
|
||||
class JobRequirementsAgent(Agent):
|
||||
agent_type: Literal["job_requirements"] = "job_requirements" # type: ignore
|
||||
agent_type: Literal["job_requirements"] = "job_requirements"
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
# Stage 1A: Job Analysis Implementation
|
||||
|
@ -15,7 +15,7 @@ class Chat(Agent):
|
||||
Chat Agent
|
||||
"""
|
||||
|
||||
agent_type: Literal["rag_search"] = "rag_search" # type: ignore
|
||||
agent_type: Literal["rag_search"] = "rag_search"
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
async def generate(
|
||||
|
@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import model_validator, Field # type: ignore
|
||||
from pydantic import model_validator, Field
|
||||
from typing import (
|
||||
Dict,
|
||||
Literal,
|
||||
@ -16,7 +16,7 @@ import json
|
||||
import asyncio
|
||||
import time
|
||||
import asyncio
|
||||
import numpy as np # type: ignore
|
||||
import numpy as np
|
||||
|
||||
from .base import Agent, agent_registry, LLMMessage
|
||||
from models import Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, SkillMatch, Tunables
|
||||
@ -25,7 +25,7 @@ from logger import logger
|
||||
import defines
|
||||
|
||||
class SkillMatchAgent(Agent):
|
||||
agent_type: Literal["skill_match"] = "skill_match" # type: ignore
|
||||
agent_type: Literal["skill_match"] = "skill_match"
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
def generate_skill_assessment_prompt(self, skill, rag_context):
|
||||
|
@ -5,12 +5,12 @@ Provides password hashing, verification, and security features
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import bcrypt # type: ignore
|
||||
import bcrypt
|
||||
import secrets
|
||||
import logging
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from pydantic import BaseModel # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
175
src/backend/background_tasks.py
Normal file
175
src/backend/background_tasks.py
Normal file
@ -0,0 +1,175 @@
|
||||
"""
|
||||
Background tasks for guest cleanup and system maintenance
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import schedule # type: ignore
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime, timedelta, UTC
|
||||
from typing import Optional
|
||||
from logger import logger
|
||||
from database import DatabaseManager
|
||||
|
||||
class BackgroundTaskManager:
|
||||
"""Manages background tasks for the application"""
|
||||
|
||||
def __init__(self, database_manager: DatabaseManager):
|
||||
self.database_manager = database_manager
|
||||
self.running = False
|
||||
self.tasks = []
|
||||
self.scheduler_thread: Optional[threading.Thread] = None
|
||||
|
||||
async def cleanup_inactive_guests(self, inactive_hours: int = 24):
|
||||
"""Clean up inactive guest sessions"""
|
||||
try:
|
||||
database = self.database_manager.get_database()
|
||||
cleaned_count = await database.cleanup_inactive_guests(inactive_hours)
|
||||
|
||||
if cleaned_count > 0:
|
||||
logger.info(f"🧹 Background cleanup: removed {cleaned_count} inactive guest sessions")
|
||||
|
||||
return cleaned_count
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error in guest cleanup: {e}")
|
||||
return 0
|
||||
|
||||
async def cleanup_expired_verification_tokens(self):
|
||||
"""Clean up expired email verification tokens"""
|
||||
try:
|
||||
database = self.database_manager.get_database()
|
||||
cleaned_count = await database.cleanup_expired_verification_tokens()
|
||||
|
||||
if cleaned_count > 0:
|
||||
logger.info(f"🧹 Background cleanup: removed {cleaned_count} expired verification tokens")
|
||||
|
||||
return cleaned_count
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error in verification token cleanup: {e}")
|
||||
return 0
|
||||
|
||||
async def update_guest_statistics(self):
|
||||
"""Update guest usage statistics"""
|
||||
try:
|
||||
database = self.database_manager.get_database()
|
||||
stats = await database.get_guest_statistics()
|
||||
|
||||
# Log interesting statistics
|
||||
if stats.get('total_guests', 0) > 0:
|
||||
logger.info(f"📊 Guest stats: {stats['total_guests']} total, "
|
||||
f"{stats['active_last_hour']} active in last hour, "
|
||||
f"{stats['converted_guests']} converted")
|
||||
|
||||
return stats
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error updating guest statistics: {e}")
|
||||
return {}
|
||||
|
||||
async def cleanup_old_rate_limit_data(self, days_old: int = 7):
|
||||
"""Clean up old rate limiting data"""
|
||||
try:
|
||||
database = self.database_manager.get_database()
|
||||
redis = database.redis
|
||||
|
||||
# Clean up rate limit keys older than specified days
|
||||
cutoff_time = datetime.now(UTC) - timedelta(days=days_old)
|
||||
pattern = "rate_limit:*"
|
||||
|
||||
cursor = 0
|
||||
deleted_count = 0
|
||||
|
||||
while True:
|
||||
cursor, keys = await redis.scan(cursor, match=pattern, count=100)
|
||||
|
||||
for key in keys:
|
||||
# Check if key is old enough to delete
|
||||
try:
|
||||
ttl = await redis.ttl(key)
|
||||
if ttl == -1: # No expiration set, check creation time
|
||||
# For simplicity, delete keys without TTL
|
||||
await redis.delete(key)
|
||||
deleted_count += 1
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.info(f"🧹 Cleaned up {deleted_count} old rate limit keys")
|
||||
|
||||
return deleted_count
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error cleaning up rate limit data: {e}")
|
||||
return 0
|
||||
|
||||
def schedule_periodic_tasks(self):
|
||||
"""Schedule periodic background tasks with safer intervals"""
|
||||
|
||||
# Guest cleanup - every 6 hours instead of every hour (less aggressive)
|
||||
schedule.every(6).hours.do(self._run_async_task, self.cleanup_inactive_guests, 48) # 48 hours instead of 24
|
||||
|
||||
# Verification token cleanup - every 12 hours
|
||||
schedule.every(12).hours.do(self._run_async_task, self.cleanup_expired_verification_tokens)
|
||||
|
||||
# Guest statistics update - every hour
|
||||
schedule.every().hour.do(self._run_async_task, self.update_guest_statistics)
|
||||
|
||||
# Rate limit data cleanup - daily at 3 AM
|
||||
schedule.every().day.at("03:00").do(self._run_async_task, self.cleanup_old_rate_limit_data, 7)
|
||||
|
||||
logger.info("📅 Background tasks scheduled with safer intervals")
|
||||
|
||||
def _run_async_task(self, coro_func, *args, **kwargs):
|
||||
"""Run an async task in the background"""
|
||||
try:
|
||||
# Create new event loop for this thread if needed
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Run the coroutine
|
||||
loop.run_until_complete(coro_func(*args, **kwargs))
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error running background task {coro_func.__name__}: {e}")
|
||||
|
||||
def _scheduler_worker(self):
|
||||
"""Worker thread for running scheduled tasks"""
|
||||
while self.running:
|
||||
try:
|
||||
schedule.run_pending()
|
||||
time.sleep(60) # Check every minute
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error in scheduler worker: {e}")
|
||||
time.sleep(60)
|
||||
|
||||
def start(self):
|
||||
"""Start the background task manager"""
|
||||
if self.running:
|
||||
logger.warning("⚠️ Background task manager already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.schedule_periodic_tasks()
|
||||
|
||||
# Start scheduler thread
|
||||
self.scheduler_thread = threading.Thread(target=self._scheduler_worker, daemon=True)
|
||||
self.scheduler_thread.start()
|
||||
|
||||
logger.info("🚀 Background task manager started")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the background task manager"""
|
||||
self.running = False
|
||||
|
||||
if self.scheduler_thread and self.scheduler_thread.is_alive():
|
||||
self.scheduler_thread.join(timeout=5)
|
||||
|
||||
# Clear scheduled tasks
|
||||
schedule.clear()
|
||||
|
||||
logger.info("🛑 Background task manager stopped")
|
||||
|
||||
|
@ -9,6 +9,7 @@ from models import (
|
||||
# User models
|
||||
Candidate, Employer, BaseUser, Guest, Authentication, AuthResponse,
|
||||
)
|
||||
import backstory_traceback as traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -198,6 +199,7 @@ class RedisDatabase:
|
||||
try:
|
||||
return json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"Failed to deserialize data: {data}")
|
||||
return None
|
||||
|
||||
@ -254,43 +256,44 @@ class RedisDatabase:
|
||||
raise
|
||||
|
||||
async def get_cached_skill_match(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""Retrieve cached skill match assessment"""
|
||||
"""Get cached skill match assessment"""
|
||||
try:
|
||||
cached_data = await self.redis.get(cache_key)
|
||||
if cached_data:
|
||||
return json.loads(cached_data)
|
||||
data = await self.redis.get(cache_key)
|
||||
if data:
|
||||
return json.loads(data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving cached skill match: {e}")
|
||||
logger.error(f"❌ Error getting cached skill match: {e}")
|
||||
return None
|
||||
|
||||
async def cache_skill_match(self, cache_key: str, assessment_data: Dict[str, Any], ttl: int = 86400 * 30) -> bool:
|
||||
"""Cache skill match assessment with TTL (default 30 days)"""
|
||||
|
||||
async def cache_skill_match(self, cache_key: str, assessment_data: Dict[str, Any]) -> None:
|
||||
"""Cache skill match assessment"""
|
||||
try:
|
||||
# Cache for 1 hour by default
|
||||
await self.redis.setex(
|
||||
cache_key,
|
||||
ttl,
|
||||
json.dumps(assessment_data, default=str)
|
||||
3600,
|
||||
json.dumps(assessment_data)
|
||||
)
|
||||
return True
|
||||
logger.debug(f"💾 Skill match cached: {cache_key}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error caching skill match: {e}")
|
||||
return False
|
||||
|
||||
logger.error(f"❌ Error caching skill match: {e}")
|
||||
|
||||
async def get_candidate_skill_update_time(self, candidate_id: str) -> Optional[datetime]:
|
||||
"""Get the last time candidate's skill information was updated"""
|
||||
"""Get the last time candidate skills were updated"""
|
||||
try:
|
||||
# This assumes you track skill update timestamps in your candidate data
|
||||
candidate_data = await self.get_candidate(candidate_id)
|
||||
if candidate_data and 'skills_updated_at' in candidate_data:
|
||||
return datetime.fromisoformat(candidate_data['skills_updated_at'])
|
||||
if candidate_data:
|
||||
updated_at_str = candidate_data.get("updated_at")
|
||||
if updated_at_str:
|
||||
return datetime.fromisoformat(updated_at_str.replace('Z', '+00:00'))
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting candidate skill update time: {e}")
|
||||
logger.error(f"❌ Error getting candidate skill update time: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]:
|
||||
"""Get the timestamp of the latest RAG data update for a specific user"""
|
||||
"""Get the last time user's RAG data was updated"""
|
||||
try:
|
||||
rag_update_key = f"user:{user_id}:rag_last_update"
|
||||
timestamp_str = await self.redis.get(rag_update_key)
|
||||
@ -298,8 +301,8 @@ class RedisDatabase:
|
||||
return datetime.fromisoformat(timestamp_str.decode('utf-8'))
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user RAG update time for user {user_id}: {e}")
|
||||
return None
|
||||
logger.error(f"❌ Error getting user RAG update time: {e}")
|
||||
return None
|
||||
|
||||
async def update_user_rag_timestamp(self, user_id: str) -> bool:
|
||||
"""Update the RAG data timestamp for a specific user (call this when user's RAG data is updated)"""
|
||||
@ -359,7 +362,7 @@ class RedisDatabase:
|
||||
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]:
|
||||
"""Get all documents for a specific candidate"""
|
||||
key = f"{self.KEY_PREFIXES['candidate_documents']}{candidate_id}"
|
||||
document_ids = await self.redis.lrange(key, 0, -1)
|
||||
document_ids = await self.redis.lrange(key, 0, -1)
|
||||
|
||||
if not document_ids:
|
||||
return []
|
||||
@ -1707,7 +1710,11 @@ class RedisDatabase:
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
email = key.replace(self.KEY_PREFIXES['users'], '')
|
||||
result[email] = self._deserialize(value)
|
||||
logger.info(f"🔍 Found user key: {key}, type: {type(value)}")
|
||||
if type(value) == str:
|
||||
result[email] = value
|
||||
else:
|
||||
result[email] = self._deserialize(value)
|
||||
|
||||
return result
|
||||
|
||||
@ -1850,17 +1857,16 @@ class RedisDatabase:
|
||||
return False
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Retrieve user data by ID"""
|
||||
"""Get user lookup data by user ID"""
|
||||
try:
|
||||
key = f"user_by_id:{user_id}"
|
||||
data = await self.redis.get(key)
|
||||
data = await self.redis.hget("user_lookup_by_id", user_id)
|
||||
if data:
|
||||
return json.loads(data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error retrieving user by ID {user_id}: {e}")
|
||||
return None
|
||||
|
||||
logger.error(f"❌ Error getting user by ID {user_id}: {e}")
|
||||
return None
|
||||
|
||||
async def user_exists_by_email(self, email: str) -> bool:
|
||||
"""Check if a user exists with the given email"""
|
||||
try:
|
||||
@ -2104,6 +2110,208 @@ class RedisDatabase:
|
||||
logger.error(f"❌ Error retrieving security log for {user_id}: {e}")
|
||||
return []
|
||||
|
||||
# ============================
|
||||
# Guest Management Methods
|
||||
# ============================
|
||||
|
||||
async def set_guest(self, guest_id: str, guest_data: Dict[str, Any]) -> None:
|
||||
"""Store guest data with enhanced persistence"""
|
||||
try:
|
||||
# Ensure last_activity is always set
|
||||
guest_data["last_activity"] = datetime.now(UTC).isoformat()
|
||||
|
||||
# Store in Redis with both hash and individual key for redundancy
|
||||
await self.redis.hset("guests", guest_id, json.dumps(guest_data))
|
||||
|
||||
# Also store with a longer TTL as backup
|
||||
await self.redis.setex(
|
||||
f"guest_backup:{guest_id}",
|
||||
86400 * 7, # 7 days TTL
|
||||
json.dumps(guest_data)
|
||||
)
|
||||
|
||||
logger.debug(f"💾 Guest stored with backup: {guest_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error storing guest {guest_id}: {e}")
|
||||
raise
|
||||
|
||||
async def get_guest(self, guest_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get guest data with fallback to backup"""
|
||||
try:
|
||||
# Try primary storage first
|
||||
data = await self.redis.hget("guests", guest_id)
|
||||
if data:
|
||||
guest_data = json.loads(data)
|
||||
# Update last activity when accessed
|
||||
guest_data["last_activity"] = datetime.now(UTC).isoformat()
|
||||
await self.set_guest(guest_id, guest_data)
|
||||
logger.debug(f"🔍 Guest found in primary storage: {guest_id}")
|
||||
return guest_data
|
||||
|
||||
# Fallback to backup storage
|
||||
backup_data = await self.redis.get(f"guest_backup:{guest_id}")
|
||||
if backup_data:
|
||||
guest_data = json.loads(backup_data)
|
||||
guest_data["last_activity"] = datetime.now(UTC).isoformat()
|
||||
|
||||
# Restore to primary storage
|
||||
await self.set_guest(guest_id, guest_data)
|
||||
logger.info(f"🔄 Guest restored from backup: {guest_id}")
|
||||
return guest_data
|
||||
|
||||
logger.warning(f"⚠️ Guest not found: {guest_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting guest {guest_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_guest_by_session_id(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get guest data by session ID"""
|
||||
try:
|
||||
all_guests = await self.get_all_guests()
|
||||
for guest_data in all_guests.values():
|
||||
if guest_data.get("session_id") == session_id:
|
||||
return guest_data
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting guest by session ID {session_id}: {e}")
|
||||
return None
|
||||
|
||||
async def get_all_guests(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all guests"""
|
||||
try:
|
||||
data = await self.redis.hgetall("guests")
|
||||
return {
|
||||
guest_id: json.loads(guest_json)
|
||||
for guest_id, guest_json in data.items()
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting all guests: {e}")
|
||||
return {}
|
||||
|
||||
async def delete_guest(self, guest_id: str) -> bool:
|
||||
"""Delete a guest"""
|
||||
try:
|
||||
result = await self.redis.hdel("guests", guest_id)
|
||||
if result:
|
||||
logger.info(f"🗑️ Guest deleted: {guest_id}")
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error deleting guest {guest_id}: {e}")
|
||||
return False
|
||||
|
||||
async def cleanup_inactive_guests(self, inactive_hours: int = 24) -> int:
|
||||
"""Clean up inactive guest sessions with safety checks"""
|
||||
try:
|
||||
all_guests = await self.get_all_guests()
|
||||
current_time = datetime.now(UTC)
|
||||
cutoff_time = current_time - timedelta(hours=inactive_hours)
|
||||
|
||||
deleted_count = 0
|
||||
preserved_count = 0
|
||||
|
||||
for guest_id, guest_data in all_guests.items():
|
||||
try:
|
||||
last_activity_str = guest_data.get("last_activity")
|
||||
created_at_str = guest_data.get("created_at")
|
||||
|
||||
# Skip cleanup if guest is very new (less than 1 hour old)
|
||||
if created_at_str:
|
||||
created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
|
||||
if current_time - created_at < timedelta(hours=1):
|
||||
preserved_count += 1
|
||||
logger.debug(f"🛡️ Preserving new guest: {guest_id}")
|
||||
continue
|
||||
|
||||
# Check last activity
|
||||
should_delete = False
|
||||
if last_activity_str:
|
||||
try:
|
||||
last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00'))
|
||||
if last_activity < cutoff_time:
|
||||
should_delete = True
|
||||
except ValueError:
|
||||
# Invalid date format, but don't delete if guest is new
|
||||
if not created_at_str:
|
||||
should_delete = True
|
||||
else:
|
||||
# No last activity, but don't delete if guest is new
|
||||
if not created_at_str:
|
||||
should_delete = True
|
||||
|
||||
if should_delete:
|
||||
await self.delete_guest(guest_id)
|
||||
deleted_count += 1
|
||||
else:
|
||||
preserved_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error processing guest {guest_id} for cleanup: {e}")
|
||||
preserved_count += 1 # Preserve on error
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.info(f"🧹 Guest cleanup: removed {deleted_count}, preserved {preserved_count}")
|
||||
|
||||
return deleted_count
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error in guest cleanup: {e}")
|
||||
return 0
|
||||
|
||||
async def get_guest_statistics(self) -> Dict[str, Any]:
|
||||
"""Get guest usage statistics"""
|
||||
try:
|
||||
all_guests = await self.get_all_guests()
|
||||
current_time = datetime.now(UTC)
|
||||
|
||||
stats = {
|
||||
"total_guests": len(all_guests),
|
||||
"active_last_hour": 0,
|
||||
"active_last_day": 0,
|
||||
"converted_guests": 0,
|
||||
"by_ip": {},
|
||||
"creation_timeline": {}
|
||||
}
|
||||
|
||||
hour_ago = current_time - timedelta(hours=1)
|
||||
day_ago = current_time - timedelta(days=1)
|
||||
|
||||
for guest_data in all_guests.values():
|
||||
# Check activity
|
||||
last_activity_str = guest_data.get("last_activity")
|
||||
if last_activity_str:
|
||||
try:
|
||||
last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00'))
|
||||
if last_activity > hour_ago:
|
||||
stats["active_last_hour"] += 1
|
||||
if last_activity > day_ago:
|
||||
stats["active_last_day"] += 1
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Check conversions
|
||||
if guest_data.get("converted_to_user_id"):
|
||||
stats["converted_guests"] += 1
|
||||
|
||||
# IP tracking
|
||||
ip = guest_data.get("ip_address", "unknown")
|
||||
stats["by_ip"][ip] = stats["by_ip"].get(ip, 0) + 1
|
||||
|
||||
# Creation timeline
|
||||
created_at_str = guest_data.get("created_at")
|
||||
if created_at_str:
|
||||
try:
|
||||
created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
|
||||
date_key = created_at.strftime('%Y-%m-%d')
|
||||
stats["creation_timeline"][date_key] = stats["creation_timeline"].get(date_key, 0) + 1
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return stats
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting guest statistics: {e}")
|
||||
return {}
|
||||
|
||||
# Global Redis manager instance
|
||||
redis_manager = _RedisManager()
|
||||
|
||||
|
@ -1,9 +1,9 @@
|
||||
from fastapi import FastAPI, HTTPException, Depends, Query, Path, Body, status, APIRouter, Request, BackgroundTasks # type: ignore
|
||||
from fastapi import FastAPI, HTTPException, Depends, Query, Path, Body, status, APIRouter, Request, BackgroundTasks
|
||||
from database import RedisDatabase
|
||||
import hashlib
|
||||
from logger import logger
|
||||
from datetime import datetime, timezone
|
||||
from user_agents import parse # type: ignore
|
||||
from user_agents import parse
|
||||
import json
|
||||
|
||||
class DeviceManager:
|
||||
|
@ -1,8 +1,8 @@
|
||||
import os
|
||||
from typing import Tuple
|
||||
from logger import logger
|
||||
from email.mime.text import MIMEText # type: ignore
|
||||
from email.mime.multipart import MIMEMultipart # type: ignore
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
import smtplib
|
||||
import asyncio
|
||||
from email_templates import EMAIL_TEMPLATES
|
||||
|
@ -1,13 +1,13 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from uuid import uuid4
|
||||
from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING, Literal
|
||||
|
||||
from typing_extensions import Annotated, Union
|
||||
import numpy as np # type: ignore
|
||||
import numpy as np
|
||||
|
||||
from uuid import uuid4
|
||||
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
||||
from prometheus_client import CollectorRegistry, Counter
|
||||
import traceback
|
||||
import os
|
||||
import json
|
||||
|
@ -3,12 +3,12 @@ import weakref
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional, Any
|
||||
from contextlib import asynccontextmanager
|
||||
from pydantic import BaseModel, Field # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from models import ( Candidate )
|
||||
from .candidate_entity import CandidateEntity
|
||||
from database import RedisDatabase
|
||||
from prometheus_client import CollectorRegistry # type: ignore
|
||||
from prometheus_client import CollectorRegistry
|
||||
|
||||
class EntityManager:
|
||||
"""Manages lifecycle of CandidateEntity instances"""
|
||||
|
@ -64,7 +64,7 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, current_dir)
|
||||
|
||||
try:
|
||||
from pydantic import BaseModel # type: ignore
|
||||
from pydantic import BaseModel
|
||||
except ImportError as e:
|
||||
print(f"Error importing pydantic: {e}")
|
||||
print("Make sure pydantic is installed: pip install pydantic")
|
||||
|
@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel, Field # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
import json
|
||||
from typing import Any, List, Set
|
||||
|
||||
|
@ -4,8 +4,8 @@ import re
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import torch # type: ignore
|
||||
from diffusers import StableDiffusionPipeline, FluxPipeline # type: ignore
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline, FluxPipeline
|
||||
|
||||
class ImageModelCache: # Stay loaded for 3 hours
|
||||
def __init__(self, timeout_seconds: float = 3 * 60 * 60):
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from datetime import UTC, datetime
|
||||
from pydantic import BaseModel, Field # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Literal, Any, AsyncGenerator, Optional
|
||||
import inspect
|
||||
import random
|
||||
@ -13,7 +13,7 @@ import os
|
||||
import gc
|
||||
import tempfile
|
||||
import uuid
|
||||
import torch # type: ignore
|
||||
import torch
|
||||
import asyncio
|
||||
import time
|
||||
import json
|
||||
|
@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Any, AsyncGenerator, Optional, Union
|
||||
from pydantic import BaseModel, Field # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
from enum import Enum
|
||||
import asyncio
|
||||
import json
|
||||
@ -179,7 +179,7 @@ class OllamaAdapter(BaseLLMAdapter):
|
||||
def __init__(self, **config):
|
||||
super().__init__(**config)
|
||||
import ollama
|
||||
self.client = ollama.AsyncClient( # type: ignore
|
||||
self.client = ollama.AsyncClient(
|
||||
host=config.get('host', defines.ollama_api_url)
|
||||
)
|
||||
|
||||
@ -386,7 +386,7 @@ class OpenAIAdapter(BaseLLMAdapter):
|
||||
|
||||
def __init__(self, **config):
|
||||
super().__init__(**config)
|
||||
import openai # type: ignore
|
||||
import openai
|
||||
self.client = openai.AsyncOpenAI(
|
||||
api_key=config.get('api_key', os.getenv('OPENAI_API_KEY'))
|
||||
)
|
||||
@ -522,7 +522,7 @@ class AnthropicAdapter(BaseLLMAdapter):
|
||||
|
||||
def __init__(self, **config):
|
||||
super().__init__(**config)
|
||||
import anthropic # type: ignore
|
||||
import anthropic
|
||||
self.client = anthropic.AsyncAnthropic(
|
||||
api_key=config.get('api_key', os.getenv('ANTHROPIC_API_KEY'))
|
||||
)
|
||||
@ -656,7 +656,7 @@ class GeminiAdapter(BaseLLMAdapter):
|
||||
|
||||
def __init__(self, **config):
|
||||
super().__init__(**config)
|
||||
import google.generativeai as genai # type: ignore
|
||||
import google.generativeai as genai
|
||||
genai.configure(api_key=config.get('api_key', os.getenv('GEMINI_API_KEY')))
|
||||
self.genai = genai
|
||||
|
||||
@ -867,7 +867,7 @@ class UnifiedLLMProxy:
|
||||
raise ValueError("stream must be True for chat_stream")
|
||||
result = await self.chat(model, messages, provider, stream=True, **kwargs)
|
||||
# Type checker now knows this is an AsyncGenerator due to stream=True
|
||||
async for chunk in result: # type: ignore
|
||||
async for chunk in result:
|
||||
yield chunk
|
||||
|
||||
async def chat_single(
|
||||
@ -881,7 +881,7 @@ class UnifiedLLMProxy:
|
||||
|
||||
result = await self.chat(model, messages, provider, stream=False, **kwargs)
|
||||
# Type checker now knows this is a ChatResponse due to stream=False
|
||||
return result # type: ignore
|
||||
return result
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
@ -908,7 +908,7 @@ class UnifiedLLMProxy:
|
||||
"""Stream text generation using specified or default provider"""
|
||||
|
||||
result = await self.generate(model, prompt, provider, stream=True, **kwargs)
|
||||
async for chunk in result: # type: ignore
|
||||
async for chunk in result:
|
||||
yield chunk
|
||||
|
||||
async def generate_single(
|
||||
@ -921,7 +921,7 @@ class UnifiedLLMProxy:
|
||||
"""Get single generation response using specified or default provider"""
|
||||
|
||||
result = await self.generate(model, prompt, provider, stream=False, **kwargs)
|
||||
return result # type: ignore
|
||||
return result
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
@ -1148,7 +1148,7 @@ async def example_embeddings_usage():
|
||||
|
||||
# Calculate similarity between first two texts (requires numpy)
|
||||
try:
|
||||
import numpy as np # type: ignore
|
||||
import numpy as np
|
||||
emb1 = np.array(response.data[0].embedding)
|
||||
emb2 = np.array(response.data[1].embedding)
|
||||
similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
|
||||
|
1411
src/backend/main.py
1411
src/backend/main.py
File diff suppressed because it is too large
Load Diff
@ -1,7 +1,14 @@
|
||||
from typing import Type, TypeVar
|
||||
from pydantic import BaseModel # type: ignore
|
||||
from pydantic import BaseModel
|
||||
import copy
|
||||
|
||||
from models import Candidate, CandidateAI, Employer, Guest, BaseUserWithType
|
||||
|
||||
# Ensure all user models inherit from BaseUserWithType
|
||||
assert issubclass(Candidate, BaseUserWithType), "Candidate must inherit from BaseUserWithType"
|
||||
assert issubclass(CandidateAI, BaseUserWithType), "CandidateAI must inherit from BaseUserWithType"
|
||||
assert issubclass(Employer, BaseUserWithType), "Employer must inherit from BaseUserWithType"
|
||||
assert issubclass(Guest, BaseUserWithType), "Guest must inherit from BaseUserWithType"
|
||||
|
||||
T = TypeVar('T', bound=BaseModel)
|
||||
|
||||
@ -12,3 +19,35 @@ def cast_to_model(model_cls: Type[T], source: BaseModel) -> T:
|
||||
def cast_to_model_safe(model_cls: Type[T], source: BaseModel) -> T:
|
||||
data = {field: copy.deepcopy(getattr(source, field)) for field in model_cls.__fields__}
|
||||
return model_cls(**data)
|
||||
|
||||
def cast_to_base_user_with_type(user) -> BaseUserWithType:
|
||||
"""
|
||||
Casts a Candidate, CandidateAI, Employer, or Guest to BaseUserWithType.
|
||||
This is useful for FastAPI dependencies that expect a common user type.
|
||||
"""
|
||||
if isinstance(user, BaseUserWithType):
|
||||
return user
|
||||
# If it's a dict, try to detect type
|
||||
if isinstance(user, dict):
|
||||
user_type = user.get("user_type") or user.get("type")
|
||||
if user_type == "candidate":
|
||||
if user.get("is_AI"):
|
||||
return CandidateAI.model_validate(user)
|
||||
return Candidate.model_validate(user)
|
||||
elif user_type == "employer":
|
||||
return Employer.model_validate(user)
|
||||
elif user_type == "guest":
|
||||
return Guest.model_validate(user)
|
||||
else:
|
||||
raise ValueError(f"Unknown user_type: {user_type}")
|
||||
# If it's a model, check its type
|
||||
if hasattr(user, "user_type"):
|
||||
if getattr(user, "user_type", None) == "candidate":
|
||||
if getattr(user, "is_AI", False):
|
||||
return CandidateAI.model_validate(user.model_dump())
|
||||
return Candidate.model_validate(user.model_dump())
|
||||
elif getattr(user, "user_type", None) == "employer":
|
||||
return Employer.model_validate(user.model_dump())
|
||||
elif getattr(user, "user_type", None) == "guest":
|
||||
return Guest.model_validate(user.model_dump())
|
||||
raise TypeError(f"Cannot cast object of type {type(user)} to BaseUserWithType")
|
||||
|
@ -10,6 +10,7 @@ from auth_utils import (
|
||||
sanitize_login_input,
|
||||
SecurityConfig
|
||||
)
|
||||
import defines
|
||||
|
||||
# Generic type variable
|
||||
T = TypeVar('T')
|
||||
@ -256,6 +257,7 @@ class MFARequest(BaseModel):
|
||||
password: str
|
||||
device_id: str = Field(..., alias="deviceId")
|
||||
device_name: str = Field(..., alias="deviceName")
|
||||
email: str = Field(..., alias="email")
|
||||
model_config = {
|
||||
"populate_by_name": True, # Allow both field names and aliases
|
||||
}
|
||||
@ -464,9 +466,18 @@ class ErrorDetail(BaseModel):
|
||||
# Main Models
|
||||
# ============================
|
||||
|
||||
# Base user model without user_type field
|
||||
class BaseUser(BaseModel):
|
||||
# Generic base user with user_type for API responses
|
||||
class BaseUserWithType(BaseModel):
|
||||
user_type: UserType = Field(..., alias="userType")
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
last_activity: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="lastActivity")
|
||||
model_config = {
|
||||
"populate_by_name": True, # Allow both field names and aliases
|
||||
"use_enum_values": True # Use enum values instead of names
|
||||
}
|
||||
|
||||
# Base user model without user_type field
|
||||
class BaseUser(BaseUserWithType):
|
||||
email: EmailStr
|
||||
first_name: str = Field(..., alias="firstName")
|
||||
last_name: str = Field(..., alias="lastName")
|
||||
@ -485,9 +496,6 @@ class BaseUser(BaseModel):
|
||||
"use_enum_values": True # Use enum values instead of names
|
||||
}
|
||||
|
||||
# Generic base user with user_type for API responses
|
||||
class BaseUserWithType(BaseUser):
|
||||
user_type: UserType = Field(..., alias="userType")
|
||||
|
||||
class RagEntry(BaseModel):
|
||||
name: str
|
||||
@ -519,9 +527,9 @@ class DocumentType(str, Enum):
|
||||
IMAGE = "image"
|
||||
|
||||
class DocumentOptions(BaseModel):
|
||||
include_in_RAG: Optional[bool] = Field(True, alias="includeInRAG")
|
||||
is_job_document: Optional[bool] = Field(False, alias="isJobDocument")
|
||||
overwrite: Optional[bool] = Field(False, alias="overwrite")
|
||||
include_in_RAG: bool = Field(default=True, alias="includeInRAG")
|
||||
is_job_document: Optional[bool] = Field(default=False, alias="isJobDocument")
|
||||
overwrite: Optional[bool] = Field(default=False, alias="overwrite")
|
||||
model_config = {
|
||||
"populate_by_name": True # Allow both field names and aliases
|
||||
}
|
||||
@ -534,7 +542,7 @@ class Document(BaseModel):
|
||||
type: DocumentType
|
||||
size: int
|
||||
upload_date: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="uploadDate")
|
||||
options: DocumentOptions = Field(default_factory=DocumentOptions, alias="options")
|
||||
options: DocumentOptions = Field(default_factory=lambda: DocumentOptions(), alias="options")
|
||||
rag_chunks: Optional[int] = Field(default=0, alias="ragChunks")
|
||||
model_config = {
|
||||
"populate_by_name": True # Allow both field names and aliases
|
||||
@ -565,7 +573,7 @@ class DocumentUpdateRequest(BaseModel):
|
||||
}
|
||||
|
||||
class Candidate(BaseUser):
|
||||
user_type: Literal[UserType.CANDIDATE] = Field(UserType.CANDIDATE, alias="userType")
|
||||
user_type: UserType = Field(UserType.CANDIDATE, alias="userType")
|
||||
username: str
|
||||
description: Optional[str] = None
|
||||
resume: Optional[str] = None
|
||||
@ -584,14 +592,14 @@ class Candidate(BaseUser):
|
||||
rag_content_size : int = 0
|
||||
|
||||
class CandidateAI(Candidate):
|
||||
user_type: Literal[UserType.CANDIDATE] = Field(UserType.CANDIDATE, alias="userType")
|
||||
user_type: UserType = Field(UserType.CANDIDATE, alias="userType")
|
||||
is_AI: bool = Field(True, alias="isAI")
|
||||
age: Optional[int] = None
|
||||
gender: Optional[UserGender] = None
|
||||
ethnicity: Optional[str] = None
|
||||
|
||||
class Employer(BaseUser):
|
||||
user_type: Literal[UserType.EMPLOYER] = Field(UserType.EMPLOYER, alias="userType")
|
||||
user_type: UserType = Field(UserType.EMPLOYER, alias="userType")
|
||||
company_name: str = Field(..., alias="companyName")
|
||||
industry: str
|
||||
description: Optional[str] = None
|
||||
@ -603,16 +611,18 @@ class Employer(BaseUser):
|
||||
social_links: Optional[List[SocialLink]] = Field(None, alias="socialLinks")
|
||||
poc: Optional[PointOfContact] = None
|
||||
|
||||
class Guest(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
class Guest(BaseUser):
|
||||
user_type: UserType = Field(UserType.GUEST, alias="userType")
|
||||
session_id: str = Field(..., alias="sessionId")
|
||||
created_at: datetime = Field(..., alias="createdAt")
|
||||
last_activity: datetime = Field(..., alias="lastActivity")
|
||||
username: str # Add username for consistency with other user types
|
||||
converted_to_user_id: Optional[str] = Field(None, alias="convertedToUserId")
|
||||
ip_address: Optional[str] = Field(None, alias="ipAddress")
|
||||
created_at: datetime = Field(..., alias="createdAt")
|
||||
user_agent: Optional[str] = Field(None, alias="userAgent")
|
||||
rag_content_size: int = 0
|
||||
model_config = {
|
||||
"populate_by_name": True # Allow both field names and aliases
|
||||
"populate_by_name": True, # Allow both field names and aliases
|
||||
"use_enum_values": True # Use enum values instead of names
|
||||
}
|
||||
|
||||
class Authentication(BaseModel):
|
||||
@ -635,10 +645,21 @@ class Authentication(BaseModel):
|
||||
class AuthResponse(BaseModel):
|
||||
access_token: str = Field(..., alias="accessToken")
|
||||
refresh_token: str = Field(..., alias="refreshToken")
|
||||
user: Candidate | Employer
|
||||
user: Union[Candidate, Employer, Guest] # Add Guest support
|
||||
expires_at: int = Field(..., alias="expiresAt")
|
||||
user_type: Optional[str] = Field(default=UserType.GUEST, alias="userType") # Explicit user type
|
||||
is_guest: Optional[bool] = Field(default=True, alias="isGuest") # Guest indicator
|
||||
|
||||
model_config = {
|
||||
"populate_by_name": True # Allow both field names and aliases
|
||||
"populate_by_name": True
|
||||
}
|
||||
|
||||
class GuestCleanupRequest(BaseModel):
|
||||
"""Request to cleanup inactive guests"""
|
||||
inactive_hours: int = Field(24, alias="inactiveHours")
|
||||
|
||||
model_config = {
|
||||
"populate_by_name": True
|
||||
}
|
||||
|
||||
class JobRequirements(BaseModel):
|
||||
@ -751,6 +772,19 @@ class ChromaDBGetResponse(BaseModel):
|
||||
umap_embedding_2d: Optional[List[float]] = Field(default=None, alias="umapEmbedding2D")
|
||||
umap_embedding_3d: Optional[List[float]] = Field(default=None, alias="umapEmbedding3D")
|
||||
|
||||
class GuestSessionResponse(BaseModel):
|
||||
"""Response for guest session creation"""
|
||||
access_token: str = Field(..., alias="accessToken")
|
||||
refresh_token: str = Field(..., alias="refreshToken")
|
||||
user: Guest
|
||||
expires_at: int = Field(..., alias="expiresAt")
|
||||
user_type: Literal["guest"] = Field("guest", alias="userType")
|
||||
is_guest: bool = Field(True, alias="isGuest")
|
||||
|
||||
model_config = {
|
||||
"populate_by_name": True
|
||||
}
|
||||
|
||||
class ChatContext(BaseModel):
|
||||
type: ChatContextType
|
||||
related_entity_id: Optional[str] = Field(None, alias="relatedEntityId")
|
||||
@ -768,6 +802,98 @@ class ChatOptions(BaseModel):
|
||||
"populate_by_name": True # Allow both field names and aliases
|
||||
}
|
||||
|
||||
# Add rate limiting configuration models
|
||||
class RateLimitConfig(BaseModel):
|
||||
"""Rate limit configuration"""
|
||||
requests_per_minute: int = Field(..., alias="requestsPerMinute")
|
||||
requests_per_hour: int = Field(..., alias="requestsPerHour")
|
||||
requests_per_day: int = Field(..., alias="requestsPerDay")
|
||||
burst_limit: int = Field(..., alias="burstLimit")
|
||||
burst_window_seconds: int = Field(60, alias="burstWindowSeconds")
|
||||
|
||||
model_config = {
|
||||
"populate_by_name": True
|
||||
}
|
||||
|
||||
class RateLimitResult(BaseModel):
|
||||
"""Result of rate limit check"""
|
||||
allowed: bool
|
||||
reason: Optional[str] = None
|
||||
retry_after_seconds: Optional[int] = Field(None, alias="retryAfterSeconds")
|
||||
remaining_requests: Dict[str, int] = Field(default_factory=dict, alias="remainingRequests")
|
||||
reset_times: Dict[str, datetime] = Field(default_factory=dict, alias="resetTimes")
|
||||
|
||||
model_config = {
|
||||
"populate_by_name": True
|
||||
}
|
||||
|
||||
class RateLimitStatus(BaseModel):
|
||||
"""Rate limit status for a user"""
|
||||
user_id: str = Field(..., alias="userId")
|
||||
user_type: str = Field(..., alias="userType")
|
||||
is_admin: bool = Field(..., alias="isAdmin")
|
||||
current_usage: Dict[str, int] = Field(..., alias="currentUsage")
|
||||
limits: Dict[str, int] = Field(..., alias="limits")
|
||||
remaining: Dict[str, int] = Field(..., alias="remaining")
|
||||
reset_times: Dict[str, datetime] = Field(..., alias="resetTimes")
|
||||
config: RateLimitConfig
|
||||
|
||||
model_config = {
|
||||
"populate_by_name": True
|
||||
}
|
||||
|
||||
|
||||
# Add guest conversion request models
|
||||
class GuestConversionRequest(BaseModel):
|
||||
"""Request to convert guest to permanent user"""
|
||||
account_type: Literal["candidate", "employer"] = Field(..., alias="accountType")
|
||||
email: EmailStr
|
||||
username: str
|
||||
password: str
|
||||
first_name: str = Field(..., alias="firstName")
|
||||
last_name: str = Field(..., alias="lastName")
|
||||
phone: Optional[str] = None
|
||||
|
||||
# Employer-specific fields (optional)
|
||||
company_name: Optional[str] = Field(None, alias="companyName")
|
||||
industry: Optional[str] = None
|
||||
company_size: Optional[str] = Field(None, alias="companySize")
|
||||
company_description: Optional[str] = Field(None, alias="companyDescription")
|
||||
website_url: Optional[HttpUrl] = Field(None, alias="websiteUrl")
|
||||
|
||||
model_config = {
|
||||
"populate_by_name": True
|
||||
}
|
||||
|
||||
@field_validator('username')
|
||||
def validate_username(cls, v):
|
||||
if not v or len(v.strip()) < 3:
|
||||
raise ValueError('Username must be at least 3 characters long')
|
||||
return v.strip().lower()
|
||||
|
||||
@field_validator('password')
|
||||
def validate_password_strength(cls, v):
|
||||
# Import here to avoid circular imports
|
||||
from auth_utils import validate_password_strength
|
||||
is_valid, issues = validate_password_strength(v)
|
||||
if not is_valid:
|
||||
raise ValueError('; '.join(issues))
|
||||
return v
|
||||
|
||||
# Add guest statistics response model
|
||||
class GuestStatistics(BaseModel):
|
||||
"""Guest usage statistics"""
|
||||
total_guests: int = Field(..., alias="totalGuests")
|
||||
active_last_hour: int = Field(..., alias="activeLastHour")
|
||||
active_last_day: int = Field(..., alias="activeLastDay")
|
||||
converted_guests: int = Field(..., alias="convertedGuests")
|
||||
by_ip: Dict[str, int] = Field(..., alias="byIp")
|
||||
creation_timeline: Dict[str, int] = Field(..., alias="creationTimeline")
|
||||
|
||||
model_config = {
|
||||
"populate_by_name": True
|
||||
}
|
||||
|
||||
from llm_proxy import (LLMMessage)
|
||||
|
||||
class ApiMessage(BaseModel):
|
||||
@ -800,12 +926,14 @@ class ApiActivityType(str, Enum):
|
||||
HEARTBEAT = "heartbeat" # Used for periodic updates
|
||||
|
||||
class ChatMessageStatus(ApiMessage):
|
||||
sender_id: Optional[str] = Field(default=MOCK_UUID, alias="senderId")
|
||||
status: ApiStatusType = ApiStatusType.STATUS
|
||||
type: ApiMessageType = ApiMessageType.TEXT
|
||||
activity: ApiActivityType
|
||||
content: Any
|
||||
|
||||
class ChatMessageError(ApiMessage):
|
||||
sender_id: Optional[str] = Field(default=MOCK_UUID, alias="senderId")
|
||||
status: ApiStatusType = ApiStatusType.ERROR
|
||||
type: ApiMessageType = ApiMessageType.TEXT
|
||||
content: str
|
||||
@ -825,6 +953,7 @@ class JobRequirementsMessage(ApiMessage):
|
||||
|
||||
class DocumentMessage(ApiMessage):
|
||||
type: ApiMessageType = ApiMessageType.JSON
|
||||
sender_id: Optional[str] = Field(default=MOCK_UUID, alias="senderId")
|
||||
document: Document = Field(..., alias="document")
|
||||
content: Optional[str] = ""
|
||||
converted: bool = Field(False, alias="converted")
|
||||
@ -837,9 +966,9 @@ class ChatMessageMetaData(BaseModel):
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = Field(default=8092, alias="maxTokens")
|
||||
top_p: float = Field(default=1, alias="topP")
|
||||
frequency_penalty: Optional[float] = Field(None, alias="frequencyPenalty")
|
||||
presence_penalty: Optional[float] = Field(None, alias="presencePenalty")
|
||||
stop_sequences: Optional[List[str]] = Field(None, alias="stopSequences")
|
||||
frequency_penalty: float = Field(default=0, alias="frequencyPenalty")
|
||||
presence_penalty: float = Field(default=0, alias="presencePenalty")
|
||||
stop_sequences: List[str] = Field(default=[], alias="stopSequences")
|
||||
rag_results: List[ChromaDBGetResponse] = Field(default_factory=list, alias="ragResults")
|
||||
llm_history: List[LLMMessage] = Field(default_factory=list, alias="llmHistory")
|
||||
eval_count: int = 0
|
||||
@ -862,16 +991,31 @@ class ChatMessageUser(ApiMessage):
|
||||
|
||||
class ChatMessage(ChatMessageUser):
|
||||
role: ChatSenderType = ChatSenderType.ASSISTANT
|
||||
metadata: ChatMessageMetaData = Field(default_factory=ChatMessageMetaData)
|
||||
metadata: ChatMessageMetaData = Field(default=ChatMessageMetaData())
|
||||
#attachments: Optional[List[Attachment]] = None
|
||||
#reactions: Optional[List[MessageReaction]] = None
|
||||
#is_edited: bool = Field(False, alias="isEdited")
|
||||
#edit_history: Optional[List[EditHistory]] = Field(None, alias="editHistory")
|
||||
|
||||
class GPUInfo(BaseModel):
|
||||
name: str
|
||||
memory: int
|
||||
discrete: bool
|
||||
|
||||
class SystemInfo(BaseModel):
|
||||
installed_RAM: str = Field(..., alias="installedRAM")
|
||||
graphics_cards: List[GPUInfo] = Field(..., alias="graphicsCards")
|
||||
CPU: str
|
||||
llm_model: str = Field(default=defines.model, alias="llmModel")
|
||||
embedding_model: str = Field(default=defines.embedding_model, alias="embeddingModel")
|
||||
max_context_length: int = Field(default=defines.max_context, alias="maxContextLength")
|
||||
model_config = {
|
||||
"populate_by_name": True # Allow both field names and aliases
|
||||
}
|
||||
|
||||
class ChatSession(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
user_id: Optional[str] = Field(None, alias="userId")
|
||||
guest_id: Optional[str] = Field(None, alias="guestId")
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="createdAt")
|
||||
last_activity: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="lastActivity")
|
||||
title: Optional[str] = None
|
||||
@ -937,7 +1081,7 @@ class UserActivity(BaseModel):
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_user_or_guest(self) -> "ChatSession":
|
||||
def check_user_or_guest(self):
|
||||
if not self.user_id and not self.guest_id:
|
||||
raise ValueError("Either user_id or guest_id must be provided")
|
||||
return self
|
||||
@ -1102,6 +1246,8 @@ class JobListResponse(BaseModel):
|
||||
error: Optional[ErrorDetail] = None
|
||||
meta: Optional[Dict[str, Any]] = None
|
||||
|
||||
User = Union[Candidate, CandidateAI, Employer, Guest]
|
||||
|
||||
# Forward references resolution
|
||||
Candidate.update_forward_refs()
|
||||
Employer.update_forward_refs()
|
||||
|
@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field # type: ignore
|
||||
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
import os
|
||||
import glob
|
||||
@ -9,15 +9,15 @@ import hashlib
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
import numpy as np # type: ignore
|
||||
import numpy as np
|
||||
import traceback
|
||||
|
||||
import chromadb # type: ignore
|
||||
from watchdog.observers import Observer # type: ignore
|
||||
from watchdog.events import FileSystemEventHandler # type: ignore
|
||||
import umap # type: ignore
|
||||
from markitdown import MarkItDown # type: ignore
|
||||
from chromadb.api.models.Collection import Collection # type: ignore
|
||||
import chromadb
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
import umap
|
||||
from markitdown import MarkItDown
|
||||
from chromadb.api.models.Collection import Collection
|
||||
|
||||
from .markdown_chunker import (
|
||||
MarkdownChunker,
|
||||
@ -351,9 +351,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
os.makedirs(self.persist_directory)
|
||||
|
||||
# Initialize ChromaDB client
|
||||
chroma_client = chromadb.PersistentClient( # type: ignore
|
||||
chroma_client = chromadb.PersistentClient(
|
||||
path=self.persist_directory,
|
||||
settings=chromadb.Settings(anonymized_telemetry=False), # type: ignore
|
||||
settings=chromadb.Settings(anonymized_telemetry=False),
|
||||
)
|
||||
|
||||
# Check if the collection exists
|
||||
|
287
src/backend/rate_limiter.py
Normal file
287
src/backend/rate_limiter.py
Normal file
@ -0,0 +1,287 @@
|
||||
"""
|
||||
Rate limiting utilities for guest and authenticated users
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timedelta, UTC
|
||||
from typing import Dict, Optional, Tuple, Any
|
||||
from pydantic import BaseModel # type: ignore
|
||||
from database import RedisDatabase
|
||||
from logger import logger
|
||||
|
||||
class RateLimitConfig(BaseModel):
|
||||
"""Rate limit configuration"""
|
||||
requests_per_minute: int
|
||||
requests_per_hour: int
|
||||
requests_per_day: int
|
||||
burst_limit: int # Maximum requests in a short burst
|
||||
burst_window_seconds: int = 60 # Window for burst detection
|
||||
|
||||
class GuestRateLimitConfig(RateLimitConfig):
|
||||
"""Rate limits for guest users - more restrictive"""
|
||||
requests_per_minute: int = 10
|
||||
requests_per_hour: int = 100
|
||||
requests_per_day: int = 500
|
||||
burst_limit: int = 15
|
||||
burst_window_seconds: int = 60
|
||||
|
||||
class AuthenticatedUserRateLimitConfig(RateLimitConfig):
|
||||
"""Rate limits for authenticated users - more generous"""
|
||||
requests_per_minute: int = 60
|
||||
requests_per_hour: int = 1000
|
||||
requests_per_day: int = 10000
|
||||
burst_limit: int = 100
|
||||
burst_window_seconds: int = 60
|
||||
|
||||
class PremiumUserRateLimitConfig(RateLimitConfig):
|
||||
"""Rate limits for premium/admin users - most generous"""
|
||||
requests_per_minute: int = 120
|
||||
requests_per_hour: int = 5000
|
||||
requests_per_day: int = 50000
|
||||
burst_limit: int = 200
|
||||
burst_window_seconds: int = 60
|
||||
|
||||
class RateLimitResult(BaseModel):
|
||||
"""Result of rate limit check"""
|
||||
allowed: bool
|
||||
reason: Optional[str] = None
|
||||
retry_after_seconds: Optional[int] = None
|
||||
remaining_requests: Dict[str, int] = {}
|
||||
reset_times: Dict[str, datetime] = {}
|
||||
|
||||
class RateLimiter:
|
||||
"""Rate limiter using Redis for distributed rate limiting"""
|
||||
|
||||
def __init__(self, database: RedisDatabase):
|
||||
self.database = database
|
||||
self.redis = database.redis
|
||||
|
||||
# Rate limit configurations
|
||||
self.guest_config = GuestRateLimitConfig()
|
||||
self.user_config = AuthenticatedUserRateLimitConfig()
|
||||
self.premium_config = PremiumUserRateLimitConfig()
|
||||
|
||||
def get_config_for_user(self, user_type: str, is_admin: bool = False) -> RateLimitConfig:
|
||||
"""Get rate limit configuration based on user type"""
|
||||
if user_type == "guest":
|
||||
return self.guest_config
|
||||
elif is_admin:
|
||||
return self.premium_config
|
||||
else:
|
||||
return self.user_config
|
||||
|
||||
async def check_rate_limit(
|
||||
self,
|
||||
user_id: str,
|
||||
user_type: str,
|
||||
is_admin: bool = False,
|
||||
endpoint: Optional[str] = None
|
||||
) -> RateLimitResult:
|
||||
"""
|
||||
Check if user has exceeded rate limits
|
||||
|
||||
Args:
|
||||
user_id: Unique identifier for the user (guest session ID or user ID)
|
||||
user_type: "guest", "candidate", or "employer"
|
||||
is_admin: Whether user has admin privileges
|
||||
endpoint: Optional endpoint-specific rate limiting
|
||||
|
||||
Returns:
|
||||
RateLimitResult indicating if request is allowed
|
||||
"""
|
||||
config = self.get_config_for_user(user_type, is_admin)
|
||||
current_time = datetime.now(UTC)
|
||||
|
||||
# Create Redis keys for different time windows
|
||||
base_key = f"rate_limit:{user_type}:{user_id}"
|
||||
keys = {
|
||||
"minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}",
|
||||
"hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}",
|
||||
"day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}",
|
||||
"burst": f"{base_key}:burst"
|
||||
}
|
||||
|
||||
# Add endpoint-specific limiting if provided
|
||||
if endpoint:
|
||||
keys = {k: f"{v}:{endpoint}" for k, v in keys.items()}
|
||||
|
||||
try:
|
||||
# Use Redis pipeline for atomic operations
|
||||
pipe = self.redis.pipeline()
|
||||
|
||||
# Get current counts
|
||||
for key in keys.values():
|
||||
pipe.get(key)
|
||||
|
||||
results = await pipe.execute()
|
||||
current_counts = {
|
||||
"minute": int(results[0] or 0),
|
||||
"hour": int(results[1] or 0),
|
||||
"day": int(results[2] or 0),
|
||||
"burst": int(results[3] or 0)
|
||||
}
|
||||
|
||||
# Check limits
|
||||
limits = {
|
||||
"minute": config.requests_per_minute,
|
||||
"hour": config.requests_per_hour,
|
||||
"day": config.requests_per_day,
|
||||
"burst": config.burst_limit
|
||||
}
|
||||
|
||||
# Check each limit
|
||||
for window, current_count in current_counts.items():
|
||||
limit = limits[window]
|
||||
if current_count >= limit:
|
||||
# Calculate retry after time
|
||||
if window == "minute":
|
||||
retry_after = 60 - current_time.second
|
||||
elif window == "hour":
|
||||
retry_after = 3600 - (current_time.minute * 60 + current_time.second)
|
||||
elif window == "day":
|
||||
retry_after = 86400 - (current_time.hour * 3600 + current_time.minute * 60 + current_time.second)
|
||||
else: # burst
|
||||
retry_after = config.burst_window_seconds
|
||||
|
||||
logger.warning(f"🚫 Rate limit exceeded for {user_type} {user_id}: {current_count}/{limit} {window}")
|
||||
|
||||
return RateLimitResult(
|
||||
allowed=False,
|
||||
reason=f"Rate limit exceeded: {current_count}/{limit} requests per {window}",
|
||||
retry_after_seconds=retry_after,
|
||||
remaining_requests={k: max(0, limits[k] - v) for k, v in current_counts.items()},
|
||||
reset_times=self._calculate_reset_times(current_time)
|
||||
)
|
||||
|
||||
# If we get here, request is allowed - increment counters
|
||||
pipe = self.redis.pipeline()
|
||||
|
||||
# Increment minute counter (expires after 2 minutes)
|
||||
pipe.incr(keys["minute"])
|
||||
pipe.expire(keys["minute"], 120)
|
||||
|
||||
# Increment hour counter (expires after 2 hours)
|
||||
pipe.incr(keys["hour"])
|
||||
pipe.expire(keys["hour"], 7200)
|
||||
|
||||
# Increment day counter (expires after 2 days)
|
||||
pipe.incr(keys["day"])
|
||||
pipe.expire(keys["day"], 172800)
|
||||
|
||||
# Increment burst counter (expires after burst window)
|
||||
pipe.incr(keys["burst"])
|
||||
pipe.expire(keys["burst"], config.burst_window_seconds)
|
||||
|
||||
await pipe.execute()
|
||||
|
||||
# Calculate remaining requests
|
||||
remaining = {
|
||||
k: max(0, limits[k] - (current_counts[k] + 1))
|
||||
for k in current_counts.keys()
|
||||
}
|
||||
|
||||
logger.debug(f"✅ Rate limit check passed for {user_type} {user_id}")
|
||||
|
||||
return RateLimitResult(
|
||||
allowed=True,
|
||||
remaining_requests=remaining,
|
||||
reset_times=self._calculate_reset_times(current_time)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Rate limit check failed for {user_id}: {e}")
|
||||
# Fail open - allow request if rate limiting system fails
|
||||
return RateLimitResult(allowed=True, reason="Rate limit check failed - allowing request")
|
||||
|
||||
def _calculate_reset_times(self, current_time: datetime) -> Dict[str, datetime]:
|
||||
"""Calculate when each rate limit window resets"""
|
||||
next_minute = current_time.replace(second=0, microsecond=0) + timedelta(minutes=1)
|
||||
next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
|
||||
next_day = current_time.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
|
||||
|
||||
return {
|
||||
"minute": next_minute,
|
||||
"hour": next_hour,
|
||||
"day": next_day
|
||||
}
|
||||
|
||||
async def get_user_rate_limit_status(
|
||||
self,
|
||||
user_id: str,
|
||||
user_type: str,
|
||||
is_admin: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Get current rate limit status for a user"""
|
||||
config = self.get_config_for_user(user_type, is_admin)
|
||||
current_time = datetime.now(UTC)
|
||||
|
||||
base_key = f"rate_limit:{user_type}:{user_id}"
|
||||
keys = {
|
||||
"minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}",
|
||||
"hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}",
|
||||
"day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}",
|
||||
"burst": f"{base_key}:burst"
|
||||
}
|
||||
|
||||
try:
|
||||
pipe = self.redis.pipeline()
|
||||
for key in keys.values():
|
||||
pipe.get(key)
|
||||
|
||||
results = await pipe.execute()
|
||||
current_counts = {
|
||||
"minute": int(results[0] or 0),
|
||||
"hour": int(results[1] or 0),
|
||||
"day": int(results[2] or 0),
|
||||
"burst": int(results[3] or 0)
|
||||
}
|
||||
|
||||
limits = {
|
||||
"minute": config.requests_per_minute,
|
||||
"hour": config.requests_per_hour,
|
||||
"day": config.requests_per_day,
|
||||
"burst": config.burst_limit
|
||||
}
|
||||
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"user_type": user_type,
|
||||
"is_admin": is_admin,
|
||||
"current_usage": current_counts,
|
||||
"limits": limits,
|
||||
"remaining": {k: max(0, limits[k] - current_counts[k]) for k in limits.keys()},
|
||||
"reset_times": self._calculate_reset_times(current_time),
|
||||
"config": config.model_dump()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to get rate limit status for {user_id}: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def reset_user_rate_limits(self, user_id: str, user_type: str) -> bool:
|
||||
"""Reset all rate limits for a user (admin function)"""
|
||||
try:
|
||||
base_key = f"rate_limit:{user_type}:{user_id}"
|
||||
pattern = f"{base_key}:*"
|
||||
|
||||
cursor = 0
|
||||
deleted_count = 0
|
||||
|
||||
while True:
|
||||
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
|
||||
if keys:
|
||||
await self.redis.delete(*keys)
|
||||
deleted_count += len(keys)
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
logger.info(f"🔄 Reset {deleted_count} rate limit keys for {user_type} {user_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to reset rate limits for {user_id}: {e}")
|
||||
return False
|
||||
|
||||
|
@ -2,6 +2,7 @@ import defines
|
||||
import re
|
||||
import subprocess
|
||||
import math
|
||||
from models import SystemInfo
|
||||
|
||||
def get_installed_ram():
|
||||
try:
|
||||
@ -70,12 +71,18 @@ def get_cpu_info():
|
||||
except Exception as e:
|
||||
return f"Error retrieving CPU info: {e}"
|
||||
|
||||
def system_info():
|
||||
return {
|
||||
"System RAM": get_installed_ram(),
|
||||
"Graphics Card": get_graphics_cards(),
|
||||
"CPU": get_cpu_info(),
|
||||
"LLM Model": defines.model,
|
||||
"Embedding Model": defines.embedding_model,
|
||||
"Context length": defines.max_context,
|
||||
}
|
||||
def system_info() -> SystemInfo:
|
||||
"""
|
||||
Collects system information including RAM, GPU, CPU, LLM model, embedding model, and context length.
|
||||
Returns:
|
||||
SystemInfo: An object containing the collected system information.
|
||||
"""
|
||||
system = SystemInfo(
|
||||
installed_RAM=get_installed_ram(),
|
||||
graphics_cards=get_graphics_cards(),
|
||||
CPU=get_cpu_info(),
|
||||
llm_model=defines.model,
|
||||
embedding_model=defines.embedding_model,
|
||||
max_context_length=defines.max_context,
|
||||
)
|
||||
return system
|
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing import List, Optional, Generator, ClassVar, Any, Dict
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
@ -7,12 +7,12 @@ from typing import (
|
||||
)
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from bs4 import BeautifulSoup # type: ignore
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from geopy.geocoders import Nominatim # type: ignore
|
||||
import pytz # type: ignore
|
||||
from geopy.geocoders import Nominatim
|
||||
import pytz
|
||||
import requests
|
||||
import yfinance as yf # type: ignore
|
||||
import yfinance as yf
|
||||
import logging
|
||||
|
||||
|
||||
@ -523,4 +523,4 @@ def enabled_tools(tools: List[ToolEntry]) -> List[ToolEntry]:
|
||||
|
||||
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite", "GenerateImage"]
|
||||
__all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"]
|
||||
# __all__.extend(__tool_functions__) # type: ignore
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from fastapi import FastAPI, HTTPException # type: ignore
|
||||
from fastapi.responses import StreamingResponse # type: ignore
|
||||
from pydantic import BaseModel # type: ignore
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Dict, Any
|
||||
import json
|
||||
import asyncio
|
||||
@ -234,5 +234,5 @@ async def health_check():
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn # type: ignore
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
@ -92,7 +92,7 @@ class OllamaAdapter(BaseLLMAdapter):
|
||||
def __init__(self, **config):
|
||||
super().__init__(**config)
|
||||
import ollama
|
||||
self.client = ollama.AsyncClient( # type: ignore
|
||||
self.client = ollama.AsyncClient(
|
||||
host=config.get('host', 'http://localhost:11434')
|
||||
)
|
||||
|
||||
@ -187,7 +187,7 @@ class OpenAIAdapter(BaseLLMAdapter):
|
||||
|
||||
def __init__(self, **config):
|
||||
super().__init__(**config)
|
||||
import openai # type: ignore
|
||||
import openai
|
||||
self.client = openai.AsyncOpenAI(
|
||||
api_key=config.get('api_key', os.getenv('OPENAI_API_KEY'))
|
||||
)
|
||||
@ -259,7 +259,7 @@ class AnthropicAdapter(BaseLLMAdapter):
|
||||
|
||||
def __init__(self, **config):
|
||||
super().__init__(**config)
|
||||
import anthropic # type: ignore
|
||||
import anthropic
|
||||
self.client = anthropic.AsyncAnthropic(
|
||||
api_key=config.get('api_key', os.getenv('ANTHROPIC_API_KEY'))
|
||||
)
|
||||
@ -344,7 +344,7 @@ class GeminiAdapter(BaseLLMAdapter):
|
||||
|
||||
def __init__(self, **config):
|
||||
super().__init__(**config)
|
||||
import google.generativeai as genai # type: ignore
|
||||
import google.generativeai as genai
|
||||
genai.configure(api_key=config.get('api_key', os.getenv('GEMINI_API_KEY')))
|
||||
self.genai = genai
|
||||
|
||||
@ -476,7 +476,7 @@ class UnifiedLLMProxy:
|
||||
|
||||
result = await self.chat(model, messages, provider, stream=True, **kwargs)
|
||||
# Type checker now knows this is an AsyncGenerator due to stream=True
|
||||
async for chunk in result: # type: ignore
|
||||
async for chunk in result:
|
||||
yield chunk
|
||||
|
||||
async def chat_single(
|
||||
@ -490,7 +490,7 @@ class UnifiedLLMProxy:
|
||||
|
||||
result = await self.chat(model, messages, provider, stream=False, **kwargs)
|
||||
# Type checker now knows this is a ChatResponse due to stream=False
|
||||
return result # type: ignore
|
||||
return result
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
@ -517,7 +517,7 @@ class UnifiedLLMProxy:
|
||||
"""Stream text generation using specified or default provider"""
|
||||
|
||||
result = await self.generate(model, prompt, provider, stream=True, **kwargs)
|
||||
async for chunk in result: # type: ignore
|
||||
async for chunk in result:
|
||||
yield chunk
|
||||
|
||||
async def generate_single(
|
||||
@ -530,7 +530,7 @@ class UnifiedLLMProxy:
|
||||
"""Get single generation response using specified or default provider"""
|
||||
|
||||
result = await self.generate(model, prompt, provider, stream=False, **kwargs)
|
||||
return result # type: ignore
|
||||
return result
|
||||
|
||||
async def list_models(self, provider: Optional[LLMProvider] = None) -> List[str]:
|
||||
"""List available models for specified or default provider"""
|
||||
|
@ -1,8 +1,8 @@
|
||||
LLM_TIMEOUT = 600
|
||||
|
||||
from utils import logger
|
||||
from pydantic import BaseModel, Field, ValidationError # type: ignore
|
||||
from pydantic_core import PydanticSerializationError # type: ignore
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from pydantic_core import PydanticSerializationError
|
||||
from typing import List
|
||||
|
||||
from typing import AsyncGenerator, Dict, Optional
|
||||
@ -48,18 +48,18 @@ try_import("prometheus_fastapi_instrumentator")
|
||||
|
||||
import ollama
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, Request, HTTPException, Depends # type: ignore
|
||||
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse # type: ignore
|
||||
from fastapi.middleware.cors import CORSMiddleware # type: ignore
|
||||
import uvicorn # type: ignore
|
||||
import numpy as np # type: ignore
|
||||
from fastapi import FastAPI, Request, HTTPException, Depends
|
||||
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import uvicorn
|
||||
import numpy as np
|
||||
from utils import redis_manager
|
||||
import redis.asyncio as redis # type: ignore
|
||||
import redis.asyncio as redis
|
||||
|
||||
# Prometheus
|
||||
from prometheus_client import Summary # type: ignore
|
||||
from prometheus_fastapi_instrumentator import Instrumentator # type: ignore
|
||||
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
||||
from prometheus_client import Summary
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from prometheus_client import CollectorRegistry, Counter
|
||||
|
||||
from utils import (
|
||||
rag as Rag,
|
||||
@ -1308,7 +1308,7 @@ def main():
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="umap.*")
|
||||
|
||||
llm = ollama.Client(host=args.ollama_server) # type: ignore
|
||||
llm = ollama.Client(host=args.ollama_server)
|
||||
web_server = WebServer(llm, args.ollama_model)
|
||||
web_server.run(host=args.web_host, port=args.web_port, use_reloader=False)
|
||||
|
||||
|
@ -5,7 +5,7 @@ import sys
|
||||
import os
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../utils")))
|
||||
from markdown_chunker import MarkdownChunker # type: ignore
|
||||
from markdown_chunker import MarkdownChunker
|
||||
|
||||
chunker = MarkdownChunker()
|
||||
chunks = chunker.process_file("./src/tests/test.md") # docs/resume/resume.md")
|
||||
|
@ -1,10 +1,10 @@
|
||||
from fastapi import FastAPI, Request, Depends, Query # type: ignore
|
||||
from fastapi.responses import RedirectResponse, JSONResponse # type: ignore
|
||||
from fastapi import FastAPI, Request, Depends, Query
|
||||
from fastapi.responses import RedirectResponse, JSONResponse
|
||||
from uuid import UUID, uuid4
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
from anyio.to_thread import run_sync # type: ignore
|
||||
from anyio.to_thread import run_sync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -69,7 +69,7 @@ class ContextRouteManager:
|
||||
logger.info(f"Invalid UUID, redirecting to {redirect_url}")
|
||||
raise RedirectToContext(redirect_url)
|
||||
|
||||
return _ensure_context_dependency # type: ignore
|
||||
return _ensure_context_dependency
|
||||
|
||||
def route_pattern(self, path: str, *dependencies, **kwargs):
|
||||
logger.info(f"Registering route: {path}")
|
||||
@ -134,6 +134,6 @@ async def redirect_history(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn # type: ignore
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8900)
|
||||
|
@ -1,9 +1,9 @@
|
||||
# From /opt/backstory run:
|
||||
# python -m src.tests.test-embedding
|
||||
import numpy as np # type: ignore
|
||||
import numpy as np
|
||||
import logging
|
||||
import argparse
|
||||
from ollama import Client # type: ignore
|
||||
from ollama import Client
|
||||
from ..utils import defines
|
||||
|
||||
# Configure logging
|
||||
|
@ -1,11 +1,11 @@
|
||||
# From /opt/backstory run:
|
||||
# python -m src.tests.test-rag
|
||||
from ..utils import logger
|
||||
from pydantic import BaseModel, field_validator # type: ignore
|
||||
from prometheus_client import CollectorRegistry # type: ignore
|
||||
from pydantic import BaseModel, field_validator
|
||||
from prometheus_client import CollectorRegistry
|
||||
from typing import List, Dict, Any, Optional
|
||||
import ollama
|
||||
import numpy as np # type: ignore
|
||||
import numpy as np
|
||||
from ..utils import (rag as Rag, ChromaDBGetResponse)
|
||||
from ..utils import Context
|
||||
from ..utils import defines
|
||||
@ -44,7 +44,7 @@ rag.embeddings = np.array([[1.0, 2.0], [3.0, 4.0]])
|
||||
json_str = rag.model_dump(mode="json")
|
||||
logger.info(json_str)
|
||||
rag = ChromaDBGetResponse.model_validate(json_str)
|
||||
llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
|
||||
llm = ollama.Client(host=defines.ollama_api_url)
|
||||
prometheus_collector = CollectorRegistry()
|
||||
observer, file_watcher = Rag.start_file_watcher(
|
||||
llm=llm,
|
||||
|
@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import BaseModel # type: ignore
|
||||
from pydantic import BaseModel
|
||||
from typing import (
|
||||
Any,
|
||||
Set
|
||||
@ -39,7 +39,7 @@ __all__ = [
|
||||
"generate_image", "ImageRequest"
|
||||
]
|
||||
|
||||
__all__.extend(agents_all) # type: ignore
|
||||
__all__.extend(agents_all)
|
||||
|
||||
logger = setup_logging(level=defines.logging_level)
|
||||
|
||||
|
@ -44,7 +44,7 @@ for path in package_dir.glob("*.py"):
|
||||
class_registry[name] = (full_module_name, name)
|
||||
globals()[name] = obj
|
||||
logger.info(f"Adding agent: {name}")
|
||||
__all__.append(name) # type: ignore
|
||||
__all__.append(name)
|
||||
except ImportError as e:
|
||||
logger.error(f"Error importing {full_module_name}: {e}")
|
||||
raise e
|
||||
|
@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import BaseModel, Field # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import (
|
||||
Literal,
|
||||
get_args,
|
||||
@ -19,7 +19,7 @@ import inspect
|
||||
from abc import ABC
|
||||
import asyncio
|
||||
|
||||
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
|
||||
from prometheus_client import Counter, Summary, CollectorRegistry
|
||||
|
||||
from ..setup_logging import setup_logging
|
||||
|
||||
@ -33,7 +33,7 @@ from .types import agent_registry
|
||||
from .. import defines
|
||||
from ..message import Message, Tunables
|
||||
from ..metrics import Metrics
|
||||
from ..tools import TickerValue, WeatherForecast, AnalyzeSite, GenerateImage, DateTime, llm_tools # type: ignore -- dynamically added to __all__
|
||||
from ..tools import TickerValue, WeatherForecast, AnalyzeSite, GenerateImage, DateTime, llm_tools
|
||||
from ..conversation import Conversation
|
||||
|
||||
class LLMMessage(BaseModel):
|
||||
|
@ -53,7 +53,7 @@ class Chat(Agent):
|
||||
Chat Agent
|
||||
"""
|
||||
|
||||
agent_type: Literal["chat"] = "chat" # type: ignore
|
||||
agent_type: Literal["chat"] = "chat"
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
system_prompt: str = system_message
|
||||
|
@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import model_validator # type: ignore
|
||||
from pydantic import model_validator
|
||||
from typing import (
|
||||
Literal,
|
||||
ClassVar,
|
||||
@ -31,7 +31,7 @@ When answering queries, follow these steps:
|
||||
|
||||
|
||||
class FactCheck(Agent):
|
||||
agent_type: Literal["fact_check"] = "fact_check" # type: ignore
|
||||
agent_type: Literal["fact_check"] = "fact_check"
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
system_prompt: str = system_fact_check
|
||||
|
@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import model_validator, Field, BaseModel # type: ignore
|
||||
from pydantic import model_validator, Field, BaseModel
|
||||
from typing import (
|
||||
Dict,
|
||||
Literal,
|
||||
@ -37,7 +37,7 @@ seed = int(time.time())
|
||||
random.seed(seed)
|
||||
|
||||
class ImageGenerator(Agent):
|
||||
agent_type: Literal["image"] = "image" # type: ignore
|
||||
agent_type: Literal["image"] = "image"
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
agent_persist: bool = False
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import model_validator, Field # type: ignore
|
||||
from pydantic import model_validator, Field
|
||||
from typing import (
|
||||
Dict,
|
||||
Literal,
|
||||
@ -17,7 +17,7 @@ import traceback
|
||||
import asyncio
|
||||
import time
|
||||
import asyncio
|
||||
import numpy as np # type: ignore
|
||||
import numpy as np
|
||||
|
||||
from . base import Agent, agent_registry, LLMMessage
|
||||
from .. message import Message
|
||||
@ -36,7 +36,7 @@ Answer questions about the job description.
|
||||
|
||||
|
||||
class JobDescription(Agent):
|
||||
agent_type: Literal["job_description"] = "job_description" # type: ignore
|
||||
agent_type: Literal["job_description"] = "job_description"
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
system_prompt: str = system_generate_resume
|
||||
|
@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import model_validator, Field, BaseModel # type: ignore
|
||||
from pydantic import model_validator, Field, BaseModel
|
||||
from typing import (
|
||||
Dict,
|
||||
Literal,
|
||||
@ -23,7 +23,7 @@ import asyncio
|
||||
import time
|
||||
import os
|
||||
import random
|
||||
from names_dataset import NameDataset, NameWrapper # type: ignore
|
||||
from names_dataset import NameDataset, NameWrapper
|
||||
|
||||
from .base import Agent, agent_registry, LLMMessage
|
||||
from ..message import Message
|
||||
@ -128,7 +128,7 @@ logger = logging.getLogger(__name__)
|
||||
class EthnicNameGenerator:
|
||||
def __init__(self):
|
||||
try:
|
||||
from names_dataset import NameDataset # type: ignore
|
||||
from names_dataset import NameDataset
|
||||
self.nd = NameDataset()
|
||||
except ImportError:
|
||||
logger.error("NameDataset not available. Please install: pip install names-dataset")
|
||||
@ -292,7 +292,7 @@ class EthnicNameGenerator:
|
||||
return names
|
||||
|
||||
class PersonaGenerator(Agent):
|
||||
agent_type: Literal["persona"] = "persona" # type: ignore
|
||||
agent_type: Literal["persona"] = "persona"
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
agent_persist: bool = False
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import model_validator # type: ignore
|
||||
from pydantic import model_validator
|
||||
from typing import (
|
||||
Literal,
|
||||
ClassVar,
|
||||
@ -46,7 +46,7 @@ When answering queries, follow these steps:
|
||||
|
||||
|
||||
class Resume(Agent):
|
||||
agent_type: Literal["resume"] = "resume" # type: ignore
|
||||
agent_type: Literal["resume"] = "resume"
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
system_prompt: str = system_fact_check
|
||||
|
@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel, Field # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
import json
|
||||
from typing import Any, List, Set
|
||||
|
||||
|
@ -1,9 +1,9 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from uuid import uuid4
|
||||
from typing import List, Optional, Generator, ClassVar, Any, TYPE_CHECKING
|
||||
from typing_extensions import Annotated, Union
|
||||
import numpy as np # type: ignore
|
||||
import numpy as np
|
||||
import logging
|
||||
from uuid import uuid4
|
||||
import traceback
|
||||
@ -44,7 +44,7 @@ class Context(BaseModel):
|
||||
user_facts: Optional[str] = None
|
||||
|
||||
# Class managed fields
|
||||
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( # type: ignore
|
||||
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel, Field # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List
|
||||
from .message import Message
|
||||
|
||||
|
@ -4,8 +4,8 @@ import re
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import torch # type: ignore
|
||||
from diffusers import StableDiffusionPipeline, FluxPipeline # type: ignore
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline, FluxPipeline
|
||||
|
||||
class ImageModelCache: # Stay loaded for 3 hours
|
||||
def __init__(self, timeout_seconds: float = 3 * 60 * 60):
|
||||
|
@ -1,8 +1,8 @@
|
||||
from pydantic import BaseModel, Field # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, List, Optional, Any, Union, Mapping
|
||||
from datetime import datetime, timezone
|
||||
from . rag import ChromaDBGetResponse
|
||||
from ollama._types import Options # type: ignore
|
||||
from ollama._types import Options
|
||||
|
||||
class Tunables(BaseModel):
|
||||
enable_rag: bool = True # Enable RAG collection chromadb matching
|
||||
|
@ -1,4 +1,4 @@
|
||||
from prometheus_client import Counter, Histogram # type: ignore
|
||||
from prometheus_client import Counter, Histogram
|
||||
from threading import Lock
|
||||
|
||||
def singleton(cls):
|
||||
|
@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import BaseModel, Field # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Literal, Any, AsyncGenerator, Optional
|
||||
import inspect
|
||||
import random
|
||||
@ -12,7 +12,7 @@ import os
|
||||
import gc
|
||||
import tempfile
|
||||
import uuid
|
||||
import torch # type: ignore
|
||||
import torch
|
||||
import asyncio
|
||||
import time
|
||||
import json
|
||||
|
@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field # type: ignore
|
||||
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
import os
|
||||
import glob
|
||||
@ -8,16 +8,16 @@ import hashlib
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
import numpy as np # type: ignore
|
||||
import numpy as np
|
||||
import traceback
|
||||
|
||||
import chromadb # type: ignore
|
||||
import chromadb
|
||||
import ollama
|
||||
from watchdog.observers import Observer # type: ignore
|
||||
from watchdog.events import FileSystemEventHandler # type: ignore
|
||||
import umap # type: ignore
|
||||
from markitdown import MarkItDown # type: ignore
|
||||
from chromadb.api.models.Collection import Collection # type: ignore
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
import umap
|
||||
from markitdown import MarkItDown
|
||||
from chromadb.api.models.Collection import Collection
|
||||
|
||||
from .markdown_chunker import (
|
||||
MarkdownChunker,
|
||||
@ -388,9 +388,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
os.makedirs(self.persist_directory)
|
||||
|
||||
# Initialize ChromaDB client
|
||||
chroma_client = chromadb.PersistentClient( # type: ignore
|
||||
chroma_client = chromadb.PersistentClient(
|
||||
path=self.persist_directory,
|
||||
settings=chromadb.Settings(anonymized_telemetry=False), # type: ignore
|
||||
settings=chromadb.Settings(anonymized_telemetry=False),
|
||||
)
|
||||
|
||||
# Check if the collection exists
|
||||
|
@ -1,4 +1,4 @@
|
||||
import redis.asyncio as redis # type: ignore
|
||||
import redis.asyncio as redis
|
||||
from typing import Optional
|
||||
import os
|
||||
import logging
|
||||
|
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing import List, Optional, Generator, ClassVar, Any, Dict
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
@ -7,12 +7,12 @@ from typing import (
|
||||
)
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from bs4 import BeautifulSoup # type: ignore
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from geopy.geocoders import Nominatim # type: ignore
|
||||
import pytz # type: ignore
|
||||
from geopy.geocoders import Nominatim
|
||||
import pytz
|
||||
import requests
|
||||
import yfinance as yf # type: ignore
|
||||
import yfinance as yf
|
||||
import logging
|
||||
|
||||
|
||||
@ -523,4 +523,4 @@ def enabled_tools(tools: List[ToolEntry]) -> List[ToolEntry]:
|
||||
|
||||
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite", "GenerateImage"]
|
||||
__all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"]
|
||||
# __all__.extend(__tool_functions__) # type: ignore
|
||||
|
||||
|
@ -1,13 +1,13 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from uuid import uuid4
|
||||
from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING, Literal
|
||||
|
||||
from typing_extensions import Annotated, Union
|
||||
import numpy as np # type: ignore
|
||||
import numpy as np
|
||||
import logging
|
||||
from uuid import uuid4
|
||||
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
||||
from prometheus_client import CollectorRegistry, Counter
|
||||
import traceback
|
||||
import os
|
||||
import json
|
||||
|
Loading…
x
Reference in New Issue
Block a user