Guest seems to work!
This commit is contained in:
parent
35ef9898f1
commit
20f8d7bd32
@ -188,7 +188,7 @@ RUN pip install prometheus-client prometheus-fastapi-instrumentator
|
|||||||
RUN pip install "redis[hiredis]>=4.5.0"
|
RUN pip install "redis[hiredis]>=4.5.0"
|
||||||
|
|
||||||
# New backend implementation
|
# New backend implementation
|
||||||
RUN pip install fastapi uvicorn "python-jose[cryptography]" bcrypt python-multipart
|
RUN pip install fastapi uvicorn "python-jose[cryptography]" bcrypt python-multipart schedule
|
||||||
|
|
||||||
# Needed for email verification
|
# Needed for email verification
|
||||||
RUN pip install pyyaml user-agents cryptography
|
RUN pip install pyyaml user-agents cryptography
|
||||||
|
@ -91,11 +91,11 @@ services:
|
|||||||
# Optional: Redis Commander for GUI management
|
# Optional: Redis Commander for GUI management
|
||||||
redis-commander:
|
redis-commander:
|
||||||
image: rediscommander/redis-commander:latest
|
image: rediscommander/redis-commander:latest
|
||||||
container_name: backstory-redis-commander
|
container_name: redis-commander
|
||||||
ports:
|
ports:
|
||||||
- "8081:8081"
|
- "8081:8081"
|
||||||
environment:
|
environment:
|
||||||
- REDIS_HOSTS=redis:6379
|
- REDIS_HOSTS=redis:redis:6379
|
||||||
networks:
|
networks:
|
||||||
- internal
|
- internal
|
||||||
depends_on:
|
depends_on:
|
||||||
|
@ -20,8 +20,8 @@ import '@fontsource/roboto/700.css';
|
|||||||
const BackstoryApp = () => {
|
const BackstoryApp = () => {
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
const location = useLocation();
|
const location = useLocation();
|
||||||
const snackRef = useRef<any>(null);
|
|
||||||
const chatRef = useRef<ConversationHandle>(null);
|
const chatRef = useRef<ConversationHandle>(null);
|
||||||
|
const snackRef = useRef<any>(null);
|
||||||
const setSnack = useCallback((message: string, severity?: SeverityType) => {
|
const setSnack = useCallback((message: string, severity?: SeverityType) => {
|
||||||
snackRef.current?.setSnack(message, severity);
|
snackRef.current?.setSnack(message, severity);
|
||||||
}, [snackRef]);
|
}, [snackRef]);
|
||||||
|
@ -12,7 +12,6 @@ import { Message } from './Message';
|
|||||||
import { DeleteConfirmation } from 'components/DeleteConfirmation';
|
import { DeleteConfirmation } from 'components/DeleteConfirmation';
|
||||||
import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryTextField';
|
import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryTextField';
|
||||||
import { BackstoryElementProps } from './BackstoryTab';
|
import { BackstoryElementProps } from './BackstoryTab';
|
||||||
import { connectionBase } from 'utils/Global';
|
|
||||||
import { useAuth } from "hooks/AuthContext";
|
import { useAuth } from "hooks/AuthContext";
|
||||||
import { StreamingResponse } from 'services/api-client';
|
import { StreamingResponse } from 'services/api-client';
|
||||||
import { ChatMessage, ChatContext, ChatSession, ChatQuery, ChatMessageUser, ChatMessageError, ChatMessageStreaming, ChatMessageStatus } from 'types/types';
|
import { ChatMessage, ChatContext, ChatSession, ChatQuery, ChatMessageUser, ChatMessageError, ChatMessageStreaming, ChatMessageStatus } from 'types/types';
|
||||||
@ -22,7 +21,7 @@ import './Conversation.css';
|
|||||||
import { useAppState } from 'hooks/GlobalContext';
|
import { useAppState } from 'hooks/GlobalContext';
|
||||||
|
|
||||||
const defaultMessage: ChatMessage = {
|
const defaultMessage: ChatMessage = {
|
||||||
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "assistant"
|
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "assistant", metadata: null as any
|
||||||
};
|
};
|
||||||
|
|
||||||
const loadingMessage: ChatMessage = { ...defaultMessage, content: "Establishing connection with server..." };
|
const loadingMessage: ChatMessage = { ...defaultMessage, content: "Establishing connection with server..." };
|
||||||
@ -325,16 +324,16 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
|
|||||||
<Box sx={{ p: 1, mt: 0, ...sx }}>
|
<Box sx={{ p: 1, mt: 0, ...sx }}>
|
||||||
{
|
{
|
||||||
filteredConversation.map((message, index) =>
|
filteredConversation.map((message, index) =>
|
||||||
<Message key={index} {...{ chatSession, sendQuery: processQuery, message, connectionBase, }} />
|
<Message key={index} {...{ chatSession, sendQuery: processQuery, message, }} />
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
processingMessage !== undefined &&
|
processingMessage !== undefined &&
|
||||||
<Message {...{ chatSession, sendQuery: processQuery, connectionBase, message: processingMessage, }} />
|
<Message {...{ chatSession, sendQuery: processQuery, message: processingMessage, }} />
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
streamingMessage !== undefined &&
|
streamingMessage !== undefined &&
|
||||||
<Message {...{ chatSession, sendQuery: processQuery, connectionBase, message: streamingMessage }} />
|
<Message {...{ chatSession, sendQuery: processQuery, message: streamingMessage }} />
|
||||||
}
|
}
|
||||||
<Box sx={{
|
<Box sx={{
|
||||||
display: "flex",
|
display: "flex",
|
||||||
|
@ -29,7 +29,7 @@ import {
|
|||||||
} from '@mui/icons-material';
|
} from '@mui/icons-material';
|
||||||
import { useAuth } from 'hooks/AuthContext';
|
import { useAuth } from 'hooks/AuthContext';
|
||||||
import { BackstoryPageProps } from './BackstoryTab';
|
import { BackstoryPageProps } from './BackstoryTab';
|
||||||
import { useNavigate } from 'react-router-dom';
|
import { Navigate, useNavigate } from 'react-router-dom';
|
||||||
|
|
||||||
// Email Verification Component
|
// Email Verification Component
|
||||||
const EmailVerificationPage = (props: BackstoryPageProps) => {
|
const EmailVerificationPage = (props: BackstoryPageProps) => {
|
||||||
@ -488,10 +488,11 @@ const RegistrationSuccessDialog = ({
|
|||||||
|
|
||||||
// Enhanced Login Component with MFA Support
|
// Enhanced Login Component with MFA Support
|
||||||
const LoginForm = () => {
|
const LoginForm = () => {
|
||||||
const { login, mfaResponse, isLoading, error } = useAuth();
|
const { login, mfaResponse, isLoading, error, user } = useAuth();
|
||||||
const [email, setEmail] = useState('');
|
const [email, setEmail] = useState('');
|
||||||
const [password, setPassword] = useState('');
|
const [password, setPassword] = useState('');
|
||||||
const [errorMessage, setErrorMessage] = useState<string | null>(null);
|
const [errorMessage, setErrorMessage] = useState<string | null>(null);
|
||||||
|
const navigate = useNavigate();
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!error) {
|
if (!error) {
|
||||||
@ -524,10 +525,13 @@ const LoginForm = () => {
|
|||||||
handleLoginSuccess();
|
handleLoginSuccess();
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleLoginSuccess = () => {
|
const handleLoginSuccess = () => {
|
||||||
// This could be handled by a router or parent component
|
if (!user) {
|
||||||
// For now, just showing the pattern
|
navigate('/');
|
||||||
console.log('Login successful - redirect to dashboard');
|
} else {
|
||||||
|
navigate(`/${user.userType}/dashboard`);
|
||||||
|
}
|
||||||
|
console.log('Login successful - redirect to dashboard');
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -39,7 +39,7 @@ interface JobAnalysisProps extends BackstoryPageProps {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const defaultMessage: ChatMessage = {
|
const defaultMessage: ChatMessage = {
|
||||||
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "assistant"
|
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "assistant", metadata: null as any
|
||||||
};
|
};
|
||||||
|
|
||||||
const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) => {
|
const JobMatchAnalysis: React.FC<JobAnalysisProps> = (props: JobAnalysisProps) => {
|
||||||
|
@ -101,6 +101,7 @@ const getStyle = (theme: Theme, type: ApiActivityType | ChatSenderType | "error"
|
|||||||
color: theme.palette.text.primary,
|
color: theme.palette.text.primary,
|
||||||
opacity: 0.95,
|
opacity: 0.95,
|
||||||
},
|
},
|
||||||
|
info: 'information',
|
||||||
preparing: 'status',
|
preparing: 'status',
|
||||||
processing: 'status',
|
processing: 'status',
|
||||||
qualifications: {
|
qualifications: {
|
||||||
|
@ -17,7 +17,6 @@ import TableContainer from '@mui/material/TableContainer';
|
|||||||
import TableRow from '@mui/material/TableRow';
|
import TableRow from '@mui/material/TableRow';
|
||||||
|
|
||||||
import { Scrollable } from './Scrollable';
|
import { Scrollable } from './Scrollable';
|
||||||
import { connectionBase } from '../utils/Global';
|
|
||||||
|
|
||||||
import './VectorVisualizer.css';
|
import './VectorVisualizer.css';
|
||||||
import { BackstoryPageProps } from './BackstoryTab';
|
import { BackstoryPageProps } from './BackstoryTab';
|
||||||
@ -195,7 +194,7 @@ const VectorVisualizer: React.FC<VectorVisualizerProps> = (props: VectorVisualiz
|
|||||||
const [plotDimensions, setPlotDimensions] = useState({ width: 0, height: 0 });
|
const [plotDimensions, setPlotDimensions] = useState({ width: 0, height: 0 });
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
|
|
||||||
const candidate: Types.Candidate | null = user?.userType === 'candidate' ? user : null;
|
const candidate: Types.Candidate | null = user?.userType === 'candidate' ? user as Types.Candidate : null;
|
||||||
|
|
||||||
/* Force resize of Plotly as it tends to not be the correct size if it is initially rendered
|
/* Force resize of Plotly as it tends to not be the correct size if it is initially rendered
|
||||||
* off screen (eg., the VectorVisualizer is not on the tab the app loads to) */
|
* off screen (eg., the VectorVisualizer is not on the tab the app loads to) */
|
||||||
|
@ -84,7 +84,6 @@ const BackstoryLayout: React.FC<BackstoryLayoutProps> = (props: BackstoryLayoutP
|
|||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
const location = useLocation();
|
const location = useLocation();
|
||||||
const { guest, user } = useAuth();
|
const { guest, user } = useAuth();
|
||||||
const { selectedCandidate } = useSelectedCandidate();
|
|
||||||
const [navigationItems, setNavigationItems] = useState<NavigationItem[]>([]);
|
const [navigationItems, setNavigationItems] = useState<NavigationItem[]>([]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@ -92,9 +91,13 @@ const BackstoryLayout: React.FC<BackstoryLayoutProps> = (props: BackstoryLayoutP
|
|||||||
setNavigationItems(getMainNavigationItems(userType, user?.isAdmin ? true : false));
|
setNavigationItems(getMainNavigationItems(userType, user?.isAdmin ? true : false));
|
||||||
}, [user]);
|
}, [user]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
console.log({ guest, user });
|
||||||
|
}, [guest, user]);
|
||||||
|
|
||||||
// Generate dynamic routes from navigation config
|
// Generate dynamic routes from navigation config
|
||||||
const generateRoutes = () => {
|
const generateRoutes = () => {
|
||||||
if (!guest) return null;
|
if (!guest && !user) return null;
|
||||||
|
|
||||||
const userType = user?.userType || null;
|
const userType = user?.userType || null;
|
||||||
const isAdmin = user?.isAdmin ? true : false;
|
const isAdmin = user?.isAdmin ? true : false;
|
||||||
@ -161,7 +164,7 @@ const BackstoryLayout: React.FC<BackstoryLayoutProps> = (props: BackstoryLayoutP
|
|||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<BackstoryPageContainer>
|
<BackstoryPageContainer>
|
||||||
{!guest && (
|
{!guest && !user && (
|
||||||
<Box>
|
<Box>
|
||||||
<LoadingComponent
|
<LoadingComponent
|
||||||
loadingText="Creating session..."
|
loadingText="Creating session..."
|
||||||
@ -171,7 +174,7 @@ const BackstoryLayout: React.FC<BackstoryLayoutProps> = (props: BackstoryLayoutP
|
|||||||
/>
|
/>
|
||||||
</Box>
|
</Box>
|
||||||
)}
|
)}
|
||||||
{guest && (
|
{(guest || user) && (
|
||||||
<>
|
<>
|
||||||
<Outlet />
|
<Outlet />
|
||||||
<Routes>
|
<Routes>
|
||||||
|
@ -156,7 +156,7 @@ const Header: React.FC<HeaderProps> = (props: HeaderProps) => {
|
|||||||
id: 'profile',
|
id: 'profile',
|
||||||
label: 'Profile',
|
label: 'Profile',
|
||||||
icon: <Person fontSize="small" />,
|
icon: <Person fontSize="small" />,
|
||||||
action: () => navigate(`/${user?.userType}/dashboard/profile`)
|
action: () => navigate(`/${user?.userType}/profile`)
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: 'dashboard',
|
id: 'dashboard',
|
||||||
|
@ -74,7 +74,7 @@ const CandidateInfo: React.FC<CandidateInfoProps> = (props: CandidateInfoProps)
|
|||||||
maxWidth: "80px"
|
maxWidth: "80px"
|
||||||
}}>
|
}}>
|
||||||
<Avatar
|
<Avatar
|
||||||
src={candidate.profileImage ? `/api/1.0/candidates/profile/${candidate.username}?timestamp=${Date.now()}` : ''}
|
src={candidate.profileImage ? `/api/1.0/candidates/profile/${candidate.username}` : ''}
|
||||||
alt={`${candidate.fullName}'s profile`}
|
alt={`${candidate.fullName}'s profile`}
|
||||||
sx={{
|
sx={{
|
||||||
alignSelf: "flex-start",
|
alignSelf: "flex-start",
|
||||||
|
@ -46,6 +46,7 @@ import { JobPicker } from 'components/ui/JobPicker';
|
|||||||
import { DocumentManager } from 'components/DocumentManager';
|
import { DocumentManager } from 'components/DocumentManager';
|
||||||
import { VectorVisualizer } from 'components/VectorVisualizer';
|
import { VectorVisualizer } from 'components/VectorVisualizer';
|
||||||
import { ComingSoon } from 'components/ui/ComingSoon';
|
import { ComingSoon } from 'components/ui/ComingSoon';
|
||||||
|
import { Beta } from 'components/ui/Beta';
|
||||||
|
|
||||||
// Beta page components for placeholder routes
|
// Beta page components for placeholder routes
|
||||||
const SearchPage = () => (<BetaPage><Typography variant="h4">Search</Typography></BetaPage>);
|
const SearchPage = () => (<BetaPage><Typography variant="h4">Search</Typography></BetaPage>);
|
||||||
@ -58,20 +59,19 @@ const SettingsPage = () => (<BetaPage><Typography variant="h4">Settings</Typogra
|
|||||||
|
|
||||||
export const navigationConfig: NavigationConfig = {
|
export const navigationConfig: NavigationConfig = {
|
||||||
items: [
|
items: [
|
||||||
{ id: 'home', label: <BackstoryLogo />, path: '/', component: <HomePage />, userTypes: ['guest', 'candidate', 'employer'], exact: true, },
|
{ id: 'home', label: <BackstoryLogo />, path: '/', component: <HomePage />, userTypes: ['guest', 'candidate', 'employer'], exact: true, },
|
||||||
{ id: 'chat', label: 'Chat', path: '/chat', icon: <ChatIcon />, component: <CandidateChatPage />, userTypes: ['guest', 'candidate', 'employer',], },
|
{ id: 'chat', label: 'Chat about a Candidate', path: '/chat', icon: <ChatIcon />, component: <CandidateChatPage />, userTypes: ['guest', 'candidate', 'employer',], },
|
||||||
{
|
{
|
||||||
id: 'candidate-menu', label: 'Tools', icon: <PersonIcon />, userTypes: ['candidate'], children: [
|
id: 'candidate-menu', label: 'Tools', icon: <PersonIcon />, userTypes: ['candidate'], children: [
|
||||||
{ id: 'candidate-dashboard', label: 'Dashboard', path: '/candidate/dashboard', icon: <DashboardIcon />, component: <CandidateDashboard />, userTypes: ['candidate'] },
|
{ id: 'candidate-dashboard', label: 'Dashboard', path: '/candidate/dashboard', icon: <DashboardIcon />, component: <CandidateDashboard />, userTypes: ['candidate'] },
|
||||||
{ id: 'candidate-profile', label: 'Profile', icon: <PersonIcon />, path: '/candidate/profile', component: <CandidateProfile />, userTypes: ['candidate'] },
|
{ id: 'candidate-profile', label: 'Profile', icon: <PersonIcon />, path: '/candidate/profile', component: <CandidateProfile />, userTypes: ['candidate'] },
|
||||||
{ id: 'candidate-qa-setup', label: 'Q&A Setup', icon: <QuizIcon />, path: '/candidate/qa-setup', component: <BetaPage><Box>Candidate q&a setup page</Box></BetaPage>, userTypes: ['candidate'] },
|
{ id: 'candidate-qa-setup', label: 'Q&A Setup', icon: <QuizIcon />, path: '/candidate/qa-setup', component: <BetaPage><Box>Candidate q&a setup page</Box></BetaPage>, userTypes: ['candidate'] },
|
||||||
{ id: 'candidate-analytics', label: 'Analytics', icon: <AnalyticsIcon />, path: '/candidate/analytics', component: <BetaPage><Box>Candidate analytics page</Box></BetaPage>, userTypes: ['candidate'] },
|
{ id: 'candidate-analytics', label: 'Analytics', icon: <AnalyticsIcon />, path: '/candidate/analytics', component: <BetaPage><Box>Candidate analytics page</Box></BetaPage>, userTypes: ['candidate'] },
|
||||||
{ id: 'candidate-jobs', label: 'Jobs', icon: <WorkIcon />, path: '/candidate/jobs', component: <JobPicker />, userTypes: ['candidate'] },
|
|
||||||
{ id: 'candidate-job-analysis', label: 'Job Analysis', path: '/candidate/job-analysis', icon: <WorkIcon />, component: <JobAnalysisPage />, userTypes: ['candidate'], },
|
{ id: 'candidate-job-analysis', label: 'Job Analysis', path: '/candidate/job-analysis', icon: <WorkIcon />, component: <JobAnalysisPage />, userTypes: ['candidate'], },
|
||||||
{ id: 'candidate-resumes', label: 'Resumes', icon: <DescriptionIcon />, path: '/candidate/resumes', component: <BetaPage><Box>Candidate resumes page</Box></BetaPage>, userTypes: ['candidate'] },
|
{ id: 'candidate-resumes', label: 'Resumes', icon: <DescriptionIcon />, path: '/candidate/resumes', component: <BetaPage><Box>Candidate resumes page</Box></BetaPage>, userTypes: ['candidate'] },
|
||||||
{ id: 'candidate-resume-builder', label: 'Resume Builder', path: '/candidate/resume-builder', icon: <DescriptionIcon />, component: <ResumeBuilderPage />, userTypes: ['candidate'], },
|
{ id: 'candidate-resume-builder', label: 'Resume Builder', path: '/candidate/resume-builder', icon: <DescriptionIcon />, component: <ResumeBuilderPage />, userTypes: ['candidate'], },
|
||||||
{ id: 'candidate-content', label: 'Content', icon: <BubbleChart />, path: '/candidate/content', component: <Box sx={{ display: "flex", width: "100%", flexDirection: "column" }}><VectorVisualizer /><DocumentManager /></Box>, userTypes: ['candidate'] },
|
{ id: 'candidate-content', label: 'Content', icon: <BubbleChart />, path: '/candidate/content', component: <Box sx={{ display: "flex", width: "100%", flexDirection: "column" }}><VectorVisualizer /><DocumentManager /></Box>, userTypes: ['candidate'] },
|
||||||
{ id: 'candidate-settings', label: 'Settings', path: '/candidate/settings', icon: <SettingsIcon />, component: <ComingSoon><Settings /></ComingSoon>, userTypes: ['candidate'], },
|
{ id: 'candidate-settings', label: 'Settings', path: '/candidate/settings', icon: <SettingsIcon />, component: <Settings />, userTypes: ['candidate'], },
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -87,7 +87,7 @@ export const navigationConfig: NavigationConfig = {
|
|||||||
{ id: 'employer-settings', label: 'Settings', path: '/employer/settings', icon: <SettingsIcon />, component: <SettingsPage />, userTypes: ['employer'], },
|
{ id: 'employer-settings', label: 'Settings', path: '/employer/settings', icon: <SettingsIcon />, component: <SettingsPage />, userTypes: ['employer'], },
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
{ id: 'find-candidate', label: 'Find a Candidate', path: '/find-a-candidate', icon: <PersonSearchIcon />, component: <CandidateListingPage />, userTypes: ['guest', 'candidate', 'employer'], },
|
// { id: 'find-candidate', label: 'Find a Candidate', path: '/find-a-candidate', icon: <PersonSearchIcon />, component: <CandidateListingPage />, userTypes: ['guest', 'candidate', 'employer'], },
|
||||||
{ id: 'docs', label: 'Docs', path: '/docs/*', icon: <LibraryBooksIcon />, component: <DocsPage />, userTypes: ['guest', 'candidate', 'employer'], },
|
{ id: 'docs', label: 'Docs', path: '/docs/*', icon: <LibraryBooksIcon />, component: <DocsPage />, userTypes: ['guest', 'candidate', 'employer'], },
|
||||||
{
|
{
|
||||||
id: 'admin-menu',
|
id: 'admin-menu',
|
||||||
|
@ -1,17 +1,19 @@
|
|||||||
|
// Replace the existing AuthContext.tsx with these enhancements
|
||||||
|
|
||||||
import React, { createContext, useContext, useState, useCallback, useEffect, useRef } from 'react';
|
import React, { createContext, useContext, useState, useCallback, useEffect, useRef } from 'react';
|
||||||
import * as Types from '../types/types';
|
import * as Types from '../types/types';
|
||||||
import { ApiClient, CreateCandidateRequest, CreateEmployerRequest } from '../services/api-client';
|
import { ApiClient, CreateCandidateRequest, CreateEmployerRequest, GuestConversionRequest } from '../services/api-client';
|
||||||
import { formatApiRequest, toCamelCase } from '../types/conversion';
|
import { formatApiRequest, toCamelCase } from '../types/conversion';
|
||||||
|
|
||||||
// ============================
|
// ============================
|
||||||
// Types and Interfaces
|
// Enhanced Types and Interfaces
|
||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
|
|
||||||
interface AuthState {
|
interface AuthState {
|
||||||
user: Types.User | null;
|
user: Types.User | null;
|
||||||
guest: Types.Guest | null;
|
guest: Types.Guest | null;
|
||||||
isAuthenticated: boolean;
|
isAuthenticated: boolean;
|
||||||
|
isGuest: boolean;
|
||||||
isLoading: boolean;
|
isLoading: boolean;
|
||||||
isInitializing: boolean;
|
isInitializing: boolean;
|
||||||
error: string | null;
|
error: string | null;
|
||||||
@ -36,7 +38,7 @@ interface PasswordResetRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ============================
|
// ============================
|
||||||
// Token Storage Constants
|
// Enhanced Token Storage Constants
|
||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
const TOKEN_STORAGE = {
|
const TOKEN_STORAGE = {
|
||||||
@ -44,7 +46,8 @@ const TOKEN_STORAGE = {
|
|||||||
REFRESH_TOKEN: 'refreshToken',
|
REFRESH_TOKEN: 'refreshToken',
|
||||||
USER_DATA: 'userData',
|
USER_DATA: 'userData',
|
||||||
TOKEN_EXPIRY: 'tokenExpiry',
|
TOKEN_EXPIRY: 'tokenExpiry',
|
||||||
GUEST_DATA: 'guestData',
|
USER_TYPE: 'userType',
|
||||||
|
IS_GUEST: 'isGuest',
|
||||||
PENDING_VERIFICATION_EMAIL: 'pendingVerificationEmail'
|
PENDING_VERIFICATION_EMAIL: 'pendingVerificationEmail'
|
||||||
} as const;
|
} as const;
|
||||||
|
|
||||||
@ -84,7 +87,7 @@ function isTokenExpired(token: string): boolean {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ============================
|
// ============================
|
||||||
// Storage Utilities with Date Conversion
|
// Enhanced Storage Utilities
|
||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
function clearStoredAuth(): void {
|
function clearStoredAuth(): void {
|
||||||
@ -92,6 +95,8 @@ function clearStoredAuth(): void {
|
|||||||
localStorage.removeItem(TOKEN_STORAGE.REFRESH_TOKEN);
|
localStorage.removeItem(TOKEN_STORAGE.REFRESH_TOKEN);
|
||||||
localStorage.removeItem(TOKEN_STORAGE.USER_DATA);
|
localStorage.removeItem(TOKEN_STORAGE.USER_DATA);
|
||||||
localStorage.removeItem(TOKEN_STORAGE.TOKEN_EXPIRY);
|
localStorage.removeItem(TOKEN_STORAGE.TOKEN_EXPIRY);
|
||||||
|
localStorage.removeItem(TOKEN_STORAGE.USER_TYPE);
|
||||||
|
localStorage.removeItem(TOKEN_STORAGE.IS_GUEST);
|
||||||
}
|
}
|
||||||
|
|
||||||
function prepareUserDataForStorage(user: Types.User): string {
|
function prepareUserDataForStorage(user: Types.User): string {
|
||||||
@ -119,11 +124,21 @@ function parseStoredUserData(userDataStr: string): Types.User | null {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function storeAuthData(authResponse: Types.AuthResponse): void {
|
function updateStoredUserData(user: Types.User): void {
|
||||||
|
try {
|
||||||
|
localStorage.setItem(TOKEN_STORAGE.USER_DATA, prepareUserDataForStorage(user));
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to update stored user data:', error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function storeAuthData(authResponse: Types.AuthResponse, isGuest: boolean = false): void {
|
||||||
localStorage.setItem(TOKEN_STORAGE.ACCESS_TOKEN, authResponse.accessToken);
|
localStorage.setItem(TOKEN_STORAGE.ACCESS_TOKEN, authResponse.accessToken);
|
||||||
localStorage.setItem(TOKEN_STORAGE.REFRESH_TOKEN, authResponse.refreshToken);
|
localStorage.setItem(TOKEN_STORAGE.REFRESH_TOKEN, authResponse.refreshToken);
|
||||||
localStorage.setItem(TOKEN_STORAGE.USER_DATA, prepareUserDataForStorage(authResponse.user));
|
localStorage.setItem(TOKEN_STORAGE.USER_DATA, prepareUserDataForStorage(authResponse.user));
|
||||||
localStorage.setItem(TOKEN_STORAGE.TOKEN_EXPIRY, authResponse.expiresAt.toString());
|
localStorage.setItem(TOKEN_STORAGE.TOKEN_EXPIRY, authResponse.expiresAt.toString());
|
||||||
|
localStorage.setItem(TOKEN_STORAGE.USER_TYPE, authResponse.user.userType);
|
||||||
|
localStorage.setItem(TOKEN_STORAGE.IS_GUEST, isGuest.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
function getStoredAuthData(): {
|
function getStoredAuthData(): {
|
||||||
@ -131,11 +146,15 @@ function getStoredAuthData(): {
|
|||||||
refreshToken: string | null;
|
refreshToken: string | null;
|
||||||
userData: Types.User | null;
|
userData: Types.User | null;
|
||||||
expiresAt: number | null;
|
expiresAt: number | null;
|
||||||
|
userType: string | null;
|
||||||
|
isGuest: boolean;
|
||||||
} {
|
} {
|
||||||
const accessToken = localStorage.getItem(TOKEN_STORAGE.ACCESS_TOKEN);
|
const accessToken = localStorage.getItem(TOKEN_STORAGE.ACCESS_TOKEN);
|
||||||
const refreshToken = localStorage.getItem(TOKEN_STORAGE.REFRESH_TOKEN);
|
const refreshToken = localStorage.getItem(TOKEN_STORAGE.REFRESH_TOKEN);
|
||||||
const userDataStr = localStorage.getItem(TOKEN_STORAGE.USER_DATA);
|
const userDataStr = localStorage.getItem(TOKEN_STORAGE.USER_DATA);
|
||||||
const expiryStr = localStorage.getItem(TOKEN_STORAGE.TOKEN_EXPIRY);
|
const expiryStr = localStorage.getItem(TOKEN_STORAGE.TOKEN_EXPIRY);
|
||||||
|
const userType = localStorage.getItem(TOKEN_STORAGE.USER_TYPE);
|
||||||
|
const isGuestStr = localStorage.getItem(TOKEN_STORAGE.IS_GUEST);
|
||||||
|
|
||||||
let userData: Types.User | null = null;
|
let userData: Types.User | null = null;
|
||||||
let expiresAt: number | null = null;
|
let expiresAt: number | null = null;
|
||||||
@ -152,55 +171,18 @@ function getStoredAuthData(): {
|
|||||||
clearStoredAuth();
|
clearStoredAuth();
|
||||||
}
|
}
|
||||||
|
|
||||||
return { accessToken, refreshToken, userData, expiresAt };
|
return {
|
||||||
}
|
accessToken,
|
||||||
|
refreshToken,
|
||||||
function updateStoredUserData(user: Types.User): void {
|
userData,
|
||||||
try {
|
expiresAt,
|
||||||
localStorage.setItem(TOKEN_STORAGE.USER_DATA, prepareUserDataForStorage(user));
|
userType,
|
||||||
} catch (error) {
|
isGuest: isGuestStr === 'true'
|
||||||
console.error('Failed to update stored user data:', error);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================
|
|
||||||
// Guest Session Utilities
|
|
||||||
// ============================
|
|
||||||
|
|
||||||
function createGuestSession(): Types.Guest {
|
|
||||||
const sessionId = `guest_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
|
|
||||||
const guest: Types.Guest = {
|
|
||||||
sessionId,
|
|
||||||
createdAt: new Date(),
|
|
||||||
lastActivity: new Date(),
|
|
||||||
ipAddress: 'unknown',
|
|
||||||
userAgent: navigator.userAgent
|
|
||||||
};
|
};
|
||||||
|
|
||||||
try {
|
|
||||||
localStorage.setItem(TOKEN_STORAGE.GUEST_DATA, JSON.stringify(formatApiRequest(guest)));
|
|
||||||
} catch (error) {
|
|
||||||
console.error('Failed to store guest session:', error);
|
|
||||||
}
|
|
||||||
|
|
||||||
return guest;
|
|
||||||
}
|
|
||||||
|
|
||||||
function getStoredGuestData(): Types.Guest | null {
|
|
||||||
try {
|
|
||||||
const guestDataStr = localStorage.getItem(TOKEN_STORAGE.GUEST_DATA);
|
|
||||||
if (guestDataStr) {
|
|
||||||
return toCamelCase<Types.Guest>(JSON.parse(guestDataStr));
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
console.error('Failed to parse stored guest data:', error);
|
|
||||||
localStorage.removeItem(TOKEN_STORAGE.GUEST_DATA);
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================
|
// ============================
|
||||||
// Main Authentication Hook
|
// Enhanced Authentication Hook
|
||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
function useAuthenticationLogic() {
|
function useAuthenticationLogic() {
|
||||||
@ -208,6 +190,7 @@ function useAuthenticationLogic() {
|
|||||||
user: null,
|
user: null,
|
||||||
guest: null,
|
guest: null,
|
||||||
isAuthenticated: false,
|
isAuthenticated: false,
|
||||||
|
isGuest: false,
|
||||||
isLoading: false,
|
isLoading: false,
|
||||||
isInitializing: true,
|
isInitializing: true,
|
||||||
error: null,
|
error: null,
|
||||||
@ -216,6 +199,7 @@ function useAuthenticationLogic() {
|
|||||||
|
|
||||||
const [apiClient] = useState(() => new ApiClient());
|
const [apiClient] = useState(() => new ApiClient());
|
||||||
const initializationCompleted = useRef(false);
|
const initializationCompleted = useRef(false);
|
||||||
|
const guestCreationAttempted = useRef(false);
|
||||||
|
|
||||||
// Token refresh function
|
// Token refresh function
|
||||||
const refreshAccessToken = useCallback(async (refreshToken: string): Promise<Types.AuthResponse | null> => {
|
const refreshAccessToken = useCallback(async (refreshToken: string): Promise<Types.AuthResponse | null> => {
|
||||||
@ -228,6 +212,58 @@ function useAuthenticationLogic() {
|
|||||||
}
|
}
|
||||||
}, [apiClient]);
|
}, [apiClient]);
|
||||||
|
|
||||||
|
// Create guest session
|
||||||
|
const createGuestSession = useCallback(async (): Promise<boolean> => {
|
||||||
|
if (guestCreationAttempted.current) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
guestCreationAttempted.current = true;
|
||||||
|
|
||||||
|
try {
|
||||||
|
console.log('🔄 Creating guest session...');
|
||||||
|
const guestAuth = await apiClient.createGuestSession();
|
||||||
|
|
||||||
|
if (guestAuth && guestAuth.user && guestAuth.user.userType === 'guest') {
|
||||||
|
storeAuthData(guestAuth, true);
|
||||||
|
apiClient.setAuthToken(guestAuth.accessToken);
|
||||||
|
|
||||||
|
setAuthState({
|
||||||
|
user: null,
|
||||||
|
guest: guestAuth.user as Types.Guest,
|
||||||
|
isAuthenticated: true,
|
||||||
|
isGuest: true,
|
||||||
|
isLoading: false,
|
||||||
|
isInitializing: false,
|
||||||
|
error: null,
|
||||||
|
mfaResponse: null,
|
||||||
|
});
|
||||||
|
|
||||||
|
console.log('👤 Guest session created successfully:', guestAuth.user);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
} catch (error) {
|
||||||
|
console.error('❌ Failed to create guest session:', error);
|
||||||
|
guestCreationAttempted.current = false;
|
||||||
|
|
||||||
|
// Set to unauthenticated state if guest creation fails
|
||||||
|
setAuthState(prev => ({
|
||||||
|
...prev,
|
||||||
|
user: null,
|
||||||
|
guest: null,
|
||||||
|
isAuthenticated: false,
|
||||||
|
isGuest: false,
|
||||||
|
isLoading: false,
|
||||||
|
isInitializing: false,
|
||||||
|
error: 'Failed to create guest session',
|
||||||
|
}));
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}, [apiClient]);
|
||||||
|
|
||||||
// Initialize authentication state
|
// Initialize authentication state
|
||||||
const initializeAuth = useCallback(async () => {
|
const initializeAuth = useCallback(async () => {
|
||||||
if (initializationCompleted.current) {
|
if (initializationCompleted.current) {
|
||||||
@ -235,99 +271,94 @@ function useAuthenticationLogic() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Initialize guest session first
|
const stored = getStoredAuthData();
|
||||||
let guest = getStoredGuestData();
|
|
||||||
if (!guest) {
|
// If no stored tokens, create guest session
|
||||||
guest = createGuestSession();
|
if (!stored.accessToken || !stored.refreshToken || !stored.userData) {
|
||||||
|
console.log('🔄 No stored auth found, creating guest session...');
|
||||||
|
await createGuestSession();
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const stored = getStoredAuthData();
|
// For guests, always verify the session exists on server
|
||||||
|
if (stored.userType === 'guest' && stored.userData) {
|
||||||
// If no stored tokens, user is not authenticated but has guest session
|
console.log(stored.userData);
|
||||||
if (!stored.accessToken || !stored.refreshToken || !stored.userData) {
|
try {
|
||||||
setAuthState({
|
// Make a quick API call to verify guest still exists
|
||||||
user: null,
|
const response = await fetch(`${apiClient.getBaseUrl()}/users/${stored.userData.id}`, {
|
||||||
guest,
|
headers: { 'Authorization': `Bearer ${stored.accessToken}` }
|
||||||
isAuthenticated: false,
|
});
|
||||||
isLoading: false,
|
|
||||||
isInitializing: false,
|
if (!response.ok) {
|
||||||
error: null,
|
console.log('🔄 Guest session invalid, creating new guest session...');
|
||||||
mfaResponse: null,
|
clearStoredAuth();
|
||||||
});
|
await createGuestSession();
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.log('🔄 Guest verification failed, creating new guest session...');
|
||||||
|
clearStoredAuth();
|
||||||
|
await createGuestSession();
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if access token is expired
|
// Check if access token is expired
|
||||||
if (isTokenExpired(stored.accessToken)) {
|
if (isTokenExpired(stored.accessToken)) {
|
||||||
console.log('Access token expired, attempting refresh...');
|
console.log('🔄 Access token expired, attempting refresh...');
|
||||||
|
|
||||||
const refreshResult = await refreshAccessToken(stored.refreshToken);
|
const refreshResult = await refreshAccessToken(stored.refreshToken);
|
||||||
|
|
||||||
if (refreshResult) {
|
if (refreshResult) {
|
||||||
storeAuthData(refreshResult);
|
const isGuest = stored.userType === 'guest';
|
||||||
|
storeAuthData(refreshResult, isGuest);
|
||||||
apiClient.setAuthToken(refreshResult.accessToken);
|
apiClient.setAuthToken(refreshResult.accessToken);
|
||||||
|
|
||||||
setAuthState({
|
setAuthState({
|
||||||
user: refreshResult.user,
|
user: isGuest ? null : refreshResult.user,
|
||||||
guest,
|
guest: isGuest ? refreshResult.user as Types.Guest : null,
|
||||||
isAuthenticated: true,
|
isAuthenticated: true,
|
||||||
|
isGuest,
|
||||||
isLoading: false,
|
isLoading: false,
|
||||||
isInitializing: false,
|
isInitializing: false,
|
||||||
error: null,
|
error: null,
|
||||||
mfaResponse: null
|
mfaResponse: null
|
||||||
});
|
});
|
||||||
|
|
||||||
console.log('Token refreshed successfully');
|
console.log('✅ Token refreshed successfully');
|
||||||
} else {
|
} else {
|
||||||
console.log('Token refresh failed, clearing stored auth');
|
console.log('❌ Token refresh failed, creating new guest session...');
|
||||||
clearStoredAuth();
|
clearStoredAuth();
|
||||||
apiClient.clearAuthToken();
|
apiClient.clearAuthToken();
|
||||||
|
await createGuestSession();
|
||||||
setAuthState({
|
|
||||||
user: null,
|
|
||||||
guest,
|
|
||||||
isAuthenticated: false,
|
|
||||||
isLoading: false,
|
|
||||||
isInitializing: false,
|
|
||||||
error: null,
|
|
||||||
mfaResponse: null
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Access token is still valid
|
// Access token is still valid
|
||||||
apiClient.setAuthToken(stored.accessToken);
|
apiClient.setAuthToken(stored.accessToken);
|
||||||
|
const isGuest = stored.userType === 'guest';
|
||||||
|
|
||||||
setAuthState({
|
setAuthState({
|
||||||
user: stored.userData,
|
user: isGuest ? null : stored.userData,
|
||||||
guest,
|
guest: isGuest ? stored.userData as Types.Guest : null,
|
||||||
isAuthenticated: true,
|
isAuthenticated: true,
|
||||||
|
isGuest,
|
||||||
isLoading: false,
|
isLoading: false,
|
||||||
isInitializing: false,
|
isInitializing: false,
|
||||||
error: null,
|
error: null,
|
||||||
mfaResponse: null
|
mfaResponse: null
|
||||||
});
|
});
|
||||||
|
|
||||||
console.log('Restored authentication from stored tokens');
|
console.log('✅ Restored authentication from stored tokens');
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error initializing auth:', error);
|
console.error('❌ Error initializing auth:', error);
|
||||||
clearStoredAuth();
|
clearStoredAuth();
|
||||||
apiClient.clearAuthToken();
|
apiClient.clearAuthToken();
|
||||||
|
await createGuestSession();
|
||||||
const guest = createGuestSession();
|
|
||||||
setAuthState({
|
|
||||||
user: null,
|
|
||||||
guest,
|
|
||||||
isAuthenticated: false,
|
|
||||||
isLoading: false,
|
|
||||||
isInitializing: false,
|
|
||||||
error: null,
|
|
||||||
mfaResponse: null
|
|
||||||
});
|
|
||||||
} finally {
|
} finally {
|
||||||
initializationCompleted.current = true;
|
initializationCompleted.current = true;
|
||||||
}
|
}
|
||||||
}, [apiClient, refreshAccessToken]);
|
}, [apiClient, refreshAccessToken, createGuestSession]);
|
||||||
|
|
||||||
// Run initialization on mount
|
// Run initialization on mount
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@ -355,7 +386,7 @@ function useAuthenticationLogic() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const refreshTimer = setTimeout(() => {
|
const refreshTimer = setTimeout(() => {
|
||||||
console.log('Auto-refreshing token before expiry...');
|
console.log('🔄 Auto-refreshing token before expiry...');
|
||||||
initializeAuth();
|
initializeAuth();
|
||||||
}, timeUntilExpiry);
|
}, timeUntilExpiry);
|
||||||
|
|
||||||
@ -364,7 +395,7 @@ function useAuthenticationLogic() {
|
|||||||
|
|
||||||
// Enhanced login with MFA support
|
// Enhanced login with MFA support
|
||||||
const login = useCallback(async (loginData: LoginRequest): Promise<boolean> => {
|
const login = useCallback(async (loginData: LoginRequest): Promise<boolean> => {
|
||||||
setAuthState(prev => ({ ...prev, isLoading: true, error: null, mfaResponse: null, mfaData: null }));
|
setAuthState(prev => ({ ...prev, isLoading: true, error: null, mfaResponse: null }));
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const result = await apiClient.login({
|
const result = await apiClient.login({
|
||||||
@ -381,21 +412,23 @@ function useAuthenticationLogic() {
|
|||||||
}));
|
}));
|
||||||
return false; // Login not complete yet
|
return false; // Login not complete yet
|
||||||
} else {
|
} else {
|
||||||
// Normal login success
|
// Normal login success - convert from guest to authenticated user
|
||||||
const authResponse: Types.AuthResponse = result;
|
const authResponse: Types.AuthResponse = result;
|
||||||
storeAuthData(authResponse);
|
storeAuthData(authResponse, false);
|
||||||
apiClient.setAuthToken(authResponse.accessToken);
|
apiClient.setAuthToken(authResponse.accessToken);
|
||||||
|
|
||||||
setAuthState(prev => ({
|
setAuthState(prev => ({
|
||||||
...prev,
|
...prev,
|
||||||
user: authResponse.user,
|
user: authResponse.user,
|
||||||
|
guest: null,
|
||||||
isAuthenticated: true,
|
isAuthenticated: true,
|
||||||
|
isGuest: false,
|
||||||
isLoading: false,
|
isLoading: false,
|
||||||
error: null,
|
error: null,
|
||||||
mfaResponse: null,
|
mfaResponse: null,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
console.log('Login successful');
|
console.log('✅ Login successful, converted from guest to authenticated user');
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
@ -410,6 +443,44 @@ function useAuthenticationLogic() {
|
|||||||
}
|
}
|
||||||
}, [apiClient]);
|
}, [apiClient]);
|
||||||
|
|
||||||
|
// Convert guest to permanent user
|
||||||
|
const convertGuestToUser = useCallback(async (registrationData: GuestConversionRequest): Promise<boolean> => {
|
||||||
|
if (!authState.isGuest || !authState.guest) {
|
||||||
|
throw new Error('Not currently a guest user');
|
||||||
|
}
|
||||||
|
|
||||||
|
setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
|
||||||
|
|
||||||
|
try {
|
||||||
|
const result = await apiClient.convertGuestToUser(registrationData);
|
||||||
|
|
||||||
|
// Store new authentication
|
||||||
|
storeAuthData(result.auth, false);
|
||||||
|
apiClient.setAuthToken(result.auth.accessToken);
|
||||||
|
|
||||||
|
setAuthState(prev => ({
|
||||||
|
...prev,
|
||||||
|
user: result.auth.user,
|
||||||
|
guest: null,
|
||||||
|
isAuthenticated: true,
|
||||||
|
isGuest: false,
|
||||||
|
isLoading: false,
|
||||||
|
error: null,
|
||||||
|
}));
|
||||||
|
|
||||||
|
console.log('✅ Guest successfully converted to permanent user');
|
||||||
|
return true;
|
||||||
|
} catch (error: any) {
|
||||||
|
const errorMessage = error instanceof Error ? error.message : 'Failed to convert guest account';
|
||||||
|
setAuthState(prev => ({
|
||||||
|
...prev,
|
||||||
|
isLoading: false,
|
||||||
|
error: errorMessage,
|
||||||
|
}));
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}, [apiClient, authState.isGuest, authState.guest]);
|
||||||
|
|
||||||
// MFA verification
|
// MFA verification
|
||||||
const verifyMFA = useCallback(async (mfaData: Types.MFAVerifyRequest): Promise<boolean> => {
|
const verifyMFA = useCallback(async (mfaData: Types.MFAVerifyRequest): Promise<boolean> => {
|
||||||
setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
|
setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
|
||||||
@ -419,26 +490,27 @@ function useAuthenticationLogic() {
|
|||||||
|
|
||||||
if (result.accessToken) {
|
if (result.accessToken) {
|
||||||
const authResponse: Types.AuthResponse = result;
|
const authResponse: Types.AuthResponse = result;
|
||||||
storeAuthData(authResponse);
|
storeAuthData(authResponse, false);
|
||||||
apiClient.setAuthToken(authResponse.accessToken);
|
apiClient.setAuthToken(authResponse.accessToken);
|
||||||
|
|
||||||
setAuthState(prev => ({
|
setAuthState(prev => ({
|
||||||
...prev,
|
...prev,
|
||||||
user: authResponse.user,
|
user: authResponse.user,
|
||||||
|
guest: null,
|
||||||
isAuthenticated: true,
|
isAuthenticated: true,
|
||||||
|
isGuest: false,
|
||||||
isLoading: false,
|
isLoading: false,
|
||||||
error: null,
|
error: null,
|
||||||
mfaResponse: null,
|
mfaResponse: null,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
console.log('MFA verification successful');
|
console.log('✅ MFA verification successful, converted from guest');
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
const errorMessage = error instanceof Error ? error.message : 'MFA verification failed';
|
const errorMessage = error instanceof Error ? error.message : 'MFA verification failed';
|
||||||
console.log(errorMessage);
|
|
||||||
setAuthState(prev => ({
|
setAuthState(prev => ({
|
||||||
...prev,
|
...prev,
|
||||||
isLoading: false,
|
isLoading: false,
|
||||||
@ -448,42 +520,48 @@ function useAuthenticationLogic() {
|
|||||||
}
|
}
|
||||||
}, [apiClient]);
|
}, [apiClient]);
|
||||||
|
|
||||||
// Resend MFA code
|
// Logout - returns to guest session
|
||||||
const resendMFACode = useCallback(async (email: string, deviceId: string, deviceName: string): Promise<boolean> => {
|
const logout = useCallback(async () => {
|
||||||
setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await apiClient.requestMFA({
|
// If authenticated, try to logout gracefully
|
||||||
email,
|
if (authState.isAuthenticated && !authState.isGuest) {
|
||||||
password: '', // This would need to be stored securely or re-entered
|
const stored = getStoredAuthData();
|
||||||
deviceId,
|
if (stored.accessToken && stored.refreshToken) {
|
||||||
deviceName,
|
try {
|
||||||
});
|
await apiClient.logout(stored.accessToken, stored.refreshToken);
|
||||||
|
} catch (error) {
|
||||||
setAuthState(prev => ({ ...prev, isLoading: false }));
|
console.warn('Logout request failed, proceeding with local cleanup');
|
||||||
return true;
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
const errorMessage = error instanceof Error ? error.message : 'Failed to resend MFA code';
|
console.warn('Error during logout:', error);
|
||||||
setAuthState(prev => ({
|
} finally {
|
||||||
...prev,
|
// Always clear stored auth and create new guest session
|
||||||
isLoading: false,
|
clearStoredAuth();
|
||||||
error: errorMessage
|
apiClient.clearAuthToken();
|
||||||
}));
|
guestCreationAttempted.current = false;
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}, [apiClient]);
|
|
||||||
|
|
||||||
// Clear MFA state
|
// Create new guest session
|
||||||
const clearMFA = useCallback(() => {
|
await createGuestSession();
|
||||||
|
|
||||||
|
console.log('🔄 Logged out, created new guest session');
|
||||||
|
}
|
||||||
|
}, [apiClient, authState.isAuthenticated, authState.isGuest, createGuestSession]);
|
||||||
|
|
||||||
|
// Update user data
|
||||||
|
const updateUserData = useCallback((updatedUser: Types.User) => {
|
||||||
|
updateStoredUserData(updatedUser);
|
||||||
setAuthState(prev => ({
|
setAuthState(prev => ({
|
||||||
...prev,
|
...prev,
|
||||||
mfaResponse: null,
|
user: authState.isGuest ? null : updatedUser,
|
||||||
error: null
|
guest: authState.isGuest ? updatedUser as Types.Guest : prev.guest
|
||||||
}));
|
}));
|
||||||
}, []);
|
console.log('✅ User data updated');
|
||||||
|
}, [authState.isGuest]);
|
||||||
|
|
||||||
// Email verification
|
// Email verification functions (unchanged)
|
||||||
const verifyEmail = useCallback(async (verificationData: EmailVerificationRequest): Promise<{ message: string; userType: string } | null> => {
|
const verifyEmail = useCallback(async (verificationData: EmailVerificationRequest) => {
|
||||||
setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
|
setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@ -504,7 +582,7 @@ function useAuthenticationLogic() {
|
|||||||
}
|
}
|
||||||
}, [apiClient]);
|
}, [apiClient]);
|
||||||
|
|
||||||
// Resend email verification
|
// Other existing methods remain the same...
|
||||||
const resendEmailVerification = useCallback(async (email: string): Promise<boolean> => {
|
const resendEmailVerification = useCallback(async (email: string): Promise<boolean> => {
|
||||||
setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
|
setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
|
||||||
|
|
||||||
@ -523,53 +601,21 @@ function useAuthenticationLogic() {
|
|||||||
}
|
}
|
||||||
}, [apiClient]);
|
}, [apiClient]);
|
||||||
|
|
||||||
// Store pending verification email
|
|
||||||
const setPendingVerificationEmail = useCallback((email: string) => {
|
const setPendingVerificationEmail = useCallback((email: string) => {
|
||||||
localStorage.setItem(TOKEN_STORAGE.PENDING_VERIFICATION_EMAIL, email);
|
localStorage.setItem(TOKEN_STORAGE.PENDING_VERIFICATION_EMAIL, email);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
// Get pending verification email
|
|
||||||
const getPendingVerificationEmail = useCallback((): string | null => {
|
const getPendingVerificationEmail = useCallback((): string | null => {
|
||||||
return localStorage.getItem(TOKEN_STORAGE.PENDING_VERIFICATION_EMAIL);
|
return localStorage.getItem(TOKEN_STORAGE.PENDING_VERIFICATION_EMAIL);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const logout = useCallback(() => {
|
|
||||||
clearStoredAuth();
|
|
||||||
apiClient.clearAuthToken();
|
|
||||||
|
|
||||||
// Create new guest session after logout
|
|
||||||
const guest = createGuestSession();
|
|
||||||
|
|
||||||
setAuthState(prev => ({
|
|
||||||
...prev,
|
|
||||||
user: null,
|
|
||||||
guest,
|
|
||||||
isAuthenticated: false,
|
|
||||||
isLoading: false,
|
|
||||||
error: null,
|
|
||||||
mfaResponse: null,
|
|
||||||
}));
|
|
||||||
|
|
||||||
console.log('User logged out');
|
|
||||||
}, [apiClient]);
|
|
||||||
|
|
||||||
const updateUserData = useCallback((updatedUser: Types.User) => {
|
|
||||||
updateStoredUserData(updatedUser);
|
|
||||||
setAuthState(prev => ({
|
|
||||||
...prev,
|
|
||||||
user: updatedUser
|
|
||||||
}));
|
|
||||||
console.log('User data updated', updatedUser);
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const createEmployerAccount = useCallback(async (employerData: CreateEmployerRequest): Promise<boolean> => {
|
const createEmployerAccount = useCallback(async (employerData: CreateEmployerRequest): Promise<boolean> => {
|
||||||
setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
|
setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const employer = await apiClient.createEmployer(employerData);
|
const employer = await apiClient.createEmployer(employerData);
|
||||||
console.log('Employer created:', employer);
|
console.log('✅ Employer created:', employer);
|
||||||
|
|
||||||
// Store email for potential verification resend
|
|
||||||
setPendingVerificationEmail(employerData.email);
|
setPendingVerificationEmail(employerData.email);
|
||||||
|
|
||||||
setAuthState(prev => ({ ...prev, isLoading: false }));
|
setAuthState(prev => ({ ...prev, isLoading: false }));
|
||||||
@ -614,24 +660,61 @@ function useAuthenticationLogic() {
|
|||||||
const refreshResult = await refreshAccessToken(stored.refreshToken);
|
const refreshResult = await refreshAccessToken(stored.refreshToken);
|
||||||
|
|
||||||
if (refreshResult) {
|
if (refreshResult) {
|
||||||
storeAuthData(refreshResult);
|
const isGuest = stored.userType === 'guest';
|
||||||
|
storeAuthData(refreshResult, isGuest);
|
||||||
apiClient.setAuthToken(refreshResult.accessToken);
|
apiClient.setAuthToken(refreshResult.accessToken);
|
||||||
|
|
||||||
setAuthState(prev => ({
|
setAuthState(prev => ({
|
||||||
...prev,
|
...prev,
|
||||||
user: refreshResult.user,
|
user: isGuest ? null : refreshResult.user,
|
||||||
|
guest: isGuest ? refreshResult.user as Types.Guest : null,
|
||||||
isAuthenticated: true,
|
isAuthenticated: true,
|
||||||
|
isGuest,
|
||||||
isLoading: false,
|
isLoading: false,
|
||||||
error: null
|
error: null
|
||||||
}));
|
}));
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
logout();
|
await logout();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}, [refreshAccessToken, logout]);
|
}, [refreshAccessToken, logout]);
|
||||||
|
|
||||||
|
// Resend MFA code
|
||||||
|
const resendMFACode = useCallback(async (email: string, deviceId: string, deviceName: string): Promise<boolean> => {
|
||||||
|
setAuthState(prev => ({ ...prev, isLoading: true, error: null }));
|
||||||
|
|
||||||
|
try {
|
||||||
|
await apiClient.requestMFA({
|
||||||
|
email,
|
||||||
|
password: '', // This would need to be stored securely or re-entered
|
||||||
|
deviceId,
|
||||||
|
deviceName,
|
||||||
|
});
|
||||||
|
|
||||||
|
setAuthState(prev => ({ ...prev, isLoading: false }));
|
||||||
|
return true;
|
||||||
|
} catch (error) {
|
||||||
|
const errorMessage = error instanceof Error ? error.message : 'Failed to resend MFA code';
|
||||||
|
setAuthState(prev => ({
|
||||||
|
...prev,
|
||||||
|
isLoading: false,
|
||||||
|
error: errorMessage
|
||||||
|
}));
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}, [apiClient]);
|
||||||
|
|
||||||
|
// Clear MFA state
|
||||||
|
const clearMFA = useCallback(() => {
|
||||||
|
setAuthState(prev => ({
|
||||||
|
...prev,
|
||||||
|
mfaResponse: null,
|
||||||
|
error: null
|
||||||
|
}));
|
||||||
|
}, []);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
...authState,
|
...authState,
|
||||||
apiClient,
|
apiClient,
|
||||||
@ -647,12 +730,14 @@ function useAuthenticationLogic() {
|
|||||||
createEmployerAccount,
|
createEmployerAccount,
|
||||||
requestPasswordReset,
|
requestPasswordReset,
|
||||||
refreshAuth,
|
refreshAuth,
|
||||||
updateUserData
|
updateUserData,
|
||||||
|
convertGuestToUser,
|
||||||
|
createGuestSession
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================
|
// ============================
|
||||||
// Context Provider
|
// Enhanced Context Provider
|
||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
const AuthContext = createContext<ReturnType<typeof useAuthenticationLogic> | null>(null);
|
const AuthContext = createContext<ReturnType<typeof useAuthenticationLogic> | null>(null);
|
||||||
@ -676,34 +761,41 @@ function useAuth() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ============================
|
// ============================
|
||||||
// Protected Route Component
|
// Enhanced Protected Route Component
|
||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
interface ProtectedRouteProps {
|
interface ProtectedRouteProps {
|
||||||
children: React.ReactNode;
|
children: React.ReactNode;
|
||||||
fallback?: React.ReactNode;
|
fallback?: React.ReactNode;
|
||||||
requiredUserType?: Types.UserType;
|
requiredUserType?: Types.UserType;
|
||||||
|
allowGuests?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
function ProtectedRoute({
|
function ProtectedRoute({
|
||||||
children,
|
children,
|
||||||
fallback = <div>Please log in to access this page.</div>,
|
fallback = <div>Please log in to access this page.</div>,
|
||||||
requiredUserType
|
requiredUserType,
|
||||||
|
allowGuests = false
|
||||||
}: ProtectedRouteProps) {
|
}: ProtectedRouteProps) {
|
||||||
const { isAuthenticated, isInitializing, user } = useAuth();
|
const { isAuthenticated, isInitializing, user, isGuest } = useAuth();
|
||||||
|
|
||||||
// Show loading while checking stored tokens
|
// Show loading while checking stored tokens
|
||||||
if (isInitializing) {
|
if (isInitializing) {
|
||||||
return <div>Loading...</div>;
|
return <div>Loading...</div>;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Not authenticated
|
// Not authenticated at all (shouldn't happen with guest sessions)
|
||||||
if (!isAuthenticated) {
|
if (!isAuthenticated) {
|
||||||
return <>{fallback}</>;
|
return <>{fallback}</>;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check user type if required
|
// Guest access control
|
||||||
if (requiredUserType && user?.userType !== requiredUserType) {
|
if (isGuest && !allowGuests) {
|
||||||
|
return <div>Please create an account or log in to access this page.</div>;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check user type if required (only for non-guests)
|
||||||
|
if (requiredUserType && !isGuest && user?.userType !== requiredUserType) {
|
||||||
return <div>Access denied. Required user type: {requiredUserType}</div>;
|
return <div>Access denied. Required user type: {requiredUserType}</div>;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -711,11 +803,19 @@ function ProtectedRoute({
|
|||||||
}
|
}
|
||||||
|
|
||||||
export type {
|
export type {
|
||||||
AuthState, LoginRequest, EmailVerificationRequest, ResendVerificationRequest, PasswordResetRequest
|
AuthState,
|
||||||
|
LoginRequest,
|
||||||
|
EmailVerificationRequest,
|
||||||
|
ResendVerificationRequest,
|
||||||
|
PasswordResetRequest,
|
||||||
|
GuestConversionRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
export type { CreateCandidateRequest, CreateEmployerRequest } from '../services/api-client';
|
export type { CreateCandidateRequest, CreateEmployerRequest } from '../services/api-client';
|
||||||
|
|
||||||
export {
|
export {
|
||||||
useAuthenticationLogic, AuthProvider, useAuth, ProtectedRoute
|
useAuthenticationLogic,
|
||||||
|
AuthProvider,
|
||||||
|
useAuth,
|
||||||
|
ProtectedRoute
|
||||||
}
|
}
|
@ -23,15 +23,16 @@ import { useAppState, useSelectedCandidate } from 'hooks/GlobalContext';
|
|||||||
import PropagateLoader from 'react-spinners/PropagateLoader';
|
import PropagateLoader from 'react-spinners/PropagateLoader';
|
||||||
import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryTextField';
|
import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryTextField';
|
||||||
import { BackstoryQuery } from 'components/BackstoryQuery';
|
import { BackstoryQuery } from 'components/BackstoryQuery';
|
||||||
|
import { CandidatePicker } from 'components/ui/CandidatePicker';
|
||||||
|
|
||||||
const defaultMessage: ChatMessage = {
|
const defaultMessage: ChatMessage = {
|
||||||
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "user"
|
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "user", metadata: null as any
|
||||||
};
|
};
|
||||||
|
|
||||||
const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: BackstoryPageProps, ref) => {
|
const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: BackstoryPageProps, ref) => {
|
||||||
const { apiClient } = useAuth();
|
const { apiClient } = useAuth();
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
const { selectedCandidate } = useSelectedCandidate()
|
const { selectedCandidate, setSelectedCandidate } = useSelectedCandidate()
|
||||||
const theme = useTheme();
|
const theme = useTheme();
|
||||||
const [processingMessage, setProcessingMessage] = useState<ChatMessageStatus | ChatMessageError | null>(null);
|
const [processingMessage, setProcessingMessage] = useState<ChatMessageStatus | ChatMessageError | null>(null);
|
||||||
const [streamingMessage, setStreamingMessage] = useState<ChatMessage | null>(null);
|
const [streamingMessage, setStreamingMessage] = useState<ChatMessage | null>(null);
|
||||||
@ -92,7 +93,7 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
|||||||
timestamp: new Date()
|
timestamp: new Date()
|
||||||
};
|
};
|
||||||
|
|
||||||
setProcessingMessage({ ...defaultMessage, status: 'status', content: `Establishing connection with ${selectedCandidate.firstName}'s chat session.` });
|
setProcessingMessage({ ...defaultMessage, status: 'status', activity: "info", content: `Establishing connection with ${selectedCandidate.firstName}'s chat session.` });
|
||||||
|
|
||||||
setMessages(prev => {
|
setMessages(prev => {
|
||||||
const filtered = prev.filter((m: any) => m.id !== chatMessage.id);
|
const filtered = prev.filter((m: any) => m.id !== chatMessage.id);
|
||||||
@ -123,7 +124,7 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
|||||||
},
|
},
|
||||||
onStreaming: (chunk: ChatMessageStreaming) => {
|
onStreaming: (chunk: ChatMessageStreaming) => {
|
||||||
// console.log("onStreaming:", chunk);
|
// console.log("onStreaming:", chunk);
|
||||||
setStreamingMessage({ ...chunk, role: 'assistant' });
|
setStreamingMessage({ ...chunk, role: 'assistant', metadata: null as any });
|
||||||
},
|
},
|
||||||
onStatus: (status: ChatMessageStatus) => {
|
onStatus: (status: ChatMessageStatus) => {
|
||||||
setProcessingMessage(status);
|
setProcessingMessage(status);
|
||||||
@ -171,8 +172,7 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
|||||||
}, [chatSession]);
|
}, [chatSession]);
|
||||||
|
|
||||||
if (!selectedCandidate) {
|
if (!selectedCandidate) {
|
||||||
navigate('/find-a-candidate');
|
return <CandidatePicker />;
|
||||||
return (<></>);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const welcomeMessage: ChatMessage = {
|
const welcomeMessage: ChatMessage = {
|
||||||
@ -181,7 +181,8 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
|||||||
type: "text",
|
type: "text",
|
||||||
status: "done",
|
status: "done",
|
||||||
timestamp: new Date(),
|
timestamp: new Date(),
|
||||||
content: `Welcome to the Backstory Chat about ${selectedCandidate.fullName}. Ask any questions you have about ${selectedCandidate.firstName}.`
|
content: `Welcome to the Backstory Chat about ${selectedCandidate.fullName}. Ask any questions you have about ${selectedCandidate.firstName}.`,
|
||||||
|
metadata: null as any
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -192,12 +193,14 @@ const CandidateChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((pr
|
|||||||
gap: 1,
|
gap: 1,
|
||||||
}}>
|
}}>
|
||||||
<CandidateInfo
|
<CandidateInfo
|
||||||
|
key={selectedCandidate.username}
|
||||||
action={`Chat with Backstory about ${selectedCandidate.firstName}`}
|
action={`Chat with Backstory about ${selectedCandidate.firstName}`}
|
||||||
elevation={4}
|
elevation={4}
|
||||||
candidate={selectedCandidate}
|
candidate={selectedCandidate}
|
||||||
variant="small"
|
variant="small"
|
||||||
sx={{ flexShrink: 0 }} // Prevent header from shrinking
|
sx={{ flexShrink: 0 }} // Prevent header from shrinking
|
||||||
/>
|
/>
|
||||||
|
<Button onClick={() => { setSelectedCandidate(null); }} variant="contained">Change Candidates</Button>
|
||||||
|
|
||||||
{/* Chat Interface */}
|
{/* Chat Interface */}
|
||||||
<Paper
|
<Paper
|
||||||
|
@ -23,7 +23,7 @@ import { Message } from 'components/Message';
|
|||||||
import { useAppState } from 'hooks/GlobalContext';
|
import { useAppState } from 'hooks/GlobalContext';
|
||||||
|
|
||||||
const defaultMessage: ChatMessage = {
|
const defaultMessage: ChatMessage = {
|
||||||
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "user"
|
status: "done", type: "text", sessionId: "", timestamp: new Date(), content: "", role: "user", metadata: null as any
|
||||||
};
|
};
|
||||||
|
|
||||||
const GenerateCandidate = (props: BackstoryElementProps) => {
|
const GenerateCandidate = (props: BackstoryElementProps) => {
|
||||||
|
@ -20,6 +20,7 @@ import QuestionAnswerIcon from '@mui/icons-material/QuestionAnswer';
|
|||||||
import DescriptionIcon from '@mui/icons-material/Description';
|
import DescriptionIcon from '@mui/icons-material/Description';
|
||||||
import professionalConversationPng from './Conversation.png';
|
import professionalConversationPng from './Conversation.png';
|
||||||
import { ComingSoon } from 'components/ui/ComingSoon';
|
import { ComingSoon } from 'components/ui/ComingSoon';
|
||||||
|
import { useAuth } from 'hooks/AuthContext';
|
||||||
|
|
||||||
// Placeholder for Testimonials component
|
// Placeholder for Testimonials component
|
||||||
const Testimonials = () => {
|
const Testimonials = () => {
|
||||||
@ -135,6 +136,15 @@ const FeatureCard = ({
|
|||||||
|
|
||||||
const HomePage = () => {
|
const HomePage = () => {
|
||||||
const testimonials = false;
|
const testimonials = false;
|
||||||
|
const { isGuest, guest, user } = useAuth();
|
||||||
|
|
||||||
|
if (isGuest) {
|
||||||
|
// Show guest-specific UI
|
||||||
|
console.log('Guest session:', guest?.sessionId || "No guest");
|
||||||
|
} else {
|
||||||
|
// Show authenticated user UI
|
||||||
|
console.log('Authenticated user:', user?.email || "No user");
|
||||||
|
}
|
||||||
|
|
||||||
return (<Box sx={{display: "flex", flexDirection: "column"}}>
|
return (<Box sx={{display: "flex", flexDirection: "column"}}>
|
||||||
{/* Hero Section */}
|
{/* Hero Section */}
|
||||||
|
@ -10,7 +10,8 @@ const LoadingPage = (props: BackstoryPageProps) => {
|
|||||||
status: 'done',
|
status: 'done',
|
||||||
sessionId: '',
|
sessionId: '',
|
||||||
content: 'Please wait while connecting to Backstory...',
|
content: 'Please wait while connecting to Backstory...',
|
||||||
timestamp: new Date()
|
timestamp: new Date(),
|
||||||
|
metadata: null as any
|
||||||
}
|
}
|
||||||
|
|
||||||
return <Box sx={{display: "flex", flexGrow: 1, maxWidth: "1024px", margin: "0 auto"}}>
|
return <Box sx={{display: "flex", flexGrow: 1, maxWidth: "1024px", margin: "0 auto"}}>
|
||||||
|
@ -26,6 +26,7 @@ import { LoginForm } from "components/EmailVerificationComponents";
|
|||||||
import { CandidateRegistrationForm } from "components/RegistrationForms";
|
import { CandidateRegistrationForm } from "components/RegistrationForms";
|
||||||
import { useNavigate } from 'react-router-dom';
|
import { useNavigate } from 'react-router-dom';
|
||||||
import { useAppState } from 'hooks/GlobalContext';
|
import { useAppState } from 'hooks/GlobalContext';
|
||||||
|
import * as Types from 'types/types';
|
||||||
|
|
||||||
const LoginPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps) => {
|
const LoginPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps) => {
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
@ -34,7 +35,7 @@ const LoginPage: React.FC<BackstoryPageProps> = (props: BackstoryPageProps) => {
|
|||||||
const [loading, setLoading] = useState(false);
|
const [loading, setLoading] = useState(false);
|
||||||
const [success, setSuccess] = useState<string | null>(null);
|
const [success, setSuccess] = useState<string | null>(null);
|
||||||
const { guest, user, login, isLoading, error } = useAuth();
|
const { guest, user, login, isLoading, error } = useAuth();
|
||||||
const name = (user?.userType === 'candidate') ? user.username : user?.email || '';
|
const name = (user?.userType === 'candidate') ? (user as Types.Candidate).username : user?.email || '';
|
||||||
const [errorMessage, setErrorMessage] = useState<string | null>(null);
|
const [errorMessage, setErrorMessage] = useState<string | null>(null);
|
||||||
|
|
||||||
const showGuest: boolean = false;
|
const showGuest: boolean = false;
|
||||||
|
@ -10,7 +10,8 @@ const LoginRequired = (props: BackstoryPageProps) => {
|
|||||||
status: 'done',
|
status: 'done',
|
||||||
sessionId: '',
|
sessionId: '',
|
||||||
content: 'You must be logged to view this feature.',
|
content: 'You must be logged to view this feature.',
|
||||||
timestamp: new Date()
|
timestamp: new Date(),
|
||||||
|
metadata: null as any
|
||||||
}
|
}
|
||||||
|
|
||||||
return <Box sx={{display: "flex", flexGrow: 1, maxWidth: "1024px", margin: "0 auto"}}>
|
return <Box sx={{display: "flex", flexGrow: 1, maxWidth: "1024px", margin: "0 auto"}}>
|
||||||
|
@ -11,6 +11,7 @@ import { CandidateInfo } from 'components/ui/CandidateInfo';
|
|||||||
import { useAuth } from 'hooks/AuthContext';
|
import { useAuth } from 'hooks/AuthContext';
|
||||||
import { Candidate } from 'types/types';
|
import { Candidate } from 'types/types';
|
||||||
import { useAppState } from 'hooks/GlobalContext';
|
import { useAppState } from 'hooks/GlobalContext';
|
||||||
|
import * as Types from 'types/types';
|
||||||
|
|
||||||
const ChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: BackstoryPageProps, ref) => {
|
const ChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: BackstoryPageProps, ref) => {
|
||||||
const { setSnack } = useAppState();
|
const { setSnack } = useAppState();
|
||||||
@ -18,7 +19,7 @@ const ChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: Back
|
|||||||
const theme = useTheme();
|
const theme = useTheme();
|
||||||
const isMobile = useMediaQuery(theme.breakpoints.down('md'));
|
const isMobile = useMediaQuery(theme.breakpoints.down('md'));
|
||||||
const [questions, setQuestions] = useState<React.ReactElement[]>([]);
|
const [questions, setQuestions] = useState<React.ReactElement[]>([]);
|
||||||
const candidate: Candidate | null = user?.userType === 'candidate' ? user : null;
|
const candidate: Candidate | null = user?.userType === 'candidate' ? user as Types.Candidate : null;
|
||||||
|
|
||||||
// console.log("ChatPage candidate =>", candidate);
|
// console.log("ChatPage candidate =>", candidate);
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
@ -74,7 +74,7 @@ const CandidateDashboard = (props: CandidateDashboardProps) => {
|
|||||||
variant="contained"
|
variant="contained"
|
||||||
color="primary"
|
color="primary"
|
||||||
sx={{ mt: 1 }}
|
sx={{ mt: 1 }}
|
||||||
onClick={(e) => {e.stopPropagation(); navigate('/candidate/dashboard/profile'); }}
|
onClick={(e) => { e.stopPropagation(); navigate('/candidate/profile'); }}
|
||||||
>
|
>
|
||||||
Complete Your Profile
|
Complete Your Profile
|
||||||
</Button>
|
</Button>
|
||||||
|
@ -14,9 +14,10 @@ import Typography from '@mui/material/Typography';
|
|||||||
// import ResetIcon from '@mui/icons-material/History';
|
// import ResetIcon from '@mui/icons-material/History';
|
||||||
import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
|
import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
|
||||||
|
|
||||||
import { connectionBase } from '../../utils/Global';
|
|
||||||
import { BackstoryPageProps } from '../../components/BackstoryTab';
|
import { BackstoryPageProps } from '../../components/BackstoryTab';
|
||||||
import { useAppState } from 'hooks/GlobalContext';
|
import { useAppState } from 'hooks/GlobalContext';
|
||||||
|
import { useAuth } from 'hooks/AuthContext';
|
||||||
|
import * as Types from 'types/types';
|
||||||
|
|
||||||
interface ServerTunables {
|
interface ServerTunables {
|
||||||
system_prompt: string,
|
system_prompt: string,
|
||||||
@ -33,19 +34,7 @@ type Tool = {
|
|||||||
returns?: any
|
returns?: any
|
||||||
};
|
};
|
||||||
|
|
||||||
type GPUInfo = {
|
const SystemInfoComponent: React.FC<{ systemInfo: Types.SystemInfo | undefined }> = ({ systemInfo }) => {
|
||||||
name: string,
|
|
||||||
memory: number,
|
|
||||||
discrete: boolean
|
|
||||||
}
|
|
||||||
|
|
||||||
type SystemInfo = {
|
|
||||||
"Installed RAM": string,
|
|
||||||
"Graphics Card": GPUInfo[],
|
|
||||||
"CPU": string
|
|
||||||
};
|
|
||||||
|
|
||||||
const SystemInfoComponent: React.FC<{ systemInfo: SystemInfo | undefined }> = ({ systemInfo }) => {
|
|
||||||
const [systemElements, setSystemElements] = useState<ReactElement[]>([]);
|
const [systemElements, setSystemElements] = useState<ReactElement[]>([]);
|
||||||
|
|
||||||
const convertToSymbols = (text: string) => {
|
const convertToSymbols = (text: string) => {
|
||||||
@ -86,91 +75,92 @@ const SystemInfoComponent: React.FC<{ systemInfo: SystemInfo | undefined }> = ({
|
|||||||
};
|
};
|
||||||
|
|
||||||
const Settings = (props: BackstoryPageProps) => {
|
const Settings = (props: BackstoryPageProps) => {
|
||||||
|
const { apiClient } = useAuth();
|
||||||
const { setSnack } = useAppState();
|
const { setSnack } = useAppState();
|
||||||
const [editSystemPrompt, setEditSystemPrompt] = useState<string>("");
|
const [editSystemPrompt, setEditSystemPrompt] = useState<string>("");
|
||||||
const [systemInfo, setSystemInfo] = useState<SystemInfo | undefined>(undefined);
|
const [systemInfo, setSystemInfo] = useState<Types.SystemInfo | undefined>(undefined);
|
||||||
const [tools, setTools] = useState<Tool[]>([]);
|
const [tools, setTools] = useState<Tool[]>([]);
|
||||||
const [rags, setRags] = useState<Tool[]>([]);
|
const [rags, setRags] = useState<Tool[]>([]);
|
||||||
const [systemPrompt, setSystemPrompt] = useState<string>("");
|
const [systemPrompt, setSystemPrompt] = useState<string>("");
|
||||||
const [messageHistoryLength, setMessageHistoryLength] = useState<number>(5);
|
const [messageHistoryLength, setMessageHistoryLength] = useState<number>(5);
|
||||||
const [serverTunables, setServerTunables] = useState<ServerTunables | undefined>(undefined);
|
const [serverTunables, setServerTunables] = useState<ServerTunables | undefined>(undefined);
|
||||||
|
|
||||||
useEffect(() => {
|
// useEffect(() => {
|
||||||
if (serverTunables === undefined || systemPrompt === serverTunables.system_prompt || !systemPrompt.trim()) {
|
// if (serverTunables === undefined || systemPrompt === serverTunables.system_prompt || !systemPrompt.trim()) {
|
||||||
return;
|
// return;
|
||||||
}
|
// }
|
||||||
const sendSystemPrompt = async (prompt: string) => {
|
// const sendSystemPrompt = async (prompt: string) => {
|
||||||
try {
|
// try {
|
||||||
const response = await fetch(connectionBase + `/api/1.0/tunables`, {
|
// const response = await fetch(connectionBase + `/api/1.0/tunables`, {
|
||||||
method: 'PUT',
|
// method: 'PUT',
|
||||||
headers: {
|
// headers: {
|
||||||
'Content-Type': 'application/json',
|
// 'Content-Type': 'application/json',
|
||||||
'Accept': 'application/json',
|
// 'Accept': 'application/json',
|
||||||
},
|
// },
|
||||||
body: JSON.stringify({ "system_prompt": prompt }),
|
// body: JSON.stringify({ "system_prompt": prompt }),
|
||||||
});
|
// });
|
||||||
|
|
||||||
const tunables = await response.json();
|
// const tunables = await response.json();
|
||||||
serverTunables.system_prompt = tunables.system_prompt;
|
// serverTunables.system_prompt = tunables.system_prompt;
|
||||||
console.log(tunables);
|
// console.log(tunables);
|
||||||
setSystemPrompt(tunables.system_prompt)
|
// setSystemPrompt(tunables.system_prompt)
|
||||||
setSnack("System prompt updated", "success");
|
// setSnack("System prompt updated", "success");
|
||||||
} catch (error) {
|
// } catch (error) {
|
||||||
console.error('Fetch error:', error);
|
// console.error('Fetch error:', error);
|
||||||
setSnack("System prompt update failed", "error");
|
// setSnack("System prompt update failed", "error");
|
||||||
}
|
// }
|
||||||
};
|
// };
|
||||||
|
|
||||||
sendSystemPrompt(systemPrompt);
|
// sendSystemPrompt(systemPrompt);
|
||||||
|
|
||||||
}, [systemPrompt, setSnack, serverTunables]);
|
// }, [systemPrompt, setSnack, serverTunables]);
|
||||||
|
|
||||||
const reset = async (types: ("rags" | "tools" | "history" | "system_prompt")[], message: string = "Update successful.") => {
|
// const reset = async (types: ("rags" | "tools" | "history" | "system_prompt")[], message: string = "Update successful.") => {
|
||||||
try {
|
// try {
|
||||||
const response = await fetch(connectionBase + `/api/1.0/reset/`, {
|
// const response = await fetch(connectionBase + `/api/1.0/reset/`, {
|
||||||
method: 'PUT',
|
// method: 'PUT',
|
||||||
headers: {
|
// headers: {
|
||||||
'Content-Type': 'application/json',
|
// 'Content-Type': 'application/json',
|
||||||
'Accept': 'application/json',
|
// 'Accept': 'application/json',
|
||||||
},
|
// },
|
||||||
body: JSON.stringify({ "reset": types }),
|
// body: JSON.stringify({ "reset": types }),
|
||||||
});
|
// });
|
||||||
|
|
||||||
if (!response.ok) {
|
// if (!response.ok) {
|
||||||
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
|
// throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
|
||||||
}
|
// }
|
||||||
|
|
||||||
if (!response.body) {
|
// if (!response.body) {
|
||||||
throw new Error('Response body is null');
|
// throw new Error('Response body is null');
|
||||||
}
|
// }
|
||||||
|
|
||||||
const data = await response.json();
|
// const data = await response.json();
|
||||||
if (data.error) {
|
// if (data.error) {
|
||||||
throw Error(data.error);
|
// throw Error(data.error);
|
||||||
}
|
// }
|
||||||
|
|
||||||
for (const [key, value] of Object.entries(data)) {
|
// for (const [key, value] of Object.entries(data)) {
|
||||||
switch (key) {
|
// switch (key) {
|
||||||
case "rags":
|
// case "rags":
|
||||||
setRags(value as Tool[]);
|
// setRags(value as Tool[]);
|
||||||
break;
|
// break;
|
||||||
case "tools":
|
// case "tools":
|
||||||
setTools(value as Tool[]);
|
// setTools(value as Tool[]);
|
||||||
break;
|
// break;
|
||||||
case "system_prompt":
|
// case "system_prompt":
|
||||||
setSystemPrompt((value as ServerTunables)["system_prompt"].trim());
|
// setSystemPrompt((value as ServerTunables)["system_prompt"].trim());
|
||||||
break;
|
// break;
|
||||||
case "history":
|
// case "history":
|
||||||
console.log('TODO: handle history reset');
|
// console.log('TODO: handle history reset');
|
||||||
break;
|
// break;
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
setSnack(message, "success");
|
// setSnack(message, "success");
|
||||||
} catch (error) {
|
// } catch (error) {
|
||||||
console.error('Fetch error:', error);
|
// console.error('Fetch error:', error);
|
||||||
setSnack("Unable to restore defaults", "error");
|
// setSnack("Unable to restore defaults", "error");
|
||||||
}
|
// }
|
||||||
};
|
// };
|
||||||
|
|
||||||
// Get the system information
|
// Get the system information
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@ -179,27 +169,8 @@ const Settings = (props: BackstoryPageProps) => {
|
|||||||
}
|
}
|
||||||
const fetchSystemInfo = async () => {
|
const fetchSystemInfo = async () => {
|
||||||
try {
|
try {
|
||||||
const response = await fetch(connectionBase + `/api/1.0/system-info`, {
|
const response: Types.SystemInfo = await apiClient.getSystemInfo();
|
||||||
method: 'GET',
|
setSystemInfo(response);
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
if (!response.ok) {
|
|
||||||
throw new Error(`Server responded with ${response.status}: ${response.statusText}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!response.body) {
|
|
||||||
throw new Error('Response body is null');
|
|
||||||
}
|
|
||||||
|
|
||||||
const data = await response.json();
|
|
||||||
if (data.error) {
|
|
||||||
throw Error(data.error);
|
|
||||||
}
|
|
||||||
|
|
||||||
setSystemInfo(data);
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error obtaining system information:', error);
|
console.error('Error obtaining system information:', error);
|
||||||
setSnack("Unable to obtain system information.", "error");
|
setSnack("Unable to obtain system information.", "error");
|
||||||
@ -217,101 +188,101 @@ const Settings = (props: BackstoryPageProps) => {
|
|||||||
setEditSystemPrompt(systemPrompt.trim());
|
setEditSystemPrompt(systemPrompt.trim());
|
||||||
}, [systemPrompt, setEditSystemPrompt]);
|
}, [systemPrompt, setEditSystemPrompt]);
|
||||||
|
|
||||||
const toggleRag = async (tool: Tool) => {
|
// const toggleRag = async (tool: Tool) => {
|
||||||
tool.enabled = !tool.enabled
|
// tool.enabled = !tool.enabled
|
||||||
try {
|
// try {
|
||||||
const response = await fetch(connectionBase + `/api/1.0/tunables`, {
|
// const response = await fetch(connectionBase + `/api/1.0/tunables`, {
|
||||||
method: 'PUT',
|
// method: 'PUT',
|
||||||
headers: {
|
// headers: {
|
||||||
'Content-Type': 'application/json',
|
// 'Content-Type': 'application/json',
|
||||||
'Accept': 'application/json',
|
// 'Accept': 'application/json',
|
||||||
},
|
// },
|
||||||
body: JSON.stringify({ "rags": [{ "name": tool?.name, "enabled": tool.enabled }] }),
|
// body: JSON.stringify({ "rags": [{ "name": tool?.name, "enabled": tool.enabled }] }),
|
||||||
});
|
// });
|
||||||
|
|
||||||
const tunables: ServerTunables = await response.json();
|
// const tunables: ServerTunables = await response.json();
|
||||||
setRags(tunables.rags)
|
// setRags(tunables.rags)
|
||||||
setSnack(`${tool?.name} ${tool.enabled ? "enabled" : "disabled"}`);
|
// setSnack(`${tool?.name} ${tool.enabled ? "enabled" : "disabled"}`);
|
||||||
} catch (error) {
|
// } catch (error) {
|
||||||
console.error('Fetch error:', error);
|
// console.error('Fetch error:', error);
|
||||||
setSnack(`${tool?.name} ${tool.enabled ? "enabling" : "disabling"} failed.`, "error");
|
// setSnack(`${tool?.name} ${tool.enabled ? "enabling" : "disabling"} failed.`, "error");
|
||||||
tool.enabled = !tool.enabled
|
// tool.enabled = !tool.enabled
|
||||||
}
|
// }
|
||||||
};
|
// };
|
||||||
|
|
||||||
const toggleTool = async (tool: Tool) => {
|
// const toggleTool = async (tool: Tool) => {
|
||||||
tool.enabled = !tool.enabled
|
// tool.enabled = !tool.enabled
|
||||||
try {
|
// try {
|
||||||
const response = await fetch(connectionBase + `/api/1.0/tunables`, {
|
// const response = await fetch(connectionBase + `/api/1.0/tunables`, {
|
||||||
method: 'PUT',
|
// method: 'PUT',
|
||||||
headers: {
|
// headers: {
|
||||||
'Content-Type': 'application/json',
|
// 'Content-Type': 'application/json',
|
||||||
'Accept': 'application/json',
|
// 'Accept': 'application/json',
|
||||||
},
|
// },
|
||||||
body: JSON.stringify({ "tools": [{ "name": tool.name, "enabled": tool.enabled }] }),
|
// body: JSON.stringify({ "tools": [{ "name": tool.name, "enabled": tool.enabled }] }),
|
||||||
});
|
// });
|
||||||
|
|
||||||
const tunables: ServerTunables = await response.json();
|
// const tunables: ServerTunables = await response.json();
|
||||||
setTools(tunables.tools)
|
// setTools(tunables.tools)
|
||||||
setSnack(`${tool.name} ${tool.enabled ? "enabled" : "disabled"}`);
|
// setSnack(`${tool.name} ${tool.enabled ? "enabled" : "disabled"}`);
|
||||||
} catch (error) {
|
// } catch (error) {
|
||||||
console.error('Fetch error:', error);
|
// console.error('Fetch error:', error);
|
||||||
setSnack(`${tool.name} ${tool.enabled ? "enabling" : "disabling"} failed.`, "error");
|
// setSnack(`${tool.name} ${tool.enabled ? "enabling" : "disabling"} failed.`, "error");
|
||||||
tool.enabled = !tool.enabled
|
// tool.enabled = !tool.enabled
|
||||||
}
|
// }
|
||||||
};
|
// };
|
||||||
|
|
||||||
// If the systemPrompt has not been set, fetch it from the server
|
// If the systemPrompt has not been set, fetch it from the server
|
||||||
useEffect(() => {
|
// useEffect(() => {
|
||||||
if (serverTunables !== undefined) {
|
// if (serverTunables !== undefined) {
|
||||||
return;
|
// return;
|
||||||
}
|
// }
|
||||||
const fetchTunables = async () => {
|
// const fetchTunables = async () => {
|
||||||
try {
|
// try {
|
||||||
// Make the fetch request with proper headers
|
// // Make the fetch request with proper headers
|
||||||
const response = await fetch(connectionBase + `/api/1.0/tunables`, {
|
// const response = await fetch(connectionBase + `/api/1.0/tunables`, {
|
||||||
method: 'GET',
|
// method: 'GET',
|
||||||
headers: {
|
// headers: {
|
||||||
'Content-Type': 'application/json',
|
// 'Content-Type': 'application/json',
|
||||||
'Accept': 'application/json',
|
// 'Accept': 'application/json',
|
||||||
},
|
// },
|
||||||
});
|
// });
|
||||||
const data = await response.json();
|
// const data = await response.json();
|
||||||
// console.log("Server tunables: ", data);
|
// // console.log("Server tunables: ", data);
|
||||||
setServerTunables(data);
|
// setServerTunables(data);
|
||||||
setSystemPrompt(data["system_prompt"]);
|
// setSystemPrompt(data["system_prompt"]);
|
||||||
setTools(data["tools"]);
|
// setTools(data["tools"]);
|
||||||
setRags(data["rags"]);
|
// setRags(data["rags"]);
|
||||||
} catch (error) {
|
// } catch (error) {
|
||||||
console.error('Fetch error:', error);
|
// console.error('Fetch error:', error);
|
||||||
setSnack("System prompt update failed", "error");
|
// setSnack("System prompt update failed", "error");
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
fetchTunables();
|
// fetchTunables();
|
||||||
}, [setServerTunables, setSystemPrompt, setMessageHistoryLength, serverTunables, setTools, setRags, setSnack]);
|
// }, [setServerTunables, setSystemPrompt, setMessageHistoryLength, serverTunables, setTools, setRags, setSnack]);
|
||||||
|
|
||||||
const toggle = async (type: string, index: number) => {
|
// const toggle = async (type: string, index: number) => {
|
||||||
switch (type) {
|
// switch (type) {
|
||||||
case "rag":
|
// case "rag":
|
||||||
if (rags === undefined) {
|
// if (rags === undefined) {
|
||||||
return;
|
// return;
|
||||||
}
|
// }
|
||||||
toggleRag(rags[index])
|
// toggleRag(rags[index])
|
||||||
break;
|
// break;
|
||||||
case "tool":
|
// case "tool":
|
||||||
if (tools === undefined) {
|
// if (tools === undefined) {
|
||||||
return;
|
// return;
|
||||||
}
|
// }
|
||||||
toggleTool(tools[index]);
|
// toggleTool(tools[index]);
|
||||||
}
|
// }
|
||||||
};
|
// };
|
||||||
|
|
||||||
const handleKeyPress = (event: any) => {
|
// const handleKeyPress = (event: any) => {
|
||||||
if (event.key === 'Enter' && event.ctrlKey) {
|
// if (event.key === 'Enter' && event.ctrlKey) {
|
||||||
setSystemPrompt(editSystemPrompt);
|
// setSystemPrompt(editSystemPrompt);
|
||||||
}
|
// }
|
||||||
};
|
// };
|
||||||
|
|
||||||
return (<div className="Controls">
|
return (<div className="Controls">
|
||||||
{/* <Typography component="span" sx={{ mb: 1 }}>
|
{/* <Typography component="span" sx={{ mb: 1 }}>
|
||||||
|
@ -6,6 +6,7 @@ import { SetSnackType } from '../components/Snack';
|
|||||||
import { LoadingComponent } from "../components/LoadingComponent";
|
import { LoadingComponent } from "../components/LoadingComponent";
|
||||||
import { User, Guest, Candidate } from 'types/types';
|
import { User, Guest, Candidate } from 'types/types';
|
||||||
import { useAuth } from "hooks/AuthContext";
|
import { useAuth } from "hooks/AuthContext";
|
||||||
|
import { useSelectedCandidate } from "hooks/GlobalContext";
|
||||||
|
|
||||||
interface CandidateRouteProps {
|
interface CandidateRouteProps {
|
||||||
guest?: Guest | null;
|
guest?: Guest | null;
|
||||||
@ -15,19 +16,19 @@ interface CandidateRouteProps {
|
|||||||
|
|
||||||
const CandidateRoute: React.FC<CandidateRouteProps> = (props: CandidateRouteProps) => {
|
const CandidateRoute: React.FC<CandidateRouteProps> = (props: CandidateRouteProps) => {
|
||||||
const { apiClient } = useAuth();
|
const { apiClient } = useAuth();
|
||||||
|
const { selectedCandidate, setSelectedCandidate } = useSelectedCandidate();
|
||||||
const { setSnack } = props;
|
const { setSnack } = props;
|
||||||
const { username } = useParams<{ username: string }>();
|
const { username } = useParams<{ username: string }>();
|
||||||
const [candidate, setCandidate] = useState<Candidate|null>(null);
|
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (candidate?.username === username || !username) {
|
if (selectedCandidate?.username === username || !username) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const getCandidate = async (reference: string) => {
|
const getCandidate = async (reference: string) => {
|
||||||
try {
|
try {
|
||||||
const result: Candidate = await apiClient.getCandidate(reference);
|
const result: Candidate = await apiClient.getCandidate(reference);
|
||||||
setCandidate(result);
|
setSelectedCandidate(result);
|
||||||
navigate('/chat');
|
navigate('/chat');
|
||||||
} catch {
|
} catch {
|
||||||
setSnack(`Unable to obtain information for ${username}.`, "error");
|
setSnack(`Unable to obtain information for ${username}.`, "error");
|
||||||
@ -36,9 +37,9 @@ const CandidateRoute: React.FC<CandidateRouteProps> = (props: CandidateRouteProp
|
|||||||
}
|
}
|
||||||
|
|
||||||
getCandidate(username);
|
getCandidate(username);
|
||||||
}, [candidate, username, setCandidate, navigate, setSnack, apiClient]);
|
}, [selectedCandidate, username, selectedCandidate, navigate, setSnack, apiClient]);
|
||||||
|
|
||||||
if (candidate === null) {
|
if (selectedCandidate?.username !== username) {
|
||||||
return (<Box>
|
return (<Box>
|
||||||
<LoadingComponent
|
<LoadingComponent
|
||||||
loadingText="Fetching candidate information..."
|
loadingText="Fetching candidate information..."
|
||||||
|
@ -29,10 +29,33 @@ import {
|
|||||||
convertArrayFromApi
|
convertArrayFromApi
|
||||||
} from 'types/types';
|
} from 'types/types';
|
||||||
|
|
||||||
|
const TOKEN_STORAGE = {
|
||||||
|
ACCESS_TOKEN: 'accessToken',
|
||||||
|
REFRESH_TOKEN: 'refreshToken',
|
||||||
|
USER_DATA: 'userData',
|
||||||
|
TOKEN_EXPIRY: 'tokenExpiry',
|
||||||
|
USER_TYPE: 'userType',
|
||||||
|
IS_GUEST: 'isGuest',
|
||||||
|
PENDING_VERIFICATION_EMAIL: 'pendingVerificationEmail'
|
||||||
|
} as const;
|
||||||
|
|
||||||
// ============================
|
// ============================
|
||||||
// Streaming Types and Interfaces
|
// Streaming Types and Interfaces
|
||||||
// ============================
|
// ============================
|
||||||
|
export interface GuestConversionRequest extends CreateCandidateRequest {
|
||||||
|
accountType: 'candidate';
|
||||||
|
}
|
||||||
|
|
||||||
|
export class RateLimitError extends Error {
|
||||||
|
constructor(
|
||||||
|
message: string,
|
||||||
|
public retryAfterSeconds: number,
|
||||||
|
public remainingRequests: Record<string, number>
|
||||||
|
) {
|
||||||
|
super(message);
|
||||||
|
this.name = 'RateLimitError';
|
||||||
|
}
|
||||||
|
}
|
||||||
interface StreamingOptions<T = Types.ChatMessage> {
|
interface StreamingOptions<T = Types.ChatMessage> {
|
||||||
method?: string,
|
method?: string,
|
||||||
headers?: Record<string, any>,
|
headers?: Record<string, any>,
|
||||||
@ -777,7 +800,7 @@ class ApiClient {
|
|||||||
async getOrCreateChatSession(candidate: Types.Candidate, title: string, context_type: Types.ChatContextType) : Promise<Types.ChatSession> {
|
async getOrCreateChatSession(candidate: Types.Candidate, title: string, context_type: Types.ChatContextType) : Promise<Types.ChatSession> {
|
||||||
const result = await this.getCandidateChatSessions(candidate.username);
|
const result = await this.getCandidateChatSessions(candidate.username);
|
||||||
/* Find the 'candidate_chat' session if it exists, otherwise create it */
|
/* Find the 'candidate_chat' session if it exists, otherwise create it */
|
||||||
let session = result.sessions.data.find(session => session.title === 'candidate_chat');
|
let session = result.sessions.data.find(session => session.title === title);
|
||||||
if (!session) {
|
if (!session) {
|
||||||
session = await this.createCandidateChatSession(
|
session = await this.createCandidateChatSession(
|
||||||
candidate.username,
|
candidate.username,
|
||||||
@ -788,6 +811,17 @@ class ApiClient {
|
|||||||
return session;
|
return session;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async getSystemInfo() : Promise<Types.SystemInfo> {
|
||||||
|
const response = await fetch(`${this.baseUrl}/system-info`, {
|
||||||
|
method: 'GET',
|
||||||
|
headers: this.defaultHeaders,
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await handleApiResponse<Types.SystemInfo>(response);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
async getCandidateSimilarContent(query: string
|
async getCandidateSimilarContent(query: string
|
||||||
): Promise<Types.ChromaDBGetResponse> {
|
): Promise<Types.ChromaDBGetResponse> {
|
||||||
const response = await fetch(`${this.baseUrl}/candidates/rag-search`, {
|
const response = await fetch(`${this.baseUrl}/candidates/rag-search`, {
|
||||||
@ -1003,6 +1037,202 @@ class ApiClient {
|
|||||||
return handleApiResponse<{ success: boolean; message: string }>(response);
|
return handleApiResponse<{ success: boolean; message: string }>(response);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ============================
|
||||||
|
// Guest Authentication Methods
|
||||||
|
// ============================
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a guest session with authentication
|
||||||
|
*/
|
||||||
|
async createGuestSession(): Promise<Types.AuthResponse> {
|
||||||
|
const response = await fetch(`${this.baseUrl}/auth/guest`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: this.defaultHeaders
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await handleApiResponse<Types.AuthResponse>(response);
|
||||||
|
|
||||||
|
// Convert guest data if needed
|
||||||
|
if (result.user && result.user.userType === 'guest') {
|
||||||
|
result.user = convertFromApi<Types.Guest>(result.user, "Guest");
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert guest account to permanent user account
|
||||||
|
*/
|
||||||
|
async convertGuestToUser(
|
||||||
|
registrationData: CreateCandidateRequest & { accountType: 'candidate' }
|
||||||
|
): Promise<{
|
||||||
|
message: string;
|
||||||
|
auth: Types.AuthResponse;
|
||||||
|
conversionType: string
|
||||||
|
}> {
|
||||||
|
const response = await fetch(`${this.baseUrl}/auth/guest/convert`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: this.defaultHeaders,
|
||||||
|
body: JSON.stringify(formatApiRequest(registrationData))
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await handleApiResponse<{
|
||||||
|
message: string;
|
||||||
|
auth: Types.AuthResponse;
|
||||||
|
conversionType: string;
|
||||||
|
}>(response);
|
||||||
|
|
||||||
|
// Convert the auth user data
|
||||||
|
if (result.auth?.user) {
|
||||||
|
result.auth.user = convertFromApi<Types.Candidate>(result.auth.user, "Candidate");
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if current session is a guest
|
||||||
|
*/
|
||||||
|
isGuestSession(): boolean {
|
||||||
|
try {
|
||||||
|
const userDataStr = localStorage.getItem(TOKEN_STORAGE.USER_DATA);
|
||||||
|
if (userDataStr) {
|
||||||
|
const userData = JSON.parse(userDataStr);
|
||||||
|
return userData.userType === 'guest';
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
} catch {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get guest session info
|
||||||
|
*/
|
||||||
|
getGuestSessionInfo(): Types.Guest | null {
|
||||||
|
try {
|
||||||
|
const userDataStr = localStorage.getItem(TOKEN_STORAGE.USER_DATA);
|
||||||
|
if (userDataStr) {
|
||||||
|
const userData = JSON.parse(userDataStr);
|
||||||
|
if (userData.userType === 'guest') {
|
||||||
|
return convertFromApi<Types.Guest>(userData, "Guest");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
} catch {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get rate limit status for current user
|
||||||
|
*/
|
||||||
|
async getRateLimitStatus(): Promise<{
|
||||||
|
user_id: string;
|
||||||
|
user_type: string;
|
||||||
|
is_admin: boolean;
|
||||||
|
current_usage: Record<string, number>;
|
||||||
|
limits: Record<string, number>;
|
||||||
|
remaining: Record<string, number>;
|
||||||
|
reset_times: Record<string, string>;
|
||||||
|
config: any;
|
||||||
|
}> {
|
||||||
|
const response = await fetch(`${this.baseUrl}/admin/rate-limits/info`, {
|
||||||
|
headers: this.defaultHeaders
|
||||||
|
});
|
||||||
|
|
||||||
|
return handleApiResponse<any>(response);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get guest statistics (admin only)
|
||||||
|
*/
|
||||||
|
async getGuestStatistics(): Promise<{
|
||||||
|
total_guests: number;
|
||||||
|
active_last_hour: number;
|
||||||
|
active_last_day: number;
|
||||||
|
converted_guests: number;
|
||||||
|
by_ip: Record<string, number>;
|
||||||
|
creation_timeline: Record<string, number>;
|
||||||
|
}> {
|
||||||
|
const response = await fetch(`${this.baseUrl}/admin/guests/statistics`, {
|
||||||
|
headers: this.defaultHeaders
|
||||||
|
});
|
||||||
|
|
||||||
|
return handleApiResponse<any>(response);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cleanup inactive guests (admin only)
|
||||||
|
*/
|
||||||
|
async cleanupInactiveGuests(inactiveHours: number = 24): Promise<{
|
||||||
|
message: string;
|
||||||
|
cleaned_count: number;
|
||||||
|
}> {
|
||||||
|
const response = await fetch(`${this.baseUrl}/admin/guests/cleanup`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: this.defaultHeaders,
|
||||||
|
body: JSON.stringify({ inactive_hours: inactiveHours })
|
||||||
|
});
|
||||||
|
|
||||||
|
return handleApiResponse<{
|
||||||
|
message: string;
|
||||||
|
cleaned_count: number;
|
||||||
|
}>(response);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================
|
||||||
|
// Enhanced Error Handling for Rate Limits
|
||||||
|
// ============================
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Enhanced API response handler with rate limit handling
|
||||||
|
*/
|
||||||
|
private async handleApiResponseWithRateLimit<T>(response: Response): Promise<T> {
|
||||||
|
if (response.status === 429) {
|
||||||
|
const rateLimitData = await response.json();
|
||||||
|
const retryAfter = response.headers.get('Retry-After');
|
||||||
|
|
||||||
|
throw new RateLimitError(
|
||||||
|
rateLimitData.detail?.message || 'Rate limit exceeded',
|
||||||
|
parseInt(retryAfter || '60'),
|
||||||
|
rateLimitData.detail?.remaining || {}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return this.handleApiResponseWithConversion<T>(response);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Retry mechanism for rate-limited requests
|
||||||
|
*/
|
||||||
|
private async retryWithBackoff<T>(
|
||||||
|
requestFn: () => Promise<Response>,
|
||||||
|
maxRetries: number = 3
|
||||||
|
): Promise<T> {
|
||||||
|
let lastError: Error;
|
||||||
|
|
||||||
|
for (let attempt = 0; attempt <= maxRetries; attempt++) {
|
||||||
|
try {
|
||||||
|
const response = await requestFn();
|
||||||
|
return await this.handleApiResponseWithRateLimit<T>(response);
|
||||||
|
} catch (error) {
|
||||||
|
lastError = error as Error;
|
||||||
|
|
||||||
|
if (error instanceof RateLimitError && attempt < maxRetries) {
|
||||||
|
const delayMs = Math.min(error.retryAfterSeconds * 1000, 60000); // Max 1 minute
|
||||||
|
console.warn(`Rate limited, retrying in ${delayMs}ms (attempt ${attempt + 1}/${maxRetries + 1})`);
|
||||||
|
await new Promise(resolve => setTimeout(resolve, delayMs));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
throw lastError!;
|
||||||
|
}
|
||||||
async resetChatSession(id: string): Promise<{ success: boolean; message: string }> {
|
async resetChatSession(id: string): Promise<{ success: boolean; message: string }> {
|
||||||
const response = await fetch(`${this.baseUrl}/chat/sessions/${id}/reset`, {
|
const response = await fetch(`${this.baseUrl}/chat/sessions/${id}/reset`, {
|
||||||
method: 'PATCH',
|
method: 'PATCH',
|
||||||
@ -1384,5 +1614,5 @@ export interface PendingVerification {
|
|||||||
attempts: number;
|
attempts: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
export { ApiClient }
|
export { ApiClient, TOKEN_STORAGE }
|
||||||
export type { StreamingOptions, StreamingResponse }
|
export type { StreamingOptions, StreamingResponse }
|
@ -1,6 +1,6 @@
|
|||||||
// Generated TypeScript types from Pydantic models
|
// Generated TypeScript types from Pydantic models
|
||||||
// Source: src/backend/models.py
|
// Source: src/backend/models.py
|
||||||
// Generated on: 2025-06-07T20:43:58.855207
|
// Generated on: 2025-06-09T03:32:01.335483
|
||||||
// DO NOT EDIT MANUALLY - This file is auto-generated
|
// DO NOT EDIT MANUALLY - This file is auto-generated
|
||||||
|
|
||||||
// ============================
|
// ============================
|
||||||
@ -128,8 +128,10 @@ export interface Attachment {
|
|||||||
export interface AuthResponse {
|
export interface AuthResponse {
|
||||||
accessToken: string;
|
accessToken: string;
|
||||||
refreshToken: string;
|
refreshToken: string;
|
||||||
user: any;
|
user: Candidate | Employer | Guest;
|
||||||
expiresAt: number;
|
expiresAt: number;
|
||||||
|
userType?: string;
|
||||||
|
isGuest?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface Authentication {
|
export interface Authentication {
|
||||||
@ -148,7 +150,9 @@ export interface Authentication {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export interface BaseUser {
|
export interface BaseUser {
|
||||||
|
userType: "candidate" | "employer" | "guest";
|
||||||
id?: string;
|
id?: string;
|
||||||
|
lastActivity?: Date;
|
||||||
email: string;
|
email: string;
|
||||||
firstName: string;
|
firstName: string;
|
||||||
lastName: string;
|
lastName: string;
|
||||||
@ -164,24 +168,15 @@ export interface BaseUser {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export interface BaseUserWithType {
|
export interface BaseUserWithType {
|
||||||
id?: string;
|
|
||||||
email: string;
|
|
||||||
firstName: string;
|
|
||||||
lastName: string;
|
|
||||||
fullName: string;
|
|
||||||
phone?: string;
|
|
||||||
location?: Location;
|
|
||||||
createdAt: Date;
|
|
||||||
updatedAt: Date;
|
|
||||||
lastLogin?: Date;
|
|
||||||
profileImage?: string;
|
|
||||||
status: "active" | "inactive" | "pending" | "banned";
|
|
||||||
isAdmin: boolean;
|
|
||||||
userType: "candidate" | "employer" | "guest";
|
userType: "candidate" | "employer" | "guest";
|
||||||
|
id?: string;
|
||||||
|
lastActivity?: Date;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface Candidate {
|
export interface Candidate {
|
||||||
|
userType: "candidate" | "employer" | "guest";
|
||||||
id?: string;
|
id?: string;
|
||||||
|
lastActivity?: Date;
|
||||||
email: string;
|
email: string;
|
||||||
firstName: string;
|
firstName: string;
|
||||||
lastName: string;
|
lastName: string;
|
||||||
@ -194,7 +189,6 @@ export interface Candidate {
|
|||||||
profileImage?: string;
|
profileImage?: string;
|
||||||
status: "active" | "inactive" | "pending" | "banned";
|
status: "active" | "inactive" | "pending" | "banned";
|
||||||
isAdmin: boolean;
|
isAdmin: boolean;
|
||||||
userType: "candidate";
|
|
||||||
username: string;
|
username: string;
|
||||||
description?: string;
|
description?: string;
|
||||||
resume?: string;
|
resume?: string;
|
||||||
@ -214,7 +208,9 @@ export interface Candidate {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export interface CandidateAI {
|
export interface CandidateAI {
|
||||||
|
userType: "candidate" | "employer" | "guest";
|
||||||
id?: string;
|
id?: string;
|
||||||
|
lastActivity?: Date;
|
||||||
email: string;
|
email: string;
|
||||||
firstName: string;
|
firstName: string;
|
||||||
lastName: string;
|
lastName: string;
|
||||||
@ -227,7 +223,6 @@ export interface CandidateAI {
|
|||||||
profileImage?: string;
|
profileImage?: string;
|
||||||
status: "active" | "inactive" | "pending" | "banned";
|
status: "active" | "inactive" | "pending" | "banned";
|
||||||
isAdmin: boolean;
|
isAdmin: boolean;
|
||||||
userType: "candidate";
|
|
||||||
username: string;
|
username: string;
|
||||||
description?: string;
|
description?: string;
|
||||||
resume?: string;
|
resume?: string;
|
||||||
@ -301,7 +296,7 @@ export interface ChatMessage {
|
|||||||
role: "user" | "assistant" | "system" | "information" | "warning" | "error";
|
role: "user" | "assistant" | "system" | "information" | "warning" | "error";
|
||||||
content: string;
|
content: string;
|
||||||
tunables?: Tunables;
|
tunables?: Tunables;
|
||||||
metadata?: ChatMessageMetaData;
|
metadata: ChatMessageMetaData;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ChatMessageError {
|
export interface ChatMessageError {
|
||||||
@ -319,9 +314,9 @@ export interface ChatMessageMetaData {
|
|||||||
temperature: number;
|
temperature: number;
|
||||||
maxTokens: number;
|
maxTokens: number;
|
||||||
topP: number;
|
topP: number;
|
||||||
frequencyPenalty?: number;
|
frequencyPenalty: number;
|
||||||
presencePenalty?: number;
|
presencePenalty: number;
|
||||||
stopSequences?: Array<string>;
|
stopSequences: Array<string>;
|
||||||
ragResults?: Array<ChromaDBGetResponse>;
|
ragResults?: Array<ChromaDBGetResponse>;
|
||||||
llmHistory?: Array<any>;
|
llmHistory?: Array<any>;
|
||||||
evalCount: number;
|
evalCount: number;
|
||||||
@ -392,7 +387,6 @@ export interface ChatQuery {
|
|||||||
export interface ChatSession {
|
export interface ChatSession {
|
||||||
id?: string;
|
id?: string;
|
||||||
userId?: string;
|
userId?: string;
|
||||||
guestId?: string;
|
|
||||||
createdAt?: Date;
|
createdAt?: Date;
|
||||||
lastActivity?: Date;
|
lastActivity?: Date;
|
||||||
title?: string;
|
title?: string;
|
||||||
@ -507,7 +501,7 @@ export interface DocumentMessage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export interface DocumentOptions {
|
export interface DocumentOptions {
|
||||||
includeInRAG?: boolean;
|
includeInRAG: boolean;
|
||||||
isJobDocument?: boolean;
|
isJobDocument?: boolean;
|
||||||
overwrite?: boolean;
|
overwrite?: boolean;
|
||||||
}
|
}
|
||||||
@ -541,7 +535,9 @@ export interface EmailVerificationRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export interface Employer {
|
export interface Employer {
|
||||||
|
userType: "candidate" | "employer" | "guest";
|
||||||
id?: string;
|
id?: string;
|
||||||
|
lastActivity?: Date;
|
||||||
email: string;
|
email: string;
|
||||||
firstName: string;
|
firstName: string;
|
||||||
lastName: string;
|
lastName: string;
|
||||||
@ -554,7 +550,6 @@ export interface Employer {
|
|||||||
profileImage?: string;
|
profileImage?: string;
|
||||||
status: "active" | "inactive" | "pending" | "banned";
|
status: "active" | "inactive" | "pending" | "banned";
|
||||||
isAdmin: boolean;
|
isAdmin: boolean;
|
||||||
userType: "employer";
|
|
||||||
companyName: string;
|
companyName: string;
|
||||||
industry: string;
|
industry: string;
|
||||||
description?: string;
|
description?: string;
|
||||||
@ -580,14 +575,71 @@ export interface ErrorDetail {
|
|||||||
details?: any;
|
details?: any;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface GPUInfo {
|
||||||
|
name: string;
|
||||||
|
memory: number;
|
||||||
|
discrete: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
export interface Guest {
|
export interface Guest {
|
||||||
|
userType: "candidate" | "employer" | "guest";
|
||||||
id?: string;
|
id?: string;
|
||||||
sessionId: string;
|
lastActivity?: Date;
|
||||||
|
email: string;
|
||||||
|
firstName: string;
|
||||||
|
lastName: string;
|
||||||
|
fullName: string;
|
||||||
|
phone?: string;
|
||||||
|
location?: Location;
|
||||||
createdAt: Date;
|
createdAt: Date;
|
||||||
lastActivity: Date;
|
updatedAt: Date;
|
||||||
|
lastLogin?: Date;
|
||||||
|
profileImage?: string;
|
||||||
|
status: "active" | "inactive" | "pending" | "banned";
|
||||||
|
isAdmin: boolean;
|
||||||
|
sessionId: string;
|
||||||
|
username: string;
|
||||||
convertedToUserId?: string;
|
convertedToUserId?: string;
|
||||||
ipAddress?: string;
|
ipAddress?: string;
|
||||||
userAgent?: string;
|
userAgent?: string;
|
||||||
|
ragContentSize: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface GuestCleanupRequest {
|
||||||
|
inactiveHours: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface GuestConversionRequest {
|
||||||
|
accountType: "candidate" | "employer";
|
||||||
|
email: string;
|
||||||
|
username: string;
|
||||||
|
password: string;
|
||||||
|
firstName: string;
|
||||||
|
lastName: string;
|
||||||
|
phone?: string;
|
||||||
|
companyName?: string;
|
||||||
|
industry?: string;
|
||||||
|
companySize?: string;
|
||||||
|
companyDescription?: string;
|
||||||
|
websiteUrl?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface GuestSessionResponse {
|
||||||
|
accessToken: string;
|
||||||
|
refreshToken: string;
|
||||||
|
user: Guest;
|
||||||
|
expiresAt: number;
|
||||||
|
userType: "guest";
|
||||||
|
isGuest: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface GuestStatistics {
|
||||||
|
totalGuests: number;
|
||||||
|
activeLastHour: number;
|
||||||
|
activeLastDay: number;
|
||||||
|
convertedGuests: number;
|
||||||
|
byIp: Record<string, number>;
|
||||||
|
creationTimeline: Record<string, number>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface InterviewFeedback {
|
export interface InterviewFeedback {
|
||||||
@ -746,6 +798,7 @@ export interface MFARequest {
|
|||||||
password: string;
|
password: string;
|
||||||
deviceId: string;
|
deviceId: string;
|
||||||
deviceName: string;
|
deviceName: string;
|
||||||
|
email: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface MFARequestResponse {
|
export interface MFARequestResponse {
|
||||||
@ -841,6 +894,33 @@ export interface RagEntry {
|
|||||||
enabled: boolean;
|
enabled: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface RateLimitConfig {
|
||||||
|
requestsPerMinute: number;
|
||||||
|
requestsPerHour: number;
|
||||||
|
requestsPerDay: number;
|
||||||
|
burstLimit: number;
|
||||||
|
burstWindowSeconds: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface RateLimitResult {
|
||||||
|
allowed: boolean;
|
||||||
|
reason?: string;
|
||||||
|
retryAfterSeconds?: number;
|
||||||
|
remainingRequests?: Record<string, number>;
|
||||||
|
resetTimes?: Record<string, Date>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface RateLimitStatus {
|
||||||
|
userId: string;
|
||||||
|
userType: string;
|
||||||
|
isAdmin: boolean;
|
||||||
|
currentUsage: Record<string, number>;
|
||||||
|
limits: Record<string, number>;
|
||||||
|
remaining: Record<string, number>;
|
||||||
|
resetTimes: Record<string, Date>;
|
||||||
|
config: RateLimitConfig;
|
||||||
|
}
|
||||||
|
|
||||||
export interface RefreshToken {
|
export interface RefreshToken {
|
||||||
token: string;
|
token: string;
|
||||||
expiresAt: Date;
|
expiresAt: Date;
|
||||||
@ -915,6 +995,15 @@ export interface SocialLink {
|
|||||||
url: string;
|
url: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface SystemInfo {
|
||||||
|
installedRAM: string;
|
||||||
|
graphicsCards: Array<GPUInfo>;
|
||||||
|
CPU: string;
|
||||||
|
llmModel: string;
|
||||||
|
embeddingModel: string;
|
||||||
|
maxContextLength: number;
|
||||||
|
}
|
||||||
|
|
||||||
export interface Tunables {
|
export interface Tunables {
|
||||||
enableRAG: boolean;
|
enableRAG: boolean;
|
||||||
enableTools: boolean;
|
enableTools: boolean;
|
||||||
@ -1035,13 +1124,15 @@ export function convertAuthenticationFromApi(data: any): Authentication {
|
|||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* Convert BaseUser from API response, parsing date fields
|
* Convert BaseUser from API response, parsing date fields
|
||||||
* Date fields: createdAt, updatedAt, lastLogin
|
* Date fields: lastActivity, createdAt, updatedAt, lastLogin
|
||||||
*/
|
*/
|
||||||
export function convertBaseUserFromApi(data: any): BaseUser {
|
export function convertBaseUserFromApi(data: any): BaseUser {
|
||||||
if (!data) return data;
|
if (!data) return data;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
...data,
|
...data,
|
||||||
|
// Convert lastActivity from ISO string to Date
|
||||||
|
lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined,
|
||||||
// Convert createdAt from ISO string to Date
|
// Convert createdAt from ISO string to Date
|
||||||
createdAt: new Date(data.createdAt),
|
createdAt: new Date(data.createdAt),
|
||||||
// Convert updatedAt from ISO string to Date
|
// Convert updatedAt from ISO string to Date
|
||||||
@ -1052,30 +1143,28 @@ export function convertBaseUserFromApi(data: any): BaseUser {
|
|||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* Convert BaseUserWithType from API response, parsing date fields
|
* Convert BaseUserWithType from API response, parsing date fields
|
||||||
* Date fields: createdAt, updatedAt, lastLogin
|
* Date fields: lastActivity
|
||||||
*/
|
*/
|
||||||
export function convertBaseUserWithTypeFromApi(data: any): BaseUserWithType {
|
export function convertBaseUserWithTypeFromApi(data: any): BaseUserWithType {
|
||||||
if (!data) return data;
|
if (!data) return data;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
...data,
|
...data,
|
||||||
// Convert createdAt from ISO string to Date
|
// Convert lastActivity from ISO string to Date
|
||||||
createdAt: new Date(data.createdAt),
|
lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined,
|
||||||
// Convert updatedAt from ISO string to Date
|
|
||||||
updatedAt: new Date(data.updatedAt),
|
|
||||||
// Convert lastLogin from ISO string to Date
|
|
||||||
lastLogin: data.lastLogin ? new Date(data.lastLogin) : undefined,
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* Convert Candidate from API response, parsing date fields
|
* Convert Candidate from API response, parsing date fields
|
||||||
* Date fields: createdAt, updatedAt, lastLogin, availabilityDate
|
* Date fields: lastActivity, createdAt, updatedAt, lastLogin, availabilityDate
|
||||||
*/
|
*/
|
||||||
export function convertCandidateFromApi(data: any): Candidate {
|
export function convertCandidateFromApi(data: any): Candidate {
|
||||||
if (!data) return data;
|
if (!data) return data;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
...data,
|
...data,
|
||||||
|
// Convert lastActivity from ISO string to Date
|
||||||
|
lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined,
|
||||||
// Convert createdAt from ISO string to Date
|
// Convert createdAt from ISO string to Date
|
||||||
createdAt: new Date(data.createdAt),
|
createdAt: new Date(data.createdAt),
|
||||||
// Convert updatedAt from ISO string to Date
|
// Convert updatedAt from ISO string to Date
|
||||||
@ -1088,13 +1177,15 @@ export function convertCandidateFromApi(data: any): Candidate {
|
|||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* Convert CandidateAI from API response, parsing date fields
|
* Convert CandidateAI from API response, parsing date fields
|
||||||
* Date fields: createdAt, updatedAt, lastLogin, availabilityDate
|
* Date fields: lastActivity, createdAt, updatedAt, lastLogin, availabilityDate
|
||||||
*/
|
*/
|
||||||
export function convertCandidateAIFromApi(data: any): CandidateAI {
|
export function convertCandidateAIFromApi(data: any): CandidateAI {
|
||||||
if (!data) return data;
|
if (!data) return data;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
...data,
|
...data,
|
||||||
|
// Convert lastActivity from ISO string to Date
|
||||||
|
lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined,
|
||||||
// Convert createdAt from ISO string to Date
|
// Convert createdAt from ISO string to Date
|
||||||
createdAt: new Date(data.createdAt),
|
createdAt: new Date(data.createdAt),
|
||||||
// Convert updatedAt from ISO string to Date
|
// Convert updatedAt from ISO string to Date
|
||||||
@ -1282,13 +1373,15 @@ export function convertEducationFromApi(data: any): Education {
|
|||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* Convert Employer from API response, parsing date fields
|
* Convert Employer from API response, parsing date fields
|
||||||
* Date fields: createdAt, updatedAt, lastLogin
|
* Date fields: lastActivity, createdAt, updatedAt, lastLogin
|
||||||
*/
|
*/
|
||||||
export function convertEmployerFromApi(data: any): Employer {
|
export function convertEmployerFromApi(data: any): Employer {
|
||||||
if (!data) return data;
|
if (!data) return data;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
...data,
|
...data,
|
||||||
|
// Convert lastActivity from ISO string to Date
|
||||||
|
lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined,
|
||||||
// Convert createdAt from ISO string to Date
|
// Convert createdAt from ISO string to Date
|
||||||
createdAt: new Date(data.createdAt),
|
createdAt: new Date(data.createdAt),
|
||||||
// Convert updatedAt from ISO string to Date
|
// Convert updatedAt from ISO string to Date
|
||||||
@ -1299,17 +1392,21 @@ export function convertEmployerFromApi(data: any): Employer {
|
|||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* Convert Guest from API response, parsing date fields
|
* Convert Guest from API response, parsing date fields
|
||||||
* Date fields: createdAt, lastActivity
|
* Date fields: lastActivity, createdAt, updatedAt, lastLogin
|
||||||
*/
|
*/
|
||||||
export function convertGuestFromApi(data: any): Guest {
|
export function convertGuestFromApi(data: any): Guest {
|
||||||
if (!data) return data;
|
if (!data) return data;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
...data,
|
...data,
|
||||||
|
// Convert lastActivity from ISO string to Date
|
||||||
|
lastActivity: data.lastActivity ? new Date(data.lastActivity) : undefined,
|
||||||
// Convert createdAt from ISO string to Date
|
// Convert createdAt from ISO string to Date
|
||||||
createdAt: new Date(data.createdAt),
|
createdAt: new Date(data.createdAt),
|
||||||
// Convert lastActivity from ISO string to Date
|
// Convert updatedAt from ISO string to Date
|
||||||
lastActivity: new Date(data.lastActivity),
|
updatedAt: new Date(data.updatedAt),
|
||||||
|
// Convert lastLogin from ISO string to Date
|
||||||
|
lastLogin: data.lastLogin ? new Date(data.lastLogin) : undefined,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
@ -1415,6 +1512,32 @@ export function convertRAGConfigurationFromApi(data: any): RAGConfiguration {
|
|||||||
updatedAt: new Date(data.updatedAt),
|
updatedAt: new Date(data.updatedAt),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
/**
|
||||||
|
* Convert RateLimitResult from API response, parsing date fields
|
||||||
|
* Date fields: resetTimes
|
||||||
|
*/
|
||||||
|
export function convertRateLimitResultFromApi(data: any): RateLimitResult {
|
||||||
|
if (!data) return data;
|
||||||
|
|
||||||
|
return {
|
||||||
|
...data,
|
||||||
|
// Convert resetTimes from ISO string to Date
|
||||||
|
resetTimes: data.resetTimes ? new Date(data.resetTimes) : undefined,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Convert RateLimitStatus from API response, parsing date fields
|
||||||
|
* Date fields: resetTimes
|
||||||
|
*/
|
||||||
|
export function convertRateLimitStatusFromApi(data: any): RateLimitStatus {
|
||||||
|
if (!data) return data;
|
||||||
|
|
||||||
|
return {
|
||||||
|
...data,
|
||||||
|
// Convert resetTimes from ISO string to Date
|
||||||
|
resetTimes: new Date(data.resetTimes),
|
||||||
|
};
|
||||||
|
}
|
||||||
/**
|
/**
|
||||||
* Convert RefreshToken from API response, parsing date fields
|
* Convert RefreshToken from API response, parsing date fields
|
||||||
* Date fields: expiresAt
|
* Date fields: expiresAt
|
||||||
@ -1527,6 +1650,10 @@ export function convertFromApi<T>(data: any, modelType: string): T {
|
|||||||
return convertMessageReactionFromApi(data) as T;
|
return convertMessageReactionFromApi(data) as T;
|
||||||
case 'RAGConfiguration':
|
case 'RAGConfiguration':
|
||||||
return convertRAGConfigurationFromApi(data) as T;
|
return convertRAGConfigurationFromApi(data) as T;
|
||||||
|
case 'RateLimitResult':
|
||||||
|
return convertRateLimitResultFromApi(data) as T;
|
||||||
|
case 'RateLimitStatus':
|
||||||
|
return convertRateLimitStatusFromApi(data) as T;
|
||||||
case 'RefreshToken':
|
case 'RefreshToken':
|
||||||
return convertRefreshTokenFromApi(data) as T;
|
return convertRefreshTokenFromApi(data) as T;
|
||||||
case 'UserActivity':
|
case 'UserActivity':
|
||||||
|
@ -1,15 +0,0 @@
|
|||||||
const getConnectionBase = (loc: any): string => {
|
|
||||||
if (!loc.host.match(/.*battle-linux.*/)
|
|
||||||
// && !loc.host.match(/.*backstory-beta.*/)
|
|
||||||
) {
|
|
||||||
return loc.protocol + "//" + loc.host;
|
|
||||||
} else {
|
|
||||||
return loc.protocol + "//battle-linux.ketrenos.com:8912";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const connectionBase = getConnectionBase(window.location);
|
|
||||||
|
|
||||||
export {
|
|
||||||
connectionBase
|
|
||||||
};
|
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import traceback
|
import traceback
|
||||||
from pydantic import BaseModel, Field # type: ignore
|
from pydantic import BaseModel, Field
|
||||||
from typing import (
|
from typing import (
|
||||||
Literal,
|
Literal,
|
||||||
get_args,
|
get_args,
|
||||||
@ -17,7 +17,7 @@ from typing import (
|
|||||||
import importlib
|
import importlib
|
||||||
import pathlib
|
import pathlib
|
||||||
import inspect
|
import inspect
|
||||||
from prometheus_client import CollectorRegistry # type: ignore
|
from prometheus_client import CollectorRegistry
|
||||||
|
|
||||||
from . base import Agent
|
from . base import Agent
|
||||||
from logger import logger
|
from logger import logger
|
||||||
@ -90,7 +90,7 @@ for path in package_dir.glob("*.py"):
|
|||||||
class_registry[name] = (full_module_name, name)
|
class_registry[name] = (full_module_name, name)
|
||||||
globals()[name] = obj
|
globals()[name] = obj
|
||||||
logger.info(f"Adding agent: {name}")
|
logger.info(f"Adding agent: {name}")
|
||||||
__all__.append(name) # type: ignore
|
__all__.append(name)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
logger.error(f"Error importing {full_module_name}: {e}")
|
logger.error(f"Error importing {full_module_name}: {e}")
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from typing import (
|
from typing import (
|
||||||
Literal,
|
Literal,
|
||||||
get_args,
|
get_args,
|
||||||
@ -20,8 +20,8 @@ import re
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime, UTC
|
from datetime import datetime, UTC
|
||||||
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
|
from prometheus_client import Counter, Summary, CollectorRegistry
|
||||||
import numpy as np # type: ignore
|
import numpy as np
|
||||||
|
|
||||||
from models import ( ApiActivityType, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatQuery, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate)
|
from models import ( ApiActivityType, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatQuery, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate)
|
||||||
from logger import logger
|
from logger import logger
|
||||||
@ -491,7 +491,7 @@ Content: {content}
|
|||||||
session_id: str, prompt: str,
|
session_id: str, prompt: str,
|
||||||
tunables: Optional[Tunables] = None,
|
tunables: Optional[Tunables] = None,
|
||||||
temperature=0.7
|
temperature=0.7
|
||||||
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
|
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessageRagSearch, None]:
|
||||||
if not self.user:
|
if not self.user:
|
||||||
error_message = ChatMessageError(
|
error_message = ChatMessageError(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
|
@ -28,7 +28,7 @@ class CandidateChat(Agent):
|
|||||||
CandidateChat Agent
|
CandidateChat Agent
|
||||||
"""
|
"""
|
||||||
|
|
||||||
agent_type: Literal["candidate_chat"] = "candidate_chat" # type: ignore
|
agent_type: Literal["candidate_chat"] = "candidate_chat"
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
system_prompt: str = system_message
|
system_prompt: str = system_message
|
||||||
|
@ -53,7 +53,7 @@ class Chat(Agent):
|
|||||||
Chat Agent
|
Chat Agent
|
||||||
"""
|
"""
|
||||||
|
|
||||||
agent_type: Literal["general"] = "general" # type: ignore
|
agent_type: Literal["general"] = "general"
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
system_prompt: str = system_message
|
system_prompt: str = system_message
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from pydantic import model_validator, Field, BaseModel # type: ignore
|
from pydantic import model_validator, Field, BaseModel
|
||||||
from typing import (
|
from typing import (
|
||||||
Dict,
|
Dict,
|
||||||
Literal,
|
Literal,
|
||||||
@ -37,7 +37,7 @@ seed = int(time.time())
|
|||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
class ImageGenerator(Agent):
|
class ImageGenerator(Agent):
|
||||||
agent_type: Literal["generate_image"] = "generate_image" # type: ignore
|
agent_type: Literal["generate_image"] = "generate_image"
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
agent_persist: bool = False
|
agent_persist: bool = False
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from pydantic import model_validator, Field, BaseModel # type: ignore
|
from pydantic import model_validator, Field, BaseModel
|
||||||
from typing import (
|
from typing import (
|
||||||
Dict,
|
Dict,
|
||||||
Literal,
|
Literal,
|
||||||
@ -23,7 +23,7 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from names_dataset import NameDataset, NameWrapper # type: ignore
|
from names_dataset import NameDataset, NameWrapper
|
||||||
|
|
||||||
from .base import Agent, agent_registry, LLMMessage
|
from .base import Agent, agent_registry, LLMMessage
|
||||||
from models import ApiActivityType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, Tunables
|
from models import ApiActivityType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, Tunables
|
||||||
@ -128,7 +128,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class EthnicNameGenerator:
|
class EthnicNameGenerator:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
try:
|
try:
|
||||||
from names_dataset import NameDataset # type: ignore
|
from names_dataset import NameDataset
|
||||||
self.nd = NameDataset()
|
self.nd = NameDataset()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error("NameDataset not available. Please install: pip install names-dataset")
|
logger.error("NameDataset not available. Please install: pip install names-dataset")
|
||||||
@ -292,7 +292,7 @@ class EthnicNameGenerator:
|
|||||||
return names
|
return names
|
||||||
|
|
||||||
class GeneratePersona(Agent):
|
class GeneratePersona(Agent):
|
||||||
agent_type: Literal["generate_persona"] = "generate_persona" # type: ignore
|
agent_type: Literal["generate_persona"] = "generate_persona"
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
agent_persist: bool = False
|
agent_persist: bool = False
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import model_validator, Field # type: ignore
|
from pydantic import model_validator, Field
|
||||||
from typing import (
|
from typing import (
|
||||||
Dict,
|
Dict,
|
||||||
Literal,
|
Literal,
|
||||||
@ -16,7 +16,7 @@ import json
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
import numpy as np # type: ignore
|
import numpy as np
|
||||||
|
|
||||||
from .base import Agent, agent_registry, LLMMessage
|
from .base import Agent, agent_registry, LLMMessage
|
||||||
from models import ApiActivityType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, JobRequirements, JobRequirementsMessage, Tunables
|
from models import ApiActivityType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, JobRequirements, JobRequirementsMessage, Tunables
|
||||||
@ -26,7 +26,7 @@ import defines
|
|||||||
import backstory_traceback as traceback
|
import backstory_traceback as traceback
|
||||||
|
|
||||||
class JobRequirementsAgent(Agent):
|
class JobRequirementsAgent(Agent):
|
||||||
agent_type: Literal["job_requirements"] = "job_requirements" # type: ignore
|
agent_type: Literal["job_requirements"] = "job_requirements"
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
# Stage 1A: Job Analysis Implementation
|
# Stage 1A: Job Analysis Implementation
|
||||||
|
@ -15,7 +15,7 @@ class Chat(Agent):
|
|||||||
Chat Agent
|
Chat Agent
|
||||||
"""
|
"""
|
||||||
|
|
||||||
agent_type: Literal["rag_search"] = "rag_search" # type: ignore
|
agent_type: Literal["rag_search"] = "rag_search"
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import model_validator, Field # type: ignore
|
from pydantic import model_validator, Field
|
||||||
from typing import (
|
from typing import (
|
||||||
Dict,
|
Dict,
|
||||||
Literal,
|
Literal,
|
||||||
@ -16,7 +16,7 @@ import json
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
import numpy as np # type: ignore
|
import numpy as np
|
||||||
|
|
||||||
from .base import Agent, agent_registry, LLMMessage
|
from .base import Agent, agent_registry, LLMMessage
|
||||||
from models import Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, SkillMatch, Tunables
|
from models import Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, SkillMatch, Tunables
|
||||||
@ -25,7 +25,7 @@ from logger import logger
|
|||||||
import defines
|
import defines
|
||||||
|
|
||||||
class SkillMatchAgent(Agent):
|
class SkillMatchAgent(Agent):
|
||||||
agent_type: Literal["skill_match"] = "skill_match" # type: ignore
|
agent_type: Literal["skill_match"] = "skill_match"
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
def generate_skill_assessment_prompt(self, skill, rag_context):
|
def generate_skill_assessment_prompt(self, skill, rag_context):
|
||||||
|
@ -5,12 +5,12 @@ Provides password hashing, verification, and security features
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
import bcrypt # type: ignore
|
import bcrypt
|
||||||
import secrets
|
import secrets
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone, timedelta
|
from datetime import datetime, timezone, timedelta
|
||||||
from typing import Dict, Any, Optional, Tuple
|
from typing import Dict, Any, Optional, Tuple
|
||||||
from pydantic import BaseModel # type: ignore
|
from pydantic import BaseModel
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
175
src/backend/background_tasks.py
Normal file
175
src/backend/background_tasks.py
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
"""
|
||||||
|
Background tasks for guest cleanup and system maintenance
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import schedule # type: ignore
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from datetime import datetime, timedelta, UTC
|
||||||
|
from typing import Optional
|
||||||
|
from logger import logger
|
||||||
|
from database import DatabaseManager
|
||||||
|
|
||||||
|
class BackgroundTaskManager:
|
||||||
|
"""Manages background tasks for the application"""
|
||||||
|
|
||||||
|
def __init__(self, database_manager: DatabaseManager):
|
||||||
|
self.database_manager = database_manager
|
||||||
|
self.running = False
|
||||||
|
self.tasks = []
|
||||||
|
self.scheduler_thread: Optional[threading.Thread] = None
|
||||||
|
|
||||||
|
async def cleanup_inactive_guests(self, inactive_hours: int = 24):
|
||||||
|
"""Clean up inactive guest sessions"""
|
||||||
|
try:
|
||||||
|
database = self.database_manager.get_database()
|
||||||
|
cleaned_count = await database.cleanup_inactive_guests(inactive_hours)
|
||||||
|
|
||||||
|
if cleaned_count > 0:
|
||||||
|
logger.info(f"🧹 Background cleanup: removed {cleaned_count} inactive guest sessions")
|
||||||
|
|
||||||
|
return cleaned_count
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error in guest cleanup: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def cleanup_expired_verification_tokens(self):
|
||||||
|
"""Clean up expired email verification tokens"""
|
||||||
|
try:
|
||||||
|
database = self.database_manager.get_database()
|
||||||
|
cleaned_count = await database.cleanup_expired_verification_tokens()
|
||||||
|
|
||||||
|
if cleaned_count > 0:
|
||||||
|
logger.info(f"🧹 Background cleanup: removed {cleaned_count} expired verification tokens")
|
||||||
|
|
||||||
|
return cleaned_count
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error in verification token cleanup: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def update_guest_statistics(self):
|
||||||
|
"""Update guest usage statistics"""
|
||||||
|
try:
|
||||||
|
database = self.database_manager.get_database()
|
||||||
|
stats = await database.get_guest_statistics()
|
||||||
|
|
||||||
|
# Log interesting statistics
|
||||||
|
if stats.get('total_guests', 0) > 0:
|
||||||
|
logger.info(f"📊 Guest stats: {stats['total_guests']} total, "
|
||||||
|
f"{stats['active_last_hour']} active in last hour, "
|
||||||
|
f"{stats['converted_guests']} converted")
|
||||||
|
|
||||||
|
return stats
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error updating guest statistics: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def cleanup_old_rate_limit_data(self, days_old: int = 7):
|
||||||
|
"""Clean up old rate limiting data"""
|
||||||
|
try:
|
||||||
|
database = self.database_manager.get_database()
|
||||||
|
redis = database.redis
|
||||||
|
|
||||||
|
# Clean up rate limit keys older than specified days
|
||||||
|
cutoff_time = datetime.now(UTC) - timedelta(days=days_old)
|
||||||
|
pattern = "rate_limit:*"
|
||||||
|
|
||||||
|
cursor = 0
|
||||||
|
deleted_count = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
cursor, keys = await redis.scan(cursor, match=pattern, count=100)
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
# Check if key is old enough to delete
|
||||||
|
try:
|
||||||
|
ttl = await redis.ttl(key)
|
||||||
|
if ttl == -1: # No expiration set, check creation time
|
||||||
|
# For simplicity, delete keys without TTL
|
||||||
|
await redis.delete(key)
|
||||||
|
deleted_count += 1
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
if deleted_count > 0:
|
||||||
|
logger.info(f"🧹 Cleaned up {deleted_count} old rate limit keys")
|
||||||
|
|
||||||
|
return deleted_count
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error cleaning up rate limit data: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def schedule_periodic_tasks(self):
|
||||||
|
"""Schedule periodic background tasks with safer intervals"""
|
||||||
|
|
||||||
|
# Guest cleanup - every 6 hours instead of every hour (less aggressive)
|
||||||
|
schedule.every(6).hours.do(self._run_async_task, self.cleanup_inactive_guests, 48) # 48 hours instead of 24
|
||||||
|
|
||||||
|
# Verification token cleanup - every 12 hours
|
||||||
|
schedule.every(12).hours.do(self._run_async_task, self.cleanup_expired_verification_tokens)
|
||||||
|
|
||||||
|
# Guest statistics update - every hour
|
||||||
|
schedule.every().hour.do(self._run_async_task, self.update_guest_statistics)
|
||||||
|
|
||||||
|
# Rate limit data cleanup - daily at 3 AM
|
||||||
|
schedule.every().day.at("03:00").do(self._run_async_task, self.cleanup_old_rate_limit_data, 7)
|
||||||
|
|
||||||
|
logger.info("📅 Background tasks scheduled with safer intervals")
|
||||||
|
|
||||||
|
def _run_async_task(self, coro_func, *args, **kwargs):
|
||||||
|
"""Run an async task in the background"""
|
||||||
|
try:
|
||||||
|
# Create new event loop for this thread if needed
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
# Run the coroutine
|
||||||
|
loop.run_until_complete(coro_func(*args, **kwargs))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error running background task {coro_func.__name__}: {e}")
|
||||||
|
|
||||||
|
def _scheduler_worker(self):
|
||||||
|
"""Worker thread for running scheduled tasks"""
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
schedule.run_pending()
|
||||||
|
time.sleep(60) # Check every minute
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error in scheduler worker: {e}")
|
||||||
|
time.sleep(60)
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""Start the background task manager"""
|
||||||
|
if self.running:
|
||||||
|
logger.warning("⚠️ Background task manager already running")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.running = True
|
||||||
|
self.schedule_periodic_tasks()
|
||||||
|
|
||||||
|
# Start scheduler thread
|
||||||
|
self.scheduler_thread = threading.Thread(target=self._scheduler_worker, daemon=True)
|
||||||
|
self.scheduler_thread.start()
|
||||||
|
|
||||||
|
logger.info("🚀 Background task manager started")
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""Stop the background task manager"""
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
if self.scheduler_thread and self.scheduler_thread.is_alive():
|
||||||
|
self.scheduler_thread.join(timeout=5)
|
||||||
|
|
||||||
|
# Clear scheduled tasks
|
||||||
|
schedule.clear()
|
||||||
|
|
||||||
|
logger.info("🛑 Background task manager stopped")
|
||||||
|
|
||||||
|
|
@ -9,6 +9,7 @@ from models import (
|
|||||||
# User models
|
# User models
|
||||||
Candidate, Employer, BaseUser, Guest, Authentication, AuthResponse,
|
Candidate, Employer, BaseUser, Guest, Authentication, AuthResponse,
|
||||||
)
|
)
|
||||||
|
import backstory_traceback as traceback
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -198,6 +199,7 @@ class RedisDatabase:
|
|||||||
try:
|
try:
|
||||||
return json.loads(data)
|
return json.loads(data)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
logger.error(f"Failed to deserialize data: {data}")
|
logger.error(f"Failed to deserialize data: {data}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -254,43 +256,44 @@ class RedisDatabase:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_cached_skill_match(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
async def get_cached_skill_match(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||||||
"""Retrieve cached skill match assessment"""
|
"""Get cached skill match assessment"""
|
||||||
try:
|
try:
|
||||||
cached_data = await self.redis.get(cache_key)
|
data = await self.redis.get(cache_key)
|
||||||
if cached_data:
|
if data:
|
||||||
return json.loads(cached_data)
|
return json.loads(data)
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving cached skill match: {e}")
|
logger.error(f"❌ Error getting cached skill match: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def cache_skill_match(self, cache_key: str, assessment_data: Dict[str, Any], ttl: int = 86400 * 30) -> bool:
|
async def cache_skill_match(self, cache_key: str, assessment_data: Dict[str, Any]) -> None:
|
||||||
"""Cache skill match assessment with TTL (default 30 days)"""
|
"""Cache skill match assessment"""
|
||||||
try:
|
try:
|
||||||
|
# Cache for 1 hour by default
|
||||||
await self.redis.setex(
|
await self.redis.setex(
|
||||||
cache_key,
|
cache_key,
|
||||||
ttl,
|
3600,
|
||||||
json.dumps(assessment_data, default=str)
|
json.dumps(assessment_data)
|
||||||
)
|
)
|
||||||
return True
|
logger.debug(f"💾 Skill match cached: {cache_key}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error caching skill match: {e}")
|
logger.error(f"❌ Error caching skill match: {e}")
|
||||||
return False
|
|
||||||
|
|
||||||
async def get_candidate_skill_update_time(self, candidate_id: str) -> Optional[datetime]:
|
async def get_candidate_skill_update_time(self, candidate_id: str) -> Optional[datetime]:
|
||||||
"""Get the last time candidate's skill information was updated"""
|
"""Get the last time candidate skills were updated"""
|
||||||
try:
|
try:
|
||||||
# This assumes you track skill update timestamps in your candidate data
|
|
||||||
candidate_data = await self.get_candidate(candidate_id)
|
candidate_data = await self.get_candidate(candidate_id)
|
||||||
if candidate_data and 'skills_updated_at' in candidate_data:
|
if candidate_data:
|
||||||
return datetime.fromisoformat(candidate_data['skills_updated_at'])
|
updated_at_str = candidate_data.get("updated_at")
|
||||||
|
if updated_at_str:
|
||||||
|
return datetime.fromisoformat(updated_at_str.replace('Z', '+00:00'))
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting candidate skill update time: {e}")
|
logger.error(f"❌ Error getting candidate skill update time: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]:
|
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]:
|
||||||
"""Get the timestamp of the latest RAG data update for a specific user"""
|
"""Get the last time user's RAG data was updated"""
|
||||||
try:
|
try:
|
||||||
rag_update_key = f"user:{user_id}:rag_last_update"
|
rag_update_key = f"user:{user_id}:rag_last_update"
|
||||||
timestamp_str = await self.redis.get(rag_update_key)
|
timestamp_str = await self.redis.get(rag_update_key)
|
||||||
@ -298,8 +301,8 @@ class RedisDatabase:
|
|||||||
return datetime.fromisoformat(timestamp_str.decode('utf-8'))
|
return datetime.fromisoformat(timestamp_str.decode('utf-8'))
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting user RAG update time for user {user_id}: {e}")
|
logger.error(f"❌ Error getting user RAG update time: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def update_user_rag_timestamp(self, user_id: str) -> bool:
|
async def update_user_rag_timestamp(self, user_id: str) -> bool:
|
||||||
"""Update the RAG data timestamp for a specific user (call this when user's RAG data is updated)"""
|
"""Update the RAG data timestamp for a specific user (call this when user's RAG data is updated)"""
|
||||||
@ -359,7 +362,7 @@ class RedisDatabase:
|
|||||||
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]:
|
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]:
|
||||||
"""Get all documents for a specific candidate"""
|
"""Get all documents for a specific candidate"""
|
||||||
key = f"{self.KEY_PREFIXES['candidate_documents']}{candidate_id}"
|
key = f"{self.KEY_PREFIXES['candidate_documents']}{candidate_id}"
|
||||||
document_ids = await self.redis.lrange(key, 0, -1)
|
document_ids = await self.redis.lrange(key, 0, -1)
|
||||||
|
|
||||||
if not document_ids:
|
if not document_ids:
|
||||||
return []
|
return []
|
||||||
@ -1707,7 +1710,11 @@ class RedisDatabase:
|
|||||||
result = {}
|
result = {}
|
||||||
for key, value in zip(keys, values):
|
for key, value in zip(keys, values):
|
||||||
email = key.replace(self.KEY_PREFIXES['users'], '')
|
email = key.replace(self.KEY_PREFIXES['users'], '')
|
||||||
result[email] = self._deserialize(value)
|
logger.info(f"🔍 Found user key: {key}, type: {type(value)}")
|
||||||
|
if type(value) == str:
|
||||||
|
result[email] = value
|
||||||
|
else:
|
||||||
|
result[email] = self._deserialize(value)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -1850,17 +1857,16 @@ class RedisDatabase:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
|
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||||
"""Retrieve user data by ID"""
|
"""Get user lookup data by user ID"""
|
||||||
try:
|
try:
|
||||||
key = f"user_by_id:{user_id}"
|
data = await self.redis.hget("user_lookup_by_id", user_id)
|
||||||
data = await self.redis.get(key)
|
|
||||||
if data:
|
if data:
|
||||||
return json.loads(data)
|
return json.loads(data)
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"❌ Error retrieving user by ID {user_id}: {e}")
|
logger.error(f"❌ Error getting user by ID {user_id}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def user_exists_by_email(self, email: str) -> bool:
|
async def user_exists_by_email(self, email: str) -> bool:
|
||||||
"""Check if a user exists with the given email"""
|
"""Check if a user exists with the given email"""
|
||||||
try:
|
try:
|
||||||
@ -2104,6 +2110,208 @@ class RedisDatabase:
|
|||||||
logger.error(f"❌ Error retrieving security log for {user_id}: {e}")
|
logger.error(f"❌ Error retrieving security log for {user_id}: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
# ============================
|
||||||
|
# Guest Management Methods
|
||||||
|
# ============================
|
||||||
|
|
||||||
|
async def set_guest(self, guest_id: str, guest_data: Dict[str, Any]) -> None:
|
||||||
|
"""Store guest data with enhanced persistence"""
|
||||||
|
try:
|
||||||
|
# Ensure last_activity is always set
|
||||||
|
guest_data["last_activity"] = datetime.now(UTC).isoformat()
|
||||||
|
|
||||||
|
# Store in Redis with both hash and individual key for redundancy
|
||||||
|
await self.redis.hset("guests", guest_id, json.dumps(guest_data))
|
||||||
|
|
||||||
|
# Also store with a longer TTL as backup
|
||||||
|
await self.redis.setex(
|
||||||
|
f"guest_backup:{guest_id}",
|
||||||
|
86400 * 7, # 7 days TTL
|
||||||
|
json.dumps(guest_data)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"💾 Guest stored with backup: {guest_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error storing guest {guest_id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_guest(self, guest_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get guest data with fallback to backup"""
|
||||||
|
try:
|
||||||
|
# Try primary storage first
|
||||||
|
data = await self.redis.hget("guests", guest_id)
|
||||||
|
if data:
|
||||||
|
guest_data = json.loads(data)
|
||||||
|
# Update last activity when accessed
|
||||||
|
guest_data["last_activity"] = datetime.now(UTC).isoformat()
|
||||||
|
await self.set_guest(guest_id, guest_data)
|
||||||
|
logger.debug(f"🔍 Guest found in primary storage: {guest_id}")
|
||||||
|
return guest_data
|
||||||
|
|
||||||
|
# Fallback to backup storage
|
||||||
|
backup_data = await self.redis.get(f"guest_backup:{guest_id}")
|
||||||
|
if backup_data:
|
||||||
|
guest_data = json.loads(backup_data)
|
||||||
|
guest_data["last_activity"] = datetime.now(UTC).isoformat()
|
||||||
|
|
||||||
|
# Restore to primary storage
|
||||||
|
await self.set_guest(guest_id, guest_data)
|
||||||
|
logger.info(f"🔄 Guest restored from backup: {guest_id}")
|
||||||
|
return guest_data
|
||||||
|
|
||||||
|
logger.warning(f"⚠️ Guest not found: {guest_id}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error getting guest {guest_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_guest_by_session_id(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Get guest data by session ID"""
|
||||||
|
try:
|
||||||
|
all_guests = await self.get_all_guests()
|
||||||
|
for guest_data in all_guests.values():
|
||||||
|
if guest_data.get("session_id") == session_id:
|
||||||
|
return guest_data
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error getting guest by session ID {session_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_all_guests(self) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""Get all guests"""
|
||||||
|
try:
|
||||||
|
data = await self.redis.hgetall("guests")
|
||||||
|
return {
|
||||||
|
guest_id: json.loads(guest_json)
|
||||||
|
for guest_id, guest_json in data.items()
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error getting all guests: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def delete_guest(self, guest_id: str) -> bool:
|
||||||
|
"""Delete a guest"""
|
||||||
|
try:
|
||||||
|
result = await self.redis.hdel("guests", guest_id)
|
||||||
|
if result:
|
||||||
|
logger.info(f"🗑️ Guest deleted: {guest_id}")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error deleting guest {guest_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def cleanup_inactive_guests(self, inactive_hours: int = 24) -> int:
|
||||||
|
"""Clean up inactive guest sessions with safety checks"""
|
||||||
|
try:
|
||||||
|
all_guests = await self.get_all_guests()
|
||||||
|
current_time = datetime.now(UTC)
|
||||||
|
cutoff_time = current_time - timedelta(hours=inactive_hours)
|
||||||
|
|
||||||
|
deleted_count = 0
|
||||||
|
preserved_count = 0
|
||||||
|
|
||||||
|
for guest_id, guest_data in all_guests.items():
|
||||||
|
try:
|
||||||
|
last_activity_str = guest_data.get("last_activity")
|
||||||
|
created_at_str = guest_data.get("created_at")
|
||||||
|
|
||||||
|
# Skip cleanup if guest is very new (less than 1 hour old)
|
||||||
|
if created_at_str:
|
||||||
|
created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
|
||||||
|
if current_time - created_at < timedelta(hours=1):
|
||||||
|
preserved_count += 1
|
||||||
|
logger.debug(f"🛡️ Preserving new guest: {guest_id}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check last activity
|
||||||
|
should_delete = False
|
||||||
|
if last_activity_str:
|
||||||
|
try:
|
||||||
|
last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00'))
|
||||||
|
if last_activity < cutoff_time:
|
||||||
|
should_delete = True
|
||||||
|
except ValueError:
|
||||||
|
# Invalid date format, but don't delete if guest is new
|
||||||
|
if not created_at_str:
|
||||||
|
should_delete = True
|
||||||
|
else:
|
||||||
|
# No last activity, but don't delete if guest is new
|
||||||
|
if not created_at_str:
|
||||||
|
should_delete = True
|
||||||
|
|
||||||
|
if should_delete:
|
||||||
|
await self.delete_guest(guest_id)
|
||||||
|
deleted_count += 1
|
||||||
|
else:
|
||||||
|
preserved_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error processing guest {guest_id} for cleanup: {e}")
|
||||||
|
preserved_count += 1 # Preserve on error
|
||||||
|
|
||||||
|
if deleted_count > 0:
|
||||||
|
logger.info(f"🧹 Guest cleanup: removed {deleted_count}, preserved {preserved_count}")
|
||||||
|
|
||||||
|
return deleted_count
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error in guest cleanup: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def get_guest_statistics(self) -> Dict[str, Any]:
|
||||||
|
"""Get guest usage statistics"""
|
||||||
|
try:
|
||||||
|
all_guests = await self.get_all_guests()
|
||||||
|
current_time = datetime.now(UTC)
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
"total_guests": len(all_guests),
|
||||||
|
"active_last_hour": 0,
|
||||||
|
"active_last_day": 0,
|
||||||
|
"converted_guests": 0,
|
||||||
|
"by_ip": {},
|
||||||
|
"creation_timeline": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
hour_ago = current_time - timedelta(hours=1)
|
||||||
|
day_ago = current_time - timedelta(days=1)
|
||||||
|
|
||||||
|
for guest_data in all_guests.values():
|
||||||
|
# Check activity
|
||||||
|
last_activity_str = guest_data.get("last_activity")
|
||||||
|
if last_activity_str:
|
||||||
|
try:
|
||||||
|
last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00'))
|
||||||
|
if last_activity > hour_ago:
|
||||||
|
stats["active_last_hour"] += 1
|
||||||
|
if last_activity > day_ago:
|
||||||
|
stats["active_last_day"] += 1
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Check conversions
|
||||||
|
if guest_data.get("converted_to_user_id"):
|
||||||
|
stats["converted_guests"] += 1
|
||||||
|
|
||||||
|
# IP tracking
|
||||||
|
ip = guest_data.get("ip_address", "unknown")
|
||||||
|
stats["by_ip"][ip] = stats["by_ip"].get(ip, 0) + 1
|
||||||
|
|
||||||
|
# Creation timeline
|
||||||
|
created_at_str = guest_data.get("created_at")
|
||||||
|
if created_at_str:
|
||||||
|
try:
|
||||||
|
created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
|
||||||
|
date_key = created_at.strftime('%Y-%m-%d')
|
||||||
|
stats["creation_timeline"][date_key] = stats["creation_timeline"].get(date_key, 0) + 1
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return stats
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Error getting guest statistics: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
# Global Redis manager instance
|
# Global Redis manager instance
|
||||||
redis_manager = _RedisManager()
|
redis_manager = _RedisManager()
|
||||||
|
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
from fastapi import FastAPI, HTTPException, Depends, Query, Path, Body, status, APIRouter, Request, BackgroundTasks # type: ignore
|
from fastapi import FastAPI, HTTPException, Depends, Query, Path, Body, status, APIRouter, Request, BackgroundTasks
|
||||||
from database import RedisDatabase
|
from database import RedisDatabase
|
||||||
import hashlib
|
import hashlib
|
||||||
from logger import logger
|
from logger import logger
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from user_agents import parse # type: ignore
|
from user_agents import parse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
class DeviceManager:
|
class DeviceManager:
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from logger import logger
|
from logger import logger
|
||||||
from email.mime.text import MIMEText # type: ignore
|
from email.mime.text import MIMEText
|
||||||
from email.mime.multipart import MIMEMultipart # type: ignore
|
from email.mime.multipart import MIMEMultipart
|
||||||
import smtplib
|
import smtplib
|
||||||
import asyncio
|
import asyncio
|
||||||
from email_templates import EMAIL_TEMPLATES
|
from email_templates import EMAIL_TEMPLATES
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING, Literal
|
from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING, Literal
|
||||||
|
|
||||||
from typing_extensions import Annotated, Union
|
from typing_extensions import Annotated, Union
|
||||||
import numpy as np # type: ignore
|
import numpy as np
|
||||||
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
from prometheus_client import CollectorRegistry, Counter
|
||||||
import traceback
|
import traceback
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
@ -3,12 +3,12 @@ import weakref
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Dict, Optional, Any
|
from typing import Dict, Optional, Any
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from pydantic import BaseModel, Field # type: ignore
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from models import ( Candidate )
|
from models import ( Candidate )
|
||||||
from .candidate_entity import CandidateEntity
|
from .candidate_entity import CandidateEntity
|
||||||
from database import RedisDatabase
|
from database import RedisDatabase
|
||||||
from prometheus_client import CollectorRegistry # type: ignore
|
from prometheus_client import CollectorRegistry
|
||||||
|
|
||||||
class EntityManager:
|
class EntityManager:
|
||||||
"""Manages lifecycle of CandidateEntity instances"""
|
"""Manages lifecycle of CandidateEntity instances"""
|
||||||
|
@ -64,7 +64,7 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
|
|||||||
sys.path.insert(0, current_dir)
|
sys.path.insert(0, current_dir)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from pydantic import BaseModel # type: ignore
|
from pydantic import BaseModel
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(f"Error importing pydantic: {e}")
|
print(f"Error importing pydantic: {e}")
|
||||||
print("Make sure pydantic is installed: pip install pydantic")
|
print("Make sure pydantic is installed: pip install pydantic")
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from pydantic import BaseModel, Field # type: ignore
|
from pydantic import BaseModel, Field
|
||||||
import json
|
import json
|
||||||
from typing import Any, List, Set
|
from typing import Any, List, Set
|
||||||
|
|
||||||
|
@ -4,8 +4,8 @@ import re
|
|||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch # type: ignore
|
import torch
|
||||||
from diffusers import StableDiffusionPipeline, FluxPipeline # type: ignore
|
from diffusers import StableDiffusionPipeline, FluxPipeline
|
||||||
|
|
||||||
class ImageModelCache: # Stay loaded for 3 hours
|
class ImageModelCache: # Stay loaded for 3 hours
|
||||||
def __init__(self, timeout_seconds: float = 3 * 60 * 60):
|
def __init__(self, timeout_seconds: float = 3 * 60 * 60):
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from pydantic import BaseModel, Field # type: ignore
|
from pydantic import BaseModel, Field
|
||||||
from typing import Dict, Literal, Any, AsyncGenerator, Optional
|
from typing import Dict, Literal, Any, AsyncGenerator, Optional
|
||||||
import inspect
|
import inspect
|
||||||
import random
|
import random
|
||||||
@ -13,7 +13,7 @@ import os
|
|||||||
import gc
|
import gc
|
||||||
import tempfile
|
import tempfile
|
||||||
import uuid
|
import uuid
|
||||||
import torch # type: ignore
|
import torch
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Any, AsyncGenerator, Optional, Union
|
from typing import Dict, List, Any, AsyncGenerator, Optional, Union
|
||||||
from pydantic import BaseModel, Field # type: ignore
|
from pydantic import BaseModel, Field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
@ -179,7 +179,7 @@ class OllamaAdapter(BaseLLMAdapter):
|
|||||||
def __init__(self, **config):
|
def __init__(self, **config):
|
||||||
super().__init__(**config)
|
super().__init__(**config)
|
||||||
import ollama
|
import ollama
|
||||||
self.client = ollama.AsyncClient( # type: ignore
|
self.client = ollama.AsyncClient(
|
||||||
host=config.get('host', defines.ollama_api_url)
|
host=config.get('host', defines.ollama_api_url)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -386,7 +386,7 @@ class OpenAIAdapter(BaseLLMAdapter):
|
|||||||
|
|
||||||
def __init__(self, **config):
|
def __init__(self, **config):
|
||||||
super().__init__(**config)
|
super().__init__(**config)
|
||||||
import openai # type: ignore
|
import openai
|
||||||
self.client = openai.AsyncOpenAI(
|
self.client = openai.AsyncOpenAI(
|
||||||
api_key=config.get('api_key', os.getenv('OPENAI_API_KEY'))
|
api_key=config.get('api_key', os.getenv('OPENAI_API_KEY'))
|
||||||
)
|
)
|
||||||
@ -522,7 +522,7 @@ class AnthropicAdapter(BaseLLMAdapter):
|
|||||||
|
|
||||||
def __init__(self, **config):
|
def __init__(self, **config):
|
||||||
super().__init__(**config)
|
super().__init__(**config)
|
||||||
import anthropic # type: ignore
|
import anthropic
|
||||||
self.client = anthropic.AsyncAnthropic(
|
self.client = anthropic.AsyncAnthropic(
|
||||||
api_key=config.get('api_key', os.getenv('ANTHROPIC_API_KEY'))
|
api_key=config.get('api_key', os.getenv('ANTHROPIC_API_KEY'))
|
||||||
)
|
)
|
||||||
@ -656,7 +656,7 @@ class GeminiAdapter(BaseLLMAdapter):
|
|||||||
|
|
||||||
def __init__(self, **config):
|
def __init__(self, **config):
|
||||||
super().__init__(**config)
|
super().__init__(**config)
|
||||||
import google.generativeai as genai # type: ignore
|
import google.generativeai as genai
|
||||||
genai.configure(api_key=config.get('api_key', os.getenv('GEMINI_API_KEY')))
|
genai.configure(api_key=config.get('api_key', os.getenv('GEMINI_API_KEY')))
|
||||||
self.genai = genai
|
self.genai = genai
|
||||||
|
|
||||||
@ -867,7 +867,7 @@ class UnifiedLLMProxy:
|
|||||||
raise ValueError("stream must be True for chat_stream")
|
raise ValueError("stream must be True for chat_stream")
|
||||||
result = await self.chat(model, messages, provider, stream=True, **kwargs)
|
result = await self.chat(model, messages, provider, stream=True, **kwargs)
|
||||||
# Type checker now knows this is an AsyncGenerator due to stream=True
|
# Type checker now knows this is an AsyncGenerator due to stream=True
|
||||||
async for chunk in result: # type: ignore
|
async for chunk in result:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def chat_single(
|
async def chat_single(
|
||||||
@ -881,7 +881,7 @@ class UnifiedLLMProxy:
|
|||||||
|
|
||||||
result = await self.chat(model, messages, provider, stream=False, **kwargs)
|
result = await self.chat(model, messages, provider, stream=False, **kwargs)
|
||||||
# Type checker now knows this is a ChatResponse due to stream=False
|
# Type checker now knows this is a ChatResponse due to stream=False
|
||||||
return result # type: ignore
|
return result
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
@ -908,7 +908,7 @@ class UnifiedLLMProxy:
|
|||||||
"""Stream text generation using specified or default provider"""
|
"""Stream text generation using specified or default provider"""
|
||||||
|
|
||||||
result = await self.generate(model, prompt, provider, stream=True, **kwargs)
|
result = await self.generate(model, prompt, provider, stream=True, **kwargs)
|
||||||
async for chunk in result: # type: ignore
|
async for chunk in result:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def generate_single(
|
async def generate_single(
|
||||||
@ -921,7 +921,7 @@ class UnifiedLLMProxy:
|
|||||||
"""Get single generation response using specified or default provider"""
|
"""Get single generation response using specified or default provider"""
|
||||||
|
|
||||||
result = await self.generate(model, prompt, provider, stream=False, **kwargs)
|
result = await self.generate(model, prompt, provider, stream=False, **kwargs)
|
||||||
return result # type: ignore
|
return result
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
@ -1148,7 +1148,7 @@ async def example_embeddings_usage():
|
|||||||
|
|
||||||
# Calculate similarity between first two texts (requires numpy)
|
# Calculate similarity between first two texts (requires numpy)
|
||||||
try:
|
try:
|
||||||
import numpy as np # type: ignore
|
import numpy as np
|
||||||
emb1 = np.array(response.data[0].embedding)
|
emb1 = np.array(response.data[0].embedding)
|
||||||
emb2 = np.array(response.data[1].embedding)
|
emb2 = np.array(response.data[1].embedding)
|
||||||
similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
|
similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
|
||||||
|
1411
src/backend/main.py
1411
src/backend/main.py
File diff suppressed because it is too large
Load Diff
@ -1,7 +1,14 @@
|
|||||||
from typing import Type, TypeVar
|
from typing import Type, TypeVar
|
||||||
from pydantic import BaseModel # type: ignore
|
from pydantic import BaseModel
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
from models import Candidate, CandidateAI, Employer, Guest, BaseUserWithType
|
||||||
|
|
||||||
|
# Ensure all user models inherit from BaseUserWithType
|
||||||
|
assert issubclass(Candidate, BaseUserWithType), "Candidate must inherit from BaseUserWithType"
|
||||||
|
assert issubclass(CandidateAI, BaseUserWithType), "CandidateAI must inherit from BaseUserWithType"
|
||||||
|
assert issubclass(Employer, BaseUserWithType), "Employer must inherit from BaseUserWithType"
|
||||||
|
assert issubclass(Guest, BaseUserWithType), "Guest must inherit from BaseUserWithType"
|
||||||
|
|
||||||
T = TypeVar('T', bound=BaseModel)
|
T = TypeVar('T', bound=BaseModel)
|
||||||
|
|
||||||
@ -12,3 +19,35 @@ def cast_to_model(model_cls: Type[T], source: BaseModel) -> T:
|
|||||||
def cast_to_model_safe(model_cls: Type[T], source: BaseModel) -> T:
|
def cast_to_model_safe(model_cls: Type[T], source: BaseModel) -> T:
|
||||||
data = {field: copy.deepcopy(getattr(source, field)) for field in model_cls.__fields__}
|
data = {field: copy.deepcopy(getattr(source, field)) for field in model_cls.__fields__}
|
||||||
return model_cls(**data)
|
return model_cls(**data)
|
||||||
|
|
||||||
|
def cast_to_base_user_with_type(user) -> BaseUserWithType:
|
||||||
|
"""
|
||||||
|
Casts a Candidate, CandidateAI, Employer, or Guest to BaseUserWithType.
|
||||||
|
This is useful for FastAPI dependencies that expect a common user type.
|
||||||
|
"""
|
||||||
|
if isinstance(user, BaseUserWithType):
|
||||||
|
return user
|
||||||
|
# If it's a dict, try to detect type
|
||||||
|
if isinstance(user, dict):
|
||||||
|
user_type = user.get("user_type") or user.get("type")
|
||||||
|
if user_type == "candidate":
|
||||||
|
if user.get("is_AI"):
|
||||||
|
return CandidateAI.model_validate(user)
|
||||||
|
return Candidate.model_validate(user)
|
||||||
|
elif user_type == "employer":
|
||||||
|
return Employer.model_validate(user)
|
||||||
|
elif user_type == "guest":
|
||||||
|
return Guest.model_validate(user)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown user_type: {user_type}")
|
||||||
|
# If it's a model, check its type
|
||||||
|
if hasattr(user, "user_type"):
|
||||||
|
if getattr(user, "user_type", None) == "candidate":
|
||||||
|
if getattr(user, "is_AI", False):
|
||||||
|
return CandidateAI.model_validate(user.model_dump())
|
||||||
|
return Candidate.model_validate(user.model_dump())
|
||||||
|
elif getattr(user, "user_type", None) == "employer":
|
||||||
|
return Employer.model_validate(user.model_dump())
|
||||||
|
elif getattr(user, "user_type", None) == "guest":
|
||||||
|
return Guest.model_validate(user.model_dump())
|
||||||
|
raise TypeError(f"Cannot cast object of type {type(user)} to BaseUserWithType")
|
||||||
|
@ -10,6 +10,7 @@ from auth_utils import (
|
|||||||
sanitize_login_input,
|
sanitize_login_input,
|
||||||
SecurityConfig
|
SecurityConfig
|
||||||
)
|
)
|
||||||
|
import defines
|
||||||
|
|
||||||
# Generic type variable
|
# Generic type variable
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
@ -256,6 +257,7 @@ class MFARequest(BaseModel):
|
|||||||
password: str
|
password: str
|
||||||
device_id: str = Field(..., alias="deviceId")
|
device_id: str = Field(..., alias="deviceId")
|
||||||
device_name: str = Field(..., alias="deviceName")
|
device_name: str = Field(..., alias="deviceName")
|
||||||
|
email: str = Field(..., alias="email")
|
||||||
model_config = {
|
model_config = {
|
||||||
"populate_by_name": True, # Allow both field names and aliases
|
"populate_by_name": True, # Allow both field names and aliases
|
||||||
}
|
}
|
||||||
@ -464,9 +466,18 @@ class ErrorDetail(BaseModel):
|
|||||||
# Main Models
|
# Main Models
|
||||||
# ============================
|
# ============================
|
||||||
|
|
||||||
# Base user model without user_type field
|
# Generic base user with user_type for API responses
|
||||||
class BaseUser(BaseModel):
|
class BaseUserWithType(BaseModel):
|
||||||
|
user_type: UserType = Field(..., alias="userType")
|
||||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
|
last_activity: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="lastActivity")
|
||||||
|
model_config = {
|
||||||
|
"populate_by_name": True, # Allow both field names and aliases
|
||||||
|
"use_enum_values": True # Use enum values instead of names
|
||||||
|
}
|
||||||
|
|
||||||
|
# Base user model without user_type field
|
||||||
|
class BaseUser(BaseUserWithType):
|
||||||
email: EmailStr
|
email: EmailStr
|
||||||
first_name: str = Field(..., alias="firstName")
|
first_name: str = Field(..., alias="firstName")
|
||||||
last_name: str = Field(..., alias="lastName")
|
last_name: str = Field(..., alias="lastName")
|
||||||
@ -485,9 +496,6 @@ class BaseUser(BaseModel):
|
|||||||
"use_enum_values": True # Use enum values instead of names
|
"use_enum_values": True # Use enum values instead of names
|
||||||
}
|
}
|
||||||
|
|
||||||
# Generic base user with user_type for API responses
|
|
||||||
class BaseUserWithType(BaseUser):
|
|
||||||
user_type: UserType = Field(..., alias="userType")
|
|
||||||
|
|
||||||
class RagEntry(BaseModel):
|
class RagEntry(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
@ -519,9 +527,9 @@ class DocumentType(str, Enum):
|
|||||||
IMAGE = "image"
|
IMAGE = "image"
|
||||||
|
|
||||||
class DocumentOptions(BaseModel):
|
class DocumentOptions(BaseModel):
|
||||||
include_in_RAG: Optional[bool] = Field(True, alias="includeInRAG")
|
include_in_RAG: bool = Field(default=True, alias="includeInRAG")
|
||||||
is_job_document: Optional[bool] = Field(False, alias="isJobDocument")
|
is_job_document: Optional[bool] = Field(default=False, alias="isJobDocument")
|
||||||
overwrite: Optional[bool] = Field(False, alias="overwrite")
|
overwrite: Optional[bool] = Field(default=False, alias="overwrite")
|
||||||
model_config = {
|
model_config = {
|
||||||
"populate_by_name": True # Allow both field names and aliases
|
"populate_by_name": True # Allow both field names and aliases
|
||||||
}
|
}
|
||||||
@ -534,7 +542,7 @@ class Document(BaseModel):
|
|||||||
type: DocumentType
|
type: DocumentType
|
||||||
size: int
|
size: int
|
||||||
upload_date: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="uploadDate")
|
upload_date: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="uploadDate")
|
||||||
options: DocumentOptions = Field(default_factory=DocumentOptions, alias="options")
|
options: DocumentOptions = Field(default_factory=lambda: DocumentOptions(), alias="options")
|
||||||
rag_chunks: Optional[int] = Field(default=0, alias="ragChunks")
|
rag_chunks: Optional[int] = Field(default=0, alias="ragChunks")
|
||||||
model_config = {
|
model_config = {
|
||||||
"populate_by_name": True # Allow both field names and aliases
|
"populate_by_name": True # Allow both field names and aliases
|
||||||
@ -565,7 +573,7 @@ class DocumentUpdateRequest(BaseModel):
|
|||||||
}
|
}
|
||||||
|
|
||||||
class Candidate(BaseUser):
|
class Candidate(BaseUser):
|
||||||
user_type: Literal[UserType.CANDIDATE] = Field(UserType.CANDIDATE, alias="userType")
|
user_type: UserType = Field(UserType.CANDIDATE, alias="userType")
|
||||||
username: str
|
username: str
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
resume: Optional[str] = None
|
resume: Optional[str] = None
|
||||||
@ -584,14 +592,14 @@ class Candidate(BaseUser):
|
|||||||
rag_content_size : int = 0
|
rag_content_size : int = 0
|
||||||
|
|
||||||
class CandidateAI(Candidate):
|
class CandidateAI(Candidate):
|
||||||
user_type: Literal[UserType.CANDIDATE] = Field(UserType.CANDIDATE, alias="userType")
|
user_type: UserType = Field(UserType.CANDIDATE, alias="userType")
|
||||||
is_AI: bool = Field(True, alias="isAI")
|
is_AI: bool = Field(True, alias="isAI")
|
||||||
age: Optional[int] = None
|
age: Optional[int] = None
|
||||||
gender: Optional[UserGender] = None
|
gender: Optional[UserGender] = None
|
||||||
ethnicity: Optional[str] = None
|
ethnicity: Optional[str] = None
|
||||||
|
|
||||||
class Employer(BaseUser):
|
class Employer(BaseUser):
|
||||||
user_type: Literal[UserType.EMPLOYER] = Field(UserType.EMPLOYER, alias="userType")
|
user_type: UserType = Field(UserType.EMPLOYER, alias="userType")
|
||||||
company_name: str = Field(..., alias="companyName")
|
company_name: str = Field(..., alias="companyName")
|
||||||
industry: str
|
industry: str
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
@ -603,16 +611,18 @@ class Employer(BaseUser):
|
|||||||
social_links: Optional[List[SocialLink]] = Field(None, alias="socialLinks")
|
social_links: Optional[List[SocialLink]] = Field(None, alias="socialLinks")
|
||||||
poc: Optional[PointOfContact] = None
|
poc: Optional[PointOfContact] = None
|
||||||
|
|
||||||
class Guest(BaseModel):
|
class Guest(BaseUser):
|
||||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
user_type: UserType = Field(UserType.GUEST, alias="userType")
|
||||||
session_id: str = Field(..., alias="sessionId")
|
session_id: str = Field(..., alias="sessionId")
|
||||||
created_at: datetime = Field(..., alias="createdAt")
|
username: str # Add username for consistency with other user types
|
||||||
last_activity: datetime = Field(..., alias="lastActivity")
|
|
||||||
converted_to_user_id: Optional[str] = Field(None, alias="convertedToUserId")
|
converted_to_user_id: Optional[str] = Field(None, alias="convertedToUserId")
|
||||||
ip_address: Optional[str] = Field(None, alias="ipAddress")
|
ip_address: Optional[str] = Field(None, alias="ipAddress")
|
||||||
|
created_at: datetime = Field(..., alias="createdAt")
|
||||||
user_agent: Optional[str] = Field(None, alias="userAgent")
|
user_agent: Optional[str] = Field(None, alias="userAgent")
|
||||||
|
rag_content_size: int = 0
|
||||||
model_config = {
|
model_config = {
|
||||||
"populate_by_name": True # Allow both field names and aliases
|
"populate_by_name": True, # Allow both field names and aliases
|
||||||
|
"use_enum_values": True # Use enum values instead of names
|
||||||
}
|
}
|
||||||
|
|
||||||
class Authentication(BaseModel):
|
class Authentication(BaseModel):
|
||||||
@ -635,10 +645,21 @@ class Authentication(BaseModel):
|
|||||||
class AuthResponse(BaseModel):
|
class AuthResponse(BaseModel):
|
||||||
access_token: str = Field(..., alias="accessToken")
|
access_token: str = Field(..., alias="accessToken")
|
||||||
refresh_token: str = Field(..., alias="refreshToken")
|
refresh_token: str = Field(..., alias="refreshToken")
|
||||||
user: Candidate | Employer
|
user: Union[Candidate, Employer, Guest] # Add Guest support
|
||||||
expires_at: int = Field(..., alias="expiresAt")
|
expires_at: int = Field(..., alias="expiresAt")
|
||||||
|
user_type: Optional[str] = Field(default=UserType.GUEST, alias="userType") # Explicit user type
|
||||||
|
is_guest: Optional[bool] = Field(default=True, alias="isGuest") # Guest indicator
|
||||||
|
|
||||||
model_config = {
|
model_config = {
|
||||||
"populate_by_name": True # Allow both field names and aliases
|
"populate_by_name": True
|
||||||
|
}
|
||||||
|
|
||||||
|
class GuestCleanupRequest(BaseModel):
|
||||||
|
"""Request to cleanup inactive guests"""
|
||||||
|
inactive_hours: int = Field(24, alias="inactiveHours")
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"populate_by_name": True
|
||||||
}
|
}
|
||||||
|
|
||||||
class JobRequirements(BaseModel):
|
class JobRequirements(BaseModel):
|
||||||
@ -751,6 +772,19 @@ class ChromaDBGetResponse(BaseModel):
|
|||||||
umap_embedding_2d: Optional[List[float]] = Field(default=None, alias="umapEmbedding2D")
|
umap_embedding_2d: Optional[List[float]] = Field(default=None, alias="umapEmbedding2D")
|
||||||
umap_embedding_3d: Optional[List[float]] = Field(default=None, alias="umapEmbedding3D")
|
umap_embedding_3d: Optional[List[float]] = Field(default=None, alias="umapEmbedding3D")
|
||||||
|
|
||||||
|
class GuestSessionResponse(BaseModel):
|
||||||
|
"""Response for guest session creation"""
|
||||||
|
access_token: str = Field(..., alias="accessToken")
|
||||||
|
refresh_token: str = Field(..., alias="refreshToken")
|
||||||
|
user: Guest
|
||||||
|
expires_at: int = Field(..., alias="expiresAt")
|
||||||
|
user_type: Literal["guest"] = Field("guest", alias="userType")
|
||||||
|
is_guest: bool = Field(True, alias="isGuest")
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"populate_by_name": True
|
||||||
|
}
|
||||||
|
|
||||||
class ChatContext(BaseModel):
|
class ChatContext(BaseModel):
|
||||||
type: ChatContextType
|
type: ChatContextType
|
||||||
related_entity_id: Optional[str] = Field(None, alias="relatedEntityId")
|
related_entity_id: Optional[str] = Field(None, alias="relatedEntityId")
|
||||||
@ -768,6 +802,98 @@ class ChatOptions(BaseModel):
|
|||||||
"populate_by_name": True # Allow both field names and aliases
|
"populate_by_name": True # Allow both field names and aliases
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Add rate limiting configuration models
|
||||||
|
class RateLimitConfig(BaseModel):
|
||||||
|
"""Rate limit configuration"""
|
||||||
|
requests_per_minute: int = Field(..., alias="requestsPerMinute")
|
||||||
|
requests_per_hour: int = Field(..., alias="requestsPerHour")
|
||||||
|
requests_per_day: int = Field(..., alias="requestsPerDay")
|
||||||
|
burst_limit: int = Field(..., alias="burstLimit")
|
||||||
|
burst_window_seconds: int = Field(60, alias="burstWindowSeconds")
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"populate_by_name": True
|
||||||
|
}
|
||||||
|
|
||||||
|
class RateLimitResult(BaseModel):
|
||||||
|
"""Result of rate limit check"""
|
||||||
|
allowed: bool
|
||||||
|
reason: Optional[str] = None
|
||||||
|
retry_after_seconds: Optional[int] = Field(None, alias="retryAfterSeconds")
|
||||||
|
remaining_requests: Dict[str, int] = Field(default_factory=dict, alias="remainingRequests")
|
||||||
|
reset_times: Dict[str, datetime] = Field(default_factory=dict, alias="resetTimes")
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"populate_by_name": True
|
||||||
|
}
|
||||||
|
|
||||||
|
class RateLimitStatus(BaseModel):
|
||||||
|
"""Rate limit status for a user"""
|
||||||
|
user_id: str = Field(..., alias="userId")
|
||||||
|
user_type: str = Field(..., alias="userType")
|
||||||
|
is_admin: bool = Field(..., alias="isAdmin")
|
||||||
|
current_usage: Dict[str, int] = Field(..., alias="currentUsage")
|
||||||
|
limits: Dict[str, int] = Field(..., alias="limits")
|
||||||
|
remaining: Dict[str, int] = Field(..., alias="remaining")
|
||||||
|
reset_times: Dict[str, datetime] = Field(..., alias="resetTimes")
|
||||||
|
config: RateLimitConfig
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"populate_by_name": True
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Add guest conversion request models
|
||||||
|
class GuestConversionRequest(BaseModel):
|
||||||
|
"""Request to convert guest to permanent user"""
|
||||||
|
account_type: Literal["candidate", "employer"] = Field(..., alias="accountType")
|
||||||
|
email: EmailStr
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
first_name: str = Field(..., alias="firstName")
|
||||||
|
last_name: str = Field(..., alias="lastName")
|
||||||
|
phone: Optional[str] = None
|
||||||
|
|
||||||
|
# Employer-specific fields (optional)
|
||||||
|
company_name: Optional[str] = Field(None, alias="companyName")
|
||||||
|
industry: Optional[str] = None
|
||||||
|
company_size: Optional[str] = Field(None, alias="companySize")
|
||||||
|
company_description: Optional[str] = Field(None, alias="companyDescription")
|
||||||
|
website_url: Optional[HttpUrl] = Field(None, alias="websiteUrl")
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"populate_by_name": True
|
||||||
|
}
|
||||||
|
|
||||||
|
@field_validator('username')
|
||||||
|
def validate_username(cls, v):
|
||||||
|
if not v or len(v.strip()) < 3:
|
||||||
|
raise ValueError('Username must be at least 3 characters long')
|
||||||
|
return v.strip().lower()
|
||||||
|
|
||||||
|
@field_validator('password')
|
||||||
|
def validate_password_strength(cls, v):
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from auth_utils import validate_password_strength
|
||||||
|
is_valid, issues = validate_password_strength(v)
|
||||||
|
if not is_valid:
|
||||||
|
raise ValueError('; '.join(issues))
|
||||||
|
return v
|
||||||
|
|
||||||
|
# Add guest statistics response model
|
||||||
|
class GuestStatistics(BaseModel):
|
||||||
|
"""Guest usage statistics"""
|
||||||
|
total_guests: int = Field(..., alias="totalGuests")
|
||||||
|
active_last_hour: int = Field(..., alias="activeLastHour")
|
||||||
|
active_last_day: int = Field(..., alias="activeLastDay")
|
||||||
|
converted_guests: int = Field(..., alias="convertedGuests")
|
||||||
|
by_ip: Dict[str, int] = Field(..., alias="byIp")
|
||||||
|
creation_timeline: Dict[str, int] = Field(..., alias="creationTimeline")
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"populate_by_name": True
|
||||||
|
}
|
||||||
|
|
||||||
from llm_proxy import (LLMMessage)
|
from llm_proxy import (LLMMessage)
|
||||||
|
|
||||||
class ApiMessage(BaseModel):
|
class ApiMessage(BaseModel):
|
||||||
@ -800,12 +926,14 @@ class ApiActivityType(str, Enum):
|
|||||||
HEARTBEAT = "heartbeat" # Used for periodic updates
|
HEARTBEAT = "heartbeat" # Used for periodic updates
|
||||||
|
|
||||||
class ChatMessageStatus(ApiMessage):
|
class ChatMessageStatus(ApiMessage):
|
||||||
|
sender_id: Optional[str] = Field(default=MOCK_UUID, alias="senderId")
|
||||||
status: ApiStatusType = ApiStatusType.STATUS
|
status: ApiStatusType = ApiStatusType.STATUS
|
||||||
type: ApiMessageType = ApiMessageType.TEXT
|
type: ApiMessageType = ApiMessageType.TEXT
|
||||||
activity: ApiActivityType
|
activity: ApiActivityType
|
||||||
content: Any
|
content: Any
|
||||||
|
|
||||||
class ChatMessageError(ApiMessage):
|
class ChatMessageError(ApiMessage):
|
||||||
|
sender_id: Optional[str] = Field(default=MOCK_UUID, alias="senderId")
|
||||||
status: ApiStatusType = ApiStatusType.ERROR
|
status: ApiStatusType = ApiStatusType.ERROR
|
||||||
type: ApiMessageType = ApiMessageType.TEXT
|
type: ApiMessageType = ApiMessageType.TEXT
|
||||||
content: str
|
content: str
|
||||||
@ -825,6 +953,7 @@ class JobRequirementsMessage(ApiMessage):
|
|||||||
|
|
||||||
class DocumentMessage(ApiMessage):
|
class DocumentMessage(ApiMessage):
|
||||||
type: ApiMessageType = ApiMessageType.JSON
|
type: ApiMessageType = ApiMessageType.JSON
|
||||||
|
sender_id: Optional[str] = Field(default=MOCK_UUID, alias="senderId")
|
||||||
document: Document = Field(..., alias="document")
|
document: Document = Field(..., alias="document")
|
||||||
content: Optional[str] = ""
|
content: Optional[str] = ""
|
||||||
converted: bool = Field(False, alias="converted")
|
converted: bool = Field(False, alias="converted")
|
||||||
@ -837,9 +966,9 @@ class ChatMessageMetaData(BaseModel):
|
|||||||
temperature: float = 0.7
|
temperature: float = 0.7
|
||||||
max_tokens: int = Field(default=8092, alias="maxTokens")
|
max_tokens: int = Field(default=8092, alias="maxTokens")
|
||||||
top_p: float = Field(default=1, alias="topP")
|
top_p: float = Field(default=1, alias="topP")
|
||||||
frequency_penalty: Optional[float] = Field(None, alias="frequencyPenalty")
|
frequency_penalty: float = Field(default=0, alias="frequencyPenalty")
|
||||||
presence_penalty: Optional[float] = Field(None, alias="presencePenalty")
|
presence_penalty: float = Field(default=0, alias="presencePenalty")
|
||||||
stop_sequences: Optional[List[str]] = Field(None, alias="stopSequences")
|
stop_sequences: List[str] = Field(default=[], alias="stopSequences")
|
||||||
rag_results: List[ChromaDBGetResponse] = Field(default_factory=list, alias="ragResults")
|
rag_results: List[ChromaDBGetResponse] = Field(default_factory=list, alias="ragResults")
|
||||||
llm_history: List[LLMMessage] = Field(default_factory=list, alias="llmHistory")
|
llm_history: List[LLMMessage] = Field(default_factory=list, alias="llmHistory")
|
||||||
eval_count: int = 0
|
eval_count: int = 0
|
||||||
@ -862,16 +991,31 @@ class ChatMessageUser(ApiMessage):
|
|||||||
|
|
||||||
class ChatMessage(ChatMessageUser):
|
class ChatMessage(ChatMessageUser):
|
||||||
role: ChatSenderType = ChatSenderType.ASSISTANT
|
role: ChatSenderType = ChatSenderType.ASSISTANT
|
||||||
metadata: ChatMessageMetaData = Field(default_factory=ChatMessageMetaData)
|
metadata: ChatMessageMetaData = Field(default=ChatMessageMetaData())
|
||||||
#attachments: Optional[List[Attachment]] = None
|
#attachments: Optional[List[Attachment]] = None
|
||||||
#reactions: Optional[List[MessageReaction]] = None
|
#reactions: Optional[List[MessageReaction]] = None
|
||||||
#is_edited: bool = Field(False, alias="isEdited")
|
#is_edited: bool = Field(False, alias="isEdited")
|
||||||
#edit_history: Optional[List[EditHistory]] = Field(None, alias="editHistory")
|
#edit_history: Optional[List[EditHistory]] = Field(None, alias="editHistory")
|
||||||
|
|
||||||
|
class GPUInfo(BaseModel):
|
||||||
|
name: str
|
||||||
|
memory: int
|
||||||
|
discrete: bool
|
||||||
|
|
||||||
|
class SystemInfo(BaseModel):
|
||||||
|
installed_RAM: str = Field(..., alias="installedRAM")
|
||||||
|
graphics_cards: List[GPUInfo] = Field(..., alias="graphicsCards")
|
||||||
|
CPU: str
|
||||||
|
llm_model: str = Field(default=defines.model, alias="llmModel")
|
||||||
|
embedding_model: str = Field(default=defines.embedding_model, alias="embeddingModel")
|
||||||
|
max_context_length: int = Field(default=defines.max_context, alias="maxContextLength")
|
||||||
|
model_config = {
|
||||||
|
"populate_by_name": True # Allow both field names and aliases
|
||||||
|
}
|
||||||
|
|
||||||
class ChatSession(BaseModel):
|
class ChatSession(BaseModel):
|
||||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
user_id: Optional[str] = Field(None, alias="userId")
|
user_id: Optional[str] = Field(None, alias="userId")
|
||||||
guest_id: Optional[str] = Field(None, alias="guestId")
|
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="createdAt")
|
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="createdAt")
|
||||||
last_activity: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="lastActivity")
|
last_activity: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="lastActivity")
|
||||||
title: Optional[str] = None
|
title: Optional[str] = None
|
||||||
@ -937,7 +1081,7 @@ class UserActivity(BaseModel):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_user_or_guest(self) -> "ChatSession":
|
def check_user_or_guest(self):
|
||||||
if not self.user_id and not self.guest_id:
|
if not self.user_id and not self.guest_id:
|
||||||
raise ValueError("Either user_id or guest_id must be provided")
|
raise ValueError("Either user_id or guest_id must be provided")
|
||||||
return self
|
return self
|
||||||
@ -1102,6 +1246,8 @@ class JobListResponse(BaseModel):
|
|||||||
error: Optional[ErrorDetail] = None
|
error: Optional[ErrorDetail] = None
|
||||||
meta: Optional[Dict[str, Any]] = None
|
meta: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
User = Union[Candidate, CandidateAI, Employer, Guest]
|
||||||
|
|
||||||
# Forward references resolution
|
# Forward references resolution
|
||||||
Candidate.update_forward_refs()
|
Candidate.update_forward_refs()
|
||||||
Employer.update_forward_refs()
|
Employer.update_forward_refs()
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field # type: ignore
|
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field
|
||||||
from typing import List, Optional, Dict, Any, Union
|
from typing import List, Optional, Dict, Any, Union
|
||||||
import os
|
import os
|
||||||
import glob
|
import glob
|
||||||
@ -9,15 +9,15 @@ import hashlib
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import numpy as np # type: ignore
|
import numpy as np
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import chromadb # type: ignore
|
import chromadb
|
||||||
from watchdog.observers import Observer # type: ignore
|
from watchdog.observers import Observer
|
||||||
from watchdog.events import FileSystemEventHandler # type: ignore
|
from watchdog.events import FileSystemEventHandler
|
||||||
import umap # type: ignore
|
import umap
|
||||||
from markitdown import MarkItDown # type: ignore
|
from markitdown import MarkItDown
|
||||||
from chromadb.api.models.Collection import Collection # type: ignore
|
from chromadb.api.models.Collection import Collection
|
||||||
|
|
||||||
from .markdown_chunker import (
|
from .markdown_chunker import (
|
||||||
MarkdownChunker,
|
MarkdownChunker,
|
||||||
@ -351,9 +351,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
os.makedirs(self.persist_directory)
|
os.makedirs(self.persist_directory)
|
||||||
|
|
||||||
# Initialize ChromaDB client
|
# Initialize ChromaDB client
|
||||||
chroma_client = chromadb.PersistentClient( # type: ignore
|
chroma_client = chromadb.PersistentClient(
|
||||||
path=self.persist_directory,
|
path=self.persist_directory,
|
||||||
settings=chromadb.Settings(anonymized_telemetry=False), # type: ignore
|
settings=chromadb.Settings(anonymized_telemetry=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the collection exists
|
# Check if the collection exists
|
||||||
|
287
src/backend/rate_limiter.py
Normal file
287
src/backend/rate_limiter.py
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
"""
|
||||||
|
Rate limiting utilities for guest and authenticated users
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from datetime import datetime, timedelta, UTC
|
||||||
|
from typing import Dict, Optional, Tuple, Any
|
||||||
|
from pydantic import BaseModel # type: ignore
|
||||||
|
from database import RedisDatabase
|
||||||
|
from logger import logger
|
||||||
|
|
||||||
|
class RateLimitConfig(BaseModel):
|
||||||
|
"""Rate limit configuration"""
|
||||||
|
requests_per_minute: int
|
||||||
|
requests_per_hour: int
|
||||||
|
requests_per_day: int
|
||||||
|
burst_limit: int # Maximum requests in a short burst
|
||||||
|
burst_window_seconds: int = 60 # Window for burst detection
|
||||||
|
|
||||||
|
class GuestRateLimitConfig(RateLimitConfig):
|
||||||
|
"""Rate limits for guest users - more restrictive"""
|
||||||
|
requests_per_minute: int = 10
|
||||||
|
requests_per_hour: int = 100
|
||||||
|
requests_per_day: int = 500
|
||||||
|
burst_limit: int = 15
|
||||||
|
burst_window_seconds: int = 60
|
||||||
|
|
||||||
|
class AuthenticatedUserRateLimitConfig(RateLimitConfig):
|
||||||
|
"""Rate limits for authenticated users - more generous"""
|
||||||
|
requests_per_minute: int = 60
|
||||||
|
requests_per_hour: int = 1000
|
||||||
|
requests_per_day: int = 10000
|
||||||
|
burst_limit: int = 100
|
||||||
|
burst_window_seconds: int = 60
|
||||||
|
|
||||||
|
class PremiumUserRateLimitConfig(RateLimitConfig):
|
||||||
|
"""Rate limits for premium/admin users - most generous"""
|
||||||
|
requests_per_minute: int = 120
|
||||||
|
requests_per_hour: int = 5000
|
||||||
|
requests_per_day: int = 50000
|
||||||
|
burst_limit: int = 200
|
||||||
|
burst_window_seconds: int = 60
|
||||||
|
|
||||||
|
class RateLimitResult(BaseModel):
|
||||||
|
"""Result of rate limit check"""
|
||||||
|
allowed: bool
|
||||||
|
reason: Optional[str] = None
|
||||||
|
retry_after_seconds: Optional[int] = None
|
||||||
|
remaining_requests: Dict[str, int] = {}
|
||||||
|
reset_times: Dict[str, datetime] = {}
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""Rate limiter using Redis for distributed rate limiting"""
|
||||||
|
|
||||||
|
def __init__(self, database: RedisDatabase):
|
||||||
|
self.database = database
|
||||||
|
self.redis = database.redis
|
||||||
|
|
||||||
|
# Rate limit configurations
|
||||||
|
self.guest_config = GuestRateLimitConfig()
|
||||||
|
self.user_config = AuthenticatedUserRateLimitConfig()
|
||||||
|
self.premium_config = PremiumUserRateLimitConfig()
|
||||||
|
|
||||||
|
def get_config_for_user(self, user_type: str, is_admin: bool = False) -> RateLimitConfig:
|
||||||
|
"""Get rate limit configuration based on user type"""
|
||||||
|
if user_type == "guest":
|
||||||
|
return self.guest_config
|
||||||
|
elif is_admin:
|
||||||
|
return self.premium_config
|
||||||
|
else:
|
||||||
|
return self.user_config
|
||||||
|
|
||||||
|
async def check_rate_limit(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
user_type: str,
|
||||||
|
is_admin: bool = False,
|
||||||
|
endpoint: Optional[str] = None
|
||||||
|
) -> RateLimitResult:
|
||||||
|
"""
|
||||||
|
Check if user has exceeded rate limits
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Unique identifier for the user (guest session ID or user ID)
|
||||||
|
user_type: "guest", "candidate", or "employer"
|
||||||
|
is_admin: Whether user has admin privileges
|
||||||
|
endpoint: Optional endpoint-specific rate limiting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RateLimitResult indicating if request is allowed
|
||||||
|
"""
|
||||||
|
config = self.get_config_for_user(user_type, is_admin)
|
||||||
|
current_time = datetime.now(UTC)
|
||||||
|
|
||||||
|
# Create Redis keys for different time windows
|
||||||
|
base_key = f"rate_limit:{user_type}:{user_id}"
|
||||||
|
keys = {
|
||||||
|
"minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}",
|
||||||
|
"hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}",
|
||||||
|
"day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}",
|
||||||
|
"burst": f"{base_key}:burst"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add endpoint-specific limiting if provided
|
||||||
|
if endpoint:
|
||||||
|
keys = {k: f"{v}:{endpoint}" for k, v in keys.items()}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use Redis pipeline for atomic operations
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
|
||||||
|
# Get current counts
|
||||||
|
for key in keys.values():
|
||||||
|
pipe.get(key)
|
||||||
|
|
||||||
|
results = await pipe.execute()
|
||||||
|
current_counts = {
|
||||||
|
"minute": int(results[0] or 0),
|
||||||
|
"hour": int(results[1] or 0),
|
||||||
|
"day": int(results[2] or 0),
|
||||||
|
"burst": int(results[3] or 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check limits
|
||||||
|
limits = {
|
||||||
|
"minute": config.requests_per_minute,
|
||||||
|
"hour": config.requests_per_hour,
|
||||||
|
"day": config.requests_per_day,
|
||||||
|
"burst": config.burst_limit
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check each limit
|
||||||
|
for window, current_count in current_counts.items():
|
||||||
|
limit = limits[window]
|
||||||
|
if current_count >= limit:
|
||||||
|
# Calculate retry after time
|
||||||
|
if window == "minute":
|
||||||
|
retry_after = 60 - current_time.second
|
||||||
|
elif window == "hour":
|
||||||
|
retry_after = 3600 - (current_time.minute * 60 + current_time.second)
|
||||||
|
elif window == "day":
|
||||||
|
retry_after = 86400 - (current_time.hour * 3600 + current_time.minute * 60 + current_time.second)
|
||||||
|
else: # burst
|
||||||
|
retry_after = config.burst_window_seconds
|
||||||
|
|
||||||
|
logger.warning(f"🚫 Rate limit exceeded for {user_type} {user_id}: {current_count}/{limit} {window}")
|
||||||
|
|
||||||
|
return RateLimitResult(
|
||||||
|
allowed=False,
|
||||||
|
reason=f"Rate limit exceeded: {current_count}/{limit} requests per {window}",
|
||||||
|
retry_after_seconds=retry_after,
|
||||||
|
remaining_requests={k: max(0, limits[k] - v) for k, v in current_counts.items()},
|
||||||
|
reset_times=self._calculate_reset_times(current_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
# If we get here, request is allowed - increment counters
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
|
||||||
|
# Increment minute counter (expires after 2 minutes)
|
||||||
|
pipe.incr(keys["minute"])
|
||||||
|
pipe.expire(keys["minute"], 120)
|
||||||
|
|
||||||
|
# Increment hour counter (expires after 2 hours)
|
||||||
|
pipe.incr(keys["hour"])
|
||||||
|
pipe.expire(keys["hour"], 7200)
|
||||||
|
|
||||||
|
# Increment day counter (expires after 2 days)
|
||||||
|
pipe.incr(keys["day"])
|
||||||
|
pipe.expire(keys["day"], 172800)
|
||||||
|
|
||||||
|
# Increment burst counter (expires after burst window)
|
||||||
|
pipe.incr(keys["burst"])
|
||||||
|
pipe.expire(keys["burst"], config.burst_window_seconds)
|
||||||
|
|
||||||
|
await pipe.execute()
|
||||||
|
|
||||||
|
# Calculate remaining requests
|
||||||
|
remaining = {
|
||||||
|
k: max(0, limits[k] - (current_counts[k] + 1))
|
||||||
|
for k in current_counts.keys()
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(f"✅ Rate limit check passed for {user_type} {user_id}")
|
||||||
|
|
||||||
|
return RateLimitResult(
|
||||||
|
allowed=True,
|
||||||
|
remaining_requests=remaining,
|
||||||
|
reset_times=self._calculate_reset_times(current_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Rate limit check failed for {user_id}: {e}")
|
||||||
|
# Fail open - allow request if rate limiting system fails
|
||||||
|
return RateLimitResult(allowed=True, reason="Rate limit check failed - allowing request")
|
||||||
|
|
||||||
|
def _calculate_reset_times(self, current_time: datetime) -> Dict[str, datetime]:
|
||||||
|
"""Calculate when each rate limit window resets"""
|
||||||
|
next_minute = current_time.replace(second=0, microsecond=0) + timedelta(minutes=1)
|
||||||
|
next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
|
||||||
|
next_day = current_time.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"minute": next_minute,
|
||||||
|
"hour": next_hour,
|
||||||
|
"day": next_day
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_user_rate_limit_status(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
user_type: str,
|
||||||
|
is_admin: bool = False
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Get current rate limit status for a user"""
|
||||||
|
config = self.get_config_for_user(user_type, is_admin)
|
||||||
|
current_time = datetime.now(UTC)
|
||||||
|
|
||||||
|
base_key = f"rate_limit:{user_type}:{user_id}"
|
||||||
|
keys = {
|
||||||
|
"minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}",
|
||||||
|
"hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}",
|
||||||
|
"day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}",
|
||||||
|
"burst": f"{base_key}:burst"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
pipe = self.redis.pipeline()
|
||||||
|
for key in keys.values():
|
||||||
|
pipe.get(key)
|
||||||
|
|
||||||
|
results = await pipe.execute()
|
||||||
|
current_counts = {
|
||||||
|
"minute": int(results[0] or 0),
|
||||||
|
"hour": int(results[1] or 0),
|
||||||
|
"day": int(results[2] or 0),
|
||||||
|
"burst": int(results[3] or 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
limits = {
|
||||||
|
"minute": config.requests_per_minute,
|
||||||
|
"hour": config.requests_per_hour,
|
||||||
|
"day": config.requests_per_day,
|
||||||
|
"burst": config.burst_limit
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"user_id": user_id,
|
||||||
|
"user_type": user_type,
|
||||||
|
"is_admin": is_admin,
|
||||||
|
"current_usage": current_counts,
|
||||||
|
"limits": limits,
|
||||||
|
"remaining": {k: max(0, limits[k] - current_counts[k]) for k in limits.keys()},
|
||||||
|
"reset_times": self._calculate_reset_times(current_time),
|
||||||
|
"config": config.model_dump()
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Failed to get rate limit status for {user_id}: {e}")
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
async def reset_user_rate_limits(self, user_id: str, user_type: str) -> bool:
|
||||||
|
"""Reset all rate limits for a user (admin function)"""
|
||||||
|
try:
|
||||||
|
base_key = f"rate_limit:{user_type}:{user_id}"
|
||||||
|
pattern = f"{base_key}:*"
|
||||||
|
|
||||||
|
cursor = 0
|
||||||
|
deleted_count = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
|
||||||
|
if keys:
|
||||||
|
await self.redis.delete(*keys)
|
||||||
|
deleted_count += len(keys)
|
||||||
|
|
||||||
|
if cursor == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(f"🔄 Reset {deleted_count} rate limit keys for {user_type} {user_id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Failed to reset rate limits for {user_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
@ -2,6 +2,7 @@ import defines
|
|||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import math
|
import math
|
||||||
|
from models import SystemInfo
|
||||||
|
|
||||||
def get_installed_ram():
|
def get_installed_ram():
|
||||||
try:
|
try:
|
||||||
@ -70,12 +71,18 @@ def get_cpu_info():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error retrieving CPU info: {e}"
|
return f"Error retrieving CPU info: {e}"
|
||||||
|
|
||||||
def system_info():
|
def system_info() -> SystemInfo:
|
||||||
return {
|
"""
|
||||||
"System RAM": get_installed_ram(),
|
Collects system information including RAM, GPU, CPU, LLM model, embedding model, and context length.
|
||||||
"Graphics Card": get_graphics_cards(),
|
Returns:
|
||||||
"CPU": get_cpu_info(),
|
SystemInfo: An object containing the collected system information.
|
||||||
"LLM Model": defines.model,
|
"""
|
||||||
"Embedding Model": defines.embedding_model,
|
system = SystemInfo(
|
||||||
"Context length": defines.max_context,
|
installed_RAM=get_installed_ram(),
|
||||||
}
|
graphics_cards=get_graphics_cards(),
|
||||||
|
CPU=get_cpu_info(),
|
||||||
|
llm_model=defines.model,
|
||||||
|
embedding_model=defines.embedding_model,
|
||||||
|
max_context_length=defines.max_context,
|
||||||
|
)
|
||||||
|
return system
|
@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from typing import List, Optional, Generator, ClassVar, Any, Dict
|
from typing import List, Optional, Generator, ClassVar, Any, Dict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -7,12 +7,12 @@ from typing import (
|
|||||||
)
|
)
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from bs4 import BeautifulSoup # type: ignore
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
from geopy.geocoders import Nominatim # type: ignore
|
from geopy.geocoders import Nominatim
|
||||||
import pytz # type: ignore
|
import pytz
|
||||||
import requests
|
import requests
|
||||||
import yfinance as yf # type: ignore
|
import yfinance as yf
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
@ -523,4 +523,4 @@ def enabled_tools(tools: List[ToolEntry]) -> List[ToolEntry]:
|
|||||||
|
|
||||||
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite", "GenerateImage"]
|
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite", "GenerateImage"]
|
||||||
__all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"]
|
__all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"]
|
||||||
# __all__.extend(__tool_functions__) # type: ignore
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from fastapi import FastAPI, HTTPException # type: ignore
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.responses import StreamingResponse # type: ignore
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel # type: ignore
|
from pydantic import BaseModel
|
||||||
from typing import List, Optional, Dict, Any
|
from typing import List, Optional, Dict, Any
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
@ -234,5 +234,5 @@ async def health_check():
|
|||||||
}
|
}
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn # type: ignore
|
import uvicorn
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
@ -92,7 +92,7 @@ class OllamaAdapter(BaseLLMAdapter):
|
|||||||
def __init__(self, **config):
|
def __init__(self, **config):
|
||||||
super().__init__(**config)
|
super().__init__(**config)
|
||||||
import ollama
|
import ollama
|
||||||
self.client = ollama.AsyncClient( # type: ignore
|
self.client = ollama.AsyncClient(
|
||||||
host=config.get('host', 'http://localhost:11434')
|
host=config.get('host', 'http://localhost:11434')
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -187,7 +187,7 @@ class OpenAIAdapter(BaseLLMAdapter):
|
|||||||
|
|
||||||
def __init__(self, **config):
|
def __init__(self, **config):
|
||||||
super().__init__(**config)
|
super().__init__(**config)
|
||||||
import openai # type: ignore
|
import openai
|
||||||
self.client = openai.AsyncOpenAI(
|
self.client = openai.AsyncOpenAI(
|
||||||
api_key=config.get('api_key', os.getenv('OPENAI_API_KEY'))
|
api_key=config.get('api_key', os.getenv('OPENAI_API_KEY'))
|
||||||
)
|
)
|
||||||
@ -259,7 +259,7 @@ class AnthropicAdapter(BaseLLMAdapter):
|
|||||||
|
|
||||||
def __init__(self, **config):
|
def __init__(self, **config):
|
||||||
super().__init__(**config)
|
super().__init__(**config)
|
||||||
import anthropic # type: ignore
|
import anthropic
|
||||||
self.client = anthropic.AsyncAnthropic(
|
self.client = anthropic.AsyncAnthropic(
|
||||||
api_key=config.get('api_key', os.getenv('ANTHROPIC_API_KEY'))
|
api_key=config.get('api_key', os.getenv('ANTHROPIC_API_KEY'))
|
||||||
)
|
)
|
||||||
@ -344,7 +344,7 @@ class GeminiAdapter(BaseLLMAdapter):
|
|||||||
|
|
||||||
def __init__(self, **config):
|
def __init__(self, **config):
|
||||||
super().__init__(**config)
|
super().__init__(**config)
|
||||||
import google.generativeai as genai # type: ignore
|
import google.generativeai as genai
|
||||||
genai.configure(api_key=config.get('api_key', os.getenv('GEMINI_API_KEY')))
|
genai.configure(api_key=config.get('api_key', os.getenv('GEMINI_API_KEY')))
|
||||||
self.genai = genai
|
self.genai = genai
|
||||||
|
|
||||||
@ -476,7 +476,7 @@ class UnifiedLLMProxy:
|
|||||||
|
|
||||||
result = await self.chat(model, messages, provider, stream=True, **kwargs)
|
result = await self.chat(model, messages, provider, stream=True, **kwargs)
|
||||||
# Type checker now knows this is an AsyncGenerator due to stream=True
|
# Type checker now knows this is an AsyncGenerator due to stream=True
|
||||||
async for chunk in result: # type: ignore
|
async for chunk in result:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def chat_single(
|
async def chat_single(
|
||||||
@ -490,7 +490,7 @@ class UnifiedLLMProxy:
|
|||||||
|
|
||||||
result = await self.chat(model, messages, provider, stream=False, **kwargs)
|
result = await self.chat(model, messages, provider, stream=False, **kwargs)
|
||||||
# Type checker now knows this is a ChatResponse due to stream=False
|
# Type checker now knows this is a ChatResponse due to stream=False
|
||||||
return result # type: ignore
|
return result
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
@ -517,7 +517,7 @@ class UnifiedLLMProxy:
|
|||||||
"""Stream text generation using specified or default provider"""
|
"""Stream text generation using specified or default provider"""
|
||||||
|
|
||||||
result = await self.generate(model, prompt, provider, stream=True, **kwargs)
|
result = await self.generate(model, prompt, provider, stream=True, **kwargs)
|
||||||
async for chunk in result: # type: ignore
|
async for chunk in result:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def generate_single(
|
async def generate_single(
|
||||||
@ -530,7 +530,7 @@ class UnifiedLLMProxy:
|
|||||||
"""Get single generation response using specified or default provider"""
|
"""Get single generation response using specified or default provider"""
|
||||||
|
|
||||||
result = await self.generate(model, prompt, provider, stream=False, **kwargs)
|
result = await self.generate(model, prompt, provider, stream=False, **kwargs)
|
||||||
return result # type: ignore
|
return result
|
||||||
|
|
||||||
async def list_models(self, provider: Optional[LLMProvider] = None) -> List[str]:
|
async def list_models(self, provider: Optional[LLMProvider] = None) -> List[str]:
|
||||||
"""List available models for specified or default provider"""
|
"""List available models for specified or default provider"""
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
LLM_TIMEOUT = 600
|
LLM_TIMEOUT = 600
|
||||||
|
|
||||||
from utils import logger
|
from utils import logger
|
||||||
from pydantic import BaseModel, Field, ValidationError # type: ignore
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
from pydantic_core import PydanticSerializationError # type: ignore
|
from pydantic_core import PydanticSerializationError
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from typing import AsyncGenerator, Dict, Optional
|
from typing import AsyncGenerator, Dict, Optional
|
||||||
@ -48,18 +48,18 @@ try_import("prometheus_fastapi_instrumentator")
|
|||||||
|
|
||||||
import ollama
|
import ollama
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI, Request, HTTPException, Depends # type: ignore
|
from fastapi import FastAPI, Request, HTTPException, Depends
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse # type: ignore
|
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse
|
||||||
from fastapi.middleware.cors import CORSMiddleware # type: ignore
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
import uvicorn # type: ignore
|
import uvicorn
|
||||||
import numpy as np # type: ignore
|
import numpy as np
|
||||||
from utils import redis_manager
|
from utils import redis_manager
|
||||||
import redis.asyncio as redis # type: ignore
|
import redis.asyncio as redis
|
||||||
|
|
||||||
# Prometheus
|
# Prometheus
|
||||||
from prometheus_client import Summary # type: ignore
|
from prometheus_client import Summary
|
||||||
from prometheus_fastapi_instrumentator import Instrumentator # type: ignore
|
from prometheus_fastapi_instrumentator import Instrumentator
|
||||||
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
from prometheus_client import CollectorRegistry, Counter
|
||||||
|
|
||||||
from utils import (
|
from utils import (
|
||||||
rag as Rag,
|
rag as Rag,
|
||||||
@ -1308,7 +1308,7 @@ def main():
|
|||||||
|
|
||||||
warnings.filterwarnings("ignore", category=UserWarning, module="umap.*")
|
warnings.filterwarnings("ignore", category=UserWarning, module="umap.*")
|
||||||
|
|
||||||
llm = ollama.Client(host=args.ollama_server) # type: ignore
|
llm = ollama.Client(host=args.ollama_server)
|
||||||
web_server = WebServer(llm, args.ollama_model)
|
web_server = WebServer(llm, args.ollama_model)
|
||||||
web_server.run(host=args.web_host, port=args.web_port, use_reloader=False)
|
web_server.run(host=args.web_host, port=args.web_port, use_reloader=False)
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ import sys
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../utils")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../utils")))
|
||||||
from markdown_chunker import MarkdownChunker # type: ignore
|
from markdown_chunker import MarkdownChunker
|
||||||
|
|
||||||
chunker = MarkdownChunker()
|
chunker = MarkdownChunker()
|
||||||
chunks = chunker.process_file("./src/tests/test.md") # docs/resume/resume.md")
|
chunks = chunker.process_file("./src/tests/test.md") # docs/resume/resume.md")
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
from fastapi import FastAPI, Request, Depends, Query # type: ignore
|
from fastapi import FastAPI, Request, Depends, Query
|
||||||
from fastapi.responses import RedirectResponse, JSONResponse # type: ignore
|
from fastapi.responses import RedirectResponse, JSONResponse
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from anyio.to_thread import run_sync # type: ignore
|
from anyio.to_thread import run_sync
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -69,7 +69,7 @@ class ContextRouteManager:
|
|||||||
logger.info(f"Invalid UUID, redirecting to {redirect_url}")
|
logger.info(f"Invalid UUID, redirecting to {redirect_url}")
|
||||||
raise RedirectToContext(redirect_url)
|
raise RedirectToContext(redirect_url)
|
||||||
|
|
||||||
return _ensure_context_dependency # type: ignore
|
return _ensure_context_dependency
|
||||||
|
|
||||||
def route_pattern(self, path: str, *dependencies, **kwargs):
|
def route_pattern(self, path: str, *dependencies, **kwargs):
|
||||||
logger.info(f"Registering route: {path}")
|
logger.info(f"Registering route: {path}")
|
||||||
@ -134,6 +134,6 @@ async def redirect_history(
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn # type: ignore
|
import uvicorn
|
||||||
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8900)
|
uvicorn.run(app, host="0.0.0.0", port=8900)
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
# From /opt/backstory run:
|
# From /opt/backstory run:
|
||||||
# python -m src.tests.test-embedding
|
# python -m src.tests.test-embedding
|
||||||
import numpy as np # type: ignore
|
import numpy as np
|
||||||
import logging
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
from ollama import Client # type: ignore
|
from ollama import Client
|
||||||
from ..utils import defines
|
from ..utils import defines
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
# From /opt/backstory run:
|
# From /opt/backstory run:
|
||||||
# python -m src.tests.test-rag
|
# python -m src.tests.test-rag
|
||||||
from ..utils import logger
|
from ..utils import logger
|
||||||
from pydantic import BaseModel, field_validator # type: ignore
|
from pydantic import BaseModel, field_validator
|
||||||
from prometheus_client import CollectorRegistry # type: ignore
|
from prometheus_client import CollectorRegistry
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
import ollama
|
import ollama
|
||||||
import numpy as np # type: ignore
|
import numpy as np
|
||||||
from ..utils import (rag as Rag, ChromaDBGetResponse)
|
from ..utils import (rag as Rag, ChromaDBGetResponse)
|
||||||
from ..utils import Context
|
from ..utils import Context
|
||||||
from ..utils import defines
|
from ..utils import defines
|
||||||
@ -44,7 +44,7 @@ rag.embeddings = np.array([[1.0, 2.0], [3.0, 4.0]])
|
|||||||
json_str = rag.model_dump(mode="json")
|
json_str = rag.model_dump(mode="json")
|
||||||
logger.info(json_str)
|
logger.info(json_str)
|
||||||
rag = ChromaDBGetResponse.model_validate(json_str)
|
rag = ChromaDBGetResponse.model_validate(json_str)
|
||||||
llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
|
llm = ollama.Client(host=defines.ollama_api_url)
|
||||||
prometheus_collector = CollectorRegistry()
|
prometheus_collector = CollectorRegistry()
|
||||||
observer, file_watcher = Rag.start_file_watcher(
|
observer, file_watcher = Rag.start_file_watcher(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import BaseModel # type: ignore
|
from pydantic import BaseModel
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Set
|
Set
|
||||||
@ -39,7 +39,7 @@ __all__ = [
|
|||||||
"generate_image", "ImageRequest"
|
"generate_image", "ImageRequest"
|
||||||
]
|
]
|
||||||
|
|
||||||
__all__.extend(agents_all) # type: ignore
|
__all__.extend(agents_all)
|
||||||
|
|
||||||
logger = setup_logging(level=defines.logging_level)
|
logger = setup_logging(level=defines.logging_level)
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ for path in package_dir.glob("*.py"):
|
|||||||
class_registry[name] = (full_module_name, name)
|
class_registry[name] = (full_module_name, name)
|
||||||
globals()[name] = obj
|
globals()[name] = obj
|
||||||
logger.info(f"Adding agent: {name}")
|
logger.info(f"Adding agent: {name}")
|
||||||
__all__.append(name) # type: ignore
|
__all__.append(name)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.error(f"Error importing {full_module_name}: {e}")
|
logger.error(f"Error importing {full_module_name}: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import BaseModel, Field # type: ignore
|
from pydantic import BaseModel, Field
|
||||||
from typing import (
|
from typing import (
|
||||||
Literal,
|
Literal,
|
||||||
get_args,
|
get_args,
|
||||||
@ -19,7 +19,7 @@ import inspect
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
|
from prometheus_client import Counter, Summary, CollectorRegistry
|
||||||
|
|
||||||
from ..setup_logging import setup_logging
|
from ..setup_logging import setup_logging
|
||||||
|
|
||||||
@ -33,7 +33,7 @@ from .types import agent_registry
|
|||||||
from .. import defines
|
from .. import defines
|
||||||
from ..message import Message, Tunables
|
from ..message import Message, Tunables
|
||||||
from ..metrics import Metrics
|
from ..metrics import Metrics
|
||||||
from ..tools import TickerValue, WeatherForecast, AnalyzeSite, GenerateImage, DateTime, llm_tools # type: ignore -- dynamically added to __all__
|
from ..tools import TickerValue, WeatherForecast, AnalyzeSite, GenerateImage, DateTime, llm_tools
|
||||||
from ..conversation import Conversation
|
from ..conversation import Conversation
|
||||||
|
|
||||||
class LLMMessage(BaseModel):
|
class LLMMessage(BaseModel):
|
||||||
|
@ -53,7 +53,7 @@ class Chat(Agent):
|
|||||||
Chat Agent
|
Chat Agent
|
||||||
"""
|
"""
|
||||||
|
|
||||||
agent_type: Literal["chat"] = "chat" # type: ignore
|
agent_type: Literal["chat"] = "chat"
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
system_prompt: str = system_message
|
system_prompt: str = system_message
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import model_validator # type: ignore
|
from pydantic import model_validator
|
||||||
from typing import (
|
from typing import (
|
||||||
Literal,
|
Literal,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
@ -31,7 +31,7 @@ When answering queries, follow these steps:
|
|||||||
|
|
||||||
|
|
||||||
class FactCheck(Agent):
|
class FactCheck(Agent):
|
||||||
agent_type: Literal["fact_check"] = "fact_check" # type: ignore
|
agent_type: Literal["fact_check"] = "fact_check"
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
system_prompt: str = system_fact_check
|
system_prompt: str = system_fact_check
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import model_validator, Field, BaseModel # type: ignore
|
from pydantic import model_validator, Field, BaseModel
|
||||||
from typing import (
|
from typing import (
|
||||||
Dict,
|
Dict,
|
||||||
Literal,
|
Literal,
|
||||||
@ -37,7 +37,7 @@ seed = int(time.time())
|
|||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
class ImageGenerator(Agent):
|
class ImageGenerator(Agent):
|
||||||
agent_type: Literal["image"] = "image" # type: ignore
|
agent_type: Literal["image"] = "image"
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
agent_persist: bool = False
|
agent_persist: bool = False
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import model_validator, Field # type: ignore
|
from pydantic import model_validator, Field
|
||||||
from typing import (
|
from typing import (
|
||||||
Dict,
|
Dict,
|
||||||
Literal,
|
Literal,
|
||||||
@ -17,7 +17,7 @@ import traceback
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
import numpy as np # type: ignore
|
import numpy as np
|
||||||
|
|
||||||
from . base import Agent, agent_registry, LLMMessage
|
from . base import Agent, agent_registry, LLMMessage
|
||||||
from .. message import Message
|
from .. message import Message
|
||||||
@ -36,7 +36,7 @@ Answer questions about the job description.
|
|||||||
|
|
||||||
|
|
||||||
class JobDescription(Agent):
|
class JobDescription(Agent):
|
||||||
agent_type: Literal["job_description"] = "job_description" # type: ignore
|
agent_type: Literal["job_description"] = "job_description"
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
system_prompt: str = system_generate_resume
|
system_prompt: str = system_generate_resume
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import model_validator, Field, BaseModel # type: ignore
|
from pydantic import model_validator, Field, BaseModel
|
||||||
from typing import (
|
from typing import (
|
||||||
Dict,
|
Dict,
|
||||||
Literal,
|
Literal,
|
||||||
@ -23,7 +23,7 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from names_dataset import NameDataset, NameWrapper # type: ignore
|
from names_dataset import NameDataset, NameWrapper
|
||||||
|
|
||||||
from .base import Agent, agent_registry, LLMMessage
|
from .base import Agent, agent_registry, LLMMessage
|
||||||
from ..message import Message
|
from ..message import Message
|
||||||
@ -128,7 +128,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class EthnicNameGenerator:
|
class EthnicNameGenerator:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
try:
|
try:
|
||||||
from names_dataset import NameDataset # type: ignore
|
from names_dataset import NameDataset
|
||||||
self.nd = NameDataset()
|
self.nd = NameDataset()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error("NameDataset not available. Please install: pip install names-dataset")
|
logger.error("NameDataset not available. Please install: pip install names-dataset")
|
||||||
@ -292,7 +292,7 @@ class EthnicNameGenerator:
|
|||||||
return names
|
return names
|
||||||
|
|
||||||
class PersonaGenerator(Agent):
|
class PersonaGenerator(Agent):
|
||||||
agent_type: Literal["persona"] = "persona" # type: ignore
|
agent_type: Literal["persona"] = "persona"
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
agent_persist: bool = False
|
agent_persist: bool = False
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import model_validator # type: ignore
|
from pydantic import model_validator
|
||||||
from typing import (
|
from typing import (
|
||||||
Literal,
|
Literal,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
@ -46,7 +46,7 @@ When answering queries, follow these steps:
|
|||||||
|
|
||||||
|
|
||||||
class Resume(Agent):
|
class Resume(Agent):
|
||||||
agent_type: Literal["resume"] = "resume" # type: ignore
|
agent_type: Literal["resume"] = "resume"
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
system_prompt: str = system_fact_check
|
system_prompt: str = system_fact_check
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from pydantic import BaseModel, Field # type: ignore
|
from pydantic import BaseModel, Field
|
||||||
import json
|
import json
|
||||||
from typing import Any, List, Set
|
from typing import Any, List, Set
|
||||||
|
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from typing import List, Optional, Generator, ClassVar, Any, TYPE_CHECKING
|
from typing import List, Optional, Generator, ClassVar, Any, TYPE_CHECKING
|
||||||
from typing_extensions import Annotated, Union
|
from typing_extensions import Annotated, Union
|
||||||
import numpy as np # type: ignore
|
import numpy as np
|
||||||
import logging
|
import logging
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
import traceback
|
import traceback
|
||||||
@ -44,7 +44,7 @@ class Context(BaseModel):
|
|||||||
user_facts: Optional[str] = None
|
user_facts: Optional[str] = None
|
||||||
|
|
||||||
# Class managed fields
|
# Class managed fields
|
||||||
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( # type: ignore
|
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field(
|
||||||
default_factory=list
|
default_factory=list
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from pydantic import BaseModel, Field # type: ignore
|
from pydantic import BaseModel, Field
|
||||||
from typing import List
|
from typing import List
|
||||||
from .message import Message
|
from .message import Message
|
||||||
|
|
||||||
|
@ -4,8 +4,8 @@ import re
|
|||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch # type: ignore
|
import torch
|
||||||
from diffusers import StableDiffusionPipeline, FluxPipeline # type: ignore
|
from diffusers import StableDiffusionPipeline, FluxPipeline
|
||||||
|
|
||||||
class ImageModelCache: # Stay loaded for 3 hours
|
class ImageModelCache: # Stay loaded for 3 hours
|
||||||
def __init__(self, timeout_seconds: float = 3 * 60 * 60):
|
def __init__(self, timeout_seconds: float = 3 * 60 * 60):
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
from pydantic import BaseModel, Field # type: ignore
|
from pydantic import BaseModel, Field
|
||||||
from typing import Dict, List, Optional, Any, Union, Mapping
|
from typing import Dict, List, Optional, Any, Union, Mapping
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from . rag import ChromaDBGetResponse
|
from . rag import ChromaDBGetResponse
|
||||||
from ollama._types import Options # type: ignore
|
from ollama._types import Options
|
||||||
|
|
||||||
class Tunables(BaseModel):
|
class Tunables(BaseModel):
|
||||||
enable_rag: bool = True # Enable RAG collection chromadb matching
|
enable_rag: bool = True # Enable RAG collection chromadb matching
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from prometheus_client import Counter, Histogram # type: ignore
|
from prometheus_client import Counter, Histogram
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
|
||||||
def singleton(cls):
|
def singleton(cls):
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import BaseModel, Field # type: ignore
|
from pydantic import BaseModel, Field
|
||||||
from typing import Dict, Literal, Any, AsyncGenerator, Optional
|
from typing import Dict, Literal, Any, AsyncGenerator, Optional
|
||||||
import inspect
|
import inspect
|
||||||
import random
|
import random
|
||||||
@ -12,7 +12,7 @@ import os
|
|||||||
import gc
|
import gc
|
||||||
import tempfile
|
import tempfile
|
||||||
import uuid
|
import uuid
|
||||||
import torch # type: ignore
|
import torch
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field # type: ignore
|
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field
|
||||||
from typing import List, Optional, Dict, Any, Union
|
from typing import List, Optional, Dict, Any, Union
|
||||||
import os
|
import os
|
||||||
import glob
|
import glob
|
||||||
@ -8,16 +8,16 @@ import hashlib
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import numpy as np # type: ignore
|
import numpy as np
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import chromadb # type: ignore
|
import chromadb
|
||||||
import ollama
|
import ollama
|
||||||
from watchdog.observers import Observer # type: ignore
|
from watchdog.observers import Observer
|
||||||
from watchdog.events import FileSystemEventHandler # type: ignore
|
from watchdog.events import FileSystemEventHandler
|
||||||
import umap # type: ignore
|
import umap
|
||||||
from markitdown import MarkItDown # type: ignore
|
from markitdown import MarkItDown
|
||||||
from chromadb.api.models.Collection import Collection # type: ignore
|
from chromadb.api.models.Collection import Collection
|
||||||
|
|
||||||
from .markdown_chunker import (
|
from .markdown_chunker import (
|
||||||
MarkdownChunker,
|
MarkdownChunker,
|
||||||
@ -388,9 +388,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
os.makedirs(self.persist_directory)
|
os.makedirs(self.persist_directory)
|
||||||
|
|
||||||
# Initialize ChromaDB client
|
# Initialize ChromaDB client
|
||||||
chroma_client = chromadb.PersistentClient( # type: ignore
|
chroma_client = chromadb.PersistentClient(
|
||||||
path=self.persist_directory,
|
path=self.persist_directory,
|
||||||
settings=chromadb.Settings(anonymized_telemetry=False), # type: ignore
|
settings=chromadb.Settings(anonymized_telemetry=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the collection exists
|
# Check if the collection exists
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import redis.asyncio as redis # type: ignore
|
import redis.asyncio as redis
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from typing import List, Optional, Generator, ClassVar, Any, Dict
|
from typing import List, Optional, Generator, ClassVar, Any, Dict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -7,12 +7,12 @@ from typing import (
|
|||||||
)
|
)
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from bs4 import BeautifulSoup # type: ignore
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
from geopy.geocoders import Nominatim # type: ignore
|
from geopy.geocoders import Nominatim
|
||||||
import pytz # type: ignore
|
import pytz
|
||||||
import requests
|
import requests
|
||||||
import yfinance as yf # type: ignore
|
import yfinance as yf
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
@ -523,4 +523,4 @@ def enabled_tools(tools: List[ToolEntry]) -> List[ToolEntry]:
|
|||||||
|
|
||||||
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite", "GenerateImage"]
|
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite", "GenerateImage"]
|
||||||
__all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"]
|
__all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"]
|
||||||
# __all__.extend(__tool_functions__) # type: ignore
|
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING, Literal
|
from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING, Literal
|
||||||
|
|
||||||
from typing_extensions import Annotated, Union
|
from typing_extensions import Annotated, Union
|
||||||
import numpy as np # type: ignore
|
import numpy as np
|
||||||
import logging
|
import logging
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
from prometheus_client import CollectorRegistry, Counter
|
||||||
import traceback
|
import traceback
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
Loading…
x
Reference in New Issue
Block a user