diff --git a/frontend/src/Conversation.tsx b/frontend/src/Conversation.tsx index c6a318d..5235b02 100644 --- a/frontend/src/Conversation.tsx +++ b/frontend/src/Conversation.tsx @@ -243,7 +243,7 @@ const Conversation = forwardRef(({ 'Content-Type': 'application/json', 'Accept': 'application/json', }, - body: JSON.stringify({ reset: 'history' }) + body: JSON.stringify({ reset: ['history'] }) }); if (!response.ok) { @@ -503,7 +503,7 @@ const Conversation = forwardRef(({ edge="start" color="inherit" disabled={sessionId === undefined || processingMessage !== undefined} - onClick={() => { reset(); }} + onClick={() => { reset(); resetAction && resetAction(); }} > diff --git a/frontend/src/ResumeBuilder.tsx b/frontend/src/ResumeBuilder.tsx index e316dbe..2818dda 100644 --- a/frontend/src/ResumeBuilder.tsx +++ b/frontend/src/ResumeBuilder.tsx @@ -221,6 +221,21 @@ const ResumeBuilder: React.FC = ({ return message; }, []); + const resetJobDescription = useCallback(() => { + setHasJobDescription(false); + setHasResume(false); + setHasFacts(false); + }, [setHasJobDescription, setHasResume, setHasFacts]); + + const resetResume = useCallback(() => { + setHasResume(false); + setHasFacts(false); + }, [setHasResume, setHasFacts]); + + const resetFacts = useCallback(() => { + setHasFacts(false); + }, [setHasFacts]); + const renderJobDescriptionView = useCallback((small: boolean) => { console.log('renderJobDescriptionView'); const jobDescriptionQuestions = [ @@ -239,6 +254,7 @@ const ResumeBuilder: React.FC = ({ prompt: "Paste a job description, then click Generate...", multiline: true, messageFilter: filterJobDescriptionMessages, + resetAction: resetJobDescription, onResponse: jobResponse, sessionId, connectionBase, @@ -255,6 +271,7 @@ const ResumeBuilder: React.FC = ({ prompt: "Ask a question about this job description...", messageFilter: filterJobDescriptionMessages, defaultPrompts: jobDescriptionQuestions, + resetAction: resetJobDescription, onResponse: jobResponse, sessionId, connectionBase, @@ -262,7 +279,7 @@ const ResumeBuilder: React.FC = ({ }} /> } - }, [connectionBase, filterJobDescriptionMessages, hasJobDescription, sessionId, setSnack, jobResponse]); + }, [connectionBase, filterJobDescriptionMessages, hasJobDescription, sessionId, setSnack, jobResponse, resetJobDescription]); /** * Renders the resume view with loading indicator @@ -284,6 +301,7 @@ const ResumeBuilder: React.FC = ({ type: "resume", messageFilter: filterResumeMessages, onResponse: resumeResponse, + resetAction: resetResume, sessionId, connectionBase, setSnack, @@ -298,6 +316,7 @@ const ResumeBuilder: React.FC = ({ prompt: "Ask a question about this job resume...", messageFilter: filterResumeMessages, defaultPrompts: resumeQuestions, + resetAction: resetResume, onResponse: resumeResponse, sessionId, connectionBase, @@ -305,7 +324,7 @@ const ResumeBuilder: React.FC = ({ }} /> } - }, [connectionBase, filterResumeMessages, hasFacts, sessionId, setSnack, resumeResponse]); + }, [connectionBase, filterResumeMessages, hasFacts, sessionId, setSnack, resumeResponse, resetResume]); /** * Renders the fact check view @@ -325,14 +344,14 @@ const ResumeBuilder: React.FC = ({ prompt: "Ask a question about any discrepencies...", messageFilter: filterFactsMessages, defaultPrompts: factsQuestions, + resetAction: resetFacts, onResponse: factsResponse, sessionId, connectionBase, setSnack, }} /> - }, [connectionBase, sessionId, setSnack, factsResponse, filterFactsMessages]); - + }, [connectionBase, sessionId, setSnack, factsResponse, filterFactsMessages, resetFacts]); /** * Gets the appropriate content based on active state for Desktop diff --git a/src/server.py b/src/server.py index a992a6d..e18332e 100644 --- a/src/server.py +++ b/src/server.py @@ -539,37 +539,61 @@ class WebServer: logging.error(e) #return JSONResponse({"error": str(e)}, 500) - @self.app.put("/api/reset/{context_id}/{type}") - async def put_reset(context_id: str, type: str, request: Request): + @self.app.put("/api/reset/{context_id}/{session_type}") + async def put_reset(context_id: str, session_type: str, request: Request): if not is_valid_uuid(context_id): logging.warning(f"Invalid context_id: {context_id}") return JSONResponse({"error": "Invalid context_id"}, status_code=400) context = self.upsert_context(context_id) - if type not in context["sessions"]: - return JSONResponse({ "error": f"{type} is not recognized", "context": context }, status_code=404) + if session_type not in context["sessions"]: + return JSONResponse({ "error": f"{session_type} is not recognized", "context": context }, status_code=404) data = await request.json() try: - session = context["sessions"][type] + session = context["sessions"][session_type] response = {} - for reset in data["reset"]: - match reset: + for reset_operation in data["reset"]: + match reset_operation: case "system_prompt": - session["system_prompt"] = system_message - response["system_prompt"] = { "system_prompt": system_message } + logging.info(f"Resetting {reset_operation}") + match session_type: + case "chat": + prompt = system_message + case "job_description": + prompt = system_job_description + case "resume": + prompt = system_generate_resume + case "fact_check": + prompt = system_fact_check + + session["system_prompt"] = prompt + response["system_prompt"] = { "system_prompt": prompt } case "rags": + logging.info(f"Resetting {reset_operation}") context["rags"] = rags.copy() response["rags"] = context["rags"] case "tools": + logging.info(f"Resetting {reset_operation}") context["tools"] = default_tools(tools) response["tools"] = context["tools"] case "history": - session["llm_history"] = [] - session["user_history"] = [] - session["context_tokens"] = round(len(str(session["system_prompt"])) * 3 / 4) # Estimate context usage + reset_map = { + "job_description": ("job_description", "resume", "fact_check"), + "resume": ("resume", "fact_check"), + "fact_check": ("fact_check",), + "chat": ("chat",), + } + resets = reset_map.get(session_type, ()) + + for mode in resets: + logging.info(f"Resetting {reset_operation} for {mode}") + context["sessions"][mode]["llm_history"] = [] + context["sessions"][mode]["user_history"] = [] + context["sessions"][mode]["context_tokens"] = round(len(str(context["sessions"][mode]["system_prompt"])) * 3 / 4) # Estimate context usage response["history"] = [] response["context_used"] = session["context_tokens"] case "message_history_length": + logging.info(f"Resetting {reset_operation}") context["message_history_length"] = DEFAULT_HISTORY_LENGTH response["message_history_length"] = DEFAULT_HISTORY_LENGTH @@ -622,21 +646,21 @@ class WebServer: async def get_system_info(context_id: str): return JSONResponse(system_info(self.model)) - @self.app.post("/api/chat/{context_id}/{type}") - async def chat_endpoint(context_id: str, type: str, request: Request): + @self.app.post("/api/chat/{context_id}/{session_type}") + async def chat_endpoint(context_id: str, session_type: str, request: Request): if not is_valid_uuid(context_id): logging.warning(f"Invalid context_id: {context_id}") return JSONResponse({"error": "Invalid context_id"}, status_code=400) context = self.upsert_context(context_id) - if type not in context["sessions"]: - return JSONResponse({ "error": f"{type} is not recognized", "context": context }, status_code=404) + if session_type not in context["sessions"]: + return JSONResponse({ "error": f"{session_type} is not recognized", "context": context }, status_code=404) data = await request.json() # Create a custom generator that ensures flushing async def flush_generator(): - async for message in self.chat(context=context, type=type, content=data["content"]): + async for message in self.chat(context=context, session_type=session_type, content=data["content"]): # Convert to JSON and add newline yield json.dumps(message) + "\n" # Save the history as its generated @@ -661,12 +685,12 @@ class WebServer: self.logging.info(f"Generated new session as {context['id']}") return JSONResponse(context) - @self.app.get("/api/history/{context_id}/{type}") - async def get_history(context_id: str, type: str): + @self.app.get("/api/history/{context_id}/{session_type}") + async def get_history(context_id: str, session_type: str): context = self.upsert_context(context_id) - if type not in context["sessions"]: - return JSONResponse({ "error": f"{type} is not recognized", "context": context }, status_code=404) - return JSONResponse(context["sessions"][type]["user_history"]) + if session_type not in context["sessions"]: + return JSONResponse({ "error": f"{session_type} is not recognized", "context": context }, status_code=404) + return JSONResponse(context["sessions"][session_type]["user_history"]) @self.app.get("/api/tools/{context_id}") async def get_tools(context_id: str): @@ -716,15 +740,15 @@ class WebServer: except: return JSONResponse({ "status": "error" }), 405 - @self.app.get("/api/context-status/{context_id}/{type}") - async def get_context_status(context_id, type: str): + @self.app.get("/api/context-status/{context_id}/{session_type}") + async def get_context_status(context_id, session_type: str): if not is_valid_uuid(context_id): logging.warning(f"Invalid context_id: {context_id}") return JSONResponse({"error": "Invalid context_id"}, status_code=400) context = self.upsert_context(context_id) - if type not in context["sessions"]: - return JSONResponse({ "error": f"{type} is not recognized", "context": context }, status_code=404) - return JSONResponse({"context_used": context["sessions"][type]["context_tokens"], "max_context": defines.max_context}) + if session_type not in context["sessions"]: + return JSONResponse({ "error": f"{session_type} is not recognized", "context": context }, status_code=404) + return JSONResponse({"context_used": context["sessions"][session_type]["context_tokens"], "max_context": defines.max_context}) @self.app.get("/api/health") async def health_check(): @@ -785,7 +809,7 @@ class WebServer: # "version": 2, # "id": context_id, # "sessions": { - # **TYPE**: { # chat, job-description, resume, fact-check + # **session_type**: { # chat, job-description, resume, fact-check # "system_prompt": **SYSTEM_MESSAGE**, # "llm_history": [], # "user_history": [], @@ -954,15 +978,15 @@ class WebServer: else: yield {"status": "complete", "message": "RAG processing complete"} - # type: chat + # session_type: chat # * Q&A # - # type: job_description + # session_type: job_description # * First message sets Job Description and generates Resume # * Has content (Job Description) # * Then Q&A of Job Description # - # type: resume + # session_type: resume # * First message sets Resume and generates Fact Check # * Has no content # * Then Q&A of Resume @@ -972,7 +996,7 @@ class WebServer: # * Has content # * Then Q&A of Fact Check - async def chat(self, context, type, content): + async def chat(self, context, session_type, content): if not self.file_watcher: return @@ -985,11 +1009,11 @@ class WebServer: self.processing = True try: - session = context["sessions"][type] + session = context["sessions"][session_type] llm_history = session["llm_history"] user_history = session["user_history"] metadata = { - "origin": type, + "origin": session_type, "rag": { "documents": [] }, "tools": [], "eval_count": 0, @@ -1008,10 +1032,10 @@ class WebServer: enable_rag = False # RAG is disabled when asking questions about the resume - if type == "resume": + if session_type == "resume": enable_rag = False - # The first time through each session type a content_seed may be set for + # The first time through each session session_type a content_seed may be set for # future chat sessions; use it once, then clear it if session["content_seed"]: preamble = f"{session['content_seed']}" @@ -1019,10 +1043,10 @@ class WebServer: else: preamble = "" - # After the first time a particular session type is used, it is handled as a chat. + # After the first time a particular session session_type is used, it is handled as a chat. # The number of messages indicating the session is ready for chat varies based on - # the type of session - process_type = type + # the session_type of session + process_type = session_type match process_type: case "job_description": logging.info(f"job_description user_history len: {len(user_history)}") @@ -1073,7 +1097,7 @@ class WebServer: Use that information to respond to:""" # Use the mode specific system_prompt instead of 'chat' - system_prompt = context["sessions"][type]["system_prompt"] + system_prompt = context["sessions"][session_type]["system_prompt"] # On first entry, a single job_description is provided ("user") # Generate a resume to append to RESUME history @@ -1201,7 +1225,7 @@ Use the above [RESUME] to answer this query: user_history = session["user_history"] = [] case _: - raise Exception(f"Invalid chat type: {type}") + raise Exception(f"Invalid chat session_type: {session_type}") llm_history.append({"role": "user", "content": preamble + content}) user_history.append({"role": "user", "content": content, "origin": metadata["origin"]}) @@ -1218,7 +1242,7 @@ Use the above [RESUME] to answer this query: if len(user_history) > 2: processing_message = f"Processing {'RAG augmented ' if enable_rag else ''}query..." else: - match type: + match session_type: case "job_description": processing_message = f"Generating {'RAG augmented ' if enable_rag else ''}resume..." case "resume": @@ -1312,7 +1336,7 @@ Use the above [RESUME] to answer this query: yield {"status": "done", "message": final_message } except Exception as e: - logging.exception({ "model": self.model, "origin": type, "content": content, "error": str(e) }) + logging.exception({ "model": self.model, "origin": session_type, "content": content, "error": str(e) }) yield {"status": "error", "message": f"An error occurred: {str(e)}"} finally: