From 90a83a7313332223f9b230888005598496aaa2e4 Mon Sep 17 00:00:00 2001 From: James Ketrenos Date: Tue, 29 Apr 2025 15:53:04 -0700 Subject: [PATCH] Almost working? --- Dockerfile | 9 +- src/kill-server.sh | 6 +- src/server.py | 226 ++++++++++++------------ src/utils/__init__.py | 19 ++- src/utils/agent.py | 256 ++++++++++++++++++++++++++++ src/utils/agents/base.py | 241 ++++++++++++++++++++++++++ src/utils/agents/chat.py | 243 ++++++++++++++++++++++++++ src/utils/agents/fact_check.py | 24 +++ src/utils/agents/job_description.py | 24 +++ src/utils/agents/resume.py | 32 ++++ src/utils/context.py | 186 ++++++++++++++------ src/utils/defines.py | 2 +- src/utils/message.py | 14 +- src/utils/rag.py | 1 + src/utils/session.py | 78 --------- 15 files changed, 1102 insertions(+), 259 deletions(-) create mode 100644 src/utils/agent.py create mode 100644 src/utils/agents/base.py create mode 100644 src/utils/agents/chat.py create mode 100644 src/utils/agents/fact_check.py create mode 100644 src/utils/agents/job_description.py create mode 100644 src/utils/agents/resume.py delete mode 100644 src/utils/session.py diff --git a/Dockerfile b/Dockerfile index d3ae56f..edc1aa3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -293,8 +293,13 @@ RUN { \ echo ' openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout src/key.pem -out src/cert.pem -subj "/C=US/ST=OR/L=Portland/O=Development/CN=localhost"'; \ echo ' fi' ; \ echo ' while true; do'; \ - echo ' echo "Launching Backstory server..."'; \ - echo ' python src/server.py "${@}" || echo "Backstory server died. Restarting in 3 seconds."'; \ + echo ' if [[ ! -e /opt/backstory/block-server ]]; then' \ + echo ' echo "Launching Backstory server..."'; \ + echo ' python src/server.py "${@}" || echo "Backstory server died."'; \ + echo ' else'; \ + echo ' echo "block-server file exists. Not launching."'; \ + echo ' fi' ; \ + echo ' echo "Sleeping for 3 seconds."'; \ echo ' sleep 3'; \ echo ' done' ; \ echo 'fi'; \ diff --git a/src/kill-server.sh b/src/kill-server.sh index 39cc63b..8859181 100755 --- a/src/kill-server.sh +++ b/src/kill-server.sh @@ -2,12 +2,12 @@ # Ensure input was provided if [[ -z "$1" ]]; then - echo "Usage: $0 " - exit 1 + TARGET=$(readlink -f "src/server.py") +else + TARGET=$(readlink -f "$1") fi # Resolve user-supplied path to absolute path -TARGET=$(readlink -f "$1") if [[ ! -f "$TARGET" ]]; then echo "Target file '$TARGET' not found." diff --git a/src/server.py b/src/server.py index c6b0298..df0a7ce 100644 --- a/src/server.py +++ b/src/server.py @@ -44,7 +44,8 @@ from sklearn.preprocessing import MinMaxScaler from utils import ( rag as Rag, - Context, Conversation, Session, Message, Chat, Resume, JobDescription, FactCheck, + Context, Conversation, Message, + Agent, defines ) @@ -409,11 +410,10 @@ class WebServer: @self.app.get("/") async def root(): context = self.create_context() - logging.info(f"Redirecting non-session to {context.id}") + logging.info(f"Redirecting non-context to {context.id}") return RedirectResponse(url=f"/{context.id}", status_code=307) #return JSONResponse({"redirect": f"/{context.id}"}) - @self.app.put("/api/umap/{context_id}") async def put_umap(context_id: str, request: Request): logging.info(f"{request.method} {request.url.path}") @@ -487,16 +487,16 @@ class WebServer: logging.error(e) #return JSONResponse({"error": str(e)}, 500) - @self.app.put("/api/reset/{context_id}/{session_type}") - async def put_reset(context_id: str, session_type: str, request: Request): + @self.app.put("/api/reset/{context_id}/{agent_type}") + async def put_reset(context_id: str, agent_type: str, request: Request): logging.info(f"{request.method} {request.url.path}") if not is_valid_uuid(context_id): logging.warning(f"Invalid context_id: {context_id}") return JSONResponse({"error": "Invalid context_id"}, status_code=400) context = self.upsert_context(context_id) - session = context.get_session(session_type) - if not session: - return JSONResponse({ "error": f"{session_type} is not recognized", "context": context.id }, status_code=404) + agent = context.get_agent(agent_type) + if not agent: + return JSONResponse({ "error": f"{agent_type} is not recognized", "context": context.id }, status_code=404) data = await request.json() try: @@ -505,7 +505,7 @@ class WebServer: match reset_operation: case "system_prompt": logging.info(f"Resetting {reset_operation}") - match session_type: + match agent_type: case "chat": prompt = system_message case "job_description": @@ -515,7 +515,7 @@ class WebServer: case "fact_check": prompt = system_message - session.system_prompt = prompt + agent.system_prompt = prompt response["system_prompt"] = { "system_prompt": prompt } case "rags": logging.info(f"Resetting {reset_operation}") @@ -532,17 +532,17 @@ class WebServer: "fact_check": ("job_description", "resume", "fact_check"), "chat": ("chat",), } - resets = reset_map.get(session_type, ()) + resets = reset_map.get(agent_type, ()) for mode in resets: - tmp = context.get_session(mode) + tmp = context.get_agent(mode) if not tmp: continue logging.info(f"Resetting {reset_operation} for {mode}") context.conversation = Conversation() - context.context_tokens = round(len(str(session.system_prompt)) * 3 / 4) # Estimate context usage + context.context_tokens = round(len(str(agent.system_prompt)) * 3 / 4) # Estimate context usage response["history"] = [] - response["context_used"] = session.context_tokens + response["context_used"] = agent.context_tokens case "message_history_length": logging.info(f"Resetting {reset_operation}") context.message_history_length = DEFAULT_HISTORY_LENGTH @@ -564,8 +564,8 @@ class WebServer: context = self.upsert_context(context_id) data = await request.json() - session = context.get_session("chat") - if not session: + agent = context.get_agent("chat") + if not agent: return JSONResponse({ "error": f"chat is not recognized", "context": context.id }, status_code=404) for k in data.keys(): match k: @@ -600,7 +600,7 @@ class WebServer: system_prompt = data[k].strip() if not system_prompt: return JSONResponse({ "status": "error", "message": "System prompt can not be empty." }) - session.system_prompt = system_prompt + agent.system_prompt = system_prompt self.save_context(context_id) return JSONResponse({ "system_prompt": system_prompt }) case "message_history_length": @@ -621,11 +621,11 @@ class WebServer: logging.warning(f"Invalid context_id: {context_id}") return JSONResponse({"error": "Invalid context_id"}, status_code=400) context = self.upsert_context(context_id) - session = context.get_session("chat") - if not session: + agent = context.get_agent("chat") + if not agent: return JSONResponse({ "error": f"chat is not recognized", "context": context.id }, status_code=404) return JSONResponse({ - "system_prompt": session.system_prompt, + "system_prompt": agent.system_prompt, "message_history_length": context.message_history_length, "rags": context.rags, "tools": [ { @@ -639,8 +639,8 @@ class WebServer: logging.info(f"{request.method} {request.url.path}") return JSONResponse(system_info(self.model)) - @self.app.post("/api/chat/{context_id}/{session_type}") - async def post_chat_endpoint(context_id: str, session_type: str, request: Request): + @self.app.post("/api/chat/{context_id}/{agent_type}") + async def post_chat_endpoint(context_id: str, agent_type: str, request: Request): logging.info(f"{request.method} {request.url.path}") try: if not is_valid_uuid(context_id): @@ -651,18 +651,18 @@ class WebServer: try: data = await request.json() - session = context.get_session(session_type) - if not session and session_type == "job_description": - logging.info(f"Session {session_type} not found. Returning empty history.") - # Create a new session if it doesn't exist - session = context.get_or_create_session("job_description", system_prompt=system_generate_resume, job_description=data["content"]) + agent = context.get_agent(agent_type) + if not agent and agent_type == "job_description": + logging.info(f"Agent {agent_type} not found. Returning empty history.") + # Create a new agent if it doesn't exist + agent = context.get_or_create_agent("job_description", system_prompt=system_generate_resume, job_description=data["content"]) except Exception as e: - logging.info(f"Attempt to create session type: {session_type} failed", e) - return JSONResponse({ "error": f"{session_type} is not recognized", "context": context.id }, status_code=404) + logging.info(f"Attempt to create agent type: {agent_type} failed", e) + return JSONResponse({ "error": f"{agent_type} is not recognized", "context": context.id }, status_code=404) # Create a custom generator that ensures flushing async def flush_generator(): - async for message in self.generate_response(context=context, session=session, content=data["content"]): + async for message in self.generate_response(context=context, agent=agent, content=data["content"]): # Convert to JSON and add newline yield json.dumps(message) + "\n" # Save the history as its generated @@ -687,20 +687,20 @@ class WebServer: @self.app.post("/api/context") async def create_context(): context = self.create_context() - logging.info(f"Generated new session as {context.id}") + logging.info(f"Generated new agent as {context.id}") return JSONResponse({ "id": context.id }) - @self.app.get("/api/history/{context_id}/{session_type}") - async def get_history(context_id: str, session_type: str, request: Request): + @self.app.get("/api/history/{context_id}/{agent_type}") + async def get_history(context_id: str, agent_type: str, request: Request): logging.info(f"{request.method} {request.url.path}") try: context = self.upsert_context(context_id) - session = context.get_session(session_type) - if not session: - logging.info(f"Session {session_type} not found. Returning empty history.") + agent = context.get_agent(agent_type) + if not agent: + logging.info(f"Agent {agent_type} not found. Returning empty history.") return JSONResponse({ "messages": [] }) - logging.info(f"History for {session_type} contains {len(session.conversation.messages)} entries.") - return session.conversation + logging.info(f"History for {agent_type} contains {len(agent.conversation.messages)} entries.") + return agent.conversation except Exception as e: logging.error(f"Error in get_history: {e}") return JSONResponse({"error": str(e)}, status_code=404) @@ -732,17 +732,17 @@ class WebServer: return JSONResponse({ "status": "error" }, 405) - @self.app.get("/api/context-status/{context_id}/{session_type}") - async def get_context_status(context_id, session_type: str, request: Request): + @self.app.get("/api/context-status/{context_id}/{agent_type}") + async def get_context_status(context_id, agent_type: str, request: Request): logging.info(f"{request.method} {request.url.path}") if not is_valid_uuid(context_id): logging.warning(f"Invalid context_id: {context_id}") return JSONResponse({"error": "Invalid context_id"}, status_code=400) context = self.upsert_context(context_id) - session = context.get_session(session_type) - if not session: + agent = context.get_agent(agent_type) + if not agent: return JSONResponse({"context_used": 0, "max_context": defines.max_context}) - return JSONResponse({"context_used": session.context_tokens, "max_context": defines.max_context}) + return JSONResponse({"context_used": agent.context_tokens, "max_context": defines.max_context}) @self.app.get("/api/health") async def health_check(): @@ -757,52 +757,52 @@ class WebServer: logging.info(f"Serve index.html for {path}") return FileResponse(os.path.join(defines.static_content, "index.html")) - def save_context(self, session_id): + def save_context(self, agent_id): """ - Serialize a Python dictionary to a file in the sessions directory. + Serialize a Python dictionary to a file in the agents directory. Args: - data: Dictionary containing the session data - session_id: UUID string for the context. If it doesn't exist, it is created + data: Dictionary containing the agent data + agent_id: UUID string for the context. If it doesn't exist, it is created Returns: - The session_id used for the file + The agent_id used for the file """ - context = self.upsert_context(session_id) + context = self.upsert_context(agent_id) - # Create sessions directory if it doesn't exist - if not os.path.exists(defines.session_dir): - os.makedirs(defines.session_dir) + # Create agents directory if it doesn't exist + if not os.path.exists(defines.context_dir): + os.makedirs(defines.context_dir) # Create the full file path - file_path = os.path.join(defines.session_dir, session_id) + file_path = os.path.join(defines.context_dir, agent_id) # Serialize the data to JSON and write to file with open(file_path, "w") as f: f.write(context.model_dump_json()) - return session_id + return agent_id - def load_context(self, session_id) -> Context: + def load_context(self, agent_id) -> Context: """ - Load a context from a file in the sessions directory. + Load a context from a file in the agents directory. Args: - session_id: UUID string for the context. If it doesn't exist, a new context is created. + agent_id: UUID string for the context. If it doesn't exist, a new context is created. Returns: A Context object with the specified ID and default settings. """ - file_path = os.path.join(defines.session_dir, session_id) + file_path = os.path.join(defines.context_dir, agent_id) # Check if the file exists if not os.path.exists(file_path): - self.contexts[session_id] = self.create_context(session_id) + self.contexts[agent_id] = self.create_context(agent_id) else: # Read and deserialize the data with open(file_path, "r") as f: - self.contexts[session_id] = Context.model_validate_json(f.read()) + self.contexts[agent_id] = Context.model_validate_json(f.read()) - return self.contexts[session_id] + return self.contexts[agent_id] def create_context(self, context_id = None) -> Context: """ @@ -816,14 +816,16 @@ class WebServer: if os.path.exists(defines.resume_doc): context.user_resume = open(defines.resume_doc, "r").read() - context.add_session(Chat(system_prompt = system_message)) - # context.add_session(Resume(system_prompt = system_generate_resume)) - # context.add_session(JobDescription(system_prompt = system_job_description)) - # context.add_session(FactCheck(system_prompt = system_fact_check)) + context.get_or_create_agent( + agent_type="chat", + system_prompt=system_message) + # context.add_agent(Resume(system_prompt = system_generate_resume)) + # context.add_agent(JobDescription(system_prompt = system_job_description)) + # context.add_agent(FactCheck(system_prompt = system_fact_check)) context.tools = default_tools(tools) context.rags = rags.copy() - logging.info(f"{context.id} created and added to sessions.") + logging.info(f"{context.id} created and added to contexts.") self.contexts[context.id] = context self.save_context(context.id) return context @@ -956,15 +958,15 @@ class WebServer: else: yield {"status": "complete", "message": "RAG processing complete"} - # session_type: chat + # agent_type: chat # * Q&A # - # session_type: job_description + # agent_type: job_description # * First message sets Job Description and generates Resume # * Has content (Job Description) # * Then Q&A of Job Description # - # session_type: resume + # agent_type: resume # * First message sets Resume and generates Fact Check # * Has no content # * Then Q&A of Resume @@ -973,18 +975,18 @@ class WebServer: # * First message sets Fact Check and is Q&A # * Has content # * Then Q&A of Fact Check - async def generate_response(self, context : Context, session : Session, content : str): + async def generate_response(self, context : Context, agent : Agent, content : str): if not self.file_watcher: return if self.processing: - logging.info("TODO: Implement delay queing; busy for same session, otherwise return queue size and estimated wait time") + logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time") yield {"status": "error", "message": "Busy processing another request."} return self.processing = True - conversation : Conversation = session.conversation + conversation : Conversation = agent.conversation message = Message(prompt=content) del content # Prevent accidental use of content @@ -999,18 +1001,18 @@ class WebServer: enable_rag = False # RAG is disabled when asking questions about the resume - if session.session_type == "resume": + if agent.agent_type == "resume": enable_rag = False - # The first time through each session session_type a content_seed may be set for - # future chat sessions; use it once, then clear it - message.preamble = session.get_and_reset_content_seed() - system_prompt = session.system_prompt + # The first time through each agent agent_type a content_seed may be set for + # future chat agents; use it once, then clear it + message.preamble = agent.get_and_reset_content_seed() + system_prompt = agent.system_prompt - # After the first time a particular session session_type is used, it is handled as a chat. - # The number of messages indicating the session is ready for chat varies based on - # the session_type of session - process_type = session.session_type + # After the first time a particular agent agent_type is used, it is handled as a chat. + # The number of messages indicating the agent is ready for chat varies based on + # the agent_type of agent + process_type = agent.agent_type match process_type: case "job_description": logging.info(f"job_description user_history len: {len(conversation.messages)}") @@ -1021,7 +1023,7 @@ class WebServer: if len(conversation.messages) >= 3: # USER, ASSISTANT, FACT_CHECK process_type = "chat" case "fact_check": - process_type = "chat" # Fact Check is always a chat session + process_type = "chat" # Fact Check is always a chat agent match process_type: # Normal chat interactions with context history @@ -1071,7 +1073,7 @@ class WebServer: Use that information to respond to:""" # Use the mode specific system_prompt instead of 'chat' - system_prompt = session.system_prompt + system_prompt = agent.system_prompt # On first entry, a single job_description is provided ("user") # Generate a resume to append to RESUME history @@ -1110,10 +1112,10 @@ Use that information to respond to:""" <|job_description|> {message.prompt} """ - tmp = context.get_session("job_description") + tmp = context.get_agent("job_description") if not tmp: - raise Exception(f"Job description session not found.") - # Set the content seed for the job_description session + raise Exception(f"Job description agent not found.") + # Set the content seed for the job_description agent tmp.set_content_seed(message.preamble + "<|question|>\nUse the above information to respond to this prompt: ") message.preamble += f""" @@ -1126,7 +1128,7 @@ Use to the above information to respond to this prompt: """ # For all future calls to job_description, use the system_job_description - session.system_prompt = system_job_description + agent.system_prompt = system_job_description # Seed the history for job_description stuffingMessage = Message(prompt=message.prompt) @@ -1137,21 +1139,21 @@ Use to the above information to respond to this prompt: message.add_action("generate_resume") - logging.info("TODO: Convert these to generators, eg generate_resume() and then manually add results into session'resume'") - logging.info("TODO: For subsequent runs, have the Session handler generate the follow up prompts so they can have correct context preamble") + logging.info("TODO: Convert these to generators, eg generate_resume() and then manually add results into agent'resume'") + logging.info("TODO: For subsequent runs, have the Agent handler generate the follow up prompts so they can have correct context preamble") - # Switch to resume session for LLM responses + # Switch to resume agent for LLM responses # message.metadata["origin"] = "resume" - # session = context.get_or_create_session("resume") - # system_prompt = session.system_prompt - # llm_history = session.llm_history = [] - # user_history = session.user_history = [] + # agent = context.get_or_create_agent("resume") + # system_prompt = agent.system_prompt + # llm_history = agent.llm_history = [] + # user_history = agent.user_history = [] # Ignore the passed in content and invoke Fact Check case "resume": - if len(context.get_or_create_session("resume").conversation.messages) < 2: # USER, **ASSISTANT** + if len(context.get_or_create_agent("resume").conversation.messages) < 2: # USER, **ASSISTANT** raise Exception(f"No resume found in user history.") - resume = context.get_or_create_session("resume").conversation.messages[1] + resume = context.get_or_create_agent("resume").conversation.messages[1] # Generate RAG content if enabled, based on the content rag_context = "" @@ -1196,7 +1198,7 @@ Use to the above information to respond to this prompt: <|question|> """ - context.get_or_create_session("resume").set_content_seed(f""" + context.get_or_create_agent("resume").set_content_seed(f""" <|resume|> {resume["content"]} @@ -1222,25 +1224,25 @@ Use the above <|resume|> and <|job_description|> to answer this query: conversation.add_message(stuffingMessage) # For all future calls to job_description, use the system_job_description - logging.info("TODO: Create a system_resume_QA prompt to use for the resume session") - session.system_prompt = system_prompt + logging.info("TODO: Create a system_resume_QA prompt to use for the resume agent") + agent.system_prompt = system_prompt - # Switch to fact_check session for LLM responses + # Switch to fact_check agent for LLM responses message.metadata["origin"] = "fact_check" - session = context.get_or_create_session("fact_check", system_prompt=system_fact_check) + agent = context.get_or_create_agent("fact_check", system_prompt=system_fact_check) - llm_history = session.llm_history = [] - user_history = session.user_history = [] + llm_history = agent.llm_history = [] + user_history = agent.user_history = [] case _: - raise Exception(f"Invalid chat session_type: {session_type}") + raise Exception(f"Invalid chat agent_type: {agent_type}") conversation.add_message(message) # llm_history.append({"role": "user", "content": message.preamble + content}) # user_history.append({"role": "user", "content": content, "origin": message.metadata["origin"]}) # message.metadata["full_query"] = llm_history[-1]["content"] - # Uses cached system_prompt as session.system_prompt may have been updated for follow up questions + # Uses cached system_prompt as agent.system_prompt may have been updated for follow up questions messages = create_system_message(system_prompt) if context.message_history_length: to_add = conversation.messages[-context.message_history_length:] @@ -1272,12 +1274,12 @@ Use the above <|resume|> and <|job_description|> to answer this query: {message.prompt}""" # Estimate token length of new messages - ctx_size = self.get_optimal_ctx_size(context.get_or_create_session(process_type).context_tokens, messages=message.prompt) + ctx_size = self.get_optimal_ctx_size(context.get_or_create_agent(process_type).context_tokens, messages=message.prompt) if len(conversation.messages) > 2: processing_message = f"Processing {'RAG augmented ' if enable_rag else ''}query..." else: - match session.session_type: + match agent.agent_type: case "job_description": processing_message = f"Generating {'RAG augmented ' if enable_rag else ''}resume..." case "resume": @@ -1303,7 +1305,7 @@ Use the above <|resume|> and <|job_description|> to answer this query: message.metadata["eval_duration"] += response["eval_duration"] message.metadata["prompt_eval_count"] += response["prompt_eval_count"] message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"] - session.context_tokens = response["prompt_eval_count"] + response["eval_count"] + agent.context_tokens = response["prompt_eval_count"] + response["eval_count"] tools_used = [] @@ -1347,7 +1349,7 @@ Use the above <|resume|> and <|job_description|> to answer this query: message.metadata["tools"] = tools_used # Estimate token length of new messages - ctx_size = self.get_optimal_ctx_size(session.context_tokens, messages=messages[pre_add_index:]) + ctx_size = self.get_optimal_ctx_size(agent.context_tokens, messages=messages[pre_add_index:]) yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size } # Decrease creativity when processing tool call requests response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 }) @@ -1355,11 +1357,11 @@ Use the above <|resume|> and <|job_description|> to answer this query: message.metadata["eval_duration"] += response["eval_duration"] message.metadata["prompt_eval_count"] += response["prompt_eval_count"] message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"] - session.context_tokens = response["prompt_eval_count"] + response["eval_count"] + agent.context_tokens = response["prompt_eval_count"] + response["eval_count"] reply = response["message"]["content"] message.response = reply - message.metadata["origin"] = session.session_type + message.metadata["origin"] = agent.agent_type # final_message = {"role": "assistant", "content": reply } # # history is provided to the LLM and should not have additional metadata @@ -1379,7 +1381,7 @@ Use the above <|resume|> and <|job_description|> to answer this query: } # except Exception as e: - # logging.exception({ "model": self.model, "origin": session_type, "content": content, "error": str(e) }) + # logging.exception({ "model": self.model, "origin": agent_type, "content": content, "error": str(e) }) # yield {"status": "error", "message": f"An error occurred: {str(e)}"} # finally: diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 3b6a2da..e00511e 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,10 +1,17 @@ -# Import defines to make `utils.defines` accessible from . import defines - -# Import rest as `utils.*` accessible from .rag import ChromaDBFileWatcher, start_file_watcher - from .message import Message from .conversation import Conversation -from .session import Session, Chat, Resume, JobDescription, FactCheck -from .context import Context \ No newline at end of file +from .context import Context +from . import agents + +from .agents import Agent, __all__ as agents_all + +__all__ = [ + 'Agent', + 'Context', + 'Conversation', + 'Message', + 'ChromaDBFileWatcher', + 'start_file_watcher' +] + agents_all \ No newline at end of file diff --git a/src/utils/agent.py b/src/utils/agent.py new file mode 100644 index 0000000..e723f54 --- /dev/null +++ b/src/utils/agent.py @@ -0,0 +1,256 @@ +from pydantic import BaseModel, Field, model_validator, PrivateAttr +from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar +from abc import ABC, abstractmethod +from typing_extensions import Annotated +import logging + +from .types import AgentBase, registry + +# Only import Context for type checking +if TYPE_CHECKING: + from .context import Context + +from .types import AgentBase + +from .conversation import Conversation +from .message import Message + +class Agent(AgentBase): + """ + Base class for all agent types. + This class defines the common attributes and methods for all agent types. + """ + agent_type: str = Field(default="agent", const=True) # discriminator value + + + def __init_subclass__(cls, **kwargs): + """Auto-register subclasses""" + super().__init_subclass__(**kwargs) + # Register this class if it has an agent_type + if hasattr(cls, 'agent_type') and cls.agent_type != Agent.agent_type: + registry.register(cls.agent_type, cls) + + def __init__(self, **data): + # Set agent_type from class if not provided + if 'agent_type' not in data: + data['agent_type'] = self.__class__.agent_type + super().__init__(**data) + + system_prompt: str # Mandatory + conversation: Conversation = Conversation() + context_tokens: int = 0 + + # Add a property for context if needed without creating a circular reference + @property + def context(self) -> Optional['Context']: + if TYPE_CHECKING: + from .context import Context + # Implement logic to fetch context by ID if needed + return None + + #context: Context + + _content_seed: str = PrivateAttr(default="") + + async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]: + """ + Prepare message with context information in message.preamble + """ + # Generate RAG content if enabled, based on the content + rag_context = "" + if not message.disable_rag: + # Gather RAG results, yielding each result + # as it becomes available + for value in self.context.generate_rag_results(message): + logging.info(f"RAG: {value.status} - {value.response}") + if value.status != "done": + yield value + if value.status == "error": + message.status = "error" + message.response = value.response + yield message + return + + if message.metadata["rag"]: + for rag_collection in message.metadata["rag"]: + for doc in rag_collection["documents"]: + rag_context += f"{doc}\n" + + if rag_context: + message["context"] = rag_context + + if self.context.user_resume: + message["resume"] = self.content.user_resume + + if message.preamble: + preamble_types = [f"<|{p}|>" for p in message.preamble.keys()] + preamble_types_AND = " and ".join(preamble_types) + preamble_types_OR = " or ".join(preamble_types) + message.preamble["rules"] = f"""\ +- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_or_types} or quoting it directly. +- If there is no information in these sections, answer based on your knowledge. +- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}. +""" + message.preamble["question"] = "Use that information to respond to:" + else: + message.preamble["question"] = "Respond to:" + + message.system_prompt = self.system_prompt + message.status = "done" + yield message + return + + async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]: + if self.context.processing: + logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time") + message.status = "error" + message.response = "Busy processing another request." + yield message + return + + self.context.processing = True + + messages = [] + + for value in self.llm.chat( + model=self.model, + messages=messages, + #tools=llm_tools(context.tools) if message.enable_tools else None, + options={ "num_ctx": message.ctx_size } + ): + logging.info(f"LLM: {value.status} - {value.response}") + if value.status != "done": + message.status = value.status + message.response = value.response + yield message + if value.status == "error": + return + response = value + + message.metadata["eval_count"] += response["eval_count"] + message.metadata["eval_duration"] += response["eval_duration"] + message.metadata["prompt_eval_count"] += response["prompt_eval_count"] + message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"] + agent.context_tokens = response["prompt_eval_count"] + response["eval_count"] + + tools_used = [] + + yield {"status": "processing", "message": "Initial response received..."} + + if "tool_calls" in response.get("message", {}): + yield {"status": "processing", "message": "Processing tool calls..."} + + tool_message = response["message"] + tool_result = None + + # Process all yielded items from the handler + async for item in self.handle_tool_calls(tool_message): + if isinstance(item, tuple) and len(item) == 2: + # This is the final result tuple (tool_result, tools_used) + tool_result, tools_used = item + else: + # This is a status update, forward it + yield item + + message_dict = { + "role": tool_message.get("role", "assistant"), + "content": tool_message.get("content", "") + } + + if "tool_calls" in tool_message: + message_dict["tool_calls"] = [ + {"function": {"name": tc["function"]["name"], "arguments": tc["function"]["arguments"]}} + for tc in tool_message["tool_calls"] + ] + + pre_add_index = len(messages) + messages.append(message_dict) + + if isinstance(tool_result, list): + messages.extend(tool_result) + else: + if tool_result: + messages.append(tool_result) + + message.metadata["tools"] = tools_used + + # Estimate token length of new messages + ctx_size = self.get_optimal_ctx_size(agent.context_tokens, messages=messages[pre_add_index:]) + yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size } + # Decrease creativity when processing tool call requests + response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 }) + message.metadata["eval_count"] += response["eval_count"] + message.metadata["eval_duration"] += response["eval_duration"] + message.metadata["prompt_eval_count"] += response["prompt_eval_count"] + message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"] + agent.context_tokens = response["prompt_eval_count"] + response["eval_count"] + + reply = response["message"]["content"] + message.response = reply + message.metadata["origin"] = agent.agent_type + # final_message = {"role": "assistant", "content": reply } + + # # history is provided to the LLM and should not have additional metadata + # llm_history.append(final_message) + + # user_history is provided to the REST API and does not include CONTEXT + # It does include metadata + # final_message["metadata"] = message.metadata + # user_history.append({**final_message, "origin": message.metadata["origin"]}) + + # Return the REST API with metadata + yield { + "status": "done", + "message": { + **message.model_dump(mode='json'), + } + } + + self.context.processing = False + return + + async def process_message(self, message:Message) -> AsyncGenerator[Message, None]: + message.full_content = "" + for i, p in enumerate(message.preamble.keys()): + message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n" + + # Estimate token length of new messages + message.ctx_size = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content) + + message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..." + message.status = "thinking" + yield message + + for value in self.generate_llm_response(message): + logging.info(f"LLM: {value.status} - {value.response}") + if value.status != "done": + yield value + if value.status == "error": + return + + def get_and_reset_content_seed(self): + tmp = self._content_seed + self._content_seed = "" + return tmp + + def set_content_seed(self, content: str) -> None: + """Set the content seed for the agent.""" + self._content_seed = content + + def get_content_seed(self) -> str: + """Get the content seed for the agent.""" + return self._content_seed + + @classmethod + def valid_agent_types(cls) -> set[str]: + """Return the set of valid agent_type values.""" + return set(get_args(cls.__annotations__["agent_type"])) + + +# Register the base agent +registry.register(Agent.agent_type, Agent) + +# Type alias for Agent or any subclass +AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses + +import ./agents \ No newline at end of file diff --git a/src/utils/agents/base.py b/src/utils/agents/base.py new file mode 100644 index 0000000..489673d --- /dev/null +++ b/src/utils/agents/base.py @@ -0,0 +1,241 @@ +from pydantic import BaseModel, Field, model_validator, PrivateAttr +from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar +from abc import ABC, abstractmethod +from typing_extensions import Annotated +import logging + +# Only import Context for type checking +if TYPE_CHECKING: + from .. context import Context + +from .types import AgentBase, ContextBase, registry + +from .. conversation import Conversation +from .. message import Message + +class Agent(AgentBase): + """ + Base class for all agent types. + This class defines the common attributes and methods for all agent types. + """ + agent_type: Literal["base"] = "base" + _agent_type: ClassVar[str] = agent_type # Add this for registration + + def __init_subclass__(cls, **kwargs): + """Auto-register subclasses""" + super().__init_subclass__(**kwargs) + # Register this class if it has an agent_type + if hasattr(cls, 'agent_type') and cls.agent_type != AgentBase._agent_type: + registry.register(cls.agent_type, cls) + + def __init__(self, **data): + # Set agent_type from class if not provided + if 'agent_type' not in data: + data['agent_type'] = self.__class__.agent_type + super().__init__(**data) + + system_prompt: str # Mandatory + conversation: Conversation = Conversation() + context_tokens: int = 0 + context: ContextBase # Avoid circular reference + + _content_seed: str = PrivateAttr(default="") + + async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]: + """ + Prepare message with context information in message.preamble + """ + # Generate RAG content if enabled, based on the content + rag_context = "" + if not message.disable_rag: + # Gather RAG results, yielding each result + # as it becomes available + for value in self.context.generate_rag_results(message): + logging.info(f"RAG: {value.status} - {value.response}") + if value.status != "done": + yield value + if value.status == "error": + message.status = "error" + message.response = value.response + yield message + return + + if message.metadata["rag"]: + for rag_collection in message.metadata["rag"]: + for doc in rag_collection["documents"]: + rag_context += f"{doc}\n" + + if rag_context: + message["context"] = rag_context + + if self.context.user_resume: + message["resume"] = self.content.user_resume + + if message.preamble: + preamble_types = [f"<|{p}|>" for p in message.preamble.keys()] + preamble_types_AND = " and ".join(preamble_types) + preamble_types_OR = " or ".join(preamble_types) + message.preamble["rules"] = f"""\ +- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_or_types} or quoting it directly. +- If there is no information in these sections, answer based on your knowledge. +- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}. +""" + message.preamble["question"] = "Use that information to respond to:" + else: + message.preamble["question"] = "Respond to:" + + message.system_prompt = self.system_prompt + message.status = "done" + yield message + return + + async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]: + if self.context.processing: + logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time") + message.status = "error" + message.response = "Busy processing another request." + yield message + return + + self.context.processing = True + + messages = [] + + for value in self.llm.chat( + model=self.model, + messages=messages, + #tools=llm_tools(context.tools) if message.enable_tools else None, + options={ "num_ctx": message.ctx_size } + ): + logging.info(f"LLM: {value.status} - {value.response}") + if value.status != "done": + message.status = value.status + message.response = value.response + yield message + if value.status == "error": + return + response = value + + message.metadata["eval_count"] += response["eval_count"] + message.metadata["eval_duration"] += response["eval_duration"] + message.metadata["prompt_eval_count"] += response["prompt_eval_count"] + message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"] + agent.context_tokens = response["prompt_eval_count"] + response["eval_count"] + + tools_used = [] + + yield {"status": "processing", "message": "Initial response received..."} + + if "tool_calls" in response.get("message", {}): + yield {"status": "processing", "message": "Processing tool calls..."} + + tool_message = response["message"] + tool_result = None + + # Process all yielded items from the handler + async for item in self.handle_tool_calls(tool_message): + if isinstance(item, tuple) and len(item) == 2: + # This is the final result tuple (tool_result, tools_used) + tool_result, tools_used = item + else: + # This is a status update, forward it + yield item + + message_dict = { + "role": tool_message.get("role", "assistant"), + "content": tool_message.get("content", "") + } + + if "tool_calls" in tool_message: + message_dict["tool_calls"] = [ + {"function": {"name": tc["function"]["name"], "arguments": tc["function"]["arguments"]}} + for tc in tool_message["tool_calls"] + ] + + pre_add_index = len(messages) + messages.append(message_dict) + + if isinstance(tool_result, list): + messages.extend(tool_result) + else: + if tool_result: + messages.append(tool_result) + + message.metadata["tools"] = tools_used + + # Estimate token length of new messages + ctx_size = self.get_optimal_ctx_size(agent.context_tokens, messages=messages[pre_add_index:]) + yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size } + # Decrease creativity when processing tool call requests + response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 }) + message.metadata["eval_count"] += response["eval_count"] + message.metadata["eval_duration"] += response["eval_duration"] + message.metadata["prompt_eval_count"] += response["prompt_eval_count"] + message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"] + agent.context_tokens = response["prompt_eval_count"] + response["eval_count"] + + reply = response["message"]["content"] + message.response = reply + message.metadata["origin"] = agent.agent_type + # final_message = {"role": "assistant", "content": reply } + + # # history is provided to the LLM and should not have additional metadata + # llm_history.append(final_message) + + # user_history is provided to the REST API and does not include CONTEXT + # It does include metadata + # final_message["metadata"] = message.metadata + # user_history.append({**final_message, "origin": message.metadata["origin"]}) + + # Return the REST API with metadata + yield { + "status": "done", + "message": { + **message.model_dump(mode='json'), + } + } + + self.context.processing = False + return + + async def process_message(self, message:Message) -> AsyncGenerator[Message, None]: + message.full_content = "" + for i, p in enumerate(message.preamble.keys()): + message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n" + + # Estimate token length of new messages + message.ctx_size = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content) + + message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..." + message.status = "thinking" + yield message + + for value in self.generate_llm_response(message): + logging.info(f"LLM: {value.status} - {value.response}") + if value.status != "done": + yield value + if value.status == "error": + return + + def get_and_reset_content_seed(self): + tmp = self._content_seed + self._content_seed = "" + return tmp + + def set_content_seed(self, content: str) -> None: + """Set the content seed for the agent.""" + self._content_seed = content + + def get_content_seed(self) -> str: + """Get the content seed for the agent.""" + return self._content_seed + + @classmethod + def valid_agent_types(cls) -> set[str]: + """Return the set of valid agent_type values.""" + return set(get_args(cls.__annotations__["agent_type"])) + +# Register the base agent +registry.register(Agent._agent_type, Agent) + + diff --git a/src/utils/agents/chat.py b/src/utils/agents/chat.py new file mode 100644 index 0000000..26b4618 --- /dev/null +++ b/src/utils/agents/chat.py @@ -0,0 +1,243 @@ +from pydantic import BaseModel, Field, model_validator, PrivateAttr +from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar +from typing_extensions import Annotated +from abc import ABC, abstractmethod +from typing_extensions import Annotated +import logging +from .base import Agent, registry +from .. conversation import Conversation +from .. message import Message + +class Chat(Agent, ABC): + """ + Base class for all agent types. + This class defines the common attributes and methods for all agent types. + """ + agent_type: Literal["chat"] = "chat" + _agent_type: ClassVar[str] = agent_type # Add this for registration + + def __init_subclass__(cls, **kwargs): + """Auto-register subclasses""" + super().__init_subclass__(**kwargs) + # Register this class if it has an agent_type + if hasattr(cls, 'agent_type') and cls.agent_type != Agent.agent_type: + registry.register(cls.agent_type, cls) + + def __init__(self, **data): + # Set agent_type from class if not provided + if 'agent_type' not in data: + data['agent_type'] = self.__class__.agent_type + super().__init__(**data) + + system_prompt: str # Mandatory + conversation: Conversation = Conversation() + context_tokens: int = 0 + + # Add a property for context if needed without creating a circular reference + @property + def context(self) -> Optional['Context']: + if TYPE_CHECKING: + from .context import Context + # Implement logic to fetch context by ID if needed + return None + + #context: Context + + _content_seed: str = PrivateAttr(default="") + + async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]: + """ + Prepare message with context information in message.preamble + """ + # Generate RAG content if enabled, based on the content + rag_context = "" + if not message.disable_rag: + # Gather RAG results, yielding each result + # as it becomes available + for value in self.context.generate_rag_results(message): + logging.info(f"RAG: {value.status} - {value.response}") + if value.status != "done": + yield value + if value.status == "error": + message.status = "error" + message.response = value.response + yield message + return + + if message.metadata["rag"]: + for rag_collection in message.metadata["rag"]: + for doc in rag_collection["documents"]: + rag_context += f"{doc}\n" + + if rag_context: + message["context"] = rag_context + + if self.context.user_resume: + message["resume"] = self.content.user_resume + + if message.preamble: + preamble_types = [f"<|{p}|>" for p in message.preamble.keys()] + preamble_types_AND = " and ".join(preamble_types) + preamble_types_OR = " or ".join(preamble_types) + message.preamble["rules"] = f"""\ +- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_or_types} or quoting it directly. +- If there is no information in these sections, answer based on your knowledge. +- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}. +""" + message.preamble["question"] = "Use that information to respond to:" + else: + message.preamble["question"] = "Respond to:" + + message.system_prompt = self.system_prompt + message.status = "done" + yield message + return + + async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]: + if self.context.processing: + logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time") + message.status = "error" + message.response = "Busy processing another request." + yield message + return + + self.context.processing = True + + messages = [] + + for value in self.llm.chat( + model=self.model, + messages=messages, + #tools=llm_tools(context.tools) if message.enable_tools else None, + options={ "num_ctx": message.ctx_size } + ): + logging.info(f"LLM: {value.status} - {value.response}") + if value.status != "done": + message.status = value.status + message.response = value.response + yield message + if value.status == "error": + return + response = value + + message.metadata["eval_count"] += response["eval_count"] + message.metadata["eval_duration"] += response["eval_duration"] + message.metadata["prompt_eval_count"] += response["prompt_eval_count"] + message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"] + agent.context_tokens = response["prompt_eval_count"] + response["eval_count"] + + tools_used = [] + + yield {"status": "processing", "message": "Initial response received..."} + + if "tool_calls" in response.get("message", {}): + yield {"status": "processing", "message": "Processing tool calls..."} + + tool_message = response["message"] + tool_result = None + + # Process all yielded items from the handler + async for item in self.handle_tool_calls(tool_message): + if isinstance(item, tuple) and len(item) == 2: + # This is the final result tuple (tool_result, tools_used) + tool_result, tools_used = item + else: + # This is a status update, forward it + yield item + + message_dict = { + "role": tool_message.get("role", "assistant"), + "content": tool_message.get("content", "") + } + + if "tool_calls" in tool_message: + message_dict["tool_calls"] = [ + {"function": {"name": tc["function"]["name"], "arguments": tc["function"]["arguments"]}} + for tc in tool_message["tool_calls"] + ] + + pre_add_index = len(messages) + messages.append(message_dict) + + if isinstance(tool_result, list): + messages.extend(tool_result) + else: + if tool_result: + messages.append(tool_result) + + message.metadata["tools"] = tools_used + + # Estimate token length of new messages + ctx_size = self.get_optimal_ctx_size(agent.context_tokens, messages=messages[pre_add_index:]) + yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size } + # Decrease creativity when processing tool call requests + response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 }) + message.metadata["eval_count"] += response["eval_count"] + message.metadata["eval_duration"] += response["eval_duration"] + message.metadata["prompt_eval_count"] += response["prompt_eval_count"] + message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"] + agent.context_tokens = response["prompt_eval_count"] + response["eval_count"] + + reply = response["message"]["content"] + message.response = reply + message.metadata["origin"] = agent.agent_type + # final_message = {"role": "assistant", "content": reply } + + # # history is provided to the LLM and should not have additional metadata + # llm_history.append(final_message) + + # user_history is provided to the REST API and does not include CONTEXT + # It does include metadata + # final_message["metadata"] = message.metadata + # user_history.append({**final_message, "origin": message.metadata["origin"]}) + + # Return the REST API with metadata + yield { + "status": "done", + "message": { + **message.model_dump(mode='json'), + } + } + + self.context.processing = False + return + + async def process_message(self, message:Message) -> AsyncGenerator[Message, None]: + message.full_content = "" + for i, p in enumerate(message.preamble.keys()): + message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n" + + # Estimate token length of new messages + message.ctx_size = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content) + + message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..." + message.status = "thinking" + yield message + + for value in self.generate_llm_response(message): + logging.info(f"LLM: {value.status} - {value.response}") + if value.status != "done": + yield value + if value.status == "error": + return + + def get_and_reset_content_seed(self): + tmp = self._content_seed + self._content_seed = "" + return tmp + + def set_content_seed(self, content: str) -> None: + """Set the content seed for the agent.""" + self._content_seed = content + + def get_content_seed(self) -> str: + """Get the content seed for the agent.""" + return self._content_seed + + @classmethod + def valid_agent_types(cls) -> set[str]: + """Return the set of valid agent_type values.""" + return set(get_args(cls.__annotations__["agent_type"])) + +# Register the base agent +registry.register(Chat._agent_type, Chat) diff --git a/src/utils/agents/fact_check.py b/src/utils/agents/fact_check.py new file mode 100644 index 0000000..c6e5ff9 --- /dev/null +++ b/src/utils/agents/fact_check.py @@ -0,0 +1,24 @@ +from pydantic import BaseModel, Field, model_validator, PrivateAttr +from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar +from typing_extensions import Annotated +from abc import ABC, abstractmethod +from typing_extensions import Annotated +import logging +from .base import Agent, registry +from .. conversation import Conversation +from .. message import Message + +class FactCheck(Agent): + agent_type: Literal["fact_check"] = "fact_check" + _agent_type: ClassVar[str] = agent_type # Add this for registration + + facts: str = "" + + @model_validator(mode="after") + def validate_facts(self): + if not self.facts.strip(): + raise ValueError("Facts cannot be empty") + return self + +# Register the base agent +registry.register(FactCheck._agent_type, FactCheck) diff --git a/src/utils/agents/job_description.py b/src/utils/agents/job_description.py new file mode 100644 index 0000000..360f677 --- /dev/null +++ b/src/utils/agents/job_description.py @@ -0,0 +1,24 @@ +from pydantic import BaseModel, Field, model_validator, PrivateAttr +from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar +from typing_extensions import Annotated +from abc import ABC, abstractmethod +from typing_extensions import Annotated +import logging +from .base import Agent, registry +from .. conversation import Conversation +from .. message import Message + +class JobDescription(Agent): + agent_type: Literal["job_description"] = "job_description" + _agent_type: ClassVar[str] = agent_type # Add this for registration + + job_description: str = "" + + @model_validator(mode="after") + def validate_job_description(self): + if not self.job_description.strip(): + raise ValueError("Job description cannot be empty") + return self + +# Register the base agent +registry.register(JobDescription._agent_type, JobDescription) diff --git a/src/utils/agents/resume.py b/src/utils/agents/resume.py new file mode 100644 index 0000000..635da5c --- /dev/null +++ b/src/utils/agents/resume.py @@ -0,0 +1,32 @@ +from pydantic import BaseModel, Field, model_validator, PrivateAttr +from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar +from typing_extensions import Annotated +from abc import ABC, abstractmethod +from typing_extensions import Annotated +import logging +from .base import Agent, registry +from .. conversation import Conversation +from .. message import Message + +class Resume(Agent): + agent_type: Literal["resume"] = "resume" + _agent_type: ClassVar[str] = agent_type # Add this for registration + + resume: str = "" + + @model_validator(mode="after") + def validate_resume(self): + if not self.resume.strip(): + raise ValueError("Resume content cannot be empty") + return self + + def get_resume(self) -> str: + """Get the resume content.""" + return self.resume + + def set_resume(self, resume: str) -> None: + """Set the resume content.""" + self.resume = resume + +# Register the base agent +registry.register(Resume._agent_type, Resume) diff --git a/src/utils/context.py b/src/utils/context.py index e06a15f..4afe65a 100644 --- a/src/utils/context.py +++ b/src/utils/context.py @@ -1,16 +1,31 @@ from pydantic import BaseModel, Field, model_validator from uuid import uuid4 -from typing import List, Optional +from typing import List, Dict, Any, Optional, Generator, TYPE_CHECKING from typing_extensions import Annotated, Union -from .session import AnySession, Session +import numpy as np +import logging + +from .message import Message +from .rag import ChromaDBFileWatcher +from . import defines + +from .agents import Agent, ContextBase + +# Import only agent types, not actual classes +if TYPE_CHECKING: + from .agents import Agent, AnyAgent, Chat, Resume, JobDescription, FactCheck + +from .agents import AnyAgent + +class Context(ContextBase): + model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher -class Context(BaseModel): id: str = Field( default_factory=lambda: str(uuid4()), pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$" ) - sessions: List[Annotated[Union[*Session.__subclasses__()], Field(discriminator="session_type")]] = Field( + agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( default_factory=list ) @@ -21,78 +36,145 @@ class Context(BaseModel): rags: List[dict] = [] message_history_length: int = 5 context_tokens: int = 0 + file_watcher: ChromaDBFileWatcher = Field(default=None, exclude=True) def __init__(self, id: Optional[str] = None, **kwargs): super().__init__(id=id if id is not None else str(uuid4()), **kwargs) @model_validator(mode="after") - def validate_unique_session_types(self): - """Ensure at most one session per session_type.""" - session_types = [session.session_type for session in self.sessions] - if len(session_types) != len(set(session_types)): - raise ValueError("Context cannot contain multiple sessions of the same session_type") + def validate_unique_agent_types(self): + """Ensure at most one agent per agent_type.""" + agent_types = [agent.agent_type for agent in self.agents] + if len(agent_types) != len(set(agent_types)): + raise ValueError("Context cannot contain multiple agents of the same agent_type") return self - def get_or_create_session(self, session_type: str, **kwargs) -> Session: + def get_optimal_ctx_size(self, context, messages, ctx_buffer = 4096): + ctx = round(context + len(str(messages)) * 3 / 4) + return max(defines.max_context, min(2048, ctx + ctx_buffer)) + + def generate_rag_results(self, message: Message) -> Generator[Message, None, None]: """ - Get or create and append a new session of the specified type, ensuring only one session per type exists. + Generate RAG results for the given query. Args: - session_type: The type of session to create (e.g., 'web', 'database'). - **kwargs: Additional fields required by the specific session subclass. + query: The query string to generate RAG results for. Returns: - The created session instance. + A list of dictionaries containing the RAG results. + """ + try: + message.status = "processing" + + entries : int = 0 + + if not self.file_watcher: + message.response = "No RAG context available." + del message.metadata["rag"] + message.status = "done" + yield message + return + + for rag in self.rags: + if not rag["enabled"]: + continue + message.response = f"Checking RAG context {rag['name']}..." + yield message + chroma_results = self.file_watcher.find_similar(query=message.prompt, top_k=10) + if chroma_results: + entries += len(chroma_results["documents"]) + + chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape + print(f"Chroma embedding shape: {chroma_embedding.shape}") + + umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist() + print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output + + umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist() + print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output + + message.metadata["rag"][rag["name"]] = { + **chroma_results, + "umap_embedding_2d": umap_2d, + "umap_embedding_3d": umap_3d + } + yield message + + if entries == 0: + del message.metadata["rag"] + + message.response = f"RAG context gathered from results from {entries} documents." + message.status = "done" + yield message + return + except Exception as e: + message.response = f"Error generating RAG results: {str(e)}" + message.status = "error" + logging.error(e) + yield message + return + + def get_or_create_agent(self, agent_type: str, **kwargs) -> Agent: + """ + Get or create and append a new agent of the specified type, ensuring only one agent per type exists. + + Args: + agent_type: The type of agent to create (e.g., 'web', 'database'). + **kwargs: Additional fields required by the specific agent subclass. + + Returns: + The created agent instance. Raises: - ValueError: If no matching session type is found or if a session of this type already exists. + ValueError: If no matching agent type is found or if a agent of this type already exists. """ - # Check if a session with the given session_type already exists - for session in self.sessions: - if session.session_type == session_type: - return session + # Check if a agent with the given agent_type already exists + for agent in self.agents: + if agent.agent_type == agent_type: + return agent # Find the matching subclass - for session_cls in Session.__subclasses__(): - if session_cls.model_fields["session_type"].default == session_type: - # Create the session instance with provided kwargs - session = session_cls(session_type=session_type, **kwargs) - self.sessions.append(session) - return session + for agent_cls in Agent.__subclasses__(): + logging.info(f"Found class: {agent_cls.model_fields['agent_type'].default}") + if agent_cls.model_fields["agent_type"].default == agent_type: + # Create the agent instance with provided kwargs + agent = agent_cls(agent_type=agent_type, **kwargs) + self.agents.append(agent) + return agent - raise ValueError(f"No session class found for session_type: {session_type}") + raise ValueError(f"No agent class found for agent_type: {agent_type}") - def add_session(self, session: AnySession) -> None: - """Add a Session to the context, ensuring no duplicate session_type.""" - if any(s.session_type == session.session_type for s in self.sessions): - raise ValueError(f"A session with session_type '{session.session_type}' already exists") - self.sessions.append(session) + def add_agent(self, agent: AnyAgent) -> None: + """Add a Agent to the context, ensuring no duplicate agent_type.""" + if any(s.agent_type == agent.agent_type for s in self.agents): + raise ValueError(f"A agent with agent_type '{agent.agent_type}' already exists") + self.agents.append(agent) - def get_session(self, session_type: str) -> Session | None: - """Return the Session with the given session_type, or None if not found.""" - for session in self.sessions: - if session.session_type == session_type: - return session + def get_agent(self, agent_type: str) -> Agent | None: + """Return the Agent with the given agent_type, or None if not found.""" + for agent in self.agents: + if agent.agent_type == agent_type: + return agent return None - def is_valid_session_type(self, session_type: str) -> bool: - """Check if the given session_type is valid.""" - return session_type in Session.valid_session_types() + def is_valid_agent_type(self, agent_type: str) -> bool: + """Check if the given agent_type is valid.""" + return agent_type in Agent.valid_agent_types() def get_summary(self) -> str: """Return a summary of the context.""" - if not self.sessions: - return f"Context {self.uuid}: No sessions." + if not self.agents: + return f"Context {self.uuid}: No agents." summary = f"Context {self.uuid}:\n" - for i, session in enumerate(self.sessions, 1): - summary += f"\nSession {i} ({session.session_type}):\n" - summary += session.conversation.get_summary() - if session.session_type == "resume": - summary += f"\nResume: {session.get_resume()}\n" - elif session.session_type == "job_description": - summary += f"\nJob Description: {session.job_description}\n" - elif session.session_type == "fact_check": - summary += f"\nFacts: {session.facts}\n" - elif session.session_type == "chat": - summary += f"\nChat Name: {session.name}\n" + for i, agent in enumerate(self.agents, 1): + summary += f"\nAgent {i} ({agent.agent_type}):\n" + summary += agent.conversation.get_summary() + if agent.agent_type == "resume": + summary += f"\nResume: {agent.get_resume()}\n" + elif agent.agent_type == "job_description": + summary += f"\nJob Description: {agent.job_description}\n" + elif agent.agent_type == "fact_check": + summary += f"\nFacts: {agent.facts}\n" + elif agent.agent_type == "chat": + summary += f"\nChat Name: {agent.name}\n" return summary \ No newline at end of file diff --git a/src/utils/defines.py b/src/utils/defines.py index 36f85f3..97f061b 100644 --- a/src/utils/defines.py +++ b/src/utils/defines.py @@ -9,7 +9,7 @@ embedding_model = os.getenv("EMBEDDING_MODEL_NAME", "mxbai-embed-large") persist_directory = os.getenv("PERSIST_DIR", "/opt/backstory/chromadb") max_context = 2048*8*2 doc_dir = "/opt/backstory/docs/" -session_dir = "/opt/backstory/sessions" +context_dir = "/opt/backstory/sessions" static_content = "/opt/backstory/frontend/deployed" resume_doc = "/opt/backstory/docs/resume/generic.md" # Only used for testing; backstory-prod will not use this diff --git a/src/utils/message.py b/src/utils/message.py index ca81412..37c1f19 100644 --- a/src/utils/message.py +++ b/src/utils/message.py @@ -3,10 +3,14 @@ from typing import Dict, List, Optional, Any from datetime import datetime, timezone class Message(BaseModel): - prompt: str - preamble: str = "" - content: str = "" - response: str = "" + # Required + prompt: str # Query to be answered + + # Generated while processing message + preamble: dict[str,str] = {} # Preamble to be prepended to the prompt + system_prompt: str = "" # System prompt provided to the LLM + full_content: str = "" # Full content of the message (preamble + prompt) + response: str = "" # LLM response to the preamble + query metadata: dict[str, Any] = { "rag": { "documents": [] }, "tools": [], @@ -15,7 +19,7 @@ class Message(BaseModel): "prompt_eval_count": 0, "prompt_eval_duration": 0, } - actions: List[str] = [] + actions: List[str] = [] # Other session modifying actions performed while processing the message timestamp: datetime = datetime.now(timezone.utc) def add_action(self, action: str | list[str]) -> None: diff --git a/src/utils/rag.py b/src/utils/rag.py index c7ef8ca..9260ef7 100644 --- a/src/utils/rag.py +++ b/src/utils/rag.py @@ -1,3 +1,4 @@ +from pydantic import BaseModel, Field, model_validator, PrivateAttr import os import glob from pathlib import Path diff --git a/src/utils/session.py b/src/utils/session.py deleted file mode 100644 index f4bd3e5..0000000 --- a/src/utils/session.py +++ /dev/null @@ -1,78 +0,0 @@ -from pydantic import BaseModel, Field, model_validator, PrivateAttr -from typing import Literal, TypeAlias, get_args -from .conversation import Conversation - -class Session(BaseModel): - session_type: Literal["resume", "job_description", "fact_check", "chat"] - system_prompt: str # Mandatory - conversation: Conversation = Conversation() - context_tokens: int = 0 - - _content_seed: str = PrivateAttr(default="") - - def get_and_reset_content_seed(self): - tmp = self._content_seed - self._content_seed = "" - return tmp - - def set_content_seed(self, content: str) -> None: - """Set the content seed for the session.""" - self._content_seed = content - - def get_content_seed(self) -> str: - """Get the content seed for the session.""" - return self._content_seed - - @classmethod - def valid_session_types(cls) -> set[str]: - """Return the set of valid session_type values.""" - return set(get_args(cls.__annotations__["session_type"])) - - -# Type alias for Session or any subclass -AnySession: TypeAlias = Session # BaseModel covers Session and subclasses - -class Resume(Session): - session_type: Literal["resume"] = "resume" - resume: str = "" - - @model_validator(mode="after") - def validate_resume(self): - if not self.resume.strip(): - raise ValueError("Resume content cannot be empty") - return self - - def get_resume(self) -> str: - """Get the resume content.""" - return self.resume - - def set_resume(self, resume: str) -> None: - """Set the resume content.""" - self.resume = resume - -class JobDescription(Session): - session_type: Literal["job_description"] = "job_description" - job_description: str = "" - - @model_validator(mode="after") - def validate_job_description(self): - if not self.job_description.strip(): - raise ValueError("Job description cannot be empty") - return self - -class FactCheck(Session): - session_type: Literal["fact_check"] = "fact_check" - facts: str = "" - - @model_validator(mode="after") - def validate_facts(self): - if not self.facts.strip(): - raise ValueError("Facts cannot be empty") - return self - -class Chat(Session): - session_type: Literal["chat"] = "chat" - - @model_validator(mode="after") - def validate_name(self): - return self \ No newline at end of file