Compare commits

...

2 Commits

Author SHA1 Message Date
02a278736e Almost working again 2025-05-28 22:50:38 -07:00
b5b3a1f5dc Working on model conversion 2025-05-28 19:09:02 -07:00
32 changed files with 2067 additions and 754 deletions

View File

@ -1,15 +1,14 @@
import React, { useEffect, useState, useRef, useCallback } from 'react'; import React, { useEffect, useState, useRef, useCallback } from 'react';
import { Route, Routes, useLocation, useNavigate } from 'react-router-dom'; import { Route, Routes, useLocation, useNavigate } from 'react-router-dom';
import { ThemeProvider } from '@mui/material/styles'; import { ThemeProvider } from '@mui/material/styles';
import { backstoryTheme } from './BackstoryTheme'; import { backstoryTheme } from './BackstoryTheme';
import { SeverityType } from 'components/Snack'; import { SeverityType } from 'components/Snack';
import { Query } from 'types/types';
import { ConversationHandle } from 'components/Conversation'; import { ConversationHandle } from 'components/Conversation';
import { UserProvider } from 'hooks/useUser'; import { UserProvider } from 'hooks/useUser';
import { CandidateRoute } from 'routes/CandidateRoute'; import { CandidateRoute } from 'routes/CandidateRoute';
import { BackstoryLayout } from 'components/layout/BackstoryLayout'; import { BackstoryLayout } from 'components/layout/BackstoryLayout';
import { ChatQuery } from 'types/types';
import './BackstoryApp.css'; import './BackstoryApp.css';
import '@fontsource/roboto/300.css'; import '@fontsource/roboto/300.css';
@ -17,13 +16,7 @@ import '@fontsource/roboto/400.css';
import '@fontsource/roboto/500.css'; import '@fontsource/roboto/500.css';
import '@fontsource/roboto/700.css'; import '@fontsource/roboto/700.css';
import { debugConversion } from 'types/conversion';
import { User, Guest, Candidate } from 'types/types';
const BackstoryApp = () => { const BackstoryApp = () => {
const [user, setUser] = useState<User | null>(null);
const [guest, setGuest] = useState<Guest | null>(null);
const [candidate, setCandidate] = useState<Candidate | null>(null);
const navigate = useNavigate(); const navigate = useNavigate();
const location = useLocation(); const location = useLocation();
const snackRef = useRef<any>(null); const snackRef = useRef<any>(null);
@ -31,58 +24,13 @@ const BackstoryApp = () => {
const setSnack = useCallback((message: string, severity?: SeverityType) => { const setSnack = useCallback((message: string, severity?: SeverityType) => {
snackRef.current?.setSnack(message, severity); snackRef.current?.setSnack(message, severity);
}, [snackRef]); }, [snackRef]);
const submitQuery = (query: Query) => { const submitQuery = (query: ChatQuery) => {
console.log(`handleSubmitChatQuery:`, query, chatRef.current ? ' sending' : 'no handler'); console.log(`handleSubmitChatQuery:`, query, chatRef.current ? ' sending' : 'no handler');
chatRef.current?.submitQuery(query); chatRef.current?.submitQuery(query);
navigate('/chat'); navigate('/chat');
}; };
const [page, setPage] = useState<string>(""); const [page, setPage] = useState<string>("");
const createGuestSession = () => {
console.log("TODO: Convert this to query the server for the session instead of generating it.");
const sessionId = `guest_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
const guest: Guest = {
sessionId,
createdAt: new Date(),
lastActivity: new Date(),
ipAddress: 'unknown',
userAgent: navigator.userAgent
};
setGuest(guest);
debugConversion(guest, 'Guest Session');
};
const checkExistingAuth = () => {
const token = localStorage.getItem('accessToken');
const userData = localStorage.getItem('userData');
if (token && userData) {
try {
const user = JSON.parse(userData);
// Convert dates back to Date objects if they're stored as strings
if (user.createdAt && typeof user.createdAt === 'string') {
user.createdAt = new Date(user.createdAt);
}
if (user.updatedAt && typeof user.updatedAt === 'string') {
user.updatedAt = new Date(user.updatedAt);
}
if (user.lastLogin && typeof user.lastLogin === 'string') {
user.lastLogin = new Date(user.lastLogin);
}
setUser(user);
} catch (e) {
localStorage.removeItem('accessToken');
localStorage.removeItem('refreshToken');
localStorage.removeItem('userData');
}
}
};
// Create guest session on component mount
useEffect(() => {
createGuestSession();
checkExistingAuth();
}, []);
useEffect(() => { useEffect(() => {
const currentRoute = location.pathname.split("/")[1] ? `/${location.pathname.split("/")[1]}` : "/"; const currentRoute = location.pathname.split("/")[1] ? `/${location.pathname.split("/")[1]}` : "/";
setPage(currentRoute); setPage(currentRoute);
@ -91,27 +39,20 @@ const BackstoryApp = () => {
// Render appropriate routes based on user type // Render appropriate routes based on user type
return ( return (
<ThemeProvider theme={backstoryTheme}> <ThemeProvider theme={backstoryTheme}>
<UserProvider {...{ guest, user, candidate, setSnack }}> <UserProvider {...{ setSnack }}>
<Routes> <Routes>
<Route path="/u/:username" element={<CandidateRoute {...{ guest, candidate, setCandidate, setSnack }} />} /> <Route path="/u/:username" element={<CandidateRoute {...{ setSnack }} />} />
{/* Static/shared routes */} {/* Static/shared routes */}
<Route <Route
path="/*" path="/*"
element={ element={
<BackstoryLayout <BackstoryLayout {...{ setSnack, page, chatRef, snackRef, submitQuery }} />
setSnack={setSnack}
page={page}
chatRef={chatRef}
snackRef={snackRef}
submitQuery={submitQuery}
/>
} }
/> />
</Routes> </Routes>
</UserProvider> </UserProvider>
</ThemeProvider> </ThemeProvider>
); );
}; };
export { export {

View File

@ -1,16 +1,16 @@
import Box from '@mui/material/Box'; import Box from '@mui/material/Box';
import Button from '@mui/material/Button'; import Button from '@mui/material/Button';
import { Query } from "../types/types"; import { ChatQuery } from "types/types";
type ChatSubmitQueryInterface = (query: Query) => void; type ChatSubmitQueryInterface = (query: ChatQuery) => void;
interface ChatQueryInterface { interface BackstoryQueryInterface {
query: Query, query: ChatQuery,
submitQuery?: ChatSubmitQueryInterface submitQuery?: ChatSubmitQueryInterface
} }
const ChatQuery = (props : ChatQueryInterface) => { const BackstoryQuery = (props : BackstoryQueryInterface) => {
const { query, submitQuery } = props; const { query, submitQuery } = props;
if (submitQuery === undefined) { if (submitQuery === undefined) {
@ -29,11 +29,11 @@ const ChatQuery = (props : ChatQueryInterface) => {
} }
export type { export type {
ChatQueryInterface, BackstoryQueryInterface,
ChatSubmitQueryInterface, ChatSubmitQueryInterface,
}; };
export { export {
ChatQuery, BackstoryQuery,
}; };

View File

@ -1,7 +1,7 @@
import React, { ReactElement, JSXElementConstructor } from 'react'; import React, { ReactElement, JSXElementConstructor } from 'react';
import Box from '@mui/material/Box'; import Box from '@mui/material/Box';
import { SxProps, Theme } from '@mui/material'; import { SxProps, Theme } from '@mui/material';
import { ChatSubmitQueryInterface } from './ChatQuery'; import { ChatSubmitQueryInterface } from './BackstoryQuery';
import { SetSnackType } from './Snack'; import { SetSnackType } from './Snack';
interface BackstoryElementProps { interface BackstoryElementProps {

View File

@ -14,8 +14,8 @@ import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryT
import { BackstoryElementProps } from './BackstoryTab'; import { BackstoryElementProps } from './BackstoryTab';
import { connectionBase } from 'utils/Global'; import { connectionBase } from 'utils/Global';
import { useUser } from "hooks/useUser"; import { useUser } from "hooks/useUser";
import { ApiClient, StreamingResponse } from 'types/api-client'; import { StreamingResponse } from 'types/api-client';
import { ChatMessage, ChatContext, ChatSession, AIParameters, Query } from 'types/types'; import { ChatMessage, ChatContext, ChatSession, ChatQuery } from 'types/types';
import { PaginatedResponse } from 'types/conversion'; import { PaginatedResponse } from 'types/conversion';
import './Conversation.css'; import './Conversation.css';
@ -29,7 +29,7 @@ const loadingMessage: ChatMessage = { ...defaultMessage, content: "Establishing
type ConversationMode = 'chat' | 'job_description' | 'resume' | 'fact_check' | 'persona'; type ConversationMode = 'chat' | 'job_description' | 'resume' | 'fact_check' | 'persona';
interface ConversationHandle { interface ConversationHandle {
submitQuery: (query: Query) => void; submitQuery: (query: ChatQuery) => void;
fetchHistory: () => void; fetchHistory: () => void;
} }
@ -69,8 +69,7 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
sx, sx,
type, type,
} = props; } = props;
const apiClient = new ApiClient(); const { candidate, apiClient } = useUser()
const { candidate } = useUser()
const [processing, setProcessing] = useState<boolean>(false); const [processing, setProcessing] = useState<boolean>(false);
const [countdown, setCountdown] = useState<number>(0); const [countdown, setCountdown] = useState<number>(0);
const [conversation, setConversation] = useState<ChatMessage[]>([]); const [conversation, setConversation] = useState<ChatMessage[]>([]);
@ -125,23 +124,7 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
} }
const createChatSession = async () => { const createChatSession = async () => {
try { try {
const aiParameters: AIParameters = { const chatContext: ChatContext = { type: "general" };
name: '',
model: 'custom',
temperature: 0.7,
maxTokens: -1,
topP: 1,
frequencyPenalty: 0,
presencePenalty: 0,
isDefault: true,
createdAt: new Date(),
updatedAt: new Date()
};
const chatContext: ChatContext = {
type: "general",
aiParameters
};
const response: ChatSession = await apiClient.createChatSession(chatContext); const response: ChatSession = await apiClient.createChatSession(chatContext);
setChatSession(response); setChatSession(response);
} catch (e) { } catch (e) {
@ -203,14 +186,14 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
}, [chatSession]); }, [chatSession]);
const handleEnter = (value: string) => { const handleEnter = (value: string) => {
const query: Query = { const query: ChatQuery = {
prompt: value prompt: value
} }
processQuery(query); processQuery(query);
}; };
useImperativeHandle(ref, () => ({ useImperativeHandle(ref, () => ({
submitQuery: (query: Query) => { submitQuery: (query: ChatQuery) => {
processQuery(query); processQuery(query);
}, },
fetchHistory: () => { getChatMessages(); } fetchHistory: () => { getChatMessages(); }
@ -255,7 +238,7 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
controllerRef.current = null; controllerRef.current = null;
}; };
const processQuery = (query: Query) => { const processQuery = (query: ChatQuery) => {
if (controllerRef.current || !chatSession || !chatSession.id) { if (controllerRef.current || !chatSession || !chatSession.id) {
return; return;
} }

View File

@ -2,14 +2,14 @@ import React from 'react';
import { MuiMarkdown } from 'mui-markdown'; import { MuiMarkdown } from 'mui-markdown';
import { useTheme } from '@mui/material/styles'; import { useTheme } from '@mui/material/styles';
import { Link } from '@mui/material'; import { Link } from '@mui/material';
import { ChatQuery } from './ChatQuery'; import { BackstoryQuery } from 'components/BackstoryQuery';
import Box from '@mui/material/Box'; import Box from '@mui/material/Box';
import JsonView from '@uiw/react-json-view'; import JsonView from '@uiw/react-json-view';
import { vscodeTheme } from '@uiw/react-json-view/vscode'; import { vscodeTheme } from '@uiw/react-json-view/vscode';
import { Mermaid } from './Mermaid'; import { Mermaid } from 'components/Mermaid';
import { Scrollable } from './Scrollable'; import { Scrollable } from 'components/Scrollable';
import { jsonrepair } from 'jsonrepair'; import { jsonrepair } from 'jsonrepair';
import { GenerateImage } from './GenerateImage'; import { GenerateImage } from 'components/GenerateImage';
import './StyledMarkdown.css'; import './StyledMarkdown.css';
import { BackstoryElementProps } from './BackstoryTab'; import { BackstoryElementProps } from './BackstoryTab';
@ -98,13 +98,13 @@ const StyledMarkdown: React.FC<StyledMarkdownProps> = (props: StyledMarkdownProp
} }
} }
}, },
ChatQuery: { BackstoryQuery: {
component: (props: { query: string }) => { component: (props: { query: string }) => {
const queryString = props.query.replace(/(\w+):/g, '"$1":'); const queryString = props.query.replace(/(\w+):/g, '"$1":');
try { try {
const query = JSON.parse(queryString); const query = JSON.parse(queryString);
return <ChatQuery submitQuery={submitQuery} query={query} /> return <BackstoryQuery submitQuery={submitQuery} query={query} />
} catch (e) { } catch (e) {
console.log("StyledMarkdown error:", queryString, e); console.log("StyledMarkdown error:", queryString, e);
return props.query; return props.query;

View File

@ -38,17 +38,17 @@ const DefaultNavItems: NavigationLinkType[] = [
const CandidateNavItems : NavigationLinkType[]= [ const CandidateNavItems : NavigationLinkType[]= [
{ name: 'Chat', path: '/chat', icon: <ChatIcon /> }, { name: 'Chat', path: '/chat', icon: <ChatIcon /> },
{ name: 'Job Analysis', path: '/job-analysis', icon: <WorkIcon /> }, // { name: 'Job Analysis', path: '/job-analysis', icon: <WorkIcon /> },
{ name: 'Resume Builder', path: '/resume-builder', icon: <WorkIcon /> }, { name: 'Resume Builder', path: '/resume-builder', icon: <WorkIcon /> },
{ name: 'Knowledge Explorer', path: '/knowledge-explorer', icon: <WorkIcon /> }, // { name: 'Knowledge Explorer', path: '/knowledge-explorer', icon: <WorkIcon /> },
{ name: 'Find a Candidate', path: '/find-a-candidate', icon: <InfoIcon /> }, { name: 'Find a Candidate', path: '/find-a-candidate', icon: <InfoIcon /> },
// { name: 'Dashboard', icon: <DashboardIcon />, path: '/dashboard' }, // { name: 'Dashboard', icon: <DashboardIcon />, path: '/dashboard' },
// { name: 'Profile', icon: <PersonIcon />, path: '/profile' }, // { name: 'Profile', icon: <PersonIcon />, path: '/profile' },
// { name: 'Backstory', icon: <HistoryIcon />, path: '/backstory' }, // { name: 'Backstory', icon: <HistoryIcon />, path: '/backstory' },
{ name: 'Resumes', icon: <DescriptionIcon />, path: '/resumes' }, // { name: 'Resumes', icon: <DescriptionIcon />, path: '/resumes' },
// { name: 'Q&A Setup', icon: <QuestionAnswerIcon />, path: '/qa-setup' }, // { name: 'Q&A Setup', icon: <QuestionAnswerIcon />, path: '/qa-setup' },
{ name: 'Analytics', icon: <BarChartIcon />, path: '/analytics' }, // { name: 'Analytics', icon: <BarChartIcon />, path: '/analytics' },
{ name: 'Settings', icon: <SettingsIcon />, path: '/settings' }, // { name: 'Settings', icon: <SettingsIcon />, path: '/settings' },
]; ];
const EmployerNavItems: NavigationLinkType[] = [ const EmployerNavItems: NavigationLinkType[] = [
@ -121,13 +121,16 @@ const BackstoryPageContainer = (props : BackstoryPageContainerProps) => {
); );
} }
const BackstoryLayout: React.FC<{ interface BackstoryLayoutProps {
setSnack: SetSnackType; setSnack: SetSnackType;
page: string; page: string;
chatRef: React.Ref<any>; chatRef: React.Ref<any>;
snackRef: React.Ref<any>; snackRef: React.Ref<any>;
submitQuery: any; submitQuery: any;
}> = ({ setSnack, page, chatRef, snackRef, submitQuery }) => { };
const BackstoryLayout: React.FC<BackstoryLayoutProps> = (props: BackstoryLayoutProps) => {
const { setSnack, page, chatRef, snackRef, submitQuery } = props;
const navigate = useNavigate(); const navigate = useNavigate();
const location = useLocation(); const location = useLocation();
const { user, guest, candidate } = useUser(); const { user, guest, candidate } = useUser();

View File

@ -17,6 +17,7 @@ import { CandidateListingPage } from 'pages/FindCandidatePage';
import { JobAnalysisPage } from 'pages/JobAnalysisPage'; import { JobAnalysisPage } from 'pages/JobAnalysisPage';
import { GenerateCandidate } from "pages/GenerateCandidate"; import { GenerateCandidate } from "pages/GenerateCandidate";
import { ControlsPage } from 'pages/ControlsPage'; import { ControlsPage } from 'pages/ControlsPage';
import { LoginPage } from "pages/LoginPage";
const ProfilePage = () => (<BetaPage><Typography variant="h4">Profile</Typography></BetaPage>); const ProfilePage = () => (<BetaPage><Typography variant="h4">Profile</Typography></BetaPage>);
const BackstoryPage = () => (<BetaPage><Typography variant="h4">Backstory</Typography></BetaPage>); const BackstoryPage = () => (<BetaPage><Typography variant="h4">Backstory</Typography></BetaPage>);
@ -27,7 +28,6 @@ const SavedPage = () => (<BetaPage><Typography variant="h4">Saved</Typography></
const JobsPage = () => (<BetaPage><Typography variant="h4">Jobs</Typography></BetaPage>); const JobsPage = () => (<BetaPage><Typography variant="h4">Jobs</Typography></BetaPage>);
const CompanyPage = () => (<BetaPage><Typography variant="h4">Company</Typography></BetaPage>); const CompanyPage = () => (<BetaPage><Typography variant="h4">Company</Typography></BetaPage>);
const LogoutPage = () => (<BetaPage><Typography variant="h4">Logout page...</Typography></BetaPage>); const LogoutPage = () => (<BetaPage><Typography variant="h4">Logout page...</Typography></BetaPage>);
const LoginPage = () => (<BetaPage><Typography variant="h4">Login page...</Typography></BetaPage>);
// const DashboardPage = () => (<BetaPage><Typography variant="h4">Dashboard</Typography></BetaPage>); // const DashboardPage = () => (<BetaPage><Typography variant="h4">Dashboard</Typography></BetaPage>);
// const AnalyticsPage = () => (<BetaPage><Typography variant="h4">Analytics</Typography></BetaPage>); // const AnalyticsPage = () => (<BetaPage><Typography variant="h4">Analytics</Typography></BetaPage>);
// const SettingsPage = () => (<BetaPage><Typography variant="h4">Settings</Typography></BetaPage>); // const SettingsPage = () => (<BetaPage><Typography variant="h4">Settings</Typography></BetaPage>);
@ -57,7 +57,7 @@ const getBackstoryDynamicRoutes = (props: BackstoryDynamicRoutesProps): ReactNod
routes.push(<Route key={`${index++}`} path="/login" element={<LoginPage />} />); routes.push(<Route key={`${index++}`} path="/login" element={<LoginPage />} />);
routes.push(<Route key={`${index++}`} path="*" element={<BetaPage />} />); routes.push(<Route key={`${index++}`} path="*" element={<BetaPage />} />);
} else { } else {
routes.push(<Route key={`${index++}`} path="/login" element={<LoginPage />} />);
routes.push(<Route key={`${index++}`} path="/logout" element={<LogoutPage />} />); routes.push(<Route key={`${index++}`} path="/logout" element={<LogoutPage />} />);
if (user.userType === 'candidate') { if (user.userType === 'candidate') {

View File

@ -87,7 +87,6 @@ const MobileDrawer = styled(Drawer)(({ theme }) => ({
interface HeaderProps { interface HeaderProps {
transparent?: boolean; transparent?: boolean;
onLogout?: () => void;
className?: string; className?: string;
navigate: NavigateFunction; navigate: NavigateFunction;
navigationLinks: NavigationLinkType[]; navigationLinks: NavigationLinkType[];
@ -98,7 +97,7 @@ interface HeaderProps {
} }
const Header: React.FC<HeaderProps> = (props: HeaderProps) => { const Header: React.FC<HeaderProps> = (props: HeaderProps) => {
const { user } = useUser(); const { user, setUser } = useUser();
const candidate: Candidate | null = (user && user.userType === "candidate") ? user as Candidate : null; const candidate: Candidate | null = (user && user.userType === "candidate") ? user as Candidate : null;
const employer: Employer | null = (user && user.userType === "employer") ? user as Employer : null; const employer: Employer | null = (user && user.userType === "employer") ? user as Employer : null;
const { const {
@ -108,7 +107,6 @@ const Header: React.FC<HeaderProps> = (props: HeaderProps) => {
navigationLinks, navigationLinks,
showLogin, showLogin,
sessionId, sessionId,
onLogout,
setSnack, setSnack,
} = props; } = props;
const theme = useTheme(); const theme = useTheme();
@ -177,9 +175,7 @@ const Header: React.FC<HeaderProps> = (props: HeaderProps) => {
const handleLogout = () => { const handleLogout = () => {
handleUserMenuClose(); handleUserMenuClose();
if (onLogout) { setUser(null);
onLogout();
}
}; };
const handleDrawerToggle = () => { const handleDrawerToggle = () => {
@ -245,14 +241,6 @@ const Header: React.FC<HeaderProps> = (props: HeaderProps) => {
> >
Login Login
</Button> </Button>
<Button
variant="outlined"
color="secondary"
fullWidth
onClick={() => { navigate("/register"); }}
>
Register
</Button>
</Box> </Box>
)} )}
</> </>
@ -279,14 +267,6 @@ const Header: React.FC<HeaderProps> = (props: HeaderProps) => {
> >
Login Login
</Button> </Button>
<Button
color="secondary"
variant="contained"
onClick={() => { navigate("/register"); }}
sx={{ display: { xs: 'none', sm: 'block' } }}
>
Register
</Button>
</> </>
); );
} }

View File

@ -1,11 +1,16 @@
import React, { createContext, useContext, useEffect, useState } from "react"; import React, { createContext, useContext, useEffect, useState } from "react";
import { SetSnackType } from '../components/Snack'; import { SetSnackType } from '../components/Snack';
import { User, Guest, Candidate } from 'types/types'; import { User, Guest, Candidate } from 'types/types';
import { ApiClient } from "types/api-client";
import { debugConversion } from "types/conversion";
type UserContextType = { type UserContextType = {
apiClient: ApiClient;
user: User | null; user: User | null;
guest: Guest; guest: Guest;
candidate: Candidate | null; candidate: Candidate | null;
setUser: (user: User | null) => void;
setCandidate: (candidate: Candidate | null) => void;
}; };
const UserContext = createContext<UserContextType | undefined>(undefined); const UserContext = createContext<UserContextType | undefined>(undefined);
@ -18,19 +23,136 @@ const useUser = () => {
interface UserProviderProps { interface UserProviderProps {
children: React.ReactNode; children: React.ReactNode;
candidate: Candidate | null;
user: User | null;
guest: Guest | null;
setSnack: SetSnackType; setSnack: SetSnackType;
}; };
const UserProvider: React.FC<UserProviderProps> = (props: UserProviderProps) => { const UserProvider: React.FC<UserProviderProps> = (props: UserProviderProps) => {
const { guest, user, children, candidate, setSnack } = props; const { children, setSnack } = props;
const [apiClient, setApiClient] = useState<ApiClient>(new ApiClient());
const [candidate, setCandidate] = useState<Candidate | null>(null);
const [guest, setGuest] = useState<Guest | null>(null);
const [user, setUser] = useState<User | null>(null);
const [activeUser, setActiveUser] = useState<User | null>(null);
useEffect(() => {
console.log("Candidate =>", candidate);
}, [candidate]);
useEffect(() => {
console.log("Guest =>", guest);
}, [guest]);
/* If the user changes to a non-null value, create a new
* apiClient with the access token */
useEffect(() => {
console.log("User => ", user);
if (user === null) {
return;
}
/* This apiClient will persist until the user is changed
* or logged out */
const accessToken = localStorage.getItem('accessToken');
if (!accessToken) {
throw Error("accessToken is not set for user!");
}
setApiClient(new ApiClient(accessToken));
}, [user]);
/* Handle logout if any consumers of UserProvider setUser to NULL */
useEffect(() => {
/* If there is an active user and it is the same as the
* new user, do nothing */
if (activeUser && activeUser.email === user?.email) {
return;
}
const logout = async () => {
if (!activeUser) {
return;
}
console.log(`Logging out ${activeUser.email}`);
try {
const accessToken = localStorage.getItem('accessToken');
const refreshToken = localStorage.getItem('refreshToken');
if (!accessToken || !refreshToken) {
setSnack("Authentication tokens are invalid.", "error");
return;
}
const results = await apiClient.logout(accessToken, refreshToken);
if (results.error) {
console.error(results.error);
setSnack(results.error.message, "error")
}
} catch (e) {
console.error(e);
setSnack(`Unable to logout: ${e}`, "error")
}
localStorage.removeItem('accessToken');
localStorage.removeItem('refreshToken');
localStorage.removeItem('userData');
createGuestSession();
setUser(null);
};
setActiveUser(user);
if (!user) {
logout();
}
}, [user, apiClient, activeUser]);
const createGuestSession = () => {
console.log("TODO: Convert this to query the server for the session instead of generating it.");
const sessionId = `guest_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
const guest: Guest = {
sessionId,
createdAt: new Date(),
lastActivity: new Date(),
ipAddress: 'unknown',
userAgent: navigator.userAgent
};
setGuest(guest);
debugConversion(guest, 'Guest Session');
};
const checkExistingAuth = () => {
const accessToken = localStorage.getItem('accessToken');
const userData = localStorage.getItem('userData');
if (accessToken && userData) {
try {
const user = JSON.parse(userData);
// Convert dates back to Date objects if they're stored as strings
if (user.createdAt && typeof user.createdAt === 'string') {
user.createdAt = new Date(user.createdAt);
}
if (user.updatedAt && typeof user.updatedAt === 'string') {
user.updatedAt = new Date(user.updatedAt);
}
if (user.lastLogin && typeof user.lastLogin === 'string') {
user.lastLogin = new Date(user.lastLogin);
}
setApiClient(new ApiClient(accessToken));
setUser(user);
} catch (e) {
localStorage.removeItem('accessToken');
localStorage.removeItem('refreshToken');
localStorage.removeItem('userData');
}
}
};
// Create guest session on component mount
useEffect(() => {
createGuestSession();
checkExistingAuth();
}, []);
if (guest === null) { if (guest === null) {
return <></>; return <></>;
} }
return ( return (
<UserContext.Provider value={{ candidate, user, guest }}> <UserContext.Provider value={{ apiClient, candidate, setCandidate, user, setUser, guest }}>
{children} {children}
</UserContext.Provider> </UserContext.Provider>
); );

View File

@ -1,5 +1,5 @@
import React, { useState, useEffect } from 'react'; import React, { useState, useEffect } from 'react';
import { useNavigate } from 'react-router-dom'; import { useNavigate, useLocation } from 'react-router-dom';
import { import {
Box, Box,
Container, Container,
@ -35,7 +35,11 @@ const BetaPage: React.FC<BetaPageProps> = ({
const theme = useTheme(); const theme = useTheme();
const [showSparkle, setShowSparkle] = useState<boolean>(false); const [showSparkle, setShowSparkle] = useState<boolean>(false);
const navigate = useNavigate(); const navigate = useNavigate();
const location = useLocation();
if (!children) {
children = (<Box>Location: {location.pathname}</Box>);
}
console.log("BetaPage", children); console.log("BetaPage", children);
// Enhanced sparkle effect for background elements // Enhanced sparkle effect for background elements

View File

@ -6,19 +6,18 @@ import MuiMarkdown from 'mui-markdown';
import { BackstoryPageProps } from '../components/BackstoryTab'; import { BackstoryPageProps } from '../components/BackstoryTab';
import { Conversation, ConversationHandle } from '../components/Conversation'; import { Conversation, ConversationHandle } from '../components/Conversation';
import { ChatQuery } from '../components/ChatQuery'; import { BackstoryQuery } from '../components/BackstoryQuery';
import { CandidateInfo } from 'components/CandidateInfo'; import { CandidateInfo } from 'components/CandidateInfo';
import { useUser } from "../hooks/useUser"; import { useUser } from "../hooks/useUser";
import { Candidate } from "../types/types";
const ChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: BackstoryPageProps, ref) => { const ChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: BackstoryPageProps, ref) => {
const { setSnack, submitQuery } = props; const { setSnack, submitQuery } = props;
const { candidate } = useUser();
const theme = useTheme(); const theme = useTheme();
const isMobile = useMediaQuery(theme.breakpoints.down('md')); const isMobile = useMediaQuery(theme.breakpoints.down('md'));
const [questions, setQuestions] = useState<React.ReactElement[]>([]); const [questions, setQuestions] = useState<React.ReactElement[]>([]);
const { user } = useUser();
const candidate: Candidate | null = (user && user.userType === 'candidate') ? user as Candidate : null;
console.log("ChatPage candidate =>", candidate);
useEffect(() => { useEffect(() => {
if (!candidate) { if (!candidate) {
return; return;
@ -27,7 +26,7 @@ const ChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: Back
setQuestions([ setQuestions([
<Box sx={{ display: "flex", flexDirection: isMobile ? "column" : "row" }}> <Box sx={{ display: "flex", flexDirection: isMobile ? "column" : "row" }}>
{candidate.questions?.map(({ question, tunables }, i: number) => {candidate.questions?.map(({ question, tunables }, i: number) =>
<ChatQuery key={i} query={{ prompt: question, tunables: tunables }} submitQuery={submitQuery} /> <BackstoryQuery key={i} query={{ prompt: question, tunables: tunables }} submitQuery={submitQuery} />
)} )}
</Box>, </Box>,
<Box sx={{ p: 1 }}> <Box sx={{ p: 1 }}>
@ -38,8 +37,8 @@ const ChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: Back
}, [candidate, isMobile, submitQuery]); }, [candidate, isMobile, submitQuery]);
if (!candidate) { if (!candidate) {
return (<></>); return (<></>);
} }
return ( return (
<Box> <Box>
<CandidateInfo candidate={candidate} action="Chat with Backstory AI about " /> <CandidateInfo candidate={candidate} action="Chat with Backstory AI about " />

View File

@ -100,7 +100,7 @@ const ControlsPage = (props: BackstoryPageProps) => {
} }
const sendSystemPrompt = async (prompt: string) => { const sendSystemPrompt = async (prompt: string) => {
try { try {
const response = await fetch(connectionBase + `/api/tunables`, { const response = await fetch(connectionBase + `/api/1.0/tunables`, {
method: 'PUT', method: 'PUT',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@ -126,7 +126,7 @@ const ControlsPage = (props: BackstoryPageProps) => {
const reset = async (types: ("rags" | "tools" | "history" | "system_prompt")[], message: string = "Update successful.") => { const reset = async (types: ("rags" | "tools" | "history" | "system_prompt")[], message: string = "Update successful.") => {
try { try {
const response = await fetch(connectionBase + `/api/reset/`, { const response = await fetch(connectionBase + `/api/1.0/reset/`, {
method: 'PUT', method: 'PUT',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@ -178,7 +178,7 @@ const ControlsPage = (props: BackstoryPageProps) => {
} }
const fetchSystemInfo = async () => { const fetchSystemInfo = async () => {
try { try {
const response = await fetch(connectionBase + `/api/system-info`, { const response = await fetch(connectionBase + `/api/1.0/system-info`, {
method: 'GET', method: 'GET',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@ -210,13 +210,16 @@ const ControlsPage = (props: BackstoryPageProps) => {
}, [systemInfo, setSystemInfo, setSnack]) }, [systemInfo, setSystemInfo, setSnack])
useEffect(() => { useEffect(() => {
if (!systemPrompt) {
return;
}
setEditSystemPrompt(systemPrompt.trim()); setEditSystemPrompt(systemPrompt.trim());
}, [systemPrompt, setEditSystemPrompt]); }, [systemPrompt, setEditSystemPrompt]);
const toggleRag = async (tool: Tool) => { const toggleRag = async (tool: Tool) => {
tool.enabled = !tool.enabled tool.enabled = !tool.enabled
try { try {
const response = await fetch(connectionBase + `/api/tunables`, { const response = await fetch(connectionBase + `/api/1.0/tunables`, {
method: 'PUT', method: 'PUT',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@ -238,7 +241,7 @@ const ControlsPage = (props: BackstoryPageProps) => {
const toggleTool = async (tool: Tool) => { const toggleTool = async (tool: Tool) => {
tool.enabled = !tool.enabled tool.enabled = !tool.enabled
try { try {
const response = await fetch(connectionBase + `/api/tunables`, { const response = await fetch(connectionBase + `/api/1.0/tunables`, {
method: 'PUT', method: 'PUT',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@ -265,7 +268,7 @@ const ControlsPage = (props: BackstoryPageProps) => {
const fetchTunables = async () => { const fetchTunables = async () => {
try { try {
// Make the fetch request with proper headers // Make the fetch request with proper headers
const response = await fetch(connectionBase + `/api/tunables`, { const response = await fetch(connectionBase + `/api/1.0/tunables`, {
method: 'GET', method: 'GET',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',

View File

@ -5,11 +5,11 @@ import Box from '@mui/material/Box';
import { BackstoryPageProps } from '../components/BackstoryTab'; import { BackstoryPageProps } from '../components/BackstoryTab';
import { CandidateInfo } from 'components/CandidateInfo'; import { CandidateInfo } from 'components/CandidateInfo';
import { connectionBase } from '../utils/Global';
import { Candidate } from "../types/types"; import { Candidate } from "../types/types";
import { ApiClient } from 'types/api-client'; import { useUser } from 'hooks/useUser';
const CandidateListingPage = (props: BackstoryPageProps) => { const CandidateListingPage = (props: BackstoryPageProps) => {
const apiClient = new ApiClient(); const { apiClient, setCandidate } = useUser();
const navigate = useNavigate(); const navigate = useNavigate();
const { setSnack } = props; const { setSnack } = props;
const [candidates, setCandidates] = useState<Candidate[] | null>(null); const [candidates, setCandidates] = useState<Candidate[] | null>(null);
@ -44,27 +44,24 @@ const CandidateListingPage = (props: BackstoryPageProps) => {
return ( return (
<Box sx={{display: "flex", flexDirection: "column"}}> <Box sx={{display: "flex", flexDirection: "column"}}>
<Box sx={{ p: 1, textAlign: "center" }}> <Box sx={{ p: 1, textAlign: "center" }}>
Not seeing a candidate you like? Not seeing a candidate you like?
<Button <Button
variant="contained" variant="contained"
sx={{m: 1}} sx={{ m: 1 }}
onClick={() => { navigate('/generate-candidate')}}> onClick={() => { navigate('/generate-candidate') }}>
Generate your own perfect AI candidate! Generate your own perfect AI candidate!
</Button> </Button>
</Box> </Box>
<Box sx={{ display: "flex", gap: 1, flexWrap: "wrap"}}> <Box sx={{ display: "flex", gap: 1, flexWrap: "wrap" }}>
{candidates?.map((u, i) => {candidates?.map((u, i) =>
<Box key={`${u.username}`} <Box key={`${u.username}`}
onClick={(event: React.MouseEvent<HTMLDivElement>) : void => { onClick={() => { setCandidate(u); navigate("/chat"); }}
navigate(`/u/${u.username}`) sx={{ cursor: "pointer" }}>
}}
sx={{ cursor: "pointer" }}
>
<CandidateInfo sx={{ maxWidth: "320px", "cursor": "pointer", "&:hover": { border: "2px solid orange" }, border: "2px solid transparent" }} candidate={u} /> <CandidateInfo sx={{ maxWidth: "320px", "cursor": "pointer", "&:hover": { border: "2px solid orange" }, border: "2px solid transparent" }} candidate={u} />
</Box> </Box>
)} )}
</Box> </Box>
</Box> </Box>
); );
}; };

View File

@ -10,9 +10,7 @@ import SendIcon from '@mui/icons-material/Send';
import PropagateLoader from 'react-spinners/PropagateLoader'; import PropagateLoader from 'react-spinners/PropagateLoader';
import { jsonrepair } from 'jsonrepair'; import { jsonrepair } from 'jsonrepair';
import { CandidateInfo } from '../components/CandidateInfo'; import { CandidateInfo } from '../components/CandidateInfo';
import { Query } from '../types/types'
import { Quote } from 'components/Quote'; import { Quote } from 'components/Quote';
import { Candidate } from '../types/types'; import { Candidate } from '../types/types';
import { BackstoryElementProps } from 'components/BackstoryTab'; import { BackstoryElementProps } from 'components/BackstoryTab';
@ -21,6 +19,8 @@ import { StyledMarkdown } from 'components/StyledMarkdown';
import { Scrollable } from '../components/Scrollable'; import { Scrollable } from '../components/Scrollable';
import { Pulse } from 'components/Pulse'; import { Pulse } from 'components/Pulse';
import { StreamingResponse } from 'types/api-client'; import { StreamingResponse } from 'types/api-client';
import { ChatContext, ChatSession, ChatQuery } from 'types/types';
import { useUser } from 'hooks/useUser';
const emptyUser: Candidate = { const emptyUser: Candidate = {
description: "[blank]", description: "[blank]",
@ -46,6 +46,7 @@ const emptyUser: Candidate = {
}; };
const GenerateCandidate = (props: BackstoryElementProps) => { const GenerateCandidate = (props: BackstoryElementProps) => {
const { apiClient } = useUser();
const { setSnack, submitQuery } = props; const { setSnack, submitQuery } = props;
const [streaming, setStreaming] = useState<string>(''); const [streaming, setStreaming] = useState<string>('');
const [processing, setProcessing] = useState<boolean>(false); const [processing, setProcessing] = useState<boolean>(false);
@ -57,16 +58,40 @@ const GenerateCandidate = (props: BackstoryElementProps) => {
const [timestamp, setTimestamp] = useState<number>(0); const [timestamp, setTimestamp] = useState<number>(0);
const [state, setState] = useState<number>(0); // Replaced stateRef const [state, setState] = useState<number>(0); // Replaced stateRef
const [shouldGenerateProfile, setShouldGenerateProfile] = useState<boolean>(false); const [shouldGenerateProfile, setShouldGenerateProfile] = useState<boolean>(false);
const [chatSession, setChatSession] = useState<ChatSession | null>(null);
// Only keep refs that are truly necessary // Only keep refs that are truly necessary
const controllerRef = useRef<StreamingResponse>(null); const controllerRef = useRef<StreamingResponse>(null);
const backstoryTextRef = useRef<BackstoryTextFieldRef>(null); const backstoryTextRef = useRef<BackstoryTextFieldRef>(null);
const generatePersona = useCallback((query: Query) => { /* Create the chat session */
if (controllerRef.current) { useEffect(() => {
if (chatSession) {
return; return;
} }
setPrompt(query.prompt);
const createChatSession = async () => {
try {
const chatContext: ChatContext = { type: "generate_persona" };
const response: ChatSession = await apiClient.createChatSession(chatContext);
setChatSession(response);
setSnack(`Chat session created for generate_persona: ${response.id}`);
} catch (e) {
console.error(e);
setSnack("Unable to create chat session.", "error");
}
};
createChatSession();
}, [chatSession, setChatSession]);
const generatePersona = useCallback((query: ChatQuery) => {
if (!chatSession || !chatSession.id) {
return;
}
const sessionId: string = chatSession.id;
setPrompt(query.prompt || '');
setState(0); setState(0);
setStatus("Generating persona..."); setStatus("Generating persona...");
setUser(emptyUser); setUser(emptyUser);
@ -76,6 +101,24 @@ const GenerateCandidate = (props: BackstoryElementProps) => {
setCanGenImage(false); setCanGenImage(false);
setShouldGenerateProfile(false); // Reset the flag setShouldGenerateProfile(false); // Reset the flag
const streamResponse = apiClient.sendMessageStream(sessionId, query, {
onPartialMessage: (content, messageId) => {
console.log('Partial content:', content);
// Update UI with partial content
},
onStatusChange: (status) => {
console.log('Status changed:', status);
// Update UI status indicator
},
onComplete: (finalMessage) => {
console.log('Final message:', finalMessage.content);
// Handle completed message
},
onError: (error) => {
console.error('Streaming error:', error);
// Handle error
}
});
// controllerRef.current = streamQueryResponse({ // controllerRef.current = streamQueryResponse({
// query, // query,
// type: "persona", // type: "persona",
@ -148,7 +191,7 @@ const GenerateCandidate = (props: BackstoryElementProps) => {
if (processing) { if (processing) {
return; return;
} }
const query: Query = { const query: ChatQuery = {
prompt: value, prompt: value,
} }
generatePersona(query); generatePersona(query);

View File

@ -22,9 +22,10 @@ import { Person, PersonAdd, AccountCircle, ExitToApp } from '@mui/icons-material
import 'react-phone-number-input/style.css'; import 'react-phone-number-input/style.css';
import PhoneInput from 'react-phone-number-input'; import PhoneInput from 'react-phone-number-input';
import { E164Number } from 'libphonenumber-js/core'; import { E164Number } from 'libphonenumber-js/core';
import './PhoneInput.css'; import './LoginPage.css';
import { ApiClient } from 'types/api-client'; import { ApiClient } from 'types/api-client';
import { useUser } from 'hooks/useUser';
// Import conversion utilities // Import conversion utilities
import { import {
@ -35,11 +36,12 @@ import {
isSuccessResponse, isSuccessResponse,
debugConversion, debugConversion,
type ApiResponse type ApiResponse
} from './types/conversion'; } from 'types/conversion';
import { import {
AuthResponse, User, Guest, Candidate AuthResponse, User, Guest, Candidate
} from './types/types' } from 'types/types'
import { useNavigate } from 'react-router-dom';
interface LoginRequest { interface LoginRequest {
login: string; login: string;
@ -55,16 +57,17 @@ interface RegisterRequest {
phone?: string; phone?: string;
} }
const BackstoryTestApp: React.FC = () => { const apiClient = new ApiClient();
const apiClient = new ApiClient();
const [currentUser, setCurrentUser] = useState<User | null>(null); const LoginPage: React.FC = () => {
const [guestSession, setGuestSession] = useState<Guest | null>(null); const navigate = useNavigate();
const { user, setUser, guest } = useUser();
const [tabValue, setTabValue] = useState(0); const [tabValue, setTabValue] = useState(0);
const [loading, setLoading] = useState(false); const [loading, setLoading] = useState(false);
const [error, setError] = useState<string | null>(null); const [error, setError] = useState<string | null>(null);
const [success, setSuccess] = useState<string | null>(null); const [success, setSuccess] = useState<string | null>(null);
const [phone, setPhone] = useState<E164Number | null>(null); const [phone, setPhone] = useState<E164Number | null>(null);
const name = (currentUser?.userType === 'candidate' ? (currentUser as Candidate).username : currentUser?.email) || ''; const name = (user?.userType === 'candidate' ? (user as Candidate).username : user?.email) || '';
// Login form state // Login form state
const [loginForm, setLoginForm] = useState<LoginRequest>({ const [loginForm, setLoginForm] = useState<LoginRequest>({
@ -82,12 +85,6 @@ const BackstoryTestApp: React.FC = () => {
phone: '' phone: ''
}); });
// Create guest session on component mount
useEffect(() => {
createGuestSession();
checkExistingAuth();
}, []);
useEffect(() => { useEffect(() => {
if (phone !== registerForm.phone && phone) { if (phone !== registerForm.phone && phone) {
console.log({ phone }); console.log({ phone });
@ -95,44 +92,6 @@ const BackstoryTestApp: React.FC = () => {
} }
}, [phone, registerForm]); }, [phone, registerForm]);
const createGuestSession = () => {
const sessionId = `guest_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
const guest: Guest = {
sessionId,
createdAt: new Date(),
lastActivity: new Date(),
ipAddress: 'unknown',
userAgent: navigator.userAgent
};
setGuestSession(guest);
debugConversion(guest, 'Guest Session');
};
const checkExistingAuth = () => {
const token = localStorage.getItem('accessToken');
const userData = localStorage.getItem('userData');
if (token && userData) {
try {
const user = JSON.parse(userData);
// Convert dates back to Date objects if they're stored as strings
if (user.createdAt && typeof user.createdAt === 'string') {
user.createdAt = new Date(user.createdAt);
}
if (user.updatedAt && typeof user.updatedAt === 'string') {
user.updatedAt = new Date(user.updatedAt);
}
if (user.lastLogin && typeof user.lastLogin === 'string') {
user.lastLogin = new Date(user.lastLogin);
}
setCurrentUser(user);
} catch (e) {
localStorage.removeItem('accessToken');
localStorage.removeItem('refreshToken');
localStorage.removeItem('userData');
}
}
};
const handleLogin = async (e: React.FormEvent) => { const handleLogin = async (e: React.FormEvent) => {
e.preventDefault(); e.preventDefault();
setLoading(true); setLoading(true);
@ -149,12 +108,11 @@ const BackstoryTestApp: React.FC = () => {
localStorage.setItem('refreshToken', authResponse.refreshToken); localStorage.setItem('refreshToken', authResponse.refreshToken);
localStorage.setItem('userData', JSON.stringify(authResponse.user)); localStorage.setItem('userData', JSON.stringify(authResponse.user));
setCurrentUser(authResponse.user);
setSuccess('Login successful!'); setSuccess('Login successful!');
navigate('/');
setUser(authResponse.user);
// Clear form // Clear form
setLoginForm({ login: '', password: '' }); setLoginForm({ login: '', password: '' });
} catch (err) { } catch (err) {
console.error('Login error:', err); console.error('Login error:', err);
setError(err instanceof Error ? err.message : 'Login failed'); setError(err instanceof Error ? err.message : 'Login failed');
@ -219,113 +177,68 @@ const BackstoryTestApp: React.FC = () => {
} }
}; };
const handleLogout = () => {
localStorage.removeItem('accessToken');
localStorage.removeItem('refreshToken');
localStorage.removeItem('userData');
setCurrentUser(null);
setSuccess('Logged out successfully');
createGuestSession();
};
const handleTabChange = (event: React.SyntheticEvent, newValue: number) => { const handleTabChange = (event: React.SyntheticEvent, newValue: number) => {
setTabValue(newValue); setTabValue(newValue);
setError(null); setError(null);
setSuccess(null); setSuccess(null);
}; };
// API helper function for authenticated requests
const makeAuthenticatedRequest = async (url: string, options: RequestInit = {}) => {
const token = localStorage.getItem('accessToken');
const headers = {
'Content-Type': 'application/json',
...(token && { 'Authorization': `Bearer ${token}` }),
...options.headers,
};
const response = await fetch(url, {
...options,
headers,
});
return handleApiResponse(response);
};
// If user is logged in, show their profile // If user is logged in, show their profile
if (currentUser) { if (user) {
return ( return (
<Box sx={{ flexGrow: 1 }}> <Container maxWidth="md" sx={{ mt: 4 }}>
<AppBar position="static"> <Card elevation={3}>
<Toolbar> <CardContent>
<AccountCircle sx={{ mr: 2 }} /> <Box sx={{ display: 'flex', alignItems: 'center', mb: 3 }}>
<Typography variant="h6" component="div" sx={{ flexGrow: 1 }}> <Avatar sx={{ mr: 2, bgcolor: 'primary.main' }}>
Welcome, {name} <AccountCircle />
</Typography> </Avatar>
<Button <Typography variant="h4" component="h1">
color="inherit" User Profile
onClick={handleLogout} </Typography>
startIcon={<ExitToApp />} </Box>
>
Logout <Divider sx={{ mb: 3 }} />
</Button>
</Toolbar> <Grid container spacing={3}>
</AppBar> <Grid size={{ xs: 12, md: 6 }}>
<Typography variant="body1" sx={{ mb: 1 }}>
<Container maxWidth="md" sx={{ mt: 4 }}> <strong>Username:</strong> {name}
<Card elevation={3}>
<CardContent>
<Box sx={{ display: 'flex', alignItems: 'center', mb: 3 }}>
<Avatar sx={{ mr: 2, bgcolor: 'primary.main' }}>
<AccountCircle />
</Avatar>
<Typography variant="h4" component="h1">
User Profile
</Typography> </Typography>
</Box>
<Divider sx={{ mb: 3 }} />
<Grid container spacing={3}>
<Grid size={{ xs: 12, md: 6 }}>
<Typography variant="body1" sx={{ mb: 1 }}>
<strong>Username:</strong> {name}
</Typography>
</Grid>
<Grid size={{ xs: 12, md: 6 }}>
<Typography variant="body1" sx={{ mb: 1 }}>
<strong>Email:</strong> {currentUser.email}
</Typography>
</Grid>
<Grid size={{ xs: 12, md: 6 }}>
<Typography variant="body1" sx={{ mb: 1 }}>
<strong>Status:</strong> {currentUser.status}
</Typography>
</Grid>
<Grid size={{ xs: 12, md: 6 }}>
<Typography variant="body1" sx={{ mb: 1 }}>
<strong>Phone:</strong> {currentUser.phone || 'Not provided'}
</Typography>
</Grid>
<Grid size={{ xs: 12, md: 6 }}>
<Typography variant="body1" sx={{ mb: 1 }}>
<strong>Last Login:</strong> {
currentUser.lastLogin
? currentUser.lastLogin.toLocaleString()
: 'N/A'
}
</Typography>
</Grid>
<Grid size={{ xs: 12, md: 6 }}>
<Typography variant="body1" sx={{ mb: 1 }}>
<strong>Member Since:</strong> {currentUser.createdAt.toLocaleDateString()}
</Typography>
</Grid>
</Grid> </Grid>
</CardContent> <Grid size={{ xs: 12, md: 6 }}>
</Card> <Typography variant="body1" sx={{ mb: 1 }}>
</Container> <strong>Email:</strong> {user.email}
</Box> </Typography>
</Grid>
<Grid size={{ xs: 12, md: 6 }}>
<Typography variant="body1" sx={{ mb: 1 }}>
<strong>Status:</strong> {user.status}
</Typography>
</Grid>
<Grid size={{ xs: 12, md: 6 }}>
<Typography variant="body1" sx={{ mb: 1 }}>
<strong>Phone:</strong> {user.phone || 'Not provided'}
</Typography>
</Grid>
<Grid size={{ xs: 12, md: 6 }}>
<Typography variant="body1" sx={{ mb: 1 }}>
<strong>Last Login:</strong> {
user.lastLogin
? user.lastLogin.toLocaleString()
: 'N/A'
}
</Typography>
</Grid>
<Grid size={{ xs: 12, md: 6 }}>
<Typography variant="body1" sx={{ mb: 1 }}>
<strong>Member Since:</strong> {user.createdAt.toLocaleDateString()}
</Typography>
</Grid>
</Grid>
</CardContent>
</Card>
</Container>
); );
} }
@ -352,20 +265,20 @@ const BackstoryTestApp: React.FC = () => {
<Container maxWidth="sm" sx={{ mt: 4 }}> <Container maxWidth="sm" sx={{ mt: 4 }}>
<Paper elevation={3} sx={{ p: 4 }}> <Paper elevation={3} sx={{ p: 4 }}>
<Typography variant="h4" component="h1" gutterBottom align="center" color="primary"> <Typography variant="h4" component="h1" gutterBottom align="center" color="primary">
Backstory Platform Backstory
</Typography> </Typography>
{guestSession && ( {guest && (
<Card sx={{ mb: 3, bgcolor: 'grey.50' }} elevation={1}> <Card sx={{ mb: 3, bgcolor: 'grey.50' }} elevation={1}>
<CardContent> <CardContent>
<Typography variant="h6" gutterBottom color="primary"> <Typography variant="h6" gutterBottom color="primary">
Guest Session Active Guest Session Active
</Typography> </Typography>
<Typography variant="body2" color="text.secondary" sx={{ mb: 0.5 }}> <Typography variant="body2" color="text.secondary" sx={{ mb: 0.5 }}>
Session ID: {guestSession.sessionId} Session ID: {guest.sessionId}
</Typography> </Typography>
<Typography variant="body2" color="text.secondary"> <Typography variant="body2" color="text.secondary">
Created: {guestSession.createdAt.toLocaleString()} Created: {guest.createdAt.toLocaleString()}
</Typography> </Typography>
</CardContent> </CardContent>
</Card> </Card>
@ -537,4 +450,4 @@ const BackstoryTestApp: React.FC = () => {
); );
}; };
export { BackstoryTestApp }; export { LoginPage };

View File

@ -6,11 +6,11 @@ import {
} from '@mui/material'; } from '@mui/material';
import { SxProps } from '@mui/material'; import { SxProps } from '@mui/material';
import { ChatQuery } from '../components/ChatQuery'; import { BackstoryQuery } from 'components/BackstoryQuery';
import { MessageList, BackstoryMessage } from '../components/Message'; import { MessageList, BackstoryMessage } from 'components/Message';
import { Conversation } from '../components/Conversation'; import { Conversation } from 'components/Conversation';
import { BackstoryPageProps } from '../components/BackstoryTab'; import { BackstoryPageProps } from 'components/BackstoryTab';
import { Query } from "../types/types"; import { ChatQuery } from "types/types";
import './ResumeBuilderPage.css'; import './ResumeBuilderPage.css';
@ -43,17 +43,17 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = (props: BackstoryPagePro
setActiveTab(newValue); setActiveTab(newValue);
}; };
const handleJobQuery = (query: Query) => { const handleJobQuery = (query: ChatQuery) => {
console.log(`handleJobQuery: ${query.prompt} -- `, jobConversationRef.current ? ' sending' : 'no handler'); console.log(`handleJobQuery: ${query.prompt} -- `, jobConversationRef.current ? ' sending' : 'no handler');
jobConversationRef.current?.submitQuery(query); jobConversationRef.current?.submitQuery(query);
}; };
const handleResumeQuery = (query: Query) => { const handleResumeQuery = (query: ChatQuery) => {
console.log(`handleResumeQuery: ${query.prompt} -- `, resumeConversationRef.current ? ' sending' : 'no handler'); console.log(`handleResumeQuery: ${query.prompt} -- `, resumeConversationRef.current ? ' sending' : 'no handler');
resumeConversationRef.current?.submitQuery(query); resumeConversationRef.current?.submitQuery(query);
}; };
const handleFactsQuery = (query: Query) => { const handleFactsQuery = (query: ChatQuery) => {
console.log(`handleFactsQuery: ${query.prompt} -- `, factsConversationRef.current ? ' sending' : 'no handler'); console.log(`handleFactsQuery: ${query.prompt} -- `, factsConversationRef.current ? ' sending' : 'no handler');
factsConversationRef.current?.submitQuery(query); factsConversationRef.current?.submitQuery(query);
}; };
@ -202,8 +202,8 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = (props: BackstoryPagePro
// console.log('renderJobDescriptionView'); // console.log('renderJobDescriptionView');
// const jobDescriptionQuestions = [ // const jobDescriptionQuestions = [
// <Box sx={{ display: "flex", flexDirection: "column" }}> // <Box sx={{ display: "flex", flexDirection: "column" }}>
// <ChatQuery query={{ prompt: "What are the key skills necessary for this position?", tunables: { enableTools: false } }} submitQuery={handleJobQuery} /> // <BackstoryQuery query={{ prompt: "What are the key skills necessary for this position?", tunables: { enableTools: false } }} submitQuery={handleJobQuery} />
// <ChatQuery query={{ prompt: "How much should this position pay (accounting for inflation)?", tunables: { enableTools: false } }} submitQuery={handleJobQuery} /> // <BackstoryQuery query={{ prompt: "How much should this position pay (accounting for inflation)?", tunables: { enableTools: false } }} submitQuery={handleJobQuery} />
// </Box>, // </Box>,
// ]; // ];
@ -213,7 +213,7 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = (props: BackstoryPagePro
// 1. **Job Analysis**: LLM extracts requirements from '\`Job Description\`' to generate a list of desired '\`Skills\`'. // 1. **Job Analysis**: LLM extracts requirements from '\`Job Description\`' to generate a list of desired '\`Skills\`'.
// 2. **Candidate Analysis**: LLM determines candidate qualifications by performing skill assessments. // 2. **Candidate Analysis**: LLM determines candidate qualifications by performing skill assessments.
// For each '\`Skill\`' from **Job Analysis** phase: // For each '\`Skill\`' from **Job Analysis** phase:
// 1. **RAG**: Retrieval Augmented Generation collection is queried for context related content for each '\`Skill\`'. // 1. **RAG**: Retrieval Augmented Generation collection is queried for context related content for each '\`Skill\`'.
@ -274,8 +274,8 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = (props: BackstoryPagePro
// const renderResumeView = useCallback((sx?: SxProps) => { // const renderResumeView = useCallback((sx?: SxProps) => {
// const resumeQuestions = [ // const resumeQuestions = [
// <Box sx={{ display: "flex", flexDirection: "column" }}> // <Box sx={{ display: "flex", flexDirection: "column" }}>
// <ChatQuery query={{ prompt: "Is this resume a good fit for the provided job description?", tunables: { enableTools: false } }} submitQuery={handleResumeQuery} /> // <BackstoryQuery query={{ prompt: "Is this resume a good fit for the provided job description?", tunables: { enableTools: false } }} submitQuery={handleResumeQuery} />
// <ChatQuery query={{ prompt: "Provide a more concise resume.", tunables: { enableTools: false } }} submitQuery={handleResumeQuery} /> // <BackstoryQuery query={{ prompt: "Provide a more concise resume.", tunables: { enableTools: false } }} submitQuery={handleResumeQuery} />
// </Box>, // </Box>,
// ]; // ];
@ -323,7 +323,7 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = (props: BackstoryPagePro
// const renderFactCheckView = useCallback((sx?: SxProps) => { // const renderFactCheckView = useCallback((sx?: SxProps) => {
// const factsQuestions = [ // const factsQuestions = [
// <Box sx={{ display: "flex", flexDirection: "column" }}> // <Box sx={{ display: "flex", flexDirection: "column" }}>
// <ChatQuery query={{ prompt: "Rewrite the resume to address any discrepancies.", tunables: { enableTools: false } }} submitQuery={handleFactsQuery} /> // <BackstoryQuery query={{ prompt: "Rewrite the resume to address any discrepancies.", tunables: { enableTools: false } }} submitQuery={handleFactsQuery} />
// </Box>, // </Box>,
// ]; // ];

View File

@ -1,12 +1,11 @@
import React, { useEffect, useState } from "react"; import React, { useEffect, useState } from "react";
import { useParams, useNavigate } from "react-router-dom"; import { useParams, useNavigate } from "react-router-dom";
import { useUser } from "../hooks/useUser";
import { Box } from "@mui/material"; import { Box } from "@mui/material";
import { SetSnackType } from '../components/Snack'; import { SetSnackType } from '../components/Snack';
import { LoadingComponent } from "../components/LoadingComponent"; import { LoadingComponent } from "../components/LoadingComponent";
import { User, Guest, Candidate } from 'types/types'; import { User, Guest, Candidate } from 'types/types';
import { ApiClient } from "types/api-client"; import { useUser } from "hooks/useUser";
interface CandidateRouteProps { interface CandidateRouteProps {
guest?: Guest | null; guest?: Guest | null;
@ -15,7 +14,7 @@ interface CandidateRouteProps {
}; };
const CandidateRoute: React.FC<CandidateRouteProps> = (props: CandidateRouteProps) => { const CandidateRoute: React.FC<CandidateRouteProps> = (props: CandidateRouteProps) => {
const apiClient = new ApiClient(); const { apiClient } = useUser();
const { setSnack } = props; const { setSnack } = props;
const { username } = useParams<{ username: string }>(); const { username } = useParams<{ username: string }>();
const [candidate, setCandidate] = useState<Candidate|null>(null); const [candidate, setCandidate] = useState<Candidate|null>(null);
@ -32,11 +31,12 @@ const CandidateRoute: React.FC<CandidateRouteProps> = (props: CandidateRouteProp
navigate('/chat'); navigate('/chat');
} catch { } catch {
setSnack(`Unable to obtain information for ${username}.`, "error"); setSnack(`Unable to obtain information for ${username}.`, "error");
navigate('/');
} }
} }
getCandidate(username); getCandidate(username);
}, [candidate, username, setCandidate]); }, [candidate, username, setCandidate, navigate, setSnack, apiClient]);
if (candidate === null) { if (candidate === null) {
return (<Box> return (<Box>

View File

@ -9,14 +9,14 @@
import * as Types from './types'; import * as Types from './types';
import { import {
formatApiRequest, formatApiRequest,
parseApiResponse, // parseApiResponse,
parsePaginatedResponse, // parsePaginatedResponse,
handleApiResponse, handleApiResponse,
handlePaginatedApiResponse, handlePaginatedApiResponse,
createPaginatedRequest, createPaginatedRequest,
toUrlParams, toUrlParams,
extractApiData, // extractApiData,
ApiResponse, // ApiResponse,
PaginatedResponse, PaginatedResponse,
PaginatedRequest PaginatedRequest
} from './conversion'; } from './conversion';
@ -59,7 +59,7 @@ class ApiClient {
private baseUrl: string; private baseUrl: string;
private defaultHeaders: Record<string, string>; private defaultHeaders: Record<string, string>;
constructor(authToken?: string) { constructor(accessToken?: string) {
const loc = window.location; const loc = window.location;
if (!loc.host.match(/.*battle-linux.*/)) { if (!loc.host.match(/.*battle-linux.*/)) {
this.baseUrl = loc.protocol + "//" + loc.host + "/api/1.0"; this.baseUrl = loc.protocol + "//" + loc.host + "/api/1.0";
@ -68,7 +68,7 @@ class ApiClient {
} }
this.defaultHeaders = { this.defaultHeaders = {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
...(authToken && { 'Authorization': `Bearer ${authToken}` }) ...(accessToken && { 'Authorization': `Bearer ${accessToken}` })
}; };
} }
@ -76,16 +76,27 @@ class ApiClient {
// Authentication Methods // Authentication Methods
// ============================ // ============================
async login(email: string, password: string): Promise<Types.AuthResponse> { async login(login: string, password: string): Promise<Types.AuthResponse> {
const response = await fetch(`${this.baseUrl}/auth/login`, { const response = await fetch(`${this.baseUrl}/auth/login`, {
method: 'POST', method: 'POST',
headers: this.defaultHeaders, headers: this.defaultHeaders,
body: JSON.stringify(formatApiRequest({ email, password })) body: JSON.stringify(formatApiRequest({ login, password }))
}); });
return handleApiResponse<Types.AuthResponse>(response); return handleApiResponse<Types.AuthResponse>(response);
} }
async logout(accessToken: string, refreshToken: string): Promise<Types.ApiResponse> {
console.log(this.defaultHeaders);
const response = await fetch(`${this.baseUrl}/auth/logout`, {
method: 'POST',
headers: this.defaultHeaders,
body: JSON.stringify(formatApiRequest({ accessToken, refreshToken }))
});
return handleApiResponse<Types.ApiResponse>(response);
}
async refreshToken(refreshToken: string): Promise<Types.AuthResponse> { async refreshToken(refreshToken: string): Promise<Types.AuthResponse> {
const response = await fetch(`${this.baseUrl}/auth/refresh`, { const response = await fetch(`${this.baseUrl}/auth/refresh`, {
method: 'POST', method: 'POST',
@ -110,8 +121,8 @@ class ApiClient {
return handleApiResponse<Types.Candidate>(response); return handleApiResponse<Types.Candidate>(response);
} }
async getCandidate(id: string): Promise<Types.Candidate> { async getCandidate(username: string): Promise<Types.Candidate> {
const response = await fetch(`${this.baseUrl}/candidates/${id}`, { const response = await fetch(`${this.baseUrl}/candidates/${username}`, {
headers: this.defaultHeaders headers: this.defaultHeaders
}); });
@ -315,11 +326,11 @@ class ApiClient {
/** /**
* Send message with standard response (non-streaming) * Send message with standard response (non-streaming)
*/ */
async sendMessage(sessionId: string, query: Types.Query): Promise<Types.ChatMessage> { async sendMessage(sessionId: string, query: Types.ChatQuery): Promise<Types.ChatMessage> {
const response = await fetch(`${this.baseUrl}/chat/sessions/${sessionId}/messages`, { const response = await fetch(`${this.baseUrl}/chat/sessions/${sessionId}/messages`, {
method: 'POST', method: 'POST',
headers: this.defaultHeaders, headers: this.defaultHeaders,
body: JSON.stringify(formatApiRequest({ query })) body: JSON.stringify(formatApiRequest({query}))
}); });
return handleApiResponse<Types.ChatMessage>(response); return handleApiResponse<Types.ChatMessage>(response);
@ -330,7 +341,7 @@ class ApiClient {
*/ */
sendMessageStream( sendMessageStream(
sessionId: string, sessionId: string,
query: Types.Query, query: Types.ChatQuery,
options: StreamingOptions = {} options: StreamingOptions = {}
): StreamingResponse { ): StreamingResponse {
const abortController = new AbortController(); const abortController = new AbortController();
@ -479,7 +490,7 @@ class ApiClient {
*/ */
async sendMessageAuto( async sendMessageAuto(
sessionId: string, sessionId: string,
query: Types.Query, query: Types.ChatQuery,
options?: StreamingOptions options?: StreamingOptions
): Promise<Types.ChatMessage> { ): Promise<Types.ChatMessage> {
// If streaming options are provided, use streaming // If streaming options are provided, use streaming
@ -503,36 +514,6 @@ class ApiClient {
return handlePaginatedApiResponse<Types.ChatMessage>(response); return handlePaginatedApiResponse<Types.ChatMessage>(response);
} }
// ============================
// AI Configuration Methods
// ============================
async createAIParameters(params: Omit<Types.AIParameters, 'id' | 'createdAt' | 'updatedAt'>): Promise<Types.AIParameters> {
const response = await fetch(`${this.baseUrl}/ai/parameters`, {
method: 'POST',
headers: this.defaultHeaders,
body: JSON.stringify(formatApiRequest(params))
});
return handleApiResponse<Types.AIParameters>(response);
}
async getAIParameters(id: string): Promise<Types.AIParameters> {
const response = await fetch(`${this.baseUrl}/ai/parameters/${id}`, {
headers: this.defaultHeaders
});
return handleApiResponse<Types.AIParameters>(response);
}
async getUserAIParameters(userId: string): Promise<Types.AIParameters[]> {
const response = await fetch(`${this.baseUrl}/users/${userId}/ai/parameters`, {
headers: this.defaultHeaders
});
return handleApiResponse<Types.AIParameters[]>(response);
}
// ============================ // ============================
// Error Handling Helper // Error Handling Helper
// ============================ // ============================
@ -580,7 +561,7 @@ export function useStreamingChat(sessionId: string) {
const apiClient = useApiClient(); const apiClient = useApiClient();
const streamingRef = useRef<StreamingResponse | null>(null); const streamingRef = useRef<StreamingResponse | null>(null);
const sendMessage = useCallback(async (query: Types.Query) => { const sendMessage = useCallback(async (query: Types.ChatQuery) => {
setError(null); setError(null);
setIsStreaming(true); setIsStreaming(true);
setCurrentMessage(null); setCurrentMessage(null);

View File

@ -1,23 +1,23 @@
// Generated TypeScript types from Pydantic models // Generated TypeScript types from Pydantic models
// Source: src/backend/models.py // Source: src/backend/models.py
// Generated on: 2025-05-28T21:47:08.590102 // Generated on: 2025-05-29T05:47:25.809967
// DO NOT EDIT MANUALLY - This file is auto-generated // DO NOT EDIT MANUALLY - This file is auto-generated
// ============================ // ============================
// Enums // Enums
// ============================ // ============================
export type AIModelType = "gpt-4" | "gpt-3.5-turbo" | "claude-3" | "claude-3-opus" | "custom"; export type AIModelType = "qwen2.5" | "flux-schnell";
export type ActivityType = "login" | "search" | "view_job" | "apply_job" | "message" | "update_profile" | "chat"; export type ActivityType = "login" | "search" | "view_job" | "apply_job" | "message" | "update_profile" | "chat";
export type ApplicationStatus = "applied" | "reviewing" | "interview" | "offer" | "rejected" | "accepted" | "withdrawn"; export type ApplicationStatus = "applied" | "reviewing" | "interview" | "offer" | "rejected" | "accepted" | "withdrawn";
export type ChatContextType = "job_search" | "candidate_screening" | "interview_prep" | "resume_review" | "general"; export type ChatContextType = "job_search" | "candidate_screening" | "interview_prep" | "resume_review" | "general" | "generate_persona" | "generate_profile";
export type ChatSenderType = "user" | "ai" | "system"; export type ChatSenderType = "user" | "ai" | "system";
export type ChatStatusType = "partial" | "done" | "streaming" | "thinking" | "error"; export type ChatStatusType = "preparing" | "thinking" | "partial" | "streaming" | "done" | "error";
export type ColorBlindMode = "protanopia" | "deuteranopia" | "tritanopia" | "none"; export type ColorBlindMode = "protanopia" | "deuteranopia" | "tritanopia" | "none";
@ -63,24 +63,6 @@ export type VectorStoreType = "pinecone" | "qdrant" | "faiss" | "milvus" | "weav
// Interfaces // Interfaces
// ============================ // ============================
export interface AIParameters {
id?: string;
userId?: string;
name: string;
description?: string;
model: "gpt-4" | "gpt-3.5-turbo" | "claude-3" | "claude-3-opus" | "custom";
temperature: number;
maxTokens: number;
topP: number;
frequencyPenalty: number;
presencePenalty: number;
systemPrompt?: string;
isDefault: boolean;
createdAt: Date;
updatedAt: Date;
customModelConfig?: Record<string, any>;
}
export interface AccessibilitySettings { export interface AccessibilitySettings {
fontSize: "small" | "medium" | "large"; fontSize: "small" | "medium" | "large";
highContrast: boolean; highContrast: boolean;
@ -240,34 +222,63 @@ export interface Certification {
} }
export interface ChatContext { export interface ChatContext {
type: "job_search" | "candidate_screening" | "interview_prep" | "resume_review" | "general"; type: "job_search" | "candidate_screening" | "interview_prep" | "resume_review" | "general" | "generate_persona" | "generate_profile";
relatedEntityId?: string; relatedEntityId?: string;
relatedEntityType?: "job" | "candidate" | "employer"; relatedEntityType?: "job" | "candidate" | "employer";
aiParameters: AIParameters;
additionalContext?: Record<string, any>; additionalContext?: Record<string, any>;
} }
export interface ChatMessage { export interface ChatMessage {
id?: string; id?: string;
sessionId: string; sessionId: string;
status: "partial" | "done" | "streaming" | "thinking" | "error"; status: "preparing" | "thinking" | "partial" | "streaming" | "done" | "error";
sender: "user" | "ai" | "system"; sender: "user" | "ai" | "system";
senderId?: string; senderId?: string;
content: string; prompt?: string;
content?: string;
chunk?: string;
timestamp: Date; timestamp: Date;
attachments?: Array<Attachment>;
reactions?: Array<MessageReaction>;
isEdited?: boolean; isEdited?: boolean;
editHistory?: Array<EditHistory>; metadata?: ChatMessageMetaData;
metadata?: Record<string, any>; }
export interface ChatMessageMetaData {
model?: "qwen2.5" | "flux-schnell";
temperature?: number;
maxTokens?: number;
topP?: number;
frequencyPenalty?: number;
presencePenalty?: number;
stopSequences?: Array<string>;
tunables?: Tunables;
rag?: Array<ChromaDBGetResponse>;
evalCount?: number;
evalDuration?: number;
promptEvalCount?: number;
promptEvalDuration?: number;
options?: ChatOptions;
tools?: Record<string, any>;
timers?: Record<string, number>;
}
export interface ChatOptions {
seed?: number;
numCtx?: number;
temperature?: number;
}
export interface ChatQuery {
prompt: string;
tunables?: Tunables;
agentOptions?: Record<string, any>;
} }
export interface ChatSession { export interface ChatSession {
id?: string; id?: string;
userId?: string; userId?: string;
guestId?: string; guestId?: string;
createdAt: Date; createdAt?: Date;
lastActivity: Date; lastActivity?: Date;
title?: string; title?: string;
context: ChatContext; context: ChatContext;
messages?: Array<ChatMessage>; messages?: Array<ChatMessage>;
@ -275,6 +286,19 @@ export interface ChatSession {
systemPrompt?: string; systemPrompt?: string;
} }
export interface ChromaDBGetResponse {
ids?: Array<string>;
embeddings?: Array<Array<number>>;
documents?: Array<string>;
metadatas?: Array<Record<string, any>>;
name?: string;
size?: number;
query?: string;
queryEmbedding?: Array<number>;
umapEmbedding2D?: Array<number>;
umapEmbedding3D?: Array<number>;
}
export interface CustomQuestion { export interface CustomQuestion {
question: string; question: string;
answer: string; answer: string;
@ -511,12 +535,6 @@ export interface ProcessingStep {
dependsOn?: Array<string>; dependsOn?: Array<string>;
} }
export interface Query {
prompt: string;
tunables?: Tunables;
agentOptions?: Record<string, any>;
}
export interface RAGConfiguration { export interface RAGConfiguration {
id?: string; id?: string;
userId: string; userId: string;
@ -528,11 +546,16 @@ export interface RAGConfiguration {
retrievalParameters: RetrievalParameters; retrievalParameters: RetrievalParameters;
createdAt: Date; createdAt: Date;
updatedAt: Date; updatedAt: Date;
isDefault: boolean;
version: number; version: number;
isActive: boolean; isActive: boolean;
} }
export interface RagEntry {
name: string;
description?: string;
enabled?: boolean;
}
export interface RefreshToken { export interface RefreshToken {
token: string; token: string;
expiresAt: Date; expiresAt: Date;

View File

@ -0,0 +1,97 @@
from __future__ import annotations
from pydantic import BaseModel, Field # type: ignore
from typing import (
Literal,
get_args,
List,
AsyncGenerator,
TYPE_CHECKING,
Optional,
ClassVar,
Any,
TypeAlias,
Dict,
Tuple,
)
import importlib
import pathlib
import inspect
from prometheus_client import CollectorRegistry # type: ignore
from database import RedisDatabase
from . base import Agent
from logger import logger
_agents: List[Agent] = []
def get_or_create_agent(agent_type: str, prometheus_collector: CollectorRegistry, database: RedisDatabase, **kwargs) -> Agent:
"""
Get or create and append a new agent of the specified type, ensuring only one agent per type exists.
Args:
agent_type: The type of agent to create (e.g., 'web', 'database').
**kwargs: Additional fields required by the specific agent subclass.
Returns:
The created agent instance.
Raises:
ValueError: If no matching agent type is found or if a agent of this type already exists.
"""
# Check if a agent with the given agent_type already exists
for agent in _agents:
if agent.agent_type == agent_type:
return agent
# Find the matching subclass
for agent_cls in Agent.__subclasses__():
if agent_cls.model_fields["agent_type"].default == agent_type:
# Create the agent instance with provided kwargs
agent = agent_cls(agent_type=agent_type, prometheus_collector=prometheus_collector, database=database, **kwargs)
# if agent.agent_persist: # If an agent is not set to persist, do not add it to the list
_agents.append(agent)
return agent
raise ValueError(f"No agent class found for agent_type: {agent_type}")
# Type alias for Agent or any subclass
AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses
# Maps class_name to (module_name, class_name)
class_registry: Dict[str, Tuple[str, str]] = (
{}
)
__all__ = ['get_or_create_agent']
package_dir = pathlib.Path(__file__).parent
package_name = __name__
for path in package_dir.glob("*.py"):
if path.name in ("__init__.py", "base.py") or path.name.startswith("_"):
continue
module_name = path.stem
full_module_name = f"{package_name}.{module_name}"
try:
module = importlib.import_module(full_module_name)
# Find all Agent subclasses in the module
for name, obj in inspect.getmembers(module, inspect.isclass):
if (
issubclass(obj, AnyAgent)
and obj is not AnyAgent
and obj is not Agent
and name not in class_registry
):
class_registry[name] = (full_module_name, name)
globals()[name] = obj
logger.info(f"Adding agent: {name}")
__all__.append(name) # type: ignore
except ImportError as e:
logger.error(f"Error importing {full_module_name}: {e}")
raise e
except Exception as e:
logger.error(f"Error processing {full_module_name}: {e}")
raise e

605
src/backend/agents/base.py Normal file
View File

@ -0,0 +1,605 @@
from __future__ import annotations
from pydantic import BaseModel, Field, model_validator # type: ignore
from typing import (
Literal,
get_args,
List,
AsyncGenerator,
TYPE_CHECKING,
Optional,
ClassVar,
Any,
TypeAlias,
Dict,
Tuple,
)
import json
import time
import inspect
from abc import ABC
import asyncio
from datetime import datetime, UTC
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
from models import ( ChatQuery, ChatMessage, Tunables, ChatStatusType, ChatMessageMetaData)
from logger import logger
import defines
from .registry import agent_registry
from metrics import Metrics
from database import RedisDatabase # type: ignore
class LLMMessage(BaseModel):
role: str = Field(default="")
content: str = Field(default="")
tool_calls: Optional[List[Dict]] = Field(default={}, exclude=True)
class Agent(BaseModel, ABC):
"""
Base class for all agent types.
This class defines the common attributes and methods for all agent types.
"""
class Config:
arbitrary_types_allowed = True # Allow arbitrary types like RedisDatabase
# Agent management with pydantic
agent_type: Literal["base"] = "base"
_agent_type: ClassVar[str] = agent_type # Add this for registration
agent_persist: bool = True # Whether this agent will persist in the database
database: RedisDatabase = Field(
...,
description="Database connection for this agent, used to store and retrieve data."
)
prometheus_collector: CollectorRegistry = Field(..., description="Prometheus collector for this agent, used to track metrics.", exclude=True)
# Tunables (sets default for new Messages attached to this agent)
tunables: Tunables = Field(default_factory=Tunables)
metrics: Metrics = Field(
None, description="Metrics collector for this agent, used to track performance and usage."
)
@model_validator(mode="after")
def initialize_metrics(self) -> "Agent":
if self.metrics is None:
self.metrics = Metrics(prometheus_collector=self.prometheus_collector)
return self
# Agent properties
system_prompt: str # Mandatory
context_tokens: int = 0
# context_size is shared across all subclasses
_context_size: ClassVar[int] = int(defines.max_context * 0.5)
conversation: List[ChatMessage] = Field(
default_factory=list,
description="Conversation history for this agent, used to maintain context across messages."
)
@property
def context_size(self) -> int:
return Agent._context_size
@context_size.setter
def context_size(self, value: int):
Agent._context_size = value
def set_optimal_context_size(
self, llm: Any, model: str, prompt: str, ctx_buffer=2048
) -> int:
# Most models average 1.3-1.5 tokens per word
word_count = len(prompt.split())
tokens = int(word_count * 1.4)
# Add buffer for safety
total_ctx = tokens + ctx_buffer
if total_ctx > self.context_size:
logger.info(
f"Increasing context size from {self.context_size} to {total_ctx}"
)
# Grow the context size if necessary
self.context_size = max(self.context_size, total_ctx)
# Use actual model maximum context size
return self.context_size
# Class and pydantic model management
def __init_subclass__(cls, **kwargs) -> None:
"""Auto-register subclasses"""
super().__init_subclass__(**kwargs)
# Register this class if it has an agent_type
if hasattr(cls, "agent_type") and cls.agent_type != Agent._agent_type:
agent_registry.register(cls.agent_type, cls)
def model_dump(self, *args, **kwargs) -> Any:
# Ensure context is always excluded, even with exclude_unset=True
kwargs.setdefault("exclude", set())
if isinstance(kwargs["exclude"], set):
kwargs["exclude"].add("context")
elif isinstance(kwargs["exclude"], dict):
kwargs["exclude"]["context"] = True
return super().model_dump(*args, **kwargs)
@classmethod
def valid_agent_types(cls) -> set[str]:
"""Return the set of valid agent_type values."""
return set(get_args(cls.__annotations__["agent_type"]))
# Agent methods
def get_agent_type(self):
return self._agent_type
# async def prepare_message(self, message: ChatMessage) -> AsyncGenerator[ChatMessage, None]:
# """
# Prepare message with context information in message.preamble
# """
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
# self.metrics.prepare_count.labels(agent=self.agent_type).inc()
# with self.metrics.prepare_duration.labels(agent=self.agent_type).time():
# if not self.context:
# raise ValueError("Context is not set for this agent.")
# # Generate RAG content if enabled, based on the content
# rag_context = ""
# if message.tunables.enable_rag and message.prompt:
# # Gather RAG results, yielding each result
# # as it becomes available
# for message in self.context.user.generate_rag_results(message):
# logger.info(f"RAG: {message.status} - {message.content}")
# if message.status == "error":
# yield message
# return
# if message.status != "done":
# yield message
# # for rag in message.metadata.rag:
# # for doc in rag.documents:
# # rag_context += f"{doc}\n"
# message.preamble = {}
# if rag_context:
# message.preamble["context"] = f"The following is context information about {self.context.user.full_name}:\n{rag_context}"
# if message.tunables.enable_context and self.context.user_resume:
# message.preamble["resume"] = self.context.user_resume
# message.system_prompt = self.system_prompt
# message.status = ChatStatusType.DONE
# yield message
# return
# async def process_tool_calls(
# self,
# llm: Any,
# model: str,
# message: ChatMessage,
# tool_message: Any, # llama response message
# messages: List[LLMMessage],
# ) -> AsyncGenerator[ChatMessage, None]:
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
# self.metrics.tool_count.labels(agent=self.agent_type).inc()
# with self.metrics.tool_duration.labels(agent=self.agent_type).time():
# if not self.context:
# raise ValueError("Context is not set for this agent.")
# if not message.metadata.tools:
# raise ValueError("tools field not initialized")
# tool_metadata = message.metadata.tools
# tool_metadata["tool_calls"] = []
# message.status = "tooling"
# for i, tool_call in enumerate(tool_message.tool_calls):
# arguments = tool_call.function.arguments
# tool = tool_call.function.name
# # Yield status update before processing each tool
# message.content = (
# f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..."
# )
# yield message
# logger.info(f"LLM - {message.content}")
# # Process the tool based on its type
# match tool:
# case "TickerValue":
# ticker = arguments.get("ticker")
# if not ticker:
# ret = None
# else:
# ret = TickerValue(ticker)
# case "AnalyzeSite":
# url = arguments.get("url")
# question = arguments.get(
# "question", "what is the summary of this content?"
# )
# # Additional status update for long-running operations
# message.content = (
# f"Retrieving and summarizing content from {url}..."
# )
# yield message
# ret = await AnalyzeSite(
# llm=llm, model=model, url=url, question=question
# )
# case "GenerateImage":
# prompt = arguments.get("prompt", None)
# if not prompt:
# logger.info("No prompt supplied to GenerateImage")
# ret = { "error": "No prompt supplied to GenerateImage" }
# # Additional status update for long-running operations
# message.content = (
# f"Generating image for {prompt}..."
# )
# yield message
# ret = await GenerateImage(
# llm=llm, model=model, prompt=prompt
# )
# logger.info("GenerateImage returning", ret)
# case "DateTime":
# tz = arguments.get("timezone")
# ret = DateTime(tz)
# case "WeatherForecast":
# city = arguments.get("city")
# state = arguments.get("state")
# message.content = (
# f"Fetching weather data for {city}, {state}..."
# )
# yield message
# ret = WeatherForecast(city, state)
# case _:
# logger.error(f"Requested tool {tool} does not exist")
# ret = None
# # Build response for this tool
# tool_response = {
# "role": "tool",
# "content": json.dumps(ret),
# "name": tool_call.function.name,
# }
# tool_metadata["tool_calls"].append(tool_response)
# if len(tool_metadata["tool_calls"]) == 0:
# message.status = "done"
# yield message
# return
# message_dict = LLMMessage(
# role=tool_message.get("role", "assistant"),
# content=tool_message.get("content", ""),
# tool_calls=[
# {
# "function": {
# "name": tc["function"]["name"],
# "arguments": tc["function"]["arguments"],
# }
# }
# for tc in tool_message.tool_calls
# ],
# )
# messages.append(message_dict)
# messages.extend(tool_metadata["tool_calls"])
# message.status = "thinking"
# message.content = "Incorporating tool results into response..."
# yield message
# # Decrease creativity when processing tool call requests
# message.content = ""
# start_time = time.perf_counter()
# for response in llm.chat(
# model=model,
# messages=messages,
# options={
# **message.metadata.options,
# },
# stream=True,
# ):
# # logger.info(f"LLM::Tools: {'done' if response.done else 'processing'} - {response.message}")
# message.status = "streaming"
# message.chunk = response.message.content
# message.content += message.chunk
# if not response.done:
# yield message
# if response.done:
# self.collect_metrics(response)
# message.metadata.eval_count += response.eval_count
# message.metadata.eval_duration += response.eval_duration
# message.metadata.prompt_eval_count += response.prompt_eval_count
# message.metadata.prompt_eval_duration += response.prompt_eval_duration
# self.context_tokens = (
# response.prompt_eval_count + response.eval_count
# )
# message.status = "done"
# yield message
# end_time = time.perf_counter()
# message.metadata.timers["llm_with_tools"] = end_time - start_time
# return
def collect_metrics(self, response):
self.metrics.tokens_prompt.labels(agent=self.agent_type).inc(
response.prompt_eval_count
)
self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count)
async def generate(
self, llm: Any, model: str, query: ChatQuery, session_id: str, user_id: str, temperature=0.7
) -> AsyncGenerator[ChatMessage, None]:
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
chat_message = ChatMessage(
session_id=session_id,
prompt=query.prompt,
tunables=query.tunables,
status=ChatStatusType.PREPARING,
sender="user",
content="",
timestamp=datetime.now(UTC)
)
self.metrics.generate_count.labels(agent=self.agent_type).inc()
with self.metrics.generate_duration.labels(agent=self.agent_type).time():
# Create a pruned down message list based purely on the prompt and responses,
# discarding the full preamble generated by prepare_message
messages: List[LLMMessage] = [
LLMMessage(role="system", content=self.system_prompt)
]
messages.extend(
[
item
for m in self.conversation
for item in [
LLMMessage(role="user", content=m.prompt.strip() if m.prompt else ""),
LLMMessage(role="assistant", content=m.response.strip()),
]
]
)
# Only the actual user query is provided with the full context message
messages.append(
LLMMessage(role="user", content=query.prompt.strip())
)
# message.messages = messages
chat_message.metadata = ChatMessageMetaData()
chat_message.metadata.options = {
"seed": 8911,
"num_ctx": self.context_size,
"temperature": temperature, # Higher temperature to encourage tool usage
}
# Create a dict for storing various timing stats
chat_message.metadata.timers = {}
# use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
# message.metadata.tools = {
# "available": llm_tools(self.context.tools),
# "used": False,
# }
# tool_metadata = message.metadata.tools
# if use_tools:
# message.status = "thinking"
# message.content = f"Performing tool analysis step 1/2..."
# yield message
# logger.info("Checking for LLM tool usage")
# start_time = time.perf_counter()
# # Tools are enabled and available, so query the LLM with a short context of messages
# # in case the LLM did something like ask "Do you want me to run the tool?" and the
# # user said "Yes" -- need to keep the context in the thread.
# tool_metadata["messages"] = (
# [{"role": "system", "content": self.system_prompt}] + messages[-6:]
# if len(messages) >= 7
# else messages
# )
# response = llm.chat(
# model=model,
# messages=tool_metadata["messages"],
# tools=tool_metadata["available"],
# options={
# **message.metadata.options,
# },
# stream=False, # No need to stream the probe
# )
# self.collect_metrics(response)
# end_time = time.perf_counter()
# message.metadata.timers["tool_check"] = end_time - start_time
# if not response.message.tool_calls:
# logger.info("LLM indicates tools will not be used")
# # The LLM will not use tools, so disable use_tools so we can stream the full response
# use_tools = False
# else:
# tool_metadata["attempted"] = response.message.tool_calls
# if use_tools:
# logger.info("LLM indicates tools will be used")
# # Tools are enabled and available and the LLM indicated it will use them
# message.content = (
# f"Performing tool analysis step 2/2 (tool use suspected)..."
# )
# yield message
# logger.info(f"Performing LLM call with tools")
# start_time = time.perf_counter()
# response = llm.chat(
# model=model,
# messages=tool_metadata["messages"], # messages,
# tools=tool_metadata["available"],
# options={
# **message.metadata.options,
# },
# stream=False,
# )
# self.collect_metrics(response)
# end_time = time.perf_counter()
# message.metadata.timers["non_streaming"] = end_time - start_time
# if not response:
# message.status = "error"
# message.content = "No response from LLM."
# yield message
# return
# if response.message.tool_calls:
# tool_metadata["used"] = response.message.tool_calls
# # Process all yielded items from the handler
# start_time = time.perf_counter()
# async for message in self.process_tool_calls(
# llm=llm,
# model=model,
# message=message,
# tool_message=response.message,
# messages=messages,
# ):
# if message.status == "error":
# yield message
# return
# yield message
# end_time = time.perf_counter()
# message.metadata.timers["process_tool_calls"] = end_time - start_time
# message.status = "done"
# return
# logger.info("LLM indicated tools will be used, and then they weren't")
# message.content = response.message.content
# message.status = "done"
# yield message
# return
# not use_tools
chat_message.status = ChatStatusType.THINKING
chat_message.content = f"Generating response..."
yield chat_message
# Reset the response for streaming
chat_message.content = ""
start_time = time.perf_counter()
for response in llm.chat(
model=model,
messages=messages,
options={
**chat_message.metadata.options,
},
stream=True,
):
if not response:
chat_message.status = ChatStatusType.ERROR
chat_message.content = "No response from LLM."
yield chat_message
return
chat_message.status = ChatStatusType.STREAMING
chat_message.chunk = response.message.content
chat_message.content += chat_message.chunk
if not response.done:
yield chat_message
if response.done:
self.collect_metrics(response)
chat_message.metadata.eval_count += response.eval_count
chat_message.metadata.eval_duration += response.eval_duration
chat_message.metadata.prompt_eval_count += response.prompt_eval_count
chat_message.metadata.prompt_eval_duration += response.prompt_eval_duration
self.context_tokens = (
response.prompt_eval_count + response.eval_count
)
chat_message.status = ChatStatusType.DONE
yield chat_message
end_time = time.perf_counter()
chat_message.metadata.timers["streamed"] = end_time - start_time
chat_message.status = ChatStatusType.DONE
self.conversation.append(chat_message)
return
# async def process_message(
# self, llm: Any, model: str, message: Message
# ) -> AsyncGenerator[Message, None]:
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
# self.metrics.process_count.labels(agent=self.agent_type).inc()
# with self.metrics.process_duration.labels(agent=self.agent_type).time():
# if not self.context:
# raise ValueError("Context is not set for this agent.")
# logger.info(
# "TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time"
# )
# spinner: List[str] = ["\\", "|", "/", "-"]
# tick: int = 0
# while self.context.processing:
# message.status = "waiting"
# message.content = (
# f"Busy processing another request. Please wait. {spinner[tick]}"
# )
# tick = (tick + 1) % len(spinner)
# yield message
# await asyncio.sleep(1) # Allow the event loop to process the write
# self.context.processing = True
# message.system_prompt = (
# f"<|system|>\n{self.system_prompt.strip()}\n</|system|>"
# )
# message.context_prompt = ""
# for p in message.preamble.keys():
# message.context_prompt += (
# f"\n<|{p}|>\n{message.preamble[p].strip()}\n</|{p}>\n\n"
# )
# message.context_prompt += f"{message.prompt}"
# # Estimate token length of new messages
# message.content = f"Optimizing context..."
# message.status = "thinking"
# yield message
# message.context_size = self.set_optimal_context_size(
# llm, model, prompt=message.context_prompt
# )
# message.content = f"Processing {'RAG augmented ' if message.metadata.rag else ''}query..."
# message.status = "thinking"
# yield message
# async for message in self.generate_llm_response(
# llm=llm, model=model, message=message
# ):
# # logger.info(f"LLM: {message.status} - {f'...{message.content[-20:]}' if len(message.content) > 20 else message.content}")
# if message.status == "error":
# yield message
# self.context.processing = False
# return
# yield message
# # Done processing, add message to conversation
# message.status = "done"
# self.conversation.add(message)
# self.context.processing = False
# return
# Register the base agent
agent_registry.register(Agent._agent_type, Agent)

View File

@ -0,0 +1,88 @@
from __future__ import annotations
from typing import Literal, AsyncGenerator, ClassVar, Optional, Any
from datetime import datetime
import inspect
from .base import Agent, agent_registry
from logger import logger
from .registry import agent_registry
from models import ( ChatQuery, ChatMessage, Tunables, ChatStatusType)
system_message = f"""
Launched on {datetime.now().isoformat()}.
When answering queries, follow these steps:
- First analyze the query to determine if real-time information from the tools might be helpful
- Even when <|context|> or <|resume|> is provided, consider whether the tools would provide more current or comprehensive information
- Use the provided tools whenever they would enhance your response, regardless of whether context is also available
- When presenting weather forecasts, include relevant emojis immediately before the corresponding text. For example, for a sunny day, say \"☀️ Sunny\" or if the forecast says there will be \"rain showers, say \"🌧️ Rain showers\". Use this mapping for weather emojis: Sunny: ☀️, Cloudy: ☁️, Rainy: 🌧️, Snowy: ❄️
- When any combination of <|context|>, <|resume|> and tool outputs are relevant, synthesize information from all sources to provide the most complete answer
- Always prioritize the most up-to-date and relevant information, whether it comes from <|context|>, <|resume|> or tools
- If <|context|> and tool outputs contain conflicting information, prefer the tool outputs as they likely represent more current data
- If there is information in the <|context|> or <|resume|> sections to enhance the answer, incorporate it seamlessly and refer to it as 'the latest information' or 'recent data' instead of mentioning '<|context|>' (etc.) or quoting it directly.
- Avoid phrases like 'According to the <|context|>' or similar references to the <|context|> or <|resume|>.
CRITICAL INSTRUCTIONS FOR IMAGE GENERATION:
1. When the user requests to generate an image, inject the following into the response: <GenerateImage prompt="USER-PROMPT"/>. Do this when users request images, drawings, or visual content.
3. MANDATORY: You must respond with EXACTLY this format: <GenerateImage prompt="{{USER-PROMPT}}"/>
4. FORBIDDEN: DO NOT use markdown image syntax ![](url)
5. FORBIDDEN: DO NOT create fake URLs or file paths
6. FORBIDDEN: DO NOT use any other image embedding format
CORRECT EXAMPLE:
User: "Draw a cat"
Your response: "<GenerateImage prompt='Draw a cat'/>"
WRONG EXAMPLES (DO NOT DO THIS):
- ![](https://example.com/...)
- ![Cat image](any_url)
- <img src="...">
The <GenerateImage prompt="{{USER-PROMPT}}"/> format is the ONLY way to display images in this system.
DO NOT make up a URL for an image or provide markdown syntax for embedding an image. Only use <GenerateImage prompt="{{USER-PROMPT}}".
Always use tools, <|resume|>, and <|context|> when possible. Be concise, and never make up information. If you do not know the answer, say so.
"""
class Chat(Agent):
"""
Chat Agent
"""
agent_type: Literal["general"] = "general" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
system_prompt: str = system_message
# async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
# if not self.context:
# raise ValueError("Context is not set for this agent.")
# async for message in super().prepare_message(message):
# if message.status != "done":
# yield message
# if message.preamble:
# excluded = {}
# preamble_types = [
# f"<|{p}|>" for p in message.preamble.keys() if p not in excluded
# ]
# preamble_types_AND = " and ".join(preamble_types)
# preamble_types_OR = " or ".join(preamble_types)
# message.preamble[
# "rules"
# ] = f"""\
# - Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
# - If there is no information in these sections, answer based on your knowledge, or use any available tools.
# - Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
# """
# message.preamble["question"] = "Respond to:"
# Register the base agent
agent_registry.register(Chat._agent_type, Chat)

View File

@ -0,0 +1,33 @@
from __future__ import annotations
from typing import List, Dict, Optional, Type
# We'll use a registry pattern rather than hardcoded strings
class AgentRegistry:
"""Registry for agent types and classes"""
_registry: Dict[str, Type] = {}
@classmethod
def register(cls, agent_type: str, agent_class: Type) -> Type:
"""Register an agent class with its type"""
cls._registry[agent_type] = agent_class
return agent_class
@classmethod
def get_class(cls, agent_type: str) -> Optional[Type]:
"""Get the class for a given agent type"""
return cls._registry.get(agent_type)
@classmethod
def get_types(cls) -> List[str]:
"""Get all registered agent types"""
return list(cls._registry.keys())
@classmethod
def get_classes(cls) -> Dict[str, Type]:
"""Get all registered agent classes"""
return cls._registry.copy()
# Create a singleton instance
agent_registry = AgentRegistry()

View File

@ -8,9 +8,10 @@ import sys
from datetime import datetime from datetime import datetime
from models import ( from models import (
UserStatus, UserType, SkillLevel, EmploymentType, UserStatus, UserType, SkillLevel, EmploymentType,
Candidate, Employer, Location, Skill, AIParameters, AIModelType Candidate, Employer, Location, Skill, AIModelType
) )
def test_model_creation(): def test_model_creation():
"""Test that we can create models successfully""" """Test that we can create models successfully"""
print("🧪 Testing model creation...") print("🧪 Testing model creation...")
@ -118,41 +119,23 @@ def test_api_dict_format():
def test_validation_constraints(): def test_validation_constraints():
"""Test that validation constraints work""" """Test that validation constraints work"""
print("\n🔒 Testing validation constraints...") print("\n🔒 Testing validation constraints...")
# Test AI Parameters with constraints
valid_params = AIParameters(
name="Test Config",
model=AIModelType.GPT_4,
temperature=0.7, # Valid: 0-1
maxTokens=2000, # Valid: > 0
topP=0.95, # Valid: 0-1
frequencyPenalty=0.0, # Valid: -2 to 2
presencePenalty=0.0, # Valid: -2 to 2
isDefault=True,
createdAt=datetime.now(),
updatedAt=datetime.now()
)
print(f"✅ Valid AI parameters created")
# Test constraint violation
try: try:
invalid_params = AIParameters( # Create a candidate with invalid email
name="Invalid Config", invalid_candidate = Candidate(
model=AIModelType.GPT_4, email="invalid-email",
temperature=1.5, # Invalid: > 1 username="test_invalid",
maxTokens=2000,
topP=0.95,
frequencyPenalty=0.0,
presencePenalty=0.0,
isDefault=True,
createdAt=datetime.now(), createdAt=datetime.now(),
updatedAt=datetime.now() updatedAt=datetime.now(),
status=UserStatus.ACTIVE,
firstName="Jane",
lastName="Doe",
fullName="Jane Doe"
) )
print("❌ Should have rejected invalid temperature") print("❌ Validation should have failed but didn't")
return False return False
except Exception: except ValueError as e:
print(f"✅ Constraint validation working") print(f"✅ Validation error caught: {e}")
return True return True
def test_enum_values(): def test_enum_values():
@ -200,6 +183,7 @@ def main():
print(f"\n❌ Test failed: {type(e).__name__}: {e}") print(f"\n❌ Test failed: {type(e).__name__}: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
print(f"\n{traceback.format_exc()}")
return False return False
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -13,6 +13,7 @@ from datetime import datetime
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
import stat import stat
def run_command(command: str, description: str, cwd: str | None = None) -> bool: def run_command(command: str, description: str, cwd: str | None = None) -> bool:
"""Run a command and return success status""" """Run a command and return success status"""
try: try:
@ -68,9 +69,34 @@ except ImportError as e:
print("Make sure pydantic is installed: pip install pydantic") print("Make sure pydantic is installed: pip install pydantic")
sys.exit(1) sys.exit(1)
def python_type_to_typescript(python_type: Any) -> str: def unwrap_annotated_type(python_type: Any) -> Any:
"""Unwrap Annotated types to get the actual type"""
# Handle typing_extensions.Annotated and typing.Annotated
origin = get_origin(python_type)
args = get_args(python_type)
# Check for Annotated types - more robust detection
if origin is not None and args:
origin_str = str(origin)
if 'Annotated' in origin_str or (hasattr(origin, '__name__') and origin.__name__ == 'Annotated'):
# Return the first argument (the actual type)
return unwrap_annotated_type(args[0]) # Recursive unwrap in case of nested annotations
return python_type
def python_type_to_typescript(python_type: Any, debug: bool = False) -> str:
"""Convert a Python type to TypeScript type string""" """Convert a Python type to TypeScript type string"""
if debug:
print(f" 🔍 Converting type: {python_type} (type: {type(python_type)})")
# First unwrap any Annotated types
original_type = python_type
python_type = unwrap_annotated_type(python_type)
if debug and original_type != python_type:
print(f" 🔄 Unwrapped: {original_type} -> {python_type}")
# Handle None/null # Handle None/null
if python_type is type(None): if python_type is type(None):
return "null" return "null"
@ -79,6 +105,8 @@ def python_type_to_typescript(python_type: Any) -> str:
if python_type == str: if python_type == str:
return "string" return "string"
elif python_type == int or python_type == float: elif python_type == int or python_type == float:
if debug:
print(f" ✅ Converting {python_type} to number")
return "number" return "number"
elif python_type == bool: elif python_type == bool:
return "boolean" return "boolean"
@ -91,30 +119,33 @@ def python_type_to_typescript(python_type: Any) -> str:
origin = get_origin(python_type) origin = get_origin(python_type)
args = get_args(python_type) args = get_args(python_type)
if debug and origin:
print(f" 🔍 Generic type - origin: {origin}, args: {args}")
if origin is Union: if origin is Union:
# Handle Optional (Union[T, None]) # Handle Optional (Union[T, None])
if len(args) == 2 and type(None) in args: if len(args) == 2 and type(None) in args:
non_none_type = next(arg for arg in args if arg is not type(None)) non_none_type = next(arg for arg in args if arg is not type(None))
return python_type_to_typescript(non_none_type) return python_type_to_typescript(non_none_type, debug)
# Handle other unions # Handle other unions
union_types = [python_type_to_typescript(arg) for arg in args if arg is not type(None)] union_types = [python_type_to_typescript(arg, debug) for arg in args if arg is not type(None)]
return " | ".join(union_types) return " | ".join(union_types)
elif origin is list or origin is List: elif origin is list or origin is List:
if args: if args:
item_type = python_type_to_typescript(args[0]) item_type = python_type_to_typescript(args[0], debug)
return f"Array<{item_type}>" return f"Array<{item_type}>"
return "Array<any>" return "Array<any>"
elif origin is dict or origin is Dict: elif origin is dict or origin is Dict:
if len(args) == 2: if len(args) == 2:
key_type = python_type_to_typescript(args[0]) key_type = python_type_to_typescript(args[0], debug)
value_type = python_type_to_typescript(args[1]) value_type = python_type_to_typescript(args[1], debug)
return f"Record<{key_type}, {value_type}>" return f"Record<{key_type}, {value_type}>"
return "Record<string, any>" return "Record<string, any>"
# Handle Literal types - UPDATED SECTION # Handle Literal types
if hasattr(python_type, '__origin__') and str(python_type.__origin__).endswith('Literal'): if hasattr(python_type, '__origin__') and str(python_type.__origin__).endswith('Literal'):
if args: if args:
literal_values = [] literal_values = []
@ -155,6 +186,8 @@ def python_type_to_typescript(python_type: Any) -> str:
return "string" return "string"
# Default fallback # Default fallback
if debug:
print(f" ⚠️ Falling back to 'any' for type: {python_type}")
return "any" return "any"
def snake_to_camel(snake_str: str) -> str: def snake_to_camel(snake_str: str) -> str:
@ -162,32 +195,148 @@ def snake_to_camel(snake_str: str) -> str:
components = snake_str.split('_') components = snake_str.split('_')
return components[0] + ''.join(x.title() for x in components[1:]) return components[0] + ''.join(x.title() for x in components[1:])
def process_pydantic_model(model_class) -> Dict[str, Any]: def is_field_optional(field_info: Any, field_type: Any, debug: bool = False) -> bool:
"""Determine if a field should be optional in TypeScript"""
if debug:
print(f" 🔍 Analyzing field optionality:")
# First, check if the type itself is Optional (Union with None)
origin = get_origin(field_type)
args = get_args(field_type)
is_union_with_none = origin is Union and type(None) in args
if debug:
print(f" └─ Type is Optional[T]: {is_union_with_none}")
# If the type is Optional[T], it's always optional regardless of Field settings
if is_union_with_none:
if debug:
print(f" └─ RESULT: Optional (type is Optional[T])")
return True
# For non-Optional types, check Field settings and defaults
# Check for default factory (makes field optional)
has_default_factory = hasattr(field_info, 'default_factory') and field_info.default_factory is not None
if debug:
print(f" └─ Has default factory: {has_default_factory}")
if has_default_factory:
if debug:
print(f" └─ RESULT: Optional (has default factory)")
return True
# Check the default value - this is the tricky part
if hasattr(field_info, 'default'):
default_val = field_info.default
if debug:
print(f" └─ Has default attribute: {repr(default_val)} (type: {type(default_val)})")
# Check for different types of "no default" markers
# Pydantic uses various markers for "no default"
if default_val is ...: # Ellipsis
if debug:
print(f" └─ RESULT: Required (default is Ellipsis)")
return False
# Check for Pydantic's internal "PydanticUndefined" or similar markers
default_str = str(default_val)
default_type_str = str(type(default_val))
# Common patterns for "undefined" in Pydantic
undefined_patterns = [
'PydanticUndefined',
'Undefined',
'_Unset',
'UNSET',
'NotSet',
'_MISSING'
]
is_undefined_marker = any(pattern in default_str or pattern in default_type_str
for pattern in undefined_patterns)
if debug:
print(f" └─ Checking for undefined markers in: {default_str} | {default_type_str}")
print(f" └─ Is undefined marker: {is_undefined_marker}")
if is_undefined_marker:
if debug:
print(f" └─ RESULT: Required (default is undefined marker)")
return False
# Any other actual default value makes it optional
if debug:
print(f" └─ RESULT: Optional (has actual default value)")
return True
else:
if debug:
print(f" └─ No default attribute found")
# If no default attribute exists, check Pydantic's required flag
if hasattr(field_info, 'is_required'):
try:
is_required = field_info.is_required()
if debug:
print(f" └─ is_required(): {is_required}")
return not is_required
except:
if debug:
print(f" └─ is_required() failed")
pass
# Check the 'required' attribute (Pydantic v1 style)
if hasattr(field_info, 'required'):
is_required = field_info.required
if debug:
print(f" └─ required attribute: {is_required}")
return not is_required
# Default: if type is not Optional and no clear default, it's required (not optional)
if debug:
print(f" └─ RESULT: Required (fallback - no Optional type, no default)")
return False
def process_pydantic_model(model_class, debug: bool = False) -> Dict[str, Any]:
"""Process a Pydantic model and return TypeScript interface definition""" """Process a Pydantic model and return TypeScript interface definition"""
interface_name = model_class.__name__ interface_name = model_class.__name__
properties = [] properties = []
if debug:
print(f" 🔍 Processing model: {interface_name}")
# Get fields from the model # Get fields from the model
if hasattr(model_class, 'model_fields'): if hasattr(model_class, 'model_fields'):
# Pydantic v2 # Pydantic v2
fields = model_class.model_fields fields = model_class.model_fields
for field_name, field_info in fields.items(): for field_name, field_info in fields.items():
ts_name = snake_to_camel(field_name) if debug:
print(f" 📝 Field: {field_name}")
print(f" Field info: {field_info}")
print(f" Default: {getattr(field_info, 'default', 'NO_DEFAULT')}")
# Check for alias # Use alias if available, otherwise convert snake_case to camelCase
if hasattr(field_info, 'alias') and field_info.alias: if hasattr(field_info, 'alias') and field_info.alias:
ts_name = field_info.alias ts_name = field_info.alias
else:
ts_name = snake_to_camel(field_name)
# Get type annotation # Get type annotation
field_type = getattr(field_info, 'annotation', str) field_type = getattr(field_info, 'annotation', str)
ts_type = python_type_to_typescript(field_type) if debug:
print(f" Raw type: {field_type}")
ts_type = python_type_to_typescript(field_type, debug)
# Check if optional # Check if optional
is_optional = False is_optional = is_field_optional(field_info, field_type, debug)
if hasattr(field_info, 'is_required'):
is_optional = not field_info.is_required() if debug:
elif hasattr(field_info, 'default'): print(f" TS name: {ts_name}")
is_optional = field_info.default is not None print(f" TS type: {ts_type}")
print(f" Optional: {is_optional}")
print()
properties.append({ properties.append({
'name': ts_name, 'name': ts_name,
@ -199,17 +348,45 @@ def process_pydantic_model(model_class) -> Dict[str, Any]:
# Pydantic v1 # Pydantic v1
fields = model_class.__fields__ fields = model_class.__fields__
for field_name, field_info in fields.items(): for field_name, field_info in fields.items():
ts_name = snake_to_camel(field_name) if debug:
print(f" 📝 Field: {field_name} (Pydantic v1)")
print(f" Field info: {field_info}")
# Use alias if available, otherwise convert snake_case to camelCase
if hasattr(field_info, 'alias') and field_info.alias: if hasattr(field_info, 'alias') and field_info.alias:
ts_name = field_info.alias ts_name = field_info.alias
else:
ts_name = snake_to_camel(field_name)
field_type = getattr(field_info, 'annotation', getattr(field_info, 'type_', str)) field_type = getattr(field_info, 'annotation', getattr(field_info, 'type_', str))
ts_type = python_type_to_typescript(field_type) if debug:
print(f" Raw type: {field_type}")
is_optional = not getattr(field_info, 'required', True) ts_type = python_type_to_typescript(field_type, debug)
if hasattr(field_info, 'default') and field_info.default is not None:
is_optional = True # For Pydantic v1, check required and default
is_optional = is_field_optional(field_info, field_type)
if debug:
print(f" TS name: {ts_name}")
print(f" TS type: {ts_type}")
print(f" Optional: {is_optional}")
# Debug the optional logic
origin = get_origin(field_type)
args = get_args(field_type)
is_union_with_none = origin is Union and type(None) in args
has_default = hasattr(field_info, 'default')
has_default_factory = hasattr(field_info, 'default_factory') and field_info.default_factory is not None
print(f" └─ Type is Optional: {is_union_with_none}")
if has_default:
default_val = field_info.default
print(f" └─ Has default: {default_val} (is ...? {default_val is ...})")
else:
print(f" └─ No default attribute")
print(f" └─ Has default factory: {has_default_factory}")
print()
properties.append({ properties.append({
'name': ts_name, 'name': ts_name,
@ -233,7 +410,7 @@ def process_enum(enum_class) -> Dict[str, Any]:
'values': " | ".join(values) 'values': " | ".join(values)
} }
def generate_typescript_interfaces(source_file: str): def generate_typescript_interfaces(source_file: str, debug: bool = False):
"""Generate TypeScript interfaces from models""" """Generate TypeScript interfaces from models"""
print(f"📖 Scanning {source_file} for Pydantic models and enums...") print(f"📖 Scanning {source_file} for Pydantic models and enums...")
@ -270,7 +447,7 @@ def generate_typescript_interfaces(source_file: str):
issubclass(obj, BaseModel) and issubclass(obj, BaseModel) and
obj != BaseModel): obj != BaseModel):
interface = process_pydantic_model(obj) interface = process_pydantic_model(obj, debug)
interfaces.append(interface) interfaces.append(interface)
print(f" ✅ Found Pydantic model: {name}") print(f" ✅ Found Pydantic model: {name}")
@ -284,6 +461,9 @@ def generate_typescript_interfaces(source_file: str):
except Exception as e: except Exception as e:
print(f" ⚠️ Warning: Error processing {name}: {e}") print(f" ⚠️ Warning: Error processing {name}: {e}")
if debug:
import traceback
traceback.print_exc()
continue continue
print(f"\n📊 Found {len(interfaces)} interfaces and {len(enums)} enums") print(f"\n📊 Found {len(interfaces)} interfaces and {len(enums)} enums")
@ -362,7 +542,8 @@ Examples:
python generate_types.py --source models.py --output types.ts # Specify files python generate_types.py --source models.py --output types.ts # Specify files
python generate_types.py --skip-test # Skip model validation python generate_types.py --skip-test # Skip model validation
python generate_types.py --skip-compile # Skip TS compilation python generate_types.py --skip-compile # Skip TS compilation
python generate_types.py --source models.py --output types.ts --skip-test --skip-compile python generate_types.py --debug # Enable debug output
python generate_types.py --source models.py --output types.ts --skip-test --skip-compile --debug
""" """
) )
@ -390,6 +571,12 @@ Examples:
help='Skip TypeScript compilation check after generation' help='Skip TypeScript compilation check after generation'
) )
parser.add_argument(
'--debug',
action='store_true',
help='Enable debug output to troubleshoot type conversion issues'
)
parser.add_argument( parser.add_argument(
'--version', '-v', '--version', '-v',
action='version', action='version',
@ -422,7 +609,11 @@ Examples:
# Step 3: Generate TypeScript content # Step 3: Generate TypeScript content
print("🔄 Generating TypeScript types...") print("🔄 Generating TypeScript types...")
ts_content = generate_typescript_interfaces(args.source) if args.debug:
print("🐛 Debug mode enabled - detailed output follows:")
print()
ts_content = generate_typescript_interfaces(args.source, args.debug)
if ts_content is None: if ts_content is None:
print("❌ Failed to generate TypeScript content") print("❌ Failed to generate TypeScript content")

View File

@ -0,0 +1,41 @@
import ollama
import defines
_llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
class llm_manager:
"""
A class to manage LLM operations using the Ollama client.
"""
@staticmethod
def get_llm() -> ollama.Client: # type: ignore
"""
Get the Ollama client instance.
Returns:
An instance of the Ollama client.
"""
return _llm
@staticmethod
def get_models() -> list[str]:
"""
Get a list of available models from the Ollama client.
Returns:
List of model names.
"""
return _llm.models()
@staticmethod
def get_model_info(model_name: str) -> dict:
"""
Get information about a specific model.
Args:
model_name: The name of the model to retrieve information for.
Returns:
A dictionary containing model information.
"""
return _llm.model(model_name)

View File

@ -1,7 +1,7 @@
from fastapi import FastAPI, HTTPException, Depends, Query, Path, Body, status, APIRouter, Request # type: ignore from fastapi import FastAPI, HTTPException, Depends, Query, Path, Body, status, APIRouter, Request # type: ignore
from fastapi.middleware.cors import CORSMiddleware # type: ignore from fastapi.middleware.cors import CORSMiddleware # type: ignore
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials # type: ignore from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials # type: ignore
from fastapi.responses import JSONResponse # type: ignore from fastapi.responses import JSONResponse, StreamingResponse# type: ignore
from fastapi.staticfiles import StaticFiles # type: ignore from fastapi.staticfiles import StaticFiles # type: ignore
import uvicorn # type: ignore import uvicorn # type: ignore
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any
@ -15,6 +15,7 @@ import re
import asyncio import asyncio
import signal import signal
import json import json
import traceback
# Prometheus # Prometheus
from prometheus_client import Summary # type: ignore from prometheus_client import Summary # type: ignore
@ -24,25 +25,24 @@ from prometheus_client import CollectorRegistry, Counter # type: ignore
# Import Pydantic models # Import Pydantic models
from models import ( from models import (
# User models # User models
Candidate, Employer, BaseUser, Guest, Authentication, AuthResponse, Candidate, Employer, BaseUserWithType, BaseUser, Guest, Authentication, AuthResponse,
# Job models # Job models
Job, JobApplication, ApplicationStatus, Job, JobApplication, ApplicationStatus,
# Chat models # Chat models
ChatSession, ChatMessage, ChatContext, ChatSession, ChatMessage, ChatContext, ChatQuery,
# AI models
AIParameters,
# Supporting models # Supporting models
Location, Skill, WorkExperience, Education Location, Skill, WorkExperience, Education
) )
import defines import defines
import agents
from logger import logger from logger import logger
from database import RedisDatabase, redis_manager, DatabaseManager from database import RedisDatabase, redis_manager, DatabaseManager
from metrics import Metrics from metrics import Metrics
from llm_manager import llm_manager
# Initialize FastAPI app # Initialize FastAPI app
# ============================ # ============================
@ -140,20 +140,45 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt return encoded_jwt
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""Verify token and check if it's blacklisted"""
try: try:
# First decode the token
payload = jwt.decode(credentials.credentials, SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(credentials.credentials, SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get("sub") user_id: str = payload.get("sub")
if user_id is None: if user_id is None:
raise HTTPException(status_code=401, detail="Invalid authentication credentials") raise HTTPException(status_code=401, detail="Invalid authentication credentials")
# Check if token is blacklisted
redis_client = redis_manager.get_client()
blacklist_key = f"blacklisted_token:{credentials.credentials}"
is_blacklisted = await redis_client.exists(blacklist_key)
if is_blacklisted:
logger.warning(f"🚫 Attempt to use blacklisted token for user {user_id}")
raise HTTPException(status_code=401, detail="Token has been revoked")
# Optional: Check if all user tokens are revoked (for "logout from all devices")
# user_revoked_key = f"user_tokens_revoked:{user_id}"
# user_tokens_revoked_at = await redis_client.get(user_revoked_key)
# if user_tokens_revoked_at:
# revoked_timestamp = datetime.fromisoformat(user_tokens_revoked_at.decode())
# token_issued_at = datetime.fromtimestamp(payload.get("iat", 0), UTC)
# if token_issued_at < revoked_timestamp:
# raise HTTPException(status_code=401, detail="All user tokens have been revoked")
return user_id return user_id
except jwt.PyJWTError: except jwt.PyJWTError:
raise HTTPException(status_code=401, detail="Invalid authentication credentials") raise HTTPException(status_code=401, detail="Invalid authentication credentials")
except Exception as e:
logger.error(f"Token verification error: {e}")
raise HTTPException(status_code=401, detail="Token verification failed")
async def get_current_user( async def get_current_user(
user_id: str = Depends(verify_token), user_id: str = Depends(verify_token_with_blacklist),
database: RedisDatabase = Depends(lambda: db_manager.get_database()) database: RedisDatabase = Depends(lambda: db_manager.get_database())
): ) -> BaseUserWithType:
"""Get current user from database""" """Get current user from database"""
try: try:
# Check candidates first # Check candidates first
@ -321,12 +346,147 @@ async def login(
return create_success_response(auth_response.model_dump(by_alias=True)) return create_success_response(auth_response.model_dump(by_alias=True))
except Exception as e: except Exception as e:
logger.error(f"Login error: {e}") logger.error(f"⚠️ Login error: {e}")
return JSONResponse( return JSONResponse(
status_code=500, status_code=500,
content=create_error_response("LOGIN_ERROR", str(e)) content=create_error_response("LOGIN_ERROR", str(e))
) )
@api_router.post("/auth/logout")
async def logout(
access_token: str = Body(..., alias="accessToken"),
refresh_token: str = Body(..., alias="refreshToken"),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
):
"""Logout endpoint - revokes both access and refresh tokens"""
logger.info(f"🔑 User {current_user.id} is logging out")
try:
# Verify refresh token
try:
refresh_payload = jwt.decode(refresh_token, SECRET_KEY, algorithms=[ALGORITHM])
user_id = refresh_payload.get("sub")
token_type = refresh_payload.get("type")
refresh_exp = refresh_payload.get("exp")
if not user_id or token_type != "refresh":
return JSONResponse(
status_code=401,
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
)
except jwt.PyJWTError as e:
logger.warning(f"Invalid refresh token during logout: {e}")
return JSONResponse(
status_code=401,
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
)
# Verify that the refresh token belongs to the current user
if user_id != current_user.id:
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Token does not belong to current user")
)
# Get Redis client
redis_client = redis_manager.get_client()
# Revoke refresh token (blacklist it until its natural expiration)
refresh_ttl = max(0, refresh_exp - int(datetime.now(UTC).timestamp()))
if refresh_ttl > 0:
await redis_client.setex(
f"blacklisted_token:{refresh_token}",
refresh_ttl,
json.dumps({
"user_id": user_id,
"token_type": "refresh",
"revoked_at": datetime.now(UTC).isoformat(),
"reason": "user_logout"
})
)
logger.info(f"🔒 Blacklisted refresh token for user {user_id}")
# If access token is provided, revoke it too
if access_token:
try:
access_payload = jwt.decode(access_token, SECRET_KEY, algorithms=[ALGORITHM])
access_user_id = access_payload.get("sub")
access_exp = access_payload.get("exp")
# Verify access token belongs to same user
if access_user_id == user_id:
access_ttl = max(0, access_exp - int(datetime.now(UTC).timestamp()))
if access_ttl > 0:
await redis_client.setex(
f"blacklisted_token:{access_token}",
access_ttl,
json.dumps({
"user_id": user_id,
"token_type": "access",
"revoked_at": datetime.now(UTC).isoformat(),
"reason": "user_logout"
})
)
logger.info(f"🔒 Blacklisted access token for user {user_id}")
else:
logger.warning(f"Access token user mismatch during logout: {access_user_id} != {user_id}")
except jwt.PyJWTError as e:
logger.warning(f"Invalid access token during logout (non-critical): {e}")
# Don't fail logout if access token is invalid
# Optional: Revoke all tokens for this user (for "logout from all devices")
# Uncomment the following lines if you want to implement this feature:
#
# await redis_client.setex(
# f"user_tokens_revoked:{user_id}",
# timedelta(days=30).total_seconds(), # Max refresh token lifetime
# datetime.now(UTC).isoformat()
# )
logger.info(f"🔑 User {user_id} logged out successfully")
return create_success_response({
"message": "Logged out successfully",
"tokensRevoked": {
"refreshToken": True,
"accessToken": bool(access_token)
}
})
except Exception as e:
logger.error(f"⚠️ Logout error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("LOGOUT_ERROR", str(e))
)
@api_router.post("/auth/logout-all")
async def logout_all_devices(
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
):
"""Logout from all devices by revoking all tokens for the user"""
try:
redis_client = redis_manager.get_client()
# Set a timestamp that invalidates all tokens issued before this moment
await redis_client.setex(
f"user_tokens_revoked:{current_user.id}",
int(timedelta(days=30).total_seconds()), # Max refresh token lifetime
datetime.now(UTC).isoformat()
)
logger.info(f"🔒 All tokens revoked for user {current_user.id}")
return create_success_response({
"message": "Logged out from all devices successfully"
})
except Exception as e:
logger.error(f"⚠️ Logout all devices error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("LOGOUT_ALL_ERROR", str(e))
)
@api_router.post("/auth/refresh") @api_router.post("/auth/refresh")
async def refresh_token_endpoint( async def refresh_token_endpoint(
refreshToken: str = Body(..., alias="refreshToken"), refreshToken: str = Body(..., alias="refreshToken"),
@ -427,21 +587,33 @@ async def create_candidate(
content=create_error_response("CREATION_FAILED", str(e)) content=create_error_response("CREATION_FAILED", str(e))
) )
@api_router.get("/candidates/{candidate_id}") @api_router.get("/candidates/{username}")
async def get_candidate( async def get_candidate(
candidate_id: str = Path(...), username: str = Path(...),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database)
): ):
"""Get a candidate by ID""" """Get a candidate by username"""
try: try:
candidate_data = await database.get_candidate(candidate_id) all_candidates_data = await database.get_all_candidates()
if not candidate_data: candidates_list = [Candidate.model_validate(data) for data in all_candidates_data.values()]
# Normalize username to lowercase for case-insensitive search
query_lower = username.lower()
# Filter by search query
candidates_list = [
c for c in candidates_list
if (query_lower == c.email.lower() or
query_lower == c.username.lower())
]
if not len(candidates_list):
return JSONResponse( return JSONResponse(
status_code=404, status_code=404,
content=create_error_response("NOT_FOUND", "Candidate not found") content=create_error_response("NOT_FOUND", "Candidate not found")
) )
candidate = Candidate.model_validate(candidate_data) candidate = Candidate.model_validate(candidates_list[0])
return create_success_response(candidate.model_dump(by_alias=True)) return create_success_response(candidate.model_dump(by_alias=True))
except Exception as e: except Exception as e:
@ -558,6 +730,7 @@ async def search_candidates(
if (query_lower in c.first_name.lower() or if (query_lower in c.first_name.lower() or
query_lower in c.last_name.lower() or query_lower in c.last_name.lower() or
query_lower in c.email.lower() or query_lower in c.email.lower() or
query_lower in c.username.lower() or
any(query_lower in skill.name.lower() for skill in c.skills)) any(query_lower in skill.name.lower() for skill in c.skills))
] ]
@ -727,6 +900,209 @@ async def search_jobs(
content=create_error_response("SEARCH_FAILED", str(e)) content=create_error_response("SEARCH_FAILED", str(e))
) )
# ============================
# Chat Endpoints
# ============================
@api_router.post("/chat/sessions")
async def create_chat_session(
session_data: Dict[str, Any] = Body(...),
current_user : BaseUserWithType = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
):
"""Create a new chat session"""
try:
# Add required fields
session_data["id"] = str(uuid.uuid4())
session_data["createdAt"] = datetime.now(UTC).isoformat()
session_data["updatedAt"] = datetime.now(UTC).isoformat()
# Create chat session
chat_session = ChatSession.model_validate(session_data)
await database.set_chat_session(chat_session.id, chat_session.model_dump())
logger.info(f"✅ Chat session created: {chat_session.id} for user {current_user.id}")
return create_success_response(chat_session.model_dump(by_alias=True))
except Exception as e:
logger.error(f"Chat session creation error: {e}")
return JSONResponse(
status_code=400,
content=create_error_response("CREATION_FAILED", str(e))
)
@api_router.get("/chat/sessions/{session_id}")
async def get_chat_session(
session_id: str = Path(...),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
):
"""Get a chat session by ID"""
try:
chat_session_data = await database.get_chat_session(session_id)
if not chat_session_data:
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
chat_session = ChatSession.model_validate(chat_session_data)
return create_success_response(chat_session.model_dump(by_alias=True))
except Exception as e:
logger.error(f"Get chat session error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("FETCH_ERROR", str(e))
)
@api_router.get("/chat/sessions/{session_id}/messages")
async def get_chat_session_messages(
session_id: str = Path(...),
current_user = Depends(get_current_user),
page: int = Query(1, ge=1),
limit: int = Query(20, ge=1, le=100),
sortBy: Optional[str] = Query(None, alias="sortBy"),
sortOrder: str = Query("desc", pattern="^(asc|desc)$", alias="sortOrder"),
filters: Optional[str] = Query(None),
database: RedisDatabase = Depends(get_database)
):
"""Get a chat session by ID"""
try:
chat_session_data = await database.get_chat_session(session_id)
if not chat_session_data:
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
chat_messages = await database.get_chat_messages(session_id)
# Convert messages to ChatMessage objects
messages_list = [ChatMessage.model_validate(msg) for msg in chat_messages]
# Apply filters and pagination
filter_dict = None
if filters:
filter_dict = json.loads(filters)
paginated_messages, total = filter_and_paginate(
messages_list, page, limit, sortBy, sortOrder, filter_dict
)
paginated_response = create_paginated_response(
[m.model_dump(by_alias=True) for m in paginated_messages],
page, limit, total
)
return create_success_response(paginated_response)
except Exception as e:
logger.error(f"Get chat session error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("FETCH_ERROR", str(e))
)
@api_router.post("/chat/sessions/{session_id}/messages/stream")
async def post_chat_session_message_stream(
session_id: str = Path(...),
data: Dict[str, Any] = Body(...),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
request: Request = Request, # For streaming response
):
"""Post a message to a chat session and stream the response"""
try:
chat_session_data = await database.get_chat_session(session_id)
if not chat_session_data:
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
chat_type = chat_session_data.get("context", {}).get("type", "general")
logger.info(f"🔗 Chat session {session_id} type {chat_type} accessed by user {current_user.id}")
query = data.get("query")
if not query:
return JSONResponse(
status_code=400,
content=create_error_response("INVALID_QUERY", "Query cannot be empty")
)
chat_query = ChatQuery.model_validate(query)
chat_agent = agents.get_or_create_agent(agent_type=chat_type, prometheus_collector=prometheus_collector, database=database)
if not chat_agent:
return JSONResponse(
status_code=400,
content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type")
)
async def message_stream_generator():
"""Generator to stream messages"""
async for message in chat_agent.generate(
llm=llm_manager.get_llm(),
model=defines.model,
query=chat_query,
session_id=session_id,
user_id=current_user.id,
):
json_data = message.model_dump(mode='json', by_alias=True)
json_str = json.dumps(json_data)
logger.info(f"🔗 Streaming message for session {session_id}: {json_str}")
yield json_str + "\n"
return StreamingResponse(
message_stream_generator(),
media_type="application/json",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Prevents Nginx buffering if you're using it
},
)
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"Get chat session error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("FETCH_ERROR", str(e))
)
@api_router.get("/chat/sessions")
async def get_chat_sessions(
page: int = Query(1, ge=1),
limit: int = Query(20, ge=1, le=100),
sortBy: Optional[str] = Query(None, alias="sortBy"),
sortOrder: str = Query("desc", pattern="^(asc|desc)$", alias="sortOrder"),
filters: Optional[str] = Query(None),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
):
"""Get paginated list of chat sessions"""
try:
filter_dict = None
if filters:
filter_dict = json.loads(filters)
# Get all chat sessions from Redis
all_sessions_data = await database.get_all_chat_sessions()
sessions_list = [ChatSession.model_validate(data) for data in all_sessions_data.values()]
paginated_sessions, total = filter_and_paginate(
sessions_list, page, limit, sortBy, sortOrder, filter_dict
)
paginated_response = create_paginated_response(
[s.model_dump(by_alias=True) for s in paginated_sessions],
page, limit, total
)
return create_success_response(paginated_response)
except Exception as e:
logger.error(f"Get chat sessions error: {e}")
return JSONResponse(
status_code=400,
content=create_error_response("FETCH_FAILED", str(e))
)
# ============================ # ============================
# Health Check and Info Endpoints # Health Check and Info Endpoints
# ============================ # ============================
@ -790,6 +1166,11 @@ async def redis_stats(redis_client: redis.Redis = Depends(get_redis)):
except Exception as e: except Exception as e:
raise HTTPException(status_code=503, detail=f"Redis stats unavailable: {e}") raise HTTPException(status_code=503, detail=f"Redis stats unavailable: {e}")
@api_router.get("/system-info")
async def get_system_info(request: Request):
from system_info import system_info # Import system_info function from system_info module
return JSONResponse(system_info())
@api_router.get("/") @api_router.get("/")
async def api_info(): async def api_info():
"""API information endpoint""" """API information endpoint"""

View File

@ -1,7 +1,7 @@
from typing import List, Dict, Optional, Any, Union, Literal, TypeVar, Generic, Annotated from typing import List, Dict, Optional, Any, Union, Literal, TypeVar, Generic, Annotated
from pydantic import BaseModel, Field, EmailStr, HttpUrl, validator # type: ignore from pydantic import BaseModel, Field, EmailStr, HttpUrl, validator # type: ignore
from pydantic.types import constr, conint # type: ignore from pydantic.types import constr, conint # type: ignore
from datetime import datetime, date from datetime import datetime, date, UTC
from enum import Enum from enum import Enum
import uuid import uuid
@ -68,10 +68,11 @@ class ChatSenderType(str, Enum):
SYSTEM = "system" SYSTEM = "system"
class ChatStatusType(str, Enum): class ChatStatusType(str, Enum):
PARTIAL = "partial" PREPARING = "preparing"
DONE = "done"
STREAMING = "streaming"
THINKING = "thinking" THINKING = "thinking"
PARTIAL = "partial"
STREAMING = "streaming"
DONE = "done"
ERROR = "error" ERROR = "error"
class ChatContextType(str, Enum): class ChatContextType(str, Enum):
@ -80,13 +81,12 @@ class ChatContextType(str, Enum):
INTERVIEW_PREP = "interview_prep" INTERVIEW_PREP = "interview_prep"
RESUME_REVIEW = "resume_review" RESUME_REVIEW = "resume_review"
GENERAL = "general" GENERAL = "general"
GENERATE_PERSONA = "generate_persona"
GENERATE_PROFILE = "generate_profile"
class AIModelType(str, Enum): class AIModelType(str, Enum):
GPT_4 = "gpt-4" QWEN2_5 = "qwen2.5"
GPT_35_TURBO = "gpt-3.5-turbo" FLUX_SCHNELL = "flux-schnell"
CLAUDE_3 = "claude-3"
CLAUDE_3_OPUS = "claude-3-opus"
CUSTOM = "custom"
class MFAMethod(str, Enum): class MFAMethod(str, Enum):
APP = "app" APP = "app"
@ -520,47 +520,73 @@ class JobApplication(BaseModel):
class Config: class Config:
populate_by_name = True # Allow both field names and aliases populate_by_name = True # Allow both field names and aliases
class AIParameters(BaseModel): class RagEntry(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
user_id: Optional[str] = Field(None, alias="userId")
name: str name: str
description: Optional[str] = None description: str = ""
model: AIModelType enabled: bool = True
temperature: Annotated[float, Field(ge=0, le=1)]
max_tokens: Annotated[int, Field(gt=0)] = Field(..., alias="maxTokens") class ChromaDBGetResponse(BaseModel):
top_p: Annotated[float, Field(ge=0, le=1)] = Field(..., alias="topP") # Chroma fields
frequency_penalty: Annotated[float, Field(ge=-2, le=2)] = Field(..., alias="frequencyPenalty") ids: List[str] = []
presence_penalty: Annotated[float, Field(ge=-2, le=2)] = Field(..., alias="presencePenalty") embeddings: List[List[float]] = Field(default=[])
system_prompt: Optional[str] = Field(None, alias="systemPrompt") documents: List[str] = []
is_default: bool = Field(..., alias="isDefault") metadatas: List[Dict[str, Any]] = []
created_at: datetime = Field(..., alias="createdAt") # Additional fields
updated_at: datetime = Field(..., alias="updatedAt") name: str = ""
custom_model_config: Optional[Dict[str, Any]] = Field(None, alias="customModelConfig") size: int = 0
class Config: query: str = ""
populate_by_name = True # Allow both field names and aliases query_embedding: Optional[List[float]] = Field(default=None, alias="queryEmbedding")
umap_embedding_2d: Optional[List[float]] = Field(default=None, alias="umapEmbedding2D")
umap_embedding_3d: Optional[List[float]] = Field(default=None, alias="umapEmbedding3D")
class ChatContext(BaseModel): class ChatContext(BaseModel):
type: ChatContextType type: ChatContextType
related_entity_id: Optional[str] = Field(None, alias="relatedEntityId") related_entity_id: Optional[str] = Field(None, alias="relatedEntityId")
related_entity_type: Optional[Literal["job", "candidate", "employer"]] = Field(None, alias="relatedEntityType") related_entity_type: Optional[Literal["job", "candidate", "employer"]] = Field(None, alias="relatedEntityType")
ai_parameters: AIParameters = Field(..., alias="aiParameters")
additional_context: Optional[Dict[str, Any]] = Field(None, alias="additionalContext") additional_context: Optional[Dict[str, Any]] = Field(None, alias="additionalContext")
class Config: class Config:
populate_by_name = True # Allow both field names and aliases populate_by_name = True # Allow both field names and aliases
class ChatOptions(BaseModel):
seed: Optional[int] = 8911
num_ctx: Optional[int] = Field(default=None, alias="numCtx") # Number of context tokens
temperature: Optional[float] = Field(default=0.7) # Higher temperature to encourage tool usage
class ChatMessageMetaData(BaseModel):
model: AIModelType = AIModelType.QWEN2_5
temperature: float = 0.7
max_tokens: int = Field(default=8092, alias="maxTokens")
top_p: float = Field(default=1, alias="topP")
frequency_penalty: Optional[float] = Field(None, alias="frequencyPenalty")
presence_penalty: Optional[float] = Field(None, alias="presencePenalty")
stop_sequences: Optional[List[str]] = Field(None, alias="stopSequences")
tunables: Optional[Tunables] = None
rag: List[ChromaDBGetResponse] = Field(default_factory=list)
eval_count: int = 0
eval_duration: int = 0
prompt_eval_count: int = 0
prompt_eval_duration: int = 0
options: Optional[ChatOptions] = None
tools: Optional[Dict[str, Any]] = None
timers: Optional[Dict[str, float]] = None
class Config:
populate_by_name = True # Allow both field names and aliases
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4())) id: str = Field(default_factory=lambda: str(uuid.uuid4()))
session_id: str = Field(..., alias="sessionId") session_id: str = Field(..., alias="sessionId")
status: ChatStatusType status: ChatStatusType
sender: ChatSenderType sender: ChatSenderType
sender_id: Optional[str] = Field(None, alias="senderId") sender_id: Optional[str] = Field(None, alias="senderId")
content: str prompt: str = ""
content: str = ""
chunk: str = ""
timestamp: datetime timestamp: datetime
attachments: Optional[List[Attachment]] = None #attachments: Optional[List[Attachment]] = None
reactions: Optional[List[MessageReaction]] = None #reactions: Optional[List[MessageReaction]] = None
is_edited: bool = Field(False, alias="isEdited") is_edited: bool = Field(False, alias="isEdited")
edit_history: Optional[List[EditHistory]] = Field(None, alias="editHistory") #edit_history: Optional[List[EditHistory]] = Field(None, alias="editHistory")
metadata: Optional[Dict[str, Any]] = None metadata: ChatMessageMetaData = Field(None)
class Config: class Config:
populate_by_name = True # Allow both field names and aliases populate_by_name = True # Allow both field names and aliases
@ -568,8 +594,8 @@ class ChatSession(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4())) id: str = Field(default_factory=lambda: str(uuid.uuid4()))
user_id: Optional[str] = Field(None, alias="userId") user_id: Optional[str] = Field(None, alias="userId")
guest_id: Optional[str] = Field(None, alias="guestId") guest_id: Optional[str] = Field(None, alias="guestId")
created_at: datetime = Field(..., alias="createdAt") created_at: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="createdAt")
last_activity: datetime = Field(..., alias="lastActivity") last_activity: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="lastActivity")
title: Optional[str] = None title: Optional[str] = None
context: ChatContext context: ChatContext
messages: Optional[List[ChatMessage]] = None messages: Optional[List[ChatMessage]] = None
@ -615,7 +641,6 @@ class RAGConfiguration(BaseModel):
retrieval_parameters: RetrievalParameters = Field(..., alias="retrievalParameters") retrieval_parameters: RetrievalParameters = Field(..., alias="retrievalParameters")
created_at: datetime = Field(..., alias="createdAt") created_at: datetime = Field(..., alias="createdAt")
updated_at: datetime = Field(..., alias="updatedAt") updated_at: datetime = Field(..., alias="updatedAt")
is_default: bool = Field(..., alias="isDefault")
version: int version: int
is_active: bool = Field(..., alias="isActive") is_active: bool = Field(..., alias="isActive")
class Config: class Config:
@ -672,7 +697,7 @@ class UserPreference(BaseModel):
# ============================ # ============================
# API Request/Response Models # API Request/Response Models
# ============================ # ============================
class Query(BaseModel): class ChatQuery(BaseModel):
prompt: str prompt: str
tunables: Optional[Tunables] = None tunables: Optional[Tunables] = None
agent_options: Optional[Dict[str, Any]] = Field(None, alias="agentOptions") agent_options: Optional[Dict[str, Any]] = Field(None, alias="agentOptions")

View File

@ -0,0 +1,81 @@
import defines
import re
import subprocess
import math
def get_installed_ram():
try:
with open("/proc/meminfo", "r") as f:
meminfo = f.read()
match = re.search(r"MemTotal:\s+(\d+)", meminfo)
if match:
return f"{math.floor(int(match.group(1)) / 1000**2)}GB" # Convert KB to GB
except Exception as e:
return f"Error retrieving RAM: {e}"
def get_graphics_cards():
gpus = []
try:
# Run the ze-monitor utility
result = subprocess.run(
["ze-monitor"], capture_output=True, text=True, check=True
)
# Clean up the output (remove leading/trailing whitespace and newlines)
output = result.stdout.strip()
for index in range(len(output.splitlines())):
result = subprocess.run(
["ze-monitor", "--device", f"{index+1}", "--info"],
capture_output=True,
text=True,
check=True,
)
gpu_info = result.stdout.strip().splitlines()
gpu = {
"discrete": True, # Assume it's discrete initially
"name": None,
"memory": None,
}
gpus.append(gpu)
for line in gpu_info:
match = re.match(r"^Device: [^(]*\((.*)\)", line)
if match:
gpu["name"] = match.group(1)
continue
match = re.match(r"^\s*Memory: (.*)", line)
if match:
gpu["memory"] = match.group(1)
continue
match = re.match(r"^.*Is integrated with host: Yes.*", line)
if match:
gpu["discrete"] = False
continue
return gpus
except Exception as e:
return f"Error retrieving GPU info: {e}"
def get_cpu_info():
try:
with open("/proc/cpuinfo", "r") as f:
cpuinfo = f.read()
model_match = re.search(r"model name\s+:\s+(.+)", cpuinfo)
cores_match = re.findall(r"processor\s+:\s+\d+", cpuinfo)
if model_match and cores_match:
return f"{model_match.group(1)} with {len(cores_match)} cores"
except Exception as e:
return f"Error retrieving CPU info: {e}"
def system_info():
return {
"System RAM": get_installed_ram(),
"Graphics Card": get_graphics_cards(),
"CPU": get_cpu_info(),
"LLM Model": defines.model,
"Embedding Model": defines.embedding_model,
"Context length": defines.max_context,
}

View File

@ -1,207 +0,0 @@
#!/usr/bin/env python
"""
Focused test script that tests the most important functionality
without getting caught up in serialization format complexities
"""
import sys
from datetime import datetime
from models import (
UserStatus, UserType, SkillLevel, EmploymentType,
Candidate, Employer, Location, Skill, AIParameters, AIModelType
)
def test_model_creation():
"""Test that we can create models successfully"""
print("🧪 Testing model creation...")
# Create supporting objects
location = Location(city="Austin", country="USA")
skill = Skill(name="Python", category="Programming", level=SkillLevel.ADVANCED)
# Create candidate
candidate = Candidate(
email="test@example.com",
username="test_candidate",
createdAt=datetime.now(),
updatedAt=datetime.now(),
status=UserStatus.ACTIVE,
firstName="John",
lastName="Doe",
fullName="John Doe",
skills=[skill],
experience=[],
education=[],
preferredJobTypes=[EmploymentType.FULL_TIME],
location=location,
languages=[],
certifications=[]
)
# Create employer
employer = Employer(
email="hr@company.com",
username="test_employer",
createdAt=datetime.now(),
updatedAt=datetime.now(),
status=UserStatus.ACTIVE,
companyName="Test Company",
industry="Technology",
companySize="50-200",
companyDescription="A test company",
location=location
)
print(f"✅ Candidate: {candidate.first_name} {candidate.last_name}")
print(f"✅ Employer: {employer.company_name}")
print(f"✅ User types: {candidate.user_type}, {employer.user_type}")
return candidate, employer
def test_json_api_format():
"""Test JSON serialization in API format (the most important use case)"""
print("\n📡 Testing JSON API format...")
candidate, employer = test_model_creation()
# Serialize to JSON (API format)
candidate_json = candidate.model_dump_json(by_alias=True)
employer_json = employer.model_dump_json(by_alias=True)
print(f"✅ Candidate JSON: {len(candidate_json)} chars")
print(f"✅ Employer JSON: {len(employer_json)} chars")
# Deserialize from JSON
candidate_back = Candidate.model_validate_json(candidate_json)
employer_back = Employer.model_validate_json(employer_json)
# Verify data integrity
assert candidate_back.email == candidate.email
assert candidate_back.first_name == candidate.first_name
assert employer_back.company_name == employer.company_name
print(f"✅ JSON round-trip successful")
print(f"✅ Data integrity verified")
return True
def test_api_dict_format():
"""Test dictionary format with aliases (for API requests/responses)"""
print("\n📊 Testing API dictionary format...")
candidate, employer = test_model_creation()
# Create API format dictionaries
candidate_dict = candidate.model_dump(by_alias=True)
employer_dict = employer.model_dump(by_alias=True)
# Verify camelCase aliases are used
assert "firstName" in candidate_dict
assert "lastName" in candidate_dict
assert "createdAt" in candidate_dict
assert "companyName" in employer_dict
print(f"✅ API format dictionaries created")
print(f"✅ CamelCase aliases verified")
# Test deserializing from API format
candidate_back = Candidate.model_validate(candidate_dict)
employer_back = Employer.model_validate(employer_dict)
assert candidate_back.email == candidate.email
assert employer_back.company_name == employer.company_name
print(f"✅ API format round-trip successful")
return True
def test_validation_constraints():
"""Test that validation constraints work"""
print("\n🔒 Testing validation constraints...")
# Test AI Parameters with constraints
valid_params = AIParameters(
name="Test Config",
model=AIModelType.GPT_4,
temperature=0.7, # Valid: 0-1
maxTokens=2000, # Valid: > 0
topP=0.95, # Valid: 0-1
frequencyPenalty=0.0, # Valid: -2 to 2
presencePenalty=0.0, # Valid: -2 to 2
isDefault=True,
createdAt=datetime.now(),
updatedAt=datetime.now()
)
print(f"✅ Valid AI parameters created")
# Test constraint violation
try:
invalid_params = AIParameters(
name="Invalid Config",
model=AIModelType.GPT_4,
temperature=1.5, # Invalid: > 1
maxTokens=2000,
topP=0.95,
frequencyPenalty=0.0,
presencePenalty=0.0,
isDefault=True,
createdAt=datetime.now(),
updatedAt=datetime.now()
)
print("❌ Should have rejected invalid temperature")
return False
except Exception:
print(f"✅ Constraint validation working")
return True
def test_enum_values():
"""Test that enum values work correctly"""
print("\n📋 Testing enum values...")
# Test that enum values are properly handled
candidate, employer = test_model_creation()
# Check enum values in serialization
candidate_dict = candidate.model_dump(by_alias=True)
assert candidate_dict["status"] == "active"
assert candidate_dict["userType"] == "candidate"
assert employer.user_type == UserType.EMPLOYER
print(f"✅ Enum values correctly serialized")
print(f"✅ User types: candidate={candidate.user_type}, employer={employer.user_type}")
return True
def main():
"""Run all focused tests"""
print("🎯 Focused Pydantic Model Tests")
print("=" * 40)
try:
test_model_creation()
test_json_api_format()
test_api_dict_format()
test_validation_constraints()
test_enum_values()
print(f"\n🎉 All focused tests passed!")
print("=" * 40)
print("✅ Models work correctly")
print("✅ JSON API format works")
print("✅ Validation constraints work")
print("✅ Enum values work")
print("✅ Ready for type generation!")
return True
except Exception as e:
print(f"\n❌ Test failed: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)

2
update-types.sh Executable file
View File

@ -0,0 +1,2 @@
#!/bin/bash
docker compose exec backstory shell "python src/backend/generate_types.py --source src/backend/models.py --output frontend/src/types/types.ts ${*}"