Compare commits
2 Commits
a8a8d3738d
...
02a278736e
Author | SHA1 | Date | |
---|---|---|---|
02a278736e | |||
b5b3a1f5dc |
@ -1,15 +1,14 @@
|
|||||||
import React, { useEffect, useState, useRef, useCallback } from 'react';
|
import React, { useEffect, useState, useRef, useCallback } from 'react';
|
||||||
import { Route, Routes, useLocation, useNavigate } from 'react-router-dom';
|
import { Route, Routes, useLocation, useNavigate } from 'react-router-dom';
|
||||||
import { ThemeProvider } from '@mui/material/styles';
|
import { ThemeProvider } from '@mui/material/styles';
|
||||||
|
|
||||||
import { backstoryTheme } from './BackstoryTheme';
|
import { backstoryTheme } from './BackstoryTheme';
|
||||||
|
|
||||||
import { SeverityType } from 'components/Snack';
|
import { SeverityType } from 'components/Snack';
|
||||||
import { Query } from 'types/types';
|
|
||||||
import { ConversationHandle } from 'components/Conversation';
|
import { ConversationHandle } from 'components/Conversation';
|
||||||
import { UserProvider } from 'hooks/useUser';
|
import { UserProvider } from 'hooks/useUser';
|
||||||
import { CandidateRoute } from 'routes/CandidateRoute';
|
import { CandidateRoute } from 'routes/CandidateRoute';
|
||||||
import { BackstoryLayout } from 'components/layout/BackstoryLayout';
|
import { BackstoryLayout } from 'components/layout/BackstoryLayout';
|
||||||
|
import { ChatQuery } from 'types/types';
|
||||||
|
|
||||||
import './BackstoryApp.css';
|
import './BackstoryApp.css';
|
||||||
import '@fontsource/roboto/300.css';
|
import '@fontsource/roboto/300.css';
|
||||||
@ -17,13 +16,7 @@ import '@fontsource/roboto/400.css';
|
|||||||
import '@fontsource/roboto/500.css';
|
import '@fontsource/roboto/500.css';
|
||||||
import '@fontsource/roboto/700.css';
|
import '@fontsource/roboto/700.css';
|
||||||
|
|
||||||
import { debugConversion } from 'types/conversion';
|
|
||||||
import { User, Guest, Candidate } from 'types/types';
|
|
||||||
|
|
||||||
const BackstoryApp = () => {
|
const BackstoryApp = () => {
|
||||||
const [user, setUser] = useState<User | null>(null);
|
|
||||||
const [guest, setGuest] = useState<Guest | null>(null);
|
|
||||||
const [candidate, setCandidate] = useState<Candidate | null>(null);
|
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
const location = useLocation();
|
const location = useLocation();
|
||||||
const snackRef = useRef<any>(null);
|
const snackRef = useRef<any>(null);
|
||||||
@ -31,58 +24,13 @@ const BackstoryApp = () => {
|
|||||||
const setSnack = useCallback((message: string, severity?: SeverityType) => {
|
const setSnack = useCallback((message: string, severity?: SeverityType) => {
|
||||||
snackRef.current?.setSnack(message, severity);
|
snackRef.current?.setSnack(message, severity);
|
||||||
}, [snackRef]);
|
}, [snackRef]);
|
||||||
const submitQuery = (query: Query) => {
|
const submitQuery = (query: ChatQuery) => {
|
||||||
console.log(`handleSubmitChatQuery:`, query, chatRef.current ? ' sending' : 'no handler');
|
console.log(`handleSubmitChatQuery:`, query, chatRef.current ? ' sending' : 'no handler');
|
||||||
chatRef.current?.submitQuery(query);
|
chatRef.current?.submitQuery(query);
|
||||||
navigate('/chat');
|
navigate('/chat');
|
||||||
};
|
};
|
||||||
const [page, setPage] = useState<string>("");
|
const [page, setPage] = useState<string>("");
|
||||||
|
|
||||||
const createGuestSession = () => {
|
|
||||||
console.log("TODO: Convert this to query the server for the session instead of generating it.");
|
|
||||||
const sessionId = `guest_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
|
|
||||||
const guest: Guest = {
|
|
||||||
sessionId,
|
|
||||||
createdAt: new Date(),
|
|
||||||
lastActivity: new Date(),
|
|
||||||
ipAddress: 'unknown',
|
|
||||||
userAgent: navigator.userAgent
|
|
||||||
};
|
|
||||||
setGuest(guest);
|
|
||||||
debugConversion(guest, 'Guest Session');
|
|
||||||
};
|
|
||||||
|
|
||||||
const checkExistingAuth = () => {
|
|
||||||
const token = localStorage.getItem('accessToken');
|
|
||||||
const userData = localStorage.getItem('userData');
|
|
||||||
if (token && userData) {
|
|
||||||
try {
|
|
||||||
const user = JSON.parse(userData);
|
|
||||||
// Convert dates back to Date objects if they're stored as strings
|
|
||||||
if (user.createdAt && typeof user.createdAt === 'string') {
|
|
||||||
user.createdAt = new Date(user.createdAt);
|
|
||||||
}
|
|
||||||
if (user.updatedAt && typeof user.updatedAt === 'string') {
|
|
||||||
user.updatedAt = new Date(user.updatedAt);
|
|
||||||
}
|
|
||||||
if (user.lastLogin && typeof user.lastLogin === 'string') {
|
|
||||||
user.lastLogin = new Date(user.lastLogin);
|
|
||||||
}
|
|
||||||
setUser(user);
|
|
||||||
} catch (e) {
|
|
||||||
localStorage.removeItem('accessToken');
|
|
||||||
localStorage.removeItem('refreshToken');
|
|
||||||
localStorage.removeItem('userData');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Create guest session on component mount
|
|
||||||
useEffect(() => {
|
|
||||||
createGuestSession();
|
|
||||||
checkExistingAuth();
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const currentRoute = location.pathname.split("/")[1] ? `/${location.pathname.split("/")[1]}` : "/";
|
const currentRoute = location.pathname.split("/")[1] ? `/${location.pathname.split("/")[1]}` : "/";
|
||||||
setPage(currentRoute);
|
setPage(currentRoute);
|
||||||
@ -91,27 +39,20 @@ const BackstoryApp = () => {
|
|||||||
// Render appropriate routes based on user type
|
// Render appropriate routes based on user type
|
||||||
return (
|
return (
|
||||||
<ThemeProvider theme={backstoryTheme}>
|
<ThemeProvider theme={backstoryTheme}>
|
||||||
<UserProvider {...{ guest, user, candidate, setSnack }}>
|
<UserProvider {...{ setSnack }}>
|
||||||
<Routes>
|
<Routes>
|
||||||
<Route path="/u/:username" element={<CandidateRoute {...{ guest, candidate, setCandidate, setSnack }} />} />
|
<Route path="/u/:username" element={<CandidateRoute {...{ setSnack }} />} />
|
||||||
{/* Static/shared routes */}
|
{/* Static/shared routes */}
|
||||||
<Route
|
<Route
|
||||||
path="/*"
|
path="/*"
|
||||||
element={
|
element={
|
||||||
<BackstoryLayout
|
<BackstoryLayout {...{ setSnack, page, chatRef, snackRef, submitQuery }} />
|
||||||
setSnack={setSnack}
|
|
||||||
page={page}
|
|
||||||
chatRef={chatRef}
|
|
||||||
snackRef={snackRef}
|
|
||||||
submitQuery={submitQuery}
|
|
||||||
/>
|
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
</Routes>
|
</Routes>
|
||||||
</UserProvider>
|
</UserProvider>
|
||||||
</ThemeProvider>
|
</ThemeProvider>
|
||||||
);
|
);
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export {
|
export {
|
||||||
|
@ -1,16 +1,16 @@
|
|||||||
import Box from '@mui/material/Box';
|
import Box from '@mui/material/Box';
|
||||||
import Button from '@mui/material/Button';
|
import Button from '@mui/material/Button';
|
||||||
|
|
||||||
import { Query } from "../types/types";
|
import { ChatQuery } from "types/types";
|
||||||
|
|
||||||
type ChatSubmitQueryInterface = (query: Query) => void;
|
type ChatSubmitQueryInterface = (query: ChatQuery) => void;
|
||||||
|
|
||||||
interface ChatQueryInterface {
|
interface BackstoryQueryInterface {
|
||||||
query: Query,
|
query: ChatQuery,
|
||||||
submitQuery?: ChatSubmitQueryInterface
|
submitQuery?: ChatSubmitQueryInterface
|
||||||
}
|
}
|
||||||
|
|
||||||
const ChatQuery = (props : ChatQueryInterface) => {
|
const BackstoryQuery = (props : BackstoryQueryInterface) => {
|
||||||
const { query, submitQuery } = props;
|
const { query, submitQuery } = props;
|
||||||
|
|
||||||
if (submitQuery === undefined) {
|
if (submitQuery === undefined) {
|
||||||
@ -29,11 +29,11 @@ const ChatQuery = (props : ChatQueryInterface) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export type {
|
export type {
|
||||||
ChatQueryInterface,
|
BackstoryQueryInterface,
|
||||||
ChatSubmitQueryInterface,
|
ChatSubmitQueryInterface,
|
||||||
};
|
};
|
||||||
|
|
||||||
export {
|
export {
|
||||||
ChatQuery,
|
BackstoryQuery,
|
||||||
};
|
};
|
||||||
|
|
@ -1,7 +1,7 @@
|
|||||||
import React, { ReactElement, JSXElementConstructor } from 'react';
|
import React, { ReactElement, JSXElementConstructor } from 'react';
|
||||||
import Box from '@mui/material/Box';
|
import Box from '@mui/material/Box';
|
||||||
import { SxProps, Theme } from '@mui/material';
|
import { SxProps, Theme } from '@mui/material';
|
||||||
import { ChatSubmitQueryInterface } from './ChatQuery';
|
import { ChatSubmitQueryInterface } from './BackstoryQuery';
|
||||||
import { SetSnackType } from './Snack';
|
import { SetSnackType } from './Snack';
|
||||||
|
|
||||||
interface BackstoryElementProps {
|
interface BackstoryElementProps {
|
||||||
|
@ -14,8 +14,8 @@ import { BackstoryTextField, BackstoryTextFieldRef } from 'components/BackstoryT
|
|||||||
import { BackstoryElementProps } from './BackstoryTab';
|
import { BackstoryElementProps } from './BackstoryTab';
|
||||||
import { connectionBase } from 'utils/Global';
|
import { connectionBase } from 'utils/Global';
|
||||||
import { useUser } from "hooks/useUser";
|
import { useUser } from "hooks/useUser";
|
||||||
import { ApiClient, StreamingResponse } from 'types/api-client';
|
import { StreamingResponse } from 'types/api-client';
|
||||||
import { ChatMessage, ChatContext, ChatSession, AIParameters, Query } from 'types/types';
|
import { ChatMessage, ChatContext, ChatSession, ChatQuery } from 'types/types';
|
||||||
import { PaginatedResponse } from 'types/conversion';
|
import { PaginatedResponse } from 'types/conversion';
|
||||||
|
|
||||||
import './Conversation.css';
|
import './Conversation.css';
|
||||||
@ -29,7 +29,7 @@ const loadingMessage: ChatMessage = { ...defaultMessage, content: "Establishing
|
|||||||
type ConversationMode = 'chat' | 'job_description' | 'resume' | 'fact_check' | 'persona';
|
type ConversationMode = 'chat' | 'job_description' | 'resume' | 'fact_check' | 'persona';
|
||||||
|
|
||||||
interface ConversationHandle {
|
interface ConversationHandle {
|
||||||
submitQuery: (query: Query) => void;
|
submitQuery: (query: ChatQuery) => void;
|
||||||
fetchHistory: () => void;
|
fetchHistory: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -69,8 +69,7 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
|
|||||||
sx,
|
sx,
|
||||||
type,
|
type,
|
||||||
} = props;
|
} = props;
|
||||||
const apiClient = new ApiClient();
|
const { candidate, apiClient } = useUser()
|
||||||
const { candidate } = useUser()
|
|
||||||
const [processing, setProcessing] = useState<boolean>(false);
|
const [processing, setProcessing] = useState<boolean>(false);
|
||||||
const [countdown, setCountdown] = useState<number>(0);
|
const [countdown, setCountdown] = useState<number>(0);
|
||||||
const [conversation, setConversation] = useState<ChatMessage[]>([]);
|
const [conversation, setConversation] = useState<ChatMessage[]>([]);
|
||||||
@ -125,23 +124,7 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
|
|||||||
}
|
}
|
||||||
const createChatSession = async () => {
|
const createChatSession = async () => {
|
||||||
try {
|
try {
|
||||||
const aiParameters: AIParameters = {
|
const chatContext: ChatContext = { type: "general" };
|
||||||
name: '',
|
|
||||||
model: 'custom',
|
|
||||||
temperature: 0.7,
|
|
||||||
maxTokens: -1,
|
|
||||||
topP: 1,
|
|
||||||
frequencyPenalty: 0,
|
|
||||||
presencePenalty: 0,
|
|
||||||
isDefault: true,
|
|
||||||
createdAt: new Date(),
|
|
||||||
updatedAt: new Date()
|
|
||||||
};
|
|
||||||
|
|
||||||
const chatContext: ChatContext = {
|
|
||||||
type: "general",
|
|
||||||
aiParameters
|
|
||||||
};
|
|
||||||
const response: ChatSession = await apiClient.createChatSession(chatContext);
|
const response: ChatSession = await apiClient.createChatSession(chatContext);
|
||||||
setChatSession(response);
|
setChatSession(response);
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
@ -203,14 +186,14 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
|
|||||||
}, [chatSession]);
|
}, [chatSession]);
|
||||||
|
|
||||||
const handleEnter = (value: string) => {
|
const handleEnter = (value: string) => {
|
||||||
const query: Query = {
|
const query: ChatQuery = {
|
||||||
prompt: value
|
prompt: value
|
||||||
}
|
}
|
||||||
processQuery(query);
|
processQuery(query);
|
||||||
};
|
};
|
||||||
|
|
||||||
useImperativeHandle(ref, () => ({
|
useImperativeHandle(ref, () => ({
|
||||||
submitQuery: (query: Query) => {
|
submitQuery: (query: ChatQuery) => {
|
||||||
processQuery(query);
|
processQuery(query);
|
||||||
},
|
},
|
||||||
fetchHistory: () => { getChatMessages(); }
|
fetchHistory: () => { getChatMessages(); }
|
||||||
@ -255,7 +238,7 @@ const Conversation = forwardRef<ConversationHandle, ConversationProps>((props: C
|
|||||||
controllerRef.current = null;
|
controllerRef.current = null;
|
||||||
};
|
};
|
||||||
|
|
||||||
const processQuery = (query: Query) => {
|
const processQuery = (query: ChatQuery) => {
|
||||||
if (controllerRef.current || !chatSession || !chatSession.id) {
|
if (controllerRef.current || !chatSession || !chatSession.id) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -2,14 +2,14 @@ import React from 'react';
|
|||||||
import { MuiMarkdown } from 'mui-markdown';
|
import { MuiMarkdown } from 'mui-markdown';
|
||||||
import { useTheme } from '@mui/material/styles';
|
import { useTheme } from '@mui/material/styles';
|
||||||
import { Link } from '@mui/material';
|
import { Link } from '@mui/material';
|
||||||
import { ChatQuery } from './ChatQuery';
|
import { BackstoryQuery } from 'components/BackstoryQuery';
|
||||||
import Box from '@mui/material/Box';
|
import Box from '@mui/material/Box';
|
||||||
import JsonView from '@uiw/react-json-view';
|
import JsonView from '@uiw/react-json-view';
|
||||||
import { vscodeTheme } from '@uiw/react-json-view/vscode';
|
import { vscodeTheme } from '@uiw/react-json-view/vscode';
|
||||||
import { Mermaid } from './Mermaid';
|
import { Mermaid } from 'components/Mermaid';
|
||||||
import { Scrollable } from './Scrollable';
|
import { Scrollable } from 'components/Scrollable';
|
||||||
import { jsonrepair } from 'jsonrepair';
|
import { jsonrepair } from 'jsonrepair';
|
||||||
import { GenerateImage } from './GenerateImage';
|
import { GenerateImage } from 'components/GenerateImage';
|
||||||
|
|
||||||
import './StyledMarkdown.css';
|
import './StyledMarkdown.css';
|
||||||
import { BackstoryElementProps } from './BackstoryTab';
|
import { BackstoryElementProps } from './BackstoryTab';
|
||||||
@ -98,13 +98,13 @@ const StyledMarkdown: React.FC<StyledMarkdownProps> = (props: StyledMarkdownProp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
ChatQuery: {
|
BackstoryQuery: {
|
||||||
component: (props: { query: string }) => {
|
component: (props: { query: string }) => {
|
||||||
const queryString = props.query.replace(/(\w+):/g, '"$1":');
|
const queryString = props.query.replace(/(\w+):/g, '"$1":');
|
||||||
try {
|
try {
|
||||||
const query = JSON.parse(queryString);
|
const query = JSON.parse(queryString);
|
||||||
|
|
||||||
return <ChatQuery submitQuery={submitQuery} query={query} />
|
return <BackstoryQuery submitQuery={submitQuery} query={query} />
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.log("StyledMarkdown error:", queryString, e);
|
console.log("StyledMarkdown error:", queryString, e);
|
||||||
return props.query;
|
return props.query;
|
||||||
|
@ -38,17 +38,17 @@ const DefaultNavItems: NavigationLinkType[] = [
|
|||||||
|
|
||||||
const CandidateNavItems : NavigationLinkType[]= [
|
const CandidateNavItems : NavigationLinkType[]= [
|
||||||
{ name: 'Chat', path: '/chat', icon: <ChatIcon /> },
|
{ name: 'Chat', path: '/chat', icon: <ChatIcon /> },
|
||||||
{ name: 'Job Analysis', path: '/job-analysis', icon: <WorkIcon /> },
|
// { name: 'Job Analysis', path: '/job-analysis', icon: <WorkIcon /> },
|
||||||
{ name: 'Resume Builder', path: '/resume-builder', icon: <WorkIcon /> },
|
{ name: 'Resume Builder', path: '/resume-builder', icon: <WorkIcon /> },
|
||||||
{ name: 'Knowledge Explorer', path: '/knowledge-explorer', icon: <WorkIcon /> },
|
// { name: 'Knowledge Explorer', path: '/knowledge-explorer', icon: <WorkIcon /> },
|
||||||
{ name: 'Find a Candidate', path: '/find-a-candidate', icon: <InfoIcon /> },
|
{ name: 'Find a Candidate', path: '/find-a-candidate', icon: <InfoIcon /> },
|
||||||
// { name: 'Dashboard', icon: <DashboardIcon />, path: '/dashboard' },
|
// { name: 'Dashboard', icon: <DashboardIcon />, path: '/dashboard' },
|
||||||
// { name: 'Profile', icon: <PersonIcon />, path: '/profile' },
|
// { name: 'Profile', icon: <PersonIcon />, path: '/profile' },
|
||||||
// { name: 'Backstory', icon: <HistoryIcon />, path: '/backstory' },
|
// { name: 'Backstory', icon: <HistoryIcon />, path: '/backstory' },
|
||||||
{ name: 'Resumes', icon: <DescriptionIcon />, path: '/resumes' },
|
// { name: 'Resumes', icon: <DescriptionIcon />, path: '/resumes' },
|
||||||
// { name: 'Q&A Setup', icon: <QuestionAnswerIcon />, path: '/qa-setup' },
|
// { name: 'Q&A Setup', icon: <QuestionAnswerIcon />, path: '/qa-setup' },
|
||||||
{ name: 'Analytics', icon: <BarChartIcon />, path: '/analytics' },
|
// { name: 'Analytics', icon: <BarChartIcon />, path: '/analytics' },
|
||||||
{ name: 'Settings', icon: <SettingsIcon />, path: '/settings' },
|
// { name: 'Settings', icon: <SettingsIcon />, path: '/settings' },
|
||||||
];
|
];
|
||||||
|
|
||||||
const EmployerNavItems: NavigationLinkType[] = [
|
const EmployerNavItems: NavigationLinkType[] = [
|
||||||
@ -121,13 +121,16 @@ const BackstoryPageContainer = (props : BackstoryPageContainerProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const BackstoryLayout: React.FC<{
|
interface BackstoryLayoutProps {
|
||||||
setSnack: SetSnackType;
|
setSnack: SetSnackType;
|
||||||
page: string;
|
page: string;
|
||||||
chatRef: React.Ref<any>;
|
chatRef: React.Ref<any>;
|
||||||
snackRef: React.Ref<any>;
|
snackRef: React.Ref<any>;
|
||||||
submitQuery: any;
|
submitQuery: any;
|
||||||
}> = ({ setSnack, page, chatRef, snackRef, submitQuery }) => {
|
};
|
||||||
|
|
||||||
|
const BackstoryLayout: React.FC<BackstoryLayoutProps> = (props: BackstoryLayoutProps) => {
|
||||||
|
const { setSnack, page, chatRef, snackRef, submitQuery } = props;
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
const location = useLocation();
|
const location = useLocation();
|
||||||
const { user, guest, candidate } = useUser();
|
const { user, guest, candidate } = useUser();
|
||||||
|
@ -17,6 +17,7 @@ import { CandidateListingPage } from 'pages/FindCandidatePage';
|
|||||||
import { JobAnalysisPage } from 'pages/JobAnalysisPage';
|
import { JobAnalysisPage } from 'pages/JobAnalysisPage';
|
||||||
import { GenerateCandidate } from "pages/GenerateCandidate";
|
import { GenerateCandidate } from "pages/GenerateCandidate";
|
||||||
import { ControlsPage } from 'pages/ControlsPage';
|
import { ControlsPage } from 'pages/ControlsPage';
|
||||||
|
import { LoginPage } from "pages/LoginPage";
|
||||||
|
|
||||||
const ProfilePage = () => (<BetaPage><Typography variant="h4">Profile</Typography></BetaPage>);
|
const ProfilePage = () => (<BetaPage><Typography variant="h4">Profile</Typography></BetaPage>);
|
||||||
const BackstoryPage = () => (<BetaPage><Typography variant="h4">Backstory</Typography></BetaPage>);
|
const BackstoryPage = () => (<BetaPage><Typography variant="h4">Backstory</Typography></BetaPage>);
|
||||||
@ -27,7 +28,6 @@ const SavedPage = () => (<BetaPage><Typography variant="h4">Saved</Typography></
|
|||||||
const JobsPage = () => (<BetaPage><Typography variant="h4">Jobs</Typography></BetaPage>);
|
const JobsPage = () => (<BetaPage><Typography variant="h4">Jobs</Typography></BetaPage>);
|
||||||
const CompanyPage = () => (<BetaPage><Typography variant="h4">Company</Typography></BetaPage>);
|
const CompanyPage = () => (<BetaPage><Typography variant="h4">Company</Typography></BetaPage>);
|
||||||
const LogoutPage = () => (<BetaPage><Typography variant="h4">Logout page...</Typography></BetaPage>);
|
const LogoutPage = () => (<BetaPage><Typography variant="h4">Logout page...</Typography></BetaPage>);
|
||||||
const LoginPage = () => (<BetaPage><Typography variant="h4">Login page...</Typography></BetaPage>);
|
|
||||||
// const DashboardPage = () => (<BetaPage><Typography variant="h4">Dashboard</Typography></BetaPage>);
|
// const DashboardPage = () => (<BetaPage><Typography variant="h4">Dashboard</Typography></BetaPage>);
|
||||||
// const AnalyticsPage = () => (<BetaPage><Typography variant="h4">Analytics</Typography></BetaPage>);
|
// const AnalyticsPage = () => (<BetaPage><Typography variant="h4">Analytics</Typography></BetaPage>);
|
||||||
// const SettingsPage = () => (<BetaPage><Typography variant="h4">Settings</Typography></BetaPage>);
|
// const SettingsPage = () => (<BetaPage><Typography variant="h4">Settings</Typography></BetaPage>);
|
||||||
@ -57,7 +57,7 @@ const getBackstoryDynamicRoutes = (props: BackstoryDynamicRoutesProps): ReactNod
|
|||||||
routes.push(<Route key={`${index++}`} path="/login" element={<LoginPage />} />);
|
routes.push(<Route key={`${index++}`} path="/login" element={<LoginPage />} />);
|
||||||
routes.push(<Route key={`${index++}`} path="*" element={<BetaPage />} />);
|
routes.push(<Route key={`${index++}`} path="*" element={<BetaPage />} />);
|
||||||
} else {
|
} else {
|
||||||
|
routes.push(<Route key={`${index++}`} path="/login" element={<LoginPage />} />);
|
||||||
routes.push(<Route key={`${index++}`} path="/logout" element={<LogoutPage />} />);
|
routes.push(<Route key={`${index++}`} path="/logout" element={<LogoutPage />} />);
|
||||||
|
|
||||||
if (user.userType === 'candidate') {
|
if (user.userType === 'candidate') {
|
||||||
|
@ -87,7 +87,6 @@ const MobileDrawer = styled(Drawer)(({ theme }) => ({
|
|||||||
|
|
||||||
interface HeaderProps {
|
interface HeaderProps {
|
||||||
transparent?: boolean;
|
transparent?: boolean;
|
||||||
onLogout?: () => void;
|
|
||||||
className?: string;
|
className?: string;
|
||||||
navigate: NavigateFunction;
|
navigate: NavigateFunction;
|
||||||
navigationLinks: NavigationLinkType[];
|
navigationLinks: NavigationLinkType[];
|
||||||
@ -98,7 +97,7 @@ interface HeaderProps {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const Header: React.FC<HeaderProps> = (props: HeaderProps) => {
|
const Header: React.FC<HeaderProps> = (props: HeaderProps) => {
|
||||||
const { user } = useUser();
|
const { user, setUser } = useUser();
|
||||||
const candidate: Candidate | null = (user && user.userType === "candidate") ? user as Candidate : null;
|
const candidate: Candidate | null = (user && user.userType === "candidate") ? user as Candidate : null;
|
||||||
const employer: Employer | null = (user && user.userType === "employer") ? user as Employer : null;
|
const employer: Employer | null = (user && user.userType === "employer") ? user as Employer : null;
|
||||||
const {
|
const {
|
||||||
@ -108,7 +107,6 @@ const Header: React.FC<HeaderProps> = (props: HeaderProps) => {
|
|||||||
navigationLinks,
|
navigationLinks,
|
||||||
showLogin,
|
showLogin,
|
||||||
sessionId,
|
sessionId,
|
||||||
onLogout,
|
|
||||||
setSnack,
|
setSnack,
|
||||||
} = props;
|
} = props;
|
||||||
const theme = useTheme();
|
const theme = useTheme();
|
||||||
@ -177,9 +175,7 @@ const Header: React.FC<HeaderProps> = (props: HeaderProps) => {
|
|||||||
|
|
||||||
const handleLogout = () => {
|
const handleLogout = () => {
|
||||||
handleUserMenuClose();
|
handleUserMenuClose();
|
||||||
if (onLogout) {
|
setUser(null);
|
||||||
onLogout();
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleDrawerToggle = () => {
|
const handleDrawerToggle = () => {
|
||||||
@ -245,14 +241,6 @@ const Header: React.FC<HeaderProps> = (props: HeaderProps) => {
|
|||||||
>
|
>
|
||||||
Login
|
Login
|
||||||
</Button>
|
</Button>
|
||||||
<Button
|
|
||||||
variant="outlined"
|
|
||||||
color="secondary"
|
|
||||||
fullWidth
|
|
||||||
onClick={() => { navigate("/register"); }}
|
|
||||||
>
|
|
||||||
Register
|
|
||||||
</Button>
|
|
||||||
</Box>
|
</Box>
|
||||||
)}
|
)}
|
||||||
</>
|
</>
|
||||||
@ -279,14 +267,6 @@ const Header: React.FC<HeaderProps> = (props: HeaderProps) => {
|
|||||||
>
|
>
|
||||||
Login
|
Login
|
||||||
</Button>
|
</Button>
|
||||||
<Button
|
|
||||||
color="secondary"
|
|
||||||
variant="contained"
|
|
||||||
onClick={() => { navigate("/register"); }}
|
|
||||||
sx={{ display: { xs: 'none', sm: 'block' } }}
|
|
||||||
>
|
|
||||||
Register
|
|
||||||
</Button>
|
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -1,11 +1,16 @@
|
|||||||
import React, { createContext, useContext, useEffect, useState } from "react";
|
import React, { createContext, useContext, useEffect, useState } from "react";
|
||||||
import { SetSnackType } from '../components/Snack';
|
import { SetSnackType } from '../components/Snack';
|
||||||
import { User, Guest, Candidate } from 'types/types';
|
import { User, Guest, Candidate } from 'types/types';
|
||||||
|
import { ApiClient } from "types/api-client";
|
||||||
|
import { debugConversion } from "types/conversion";
|
||||||
|
|
||||||
type UserContextType = {
|
type UserContextType = {
|
||||||
|
apiClient: ApiClient;
|
||||||
user: User | null;
|
user: User | null;
|
||||||
guest: Guest;
|
guest: Guest;
|
||||||
candidate: Candidate | null;
|
candidate: Candidate | null;
|
||||||
|
setUser: (user: User | null) => void;
|
||||||
|
setCandidate: (candidate: Candidate | null) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
const UserContext = createContext<UserContextType | undefined>(undefined);
|
const UserContext = createContext<UserContextType | undefined>(undefined);
|
||||||
@ -18,19 +23,136 @@ const useUser = () => {
|
|||||||
|
|
||||||
interface UserProviderProps {
|
interface UserProviderProps {
|
||||||
children: React.ReactNode;
|
children: React.ReactNode;
|
||||||
candidate: Candidate | null;
|
|
||||||
user: User | null;
|
|
||||||
guest: Guest | null;
|
|
||||||
setSnack: SetSnackType;
|
setSnack: SetSnackType;
|
||||||
};
|
};
|
||||||
|
|
||||||
const UserProvider: React.FC<UserProviderProps> = (props: UserProviderProps) => {
|
const UserProvider: React.FC<UserProviderProps> = (props: UserProviderProps) => {
|
||||||
const { guest, user, children, candidate, setSnack } = props;
|
const { children, setSnack } = props;
|
||||||
|
const [apiClient, setApiClient] = useState<ApiClient>(new ApiClient());
|
||||||
|
const [candidate, setCandidate] = useState<Candidate | null>(null);
|
||||||
|
const [guest, setGuest] = useState<Guest | null>(null);
|
||||||
|
const [user, setUser] = useState<User | null>(null);
|
||||||
|
const [activeUser, setActiveUser] = useState<User | null>(null);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
console.log("Candidate =>", candidate);
|
||||||
|
}, [candidate]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
console.log("Guest =>", guest);
|
||||||
|
}, [guest]);
|
||||||
|
|
||||||
|
/* If the user changes to a non-null value, create a new
|
||||||
|
* apiClient with the access token */
|
||||||
|
useEffect(() => {
|
||||||
|
console.log("User => ", user);
|
||||||
|
if (user === null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
/* This apiClient will persist until the user is changed
|
||||||
|
* or logged out */
|
||||||
|
const accessToken = localStorage.getItem('accessToken');
|
||||||
|
if (!accessToken) {
|
||||||
|
throw Error("accessToken is not set for user!");
|
||||||
|
}
|
||||||
|
setApiClient(new ApiClient(accessToken));
|
||||||
|
}, [user]);
|
||||||
|
|
||||||
|
/* Handle logout if any consumers of UserProvider setUser to NULL */
|
||||||
|
useEffect(() => {
|
||||||
|
/* If there is an active user and it is the same as the
|
||||||
|
* new user, do nothing */
|
||||||
|
if (activeUser && activeUser.email === user?.email) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const logout = async () => {
|
||||||
|
if (!activeUser) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
console.log(`Logging out ${activeUser.email}`);
|
||||||
|
try {
|
||||||
|
const accessToken = localStorage.getItem('accessToken');
|
||||||
|
const refreshToken = localStorage.getItem('refreshToken');
|
||||||
|
if (!accessToken || !refreshToken) {
|
||||||
|
setSnack("Authentication tokens are invalid.", "error");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const results = await apiClient.logout(accessToken, refreshToken);
|
||||||
|
if (results.error) {
|
||||||
|
console.error(results.error);
|
||||||
|
setSnack(results.error.message, "error")
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
setSnack(`Unable to logout: ${e}`, "error")
|
||||||
|
}
|
||||||
|
|
||||||
|
localStorage.removeItem('accessToken');
|
||||||
|
localStorage.removeItem('refreshToken');
|
||||||
|
localStorage.removeItem('userData');
|
||||||
|
createGuestSession();
|
||||||
|
setUser(null);
|
||||||
|
};
|
||||||
|
|
||||||
|
setActiveUser(user);
|
||||||
|
if (!user) {
|
||||||
|
logout();
|
||||||
|
}
|
||||||
|
}, [user, apiClient, activeUser]);
|
||||||
|
|
||||||
|
const createGuestSession = () => {
|
||||||
|
console.log("TODO: Convert this to query the server for the session instead of generating it.");
|
||||||
|
const sessionId = `guest_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
|
||||||
|
const guest: Guest = {
|
||||||
|
sessionId,
|
||||||
|
createdAt: new Date(),
|
||||||
|
lastActivity: new Date(),
|
||||||
|
ipAddress: 'unknown',
|
||||||
|
userAgent: navigator.userAgent
|
||||||
|
};
|
||||||
|
setGuest(guest);
|
||||||
|
debugConversion(guest, 'Guest Session');
|
||||||
|
};
|
||||||
|
|
||||||
|
const checkExistingAuth = () => {
|
||||||
|
const accessToken = localStorage.getItem('accessToken');
|
||||||
|
const userData = localStorage.getItem('userData');
|
||||||
|
if (accessToken && userData) {
|
||||||
|
try {
|
||||||
|
const user = JSON.parse(userData);
|
||||||
|
// Convert dates back to Date objects if they're stored as strings
|
||||||
|
if (user.createdAt && typeof user.createdAt === 'string') {
|
||||||
|
user.createdAt = new Date(user.createdAt);
|
||||||
|
}
|
||||||
|
if (user.updatedAt && typeof user.updatedAt === 'string') {
|
||||||
|
user.updatedAt = new Date(user.updatedAt);
|
||||||
|
}
|
||||||
|
if (user.lastLogin && typeof user.lastLogin === 'string') {
|
||||||
|
user.lastLogin = new Date(user.lastLogin);
|
||||||
|
}
|
||||||
|
setApiClient(new ApiClient(accessToken));
|
||||||
|
setUser(user);
|
||||||
|
} catch (e) {
|
||||||
|
localStorage.removeItem('accessToken');
|
||||||
|
localStorage.removeItem('refreshToken');
|
||||||
|
localStorage.removeItem('userData');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create guest session on component mount
|
||||||
|
useEffect(() => {
|
||||||
|
createGuestSession();
|
||||||
|
checkExistingAuth();
|
||||||
|
}, []);
|
||||||
|
|
||||||
if (guest === null) {
|
if (guest === null) {
|
||||||
return <></>;
|
return <></>;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<UserContext.Provider value={{ candidate, user, guest }}>
|
<UserContext.Provider value={{ apiClient, candidate, setCandidate, user, setUser, guest }}>
|
||||||
{children}
|
{children}
|
||||||
</UserContext.Provider>
|
</UserContext.Provider>
|
||||||
);
|
);
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import React, { useState, useEffect } from 'react';
|
import React, { useState, useEffect } from 'react';
|
||||||
import { useNavigate } from 'react-router-dom';
|
import { useNavigate, useLocation } from 'react-router-dom';
|
||||||
import {
|
import {
|
||||||
Box,
|
Box,
|
||||||
Container,
|
Container,
|
||||||
@ -35,7 +35,11 @@ const BetaPage: React.FC<BetaPageProps> = ({
|
|||||||
const theme = useTheme();
|
const theme = useTheme();
|
||||||
const [showSparkle, setShowSparkle] = useState<boolean>(false);
|
const [showSparkle, setShowSparkle] = useState<boolean>(false);
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
|
const location = useLocation();
|
||||||
|
|
||||||
|
if (!children) {
|
||||||
|
children = (<Box>Location: {location.pathname}</Box>);
|
||||||
|
}
|
||||||
console.log("BetaPage", children);
|
console.log("BetaPage", children);
|
||||||
|
|
||||||
// Enhanced sparkle effect for background elements
|
// Enhanced sparkle effect for background elements
|
||||||
|
@ -6,19 +6,18 @@ import MuiMarkdown from 'mui-markdown';
|
|||||||
|
|
||||||
import { BackstoryPageProps } from '../components/BackstoryTab';
|
import { BackstoryPageProps } from '../components/BackstoryTab';
|
||||||
import { Conversation, ConversationHandle } from '../components/Conversation';
|
import { Conversation, ConversationHandle } from '../components/Conversation';
|
||||||
import { ChatQuery } from '../components/ChatQuery';
|
import { BackstoryQuery } from '../components/BackstoryQuery';
|
||||||
import { CandidateInfo } from 'components/CandidateInfo';
|
import { CandidateInfo } from 'components/CandidateInfo';
|
||||||
import { useUser } from "../hooks/useUser";
|
import { useUser } from "../hooks/useUser";
|
||||||
import { Candidate } from "../types/types";
|
|
||||||
|
|
||||||
const ChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: BackstoryPageProps, ref) => {
|
const ChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: BackstoryPageProps, ref) => {
|
||||||
const { setSnack, submitQuery } = props;
|
const { setSnack, submitQuery } = props;
|
||||||
|
const { candidate } = useUser();
|
||||||
const theme = useTheme();
|
const theme = useTheme();
|
||||||
const isMobile = useMediaQuery(theme.breakpoints.down('md'));
|
const isMobile = useMediaQuery(theme.breakpoints.down('md'));
|
||||||
const [questions, setQuestions] = useState<React.ReactElement[]>([]);
|
const [questions, setQuestions] = useState<React.ReactElement[]>([]);
|
||||||
const { user } = useUser();
|
|
||||||
const candidate: Candidate | null = (user && user.userType === 'candidate') ? user as Candidate : null;
|
|
||||||
|
|
||||||
|
console.log("ChatPage candidate =>", candidate);
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!candidate) {
|
if (!candidate) {
|
||||||
return;
|
return;
|
||||||
@ -27,7 +26,7 @@ const ChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: Back
|
|||||||
setQuestions([
|
setQuestions([
|
||||||
<Box sx={{ display: "flex", flexDirection: isMobile ? "column" : "row" }}>
|
<Box sx={{ display: "flex", flexDirection: isMobile ? "column" : "row" }}>
|
||||||
{candidate.questions?.map(({ question, tunables }, i: number) =>
|
{candidate.questions?.map(({ question, tunables }, i: number) =>
|
||||||
<ChatQuery key={i} query={{ prompt: question, tunables: tunables }} submitQuery={submitQuery} />
|
<BackstoryQuery key={i} query={{ prompt: question, tunables: tunables }} submitQuery={submitQuery} />
|
||||||
)}
|
)}
|
||||||
</Box>,
|
</Box>,
|
||||||
<Box sx={{ p: 1 }}>
|
<Box sx={{ p: 1 }}>
|
||||||
@ -38,8 +37,8 @@ const ChatPage = forwardRef<ConversationHandle, BackstoryPageProps>((props: Back
|
|||||||
}, [candidate, isMobile, submitQuery]);
|
}, [candidate, isMobile, submitQuery]);
|
||||||
|
|
||||||
if (!candidate) {
|
if (!candidate) {
|
||||||
return (<></>);
|
return (<></>);
|
||||||
}
|
}
|
||||||
return (
|
return (
|
||||||
<Box>
|
<Box>
|
||||||
<CandidateInfo candidate={candidate} action="Chat with Backstory AI about " />
|
<CandidateInfo candidate={candidate} action="Chat with Backstory AI about " />
|
||||||
|
@ -100,7 +100,7 @@ const ControlsPage = (props: BackstoryPageProps) => {
|
|||||||
}
|
}
|
||||||
const sendSystemPrompt = async (prompt: string) => {
|
const sendSystemPrompt = async (prompt: string) => {
|
||||||
try {
|
try {
|
||||||
const response = await fetch(connectionBase + `/api/tunables`, {
|
const response = await fetch(connectionBase + `/api/1.0/tunables`, {
|
||||||
method: 'PUT',
|
method: 'PUT',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
@ -126,7 +126,7 @@ const ControlsPage = (props: BackstoryPageProps) => {
|
|||||||
|
|
||||||
const reset = async (types: ("rags" | "tools" | "history" | "system_prompt")[], message: string = "Update successful.") => {
|
const reset = async (types: ("rags" | "tools" | "history" | "system_prompt")[], message: string = "Update successful.") => {
|
||||||
try {
|
try {
|
||||||
const response = await fetch(connectionBase + `/api/reset/`, {
|
const response = await fetch(connectionBase + `/api/1.0/reset/`, {
|
||||||
method: 'PUT',
|
method: 'PUT',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
@ -178,7 +178,7 @@ const ControlsPage = (props: BackstoryPageProps) => {
|
|||||||
}
|
}
|
||||||
const fetchSystemInfo = async () => {
|
const fetchSystemInfo = async () => {
|
||||||
try {
|
try {
|
||||||
const response = await fetch(connectionBase + `/api/system-info`, {
|
const response = await fetch(connectionBase + `/api/1.0/system-info`, {
|
||||||
method: 'GET',
|
method: 'GET',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
@ -210,13 +210,16 @@ const ControlsPage = (props: BackstoryPageProps) => {
|
|||||||
}, [systemInfo, setSystemInfo, setSnack])
|
}, [systemInfo, setSystemInfo, setSnack])
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
if (!systemPrompt) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
setEditSystemPrompt(systemPrompt.trim());
|
setEditSystemPrompt(systemPrompt.trim());
|
||||||
}, [systemPrompt, setEditSystemPrompt]);
|
}, [systemPrompt, setEditSystemPrompt]);
|
||||||
|
|
||||||
const toggleRag = async (tool: Tool) => {
|
const toggleRag = async (tool: Tool) => {
|
||||||
tool.enabled = !tool.enabled
|
tool.enabled = !tool.enabled
|
||||||
try {
|
try {
|
||||||
const response = await fetch(connectionBase + `/api/tunables`, {
|
const response = await fetch(connectionBase + `/api/1.0/tunables`, {
|
||||||
method: 'PUT',
|
method: 'PUT',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
@ -238,7 +241,7 @@ const ControlsPage = (props: BackstoryPageProps) => {
|
|||||||
const toggleTool = async (tool: Tool) => {
|
const toggleTool = async (tool: Tool) => {
|
||||||
tool.enabled = !tool.enabled
|
tool.enabled = !tool.enabled
|
||||||
try {
|
try {
|
||||||
const response = await fetch(connectionBase + `/api/tunables`, {
|
const response = await fetch(connectionBase + `/api/1.0/tunables`, {
|
||||||
method: 'PUT',
|
method: 'PUT',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
@ -265,7 +268,7 @@ const ControlsPage = (props: BackstoryPageProps) => {
|
|||||||
const fetchTunables = async () => {
|
const fetchTunables = async () => {
|
||||||
try {
|
try {
|
||||||
// Make the fetch request with proper headers
|
// Make the fetch request with proper headers
|
||||||
const response = await fetch(connectionBase + `/api/tunables`, {
|
const response = await fetch(connectionBase + `/api/1.0/tunables`, {
|
||||||
method: 'GET',
|
method: 'GET',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
|
@ -5,11 +5,11 @@ import Box from '@mui/material/Box';
|
|||||||
|
|
||||||
import { BackstoryPageProps } from '../components/BackstoryTab';
|
import { BackstoryPageProps } from '../components/BackstoryTab';
|
||||||
import { CandidateInfo } from 'components/CandidateInfo';
|
import { CandidateInfo } from 'components/CandidateInfo';
|
||||||
import { connectionBase } from '../utils/Global';
|
|
||||||
import { Candidate } from "../types/types";
|
import { Candidate } from "../types/types";
|
||||||
import { ApiClient } from 'types/api-client';
|
import { useUser } from 'hooks/useUser';
|
||||||
|
|
||||||
const CandidateListingPage = (props: BackstoryPageProps) => {
|
const CandidateListingPage = (props: BackstoryPageProps) => {
|
||||||
const apiClient = new ApiClient();
|
const { apiClient, setCandidate } = useUser();
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
const { setSnack } = props;
|
const { setSnack } = props;
|
||||||
const [candidates, setCandidates] = useState<Candidate[] | null>(null);
|
const [candidates, setCandidates] = useState<Candidate[] | null>(null);
|
||||||
@ -44,27 +44,24 @@ const CandidateListingPage = (props: BackstoryPageProps) => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<Box sx={{display: "flex", flexDirection: "column"}}>
|
<Box sx={{display: "flex", flexDirection: "column"}}>
|
||||||
<Box sx={{ p: 1, textAlign: "center" }}>
|
<Box sx={{ p: 1, textAlign: "center" }}>
|
||||||
Not seeing a candidate you like?
|
Not seeing a candidate you like?
|
||||||
<Button
|
<Button
|
||||||
variant="contained"
|
variant="contained"
|
||||||
sx={{m: 1}}
|
sx={{ m: 1 }}
|
||||||
onClick={() => { navigate('/generate-candidate')}}>
|
onClick={() => { navigate('/generate-candidate') }}>
|
||||||
Generate your own perfect AI candidate!
|
Generate your own perfect AI candidate!
|
||||||
</Button>
|
</Button>
|
||||||
</Box>
|
</Box>
|
||||||
<Box sx={{ display: "flex", gap: 1, flexWrap: "wrap"}}>
|
<Box sx={{ display: "flex", gap: 1, flexWrap: "wrap" }}>
|
||||||
{candidates?.map((u, i) =>
|
{candidates?.map((u, i) =>
|
||||||
<Box key={`${u.username}`}
|
<Box key={`${u.username}`}
|
||||||
onClick={(event: React.MouseEvent<HTMLDivElement>) : void => {
|
onClick={() => { setCandidate(u); navigate("/chat"); }}
|
||||||
navigate(`/u/${u.username}`)
|
sx={{ cursor: "pointer" }}>
|
||||||
}}
|
|
||||||
sx={{ cursor: "pointer" }}
|
|
||||||
>
|
|
||||||
<CandidateInfo sx={{ maxWidth: "320px", "cursor": "pointer", "&:hover": { border: "2px solid orange" }, border: "2px solid transparent" }} candidate={u} />
|
<CandidateInfo sx={{ maxWidth: "320px", "cursor": "pointer", "&:hover": { border: "2px solid orange" }, border: "2px solid transparent" }} candidate={u} />
|
||||||
</Box>
|
</Box>
|
||||||
)}
|
)}
|
||||||
</Box>
|
</Box>
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -10,9 +10,7 @@ import SendIcon from '@mui/icons-material/Send';
|
|||||||
import PropagateLoader from 'react-spinners/PropagateLoader';
|
import PropagateLoader from 'react-spinners/PropagateLoader';
|
||||||
import { jsonrepair } from 'jsonrepair';
|
import { jsonrepair } from 'jsonrepair';
|
||||||
|
|
||||||
|
|
||||||
import { CandidateInfo } from '../components/CandidateInfo';
|
import { CandidateInfo } from '../components/CandidateInfo';
|
||||||
import { Query } from '../types/types'
|
|
||||||
import { Quote } from 'components/Quote';
|
import { Quote } from 'components/Quote';
|
||||||
import { Candidate } from '../types/types';
|
import { Candidate } from '../types/types';
|
||||||
import { BackstoryElementProps } from 'components/BackstoryTab';
|
import { BackstoryElementProps } from 'components/BackstoryTab';
|
||||||
@ -21,6 +19,8 @@ import { StyledMarkdown } from 'components/StyledMarkdown';
|
|||||||
import { Scrollable } from '../components/Scrollable';
|
import { Scrollable } from '../components/Scrollable';
|
||||||
import { Pulse } from 'components/Pulse';
|
import { Pulse } from 'components/Pulse';
|
||||||
import { StreamingResponse } from 'types/api-client';
|
import { StreamingResponse } from 'types/api-client';
|
||||||
|
import { ChatContext, ChatSession, ChatQuery } from 'types/types';
|
||||||
|
import { useUser } from 'hooks/useUser';
|
||||||
|
|
||||||
const emptyUser: Candidate = {
|
const emptyUser: Candidate = {
|
||||||
description: "[blank]",
|
description: "[blank]",
|
||||||
@ -46,6 +46,7 @@ const emptyUser: Candidate = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const GenerateCandidate = (props: BackstoryElementProps) => {
|
const GenerateCandidate = (props: BackstoryElementProps) => {
|
||||||
|
const { apiClient } = useUser();
|
||||||
const { setSnack, submitQuery } = props;
|
const { setSnack, submitQuery } = props;
|
||||||
const [streaming, setStreaming] = useState<string>('');
|
const [streaming, setStreaming] = useState<string>('');
|
||||||
const [processing, setProcessing] = useState<boolean>(false);
|
const [processing, setProcessing] = useState<boolean>(false);
|
||||||
@ -57,16 +58,40 @@ const GenerateCandidate = (props: BackstoryElementProps) => {
|
|||||||
const [timestamp, setTimestamp] = useState<number>(0);
|
const [timestamp, setTimestamp] = useState<number>(0);
|
||||||
const [state, setState] = useState<number>(0); // Replaced stateRef
|
const [state, setState] = useState<number>(0); // Replaced stateRef
|
||||||
const [shouldGenerateProfile, setShouldGenerateProfile] = useState<boolean>(false);
|
const [shouldGenerateProfile, setShouldGenerateProfile] = useState<boolean>(false);
|
||||||
|
const [chatSession, setChatSession] = useState<ChatSession | null>(null);
|
||||||
|
|
||||||
// Only keep refs that are truly necessary
|
// Only keep refs that are truly necessary
|
||||||
const controllerRef = useRef<StreamingResponse>(null);
|
const controllerRef = useRef<StreamingResponse>(null);
|
||||||
const backstoryTextRef = useRef<BackstoryTextFieldRef>(null);
|
const backstoryTextRef = useRef<BackstoryTextFieldRef>(null);
|
||||||
|
|
||||||
const generatePersona = useCallback((query: Query) => {
|
/* Create the chat session */
|
||||||
if (controllerRef.current) {
|
useEffect(() => {
|
||||||
|
if (chatSession) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
setPrompt(query.prompt);
|
|
||||||
|
const createChatSession = async () => {
|
||||||
|
try {
|
||||||
|
const chatContext: ChatContext = { type: "generate_persona" };
|
||||||
|
const response: ChatSession = await apiClient.createChatSession(chatContext);
|
||||||
|
setChatSession(response);
|
||||||
|
setSnack(`Chat session created for generate_persona: ${response.id}`);
|
||||||
|
} catch (e) {
|
||||||
|
console.error(e);
|
||||||
|
setSnack("Unable to create chat session.", "error");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
createChatSession();
|
||||||
|
}, [chatSession, setChatSession]);
|
||||||
|
|
||||||
|
const generatePersona = useCallback((query: ChatQuery) => {
|
||||||
|
if (!chatSession || !chatSession.id) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const sessionId: string = chatSession.id;
|
||||||
|
|
||||||
|
setPrompt(query.prompt || '');
|
||||||
setState(0);
|
setState(0);
|
||||||
setStatus("Generating persona...");
|
setStatus("Generating persona...");
|
||||||
setUser(emptyUser);
|
setUser(emptyUser);
|
||||||
@ -76,6 +101,24 @@ const GenerateCandidate = (props: BackstoryElementProps) => {
|
|||||||
setCanGenImage(false);
|
setCanGenImage(false);
|
||||||
setShouldGenerateProfile(false); // Reset the flag
|
setShouldGenerateProfile(false); // Reset the flag
|
||||||
|
|
||||||
|
const streamResponse = apiClient.sendMessageStream(sessionId, query, {
|
||||||
|
onPartialMessage: (content, messageId) => {
|
||||||
|
console.log('Partial content:', content);
|
||||||
|
// Update UI with partial content
|
||||||
|
},
|
||||||
|
onStatusChange: (status) => {
|
||||||
|
console.log('Status changed:', status);
|
||||||
|
// Update UI status indicator
|
||||||
|
},
|
||||||
|
onComplete: (finalMessage) => {
|
||||||
|
console.log('Final message:', finalMessage.content);
|
||||||
|
// Handle completed message
|
||||||
|
},
|
||||||
|
onError: (error) => {
|
||||||
|
console.error('Streaming error:', error);
|
||||||
|
// Handle error
|
||||||
|
}
|
||||||
|
});
|
||||||
// controllerRef.current = streamQueryResponse({
|
// controllerRef.current = streamQueryResponse({
|
||||||
// query,
|
// query,
|
||||||
// type: "persona",
|
// type: "persona",
|
||||||
@ -148,7 +191,7 @@ const GenerateCandidate = (props: BackstoryElementProps) => {
|
|||||||
if (processing) {
|
if (processing) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const query: Query = {
|
const query: ChatQuery = {
|
||||||
prompt: value,
|
prompt: value,
|
||||||
}
|
}
|
||||||
generatePersona(query);
|
generatePersona(query);
|
||||||
|
@ -22,9 +22,10 @@ import { Person, PersonAdd, AccountCircle, ExitToApp } from '@mui/icons-material
|
|||||||
import 'react-phone-number-input/style.css';
|
import 'react-phone-number-input/style.css';
|
||||||
import PhoneInput from 'react-phone-number-input';
|
import PhoneInput from 'react-phone-number-input';
|
||||||
import { E164Number } from 'libphonenumber-js/core';
|
import { E164Number } from 'libphonenumber-js/core';
|
||||||
import './PhoneInput.css';
|
import './LoginPage.css';
|
||||||
|
|
||||||
import { ApiClient } from 'types/api-client';
|
import { ApiClient } from 'types/api-client';
|
||||||
|
import { useUser } from 'hooks/useUser';
|
||||||
|
|
||||||
// Import conversion utilities
|
// Import conversion utilities
|
||||||
import {
|
import {
|
||||||
@ -35,11 +36,12 @@ import {
|
|||||||
isSuccessResponse,
|
isSuccessResponse,
|
||||||
debugConversion,
|
debugConversion,
|
||||||
type ApiResponse
|
type ApiResponse
|
||||||
} from './types/conversion';
|
} from 'types/conversion';
|
||||||
|
|
||||||
import {
|
import {
|
||||||
AuthResponse, User, Guest, Candidate
|
AuthResponse, User, Guest, Candidate
|
||||||
} from './types/types'
|
} from 'types/types'
|
||||||
|
import { useNavigate } from 'react-router-dom';
|
||||||
|
|
||||||
interface LoginRequest {
|
interface LoginRequest {
|
||||||
login: string;
|
login: string;
|
||||||
@ -55,16 +57,17 @@ interface RegisterRequest {
|
|||||||
phone?: string;
|
phone?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
const BackstoryTestApp: React.FC = () => {
|
const apiClient = new ApiClient();
|
||||||
const apiClient = new ApiClient();
|
|
||||||
const [currentUser, setCurrentUser] = useState<User | null>(null);
|
const LoginPage: React.FC = () => {
|
||||||
const [guestSession, setGuestSession] = useState<Guest | null>(null);
|
const navigate = useNavigate();
|
||||||
|
const { user, setUser, guest } = useUser();
|
||||||
const [tabValue, setTabValue] = useState(0);
|
const [tabValue, setTabValue] = useState(0);
|
||||||
const [loading, setLoading] = useState(false);
|
const [loading, setLoading] = useState(false);
|
||||||
const [error, setError] = useState<string | null>(null);
|
const [error, setError] = useState<string | null>(null);
|
||||||
const [success, setSuccess] = useState<string | null>(null);
|
const [success, setSuccess] = useState<string | null>(null);
|
||||||
const [phone, setPhone] = useState<E164Number | null>(null);
|
const [phone, setPhone] = useState<E164Number | null>(null);
|
||||||
const name = (currentUser?.userType === 'candidate' ? (currentUser as Candidate).username : currentUser?.email) || '';
|
const name = (user?.userType === 'candidate' ? (user as Candidate).username : user?.email) || '';
|
||||||
|
|
||||||
// Login form state
|
// Login form state
|
||||||
const [loginForm, setLoginForm] = useState<LoginRequest>({
|
const [loginForm, setLoginForm] = useState<LoginRequest>({
|
||||||
@ -82,12 +85,6 @@ const BackstoryTestApp: React.FC = () => {
|
|||||||
phone: ''
|
phone: ''
|
||||||
});
|
});
|
||||||
|
|
||||||
// Create guest session on component mount
|
|
||||||
useEffect(() => {
|
|
||||||
createGuestSession();
|
|
||||||
checkExistingAuth();
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (phone !== registerForm.phone && phone) {
|
if (phone !== registerForm.phone && phone) {
|
||||||
console.log({ phone });
|
console.log({ phone });
|
||||||
@ -95,44 +92,6 @@ const BackstoryTestApp: React.FC = () => {
|
|||||||
}
|
}
|
||||||
}, [phone, registerForm]);
|
}, [phone, registerForm]);
|
||||||
|
|
||||||
const createGuestSession = () => {
|
|
||||||
const sessionId = `guest_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
|
|
||||||
const guest: Guest = {
|
|
||||||
sessionId,
|
|
||||||
createdAt: new Date(),
|
|
||||||
lastActivity: new Date(),
|
|
||||||
ipAddress: 'unknown',
|
|
||||||
userAgent: navigator.userAgent
|
|
||||||
};
|
|
||||||
setGuestSession(guest);
|
|
||||||
debugConversion(guest, 'Guest Session');
|
|
||||||
};
|
|
||||||
|
|
||||||
const checkExistingAuth = () => {
|
|
||||||
const token = localStorage.getItem('accessToken');
|
|
||||||
const userData = localStorage.getItem('userData');
|
|
||||||
if (token && userData) {
|
|
||||||
try {
|
|
||||||
const user = JSON.parse(userData);
|
|
||||||
// Convert dates back to Date objects if they're stored as strings
|
|
||||||
if (user.createdAt && typeof user.createdAt === 'string') {
|
|
||||||
user.createdAt = new Date(user.createdAt);
|
|
||||||
}
|
|
||||||
if (user.updatedAt && typeof user.updatedAt === 'string') {
|
|
||||||
user.updatedAt = new Date(user.updatedAt);
|
|
||||||
}
|
|
||||||
if (user.lastLogin && typeof user.lastLogin === 'string') {
|
|
||||||
user.lastLogin = new Date(user.lastLogin);
|
|
||||||
}
|
|
||||||
setCurrentUser(user);
|
|
||||||
} catch (e) {
|
|
||||||
localStorage.removeItem('accessToken');
|
|
||||||
localStorage.removeItem('refreshToken');
|
|
||||||
localStorage.removeItem('userData');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleLogin = async (e: React.FormEvent) => {
|
const handleLogin = async (e: React.FormEvent) => {
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
@ -149,12 +108,11 @@ const BackstoryTestApp: React.FC = () => {
|
|||||||
localStorage.setItem('refreshToken', authResponse.refreshToken);
|
localStorage.setItem('refreshToken', authResponse.refreshToken);
|
||||||
localStorage.setItem('userData', JSON.stringify(authResponse.user));
|
localStorage.setItem('userData', JSON.stringify(authResponse.user));
|
||||||
|
|
||||||
setCurrentUser(authResponse.user);
|
|
||||||
setSuccess('Login successful!');
|
setSuccess('Login successful!');
|
||||||
|
navigate('/');
|
||||||
|
setUser(authResponse.user);
|
||||||
// Clear form
|
// Clear form
|
||||||
setLoginForm({ login: '', password: '' });
|
setLoginForm({ login: '', password: '' });
|
||||||
|
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Login error:', err);
|
console.error('Login error:', err);
|
||||||
setError(err instanceof Error ? err.message : 'Login failed');
|
setError(err instanceof Error ? err.message : 'Login failed');
|
||||||
@ -219,113 +177,68 @@ const BackstoryTestApp: React.FC = () => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleLogout = () => {
|
|
||||||
localStorage.removeItem('accessToken');
|
|
||||||
localStorage.removeItem('refreshToken');
|
|
||||||
localStorage.removeItem('userData');
|
|
||||||
setCurrentUser(null);
|
|
||||||
setSuccess('Logged out successfully');
|
|
||||||
createGuestSession();
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleTabChange = (event: React.SyntheticEvent, newValue: number) => {
|
const handleTabChange = (event: React.SyntheticEvent, newValue: number) => {
|
||||||
setTabValue(newValue);
|
setTabValue(newValue);
|
||||||
setError(null);
|
setError(null);
|
||||||
setSuccess(null);
|
setSuccess(null);
|
||||||
};
|
};
|
||||||
|
|
||||||
// API helper function for authenticated requests
|
|
||||||
const makeAuthenticatedRequest = async (url: string, options: RequestInit = {}) => {
|
|
||||||
const token = localStorage.getItem('accessToken');
|
|
||||||
|
|
||||||
const headers = {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
...(token && { 'Authorization': `Bearer ${token}` }),
|
|
||||||
...options.headers,
|
|
||||||
};
|
|
||||||
|
|
||||||
const response = await fetch(url, {
|
|
||||||
...options,
|
|
||||||
headers,
|
|
||||||
});
|
|
||||||
|
|
||||||
return handleApiResponse(response);
|
|
||||||
};
|
|
||||||
|
|
||||||
// If user is logged in, show their profile
|
// If user is logged in, show their profile
|
||||||
if (currentUser) {
|
if (user) {
|
||||||
return (
|
return (
|
||||||
<Box sx={{ flexGrow: 1 }}>
|
<Container maxWidth="md" sx={{ mt: 4 }}>
|
||||||
<AppBar position="static">
|
<Card elevation={3}>
|
||||||
<Toolbar>
|
<CardContent>
|
||||||
<AccountCircle sx={{ mr: 2 }} />
|
<Box sx={{ display: 'flex', alignItems: 'center', mb: 3 }}>
|
||||||
<Typography variant="h6" component="div" sx={{ flexGrow: 1 }}>
|
<Avatar sx={{ mr: 2, bgcolor: 'primary.main' }}>
|
||||||
Welcome, {name}
|
<AccountCircle />
|
||||||
</Typography>
|
</Avatar>
|
||||||
<Button
|
<Typography variant="h4" component="h1">
|
||||||
color="inherit"
|
User Profile
|
||||||
onClick={handleLogout}
|
</Typography>
|
||||||
startIcon={<ExitToApp />}
|
</Box>
|
||||||
>
|
|
||||||
Logout
|
<Divider sx={{ mb: 3 }} />
|
||||||
</Button>
|
|
||||||
</Toolbar>
|
<Grid container spacing={3}>
|
||||||
</AppBar>
|
<Grid size={{ xs: 12, md: 6 }}>
|
||||||
|
<Typography variant="body1" sx={{ mb: 1 }}>
|
||||||
<Container maxWidth="md" sx={{ mt: 4 }}>
|
<strong>Username:</strong> {name}
|
||||||
<Card elevation={3}>
|
|
||||||
<CardContent>
|
|
||||||
<Box sx={{ display: 'flex', alignItems: 'center', mb: 3 }}>
|
|
||||||
<Avatar sx={{ mr: 2, bgcolor: 'primary.main' }}>
|
|
||||||
<AccountCircle />
|
|
||||||
</Avatar>
|
|
||||||
<Typography variant="h4" component="h1">
|
|
||||||
User Profile
|
|
||||||
</Typography>
|
</Typography>
|
||||||
</Box>
|
|
||||||
|
|
||||||
<Divider sx={{ mb: 3 }} />
|
|
||||||
|
|
||||||
<Grid container spacing={3}>
|
|
||||||
<Grid size={{ xs: 12, md: 6 }}>
|
|
||||||
<Typography variant="body1" sx={{ mb: 1 }}>
|
|
||||||
<strong>Username:</strong> {name}
|
|
||||||
</Typography>
|
|
||||||
</Grid>
|
|
||||||
<Grid size={{ xs: 12, md: 6 }}>
|
|
||||||
<Typography variant="body1" sx={{ mb: 1 }}>
|
|
||||||
<strong>Email:</strong> {currentUser.email}
|
|
||||||
</Typography>
|
|
||||||
</Grid>
|
|
||||||
<Grid size={{ xs: 12, md: 6 }}>
|
|
||||||
<Typography variant="body1" sx={{ mb: 1 }}>
|
|
||||||
<strong>Status:</strong> {currentUser.status}
|
|
||||||
</Typography>
|
|
||||||
</Grid>
|
|
||||||
<Grid size={{ xs: 12, md: 6 }}>
|
|
||||||
<Typography variant="body1" sx={{ mb: 1 }}>
|
|
||||||
<strong>Phone:</strong> {currentUser.phone || 'Not provided'}
|
|
||||||
</Typography>
|
|
||||||
</Grid>
|
|
||||||
<Grid size={{ xs: 12, md: 6 }}>
|
|
||||||
<Typography variant="body1" sx={{ mb: 1 }}>
|
|
||||||
<strong>Last Login:</strong> {
|
|
||||||
currentUser.lastLogin
|
|
||||||
? currentUser.lastLogin.toLocaleString()
|
|
||||||
: 'N/A'
|
|
||||||
}
|
|
||||||
</Typography>
|
|
||||||
</Grid>
|
|
||||||
<Grid size={{ xs: 12, md: 6 }}>
|
|
||||||
<Typography variant="body1" sx={{ mb: 1 }}>
|
|
||||||
<strong>Member Since:</strong> {currentUser.createdAt.toLocaleDateString()}
|
|
||||||
</Typography>
|
|
||||||
</Grid>
|
|
||||||
</Grid>
|
</Grid>
|
||||||
</CardContent>
|
<Grid size={{ xs: 12, md: 6 }}>
|
||||||
</Card>
|
<Typography variant="body1" sx={{ mb: 1 }}>
|
||||||
</Container>
|
<strong>Email:</strong> {user.email}
|
||||||
</Box>
|
</Typography>
|
||||||
|
</Grid>
|
||||||
|
<Grid size={{ xs: 12, md: 6 }}>
|
||||||
|
<Typography variant="body1" sx={{ mb: 1 }}>
|
||||||
|
<strong>Status:</strong> {user.status}
|
||||||
|
</Typography>
|
||||||
|
</Grid>
|
||||||
|
<Grid size={{ xs: 12, md: 6 }}>
|
||||||
|
<Typography variant="body1" sx={{ mb: 1 }}>
|
||||||
|
<strong>Phone:</strong> {user.phone || 'Not provided'}
|
||||||
|
</Typography>
|
||||||
|
</Grid>
|
||||||
|
<Grid size={{ xs: 12, md: 6 }}>
|
||||||
|
<Typography variant="body1" sx={{ mb: 1 }}>
|
||||||
|
<strong>Last Login:</strong> {
|
||||||
|
user.lastLogin
|
||||||
|
? user.lastLogin.toLocaleString()
|
||||||
|
: 'N/A'
|
||||||
|
}
|
||||||
|
</Typography>
|
||||||
|
</Grid>
|
||||||
|
<Grid size={{ xs: 12, md: 6 }}>
|
||||||
|
<Typography variant="body1" sx={{ mb: 1 }}>
|
||||||
|
<strong>Member Since:</strong> {user.createdAt.toLocaleDateString()}
|
||||||
|
</Typography>
|
||||||
|
</Grid>
|
||||||
|
</Grid>
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
</Container>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -352,20 +265,20 @@ const BackstoryTestApp: React.FC = () => {
|
|||||||
<Container maxWidth="sm" sx={{ mt: 4 }}>
|
<Container maxWidth="sm" sx={{ mt: 4 }}>
|
||||||
<Paper elevation={3} sx={{ p: 4 }}>
|
<Paper elevation={3} sx={{ p: 4 }}>
|
||||||
<Typography variant="h4" component="h1" gutterBottom align="center" color="primary">
|
<Typography variant="h4" component="h1" gutterBottom align="center" color="primary">
|
||||||
Backstory Platform
|
Backstory
|
||||||
</Typography>
|
</Typography>
|
||||||
|
|
||||||
{guestSession && (
|
{guest && (
|
||||||
<Card sx={{ mb: 3, bgcolor: 'grey.50' }} elevation={1}>
|
<Card sx={{ mb: 3, bgcolor: 'grey.50' }} elevation={1}>
|
||||||
<CardContent>
|
<CardContent>
|
||||||
<Typography variant="h6" gutterBottom color="primary">
|
<Typography variant="h6" gutterBottom color="primary">
|
||||||
Guest Session Active
|
Guest Session Active
|
||||||
</Typography>
|
</Typography>
|
||||||
<Typography variant="body2" color="text.secondary" sx={{ mb: 0.5 }}>
|
<Typography variant="body2" color="text.secondary" sx={{ mb: 0.5 }}>
|
||||||
Session ID: {guestSession.sessionId}
|
Session ID: {guest.sessionId}
|
||||||
</Typography>
|
</Typography>
|
||||||
<Typography variant="body2" color="text.secondary">
|
<Typography variant="body2" color="text.secondary">
|
||||||
Created: {guestSession.createdAt.toLocaleString()}
|
Created: {guest.createdAt.toLocaleString()}
|
||||||
</Typography>
|
</Typography>
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
@ -537,4 +450,4 @@ const BackstoryTestApp: React.FC = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export { BackstoryTestApp };
|
export { LoginPage };
|
@ -6,11 +6,11 @@ import {
|
|||||||
} from '@mui/material';
|
} from '@mui/material';
|
||||||
import { SxProps } from '@mui/material';
|
import { SxProps } from '@mui/material';
|
||||||
|
|
||||||
import { ChatQuery } from '../components/ChatQuery';
|
import { BackstoryQuery } from 'components/BackstoryQuery';
|
||||||
import { MessageList, BackstoryMessage } from '../components/Message';
|
import { MessageList, BackstoryMessage } from 'components/Message';
|
||||||
import { Conversation } from '../components/Conversation';
|
import { Conversation } from 'components/Conversation';
|
||||||
import { BackstoryPageProps } from '../components/BackstoryTab';
|
import { BackstoryPageProps } from 'components/BackstoryTab';
|
||||||
import { Query } from "../types/types";
|
import { ChatQuery } from "types/types";
|
||||||
|
|
||||||
import './ResumeBuilderPage.css';
|
import './ResumeBuilderPage.css';
|
||||||
|
|
||||||
@ -43,17 +43,17 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = (props: BackstoryPagePro
|
|||||||
setActiveTab(newValue);
|
setActiveTab(newValue);
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleJobQuery = (query: Query) => {
|
const handleJobQuery = (query: ChatQuery) => {
|
||||||
console.log(`handleJobQuery: ${query.prompt} -- `, jobConversationRef.current ? ' sending' : 'no handler');
|
console.log(`handleJobQuery: ${query.prompt} -- `, jobConversationRef.current ? ' sending' : 'no handler');
|
||||||
jobConversationRef.current?.submitQuery(query);
|
jobConversationRef.current?.submitQuery(query);
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleResumeQuery = (query: Query) => {
|
const handleResumeQuery = (query: ChatQuery) => {
|
||||||
console.log(`handleResumeQuery: ${query.prompt} -- `, resumeConversationRef.current ? ' sending' : 'no handler');
|
console.log(`handleResumeQuery: ${query.prompt} -- `, resumeConversationRef.current ? ' sending' : 'no handler');
|
||||||
resumeConversationRef.current?.submitQuery(query);
|
resumeConversationRef.current?.submitQuery(query);
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleFactsQuery = (query: Query) => {
|
const handleFactsQuery = (query: ChatQuery) => {
|
||||||
console.log(`handleFactsQuery: ${query.prompt} -- `, factsConversationRef.current ? ' sending' : 'no handler');
|
console.log(`handleFactsQuery: ${query.prompt} -- `, factsConversationRef.current ? ' sending' : 'no handler');
|
||||||
factsConversationRef.current?.submitQuery(query);
|
factsConversationRef.current?.submitQuery(query);
|
||||||
};
|
};
|
||||||
@ -202,8 +202,8 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = (props: BackstoryPagePro
|
|||||||
// console.log('renderJobDescriptionView');
|
// console.log('renderJobDescriptionView');
|
||||||
// const jobDescriptionQuestions = [
|
// const jobDescriptionQuestions = [
|
||||||
// <Box sx={{ display: "flex", flexDirection: "column" }}>
|
// <Box sx={{ display: "flex", flexDirection: "column" }}>
|
||||||
// <ChatQuery query={{ prompt: "What are the key skills necessary for this position?", tunables: { enableTools: false } }} submitQuery={handleJobQuery} />
|
// <BackstoryQuery query={{ prompt: "What are the key skills necessary for this position?", tunables: { enableTools: false } }} submitQuery={handleJobQuery} />
|
||||||
// <ChatQuery query={{ prompt: "How much should this position pay (accounting for inflation)?", tunables: { enableTools: false } }} submitQuery={handleJobQuery} />
|
// <BackstoryQuery query={{ prompt: "How much should this position pay (accounting for inflation)?", tunables: { enableTools: false } }} submitQuery={handleJobQuery} />
|
||||||
// </Box>,
|
// </Box>,
|
||||||
// ];
|
// ];
|
||||||
|
|
||||||
@ -213,7 +213,7 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = (props: BackstoryPagePro
|
|||||||
|
|
||||||
// 1. **Job Analysis**: LLM extracts requirements from '\`Job Description\`' to generate a list of desired '\`Skills\`'.
|
// 1. **Job Analysis**: LLM extracts requirements from '\`Job Description\`' to generate a list of desired '\`Skills\`'.
|
||||||
// 2. **Candidate Analysis**: LLM determines candidate qualifications by performing skill assessments.
|
// 2. **Candidate Analysis**: LLM determines candidate qualifications by performing skill assessments.
|
||||||
|
|
||||||
// For each '\`Skill\`' from **Job Analysis** phase:
|
// For each '\`Skill\`' from **Job Analysis** phase:
|
||||||
|
|
||||||
// 1. **RAG**: Retrieval Augmented Generation collection is queried for context related content for each '\`Skill\`'.
|
// 1. **RAG**: Retrieval Augmented Generation collection is queried for context related content for each '\`Skill\`'.
|
||||||
@ -274,8 +274,8 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = (props: BackstoryPagePro
|
|||||||
// const renderResumeView = useCallback((sx?: SxProps) => {
|
// const renderResumeView = useCallback((sx?: SxProps) => {
|
||||||
// const resumeQuestions = [
|
// const resumeQuestions = [
|
||||||
// <Box sx={{ display: "flex", flexDirection: "column" }}>
|
// <Box sx={{ display: "flex", flexDirection: "column" }}>
|
||||||
// <ChatQuery query={{ prompt: "Is this resume a good fit for the provided job description?", tunables: { enableTools: false } }} submitQuery={handleResumeQuery} />
|
// <BackstoryQuery query={{ prompt: "Is this resume a good fit for the provided job description?", tunables: { enableTools: false } }} submitQuery={handleResumeQuery} />
|
||||||
// <ChatQuery query={{ prompt: "Provide a more concise resume.", tunables: { enableTools: false } }} submitQuery={handleResumeQuery} />
|
// <BackstoryQuery query={{ prompt: "Provide a more concise resume.", tunables: { enableTools: false } }} submitQuery={handleResumeQuery} />
|
||||||
// </Box>,
|
// </Box>,
|
||||||
// ];
|
// ];
|
||||||
|
|
||||||
@ -323,7 +323,7 @@ const ResumeBuilderPage: React.FC<BackstoryPageProps> = (props: BackstoryPagePro
|
|||||||
// const renderFactCheckView = useCallback((sx?: SxProps) => {
|
// const renderFactCheckView = useCallback((sx?: SxProps) => {
|
||||||
// const factsQuestions = [
|
// const factsQuestions = [
|
||||||
// <Box sx={{ display: "flex", flexDirection: "column" }}>
|
// <Box sx={{ display: "flex", flexDirection: "column" }}>
|
||||||
// <ChatQuery query={{ prompt: "Rewrite the resume to address any discrepancies.", tunables: { enableTools: false } }} submitQuery={handleFactsQuery} />
|
// <BackstoryQuery query={{ prompt: "Rewrite the resume to address any discrepancies.", tunables: { enableTools: false } }} submitQuery={handleFactsQuery} />
|
||||||
// </Box>,
|
// </Box>,
|
||||||
// ];
|
// ];
|
||||||
|
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
import React, { useEffect, useState } from "react";
|
import React, { useEffect, useState } from "react";
|
||||||
import { useParams, useNavigate } from "react-router-dom";
|
import { useParams, useNavigate } from "react-router-dom";
|
||||||
import { useUser } from "../hooks/useUser";
|
|
||||||
import { Box } from "@mui/material";
|
import { Box } from "@mui/material";
|
||||||
|
|
||||||
import { SetSnackType } from '../components/Snack';
|
import { SetSnackType } from '../components/Snack';
|
||||||
import { LoadingComponent } from "../components/LoadingComponent";
|
import { LoadingComponent } from "../components/LoadingComponent";
|
||||||
import { User, Guest, Candidate } from 'types/types';
|
import { User, Guest, Candidate } from 'types/types';
|
||||||
import { ApiClient } from "types/api-client";
|
import { useUser } from "hooks/useUser";
|
||||||
|
|
||||||
interface CandidateRouteProps {
|
interface CandidateRouteProps {
|
||||||
guest?: Guest | null;
|
guest?: Guest | null;
|
||||||
@ -15,7 +14,7 @@ interface CandidateRouteProps {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const CandidateRoute: React.FC<CandidateRouteProps> = (props: CandidateRouteProps) => {
|
const CandidateRoute: React.FC<CandidateRouteProps> = (props: CandidateRouteProps) => {
|
||||||
const apiClient = new ApiClient();
|
const { apiClient } = useUser();
|
||||||
const { setSnack } = props;
|
const { setSnack } = props;
|
||||||
const { username } = useParams<{ username: string }>();
|
const { username } = useParams<{ username: string }>();
|
||||||
const [candidate, setCandidate] = useState<Candidate|null>(null);
|
const [candidate, setCandidate] = useState<Candidate|null>(null);
|
||||||
@ -32,11 +31,12 @@ const CandidateRoute: React.FC<CandidateRouteProps> = (props: CandidateRouteProp
|
|||||||
navigate('/chat');
|
navigate('/chat');
|
||||||
} catch {
|
} catch {
|
||||||
setSnack(`Unable to obtain information for ${username}.`, "error");
|
setSnack(`Unable to obtain information for ${username}.`, "error");
|
||||||
|
navigate('/');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
getCandidate(username);
|
getCandidate(username);
|
||||||
}, [candidate, username, setCandidate]);
|
}, [candidate, username, setCandidate, navigate, setSnack, apiClient]);
|
||||||
|
|
||||||
if (candidate === null) {
|
if (candidate === null) {
|
||||||
return (<Box>
|
return (<Box>
|
||||||
|
@ -9,14 +9,14 @@
|
|||||||
import * as Types from './types';
|
import * as Types from './types';
|
||||||
import {
|
import {
|
||||||
formatApiRequest,
|
formatApiRequest,
|
||||||
parseApiResponse,
|
// parseApiResponse,
|
||||||
parsePaginatedResponse,
|
// parsePaginatedResponse,
|
||||||
handleApiResponse,
|
handleApiResponse,
|
||||||
handlePaginatedApiResponse,
|
handlePaginatedApiResponse,
|
||||||
createPaginatedRequest,
|
createPaginatedRequest,
|
||||||
toUrlParams,
|
toUrlParams,
|
||||||
extractApiData,
|
// extractApiData,
|
||||||
ApiResponse,
|
// ApiResponse,
|
||||||
PaginatedResponse,
|
PaginatedResponse,
|
||||||
PaginatedRequest
|
PaginatedRequest
|
||||||
} from './conversion';
|
} from './conversion';
|
||||||
@ -59,7 +59,7 @@ class ApiClient {
|
|||||||
private baseUrl: string;
|
private baseUrl: string;
|
||||||
private defaultHeaders: Record<string, string>;
|
private defaultHeaders: Record<string, string>;
|
||||||
|
|
||||||
constructor(authToken?: string) {
|
constructor(accessToken?: string) {
|
||||||
const loc = window.location;
|
const loc = window.location;
|
||||||
if (!loc.host.match(/.*battle-linux.*/)) {
|
if (!loc.host.match(/.*battle-linux.*/)) {
|
||||||
this.baseUrl = loc.protocol + "//" + loc.host + "/api/1.0";
|
this.baseUrl = loc.protocol + "//" + loc.host + "/api/1.0";
|
||||||
@ -68,7 +68,7 @@ class ApiClient {
|
|||||||
}
|
}
|
||||||
this.defaultHeaders = {
|
this.defaultHeaders = {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
...(authToken && { 'Authorization': `Bearer ${authToken}` })
|
...(accessToken && { 'Authorization': `Bearer ${accessToken}` })
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -76,16 +76,27 @@ class ApiClient {
|
|||||||
// Authentication Methods
|
// Authentication Methods
|
||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
async login(email: string, password: string): Promise<Types.AuthResponse> {
|
async login(login: string, password: string): Promise<Types.AuthResponse> {
|
||||||
const response = await fetch(`${this.baseUrl}/auth/login`, {
|
const response = await fetch(`${this.baseUrl}/auth/login`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: this.defaultHeaders,
|
headers: this.defaultHeaders,
|
||||||
body: JSON.stringify(formatApiRequest({ email, password }))
|
body: JSON.stringify(formatApiRequest({ login, password }))
|
||||||
});
|
});
|
||||||
|
|
||||||
return handleApiResponse<Types.AuthResponse>(response);
|
return handleApiResponse<Types.AuthResponse>(response);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async logout(accessToken: string, refreshToken: string): Promise<Types.ApiResponse> {
|
||||||
|
console.log(this.defaultHeaders);
|
||||||
|
const response = await fetch(`${this.baseUrl}/auth/logout`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: this.defaultHeaders,
|
||||||
|
body: JSON.stringify(formatApiRequest({ accessToken, refreshToken }))
|
||||||
|
});
|
||||||
|
|
||||||
|
return handleApiResponse<Types.ApiResponse>(response);
|
||||||
|
}
|
||||||
|
|
||||||
async refreshToken(refreshToken: string): Promise<Types.AuthResponse> {
|
async refreshToken(refreshToken: string): Promise<Types.AuthResponse> {
|
||||||
const response = await fetch(`${this.baseUrl}/auth/refresh`, {
|
const response = await fetch(`${this.baseUrl}/auth/refresh`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
@ -110,8 +121,8 @@ class ApiClient {
|
|||||||
return handleApiResponse<Types.Candidate>(response);
|
return handleApiResponse<Types.Candidate>(response);
|
||||||
}
|
}
|
||||||
|
|
||||||
async getCandidate(id: string): Promise<Types.Candidate> {
|
async getCandidate(username: string): Promise<Types.Candidate> {
|
||||||
const response = await fetch(`${this.baseUrl}/candidates/${id}`, {
|
const response = await fetch(`${this.baseUrl}/candidates/${username}`, {
|
||||||
headers: this.defaultHeaders
|
headers: this.defaultHeaders
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -315,11 +326,11 @@ class ApiClient {
|
|||||||
/**
|
/**
|
||||||
* Send message with standard response (non-streaming)
|
* Send message with standard response (non-streaming)
|
||||||
*/
|
*/
|
||||||
async sendMessage(sessionId: string, query: Types.Query): Promise<Types.ChatMessage> {
|
async sendMessage(sessionId: string, query: Types.ChatQuery): Promise<Types.ChatMessage> {
|
||||||
const response = await fetch(`${this.baseUrl}/chat/sessions/${sessionId}/messages`, {
|
const response = await fetch(`${this.baseUrl}/chat/sessions/${sessionId}/messages`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: this.defaultHeaders,
|
headers: this.defaultHeaders,
|
||||||
body: JSON.stringify(formatApiRequest({ query }))
|
body: JSON.stringify(formatApiRequest({query}))
|
||||||
});
|
});
|
||||||
|
|
||||||
return handleApiResponse<Types.ChatMessage>(response);
|
return handleApiResponse<Types.ChatMessage>(response);
|
||||||
@ -330,7 +341,7 @@ class ApiClient {
|
|||||||
*/
|
*/
|
||||||
sendMessageStream(
|
sendMessageStream(
|
||||||
sessionId: string,
|
sessionId: string,
|
||||||
query: Types.Query,
|
query: Types.ChatQuery,
|
||||||
options: StreamingOptions = {}
|
options: StreamingOptions = {}
|
||||||
): StreamingResponse {
|
): StreamingResponse {
|
||||||
const abortController = new AbortController();
|
const abortController = new AbortController();
|
||||||
@ -479,7 +490,7 @@ class ApiClient {
|
|||||||
*/
|
*/
|
||||||
async sendMessageAuto(
|
async sendMessageAuto(
|
||||||
sessionId: string,
|
sessionId: string,
|
||||||
query: Types.Query,
|
query: Types.ChatQuery,
|
||||||
options?: StreamingOptions
|
options?: StreamingOptions
|
||||||
): Promise<Types.ChatMessage> {
|
): Promise<Types.ChatMessage> {
|
||||||
// If streaming options are provided, use streaming
|
// If streaming options are provided, use streaming
|
||||||
@ -503,36 +514,6 @@ class ApiClient {
|
|||||||
return handlePaginatedApiResponse<Types.ChatMessage>(response);
|
return handlePaginatedApiResponse<Types.ChatMessage>(response);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================
|
|
||||||
// AI Configuration Methods
|
|
||||||
// ============================
|
|
||||||
|
|
||||||
async createAIParameters(params: Omit<Types.AIParameters, 'id' | 'createdAt' | 'updatedAt'>): Promise<Types.AIParameters> {
|
|
||||||
const response = await fetch(`${this.baseUrl}/ai/parameters`, {
|
|
||||||
method: 'POST',
|
|
||||||
headers: this.defaultHeaders,
|
|
||||||
body: JSON.stringify(formatApiRequest(params))
|
|
||||||
});
|
|
||||||
|
|
||||||
return handleApiResponse<Types.AIParameters>(response);
|
|
||||||
}
|
|
||||||
|
|
||||||
async getAIParameters(id: string): Promise<Types.AIParameters> {
|
|
||||||
const response = await fetch(`${this.baseUrl}/ai/parameters/${id}`, {
|
|
||||||
headers: this.defaultHeaders
|
|
||||||
});
|
|
||||||
|
|
||||||
return handleApiResponse<Types.AIParameters>(response);
|
|
||||||
}
|
|
||||||
|
|
||||||
async getUserAIParameters(userId: string): Promise<Types.AIParameters[]> {
|
|
||||||
const response = await fetch(`${this.baseUrl}/users/${userId}/ai/parameters`, {
|
|
||||||
headers: this.defaultHeaders
|
|
||||||
});
|
|
||||||
|
|
||||||
return handleApiResponse<Types.AIParameters[]>(response);
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================
|
// ============================
|
||||||
// Error Handling Helper
|
// Error Handling Helper
|
||||||
// ============================
|
// ============================
|
||||||
@ -580,7 +561,7 @@ export function useStreamingChat(sessionId: string) {
|
|||||||
const apiClient = useApiClient();
|
const apiClient = useApiClient();
|
||||||
const streamingRef = useRef<StreamingResponse | null>(null);
|
const streamingRef = useRef<StreamingResponse | null>(null);
|
||||||
|
|
||||||
const sendMessage = useCallback(async (query: Types.Query) => {
|
const sendMessage = useCallback(async (query: Types.ChatQuery) => {
|
||||||
setError(null);
|
setError(null);
|
||||||
setIsStreaming(true);
|
setIsStreaming(true);
|
||||||
setCurrentMessage(null);
|
setCurrentMessage(null);
|
||||||
|
@ -1,23 +1,23 @@
|
|||||||
// Generated TypeScript types from Pydantic models
|
// Generated TypeScript types from Pydantic models
|
||||||
// Source: src/backend/models.py
|
// Source: src/backend/models.py
|
||||||
// Generated on: 2025-05-28T21:47:08.590102
|
// Generated on: 2025-05-29T05:47:25.809967
|
||||||
// DO NOT EDIT MANUALLY - This file is auto-generated
|
// DO NOT EDIT MANUALLY - This file is auto-generated
|
||||||
|
|
||||||
// ============================
|
// ============================
|
||||||
// Enums
|
// Enums
|
||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
export type AIModelType = "gpt-4" | "gpt-3.5-turbo" | "claude-3" | "claude-3-opus" | "custom";
|
export type AIModelType = "qwen2.5" | "flux-schnell";
|
||||||
|
|
||||||
export type ActivityType = "login" | "search" | "view_job" | "apply_job" | "message" | "update_profile" | "chat";
|
export type ActivityType = "login" | "search" | "view_job" | "apply_job" | "message" | "update_profile" | "chat";
|
||||||
|
|
||||||
export type ApplicationStatus = "applied" | "reviewing" | "interview" | "offer" | "rejected" | "accepted" | "withdrawn";
|
export type ApplicationStatus = "applied" | "reviewing" | "interview" | "offer" | "rejected" | "accepted" | "withdrawn";
|
||||||
|
|
||||||
export type ChatContextType = "job_search" | "candidate_screening" | "interview_prep" | "resume_review" | "general";
|
export type ChatContextType = "job_search" | "candidate_screening" | "interview_prep" | "resume_review" | "general" | "generate_persona" | "generate_profile";
|
||||||
|
|
||||||
export type ChatSenderType = "user" | "ai" | "system";
|
export type ChatSenderType = "user" | "ai" | "system";
|
||||||
|
|
||||||
export type ChatStatusType = "partial" | "done" | "streaming" | "thinking" | "error";
|
export type ChatStatusType = "preparing" | "thinking" | "partial" | "streaming" | "done" | "error";
|
||||||
|
|
||||||
export type ColorBlindMode = "protanopia" | "deuteranopia" | "tritanopia" | "none";
|
export type ColorBlindMode = "protanopia" | "deuteranopia" | "tritanopia" | "none";
|
||||||
|
|
||||||
@ -63,24 +63,6 @@ export type VectorStoreType = "pinecone" | "qdrant" | "faiss" | "milvus" | "weav
|
|||||||
// Interfaces
|
// Interfaces
|
||||||
// ============================
|
// ============================
|
||||||
|
|
||||||
export interface AIParameters {
|
|
||||||
id?: string;
|
|
||||||
userId?: string;
|
|
||||||
name: string;
|
|
||||||
description?: string;
|
|
||||||
model: "gpt-4" | "gpt-3.5-turbo" | "claude-3" | "claude-3-opus" | "custom";
|
|
||||||
temperature: number;
|
|
||||||
maxTokens: number;
|
|
||||||
topP: number;
|
|
||||||
frequencyPenalty: number;
|
|
||||||
presencePenalty: number;
|
|
||||||
systemPrompt?: string;
|
|
||||||
isDefault: boolean;
|
|
||||||
createdAt: Date;
|
|
||||||
updatedAt: Date;
|
|
||||||
customModelConfig?: Record<string, any>;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface AccessibilitySettings {
|
export interface AccessibilitySettings {
|
||||||
fontSize: "small" | "medium" | "large";
|
fontSize: "small" | "medium" | "large";
|
||||||
highContrast: boolean;
|
highContrast: boolean;
|
||||||
@ -240,34 +222,63 @@ export interface Certification {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export interface ChatContext {
|
export interface ChatContext {
|
||||||
type: "job_search" | "candidate_screening" | "interview_prep" | "resume_review" | "general";
|
type: "job_search" | "candidate_screening" | "interview_prep" | "resume_review" | "general" | "generate_persona" | "generate_profile";
|
||||||
relatedEntityId?: string;
|
relatedEntityId?: string;
|
||||||
relatedEntityType?: "job" | "candidate" | "employer";
|
relatedEntityType?: "job" | "candidate" | "employer";
|
||||||
aiParameters: AIParameters;
|
|
||||||
additionalContext?: Record<string, any>;
|
additionalContext?: Record<string, any>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ChatMessage {
|
export interface ChatMessage {
|
||||||
id?: string;
|
id?: string;
|
||||||
sessionId: string;
|
sessionId: string;
|
||||||
status: "partial" | "done" | "streaming" | "thinking" | "error";
|
status: "preparing" | "thinking" | "partial" | "streaming" | "done" | "error";
|
||||||
sender: "user" | "ai" | "system";
|
sender: "user" | "ai" | "system";
|
||||||
senderId?: string;
|
senderId?: string;
|
||||||
content: string;
|
prompt?: string;
|
||||||
|
content?: string;
|
||||||
|
chunk?: string;
|
||||||
timestamp: Date;
|
timestamp: Date;
|
||||||
attachments?: Array<Attachment>;
|
|
||||||
reactions?: Array<MessageReaction>;
|
|
||||||
isEdited?: boolean;
|
isEdited?: boolean;
|
||||||
editHistory?: Array<EditHistory>;
|
metadata?: ChatMessageMetaData;
|
||||||
metadata?: Record<string, any>;
|
}
|
||||||
|
|
||||||
|
export interface ChatMessageMetaData {
|
||||||
|
model?: "qwen2.5" | "flux-schnell";
|
||||||
|
temperature?: number;
|
||||||
|
maxTokens?: number;
|
||||||
|
topP?: number;
|
||||||
|
frequencyPenalty?: number;
|
||||||
|
presencePenalty?: number;
|
||||||
|
stopSequences?: Array<string>;
|
||||||
|
tunables?: Tunables;
|
||||||
|
rag?: Array<ChromaDBGetResponse>;
|
||||||
|
evalCount?: number;
|
||||||
|
evalDuration?: number;
|
||||||
|
promptEvalCount?: number;
|
||||||
|
promptEvalDuration?: number;
|
||||||
|
options?: ChatOptions;
|
||||||
|
tools?: Record<string, any>;
|
||||||
|
timers?: Record<string, number>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ChatOptions {
|
||||||
|
seed?: number;
|
||||||
|
numCtx?: number;
|
||||||
|
temperature?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ChatQuery {
|
||||||
|
prompt: string;
|
||||||
|
tunables?: Tunables;
|
||||||
|
agentOptions?: Record<string, any>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ChatSession {
|
export interface ChatSession {
|
||||||
id?: string;
|
id?: string;
|
||||||
userId?: string;
|
userId?: string;
|
||||||
guestId?: string;
|
guestId?: string;
|
||||||
createdAt: Date;
|
createdAt?: Date;
|
||||||
lastActivity: Date;
|
lastActivity?: Date;
|
||||||
title?: string;
|
title?: string;
|
||||||
context: ChatContext;
|
context: ChatContext;
|
||||||
messages?: Array<ChatMessage>;
|
messages?: Array<ChatMessage>;
|
||||||
@ -275,6 +286,19 @@ export interface ChatSession {
|
|||||||
systemPrompt?: string;
|
systemPrompt?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface ChromaDBGetResponse {
|
||||||
|
ids?: Array<string>;
|
||||||
|
embeddings?: Array<Array<number>>;
|
||||||
|
documents?: Array<string>;
|
||||||
|
metadatas?: Array<Record<string, any>>;
|
||||||
|
name?: string;
|
||||||
|
size?: number;
|
||||||
|
query?: string;
|
||||||
|
queryEmbedding?: Array<number>;
|
||||||
|
umapEmbedding2D?: Array<number>;
|
||||||
|
umapEmbedding3D?: Array<number>;
|
||||||
|
}
|
||||||
|
|
||||||
export interface CustomQuestion {
|
export interface CustomQuestion {
|
||||||
question: string;
|
question: string;
|
||||||
answer: string;
|
answer: string;
|
||||||
@ -511,12 +535,6 @@ export interface ProcessingStep {
|
|||||||
dependsOn?: Array<string>;
|
dependsOn?: Array<string>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface Query {
|
|
||||||
prompt: string;
|
|
||||||
tunables?: Tunables;
|
|
||||||
agentOptions?: Record<string, any>;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface RAGConfiguration {
|
export interface RAGConfiguration {
|
||||||
id?: string;
|
id?: string;
|
||||||
userId: string;
|
userId: string;
|
||||||
@ -528,11 +546,16 @@ export interface RAGConfiguration {
|
|||||||
retrievalParameters: RetrievalParameters;
|
retrievalParameters: RetrievalParameters;
|
||||||
createdAt: Date;
|
createdAt: Date;
|
||||||
updatedAt: Date;
|
updatedAt: Date;
|
||||||
isDefault: boolean;
|
|
||||||
version: number;
|
version: number;
|
||||||
isActive: boolean;
|
isActive: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface RagEntry {
|
||||||
|
name: string;
|
||||||
|
description?: string;
|
||||||
|
enabled?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
export interface RefreshToken {
|
export interface RefreshToken {
|
||||||
token: string;
|
token: string;
|
||||||
expiresAt: Date;
|
expiresAt: Date;
|
||||||
|
97
src/backend/agents/__init__.py
Normal file
97
src/backend/agents/__init__.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from pydantic import BaseModel, Field # type: ignore
|
||||||
|
from typing import (
|
||||||
|
Literal,
|
||||||
|
get_args,
|
||||||
|
List,
|
||||||
|
AsyncGenerator,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Optional,
|
||||||
|
ClassVar,
|
||||||
|
Any,
|
||||||
|
TypeAlias,
|
||||||
|
Dict,
|
||||||
|
Tuple,
|
||||||
|
)
|
||||||
|
import importlib
|
||||||
|
import pathlib
|
||||||
|
import inspect
|
||||||
|
from prometheus_client import CollectorRegistry # type: ignore
|
||||||
|
|
||||||
|
from database import RedisDatabase
|
||||||
|
from . base import Agent
|
||||||
|
from logger import logger
|
||||||
|
|
||||||
|
_agents: List[Agent] = []
|
||||||
|
|
||||||
|
def get_or_create_agent(agent_type: str, prometheus_collector: CollectorRegistry, database: RedisDatabase, **kwargs) -> Agent:
|
||||||
|
"""
|
||||||
|
Get or create and append a new agent of the specified type, ensuring only one agent per type exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_type: The type of agent to create (e.g., 'web', 'database').
|
||||||
|
**kwargs: Additional fields required by the specific agent subclass.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created agent instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no matching agent type is found or if a agent of this type already exists.
|
||||||
|
"""
|
||||||
|
# Check if a agent with the given agent_type already exists
|
||||||
|
for agent in _agents:
|
||||||
|
if agent.agent_type == agent_type:
|
||||||
|
return agent
|
||||||
|
|
||||||
|
# Find the matching subclass
|
||||||
|
for agent_cls in Agent.__subclasses__():
|
||||||
|
if agent_cls.model_fields["agent_type"].default == agent_type:
|
||||||
|
# Create the agent instance with provided kwargs
|
||||||
|
agent = agent_cls(agent_type=agent_type, prometheus_collector=prometheus_collector, database=database, **kwargs)
|
||||||
|
# if agent.agent_persist: # If an agent is not set to persist, do not add it to the list
|
||||||
|
_agents.append(agent)
|
||||||
|
return agent
|
||||||
|
|
||||||
|
raise ValueError(f"No agent class found for agent_type: {agent_type}")
|
||||||
|
|
||||||
|
# Type alias for Agent or any subclass
|
||||||
|
AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses
|
||||||
|
|
||||||
|
# Maps class_name to (module_name, class_name)
|
||||||
|
class_registry: Dict[str, Tuple[str, str]] = (
|
||||||
|
{}
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ['get_or_create_agent']
|
||||||
|
|
||||||
|
package_dir = pathlib.Path(__file__).parent
|
||||||
|
package_name = __name__
|
||||||
|
|
||||||
|
for path in package_dir.glob("*.py"):
|
||||||
|
if path.name in ("__init__.py", "base.py") or path.name.startswith("_"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
module_name = path.stem
|
||||||
|
full_module_name = f"{package_name}.{module_name}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(full_module_name)
|
||||||
|
|
||||||
|
# Find all Agent subclasses in the module
|
||||||
|
for name, obj in inspect.getmembers(module, inspect.isclass):
|
||||||
|
if (
|
||||||
|
issubclass(obj, AnyAgent)
|
||||||
|
and obj is not AnyAgent
|
||||||
|
and obj is not Agent
|
||||||
|
and name not in class_registry
|
||||||
|
):
|
||||||
|
class_registry[name] = (full_module_name, name)
|
||||||
|
globals()[name] = obj
|
||||||
|
logger.info(f"Adding agent: {name}")
|
||||||
|
__all__.append(name) # type: ignore
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"Error importing {full_module_name}: {e}")
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing {full_module_name}: {e}")
|
||||||
|
raise e
|
605
src/backend/agents/base.py
Normal file
605
src/backend/agents/base.py
Normal file
@ -0,0 +1,605 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from pydantic import BaseModel, Field, model_validator # type: ignore
|
||||||
|
from typing import (
|
||||||
|
Literal,
|
||||||
|
get_args,
|
||||||
|
List,
|
||||||
|
AsyncGenerator,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Optional,
|
||||||
|
ClassVar,
|
||||||
|
Any,
|
||||||
|
TypeAlias,
|
||||||
|
Dict,
|
||||||
|
Tuple,
|
||||||
|
)
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import inspect
|
||||||
|
from abc import ABC
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime, UTC
|
||||||
|
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
|
||||||
|
|
||||||
|
from models import ( ChatQuery, ChatMessage, Tunables, ChatStatusType, ChatMessageMetaData)
|
||||||
|
from logger import logger
|
||||||
|
import defines
|
||||||
|
from .registry import agent_registry
|
||||||
|
from metrics import Metrics
|
||||||
|
from database import RedisDatabase # type: ignore
|
||||||
|
|
||||||
|
class LLMMessage(BaseModel):
|
||||||
|
role: str = Field(default="")
|
||||||
|
content: str = Field(default="")
|
||||||
|
tool_calls: Optional[List[Dict]] = Field(default={}, exclude=True)
|
||||||
|
|
||||||
|
class Agent(BaseModel, ABC):
|
||||||
|
"""
|
||||||
|
Base class for all agent types.
|
||||||
|
This class defines the common attributes and methods for all agent types.
|
||||||
|
"""
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True # Allow arbitrary types like RedisDatabase
|
||||||
|
|
||||||
|
# Agent management with pydantic
|
||||||
|
agent_type: Literal["base"] = "base"
|
||||||
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
agent_persist: bool = True # Whether this agent will persist in the database
|
||||||
|
|
||||||
|
database: RedisDatabase = Field(
|
||||||
|
...,
|
||||||
|
description="Database connection for this agent, used to store and retrieve data."
|
||||||
|
)
|
||||||
|
prometheus_collector: CollectorRegistry = Field(..., description="Prometheus collector for this agent, used to track metrics.", exclude=True)
|
||||||
|
|
||||||
|
# Tunables (sets default for new Messages attached to this agent)
|
||||||
|
tunables: Tunables = Field(default_factory=Tunables)
|
||||||
|
metrics: Metrics = Field(
|
||||||
|
None, description="Metrics collector for this agent, used to track performance and usage."
|
||||||
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def initialize_metrics(self) -> "Agent":
|
||||||
|
if self.metrics is None:
|
||||||
|
self.metrics = Metrics(prometheus_collector=self.prometheus_collector)
|
||||||
|
return self
|
||||||
|
|
||||||
|
# Agent properties
|
||||||
|
system_prompt: str # Mandatory
|
||||||
|
context_tokens: int = 0
|
||||||
|
|
||||||
|
# context_size is shared across all subclasses
|
||||||
|
_context_size: ClassVar[int] = int(defines.max_context * 0.5)
|
||||||
|
|
||||||
|
conversation: List[ChatMessage] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Conversation history for this agent, used to maintain context across messages."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def context_size(self) -> int:
|
||||||
|
return Agent._context_size
|
||||||
|
|
||||||
|
@context_size.setter
|
||||||
|
def context_size(self, value: int):
|
||||||
|
Agent._context_size = value
|
||||||
|
|
||||||
|
def set_optimal_context_size(
|
||||||
|
self, llm: Any, model: str, prompt: str, ctx_buffer=2048
|
||||||
|
) -> int:
|
||||||
|
# Most models average 1.3-1.5 tokens per word
|
||||||
|
word_count = len(prompt.split())
|
||||||
|
tokens = int(word_count * 1.4)
|
||||||
|
|
||||||
|
# Add buffer for safety
|
||||||
|
total_ctx = tokens + ctx_buffer
|
||||||
|
|
||||||
|
if total_ctx > self.context_size:
|
||||||
|
logger.info(
|
||||||
|
f"Increasing context size from {self.context_size} to {total_ctx}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Grow the context size if necessary
|
||||||
|
self.context_size = max(self.context_size, total_ctx)
|
||||||
|
# Use actual model maximum context size
|
||||||
|
return self.context_size
|
||||||
|
|
||||||
|
# Class and pydantic model management
|
||||||
|
def __init_subclass__(cls, **kwargs) -> None:
|
||||||
|
"""Auto-register subclasses"""
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
# Register this class if it has an agent_type
|
||||||
|
if hasattr(cls, "agent_type") and cls.agent_type != Agent._agent_type:
|
||||||
|
agent_registry.register(cls.agent_type, cls)
|
||||||
|
|
||||||
|
def model_dump(self, *args, **kwargs) -> Any:
|
||||||
|
# Ensure context is always excluded, even with exclude_unset=True
|
||||||
|
kwargs.setdefault("exclude", set())
|
||||||
|
if isinstance(kwargs["exclude"], set):
|
||||||
|
kwargs["exclude"].add("context")
|
||||||
|
elif isinstance(kwargs["exclude"], dict):
|
||||||
|
kwargs["exclude"]["context"] = True
|
||||||
|
return super().model_dump(*args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def valid_agent_types(cls) -> set[str]:
|
||||||
|
"""Return the set of valid agent_type values."""
|
||||||
|
return set(get_args(cls.__annotations__["agent_type"]))
|
||||||
|
|
||||||
|
# Agent methods
|
||||||
|
def get_agent_type(self):
|
||||||
|
return self._agent_type
|
||||||
|
|
||||||
|
# async def prepare_message(self, message: ChatMessage) -> AsyncGenerator[ChatMessage, None]:
|
||||||
|
# """
|
||||||
|
# Prepare message with context information in message.preamble
|
||||||
|
# """
|
||||||
|
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
|
|
||||||
|
# self.metrics.prepare_count.labels(agent=self.agent_type).inc()
|
||||||
|
# with self.metrics.prepare_duration.labels(agent=self.agent_type).time():
|
||||||
|
# if not self.context:
|
||||||
|
# raise ValueError("Context is not set for this agent.")
|
||||||
|
|
||||||
|
# # Generate RAG content if enabled, based on the content
|
||||||
|
# rag_context = ""
|
||||||
|
# if message.tunables.enable_rag and message.prompt:
|
||||||
|
# # Gather RAG results, yielding each result
|
||||||
|
# # as it becomes available
|
||||||
|
# for message in self.context.user.generate_rag_results(message):
|
||||||
|
# logger.info(f"RAG: {message.status} - {message.content}")
|
||||||
|
# if message.status == "error":
|
||||||
|
# yield message
|
||||||
|
# return
|
||||||
|
# if message.status != "done":
|
||||||
|
# yield message
|
||||||
|
|
||||||
|
# # for rag in message.metadata.rag:
|
||||||
|
# # for doc in rag.documents:
|
||||||
|
# # rag_context += f"{doc}\n"
|
||||||
|
|
||||||
|
# message.preamble = {}
|
||||||
|
|
||||||
|
# if rag_context:
|
||||||
|
# message.preamble["context"] = f"The following is context information about {self.context.user.full_name}:\n{rag_context}"
|
||||||
|
|
||||||
|
# if message.tunables.enable_context and self.context.user_resume:
|
||||||
|
# message.preamble["resume"] = self.context.user_resume
|
||||||
|
|
||||||
|
# message.system_prompt = self.system_prompt
|
||||||
|
# message.status = ChatStatusType.DONE
|
||||||
|
# yield message
|
||||||
|
|
||||||
|
# return
|
||||||
|
|
||||||
|
# async def process_tool_calls(
|
||||||
|
# self,
|
||||||
|
# llm: Any,
|
||||||
|
# model: str,
|
||||||
|
# message: ChatMessage,
|
||||||
|
# tool_message: Any, # llama response message
|
||||||
|
# messages: List[LLMMessage],
|
||||||
|
# ) -> AsyncGenerator[ChatMessage, None]:
|
||||||
|
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
|
|
||||||
|
# self.metrics.tool_count.labels(agent=self.agent_type).inc()
|
||||||
|
# with self.metrics.tool_duration.labels(agent=self.agent_type).time():
|
||||||
|
|
||||||
|
# if not self.context:
|
||||||
|
# raise ValueError("Context is not set for this agent.")
|
||||||
|
# if not message.metadata.tools:
|
||||||
|
# raise ValueError("tools field not initialized")
|
||||||
|
|
||||||
|
# tool_metadata = message.metadata.tools
|
||||||
|
# tool_metadata["tool_calls"] = []
|
||||||
|
|
||||||
|
# message.status = "tooling"
|
||||||
|
|
||||||
|
# for i, tool_call in enumerate(tool_message.tool_calls):
|
||||||
|
# arguments = tool_call.function.arguments
|
||||||
|
# tool = tool_call.function.name
|
||||||
|
|
||||||
|
# # Yield status update before processing each tool
|
||||||
|
# message.content = (
|
||||||
|
# f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..."
|
||||||
|
# )
|
||||||
|
# yield message
|
||||||
|
# logger.info(f"LLM - {message.content}")
|
||||||
|
|
||||||
|
# # Process the tool based on its type
|
||||||
|
# match tool:
|
||||||
|
# case "TickerValue":
|
||||||
|
# ticker = arguments.get("ticker")
|
||||||
|
# if not ticker:
|
||||||
|
# ret = None
|
||||||
|
# else:
|
||||||
|
# ret = TickerValue(ticker)
|
||||||
|
|
||||||
|
# case "AnalyzeSite":
|
||||||
|
# url = arguments.get("url")
|
||||||
|
# question = arguments.get(
|
||||||
|
# "question", "what is the summary of this content?"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # Additional status update for long-running operations
|
||||||
|
# message.content = (
|
||||||
|
# f"Retrieving and summarizing content from {url}..."
|
||||||
|
# )
|
||||||
|
# yield message
|
||||||
|
# ret = await AnalyzeSite(
|
||||||
|
# llm=llm, model=model, url=url, question=question
|
||||||
|
# )
|
||||||
|
|
||||||
|
# case "GenerateImage":
|
||||||
|
# prompt = arguments.get("prompt", None)
|
||||||
|
# if not prompt:
|
||||||
|
# logger.info("No prompt supplied to GenerateImage")
|
||||||
|
# ret = { "error": "No prompt supplied to GenerateImage" }
|
||||||
|
|
||||||
|
# # Additional status update for long-running operations
|
||||||
|
# message.content = (
|
||||||
|
# f"Generating image for {prompt}..."
|
||||||
|
# )
|
||||||
|
# yield message
|
||||||
|
# ret = await GenerateImage(
|
||||||
|
# llm=llm, model=model, prompt=prompt
|
||||||
|
# )
|
||||||
|
# logger.info("GenerateImage returning", ret)
|
||||||
|
|
||||||
|
# case "DateTime":
|
||||||
|
# tz = arguments.get("timezone")
|
||||||
|
# ret = DateTime(tz)
|
||||||
|
|
||||||
|
# case "WeatherForecast":
|
||||||
|
# city = arguments.get("city")
|
||||||
|
# state = arguments.get("state")
|
||||||
|
|
||||||
|
# message.content = (
|
||||||
|
# f"Fetching weather data for {city}, {state}..."
|
||||||
|
# )
|
||||||
|
# yield message
|
||||||
|
# ret = WeatherForecast(city, state)
|
||||||
|
|
||||||
|
# case _:
|
||||||
|
# logger.error(f"Requested tool {tool} does not exist")
|
||||||
|
# ret = None
|
||||||
|
|
||||||
|
# # Build response for this tool
|
||||||
|
# tool_response = {
|
||||||
|
# "role": "tool",
|
||||||
|
# "content": json.dumps(ret),
|
||||||
|
# "name": tool_call.function.name,
|
||||||
|
# }
|
||||||
|
|
||||||
|
# tool_metadata["tool_calls"].append(tool_response)
|
||||||
|
|
||||||
|
# if len(tool_metadata["tool_calls"]) == 0:
|
||||||
|
# message.status = "done"
|
||||||
|
# yield message
|
||||||
|
# return
|
||||||
|
|
||||||
|
# message_dict = LLMMessage(
|
||||||
|
# role=tool_message.get("role", "assistant"),
|
||||||
|
# content=tool_message.get("content", ""),
|
||||||
|
# tool_calls=[
|
||||||
|
# {
|
||||||
|
# "function": {
|
||||||
|
# "name": tc["function"]["name"],
|
||||||
|
# "arguments": tc["function"]["arguments"],
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
# for tc in tool_message.tool_calls
|
||||||
|
# ],
|
||||||
|
# )
|
||||||
|
|
||||||
|
# messages.append(message_dict)
|
||||||
|
# messages.extend(tool_metadata["tool_calls"])
|
||||||
|
|
||||||
|
# message.status = "thinking"
|
||||||
|
# message.content = "Incorporating tool results into response..."
|
||||||
|
# yield message
|
||||||
|
|
||||||
|
# # Decrease creativity when processing tool call requests
|
||||||
|
# message.content = ""
|
||||||
|
# start_time = time.perf_counter()
|
||||||
|
# for response in llm.chat(
|
||||||
|
# model=model,
|
||||||
|
# messages=messages,
|
||||||
|
# options={
|
||||||
|
# **message.metadata.options,
|
||||||
|
# },
|
||||||
|
# stream=True,
|
||||||
|
# ):
|
||||||
|
# # logger.info(f"LLM::Tools: {'done' if response.done else 'processing'} - {response.message}")
|
||||||
|
# message.status = "streaming"
|
||||||
|
# message.chunk = response.message.content
|
||||||
|
# message.content += message.chunk
|
||||||
|
|
||||||
|
# if not response.done:
|
||||||
|
# yield message
|
||||||
|
|
||||||
|
# if response.done:
|
||||||
|
# self.collect_metrics(response)
|
||||||
|
# message.metadata.eval_count += response.eval_count
|
||||||
|
# message.metadata.eval_duration += response.eval_duration
|
||||||
|
# message.metadata.prompt_eval_count += response.prompt_eval_count
|
||||||
|
# message.metadata.prompt_eval_duration += response.prompt_eval_duration
|
||||||
|
# self.context_tokens = (
|
||||||
|
# response.prompt_eval_count + response.eval_count
|
||||||
|
# )
|
||||||
|
# message.status = "done"
|
||||||
|
# yield message
|
||||||
|
|
||||||
|
# end_time = time.perf_counter()
|
||||||
|
# message.metadata.timers["llm_with_tools"] = end_time - start_time
|
||||||
|
# return
|
||||||
|
|
||||||
|
def collect_metrics(self, response):
|
||||||
|
self.metrics.tokens_prompt.labels(agent=self.agent_type).inc(
|
||||||
|
response.prompt_eval_count
|
||||||
|
)
|
||||||
|
self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.eval_count)
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self, llm: Any, model: str, query: ChatQuery, session_id: str, user_id: str, temperature=0.7
|
||||||
|
) -> AsyncGenerator[ChatMessage, None]:
|
||||||
|
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
|
|
||||||
|
chat_message = ChatMessage(
|
||||||
|
session_id=session_id,
|
||||||
|
prompt=query.prompt,
|
||||||
|
tunables=query.tunables,
|
||||||
|
status=ChatStatusType.PREPARING,
|
||||||
|
sender="user",
|
||||||
|
content="",
|
||||||
|
timestamp=datetime.now(UTC)
|
||||||
|
)
|
||||||
|
self.metrics.generate_count.labels(agent=self.agent_type).inc()
|
||||||
|
with self.metrics.generate_duration.labels(agent=self.agent_type).time():
|
||||||
|
# Create a pruned down message list based purely on the prompt and responses,
|
||||||
|
# discarding the full preamble generated by prepare_message
|
||||||
|
messages: List[LLMMessage] = [
|
||||||
|
LLMMessage(role="system", content=self.system_prompt)
|
||||||
|
]
|
||||||
|
messages.extend(
|
||||||
|
[
|
||||||
|
item
|
||||||
|
for m in self.conversation
|
||||||
|
for item in [
|
||||||
|
LLMMessage(role="user", content=m.prompt.strip() if m.prompt else ""),
|
||||||
|
LLMMessage(role="assistant", content=m.response.strip()),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Only the actual user query is provided with the full context message
|
||||||
|
messages.append(
|
||||||
|
LLMMessage(role="user", content=query.prompt.strip())
|
||||||
|
)
|
||||||
|
|
||||||
|
# message.messages = messages
|
||||||
|
chat_message.metadata = ChatMessageMetaData()
|
||||||
|
chat_message.metadata.options = {
|
||||||
|
"seed": 8911,
|
||||||
|
"num_ctx": self.context_size,
|
||||||
|
"temperature": temperature, # Higher temperature to encourage tool usage
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create a dict for storing various timing stats
|
||||||
|
chat_message.metadata.timers = {}
|
||||||
|
|
||||||
|
# use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
|
||||||
|
# message.metadata.tools = {
|
||||||
|
# "available": llm_tools(self.context.tools),
|
||||||
|
# "used": False,
|
||||||
|
# }
|
||||||
|
# tool_metadata = message.metadata.tools
|
||||||
|
|
||||||
|
# if use_tools:
|
||||||
|
# message.status = "thinking"
|
||||||
|
# message.content = f"Performing tool analysis step 1/2..."
|
||||||
|
# yield message
|
||||||
|
|
||||||
|
# logger.info("Checking for LLM tool usage")
|
||||||
|
# start_time = time.perf_counter()
|
||||||
|
# # Tools are enabled and available, so query the LLM with a short context of messages
|
||||||
|
# # in case the LLM did something like ask "Do you want me to run the tool?" and the
|
||||||
|
# # user said "Yes" -- need to keep the context in the thread.
|
||||||
|
# tool_metadata["messages"] = (
|
||||||
|
# [{"role": "system", "content": self.system_prompt}] + messages[-6:]
|
||||||
|
# if len(messages) >= 7
|
||||||
|
# else messages
|
||||||
|
# )
|
||||||
|
|
||||||
|
# response = llm.chat(
|
||||||
|
# model=model,
|
||||||
|
# messages=tool_metadata["messages"],
|
||||||
|
# tools=tool_metadata["available"],
|
||||||
|
# options={
|
||||||
|
# **message.metadata.options,
|
||||||
|
# },
|
||||||
|
# stream=False, # No need to stream the probe
|
||||||
|
# )
|
||||||
|
# self.collect_metrics(response)
|
||||||
|
|
||||||
|
# end_time = time.perf_counter()
|
||||||
|
# message.metadata.timers["tool_check"] = end_time - start_time
|
||||||
|
# if not response.message.tool_calls:
|
||||||
|
# logger.info("LLM indicates tools will not be used")
|
||||||
|
# # The LLM will not use tools, so disable use_tools so we can stream the full response
|
||||||
|
# use_tools = False
|
||||||
|
# else:
|
||||||
|
# tool_metadata["attempted"] = response.message.tool_calls
|
||||||
|
|
||||||
|
# if use_tools:
|
||||||
|
# logger.info("LLM indicates tools will be used")
|
||||||
|
|
||||||
|
# # Tools are enabled and available and the LLM indicated it will use them
|
||||||
|
# message.content = (
|
||||||
|
# f"Performing tool analysis step 2/2 (tool use suspected)..."
|
||||||
|
# )
|
||||||
|
# yield message
|
||||||
|
|
||||||
|
# logger.info(f"Performing LLM call with tools")
|
||||||
|
# start_time = time.perf_counter()
|
||||||
|
# response = llm.chat(
|
||||||
|
# model=model,
|
||||||
|
# messages=tool_metadata["messages"], # messages,
|
||||||
|
# tools=tool_metadata["available"],
|
||||||
|
# options={
|
||||||
|
# **message.metadata.options,
|
||||||
|
# },
|
||||||
|
# stream=False,
|
||||||
|
# )
|
||||||
|
# self.collect_metrics(response)
|
||||||
|
|
||||||
|
# end_time = time.perf_counter()
|
||||||
|
# message.metadata.timers["non_streaming"] = end_time - start_time
|
||||||
|
|
||||||
|
# if not response:
|
||||||
|
# message.status = "error"
|
||||||
|
# message.content = "No response from LLM."
|
||||||
|
# yield message
|
||||||
|
# return
|
||||||
|
|
||||||
|
# if response.message.tool_calls:
|
||||||
|
# tool_metadata["used"] = response.message.tool_calls
|
||||||
|
# # Process all yielded items from the handler
|
||||||
|
# start_time = time.perf_counter()
|
||||||
|
# async for message in self.process_tool_calls(
|
||||||
|
# llm=llm,
|
||||||
|
# model=model,
|
||||||
|
# message=message,
|
||||||
|
# tool_message=response.message,
|
||||||
|
# messages=messages,
|
||||||
|
# ):
|
||||||
|
# if message.status == "error":
|
||||||
|
# yield message
|
||||||
|
# return
|
||||||
|
# yield message
|
||||||
|
# end_time = time.perf_counter()
|
||||||
|
# message.metadata.timers["process_tool_calls"] = end_time - start_time
|
||||||
|
# message.status = "done"
|
||||||
|
# return
|
||||||
|
|
||||||
|
# logger.info("LLM indicated tools will be used, and then they weren't")
|
||||||
|
# message.content = response.message.content
|
||||||
|
# message.status = "done"
|
||||||
|
# yield message
|
||||||
|
# return
|
||||||
|
|
||||||
|
# not use_tools
|
||||||
|
chat_message.status = ChatStatusType.THINKING
|
||||||
|
chat_message.content = f"Generating response..."
|
||||||
|
yield chat_message
|
||||||
|
# Reset the response for streaming
|
||||||
|
chat_message.content = ""
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
for response in llm.chat(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
options={
|
||||||
|
**chat_message.metadata.options,
|
||||||
|
},
|
||||||
|
stream=True,
|
||||||
|
):
|
||||||
|
if not response:
|
||||||
|
chat_message.status = ChatStatusType.ERROR
|
||||||
|
chat_message.content = "No response from LLM."
|
||||||
|
yield chat_message
|
||||||
|
return
|
||||||
|
|
||||||
|
chat_message.status = ChatStatusType.STREAMING
|
||||||
|
chat_message.chunk = response.message.content
|
||||||
|
chat_message.content += chat_message.chunk
|
||||||
|
|
||||||
|
if not response.done:
|
||||||
|
yield chat_message
|
||||||
|
|
||||||
|
if response.done:
|
||||||
|
self.collect_metrics(response)
|
||||||
|
chat_message.metadata.eval_count += response.eval_count
|
||||||
|
chat_message.metadata.eval_duration += response.eval_duration
|
||||||
|
chat_message.metadata.prompt_eval_count += response.prompt_eval_count
|
||||||
|
chat_message.metadata.prompt_eval_duration += response.prompt_eval_duration
|
||||||
|
self.context_tokens = (
|
||||||
|
response.prompt_eval_count + response.eval_count
|
||||||
|
)
|
||||||
|
chat_message.status = ChatStatusType.DONE
|
||||||
|
yield chat_message
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
chat_message.metadata.timers["streamed"] = end_time - start_time
|
||||||
|
chat_message.status = ChatStatusType.DONE
|
||||||
|
self.conversation.append(chat_message)
|
||||||
|
return
|
||||||
|
|
||||||
|
# async def process_message(
|
||||||
|
# self, llm: Any, model: str, message: Message
|
||||||
|
# ) -> AsyncGenerator[Message, None]:
|
||||||
|
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
|
|
||||||
|
# self.metrics.process_count.labels(agent=self.agent_type).inc()
|
||||||
|
# with self.metrics.process_duration.labels(agent=self.agent_type).time():
|
||||||
|
|
||||||
|
# if not self.context:
|
||||||
|
# raise ValueError("Context is not set for this agent.")
|
||||||
|
|
||||||
|
# logger.info(
|
||||||
|
# "TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time"
|
||||||
|
# )
|
||||||
|
# spinner: List[str] = ["\\", "|", "/", "-"]
|
||||||
|
# tick: int = 0
|
||||||
|
# while self.context.processing:
|
||||||
|
# message.status = "waiting"
|
||||||
|
# message.content = (
|
||||||
|
# f"Busy processing another request. Please wait. {spinner[tick]}"
|
||||||
|
# )
|
||||||
|
# tick = (tick + 1) % len(spinner)
|
||||||
|
# yield message
|
||||||
|
# await asyncio.sleep(1) # Allow the event loop to process the write
|
||||||
|
|
||||||
|
# self.context.processing = True
|
||||||
|
|
||||||
|
# message.system_prompt = (
|
||||||
|
# f"<|system|>\n{self.system_prompt.strip()}\n</|system|>"
|
||||||
|
# )
|
||||||
|
# message.context_prompt = ""
|
||||||
|
# for p in message.preamble.keys():
|
||||||
|
# message.context_prompt += (
|
||||||
|
# f"\n<|{p}|>\n{message.preamble[p].strip()}\n</|{p}>\n\n"
|
||||||
|
# )
|
||||||
|
# message.context_prompt += f"{message.prompt}"
|
||||||
|
|
||||||
|
# # Estimate token length of new messages
|
||||||
|
# message.content = f"Optimizing context..."
|
||||||
|
# message.status = "thinking"
|
||||||
|
# yield message
|
||||||
|
|
||||||
|
# message.context_size = self.set_optimal_context_size(
|
||||||
|
# llm, model, prompt=message.context_prompt
|
||||||
|
# )
|
||||||
|
|
||||||
|
# message.content = f"Processing {'RAG augmented ' if message.metadata.rag else ''}query..."
|
||||||
|
# message.status = "thinking"
|
||||||
|
# yield message
|
||||||
|
|
||||||
|
# async for message in self.generate_llm_response(
|
||||||
|
# llm=llm, model=model, message=message
|
||||||
|
# ):
|
||||||
|
# # logger.info(f"LLM: {message.status} - {f'...{message.content[-20:]}' if len(message.content) > 20 else message.content}")
|
||||||
|
# if message.status == "error":
|
||||||
|
# yield message
|
||||||
|
# self.context.processing = False
|
||||||
|
# return
|
||||||
|
# yield message
|
||||||
|
|
||||||
|
# # Done processing, add message to conversation
|
||||||
|
# message.status = "done"
|
||||||
|
# self.conversation.add(message)
|
||||||
|
# self.context.processing = False
|
||||||
|
|
||||||
|
# return
|
||||||
|
|
||||||
|
|
||||||
|
# Register the base agent
|
||||||
|
agent_registry.register(Agent._agent_type, Agent)
|
88
src/backend/agents/general.py
Normal file
88
src/backend/agents/general.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import Literal, AsyncGenerator, ClassVar, Optional, Any
|
||||||
|
from datetime import datetime
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from .base import Agent, agent_registry
|
||||||
|
from logger import logger
|
||||||
|
|
||||||
|
from .registry import agent_registry
|
||||||
|
from models import ( ChatQuery, ChatMessage, Tunables, ChatStatusType)
|
||||||
|
|
||||||
|
system_message = f"""
|
||||||
|
Launched on {datetime.now().isoformat()}.
|
||||||
|
|
||||||
|
When answering queries, follow these steps:
|
||||||
|
|
||||||
|
- First analyze the query to determine if real-time information from the tools might be helpful
|
||||||
|
- Even when <|context|> or <|resume|> is provided, consider whether the tools would provide more current or comprehensive information
|
||||||
|
- Use the provided tools whenever they would enhance your response, regardless of whether context is also available
|
||||||
|
- When presenting weather forecasts, include relevant emojis immediately before the corresponding text. For example, for a sunny day, say \"☀️ Sunny\" or if the forecast says there will be \"rain showers, say \"🌧️ Rain showers\". Use this mapping for weather emojis: Sunny: ☀️, Cloudy: ☁️, Rainy: 🌧️, Snowy: ❄️
|
||||||
|
- When any combination of <|context|>, <|resume|> and tool outputs are relevant, synthesize information from all sources to provide the most complete answer
|
||||||
|
- Always prioritize the most up-to-date and relevant information, whether it comes from <|context|>, <|resume|> or tools
|
||||||
|
- If <|context|> and tool outputs contain conflicting information, prefer the tool outputs as they likely represent more current data
|
||||||
|
- If there is information in the <|context|> or <|resume|> sections to enhance the answer, incorporate it seamlessly and refer to it as 'the latest information' or 'recent data' instead of mentioning '<|context|>' (etc.) or quoting it directly.
|
||||||
|
- Avoid phrases like 'According to the <|context|>' or similar references to the <|context|> or <|resume|>.
|
||||||
|
|
||||||
|
CRITICAL INSTRUCTIONS FOR IMAGE GENERATION:
|
||||||
|
|
||||||
|
1. When the user requests to generate an image, inject the following into the response: <GenerateImage prompt="USER-PROMPT"/>. Do this when users request images, drawings, or visual content.
|
||||||
|
3. MANDATORY: You must respond with EXACTLY this format: <GenerateImage prompt="{{USER-PROMPT}}"/>
|
||||||
|
4. FORBIDDEN: DO NOT use markdown image syntax 
|
||||||
|
5. FORBIDDEN: DO NOT create fake URLs or file paths
|
||||||
|
6. FORBIDDEN: DO NOT use any other image embedding format
|
||||||
|
|
||||||
|
CORRECT EXAMPLE:
|
||||||
|
User: "Draw a cat"
|
||||||
|
Your response: "<GenerateImage prompt='Draw a cat'/>"
|
||||||
|
|
||||||
|
WRONG EXAMPLES (DO NOT DO THIS):
|
||||||
|
- 
|
||||||
|
- 
|
||||||
|
- <img src="...">
|
||||||
|
|
||||||
|
The <GenerateImage prompt="{{USER-PROMPT}}"/> format is the ONLY way to display images in this system.
|
||||||
|
DO NOT make up a URL for an image or provide markdown syntax for embedding an image. Only use <GenerateImage prompt="{{USER-PROMPT}}".
|
||||||
|
|
||||||
|
Always use tools, <|resume|>, and <|context|> when possible. Be concise, and never make up information. If you do not know the answer, say so.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Chat(Agent):
|
||||||
|
"""
|
||||||
|
Chat Agent
|
||||||
|
"""
|
||||||
|
|
||||||
|
agent_type: Literal["general"] = "general" # type: ignore
|
||||||
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
|
system_prompt: str = system_message
|
||||||
|
|
||||||
|
# async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
|
||||||
|
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||||
|
# if not self.context:
|
||||||
|
# raise ValueError("Context is not set for this agent.")
|
||||||
|
|
||||||
|
# async for message in super().prepare_message(message):
|
||||||
|
# if message.status != "done":
|
||||||
|
# yield message
|
||||||
|
|
||||||
|
# if message.preamble:
|
||||||
|
# excluded = {}
|
||||||
|
# preamble_types = [
|
||||||
|
# f"<|{p}|>" for p in message.preamble.keys() if p not in excluded
|
||||||
|
# ]
|
||||||
|
# preamble_types_AND = " and ".join(preamble_types)
|
||||||
|
# preamble_types_OR = " or ".join(preamble_types)
|
||||||
|
# message.preamble[
|
||||||
|
# "rules"
|
||||||
|
# ] = f"""\
|
||||||
|
# - Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_types_OR} or quoting it directly.
|
||||||
|
# - If there is no information in these sections, answer based on your knowledge, or use any available tools.
|
||||||
|
# - Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
||||||
|
# """
|
||||||
|
# message.preamble["question"] = "Respond to:"
|
||||||
|
|
||||||
|
|
||||||
|
# Register the base agent
|
||||||
|
agent_registry.register(Chat._agent_type, Chat)
|
33
src/backend/agents/registry.py
Normal file
33
src/backend/agents/registry.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import List, Dict, Optional, Type
|
||||||
|
|
||||||
|
# We'll use a registry pattern rather than hardcoded strings
|
||||||
|
class AgentRegistry:
|
||||||
|
"""Registry for agent types and classes"""
|
||||||
|
|
||||||
|
_registry: Dict[str, Type] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, agent_type: str, agent_class: Type) -> Type:
|
||||||
|
"""Register an agent class with its type"""
|
||||||
|
cls._registry[agent_type] = agent_class
|
||||||
|
return agent_class
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_class(cls, agent_type: str) -> Optional[Type]:
|
||||||
|
"""Get the class for a given agent type"""
|
||||||
|
return cls._registry.get(agent_type)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_types(cls) -> List[str]:
|
||||||
|
"""Get all registered agent types"""
|
||||||
|
return list(cls._registry.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_classes(cls) -> Dict[str, Type]:
|
||||||
|
"""Get all registered agent classes"""
|
||||||
|
return cls._registry.copy()
|
||||||
|
|
||||||
|
|
||||||
|
# Create a singleton instance
|
||||||
|
agent_registry = AgentRegistry()
|
@ -8,9 +8,10 @@ import sys
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from models import (
|
from models import (
|
||||||
UserStatus, UserType, SkillLevel, EmploymentType,
|
UserStatus, UserType, SkillLevel, EmploymentType,
|
||||||
Candidate, Employer, Location, Skill, AIParameters, AIModelType
|
Candidate, Employer, Location, Skill, AIModelType
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_model_creation():
|
def test_model_creation():
|
||||||
"""Test that we can create models successfully"""
|
"""Test that we can create models successfully"""
|
||||||
print("🧪 Testing model creation...")
|
print("🧪 Testing model creation...")
|
||||||
@ -118,41 +119,23 @@ def test_api_dict_format():
|
|||||||
def test_validation_constraints():
|
def test_validation_constraints():
|
||||||
"""Test that validation constraints work"""
|
"""Test that validation constraints work"""
|
||||||
print("\n🔒 Testing validation constraints...")
|
print("\n🔒 Testing validation constraints...")
|
||||||
|
|
||||||
# Test AI Parameters with constraints
|
|
||||||
valid_params = AIParameters(
|
|
||||||
name="Test Config",
|
|
||||||
model=AIModelType.GPT_4,
|
|
||||||
temperature=0.7, # Valid: 0-1
|
|
||||||
maxTokens=2000, # Valid: > 0
|
|
||||||
topP=0.95, # Valid: 0-1
|
|
||||||
frequencyPenalty=0.0, # Valid: -2 to 2
|
|
||||||
presencePenalty=0.0, # Valid: -2 to 2
|
|
||||||
isDefault=True,
|
|
||||||
createdAt=datetime.now(),
|
|
||||||
updatedAt=datetime.now()
|
|
||||||
)
|
|
||||||
print(f"✅ Valid AI parameters created")
|
|
||||||
|
|
||||||
# Test constraint violation
|
|
||||||
try:
|
try:
|
||||||
invalid_params = AIParameters(
|
# Create a candidate with invalid email
|
||||||
name="Invalid Config",
|
invalid_candidate = Candidate(
|
||||||
model=AIModelType.GPT_4,
|
email="invalid-email",
|
||||||
temperature=1.5, # Invalid: > 1
|
username="test_invalid",
|
||||||
maxTokens=2000,
|
|
||||||
topP=0.95,
|
|
||||||
frequencyPenalty=0.0,
|
|
||||||
presencePenalty=0.0,
|
|
||||||
isDefault=True,
|
|
||||||
createdAt=datetime.now(),
|
createdAt=datetime.now(),
|
||||||
updatedAt=datetime.now()
|
updatedAt=datetime.now(),
|
||||||
|
status=UserStatus.ACTIVE,
|
||||||
|
firstName="Jane",
|
||||||
|
lastName="Doe",
|
||||||
|
fullName="Jane Doe"
|
||||||
)
|
)
|
||||||
print("❌ Should have rejected invalid temperature")
|
print("❌ Validation should have failed but didn't")
|
||||||
return False
|
return False
|
||||||
except Exception:
|
except ValueError as e:
|
||||||
print(f"✅ Constraint validation working")
|
print(f"✅ Validation error caught: {e}")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def test_enum_values():
|
def test_enum_values():
|
||||||
@ -200,6 +183,7 @@ def main():
|
|||||||
print(f"\n❌ Test failed: {type(e).__name__}: {e}")
|
print(f"\n❌ Test failed: {type(e).__name__}: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
print(f"\n❌ {traceback.format_exc()}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -13,6 +13,7 @@ from datetime import datetime
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import stat
|
import stat
|
||||||
|
|
||||||
def run_command(command: str, description: str, cwd: str | None = None) -> bool:
|
def run_command(command: str, description: str, cwd: str | None = None) -> bool:
|
||||||
"""Run a command and return success status"""
|
"""Run a command and return success status"""
|
||||||
try:
|
try:
|
||||||
@ -68,9 +69,34 @@ except ImportError as e:
|
|||||||
print("Make sure pydantic is installed: pip install pydantic")
|
print("Make sure pydantic is installed: pip install pydantic")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
def python_type_to_typescript(python_type: Any) -> str:
|
def unwrap_annotated_type(python_type: Any) -> Any:
|
||||||
|
"""Unwrap Annotated types to get the actual type"""
|
||||||
|
# Handle typing_extensions.Annotated and typing.Annotated
|
||||||
|
origin = get_origin(python_type)
|
||||||
|
args = get_args(python_type)
|
||||||
|
|
||||||
|
# Check for Annotated types - more robust detection
|
||||||
|
if origin is not None and args:
|
||||||
|
origin_str = str(origin)
|
||||||
|
if 'Annotated' in origin_str or (hasattr(origin, '__name__') and origin.__name__ == 'Annotated'):
|
||||||
|
# Return the first argument (the actual type)
|
||||||
|
return unwrap_annotated_type(args[0]) # Recursive unwrap in case of nested annotations
|
||||||
|
|
||||||
|
return python_type
|
||||||
|
|
||||||
|
def python_type_to_typescript(python_type: Any, debug: bool = False) -> str:
|
||||||
"""Convert a Python type to TypeScript type string"""
|
"""Convert a Python type to TypeScript type string"""
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(f" 🔍 Converting type: {python_type} (type: {type(python_type)})")
|
||||||
|
|
||||||
|
# First unwrap any Annotated types
|
||||||
|
original_type = python_type
|
||||||
|
python_type = unwrap_annotated_type(python_type)
|
||||||
|
|
||||||
|
if debug and original_type != python_type:
|
||||||
|
print(f" 🔄 Unwrapped: {original_type} -> {python_type}")
|
||||||
|
|
||||||
# Handle None/null
|
# Handle None/null
|
||||||
if python_type is type(None):
|
if python_type is type(None):
|
||||||
return "null"
|
return "null"
|
||||||
@ -79,6 +105,8 @@ def python_type_to_typescript(python_type: Any) -> str:
|
|||||||
if python_type == str:
|
if python_type == str:
|
||||||
return "string"
|
return "string"
|
||||||
elif python_type == int or python_type == float:
|
elif python_type == int or python_type == float:
|
||||||
|
if debug:
|
||||||
|
print(f" ✅ Converting {python_type} to number")
|
||||||
return "number"
|
return "number"
|
||||||
elif python_type == bool:
|
elif python_type == bool:
|
||||||
return "boolean"
|
return "boolean"
|
||||||
@ -91,30 +119,33 @@ def python_type_to_typescript(python_type: Any) -> str:
|
|||||||
origin = get_origin(python_type)
|
origin = get_origin(python_type)
|
||||||
args = get_args(python_type)
|
args = get_args(python_type)
|
||||||
|
|
||||||
|
if debug and origin:
|
||||||
|
print(f" 🔍 Generic type - origin: {origin}, args: {args}")
|
||||||
|
|
||||||
if origin is Union:
|
if origin is Union:
|
||||||
# Handle Optional (Union[T, None])
|
# Handle Optional (Union[T, None])
|
||||||
if len(args) == 2 and type(None) in args:
|
if len(args) == 2 and type(None) in args:
|
||||||
non_none_type = next(arg for arg in args if arg is not type(None))
|
non_none_type = next(arg for arg in args if arg is not type(None))
|
||||||
return python_type_to_typescript(non_none_type)
|
return python_type_to_typescript(non_none_type, debug)
|
||||||
|
|
||||||
# Handle other unions
|
# Handle other unions
|
||||||
union_types = [python_type_to_typescript(arg) for arg in args if arg is not type(None)]
|
union_types = [python_type_to_typescript(arg, debug) for arg in args if arg is not type(None)]
|
||||||
return " | ".join(union_types)
|
return " | ".join(union_types)
|
||||||
|
|
||||||
elif origin is list or origin is List:
|
elif origin is list or origin is List:
|
||||||
if args:
|
if args:
|
||||||
item_type = python_type_to_typescript(args[0])
|
item_type = python_type_to_typescript(args[0], debug)
|
||||||
return f"Array<{item_type}>"
|
return f"Array<{item_type}>"
|
||||||
return "Array<any>"
|
return "Array<any>"
|
||||||
|
|
||||||
elif origin is dict or origin is Dict:
|
elif origin is dict or origin is Dict:
|
||||||
if len(args) == 2:
|
if len(args) == 2:
|
||||||
key_type = python_type_to_typescript(args[0])
|
key_type = python_type_to_typescript(args[0], debug)
|
||||||
value_type = python_type_to_typescript(args[1])
|
value_type = python_type_to_typescript(args[1], debug)
|
||||||
return f"Record<{key_type}, {value_type}>"
|
return f"Record<{key_type}, {value_type}>"
|
||||||
return "Record<string, any>"
|
return "Record<string, any>"
|
||||||
|
|
||||||
# Handle Literal types - UPDATED SECTION
|
# Handle Literal types
|
||||||
if hasattr(python_type, '__origin__') and str(python_type.__origin__).endswith('Literal'):
|
if hasattr(python_type, '__origin__') and str(python_type.__origin__).endswith('Literal'):
|
||||||
if args:
|
if args:
|
||||||
literal_values = []
|
literal_values = []
|
||||||
@ -155,6 +186,8 @@ def python_type_to_typescript(python_type: Any) -> str:
|
|||||||
return "string"
|
return "string"
|
||||||
|
|
||||||
# Default fallback
|
# Default fallback
|
||||||
|
if debug:
|
||||||
|
print(f" ⚠️ Falling back to 'any' for type: {python_type}")
|
||||||
return "any"
|
return "any"
|
||||||
|
|
||||||
def snake_to_camel(snake_str: str) -> str:
|
def snake_to_camel(snake_str: str) -> str:
|
||||||
@ -162,32 +195,148 @@ def snake_to_camel(snake_str: str) -> str:
|
|||||||
components = snake_str.split('_')
|
components = snake_str.split('_')
|
||||||
return components[0] + ''.join(x.title() for x in components[1:])
|
return components[0] + ''.join(x.title() for x in components[1:])
|
||||||
|
|
||||||
def process_pydantic_model(model_class) -> Dict[str, Any]:
|
def is_field_optional(field_info: Any, field_type: Any, debug: bool = False) -> bool:
|
||||||
|
"""Determine if a field should be optional in TypeScript"""
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(f" 🔍 Analyzing field optionality:")
|
||||||
|
|
||||||
|
# First, check if the type itself is Optional (Union with None)
|
||||||
|
origin = get_origin(field_type)
|
||||||
|
args = get_args(field_type)
|
||||||
|
is_union_with_none = origin is Union and type(None) in args
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(f" └─ Type is Optional[T]: {is_union_with_none}")
|
||||||
|
|
||||||
|
# If the type is Optional[T], it's always optional regardless of Field settings
|
||||||
|
if is_union_with_none:
|
||||||
|
if debug:
|
||||||
|
print(f" └─ RESULT: Optional (type is Optional[T])")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# For non-Optional types, check Field settings and defaults
|
||||||
|
|
||||||
|
# Check for default factory (makes field optional)
|
||||||
|
has_default_factory = hasattr(field_info, 'default_factory') and field_info.default_factory is not None
|
||||||
|
if debug:
|
||||||
|
print(f" └─ Has default factory: {has_default_factory}")
|
||||||
|
|
||||||
|
if has_default_factory:
|
||||||
|
if debug:
|
||||||
|
print(f" └─ RESULT: Optional (has default factory)")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check the default value - this is the tricky part
|
||||||
|
if hasattr(field_info, 'default'):
|
||||||
|
default_val = field_info.default
|
||||||
|
if debug:
|
||||||
|
print(f" └─ Has default attribute: {repr(default_val)} (type: {type(default_val)})")
|
||||||
|
|
||||||
|
# Check for different types of "no default" markers
|
||||||
|
# Pydantic uses various markers for "no default"
|
||||||
|
if default_val is ...: # Ellipsis
|
||||||
|
if debug:
|
||||||
|
print(f" └─ RESULT: Required (default is Ellipsis)")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for Pydantic's internal "PydanticUndefined" or similar markers
|
||||||
|
default_str = str(default_val)
|
||||||
|
default_type_str = str(type(default_val))
|
||||||
|
|
||||||
|
# Common patterns for "undefined" in Pydantic
|
||||||
|
undefined_patterns = [
|
||||||
|
'PydanticUndefined',
|
||||||
|
'Undefined',
|
||||||
|
'_Unset',
|
||||||
|
'UNSET',
|
||||||
|
'NotSet',
|
||||||
|
'_MISSING'
|
||||||
|
]
|
||||||
|
|
||||||
|
is_undefined_marker = any(pattern in default_str or pattern in default_type_str
|
||||||
|
for pattern in undefined_patterns)
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(f" └─ Checking for undefined markers in: {default_str} | {default_type_str}")
|
||||||
|
print(f" └─ Is undefined marker: {is_undefined_marker}")
|
||||||
|
|
||||||
|
if is_undefined_marker:
|
||||||
|
if debug:
|
||||||
|
print(f" └─ RESULT: Required (default is undefined marker)")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Any other actual default value makes it optional
|
||||||
|
if debug:
|
||||||
|
print(f" └─ RESULT: Optional (has actual default value)")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
if debug:
|
||||||
|
print(f" └─ No default attribute found")
|
||||||
|
|
||||||
|
# If no default attribute exists, check Pydantic's required flag
|
||||||
|
if hasattr(field_info, 'is_required'):
|
||||||
|
try:
|
||||||
|
is_required = field_info.is_required()
|
||||||
|
if debug:
|
||||||
|
print(f" └─ is_required(): {is_required}")
|
||||||
|
return not is_required
|
||||||
|
except:
|
||||||
|
if debug:
|
||||||
|
print(f" └─ is_required() failed")
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Check the 'required' attribute (Pydantic v1 style)
|
||||||
|
if hasattr(field_info, 'required'):
|
||||||
|
is_required = field_info.required
|
||||||
|
if debug:
|
||||||
|
print(f" └─ required attribute: {is_required}")
|
||||||
|
return not is_required
|
||||||
|
|
||||||
|
# Default: if type is not Optional and no clear default, it's required (not optional)
|
||||||
|
if debug:
|
||||||
|
print(f" └─ RESULT: Required (fallback - no Optional type, no default)")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def process_pydantic_model(model_class, debug: bool = False) -> Dict[str, Any]:
|
||||||
"""Process a Pydantic model and return TypeScript interface definition"""
|
"""Process a Pydantic model and return TypeScript interface definition"""
|
||||||
interface_name = model_class.__name__
|
interface_name = model_class.__name__
|
||||||
properties = []
|
properties = []
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(f" 🔍 Processing model: {interface_name}")
|
||||||
|
|
||||||
# Get fields from the model
|
# Get fields from the model
|
||||||
if hasattr(model_class, 'model_fields'):
|
if hasattr(model_class, 'model_fields'):
|
||||||
# Pydantic v2
|
# Pydantic v2
|
||||||
fields = model_class.model_fields
|
fields = model_class.model_fields
|
||||||
for field_name, field_info in fields.items():
|
for field_name, field_info in fields.items():
|
||||||
ts_name = snake_to_camel(field_name)
|
if debug:
|
||||||
|
print(f" 📝 Field: {field_name}")
|
||||||
|
print(f" Field info: {field_info}")
|
||||||
|
print(f" Default: {getattr(field_info, 'default', 'NO_DEFAULT')}")
|
||||||
|
|
||||||
# Check for alias
|
# Use alias if available, otherwise convert snake_case to camelCase
|
||||||
if hasattr(field_info, 'alias') and field_info.alias:
|
if hasattr(field_info, 'alias') and field_info.alias:
|
||||||
ts_name = field_info.alias
|
ts_name = field_info.alias
|
||||||
|
else:
|
||||||
|
ts_name = snake_to_camel(field_name)
|
||||||
|
|
||||||
# Get type annotation
|
# Get type annotation
|
||||||
field_type = getattr(field_info, 'annotation', str)
|
field_type = getattr(field_info, 'annotation', str)
|
||||||
ts_type = python_type_to_typescript(field_type)
|
if debug:
|
||||||
|
print(f" Raw type: {field_type}")
|
||||||
|
|
||||||
|
ts_type = python_type_to_typescript(field_type, debug)
|
||||||
|
|
||||||
# Check if optional
|
# Check if optional
|
||||||
is_optional = False
|
is_optional = is_field_optional(field_info, field_type, debug)
|
||||||
if hasattr(field_info, 'is_required'):
|
|
||||||
is_optional = not field_info.is_required()
|
if debug:
|
||||||
elif hasattr(field_info, 'default'):
|
print(f" TS name: {ts_name}")
|
||||||
is_optional = field_info.default is not None
|
print(f" TS type: {ts_type}")
|
||||||
|
print(f" Optional: {is_optional}")
|
||||||
|
print()
|
||||||
|
|
||||||
properties.append({
|
properties.append({
|
||||||
'name': ts_name,
|
'name': ts_name,
|
||||||
@ -199,17 +348,45 @@ def process_pydantic_model(model_class) -> Dict[str, Any]:
|
|||||||
# Pydantic v1
|
# Pydantic v1
|
||||||
fields = model_class.__fields__
|
fields = model_class.__fields__
|
||||||
for field_name, field_info in fields.items():
|
for field_name, field_info in fields.items():
|
||||||
ts_name = snake_to_camel(field_name)
|
if debug:
|
||||||
|
print(f" 📝 Field: {field_name} (Pydantic v1)")
|
||||||
|
print(f" Field info: {field_info}")
|
||||||
|
|
||||||
|
# Use alias if available, otherwise convert snake_case to camelCase
|
||||||
if hasattr(field_info, 'alias') and field_info.alias:
|
if hasattr(field_info, 'alias') and field_info.alias:
|
||||||
ts_name = field_info.alias
|
ts_name = field_info.alias
|
||||||
|
else:
|
||||||
|
ts_name = snake_to_camel(field_name)
|
||||||
|
|
||||||
field_type = getattr(field_info, 'annotation', getattr(field_info, 'type_', str))
|
field_type = getattr(field_info, 'annotation', getattr(field_info, 'type_', str))
|
||||||
ts_type = python_type_to_typescript(field_type)
|
if debug:
|
||||||
|
print(f" Raw type: {field_type}")
|
||||||
|
|
||||||
is_optional = not getattr(field_info, 'required', True)
|
ts_type = python_type_to_typescript(field_type, debug)
|
||||||
if hasattr(field_info, 'default') and field_info.default is not None:
|
|
||||||
is_optional = True
|
# For Pydantic v1, check required and default
|
||||||
|
is_optional = is_field_optional(field_info, field_type)
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(f" TS name: {ts_name}")
|
||||||
|
print(f" TS type: {ts_type}")
|
||||||
|
print(f" Optional: {is_optional}")
|
||||||
|
|
||||||
|
# Debug the optional logic
|
||||||
|
origin = get_origin(field_type)
|
||||||
|
args = get_args(field_type)
|
||||||
|
is_union_with_none = origin is Union and type(None) in args
|
||||||
|
has_default = hasattr(field_info, 'default')
|
||||||
|
has_default_factory = hasattr(field_info, 'default_factory') and field_info.default_factory is not None
|
||||||
|
|
||||||
|
print(f" └─ Type is Optional: {is_union_with_none}")
|
||||||
|
if has_default:
|
||||||
|
default_val = field_info.default
|
||||||
|
print(f" └─ Has default: {default_val} (is ...? {default_val is ...})")
|
||||||
|
else:
|
||||||
|
print(f" └─ No default attribute")
|
||||||
|
print(f" └─ Has default factory: {has_default_factory}")
|
||||||
|
print()
|
||||||
|
|
||||||
properties.append({
|
properties.append({
|
||||||
'name': ts_name,
|
'name': ts_name,
|
||||||
@ -233,7 +410,7 @@ def process_enum(enum_class) -> Dict[str, Any]:
|
|||||||
'values': " | ".join(values)
|
'values': " | ".join(values)
|
||||||
}
|
}
|
||||||
|
|
||||||
def generate_typescript_interfaces(source_file: str):
|
def generate_typescript_interfaces(source_file: str, debug: bool = False):
|
||||||
"""Generate TypeScript interfaces from models"""
|
"""Generate TypeScript interfaces from models"""
|
||||||
|
|
||||||
print(f"📖 Scanning {source_file} for Pydantic models and enums...")
|
print(f"📖 Scanning {source_file} for Pydantic models and enums...")
|
||||||
@ -270,7 +447,7 @@ def generate_typescript_interfaces(source_file: str):
|
|||||||
issubclass(obj, BaseModel) and
|
issubclass(obj, BaseModel) and
|
||||||
obj != BaseModel):
|
obj != BaseModel):
|
||||||
|
|
||||||
interface = process_pydantic_model(obj)
|
interface = process_pydantic_model(obj, debug)
|
||||||
interfaces.append(interface)
|
interfaces.append(interface)
|
||||||
print(f" ✅ Found Pydantic model: {name}")
|
print(f" ✅ Found Pydantic model: {name}")
|
||||||
|
|
||||||
@ -284,6 +461,9 @@ def generate_typescript_interfaces(source_file: str):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f" ⚠️ Warning: Error processing {name}: {e}")
|
print(f" ⚠️ Warning: Error processing {name}: {e}")
|
||||||
|
if debug:
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print(f"\n📊 Found {len(interfaces)} interfaces and {len(enums)} enums")
|
print(f"\n📊 Found {len(interfaces)} interfaces and {len(enums)} enums")
|
||||||
@ -362,7 +542,8 @@ Examples:
|
|||||||
python generate_types.py --source models.py --output types.ts # Specify files
|
python generate_types.py --source models.py --output types.ts # Specify files
|
||||||
python generate_types.py --skip-test # Skip model validation
|
python generate_types.py --skip-test # Skip model validation
|
||||||
python generate_types.py --skip-compile # Skip TS compilation
|
python generate_types.py --skip-compile # Skip TS compilation
|
||||||
python generate_types.py --source models.py --output types.ts --skip-test --skip-compile
|
python generate_types.py --debug # Enable debug output
|
||||||
|
python generate_types.py --source models.py --output types.ts --skip-test --skip-compile --debug
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -390,6 +571,12 @@ Examples:
|
|||||||
help='Skip TypeScript compilation check after generation'
|
help='Skip TypeScript compilation check after generation'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--debug',
|
||||||
|
action='store_true',
|
||||||
|
help='Enable debug output to troubleshoot type conversion issues'
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--version', '-v',
|
'--version', '-v',
|
||||||
action='version',
|
action='version',
|
||||||
@ -422,7 +609,11 @@ Examples:
|
|||||||
|
|
||||||
# Step 3: Generate TypeScript content
|
# Step 3: Generate TypeScript content
|
||||||
print("🔄 Generating TypeScript types...")
|
print("🔄 Generating TypeScript types...")
|
||||||
ts_content = generate_typescript_interfaces(args.source)
|
if args.debug:
|
||||||
|
print("🐛 Debug mode enabled - detailed output follows:")
|
||||||
|
print()
|
||||||
|
|
||||||
|
ts_content = generate_typescript_interfaces(args.source, args.debug)
|
||||||
|
|
||||||
if ts_content is None:
|
if ts_content is None:
|
||||||
print("❌ Failed to generate TypeScript content")
|
print("❌ Failed to generate TypeScript content")
|
||||||
|
41
src/backend/llm_manager.py
Normal file
41
src/backend/llm_manager.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import ollama
|
||||||
|
import defines
|
||||||
|
|
||||||
|
_llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
|
||||||
|
|
||||||
|
class llm_manager:
|
||||||
|
"""
|
||||||
|
A class to manage LLM operations using the Ollama client.
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def get_llm() -> ollama.Client: # type: ignore
|
||||||
|
"""
|
||||||
|
Get the Ollama client instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of the Ollama client.
|
||||||
|
"""
|
||||||
|
return _llm
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_models() -> list[str]:
|
||||||
|
"""
|
||||||
|
Get a list of available models from the Ollama client.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of model names.
|
||||||
|
"""
|
||||||
|
return _llm.models()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_model_info(model_name: str) -> dict:
|
||||||
|
"""
|
||||||
|
Get information about a specific model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: The name of the model to retrieve information for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing model information.
|
||||||
|
"""
|
||||||
|
return _llm.model(model_name)
|
@ -1,7 +1,7 @@
|
|||||||
from fastapi import FastAPI, HTTPException, Depends, Query, Path, Body, status, APIRouter, Request # type: ignore
|
from fastapi import FastAPI, HTTPException, Depends, Query, Path, Body, status, APIRouter, Request # type: ignore
|
||||||
from fastapi.middleware.cors import CORSMiddleware # type: ignore
|
from fastapi.middleware.cors import CORSMiddleware # type: ignore
|
||||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials # type: ignore
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials # type: ignore
|
||||||
from fastapi.responses import JSONResponse # type: ignore
|
from fastapi.responses import JSONResponse, StreamingResponse# type: ignore
|
||||||
from fastapi.staticfiles import StaticFiles # type: ignore
|
from fastapi.staticfiles import StaticFiles # type: ignore
|
||||||
import uvicorn # type: ignore
|
import uvicorn # type: ignore
|
||||||
from typing import List, Optional, Dict, Any
|
from typing import List, Optional, Dict, Any
|
||||||
@ -15,6 +15,7 @@ import re
|
|||||||
import asyncio
|
import asyncio
|
||||||
import signal
|
import signal
|
||||||
import json
|
import json
|
||||||
|
import traceback
|
||||||
|
|
||||||
# Prometheus
|
# Prometheus
|
||||||
from prometheus_client import Summary # type: ignore
|
from prometheus_client import Summary # type: ignore
|
||||||
@ -24,25 +25,24 @@ from prometheus_client import CollectorRegistry, Counter # type: ignore
|
|||||||
# Import Pydantic models
|
# Import Pydantic models
|
||||||
from models import (
|
from models import (
|
||||||
# User models
|
# User models
|
||||||
Candidate, Employer, BaseUser, Guest, Authentication, AuthResponse,
|
Candidate, Employer, BaseUserWithType, BaseUser, Guest, Authentication, AuthResponse,
|
||||||
|
|
||||||
# Job models
|
# Job models
|
||||||
Job, JobApplication, ApplicationStatus,
|
Job, JobApplication, ApplicationStatus,
|
||||||
|
|
||||||
# Chat models
|
# Chat models
|
||||||
ChatSession, ChatMessage, ChatContext,
|
ChatSession, ChatMessage, ChatContext, ChatQuery,
|
||||||
|
|
||||||
# AI models
|
|
||||||
AIParameters,
|
|
||||||
|
|
||||||
# Supporting models
|
# Supporting models
|
||||||
Location, Skill, WorkExperience, Education
|
Location, Skill, WorkExperience, Education
|
||||||
)
|
)
|
||||||
|
|
||||||
import defines
|
import defines
|
||||||
|
import agents
|
||||||
from logger import logger
|
from logger import logger
|
||||||
from database import RedisDatabase, redis_manager, DatabaseManager
|
from database import RedisDatabase, redis_manager, DatabaseManager
|
||||||
from metrics import Metrics
|
from metrics import Metrics
|
||||||
|
from llm_manager import llm_manager
|
||||||
|
|
||||||
# Initialize FastAPI app
|
# Initialize FastAPI app
|
||||||
# ============================
|
# ============================
|
||||||
@ -140,20 +140,45 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
|||||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||||
return encoded_jwt
|
return encoded_jwt
|
||||||
|
|
||||||
def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||||
|
"""Verify token and check if it's blacklisted"""
|
||||||
try:
|
try:
|
||||||
|
# First decode the token
|
||||||
payload = jwt.decode(credentials.credentials, SECRET_KEY, algorithms=[ALGORITHM])
|
payload = jwt.decode(credentials.credentials, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
user_id: str = payload.get("sub")
|
user_id: str = payload.get("sub")
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
raise HTTPException(status_code=401, detail="Invalid authentication credentials")
|
raise HTTPException(status_code=401, detail="Invalid authentication credentials")
|
||||||
|
|
||||||
|
# Check if token is blacklisted
|
||||||
|
redis_client = redis_manager.get_client()
|
||||||
|
blacklist_key = f"blacklisted_token:{credentials.credentials}"
|
||||||
|
|
||||||
|
is_blacklisted = await redis_client.exists(blacklist_key)
|
||||||
|
if is_blacklisted:
|
||||||
|
logger.warning(f"🚫 Attempt to use blacklisted token for user {user_id}")
|
||||||
|
raise HTTPException(status_code=401, detail="Token has been revoked")
|
||||||
|
|
||||||
|
# Optional: Check if all user tokens are revoked (for "logout from all devices")
|
||||||
|
# user_revoked_key = f"user_tokens_revoked:{user_id}"
|
||||||
|
# user_tokens_revoked_at = await redis_client.get(user_revoked_key)
|
||||||
|
# if user_tokens_revoked_at:
|
||||||
|
# revoked_timestamp = datetime.fromisoformat(user_tokens_revoked_at.decode())
|
||||||
|
# token_issued_at = datetime.fromtimestamp(payload.get("iat", 0), UTC)
|
||||||
|
# if token_issued_at < revoked_timestamp:
|
||||||
|
# raise HTTPException(status_code=401, detail="All user tokens have been revoked")
|
||||||
|
|
||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
except jwt.PyJWTError:
|
except jwt.PyJWTError:
|
||||||
raise HTTPException(status_code=401, detail="Invalid authentication credentials")
|
raise HTTPException(status_code=401, detail="Invalid authentication credentials")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Token verification error: {e}")
|
||||||
|
raise HTTPException(status_code=401, detail="Token verification failed")
|
||||||
|
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
user_id: str = Depends(verify_token),
|
user_id: str = Depends(verify_token_with_blacklist),
|
||||||
database: RedisDatabase = Depends(lambda: db_manager.get_database())
|
database: RedisDatabase = Depends(lambda: db_manager.get_database())
|
||||||
):
|
) -> BaseUserWithType:
|
||||||
"""Get current user from database"""
|
"""Get current user from database"""
|
||||||
try:
|
try:
|
||||||
# Check candidates first
|
# Check candidates first
|
||||||
@ -321,12 +346,147 @@ async def login(
|
|||||||
return create_success_response(auth_response.model_dump(by_alias=True))
|
return create_success_response(auth_response.model_dump(by_alias=True))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Login error: {e}")
|
logger.error(f"⚠️ Login error: {e}")
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
content=create_error_response("LOGIN_ERROR", str(e))
|
content=create_error_response("LOGIN_ERROR", str(e))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@api_router.post("/auth/logout")
|
||||||
|
async def logout(
|
||||||
|
access_token: str = Body(..., alias="accessToken"),
|
||||||
|
refresh_token: str = Body(..., alias="refreshToken"),
|
||||||
|
current_user = Depends(get_current_user),
|
||||||
|
database: RedisDatabase = Depends(get_database)
|
||||||
|
):
|
||||||
|
"""Logout endpoint - revokes both access and refresh tokens"""
|
||||||
|
logger.info(f"🔑 User {current_user.id} is logging out")
|
||||||
|
try:
|
||||||
|
# Verify refresh token
|
||||||
|
try:
|
||||||
|
refresh_payload = jwt.decode(refresh_token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
user_id = refresh_payload.get("sub")
|
||||||
|
token_type = refresh_payload.get("type")
|
||||||
|
refresh_exp = refresh_payload.get("exp")
|
||||||
|
|
||||||
|
if not user_id or token_type != "refresh":
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=401,
|
||||||
|
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
|
||||||
|
)
|
||||||
|
except jwt.PyJWTError as e:
|
||||||
|
logger.warning(f"Invalid refresh token during logout: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=401,
|
||||||
|
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that the refresh token belongs to the current user
|
||||||
|
if user_id != current_user.id:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=403,
|
||||||
|
content=create_error_response("FORBIDDEN", "Token does not belong to current user")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get Redis client
|
||||||
|
redis_client = redis_manager.get_client()
|
||||||
|
|
||||||
|
# Revoke refresh token (blacklist it until its natural expiration)
|
||||||
|
refresh_ttl = max(0, refresh_exp - int(datetime.now(UTC).timestamp()))
|
||||||
|
if refresh_ttl > 0:
|
||||||
|
await redis_client.setex(
|
||||||
|
f"blacklisted_token:{refresh_token}",
|
||||||
|
refresh_ttl,
|
||||||
|
json.dumps({
|
||||||
|
"user_id": user_id,
|
||||||
|
"token_type": "refresh",
|
||||||
|
"revoked_at": datetime.now(UTC).isoformat(),
|
||||||
|
"reason": "user_logout"
|
||||||
|
})
|
||||||
|
)
|
||||||
|
logger.info(f"🔒 Blacklisted refresh token for user {user_id}")
|
||||||
|
|
||||||
|
# If access token is provided, revoke it too
|
||||||
|
if access_token:
|
||||||
|
try:
|
||||||
|
access_payload = jwt.decode(access_token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
access_user_id = access_payload.get("sub")
|
||||||
|
access_exp = access_payload.get("exp")
|
||||||
|
|
||||||
|
# Verify access token belongs to same user
|
||||||
|
if access_user_id == user_id:
|
||||||
|
access_ttl = max(0, access_exp - int(datetime.now(UTC).timestamp()))
|
||||||
|
if access_ttl > 0:
|
||||||
|
await redis_client.setex(
|
||||||
|
f"blacklisted_token:{access_token}",
|
||||||
|
access_ttl,
|
||||||
|
json.dumps({
|
||||||
|
"user_id": user_id,
|
||||||
|
"token_type": "access",
|
||||||
|
"revoked_at": datetime.now(UTC).isoformat(),
|
||||||
|
"reason": "user_logout"
|
||||||
|
})
|
||||||
|
)
|
||||||
|
logger.info(f"🔒 Blacklisted access token for user {user_id}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Access token user mismatch during logout: {access_user_id} != {user_id}")
|
||||||
|
except jwt.PyJWTError as e:
|
||||||
|
logger.warning(f"Invalid access token during logout (non-critical): {e}")
|
||||||
|
# Don't fail logout if access token is invalid
|
||||||
|
|
||||||
|
# Optional: Revoke all tokens for this user (for "logout from all devices")
|
||||||
|
# Uncomment the following lines if you want to implement this feature:
|
||||||
|
#
|
||||||
|
# await redis_client.setex(
|
||||||
|
# f"user_tokens_revoked:{user_id}",
|
||||||
|
# timedelta(days=30).total_seconds(), # Max refresh token lifetime
|
||||||
|
# datetime.now(UTC).isoformat()
|
||||||
|
# )
|
||||||
|
|
||||||
|
logger.info(f"🔑 User {user_id} logged out successfully")
|
||||||
|
return create_success_response({
|
||||||
|
"message": "Logged out successfully",
|
||||||
|
"tokensRevoked": {
|
||||||
|
"refreshToken": True,
|
||||||
|
"accessToken": bool(access_token)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"⚠️ Logout error: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=create_error_response("LOGOUT_ERROR", str(e))
|
||||||
|
)
|
||||||
|
|
||||||
|
@api_router.post("/auth/logout-all")
|
||||||
|
async def logout_all_devices(
|
||||||
|
current_user = Depends(get_current_user),
|
||||||
|
database: RedisDatabase = Depends(get_database)
|
||||||
|
):
|
||||||
|
"""Logout from all devices by revoking all tokens for the user"""
|
||||||
|
try:
|
||||||
|
redis_client = redis_manager.get_client()
|
||||||
|
|
||||||
|
# Set a timestamp that invalidates all tokens issued before this moment
|
||||||
|
await redis_client.setex(
|
||||||
|
f"user_tokens_revoked:{current_user.id}",
|
||||||
|
int(timedelta(days=30).total_seconds()), # Max refresh token lifetime
|
||||||
|
datetime.now(UTC).isoformat()
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"🔒 All tokens revoked for user {current_user.id}")
|
||||||
|
return create_success_response({
|
||||||
|
"message": "Logged out from all devices successfully"
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"⚠️ Logout all devices error: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=create_error_response("LOGOUT_ALL_ERROR", str(e))
|
||||||
|
)
|
||||||
|
|
||||||
@api_router.post("/auth/refresh")
|
@api_router.post("/auth/refresh")
|
||||||
async def refresh_token_endpoint(
|
async def refresh_token_endpoint(
|
||||||
refreshToken: str = Body(..., alias="refreshToken"),
|
refreshToken: str = Body(..., alias="refreshToken"),
|
||||||
@ -427,21 +587,33 @@ async def create_candidate(
|
|||||||
content=create_error_response("CREATION_FAILED", str(e))
|
content=create_error_response("CREATION_FAILED", str(e))
|
||||||
)
|
)
|
||||||
|
|
||||||
@api_router.get("/candidates/{candidate_id}")
|
@api_router.get("/candidates/{username}")
|
||||||
async def get_candidate(
|
async def get_candidate(
|
||||||
candidate_id: str = Path(...),
|
username: str = Path(...),
|
||||||
database: RedisDatabase = Depends(get_database)
|
database: RedisDatabase = Depends(get_database)
|
||||||
):
|
):
|
||||||
"""Get a candidate by ID"""
|
"""Get a candidate by username"""
|
||||||
try:
|
try:
|
||||||
candidate_data = await database.get_candidate(candidate_id)
|
all_candidates_data = await database.get_all_candidates()
|
||||||
if not candidate_data:
|
candidates_list = [Candidate.model_validate(data) for data in all_candidates_data.values()]
|
||||||
|
|
||||||
|
# Normalize username to lowercase for case-insensitive search
|
||||||
|
query_lower = username.lower()
|
||||||
|
|
||||||
|
# Filter by search query
|
||||||
|
candidates_list = [
|
||||||
|
c for c in candidates_list
|
||||||
|
if (query_lower == c.email.lower() or
|
||||||
|
query_lower == c.username.lower())
|
||||||
|
]
|
||||||
|
|
||||||
|
if not len(candidates_list):
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=404,
|
status_code=404,
|
||||||
content=create_error_response("NOT_FOUND", "Candidate not found")
|
content=create_error_response("NOT_FOUND", "Candidate not found")
|
||||||
)
|
)
|
||||||
|
|
||||||
candidate = Candidate.model_validate(candidate_data)
|
candidate = Candidate.model_validate(candidates_list[0])
|
||||||
return create_success_response(candidate.model_dump(by_alias=True))
|
return create_success_response(candidate.model_dump(by_alias=True))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -558,6 +730,7 @@ async def search_candidates(
|
|||||||
if (query_lower in c.first_name.lower() or
|
if (query_lower in c.first_name.lower() or
|
||||||
query_lower in c.last_name.lower() or
|
query_lower in c.last_name.lower() or
|
||||||
query_lower in c.email.lower() or
|
query_lower in c.email.lower() or
|
||||||
|
query_lower in c.username.lower() or
|
||||||
any(query_lower in skill.name.lower() for skill in c.skills))
|
any(query_lower in skill.name.lower() for skill in c.skills))
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -727,6 +900,209 @@ async def search_jobs(
|
|||||||
content=create_error_response("SEARCH_FAILED", str(e))
|
content=create_error_response("SEARCH_FAILED", str(e))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ============================
|
||||||
|
# Chat Endpoints
|
||||||
|
# ============================
|
||||||
|
@api_router.post("/chat/sessions")
|
||||||
|
async def create_chat_session(
|
||||||
|
session_data: Dict[str, Any] = Body(...),
|
||||||
|
current_user : BaseUserWithType = Depends(get_current_user),
|
||||||
|
database: RedisDatabase = Depends(get_database)
|
||||||
|
):
|
||||||
|
"""Create a new chat session"""
|
||||||
|
try:
|
||||||
|
# Add required fields
|
||||||
|
session_data["id"] = str(uuid.uuid4())
|
||||||
|
session_data["createdAt"] = datetime.now(UTC).isoformat()
|
||||||
|
session_data["updatedAt"] = datetime.now(UTC).isoformat()
|
||||||
|
|
||||||
|
# Create chat session
|
||||||
|
chat_session = ChatSession.model_validate(session_data)
|
||||||
|
await database.set_chat_session(chat_session.id, chat_session.model_dump())
|
||||||
|
|
||||||
|
logger.info(f"✅ Chat session created: {chat_session.id} for user {current_user.id}")
|
||||||
|
return create_success_response(chat_session.model_dump(by_alias=True))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Chat session creation error: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content=create_error_response("CREATION_FAILED", str(e))
|
||||||
|
)
|
||||||
|
|
||||||
|
@api_router.get("/chat/sessions/{session_id}")
|
||||||
|
async def get_chat_session(
|
||||||
|
session_id: str = Path(...),
|
||||||
|
current_user = Depends(get_current_user),
|
||||||
|
database: RedisDatabase = Depends(get_database)
|
||||||
|
):
|
||||||
|
"""Get a chat session by ID"""
|
||||||
|
try:
|
||||||
|
chat_session_data = await database.get_chat_session(session_id)
|
||||||
|
if not chat_session_data:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=404,
|
||||||
|
content=create_error_response("NOT_FOUND", "Chat session not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_session = ChatSession.model_validate(chat_session_data)
|
||||||
|
return create_success_response(chat_session.model_dump(by_alias=True))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Get chat session error: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=create_error_response("FETCH_ERROR", str(e))
|
||||||
|
)
|
||||||
|
|
||||||
|
@api_router.get("/chat/sessions/{session_id}/messages")
|
||||||
|
async def get_chat_session_messages(
|
||||||
|
session_id: str = Path(...),
|
||||||
|
current_user = Depends(get_current_user),
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
limit: int = Query(20, ge=1, le=100),
|
||||||
|
sortBy: Optional[str] = Query(None, alias="sortBy"),
|
||||||
|
sortOrder: str = Query("desc", pattern="^(asc|desc)$", alias="sortOrder"),
|
||||||
|
filters: Optional[str] = Query(None),
|
||||||
|
database: RedisDatabase = Depends(get_database)
|
||||||
|
):
|
||||||
|
"""Get a chat session by ID"""
|
||||||
|
try:
|
||||||
|
chat_session_data = await database.get_chat_session(session_id)
|
||||||
|
if not chat_session_data:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=404,
|
||||||
|
content=create_error_response("NOT_FOUND", "Chat session not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_messages = await database.get_chat_messages(session_id)
|
||||||
|
# Convert messages to ChatMessage objects
|
||||||
|
messages_list = [ChatMessage.model_validate(msg) for msg in chat_messages]
|
||||||
|
# Apply filters and pagination
|
||||||
|
filter_dict = None
|
||||||
|
if filters:
|
||||||
|
filter_dict = json.loads(filters)
|
||||||
|
paginated_messages, total = filter_and_paginate(
|
||||||
|
messages_list, page, limit, sortBy, sortOrder, filter_dict
|
||||||
|
)
|
||||||
|
paginated_response = create_paginated_response(
|
||||||
|
[m.model_dump(by_alias=True) for m in paginated_messages],
|
||||||
|
page, limit, total
|
||||||
|
)
|
||||||
|
|
||||||
|
return create_success_response(paginated_response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Get chat session error: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=create_error_response("FETCH_ERROR", str(e))
|
||||||
|
)
|
||||||
|
|
||||||
|
@api_router.post("/chat/sessions/{session_id}/messages/stream")
|
||||||
|
async def post_chat_session_message_stream(
|
||||||
|
session_id: str = Path(...),
|
||||||
|
data: Dict[str, Any] = Body(...),
|
||||||
|
current_user = Depends(get_current_user),
|
||||||
|
database: RedisDatabase = Depends(get_database),
|
||||||
|
request: Request = Request, # For streaming response
|
||||||
|
):
|
||||||
|
"""Post a message to a chat session and stream the response"""
|
||||||
|
try:
|
||||||
|
chat_session_data = await database.get_chat_session(session_id)
|
||||||
|
if not chat_session_data:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=404,
|
||||||
|
content=create_error_response("NOT_FOUND", "Chat session not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_type = chat_session_data.get("context", {}).get("type", "general")
|
||||||
|
|
||||||
|
logger.info(f"🔗 Chat session {session_id} type {chat_type} accessed by user {current_user.id}")
|
||||||
|
query = data.get("query")
|
||||||
|
if not query:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content=create_error_response("INVALID_QUERY", "Query cannot be empty")
|
||||||
|
)
|
||||||
|
chat_query = ChatQuery.model_validate(query)
|
||||||
|
chat_agent = agents.get_or_create_agent(agent_type=chat_type, prometheus_collector=prometheus_collector, database=database)
|
||||||
|
if not chat_agent:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type")
|
||||||
|
)
|
||||||
|
async def message_stream_generator():
|
||||||
|
"""Generator to stream messages"""
|
||||||
|
async for message in chat_agent.generate(
|
||||||
|
llm=llm_manager.get_llm(),
|
||||||
|
model=defines.model,
|
||||||
|
query=chat_query,
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
):
|
||||||
|
json_data = message.model_dump(mode='json', by_alias=True)
|
||||||
|
json_str = json.dumps(json_data)
|
||||||
|
logger.info(f"🔗 Streaming message for session {session_id}: {json_str}")
|
||||||
|
yield json_str + "\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
message_stream_generator(),
|
||||||
|
media_type="application/json",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no", # Prevents Nginx buffering if you're using it
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
logger.error(f"Get chat session error: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content=create_error_response("FETCH_ERROR", str(e))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_router.get("/chat/sessions")
|
||||||
|
async def get_chat_sessions(
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
limit: int = Query(20, ge=1, le=100),
|
||||||
|
sortBy: Optional[str] = Query(None, alias="sortBy"),
|
||||||
|
sortOrder: str = Query("desc", pattern="^(asc|desc)$", alias="sortOrder"),
|
||||||
|
filters: Optional[str] = Query(None),
|
||||||
|
current_user = Depends(get_current_user),
|
||||||
|
database: RedisDatabase = Depends(get_database)
|
||||||
|
):
|
||||||
|
"""Get paginated list of chat sessions"""
|
||||||
|
try:
|
||||||
|
filter_dict = None
|
||||||
|
if filters:
|
||||||
|
filter_dict = json.loads(filters)
|
||||||
|
|
||||||
|
# Get all chat sessions from Redis
|
||||||
|
all_sessions_data = await database.get_all_chat_sessions()
|
||||||
|
sessions_list = [ChatSession.model_validate(data) for data in all_sessions_data.values()]
|
||||||
|
|
||||||
|
paginated_sessions, total = filter_and_paginate(
|
||||||
|
sessions_list, page, limit, sortBy, sortOrder, filter_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
paginated_response = create_paginated_response(
|
||||||
|
[s.model_dump(by_alias=True) for s in paginated_sessions],
|
||||||
|
page, limit, total
|
||||||
|
)
|
||||||
|
|
||||||
|
return create_success_response(paginated_response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Get chat sessions error: {e}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content=create_error_response("FETCH_FAILED", str(e))
|
||||||
|
)
|
||||||
|
|
||||||
# ============================
|
# ============================
|
||||||
# Health Check and Info Endpoints
|
# Health Check and Info Endpoints
|
||||||
# ============================
|
# ============================
|
||||||
@ -790,6 +1166,11 @@ async def redis_stats(redis_client: redis.Redis = Depends(get_redis)):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=503, detail=f"Redis stats unavailable: {e}")
|
raise HTTPException(status_code=503, detail=f"Redis stats unavailable: {e}")
|
||||||
|
|
||||||
|
@api_router.get("/system-info")
|
||||||
|
async def get_system_info(request: Request):
|
||||||
|
from system_info import system_info # Import system_info function from system_info module
|
||||||
|
return JSONResponse(system_info())
|
||||||
|
|
||||||
@api_router.get("/")
|
@api_router.get("/")
|
||||||
async def api_info():
|
async def api_info():
|
||||||
"""API information endpoint"""
|
"""API information endpoint"""
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from typing import List, Dict, Optional, Any, Union, Literal, TypeVar, Generic, Annotated
|
from typing import List, Dict, Optional, Any, Union, Literal, TypeVar, Generic, Annotated
|
||||||
from pydantic import BaseModel, Field, EmailStr, HttpUrl, validator # type: ignore
|
from pydantic import BaseModel, Field, EmailStr, HttpUrl, validator # type: ignore
|
||||||
from pydantic.types import constr, conint # type: ignore
|
from pydantic.types import constr, conint # type: ignore
|
||||||
from datetime import datetime, date
|
from datetime import datetime, date, UTC
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
@ -68,10 +68,11 @@ class ChatSenderType(str, Enum):
|
|||||||
SYSTEM = "system"
|
SYSTEM = "system"
|
||||||
|
|
||||||
class ChatStatusType(str, Enum):
|
class ChatStatusType(str, Enum):
|
||||||
PARTIAL = "partial"
|
PREPARING = "preparing"
|
||||||
DONE = "done"
|
|
||||||
STREAMING = "streaming"
|
|
||||||
THINKING = "thinking"
|
THINKING = "thinking"
|
||||||
|
PARTIAL = "partial"
|
||||||
|
STREAMING = "streaming"
|
||||||
|
DONE = "done"
|
||||||
ERROR = "error"
|
ERROR = "error"
|
||||||
|
|
||||||
class ChatContextType(str, Enum):
|
class ChatContextType(str, Enum):
|
||||||
@ -80,13 +81,12 @@ class ChatContextType(str, Enum):
|
|||||||
INTERVIEW_PREP = "interview_prep"
|
INTERVIEW_PREP = "interview_prep"
|
||||||
RESUME_REVIEW = "resume_review"
|
RESUME_REVIEW = "resume_review"
|
||||||
GENERAL = "general"
|
GENERAL = "general"
|
||||||
|
GENERATE_PERSONA = "generate_persona"
|
||||||
|
GENERATE_PROFILE = "generate_profile"
|
||||||
|
|
||||||
class AIModelType(str, Enum):
|
class AIModelType(str, Enum):
|
||||||
GPT_4 = "gpt-4"
|
QWEN2_5 = "qwen2.5"
|
||||||
GPT_35_TURBO = "gpt-3.5-turbo"
|
FLUX_SCHNELL = "flux-schnell"
|
||||||
CLAUDE_3 = "claude-3"
|
|
||||||
CLAUDE_3_OPUS = "claude-3-opus"
|
|
||||||
CUSTOM = "custom"
|
|
||||||
|
|
||||||
class MFAMethod(str, Enum):
|
class MFAMethod(str, Enum):
|
||||||
APP = "app"
|
APP = "app"
|
||||||
@ -520,47 +520,73 @@ class JobApplication(BaseModel):
|
|||||||
class Config:
|
class Config:
|
||||||
populate_by_name = True # Allow both field names and aliases
|
populate_by_name = True # Allow both field names and aliases
|
||||||
|
|
||||||
class AIParameters(BaseModel):
|
class RagEntry(BaseModel):
|
||||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
||||||
user_id: Optional[str] = Field(None, alias="userId")
|
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: str = ""
|
||||||
model: AIModelType
|
enabled: bool = True
|
||||||
temperature: Annotated[float, Field(ge=0, le=1)]
|
|
||||||
max_tokens: Annotated[int, Field(gt=0)] = Field(..., alias="maxTokens")
|
class ChromaDBGetResponse(BaseModel):
|
||||||
top_p: Annotated[float, Field(ge=0, le=1)] = Field(..., alias="topP")
|
# Chroma fields
|
||||||
frequency_penalty: Annotated[float, Field(ge=-2, le=2)] = Field(..., alias="frequencyPenalty")
|
ids: List[str] = []
|
||||||
presence_penalty: Annotated[float, Field(ge=-2, le=2)] = Field(..., alias="presencePenalty")
|
embeddings: List[List[float]] = Field(default=[])
|
||||||
system_prompt: Optional[str] = Field(None, alias="systemPrompt")
|
documents: List[str] = []
|
||||||
is_default: bool = Field(..., alias="isDefault")
|
metadatas: List[Dict[str, Any]] = []
|
||||||
created_at: datetime = Field(..., alias="createdAt")
|
# Additional fields
|
||||||
updated_at: datetime = Field(..., alias="updatedAt")
|
name: str = ""
|
||||||
custom_model_config: Optional[Dict[str, Any]] = Field(None, alias="customModelConfig")
|
size: int = 0
|
||||||
class Config:
|
query: str = ""
|
||||||
populate_by_name = True # Allow both field names and aliases
|
query_embedding: Optional[List[float]] = Field(default=None, alias="queryEmbedding")
|
||||||
|
umap_embedding_2d: Optional[List[float]] = Field(default=None, alias="umapEmbedding2D")
|
||||||
|
umap_embedding_3d: Optional[List[float]] = Field(default=None, alias="umapEmbedding3D")
|
||||||
|
|
||||||
class ChatContext(BaseModel):
|
class ChatContext(BaseModel):
|
||||||
type: ChatContextType
|
type: ChatContextType
|
||||||
related_entity_id: Optional[str] = Field(None, alias="relatedEntityId")
|
related_entity_id: Optional[str] = Field(None, alias="relatedEntityId")
|
||||||
related_entity_type: Optional[Literal["job", "candidate", "employer"]] = Field(None, alias="relatedEntityType")
|
related_entity_type: Optional[Literal["job", "candidate", "employer"]] = Field(None, alias="relatedEntityType")
|
||||||
ai_parameters: AIParameters = Field(..., alias="aiParameters")
|
|
||||||
additional_context: Optional[Dict[str, Any]] = Field(None, alias="additionalContext")
|
additional_context: Optional[Dict[str, Any]] = Field(None, alias="additionalContext")
|
||||||
class Config:
|
class Config:
|
||||||
populate_by_name = True # Allow both field names and aliases
|
populate_by_name = True # Allow both field names and aliases
|
||||||
|
|
||||||
|
class ChatOptions(BaseModel):
|
||||||
|
seed: Optional[int] = 8911
|
||||||
|
num_ctx: Optional[int] = Field(default=None, alias="numCtx") # Number of context tokens
|
||||||
|
temperature: Optional[float] = Field(default=0.7) # Higher temperature to encourage tool usage
|
||||||
|
|
||||||
|
class ChatMessageMetaData(BaseModel):
|
||||||
|
model: AIModelType = AIModelType.QWEN2_5
|
||||||
|
temperature: float = 0.7
|
||||||
|
max_tokens: int = Field(default=8092, alias="maxTokens")
|
||||||
|
top_p: float = Field(default=1, alias="topP")
|
||||||
|
frequency_penalty: Optional[float] = Field(None, alias="frequencyPenalty")
|
||||||
|
presence_penalty: Optional[float] = Field(None, alias="presencePenalty")
|
||||||
|
stop_sequences: Optional[List[str]] = Field(None, alias="stopSequences")
|
||||||
|
tunables: Optional[Tunables] = None
|
||||||
|
rag: List[ChromaDBGetResponse] = Field(default_factory=list)
|
||||||
|
eval_count: int = 0
|
||||||
|
eval_duration: int = 0
|
||||||
|
prompt_eval_count: int = 0
|
||||||
|
prompt_eval_duration: int = 0
|
||||||
|
options: Optional[ChatOptions] = None
|
||||||
|
tools: Optional[Dict[str, Any]] = None
|
||||||
|
timers: Optional[Dict[str, float]] = None
|
||||||
|
class Config:
|
||||||
|
populate_by_name = True # Allow both field names and aliases
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
session_id: str = Field(..., alias="sessionId")
|
session_id: str = Field(..., alias="sessionId")
|
||||||
status: ChatStatusType
|
status: ChatStatusType
|
||||||
sender: ChatSenderType
|
sender: ChatSenderType
|
||||||
sender_id: Optional[str] = Field(None, alias="senderId")
|
sender_id: Optional[str] = Field(None, alias="senderId")
|
||||||
content: str
|
prompt: str = ""
|
||||||
|
content: str = ""
|
||||||
|
chunk: str = ""
|
||||||
timestamp: datetime
|
timestamp: datetime
|
||||||
attachments: Optional[List[Attachment]] = None
|
#attachments: Optional[List[Attachment]] = None
|
||||||
reactions: Optional[List[MessageReaction]] = None
|
#reactions: Optional[List[MessageReaction]] = None
|
||||||
is_edited: bool = Field(False, alias="isEdited")
|
is_edited: bool = Field(False, alias="isEdited")
|
||||||
edit_history: Optional[List[EditHistory]] = Field(None, alias="editHistory")
|
#edit_history: Optional[List[EditHistory]] = Field(None, alias="editHistory")
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: ChatMessageMetaData = Field(None)
|
||||||
class Config:
|
class Config:
|
||||||
populate_by_name = True # Allow both field names and aliases
|
populate_by_name = True # Allow both field names and aliases
|
||||||
|
|
||||||
@ -568,8 +594,8 @@ class ChatSession(BaseModel):
|
|||||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
user_id: Optional[str] = Field(None, alias="userId")
|
user_id: Optional[str] = Field(None, alias="userId")
|
||||||
guest_id: Optional[str] = Field(None, alias="guestId")
|
guest_id: Optional[str] = Field(None, alias="guestId")
|
||||||
created_at: datetime = Field(..., alias="createdAt")
|
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="createdAt")
|
||||||
last_activity: datetime = Field(..., alias="lastActivity")
|
last_activity: datetime = Field(default_factory=lambda: datetime.now(UTC), alias="lastActivity")
|
||||||
title: Optional[str] = None
|
title: Optional[str] = None
|
||||||
context: ChatContext
|
context: ChatContext
|
||||||
messages: Optional[List[ChatMessage]] = None
|
messages: Optional[List[ChatMessage]] = None
|
||||||
@ -615,7 +641,6 @@ class RAGConfiguration(BaseModel):
|
|||||||
retrieval_parameters: RetrievalParameters = Field(..., alias="retrievalParameters")
|
retrieval_parameters: RetrievalParameters = Field(..., alias="retrievalParameters")
|
||||||
created_at: datetime = Field(..., alias="createdAt")
|
created_at: datetime = Field(..., alias="createdAt")
|
||||||
updated_at: datetime = Field(..., alias="updatedAt")
|
updated_at: datetime = Field(..., alias="updatedAt")
|
||||||
is_default: bool = Field(..., alias="isDefault")
|
|
||||||
version: int
|
version: int
|
||||||
is_active: bool = Field(..., alias="isActive")
|
is_active: bool = Field(..., alias="isActive")
|
||||||
class Config:
|
class Config:
|
||||||
@ -672,7 +697,7 @@ class UserPreference(BaseModel):
|
|||||||
# ============================
|
# ============================
|
||||||
# API Request/Response Models
|
# API Request/Response Models
|
||||||
# ============================
|
# ============================
|
||||||
class Query(BaseModel):
|
class ChatQuery(BaseModel):
|
||||||
prompt: str
|
prompt: str
|
||||||
tunables: Optional[Tunables] = None
|
tunables: Optional[Tunables] = None
|
||||||
agent_options: Optional[Dict[str, Any]] = Field(None, alias="agentOptions")
|
agent_options: Optional[Dict[str, Any]] = Field(None, alias="agentOptions")
|
||||||
|
81
src/backend/system_info.py
Normal file
81
src/backend/system_info.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
import defines
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import math
|
||||||
|
|
||||||
|
def get_installed_ram():
|
||||||
|
try:
|
||||||
|
with open("/proc/meminfo", "r") as f:
|
||||||
|
meminfo = f.read()
|
||||||
|
match = re.search(r"MemTotal:\s+(\d+)", meminfo)
|
||||||
|
if match:
|
||||||
|
return f"{math.floor(int(match.group(1)) / 1000**2)}GB" # Convert KB to GB
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error retrieving RAM: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_graphics_cards():
|
||||||
|
gpus = []
|
||||||
|
try:
|
||||||
|
# Run the ze-monitor utility
|
||||||
|
result = subprocess.run(
|
||||||
|
["ze-monitor"], capture_output=True, text=True, check=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up the output (remove leading/trailing whitespace and newlines)
|
||||||
|
output = result.stdout.strip()
|
||||||
|
for index in range(len(output.splitlines())):
|
||||||
|
result = subprocess.run(
|
||||||
|
["ze-monitor", "--device", f"{index+1}", "--info"],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
|
gpu_info = result.stdout.strip().splitlines()
|
||||||
|
gpu = {
|
||||||
|
"discrete": True, # Assume it's discrete initially
|
||||||
|
"name": None,
|
||||||
|
"memory": None,
|
||||||
|
}
|
||||||
|
gpus.append(gpu)
|
||||||
|
for line in gpu_info:
|
||||||
|
match = re.match(r"^Device: [^(]*\((.*)\)", line)
|
||||||
|
if match:
|
||||||
|
gpu["name"] = match.group(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
match = re.match(r"^\s*Memory: (.*)", line)
|
||||||
|
if match:
|
||||||
|
gpu["memory"] = match.group(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
match = re.match(r"^.*Is integrated with host: Yes.*", line)
|
||||||
|
if match:
|
||||||
|
gpu["discrete"] = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
return gpus
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error retrieving GPU info: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_cpu_info():
|
||||||
|
try:
|
||||||
|
with open("/proc/cpuinfo", "r") as f:
|
||||||
|
cpuinfo = f.read()
|
||||||
|
model_match = re.search(r"model name\s+:\s+(.+)", cpuinfo)
|
||||||
|
cores_match = re.findall(r"processor\s+:\s+\d+", cpuinfo)
|
||||||
|
if model_match and cores_match:
|
||||||
|
return f"{model_match.group(1)} with {len(cores_match)} cores"
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error retrieving CPU info: {e}"
|
||||||
|
|
||||||
|
def system_info():
|
||||||
|
return {
|
||||||
|
"System RAM": get_installed_ram(),
|
||||||
|
"Graphics Card": get_graphics_cards(),
|
||||||
|
"CPU": get_cpu_info(),
|
||||||
|
"LLM Model": defines.model,
|
||||||
|
"Embedding Model": defines.embedding_model,
|
||||||
|
"Context length": defines.max_context,
|
||||||
|
}
|
@ -1,207 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
"""
|
|
||||||
Focused test script that tests the most important functionality
|
|
||||||
without getting caught up in serialization format complexities
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from datetime import datetime
|
|
||||||
from models import (
|
|
||||||
UserStatus, UserType, SkillLevel, EmploymentType,
|
|
||||||
Candidate, Employer, Location, Skill, AIParameters, AIModelType
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_model_creation():
|
|
||||||
"""Test that we can create models successfully"""
|
|
||||||
print("🧪 Testing model creation...")
|
|
||||||
|
|
||||||
# Create supporting objects
|
|
||||||
location = Location(city="Austin", country="USA")
|
|
||||||
skill = Skill(name="Python", category="Programming", level=SkillLevel.ADVANCED)
|
|
||||||
|
|
||||||
# Create candidate
|
|
||||||
candidate = Candidate(
|
|
||||||
email="test@example.com",
|
|
||||||
username="test_candidate",
|
|
||||||
createdAt=datetime.now(),
|
|
||||||
updatedAt=datetime.now(),
|
|
||||||
status=UserStatus.ACTIVE,
|
|
||||||
firstName="John",
|
|
||||||
lastName="Doe",
|
|
||||||
fullName="John Doe",
|
|
||||||
skills=[skill],
|
|
||||||
experience=[],
|
|
||||||
education=[],
|
|
||||||
preferredJobTypes=[EmploymentType.FULL_TIME],
|
|
||||||
location=location,
|
|
||||||
languages=[],
|
|
||||||
certifications=[]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create employer
|
|
||||||
employer = Employer(
|
|
||||||
email="hr@company.com",
|
|
||||||
username="test_employer",
|
|
||||||
createdAt=datetime.now(),
|
|
||||||
updatedAt=datetime.now(),
|
|
||||||
status=UserStatus.ACTIVE,
|
|
||||||
companyName="Test Company",
|
|
||||||
industry="Technology",
|
|
||||||
companySize="50-200",
|
|
||||||
companyDescription="A test company",
|
|
||||||
location=location
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"✅ Candidate: {candidate.first_name} {candidate.last_name}")
|
|
||||||
print(f"✅ Employer: {employer.company_name}")
|
|
||||||
print(f"✅ User types: {candidate.user_type}, {employer.user_type}")
|
|
||||||
|
|
||||||
return candidate, employer
|
|
||||||
|
|
||||||
def test_json_api_format():
|
|
||||||
"""Test JSON serialization in API format (the most important use case)"""
|
|
||||||
print("\n📡 Testing JSON API format...")
|
|
||||||
|
|
||||||
candidate, employer = test_model_creation()
|
|
||||||
|
|
||||||
# Serialize to JSON (API format)
|
|
||||||
candidate_json = candidate.model_dump_json(by_alias=True)
|
|
||||||
employer_json = employer.model_dump_json(by_alias=True)
|
|
||||||
|
|
||||||
print(f"✅ Candidate JSON: {len(candidate_json)} chars")
|
|
||||||
print(f"✅ Employer JSON: {len(employer_json)} chars")
|
|
||||||
|
|
||||||
# Deserialize from JSON
|
|
||||||
candidate_back = Candidate.model_validate_json(candidate_json)
|
|
||||||
employer_back = Employer.model_validate_json(employer_json)
|
|
||||||
|
|
||||||
# Verify data integrity
|
|
||||||
assert candidate_back.email == candidate.email
|
|
||||||
assert candidate_back.first_name == candidate.first_name
|
|
||||||
assert employer_back.company_name == employer.company_name
|
|
||||||
|
|
||||||
print(f"✅ JSON round-trip successful")
|
|
||||||
print(f"✅ Data integrity verified")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def test_api_dict_format():
|
|
||||||
"""Test dictionary format with aliases (for API requests/responses)"""
|
|
||||||
print("\n📊 Testing API dictionary format...")
|
|
||||||
|
|
||||||
candidate, employer = test_model_creation()
|
|
||||||
|
|
||||||
# Create API format dictionaries
|
|
||||||
candidate_dict = candidate.model_dump(by_alias=True)
|
|
||||||
employer_dict = employer.model_dump(by_alias=True)
|
|
||||||
|
|
||||||
# Verify camelCase aliases are used
|
|
||||||
assert "firstName" in candidate_dict
|
|
||||||
assert "lastName" in candidate_dict
|
|
||||||
assert "createdAt" in candidate_dict
|
|
||||||
assert "companyName" in employer_dict
|
|
||||||
|
|
||||||
print(f"✅ API format dictionaries created")
|
|
||||||
print(f"✅ CamelCase aliases verified")
|
|
||||||
|
|
||||||
# Test deserializing from API format
|
|
||||||
candidate_back = Candidate.model_validate(candidate_dict)
|
|
||||||
employer_back = Employer.model_validate(employer_dict)
|
|
||||||
|
|
||||||
assert candidate_back.email == candidate.email
|
|
||||||
assert employer_back.company_name == employer.company_name
|
|
||||||
|
|
||||||
print(f"✅ API format round-trip successful")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def test_validation_constraints():
|
|
||||||
"""Test that validation constraints work"""
|
|
||||||
print("\n🔒 Testing validation constraints...")
|
|
||||||
|
|
||||||
# Test AI Parameters with constraints
|
|
||||||
valid_params = AIParameters(
|
|
||||||
name="Test Config",
|
|
||||||
model=AIModelType.GPT_4,
|
|
||||||
temperature=0.7, # Valid: 0-1
|
|
||||||
maxTokens=2000, # Valid: > 0
|
|
||||||
topP=0.95, # Valid: 0-1
|
|
||||||
frequencyPenalty=0.0, # Valid: -2 to 2
|
|
||||||
presencePenalty=0.0, # Valid: -2 to 2
|
|
||||||
isDefault=True,
|
|
||||||
createdAt=datetime.now(),
|
|
||||||
updatedAt=datetime.now()
|
|
||||||
)
|
|
||||||
print(f"✅ Valid AI parameters created")
|
|
||||||
|
|
||||||
# Test constraint violation
|
|
||||||
try:
|
|
||||||
invalid_params = AIParameters(
|
|
||||||
name="Invalid Config",
|
|
||||||
model=AIModelType.GPT_4,
|
|
||||||
temperature=1.5, # Invalid: > 1
|
|
||||||
maxTokens=2000,
|
|
||||||
topP=0.95,
|
|
||||||
frequencyPenalty=0.0,
|
|
||||||
presencePenalty=0.0,
|
|
||||||
isDefault=True,
|
|
||||||
createdAt=datetime.now(),
|
|
||||||
updatedAt=datetime.now()
|
|
||||||
)
|
|
||||||
print("❌ Should have rejected invalid temperature")
|
|
||||||
return False
|
|
||||||
except Exception:
|
|
||||||
print(f"✅ Constraint validation working")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def test_enum_values():
|
|
||||||
"""Test that enum values work correctly"""
|
|
||||||
print("\n📋 Testing enum values...")
|
|
||||||
|
|
||||||
# Test that enum values are properly handled
|
|
||||||
candidate, employer = test_model_creation()
|
|
||||||
|
|
||||||
# Check enum values in serialization
|
|
||||||
candidate_dict = candidate.model_dump(by_alias=True)
|
|
||||||
|
|
||||||
assert candidate_dict["status"] == "active"
|
|
||||||
assert candidate_dict["userType"] == "candidate"
|
|
||||||
assert employer.user_type == UserType.EMPLOYER
|
|
||||||
|
|
||||||
print(f"✅ Enum values correctly serialized")
|
|
||||||
print(f"✅ User types: candidate={candidate.user_type}, employer={employer.user_type}")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Run all focused tests"""
|
|
||||||
print("🎯 Focused Pydantic Model Tests")
|
|
||||||
print("=" * 40)
|
|
||||||
|
|
||||||
try:
|
|
||||||
test_model_creation()
|
|
||||||
test_json_api_format()
|
|
||||||
test_api_dict_format()
|
|
||||||
test_validation_constraints()
|
|
||||||
test_enum_values()
|
|
||||||
|
|
||||||
print(f"\n🎉 All focused tests passed!")
|
|
||||||
print("=" * 40)
|
|
||||||
print("✅ Models work correctly")
|
|
||||||
print("✅ JSON API format works")
|
|
||||||
print("✅ Validation constraints work")
|
|
||||||
print("✅ Enum values work")
|
|
||||||
print("✅ Ready for type generation!")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ Test failed: {type(e).__name__}: {e}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
success = main()
|
|
||||||
sys.exit(0 if success else 1)
|
|
2
update-types.sh
Executable file
2
update-types.sh
Executable file
@ -0,0 +1,2 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
docker compose exec backstory shell "python src/backend/generate_types.py --source src/backend/models.py --output frontend/src/types/types.ts ${*}"
|
Loading…
x
Reference in New Issue
Block a user