Almost working?
This commit is contained in:
parent
5806563777
commit
90a83a7313
@ -293,8 +293,13 @@ RUN { \
|
||||
echo ' openssl req -x509 -nodes -days 365 -newkey rsa:2048 -keyout src/key.pem -out src/cert.pem -subj "/C=US/ST=OR/L=Portland/O=Development/CN=localhost"'; \
|
||||
echo ' fi' ; \
|
||||
echo ' while true; do'; \
|
||||
echo ' echo "Launching Backstory server..."'; \
|
||||
echo ' python src/server.py "${@}" || echo "Backstory server died. Restarting in 3 seconds."'; \
|
||||
echo ' if [[ ! -e /opt/backstory/block-server ]]; then' \
|
||||
echo ' echo "Launching Backstory server..."'; \
|
||||
echo ' python src/server.py "${@}" || echo "Backstory server died."'; \
|
||||
echo ' else'; \
|
||||
echo ' echo "block-server file exists. Not launching."'; \
|
||||
echo ' fi' ; \
|
||||
echo ' echo "Sleeping for 3 seconds."'; \
|
||||
echo ' sleep 3'; \
|
||||
echo ' done' ; \
|
||||
echo 'fi'; \
|
||||
|
@ -2,12 +2,12 @@
|
||||
|
||||
# Ensure input was provided
|
||||
if [[ -z "$1" ]]; then
|
||||
echo "Usage: $0 <path/to/python_script.py>"
|
||||
exit 1
|
||||
TARGET=$(readlink -f "src/server.py")
|
||||
else
|
||||
TARGET=$(readlink -f "$1")
|
||||
fi
|
||||
|
||||
# Resolve user-supplied path to absolute path
|
||||
TARGET=$(readlink -f "$1")
|
||||
|
||||
if [[ ! -f "$TARGET" ]]; then
|
||||
echo "Target file '$TARGET' not found."
|
||||
|
226
src/server.py
226
src/server.py
@ -44,7 +44,8 @@ from sklearn.preprocessing import MinMaxScaler
|
||||
|
||||
from utils import (
|
||||
rag as Rag,
|
||||
Context, Conversation, Session, Message, Chat, Resume, JobDescription, FactCheck,
|
||||
Context, Conversation, Message,
|
||||
Agent,
|
||||
defines
|
||||
)
|
||||
|
||||
@ -409,11 +410,10 @@ class WebServer:
|
||||
@self.app.get("/")
|
||||
async def root():
|
||||
context = self.create_context()
|
||||
logging.info(f"Redirecting non-session to {context.id}")
|
||||
logging.info(f"Redirecting non-context to {context.id}")
|
||||
return RedirectResponse(url=f"/{context.id}", status_code=307)
|
||||
#return JSONResponse({"redirect": f"/{context.id}"})
|
||||
|
||||
|
||||
@self.app.put("/api/umap/{context_id}")
|
||||
async def put_umap(context_id: str, request: Request):
|
||||
logging.info(f"{request.method} {request.url.path}")
|
||||
@ -487,16 +487,16 @@ class WebServer:
|
||||
logging.error(e)
|
||||
#return JSONResponse({"error": str(e)}, 500)
|
||||
|
||||
@self.app.put("/api/reset/{context_id}/{session_type}")
|
||||
async def put_reset(context_id: str, session_type: str, request: Request):
|
||||
@self.app.put("/api/reset/{context_id}/{agent_type}")
|
||||
async def put_reset(context_id: str, agent_type: str, request: Request):
|
||||
logging.info(f"{request.method} {request.url.path}")
|
||||
if not is_valid_uuid(context_id):
|
||||
logging.warning(f"Invalid context_id: {context_id}")
|
||||
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
||||
context = self.upsert_context(context_id)
|
||||
session = context.get_session(session_type)
|
||||
if not session:
|
||||
return JSONResponse({ "error": f"{session_type} is not recognized", "context": context.id }, status_code=404)
|
||||
agent = context.get_agent(agent_type)
|
||||
if not agent:
|
||||
return JSONResponse({ "error": f"{agent_type} is not recognized", "context": context.id }, status_code=404)
|
||||
|
||||
data = await request.json()
|
||||
try:
|
||||
@ -505,7 +505,7 @@ class WebServer:
|
||||
match reset_operation:
|
||||
case "system_prompt":
|
||||
logging.info(f"Resetting {reset_operation}")
|
||||
match session_type:
|
||||
match agent_type:
|
||||
case "chat":
|
||||
prompt = system_message
|
||||
case "job_description":
|
||||
@ -515,7 +515,7 @@ class WebServer:
|
||||
case "fact_check":
|
||||
prompt = system_message
|
||||
|
||||
session.system_prompt = prompt
|
||||
agent.system_prompt = prompt
|
||||
response["system_prompt"] = { "system_prompt": prompt }
|
||||
case "rags":
|
||||
logging.info(f"Resetting {reset_operation}")
|
||||
@ -532,17 +532,17 @@ class WebServer:
|
||||
"fact_check": ("job_description", "resume", "fact_check"),
|
||||
"chat": ("chat",),
|
||||
}
|
||||
resets = reset_map.get(session_type, ())
|
||||
resets = reset_map.get(agent_type, ())
|
||||
|
||||
for mode in resets:
|
||||
tmp = context.get_session(mode)
|
||||
tmp = context.get_agent(mode)
|
||||
if not tmp:
|
||||
continue
|
||||
logging.info(f"Resetting {reset_operation} for {mode}")
|
||||
context.conversation = Conversation()
|
||||
context.context_tokens = round(len(str(session.system_prompt)) * 3 / 4) # Estimate context usage
|
||||
context.context_tokens = round(len(str(agent.system_prompt)) * 3 / 4) # Estimate context usage
|
||||
response["history"] = []
|
||||
response["context_used"] = session.context_tokens
|
||||
response["context_used"] = agent.context_tokens
|
||||
case "message_history_length":
|
||||
logging.info(f"Resetting {reset_operation}")
|
||||
context.message_history_length = DEFAULT_HISTORY_LENGTH
|
||||
@ -564,8 +564,8 @@ class WebServer:
|
||||
context = self.upsert_context(context_id)
|
||||
|
||||
data = await request.json()
|
||||
session = context.get_session("chat")
|
||||
if not session:
|
||||
agent = context.get_agent("chat")
|
||||
if not agent:
|
||||
return JSONResponse({ "error": f"chat is not recognized", "context": context.id }, status_code=404)
|
||||
for k in data.keys():
|
||||
match k:
|
||||
@ -600,7 +600,7 @@ class WebServer:
|
||||
system_prompt = data[k].strip()
|
||||
if not system_prompt:
|
||||
return JSONResponse({ "status": "error", "message": "System prompt can not be empty." })
|
||||
session.system_prompt = system_prompt
|
||||
agent.system_prompt = system_prompt
|
||||
self.save_context(context_id)
|
||||
return JSONResponse({ "system_prompt": system_prompt })
|
||||
case "message_history_length":
|
||||
@ -621,11 +621,11 @@ class WebServer:
|
||||
logging.warning(f"Invalid context_id: {context_id}")
|
||||
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
||||
context = self.upsert_context(context_id)
|
||||
session = context.get_session("chat")
|
||||
if not session:
|
||||
agent = context.get_agent("chat")
|
||||
if not agent:
|
||||
return JSONResponse({ "error": f"chat is not recognized", "context": context.id }, status_code=404)
|
||||
return JSONResponse({
|
||||
"system_prompt": session.system_prompt,
|
||||
"system_prompt": agent.system_prompt,
|
||||
"message_history_length": context.message_history_length,
|
||||
"rags": context.rags,
|
||||
"tools": [ {
|
||||
@ -639,8 +639,8 @@ class WebServer:
|
||||
logging.info(f"{request.method} {request.url.path}")
|
||||
return JSONResponse(system_info(self.model))
|
||||
|
||||
@self.app.post("/api/chat/{context_id}/{session_type}")
|
||||
async def post_chat_endpoint(context_id: str, session_type: str, request: Request):
|
||||
@self.app.post("/api/chat/{context_id}/{agent_type}")
|
||||
async def post_chat_endpoint(context_id: str, agent_type: str, request: Request):
|
||||
logging.info(f"{request.method} {request.url.path}")
|
||||
try:
|
||||
if not is_valid_uuid(context_id):
|
||||
@ -651,18 +651,18 @@ class WebServer:
|
||||
|
||||
try:
|
||||
data = await request.json()
|
||||
session = context.get_session(session_type)
|
||||
if not session and session_type == "job_description":
|
||||
logging.info(f"Session {session_type} not found. Returning empty history.")
|
||||
# Create a new session if it doesn't exist
|
||||
session = context.get_or_create_session("job_description", system_prompt=system_generate_resume, job_description=data["content"])
|
||||
agent = context.get_agent(agent_type)
|
||||
if not agent and agent_type == "job_description":
|
||||
logging.info(f"Agent {agent_type} not found. Returning empty history.")
|
||||
# Create a new agent if it doesn't exist
|
||||
agent = context.get_or_create_agent("job_description", system_prompt=system_generate_resume, job_description=data["content"])
|
||||
except Exception as e:
|
||||
logging.info(f"Attempt to create session type: {session_type} failed", e)
|
||||
return JSONResponse({ "error": f"{session_type} is not recognized", "context": context.id }, status_code=404)
|
||||
logging.info(f"Attempt to create agent type: {agent_type} failed", e)
|
||||
return JSONResponse({ "error": f"{agent_type} is not recognized", "context": context.id }, status_code=404)
|
||||
|
||||
# Create a custom generator that ensures flushing
|
||||
async def flush_generator():
|
||||
async for message in self.generate_response(context=context, session=session, content=data["content"]):
|
||||
async for message in self.generate_response(context=context, agent=agent, content=data["content"]):
|
||||
# Convert to JSON and add newline
|
||||
yield json.dumps(message) + "\n"
|
||||
# Save the history as its generated
|
||||
@ -687,20 +687,20 @@ class WebServer:
|
||||
@self.app.post("/api/context")
|
||||
async def create_context():
|
||||
context = self.create_context()
|
||||
logging.info(f"Generated new session as {context.id}")
|
||||
logging.info(f"Generated new agent as {context.id}")
|
||||
return JSONResponse({ "id": context.id })
|
||||
|
||||
@self.app.get("/api/history/{context_id}/{session_type}")
|
||||
async def get_history(context_id: str, session_type: str, request: Request):
|
||||
@self.app.get("/api/history/{context_id}/{agent_type}")
|
||||
async def get_history(context_id: str, agent_type: str, request: Request):
|
||||
logging.info(f"{request.method} {request.url.path}")
|
||||
try:
|
||||
context = self.upsert_context(context_id)
|
||||
session = context.get_session(session_type)
|
||||
if not session:
|
||||
logging.info(f"Session {session_type} not found. Returning empty history.")
|
||||
agent = context.get_agent(agent_type)
|
||||
if not agent:
|
||||
logging.info(f"Agent {agent_type} not found. Returning empty history.")
|
||||
return JSONResponse({ "messages": [] })
|
||||
logging.info(f"History for {session_type} contains {len(session.conversation.messages)} entries.")
|
||||
return session.conversation
|
||||
logging.info(f"History for {agent_type} contains {len(agent.conversation.messages)} entries.")
|
||||
return agent.conversation
|
||||
except Exception as e:
|
||||
logging.error(f"Error in get_history: {e}")
|
||||
return JSONResponse({"error": str(e)}, status_code=404)
|
||||
@ -732,17 +732,17 @@ class WebServer:
|
||||
return JSONResponse({ "status": "error" }, 405)
|
||||
|
||||
|
||||
@self.app.get("/api/context-status/{context_id}/{session_type}")
|
||||
async def get_context_status(context_id, session_type: str, request: Request):
|
||||
@self.app.get("/api/context-status/{context_id}/{agent_type}")
|
||||
async def get_context_status(context_id, agent_type: str, request: Request):
|
||||
logging.info(f"{request.method} {request.url.path}")
|
||||
if not is_valid_uuid(context_id):
|
||||
logging.warning(f"Invalid context_id: {context_id}")
|
||||
return JSONResponse({"error": "Invalid context_id"}, status_code=400)
|
||||
context = self.upsert_context(context_id)
|
||||
session = context.get_session(session_type)
|
||||
if not session:
|
||||
agent = context.get_agent(agent_type)
|
||||
if not agent:
|
||||
return JSONResponse({"context_used": 0, "max_context": defines.max_context})
|
||||
return JSONResponse({"context_used": session.context_tokens, "max_context": defines.max_context})
|
||||
return JSONResponse({"context_used": agent.context_tokens, "max_context": defines.max_context})
|
||||
|
||||
@self.app.get("/api/health")
|
||||
async def health_check():
|
||||
@ -757,52 +757,52 @@ class WebServer:
|
||||
logging.info(f"Serve index.html for {path}")
|
||||
return FileResponse(os.path.join(defines.static_content, "index.html"))
|
||||
|
||||
def save_context(self, session_id):
|
||||
def save_context(self, agent_id):
|
||||
"""
|
||||
Serialize a Python dictionary to a file in the sessions directory.
|
||||
Serialize a Python dictionary to a file in the agents directory.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing the session data
|
||||
session_id: UUID string for the context. If it doesn't exist, it is created
|
||||
data: Dictionary containing the agent data
|
||||
agent_id: UUID string for the context. If it doesn't exist, it is created
|
||||
|
||||
Returns:
|
||||
The session_id used for the file
|
||||
The agent_id used for the file
|
||||
"""
|
||||
context = self.upsert_context(session_id)
|
||||
context = self.upsert_context(agent_id)
|
||||
|
||||
# Create sessions directory if it doesn't exist
|
||||
if not os.path.exists(defines.session_dir):
|
||||
os.makedirs(defines.session_dir)
|
||||
# Create agents directory if it doesn't exist
|
||||
if not os.path.exists(defines.context_dir):
|
||||
os.makedirs(defines.context_dir)
|
||||
|
||||
# Create the full file path
|
||||
file_path = os.path.join(defines.session_dir, session_id)
|
||||
file_path = os.path.join(defines.context_dir, agent_id)
|
||||
|
||||
# Serialize the data to JSON and write to file
|
||||
with open(file_path, "w") as f:
|
||||
f.write(context.model_dump_json())
|
||||
|
||||
return session_id
|
||||
return agent_id
|
||||
|
||||
def load_context(self, session_id) -> Context:
|
||||
def load_context(self, agent_id) -> Context:
|
||||
"""
|
||||
Load a context from a file in the sessions directory.
|
||||
Load a context from a file in the agents directory.
|
||||
Args:
|
||||
session_id: UUID string for the context. If it doesn't exist, a new context is created.
|
||||
agent_id: UUID string for the context. If it doesn't exist, a new context is created.
|
||||
Returns:
|
||||
A Context object with the specified ID and default settings.
|
||||
"""
|
||||
|
||||
file_path = os.path.join(defines.session_dir, session_id)
|
||||
file_path = os.path.join(defines.context_dir, agent_id)
|
||||
|
||||
# Check if the file exists
|
||||
if not os.path.exists(file_path):
|
||||
self.contexts[session_id] = self.create_context(session_id)
|
||||
self.contexts[agent_id] = self.create_context(agent_id)
|
||||
else:
|
||||
# Read and deserialize the data
|
||||
with open(file_path, "r") as f:
|
||||
self.contexts[session_id] = Context.model_validate_json(f.read())
|
||||
self.contexts[agent_id] = Context.model_validate_json(f.read())
|
||||
|
||||
return self.contexts[session_id]
|
||||
return self.contexts[agent_id]
|
||||
|
||||
def create_context(self, context_id = None) -> Context:
|
||||
"""
|
||||
@ -816,14 +816,16 @@ class WebServer:
|
||||
|
||||
if os.path.exists(defines.resume_doc):
|
||||
context.user_resume = open(defines.resume_doc, "r").read()
|
||||
context.add_session(Chat(system_prompt = system_message))
|
||||
# context.add_session(Resume(system_prompt = system_generate_resume))
|
||||
# context.add_session(JobDescription(system_prompt = system_job_description))
|
||||
# context.add_session(FactCheck(system_prompt = system_fact_check))
|
||||
context.get_or_create_agent(
|
||||
agent_type="chat",
|
||||
system_prompt=system_message)
|
||||
# context.add_agent(Resume(system_prompt = system_generate_resume))
|
||||
# context.add_agent(JobDescription(system_prompt = system_job_description))
|
||||
# context.add_agent(FactCheck(system_prompt = system_fact_check))
|
||||
context.tools = default_tools(tools)
|
||||
context.rags = rags.copy()
|
||||
|
||||
logging.info(f"{context.id} created and added to sessions.")
|
||||
logging.info(f"{context.id} created and added to contexts.")
|
||||
self.contexts[context.id] = context
|
||||
self.save_context(context.id)
|
||||
return context
|
||||
@ -956,15 +958,15 @@ class WebServer:
|
||||
else:
|
||||
yield {"status": "complete", "message": "RAG processing complete"}
|
||||
|
||||
# session_type: chat
|
||||
# agent_type: chat
|
||||
# * Q&A
|
||||
#
|
||||
# session_type: job_description
|
||||
# agent_type: job_description
|
||||
# * First message sets Job Description and generates Resume
|
||||
# * Has content (Job Description)
|
||||
# * Then Q&A of Job Description
|
||||
#
|
||||
# session_type: resume
|
||||
# agent_type: resume
|
||||
# * First message sets Resume and generates Fact Check
|
||||
# * Has no content
|
||||
# * Then Q&A of Resume
|
||||
@ -973,18 +975,18 @@ class WebServer:
|
||||
# * First message sets Fact Check and is Q&A
|
||||
# * Has content
|
||||
# * Then Q&A of Fact Check
|
||||
async def generate_response(self, context : Context, session : Session, content : str):
|
||||
async def generate_response(self, context : Context, agent : Agent, content : str):
|
||||
if not self.file_watcher:
|
||||
return
|
||||
|
||||
if self.processing:
|
||||
logging.info("TODO: Implement delay queing; busy for same session, otherwise return queue size and estimated wait time")
|
||||
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
||||
yield {"status": "error", "message": "Busy processing another request."}
|
||||
return
|
||||
|
||||
self.processing = True
|
||||
|
||||
conversation : Conversation = session.conversation
|
||||
conversation : Conversation = agent.conversation
|
||||
|
||||
message = Message(prompt=content)
|
||||
del content # Prevent accidental use of content
|
||||
@ -999,18 +1001,18 @@ class WebServer:
|
||||
enable_rag = False
|
||||
|
||||
# RAG is disabled when asking questions about the resume
|
||||
if session.session_type == "resume":
|
||||
if agent.agent_type == "resume":
|
||||
enable_rag = False
|
||||
|
||||
# The first time through each session session_type a content_seed may be set for
|
||||
# future chat sessions; use it once, then clear it
|
||||
message.preamble = session.get_and_reset_content_seed()
|
||||
system_prompt = session.system_prompt
|
||||
# The first time through each agent agent_type a content_seed may be set for
|
||||
# future chat agents; use it once, then clear it
|
||||
message.preamble = agent.get_and_reset_content_seed()
|
||||
system_prompt = agent.system_prompt
|
||||
|
||||
# After the first time a particular session session_type is used, it is handled as a chat.
|
||||
# The number of messages indicating the session is ready for chat varies based on
|
||||
# the session_type of session
|
||||
process_type = session.session_type
|
||||
# After the first time a particular agent agent_type is used, it is handled as a chat.
|
||||
# The number of messages indicating the agent is ready for chat varies based on
|
||||
# the agent_type of agent
|
||||
process_type = agent.agent_type
|
||||
match process_type:
|
||||
case "job_description":
|
||||
logging.info(f"job_description user_history len: {len(conversation.messages)}")
|
||||
@ -1021,7 +1023,7 @@ class WebServer:
|
||||
if len(conversation.messages) >= 3: # USER, ASSISTANT, FACT_CHECK
|
||||
process_type = "chat"
|
||||
case "fact_check":
|
||||
process_type = "chat" # Fact Check is always a chat session
|
||||
process_type = "chat" # Fact Check is always a chat agent
|
||||
|
||||
match process_type:
|
||||
# Normal chat interactions with context history
|
||||
@ -1071,7 +1073,7 @@ class WebServer:
|
||||
Use that information to respond to:"""
|
||||
|
||||
# Use the mode specific system_prompt instead of 'chat'
|
||||
system_prompt = session.system_prompt
|
||||
system_prompt = agent.system_prompt
|
||||
|
||||
# On first entry, a single job_description is provided ("user")
|
||||
# Generate a resume to append to RESUME history
|
||||
@ -1110,10 +1112,10 @@ Use that information to respond to:"""
|
||||
<|job_description|>
|
||||
{message.prompt}
|
||||
"""
|
||||
tmp = context.get_session("job_description")
|
||||
tmp = context.get_agent("job_description")
|
||||
if not tmp:
|
||||
raise Exception(f"Job description session not found.")
|
||||
# Set the content seed for the job_description session
|
||||
raise Exception(f"Job description agent not found.")
|
||||
# Set the content seed for the job_description agent
|
||||
tmp.set_content_seed(message.preamble + "<|question|>\nUse the above information to respond to this prompt: ")
|
||||
|
||||
message.preamble += f"""
|
||||
@ -1126,7 +1128,7 @@ Use to the above information to respond to this prompt:
|
||||
"""
|
||||
|
||||
# For all future calls to job_description, use the system_job_description
|
||||
session.system_prompt = system_job_description
|
||||
agent.system_prompt = system_job_description
|
||||
|
||||
# Seed the history for job_description
|
||||
stuffingMessage = Message(prompt=message.prompt)
|
||||
@ -1137,21 +1139,21 @@ Use to the above information to respond to this prompt:
|
||||
|
||||
message.add_action("generate_resume")
|
||||
|
||||
logging.info("TODO: Convert these to generators, eg generate_resume() and then manually add results into session'resume'")
|
||||
logging.info("TODO: For subsequent runs, have the Session handler generate the follow up prompts so they can have correct context preamble")
|
||||
logging.info("TODO: Convert these to generators, eg generate_resume() and then manually add results into agent'resume'")
|
||||
logging.info("TODO: For subsequent runs, have the Agent handler generate the follow up prompts so they can have correct context preamble")
|
||||
|
||||
# Switch to resume session for LLM responses
|
||||
# Switch to resume agent for LLM responses
|
||||
# message.metadata["origin"] = "resume"
|
||||
# session = context.get_or_create_session("resume")
|
||||
# system_prompt = session.system_prompt
|
||||
# llm_history = session.llm_history = []
|
||||
# user_history = session.user_history = []
|
||||
# agent = context.get_or_create_agent("resume")
|
||||
# system_prompt = agent.system_prompt
|
||||
# llm_history = agent.llm_history = []
|
||||
# user_history = agent.user_history = []
|
||||
|
||||
# Ignore the passed in content and invoke Fact Check
|
||||
case "resume":
|
||||
if len(context.get_or_create_session("resume").conversation.messages) < 2: # USER, **ASSISTANT**
|
||||
if len(context.get_or_create_agent("resume").conversation.messages) < 2: # USER, **ASSISTANT**
|
||||
raise Exception(f"No resume found in user history.")
|
||||
resume = context.get_or_create_session("resume").conversation.messages[1]
|
||||
resume = context.get_or_create_agent("resume").conversation.messages[1]
|
||||
|
||||
# Generate RAG content if enabled, based on the content
|
||||
rag_context = ""
|
||||
@ -1196,7 +1198,7 @@ Use to the above information to respond to this prompt:
|
||||
<|question|>
|
||||
"""
|
||||
|
||||
context.get_or_create_session("resume").set_content_seed(f"""
|
||||
context.get_or_create_agent("resume").set_content_seed(f"""
|
||||
<|resume|>
|
||||
{resume["content"]}
|
||||
|
||||
@ -1222,25 +1224,25 @@ Use the above <|resume|> and <|job_description|> to answer this query:
|
||||
conversation.add_message(stuffingMessage)
|
||||
|
||||
# For all future calls to job_description, use the system_job_description
|
||||
logging.info("TODO: Create a system_resume_QA prompt to use for the resume session")
|
||||
session.system_prompt = system_prompt
|
||||
logging.info("TODO: Create a system_resume_QA prompt to use for the resume agent")
|
||||
agent.system_prompt = system_prompt
|
||||
|
||||
# Switch to fact_check session for LLM responses
|
||||
# Switch to fact_check agent for LLM responses
|
||||
message.metadata["origin"] = "fact_check"
|
||||
session = context.get_or_create_session("fact_check", system_prompt=system_fact_check)
|
||||
agent = context.get_or_create_agent("fact_check", system_prompt=system_fact_check)
|
||||
|
||||
llm_history = session.llm_history = []
|
||||
user_history = session.user_history = []
|
||||
llm_history = agent.llm_history = []
|
||||
user_history = agent.user_history = []
|
||||
|
||||
case _:
|
||||
raise Exception(f"Invalid chat session_type: {session_type}")
|
||||
raise Exception(f"Invalid chat agent_type: {agent_type}")
|
||||
|
||||
conversation.add_message(message)
|
||||
# llm_history.append({"role": "user", "content": message.preamble + content})
|
||||
# user_history.append({"role": "user", "content": content, "origin": message.metadata["origin"]})
|
||||
# message.metadata["full_query"] = llm_history[-1]["content"]
|
||||
|
||||
# Uses cached system_prompt as session.system_prompt may have been updated for follow up questions
|
||||
# Uses cached system_prompt as agent.system_prompt may have been updated for follow up questions
|
||||
messages = create_system_message(system_prompt)
|
||||
if context.message_history_length:
|
||||
to_add = conversation.messages[-context.message_history_length:]
|
||||
@ -1272,12 +1274,12 @@ Use the above <|resume|> and <|job_description|> to answer this query:
|
||||
{message.prompt}"""
|
||||
|
||||
# Estimate token length of new messages
|
||||
ctx_size = self.get_optimal_ctx_size(context.get_or_create_session(process_type).context_tokens, messages=message.prompt)
|
||||
ctx_size = self.get_optimal_ctx_size(context.get_or_create_agent(process_type).context_tokens, messages=message.prompt)
|
||||
|
||||
if len(conversation.messages) > 2:
|
||||
processing_message = f"Processing {'RAG augmented ' if enable_rag else ''}query..."
|
||||
else:
|
||||
match session.session_type:
|
||||
match agent.agent_type:
|
||||
case "job_description":
|
||||
processing_message = f"Generating {'RAG augmented ' if enable_rag else ''}resume..."
|
||||
case "resume":
|
||||
@ -1303,7 +1305,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
|
||||
message.metadata["eval_duration"] += response["eval_duration"]
|
||||
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
||||
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
||||
session.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
||||
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
||||
|
||||
tools_used = []
|
||||
|
||||
@ -1347,7 +1349,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
|
||||
message.metadata["tools"] = tools_used
|
||||
|
||||
# Estimate token length of new messages
|
||||
ctx_size = self.get_optimal_ctx_size(session.context_tokens, messages=messages[pre_add_index:])
|
||||
ctx_size = self.get_optimal_ctx_size(agent.context_tokens, messages=messages[pre_add_index:])
|
||||
yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size }
|
||||
# Decrease creativity when processing tool call requests
|
||||
response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 })
|
||||
@ -1355,11 +1357,11 @@ Use the above <|resume|> and <|job_description|> to answer this query:
|
||||
message.metadata["eval_duration"] += response["eval_duration"]
|
||||
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
||||
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
||||
session.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
||||
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
||||
|
||||
reply = response["message"]["content"]
|
||||
message.response = reply
|
||||
message.metadata["origin"] = session.session_type
|
||||
message.metadata["origin"] = agent.agent_type
|
||||
# final_message = {"role": "assistant", "content": reply }
|
||||
|
||||
# # history is provided to the LLM and should not have additional metadata
|
||||
@ -1379,7 +1381,7 @@ Use the above <|resume|> and <|job_description|> to answer this query:
|
||||
}
|
||||
|
||||
# except Exception as e:
|
||||
# logging.exception({ "model": self.model, "origin": session_type, "content": content, "error": str(e) })
|
||||
# logging.exception({ "model": self.model, "origin": agent_type, "content": content, "error": str(e) })
|
||||
# yield {"status": "error", "message": f"An error occurred: {str(e)}"}
|
||||
|
||||
# finally:
|
||||
|
@ -1,10 +1,17 @@
|
||||
# Import defines to make `utils.defines` accessible
|
||||
from . import defines
|
||||
|
||||
# Import rest as `utils.*` accessible
|
||||
from .rag import ChromaDBFileWatcher, start_file_watcher
|
||||
|
||||
from .message import Message
|
||||
from .conversation import Conversation
|
||||
from .session import Session, Chat, Resume, JobDescription, FactCheck
|
||||
from .context import Context
|
||||
from .context import Context
|
||||
from . import agents
|
||||
|
||||
from .agents import Agent, __all__ as agents_all
|
||||
|
||||
__all__ = [
|
||||
'Agent',
|
||||
'Context',
|
||||
'Conversation',
|
||||
'Message',
|
||||
'ChromaDBFileWatcher',
|
||||
'start_file_watcher'
|
||||
] + agents_all
|
256
src/utils/agent.py
Normal file
256
src/utils/agent.py
Normal file
@ -0,0 +1,256 @@
|
||||
from pydantic import BaseModel, Field, model_validator, PrivateAttr
|
||||
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
|
||||
from abc import ABC, abstractmethod
|
||||
from typing_extensions import Annotated
|
||||
import logging
|
||||
|
||||
from .types import AgentBase, registry
|
||||
|
||||
# Only import Context for type checking
|
||||
if TYPE_CHECKING:
|
||||
from .context import Context
|
||||
|
||||
from .types import AgentBase
|
||||
|
||||
from .conversation import Conversation
|
||||
from .message import Message
|
||||
|
||||
class Agent(AgentBase):
|
||||
"""
|
||||
Base class for all agent types.
|
||||
This class defines the common attributes and methods for all agent types.
|
||||
"""
|
||||
agent_type: str = Field(default="agent", const=True) # discriminator value
|
||||
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
"""Auto-register subclasses"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
# Register this class if it has an agent_type
|
||||
if hasattr(cls, 'agent_type') and cls.agent_type != Agent.agent_type:
|
||||
registry.register(cls.agent_type, cls)
|
||||
|
||||
def __init__(self, **data):
|
||||
# Set agent_type from class if not provided
|
||||
if 'agent_type' not in data:
|
||||
data['agent_type'] = self.__class__.agent_type
|
||||
super().__init__(**data)
|
||||
|
||||
system_prompt: str # Mandatory
|
||||
conversation: Conversation = Conversation()
|
||||
context_tokens: int = 0
|
||||
|
||||
# Add a property for context if needed without creating a circular reference
|
||||
@property
|
||||
def context(self) -> Optional['Context']:
|
||||
if TYPE_CHECKING:
|
||||
from .context import Context
|
||||
# Implement logic to fetch context by ID if needed
|
||||
return None
|
||||
|
||||
#context: Context
|
||||
|
||||
_content_seed: str = PrivateAttr(default="")
|
||||
|
||||
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
||||
"""
|
||||
Prepare message with context information in message.preamble
|
||||
"""
|
||||
# Generate RAG content if enabled, based on the content
|
||||
rag_context = ""
|
||||
if not message.disable_rag:
|
||||
# Gather RAG results, yielding each result
|
||||
# as it becomes available
|
||||
for value in self.context.generate_rag_results(message):
|
||||
logging.info(f"RAG: {value.status} - {value.response}")
|
||||
if value.status != "done":
|
||||
yield value
|
||||
if value.status == "error":
|
||||
message.status = "error"
|
||||
message.response = value.response
|
||||
yield message
|
||||
return
|
||||
|
||||
if message.metadata["rag"]:
|
||||
for rag_collection in message.metadata["rag"]:
|
||||
for doc in rag_collection["documents"]:
|
||||
rag_context += f"{doc}\n"
|
||||
|
||||
if rag_context:
|
||||
message["context"] = rag_context
|
||||
|
||||
if self.context.user_resume:
|
||||
message["resume"] = self.content.user_resume
|
||||
|
||||
if message.preamble:
|
||||
preamble_types = [f"<|{p}|>" for p in message.preamble.keys()]
|
||||
preamble_types_AND = " and ".join(preamble_types)
|
||||
preamble_types_OR = " or ".join(preamble_types)
|
||||
message.preamble["rules"] = f"""\
|
||||
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_or_types} or quoting it directly.
|
||||
- If there is no information in these sections, answer based on your knowledge.
|
||||
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
||||
"""
|
||||
message.preamble["question"] = "Use that information to respond to:"
|
||||
else:
|
||||
message.preamble["question"] = "Respond to:"
|
||||
|
||||
message.system_prompt = self.system_prompt
|
||||
message.status = "done"
|
||||
yield message
|
||||
return
|
||||
|
||||
async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]:
|
||||
if self.context.processing:
|
||||
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
||||
message.status = "error"
|
||||
message.response = "Busy processing another request."
|
||||
yield message
|
||||
return
|
||||
|
||||
self.context.processing = True
|
||||
|
||||
messages = []
|
||||
|
||||
for value in self.llm.chat(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
#tools=llm_tools(context.tools) if message.enable_tools else None,
|
||||
options={ "num_ctx": message.ctx_size }
|
||||
):
|
||||
logging.info(f"LLM: {value.status} - {value.response}")
|
||||
if value.status != "done":
|
||||
message.status = value.status
|
||||
message.response = value.response
|
||||
yield message
|
||||
if value.status == "error":
|
||||
return
|
||||
response = value
|
||||
|
||||
message.metadata["eval_count"] += response["eval_count"]
|
||||
message.metadata["eval_duration"] += response["eval_duration"]
|
||||
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
||||
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
||||
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
||||
|
||||
tools_used = []
|
||||
|
||||
yield {"status": "processing", "message": "Initial response received..."}
|
||||
|
||||
if "tool_calls" in response.get("message", {}):
|
||||
yield {"status": "processing", "message": "Processing tool calls..."}
|
||||
|
||||
tool_message = response["message"]
|
||||
tool_result = None
|
||||
|
||||
# Process all yielded items from the handler
|
||||
async for item in self.handle_tool_calls(tool_message):
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
# This is the final result tuple (tool_result, tools_used)
|
||||
tool_result, tools_used = item
|
||||
else:
|
||||
# This is a status update, forward it
|
||||
yield item
|
||||
|
||||
message_dict = {
|
||||
"role": tool_message.get("role", "assistant"),
|
||||
"content": tool_message.get("content", "")
|
||||
}
|
||||
|
||||
if "tool_calls" in tool_message:
|
||||
message_dict["tool_calls"] = [
|
||||
{"function": {"name": tc["function"]["name"], "arguments": tc["function"]["arguments"]}}
|
||||
for tc in tool_message["tool_calls"]
|
||||
]
|
||||
|
||||
pre_add_index = len(messages)
|
||||
messages.append(message_dict)
|
||||
|
||||
if isinstance(tool_result, list):
|
||||
messages.extend(tool_result)
|
||||
else:
|
||||
if tool_result:
|
||||
messages.append(tool_result)
|
||||
|
||||
message.metadata["tools"] = tools_used
|
||||
|
||||
# Estimate token length of new messages
|
||||
ctx_size = self.get_optimal_ctx_size(agent.context_tokens, messages=messages[pre_add_index:])
|
||||
yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size }
|
||||
# Decrease creativity when processing tool call requests
|
||||
response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 })
|
||||
message.metadata["eval_count"] += response["eval_count"]
|
||||
message.metadata["eval_duration"] += response["eval_duration"]
|
||||
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
||||
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
||||
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
||||
|
||||
reply = response["message"]["content"]
|
||||
message.response = reply
|
||||
message.metadata["origin"] = agent.agent_type
|
||||
# final_message = {"role": "assistant", "content": reply }
|
||||
|
||||
# # history is provided to the LLM and should not have additional metadata
|
||||
# llm_history.append(final_message)
|
||||
|
||||
# user_history is provided to the REST API and does not include CONTEXT
|
||||
# It does include metadata
|
||||
# final_message["metadata"] = message.metadata
|
||||
# user_history.append({**final_message, "origin": message.metadata["origin"]})
|
||||
|
||||
# Return the REST API with metadata
|
||||
yield {
|
||||
"status": "done",
|
||||
"message": {
|
||||
**message.model_dump(mode='json'),
|
||||
}
|
||||
}
|
||||
|
||||
self.context.processing = False
|
||||
return
|
||||
|
||||
async def process_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
||||
message.full_content = ""
|
||||
for i, p in enumerate(message.preamble.keys()):
|
||||
message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n"
|
||||
|
||||
# Estimate token length of new messages
|
||||
message.ctx_size = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content)
|
||||
|
||||
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
|
||||
message.status = "thinking"
|
||||
yield message
|
||||
|
||||
for value in self.generate_llm_response(message):
|
||||
logging.info(f"LLM: {value.status} - {value.response}")
|
||||
if value.status != "done":
|
||||
yield value
|
||||
if value.status == "error":
|
||||
return
|
||||
|
||||
def get_and_reset_content_seed(self):
|
||||
tmp = self._content_seed
|
||||
self._content_seed = ""
|
||||
return tmp
|
||||
|
||||
def set_content_seed(self, content: str) -> None:
|
||||
"""Set the content seed for the agent."""
|
||||
self._content_seed = content
|
||||
|
||||
def get_content_seed(self) -> str:
|
||||
"""Get the content seed for the agent."""
|
||||
return self._content_seed
|
||||
|
||||
@classmethod
|
||||
def valid_agent_types(cls) -> set[str]:
|
||||
"""Return the set of valid agent_type values."""
|
||||
return set(get_args(cls.__annotations__["agent_type"]))
|
||||
|
||||
|
||||
# Register the base agent
|
||||
registry.register(Agent.agent_type, Agent)
|
||||
|
||||
# Type alias for Agent or any subclass
|
||||
AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses
|
||||
|
||||
import ./agents
|
241
src/utils/agents/base.py
Normal file
241
src/utils/agents/base.py
Normal file
@ -0,0 +1,241 @@
|
||||
from pydantic import BaseModel, Field, model_validator, PrivateAttr
|
||||
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
|
||||
from abc import ABC, abstractmethod
|
||||
from typing_extensions import Annotated
|
||||
import logging
|
||||
|
||||
# Only import Context for type checking
|
||||
if TYPE_CHECKING:
|
||||
from .. context import Context
|
||||
|
||||
from .types import AgentBase, ContextBase, registry
|
||||
|
||||
from .. conversation import Conversation
|
||||
from .. message import Message
|
||||
|
||||
class Agent(AgentBase):
|
||||
"""
|
||||
Base class for all agent types.
|
||||
This class defines the common attributes and methods for all agent types.
|
||||
"""
|
||||
agent_type: Literal["base"] = "base"
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
"""Auto-register subclasses"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
# Register this class if it has an agent_type
|
||||
if hasattr(cls, 'agent_type') and cls.agent_type != AgentBase._agent_type:
|
||||
registry.register(cls.agent_type, cls)
|
||||
|
||||
def __init__(self, **data):
|
||||
# Set agent_type from class if not provided
|
||||
if 'agent_type' not in data:
|
||||
data['agent_type'] = self.__class__.agent_type
|
||||
super().__init__(**data)
|
||||
|
||||
system_prompt: str # Mandatory
|
||||
conversation: Conversation = Conversation()
|
||||
context_tokens: int = 0
|
||||
context: ContextBase # Avoid circular reference
|
||||
|
||||
_content_seed: str = PrivateAttr(default="")
|
||||
|
||||
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
||||
"""
|
||||
Prepare message with context information in message.preamble
|
||||
"""
|
||||
# Generate RAG content if enabled, based on the content
|
||||
rag_context = ""
|
||||
if not message.disable_rag:
|
||||
# Gather RAG results, yielding each result
|
||||
# as it becomes available
|
||||
for value in self.context.generate_rag_results(message):
|
||||
logging.info(f"RAG: {value.status} - {value.response}")
|
||||
if value.status != "done":
|
||||
yield value
|
||||
if value.status == "error":
|
||||
message.status = "error"
|
||||
message.response = value.response
|
||||
yield message
|
||||
return
|
||||
|
||||
if message.metadata["rag"]:
|
||||
for rag_collection in message.metadata["rag"]:
|
||||
for doc in rag_collection["documents"]:
|
||||
rag_context += f"{doc}\n"
|
||||
|
||||
if rag_context:
|
||||
message["context"] = rag_context
|
||||
|
||||
if self.context.user_resume:
|
||||
message["resume"] = self.content.user_resume
|
||||
|
||||
if message.preamble:
|
||||
preamble_types = [f"<|{p}|>" for p in message.preamble.keys()]
|
||||
preamble_types_AND = " and ".join(preamble_types)
|
||||
preamble_types_OR = " or ".join(preamble_types)
|
||||
message.preamble["rules"] = f"""\
|
||||
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_or_types} or quoting it directly.
|
||||
- If there is no information in these sections, answer based on your knowledge.
|
||||
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
||||
"""
|
||||
message.preamble["question"] = "Use that information to respond to:"
|
||||
else:
|
||||
message.preamble["question"] = "Respond to:"
|
||||
|
||||
message.system_prompt = self.system_prompt
|
||||
message.status = "done"
|
||||
yield message
|
||||
return
|
||||
|
||||
async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]:
|
||||
if self.context.processing:
|
||||
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
||||
message.status = "error"
|
||||
message.response = "Busy processing another request."
|
||||
yield message
|
||||
return
|
||||
|
||||
self.context.processing = True
|
||||
|
||||
messages = []
|
||||
|
||||
for value in self.llm.chat(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
#tools=llm_tools(context.tools) if message.enable_tools else None,
|
||||
options={ "num_ctx": message.ctx_size }
|
||||
):
|
||||
logging.info(f"LLM: {value.status} - {value.response}")
|
||||
if value.status != "done":
|
||||
message.status = value.status
|
||||
message.response = value.response
|
||||
yield message
|
||||
if value.status == "error":
|
||||
return
|
||||
response = value
|
||||
|
||||
message.metadata["eval_count"] += response["eval_count"]
|
||||
message.metadata["eval_duration"] += response["eval_duration"]
|
||||
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
||||
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
||||
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
||||
|
||||
tools_used = []
|
||||
|
||||
yield {"status": "processing", "message": "Initial response received..."}
|
||||
|
||||
if "tool_calls" in response.get("message", {}):
|
||||
yield {"status": "processing", "message": "Processing tool calls..."}
|
||||
|
||||
tool_message = response["message"]
|
||||
tool_result = None
|
||||
|
||||
# Process all yielded items from the handler
|
||||
async for item in self.handle_tool_calls(tool_message):
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
# This is the final result tuple (tool_result, tools_used)
|
||||
tool_result, tools_used = item
|
||||
else:
|
||||
# This is a status update, forward it
|
||||
yield item
|
||||
|
||||
message_dict = {
|
||||
"role": tool_message.get("role", "assistant"),
|
||||
"content": tool_message.get("content", "")
|
||||
}
|
||||
|
||||
if "tool_calls" in tool_message:
|
||||
message_dict["tool_calls"] = [
|
||||
{"function": {"name": tc["function"]["name"], "arguments": tc["function"]["arguments"]}}
|
||||
for tc in tool_message["tool_calls"]
|
||||
]
|
||||
|
||||
pre_add_index = len(messages)
|
||||
messages.append(message_dict)
|
||||
|
||||
if isinstance(tool_result, list):
|
||||
messages.extend(tool_result)
|
||||
else:
|
||||
if tool_result:
|
||||
messages.append(tool_result)
|
||||
|
||||
message.metadata["tools"] = tools_used
|
||||
|
||||
# Estimate token length of new messages
|
||||
ctx_size = self.get_optimal_ctx_size(agent.context_tokens, messages=messages[pre_add_index:])
|
||||
yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size }
|
||||
# Decrease creativity when processing tool call requests
|
||||
response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 })
|
||||
message.metadata["eval_count"] += response["eval_count"]
|
||||
message.metadata["eval_duration"] += response["eval_duration"]
|
||||
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
||||
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
||||
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
||||
|
||||
reply = response["message"]["content"]
|
||||
message.response = reply
|
||||
message.metadata["origin"] = agent.agent_type
|
||||
# final_message = {"role": "assistant", "content": reply }
|
||||
|
||||
# # history is provided to the LLM and should not have additional metadata
|
||||
# llm_history.append(final_message)
|
||||
|
||||
# user_history is provided to the REST API and does not include CONTEXT
|
||||
# It does include metadata
|
||||
# final_message["metadata"] = message.metadata
|
||||
# user_history.append({**final_message, "origin": message.metadata["origin"]})
|
||||
|
||||
# Return the REST API with metadata
|
||||
yield {
|
||||
"status": "done",
|
||||
"message": {
|
||||
**message.model_dump(mode='json'),
|
||||
}
|
||||
}
|
||||
|
||||
self.context.processing = False
|
||||
return
|
||||
|
||||
async def process_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
||||
message.full_content = ""
|
||||
for i, p in enumerate(message.preamble.keys()):
|
||||
message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n"
|
||||
|
||||
# Estimate token length of new messages
|
||||
message.ctx_size = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content)
|
||||
|
||||
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
|
||||
message.status = "thinking"
|
||||
yield message
|
||||
|
||||
for value in self.generate_llm_response(message):
|
||||
logging.info(f"LLM: {value.status} - {value.response}")
|
||||
if value.status != "done":
|
||||
yield value
|
||||
if value.status == "error":
|
||||
return
|
||||
|
||||
def get_and_reset_content_seed(self):
|
||||
tmp = self._content_seed
|
||||
self._content_seed = ""
|
||||
return tmp
|
||||
|
||||
def set_content_seed(self, content: str) -> None:
|
||||
"""Set the content seed for the agent."""
|
||||
self._content_seed = content
|
||||
|
||||
def get_content_seed(self) -> str:
|
||||
"""Get the content seed for the agent."""
|
||||
return self._content_seed
|
||||
|
||||
@classmethod
|
||||
def valid_agent_types(cls) -> set[str]:
|
||||
"""Return the set of valid agent_type values."""
|
||||
return set(get_args(cls.__annotations__["agent_type"]))
|
||||
|
||||
# Register the base agent
|
||||
registry.register(Agent._agent_type, Agent)
|
||||
|
||||
|
243
src/utils/agents/chat.py
Normal file
243
src/utils/agents/chat.py
Normal file
@ -0,0 +1,243 @@
|
||||
from pydantic import BaseModel, Field, model_validator, PrivateAttr
|
||||
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
|
||||
from typing_extensions import Annotated
|
||||
from abc import ABC, abstractmethod
|
||||
from typing_extensions import Annotated
|
||||
import logging
|
||||
from .base import Agent, registry
|
||||
from .. conversation import Conversation
|
||||
from .. message import Message
|
||||
|
||||
class Chat(Agent, ABC):
|
||||
"""
|
||||
Base class for all agent types.
|
||||
This class defines the common attributes and methods for all agent types.
|
||||
"""
|
||||
agent_type: Literal["chat"] = "chat"
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
"""Auto-register subclasses"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
# Register this class if it has an agent_type
|
||||
if hasattr(cls, 'agent_type') and cls.agent_type != Agent.agent_type:
|
||||
registry.register(cls.agent_type, cls)
|
||||
|
||||
def __init__(self, **data):
|
||||
# Set agent_type from class if not provided
|
||||
if 'agent_type' not in data:
|
||||
data['agent_type'] = self.__class__.agent_type
|
||||
super().__init__(**data)
|
||||
|
||||
system_prompt: str # Mandatory
|
||||
conversation: Conversation = Conversation()
|
||||
context_tokens: int = 0
|
||||
|
||||
# Add a property for context if needed without creating a circular reference
|
||||
@property
|
||||
def context(self) -> Optional['Context']:
|
||||
if TYPE_CHECKING:
|
||||
from .context import Context
|
||||
# Implement logic to fetch context by ID if needed
|
||||
return None
|
||||
|
||||
#context: Context
|
||||
|
||||
_content_seed: str = PrivateAttr(default="")
|
||||
|
||||
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
||||
"""
|
||||
Prepare message with context information in message.preamble
|
||||
"""
|
||||
# Generate RAG content if enabled, based on the content
|
||||
rag_context = ""
|
||||
if not message.disable_rag:
|
||||
# Gather RAG results, yielding each result
|
||||
# as it becomes available
|
||||
for value in self.context.generate_rag_results(message):
|
||||
logging.info(f"RAG: {value.status} - {value.response}")
|
||||
if value.status != "done":
|
||||
yield value
|
||||
if value.status == "error":
|
||||
message.status = "error"
|
||||
message.response = value.response
|
||||
yield message
|
||||
return
|
||||
|
||||
if message.metadata["rag"]:
|
||||
for rag_collection in message.metadata["rag"]:
|
||||
for doc in rag_collection["documents"]:
|
||||
rag_context += f"{doc}\n"
|
||||
|
||||
if rag_context:
|
||||
message["context"] = rag_context
|
||||
|
||||
if self.context.user_resume:
|
||||
message["resume"] = self.content.user_resume
|
||||
|
||||
if message.preamble:
|
||||
preamble_types = [f"<|{p}|>" for p in message.preamble.keys()]
|
||||
preamble_types_AND = " and ".join(preamble_types)
|
||||
preamble_types_OR = " or ".join(preamble_types)
|
||||
message.preamble["rules"] = f"""\
|
||||
- Answer the question based on the information provided in the {preamble_types_AND} sections by incorporate it seamlessly and refer to it using natural language instead of mentioning {preamble_or_types} or quoting it directly.
|
||||
- If there is no information in these sections, answer based on your knowledge.
|
||||
- Avoid phrases like 'According to the {preamble_types[0]}' or similar references to the {preamble_types_OR}.
|
||||
"""
|
||||
message.preamble["question"] = "Use that information to respond to:"
|
||||
else:
|
||||
message.preamble["question"] = "Respond to:"
|
||||
|
||||
message.system_prompt = self.system_prompt
|
||||
message.status = "done"
|
||||
yield message
|
||||
return
|
||||
|
||||
async def generate_llm_response(self, message: Message) -> AsyncGenerator[Message, None]:
|
||||
if self.context.processing:
|
||||
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
||||
message.status = "error"
|
||||
message.response = "Busy processing another request."
|
||||
yield message
|
||||
return
|
||||
|
||||
self.context.processing = True
|
||||
|
||||
messages = []
|
||||
|
||||
for value in self.llm.chat(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
#tools=llm_tools(context.tools) if message.enable_tools else None,
|
||||
options={ "num_ctx": message.ctx_size }
|
||||
):
|
||||
logging.info(f"LLM: {value.status} - {value.response}")
|
||||
if value.status != "done":
|
||||
message.status = value.status
|
||||
message.response = value.response
|
||||
yield message
|
||||
if value.status == "error":
|
||||
return
|
||||
response = value
|
||||
|
||||
message.metadata["eval_count"] += response["eval_count"]
|
||||
message.metadata["eval_duration"] += response["eval_duration"]
|
||||
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
||||
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
||||
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
||||
|
||||
tools_used = []
|
||||
|
||||
yield {"status": "processing", "message": "Initial response received..."}
|
||||
|
||||
if "tool_calls" in response.get("message", {}):
|
||||
yield {"status": "processing", "message": "Processing tool calls..."}
|
||||
|
||||
tool_message = response["message"]
|
||||
tool_result = None
|
||||
|
||||
# Process all yielded items from the handler
|
||||
async for item in self.handle_tool_calls(tool_message):
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
# This is the final result tuple (tool_result, tools_used)
|
||||
tool_result, tools_used = item
|
||||
else:
|
||||
# This is a status update, forward it
|
||||
yield item
|
||||
|
||||
message_dict = {
|
||||
"role": tool_message.get("role", "assistant"),
|
||||
"content": tool_message.get("content", "")
|
||||
}
|
||||
|
||||
if "tool_calls" in tool_message:
|
||||
message_dict["tool_calls"] = [
|
||||
{"function": {"name": tc["function"]["name"], "arguments": tc["function"]["arguments"]}}
|
||||
for tc in tool_message["tool_calls"]
|
||||
]
|
||||
|
||||
pre_add_index = len(messages)
|
||||
messages.append(message_dict)
|
||||
|
||||
if isinstance(tool_result, list):
|
||||
messages.extend(tool_result)
|
||||
else:
|
||||
if tool_result:
|
||||
messages.append(tool_result)
|
||||
|
||||
message.metadata["tools"] = tools_used
|
||||
|
||||
# Estimate token length of new messages
|
||||
ctx_size = self.get_optimal_ctx_size(agent.context_tokens, messages=messages[pre_add_index:])
|
||||
yield {"status": "processing", "message": "Generating final response...", "num_ctx": ctx_size }
|
||||
# Decrease creativity when processing tool call requests
|
||||
response = self.llm.chat(model=self.model, messages=messages, stream=False, options={ "num_ctx": ctx_size }) #, "temperature": 0.5 })
|
||||
message.metadata["eval_count"] += response["eval_count"]
|
||||
message.metadata["eval_duration"] += response["eval_duration"]
|
||||
message.metadata["prompt_eval_count"] += response["prompt_eval_count"]
|
||||
message.metadata["prompt_eval_duration"] += response["prompt_eval_duration"]
|
||||
agent.context_tokens = response["prompt_eval_count"] + response["eval_count"]
|
||||
|
||||
reply = response["message"]["content"]
|
||||
message.response = reply
|
||||
message.metadata["origin"] = agent.agent_type
|
||||
# final_message = {"role": "assistant", "content": reply }
|
||||
|
||||
# # history is provided to the LLM and should not have additional metadata
|
||||
# llm_history.append(final_message)
|
||||
|
||||
# user_history is provided to the REST API and does not include CONTEXT
|
||||
# It does include metadata
|
||||
# final_message["metadata"] = message.metadata
|
||||
# user_history.append({**final_message, "origin": message.metadata["origin"]})
|
||||
|
||||
# Return the REST API with metadata
|
||||
yield {
|
||||
"status": "done",
|
||||
"message": {
|
||||
**message.model_dump(mode='json'),
|
||||
}
|
||||
}
|
||||
|
||||
self.context.processing = False
|
||||
return
|
||||
|
||||
async def process_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
||||
message.full_content = ""
|
||||
for i, p in enumerate(message.preamble.keys()):
|
||||
message.full_content += '' if i == 0 else '\n\n' + f"<|{p}|>{message.preamble[p].strip()}\n"
|
||||
|
||||
# Estimate token length of new messages
|
||||
message.ctx_size = self.context.get_optimal_ctx_size(self.context_tokens, messages=message.full_content)
|
||||
|
||||
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
|
||||
message.status = "thinking"
|
||||
yield message
|
||||
|
||||
for value in self.generate_llm_response(message):
|
||||
logging.info(f"LLM: {value.status} - {value.response}")
|
||||
if value.status != "done":
|
||||
yield value
|
||||
if value.status == "error":
|
||||
return
|
||||
|
||||
def get_and_reset_content_seed(self):
|
||||
tmp = self._content_seed
|
||||
self._content_seed = ""
|
||||
return tmp
|
||||
|
||||
def set_content_seed(self, content: str) -> None:
|
||||
"""Set the content seed for the agent."""
|
||||
self._content_seed = content
|
||||
|
||||
def get_content_seed(self) -> str:
|
||||
"""Get the content seed for the agent."""
|
||||
return self._content_seed
|
||||
|
||||
@classmethod
|
||||
def valid_agent_types(cls) -> set[str]:
|
||||
"""Return the set of valid agent_type values."""
|
||||
return set(get_args(cls.__annotations__["agent_type"]))
|
||||
|
||||
# Register the base agent
|
||||
registry.register(Chat._agent_type, Chat)
|
24
src/utils/agents/fact_check.py
Normal file
24
src/utils/agents/fact_check.py
Normal file
@ -0,0 +1,24 @@
|
||||
from pydantic import BaseModel, Field, model_validator, PrivateAttr
|
||||
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
|
||||
from typing_extensions import Annotated
|
||||
from abc import ABC, abstractmethod
|
||||
from typing_extensions import Annotated
|
||||
import logging
|
||||
from .base import Agent, registry
|
||||
from .. conversation import Conversation
|
||||
from .. message import Message
|
||||
|
||||
class FactCheck(Agent):
|
||||
agent_type: Literal["fact_check"] = "fact_check"
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
facts: str = ""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_facts(self):
|
||||
if not self.facts.strip():
|
||||
raise ValueError("Facts cannot be empty")
|
||||
return self
|
||||
|
||||
# Register the base agent
|
||||
registry.register(FactCheck._agent_type, FactCheck)
|
24
src/utils/agents/job_description.py
Normal file
24
src/utils/agents/job_description.py
Normal file
@ -0,0 +1,24 @@
|
||||
from pydantic import BaseModel, Field, model_validator, PrivateAttr
|
||||
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
|
||||
from typing_extensions import Annotated
|
||||
from abc import ABC, abstractmethod
|
||||
from typing_extensions import Annotated
|
||||
import logging
|
||||
from .base import Agent, registry
|
||||
from .. conversation import Conversation
|
||||
from .. message import Message
|
||||
|
||||
class JobDescription(Agent):
|
||||
agent_type: Literal["job_description"] = "job_description"
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
job_description: str = ""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_job_description(self):
|
||||
if not self.job_description.strip():
|
||||
raise ValueError("Job description cannot be empty")
|
||||
return self
|
||||
|
||||
# Register the base agent
|
||||
registry.register(JobDescription._agent_type, JobDescription)
|
32
src/utils/agents/resume.py
Normal file
32
src/utils/agents/resume.py
Normal file
@ -0,0 +1,32 @@
|
||||
from pydantic import BaseModel, Field, model_validator, PrivateAttr
|
||||
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
|
||||
from typing_extensions import Annotated
|
||||
from abc import ABC, abstractmethod
|
||||
from typing_extensions import Annotated
|
||||
import logging
|
||||
from .base import Agent, registry
|
||||
from .. conversation import Conversation
|
||||
from .. message import Message
|
||||
|
||||
class Resume(Agent):
|
||||
agent_type: Literal["resume"] = "resume"
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
resume: str = ""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_resume(self):
|
||||
if not self.resume.strip():
|
||||
raise ValueError("Resume content cannot be empty")
|
||||
return self
|
||||
|
||||
def get_resume(self) -> str:
|
||||
"""Get the resume content."""
|
||||
return self.resume
|
||||
|
||||
def set_resume(self, resume: str) -> None:
|
||||
"""Set the resume content."""
|
||||
self.resume = resume
|
||||
|
||||
# Register the base agent
|
||||
registry.register(Resume._agent_type, Resume)
|
@ -1,16 +1,31 @@
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from uuid import uuid4
|
||||
from typing import List, Optional
|
||||
from typing import List, Dict, Any, Optional, Generator, TYPE_CHECKING
|
||||
from typing_extensions import Annotated, Union
|
||||
from .session import AnySession, Session
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
from .message import Message
|
||||
from .rag import ChromaDBFileWatcher
|
||||
from . import defines
|
||||
|
||||
from .agents import Agent, ContextBase
|
||||
|
||||
# Import only agent types, not actual classes
|
||||
if TYPE_CHECKING:
|
||||
from .agents import Agent, AnyAgent, Chat, Resume, JobDescription, FactCheck
|
||||
|
||||
from .agents import AnyAgent
|
||||
|
||||
class Context(ContextBase):
|
||||
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher
|
||||
|
||||
class Context(BaseModel):
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid4()),
|
||||
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
|
||||
)
|
||||
|
||||
sessions: List[Annotated[Union[*Session.__subclasses__()], Field(discriminator="session_type")]] = Field(
|
||||
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
@ -21,78 +36,145 @@ class Context(BaseModel):
|
||||
rags: List[dict] = []
|
||||
message_history_length: int = 5
|
||||
context_tokens: int = 0
|
||||
file_watcher: ChromaDBFileWatcher = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, id: Optional[str] = None, **kwargs):
|
||||
super().__init__(id=id if id is not None else str(uuid4()), **kwargs)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_unique_session_types(self):
|
||||
"""Ensure at most one session per session_type."""
|
||||
session_types = [session.session_type for session in self.sessions]
|
||||
if len(session_types) != len(set(session_types)):
|
||||
raise ValueError("Context cannot contain multiple sessions of the same session_type")
|
||||
def validate_unique_agent_types(self):
|
||||
"""Ensure at most one agent per agent_type."""
|
||||
agent_types = [agent.agent_type for agent in self.agents]
|
||||
if len(agent_types) != len(set(agent_types)):
|
||||
raise ValueError("Context cannot contain multiple agents of the same agent_type")
|
||||
return self
|
||||
|
||||
def get_or_create_session(self, session_type: str, **kwargs) -> Session:
|
||||
def get_optimal_ctx_size(self, context, messages, ctx_buffer = 4096):
|
||||
ctx = round(context + len(str(messages)) * 3 / 4)
|
||||
return max(defines.max_context, min(2048, ctx + ctx_buffer))
|
||||
|
||||
def generate_rag_results(self, message: Message) -> Generator[Message, None, None]:
|
||||
"""
|
||||
Get or create and append a new session of the specified type, ensuring only one session per type exists.
|
||||
Generate RAG results for the given query.
|
||||
|
||||
Args:
|
||||
session_type: The type of session to create (e.g., 'web', 'database').
|
||||
**kwargs: Additional fields required by the specific session subclass.
|
||||
query: The query string to generate RAG results for.
|
||||
|
||||
Returns:
|
||||
The created session instance.
|
||||
A list of dictionaries containing the RAG results.
|
||||
"""
|
||||
try:
|
||||
message.status = "processing"
|
||||
|
||||
entries : int = 0
|
||||
|
||||
if not self.file_watcher:
|
||||
message.response = "No RAG context available."
|
||||
del message.metadata["rag"]
|
||||
message.status = "done"
|
||||
yield message
|
||||
return
|
||||
|
||||
for rag in self.rags:
|
||||
if not rag["enabled"]:
|
||||
continue
|
||||
message.response = f"Checking RAG context {rag['name']}..."
|
||||
yield message
|
||||
chroma_results = self.file_watcher.find_similar(query=message.prompt, top_k=10)
|
||||
if chroma_results:
|
||||
entries += len(chroma_results["documents"])
|
||||
|
||||
chroma_embedding = np.array(chroma_results["query_embedding"]).flatten() # Ensure correct shape
|
||||
print(f"Chroma embedding shape: {chroma_embedding.shape}")
|
||||
|
||||
umap_2d = self.file_watcher.umap_model_2d.transform([chroma_embedding])[0].tolist()
|
||||
print(f"UMAP 2D output: {umap_2d}, length: {len(umap_2d)}") # Debug output
|
||||
|
||||
umap_3d = self.file_watcher.umap_model_3d.transform([chroma_embedding])[0].tolist()
|
||||
print(f"UMAP 3D output: {umap_3d}, length: {len(umap_3d)}") # Debug output
|
||||
|
||||
message.metadata["rag"][rag["name"]] = {
|
||||
**chroma_results,
|
||||
"umap_embedding_2d": umap_2d,
|
||||
"umap_embedding_3d": umap_3d
|
||||
}
|
||||
yield message
|
||||
|
||||
if entries == 0:
|
||||
del message.metadata["rag"]
|
||||
|
||||
message.response = f"RAG context gathered from results from {entries} documents."
|
||||
message.status = "done"
|
||||
yield message
|
||||
return
|
||||
except Exception as e:
|
||||
message.response = f"Error generating RAG results: {str(e)}"
|
||||
message.status = "error"
|
||||
logging.error(e)
|
||||
yield message
|
||||
return
|
||||
|
||||
def get_or_create_agent(self, agent_type: str, **kwargs) -> Agent:
|
||||
"""
|
||||
Get or create and append a new agent of the specified type, ensuring only one agent per type exists.
|
||||
|
||||
Args:
|
||||
agent_type: The type of agent to create (e.g., 'web', 'database').
|
||||
**kwargs: Additional fields required by the specific agent subclass.
|
||||
|
||||
Returns:
|
||||
The created agent instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If no matching session type is found or if a session of this type already exists.
|
||||
ValueError: If no matching agent type is found or if a agent of this type already exists.
|
||||
"""
|
||||
# Check if a session with the given session_type already exists
|
||||
for session in self.sessions:
|
||||
if session.session_type == session_type:
|
||||
return session
|
||||
# Check if a agent with the given agent_type already exists
|
||||
for agent in self.agents:
|
||||
if agent.agent_type == agent_type:
|
||||
return agent
|
||||
|
||||
# Find the matching subclass
|
||||
for session_cls in Session.__subclasses__():
|
||||
if session_cls.model_fields["session_type"].default == session_type:
|
||||
# Create the session instance with provided kwargs
|
||||
session = session_cls(session_type=session_type, **kwargs)
|
||||
self.sessions.append(session)
|
||||
return session
|
||||
for agent_cls in Agent.__subclasses__():
|
||||
logging.info(f"Found class: {agent_cls.model_fields['agent_type'].default}")
|
||||
if agent_cls.model_fields["agent_type"].default == agent_type:
|
||||
# Create the agent instance with provided kwargs
|
||||
agent = agent_cls(agent_type=agent_type, **kwargs)
|
||||
self.agents.append(agent)
|
||||
return agent
|
||||
|
||||
raise ValueError(f"No session class found for session_type: {session_type}")
|
||||
raise ValueError(f"No agent class found for agent_type: {agent_type}")
|
||||
|
||||
def add_session(self, session: AnySession) -> None:
|
||||
"""Add a Session to the context, ensuring no duplicate session_type."""
|
||||
if any(s.session_type == session.session_type for s in self.sessions):
|
||||
raise ValueError(f"A session with session_type '{session.session_type}' already exists")
|
||||
self.sessions.append(session)
|
||||
def add_agent(self, agent: AnyAgent) -> None:
|
||||
"""Add a Agent to the context, ensuring no duplicate agent_type."""
|
||||
if any(s.agent_type == agent.agent_type for s in self.agents):
|
||||
raise ValueError(f"A agent with agent_type '{agent.agent_type}' already exists")
|
||||
self.agents.append(agent)
|
||||
|
||||
def get_session(self, session_type: str) -> Session | None:
|
||||
"""Return the Session with the given session_type, or None if not found."""
|
||||
for session in self.sessions:
|
||||
if session.session_type == session_type:
|
||||
return session
|
||||
def get_agent(self, agent_type: str) -> Agent | None:
|
||||
"""Return the Agent with the given agent_type, or None if not found."""
|
||||
for agent in self.agents:
|
||||
if agent.agent_type == agent_type:
|
||||
return agent
|
||||
return None
|
||||
|
||||
def is_valid_session_type(self, session_type: str) -> bool:
|
||||
"""Check if the given session_type is valid."""
|
||||
return session_type in Session.valid_session_types()
|
||||
def is_valid_agent_type(self, agent_type: str) -> bool:
|
||||
"""Check if the given agent_type is valid."""
|
||||
return agent_type in Agent.valid_agent_types()
|
||||
|
||||
def get_summary(self) -> str:
|
||||
"""Return a summary of the context."""
|
||||
if not self.sessions:
|
||||
return f"Context {self.uuid}: No sessions."
|
||||
if not self.agents:
|
||||
return f"Context {self.uuid}: No agents."
|
||||
summary = f"Context {self.uuid}:\n"
|
||||
for i, session in enumerate(self.sessions, 1):
|
||||
summary += f"\nSession {i} ({session.session_type}):\n"
|
||||
summary += session.conversation.get_summary()
|
||||
if session.session_type == "resume":
|
||||
summary += f"\nResume: {session.get_resume()}\n"
|
||||
elif session.session_type == "job_description":
|
||||
summary += f"\nJob Description: {session.job_description}\n"
|
||||
elif session.session_type == "fact_check":
|
||||
summary += f"\nFacts: {session.facts}\n"
|
||||
elif session.session_type == "chat":
|
||||
summary += f"\nChat Name: {session.name}\n"
|
||||
for i, agent in enumerate(self.agents, 1):
|
||||
summary += f"\nAgent {i} ({agent.agent_type}):\n"
|
||||
summary += agent.conversation.get_summary()
|
||||
if agent.agent_type == "resume":
|
||||
summary += f"\nResume: {agent.get_resume()}\n"
|
||||
elif agent.agent_type == "job_description":
|
||||
summary += f"\nJob Description: {agent.job_description}\n"
|
||||
elif agent.agent_type == "fact_check":
|
||||
summary += f"\nFacts: {agent.facts}\n"
|
||||
elif agent.agent_type == "chat":
|
||||
summary += f"\nChat Name: {agent.name}\n"
|
||||
return summary
|
@ -9,7 +9,7 @@ embedding_model = os.getenv("EMBEDDING_MODEL_NAME", "mxbai-embed-large")
|
||||
persist_directory = os.getenv("PERSIST_DIR", "/opt/backstory/chromadb")
|
||||
max_context = 2048*8*2
|
||||
doc_dir = "/opt/backstory/docs/"
|
||||
session_dir = "/opt/backstory/sessions"
|
||||
context_dir = "/opt/backstory/sessions"
|
||||
static_content = "/opt/backstory/frontend/deployed"
|
||||
resume_doc = "/opt/backstory/docs/resume/generic.md"
|
||||
# Only used for testing; backstory-prod will not use this
|
||||
|
@ -3,10 +3,14 @@ from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime, timezone
|
||||
|
||||
class Message(BaseModel):
|
||||
prompt: str
|
||||
preamble: str = ""
|
||||
content: str = ""
|
||||
response: str = ""
|
||||
# Required
|
||||
prompt: str # Query to be answered
|
||||
|
||||
# Generated while processing message
|
||||
preamble: dict[str,str] = {} # Preamble to be prepended to the prompt
|
||||
system_prompt: str = "" # System prompt provided to the LLM
|
||||
full_content: str = "" # Full content of the message (preamble + prompt)
|
||||
response: str = "" # LLM response to the preamble + query
|
||||
metadata: dict[str, Any] = {
|
||||
"rag": { "documents": [] },
|
||||
"tools": [],
|
||||
@ -15,7 +19,7 @@ class Message(BaseModel):
|
||||
"prompt_eval_count": 0,
|
||||
"prompt_eval_duration": 0,
|
||||
}
|
||||
actions: List[str] = []
|
||||
actions: List[str] = [] # Other session modifying actions performed while processing the message
|
||||
timestamp: datetime = datetime.now(timezone.utc)
|
||||
|
||||
def add_action(self, action: str | list[str]) -> None:
|
||||
|
@ -1,3 +1,4 @@
|
||||
from pydantic import BaseModel, Field, model_validator, PrivateAttr
|
||||
import os
|
||||
import glob
|
||||
from pathlib import Path
|
||||
|
@ -1,78 +0,0 @@
|
||||
from pydantic import BaseModel, Field, model_validator, PrivateAttr
|
||||
from typing import Literal, TypeAlias, get_args
|
||||
from .conversation import Conversation
|
||||
|
||||
class Session(BaseModel):
|
||||
session_type: Literal["resume", "job_description", "fact_check", "chat"]
|
||||
system_prompt: str # Mandatory
|
||||
conversation: Conversation = Conversation()
|
||||
context_tokens: int = 0
|
||||
|
||||
_content_seed: str = PrivateAttr(default="")
|
||||
|
||||
def get_and_reset_content_seed(self):
|
||||
tmp = self._content_seed
|
||||
self._content_seed = ""
|
||||
return tmp
|
||||
|
||||
def set_content_seed(self, content: str) -> None:
|
||||
"""Set the content seed for the session."""
|
||||
self._content_seed = content
|
||||
|
||||
def get_content_seed(self) -> str:
|
||||
"""Get the content seed for the session."""
|
||||
return self._content_seed
|
||||
|
||||
@classmethod
|
||||
def valid_session_types(cls) -> set[str]:
|
||||
"""Return the set of valid session_type values."""
|
||||
return set(get_args(cls.__annotations__["session_type"]))
|
||||
|
||||
|
||||
# Type alias for Session or any subclass
|
||||
AnySession: TypeAlias = Session # BaseModel covers Session and subclasses
|
||||
|
||||
class Resume(Session):
|
||||
session_type: Literal["resume"] = "resume"
|
||||
resume: str = ""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_resume(self):
|
||||
if not self.resume.strip():
|
||||
raise ValueError("Resume content cannot be empty")
|
||||
return self
|
||||
|
||||
def get_resume(self) -> str:
|
||||
"""Get the resume content."""
|
||||
return self.resume
|
||||
|
||||
def set_resume(self, resume: str) -> None:
|
||||
"""Set the resume content."""
|
||||
self.resume = resume
|
||||
|
||||
class JobDescription(Session):
|
||||
session_type: Literal["job_description"] = "job_description"
|
||||
job_description: str = ""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_job_description(self):
|
||||
if not self.job_description.strip():
|
||||
raise ValueError("Job description cannot be empty")
|
||||
return self
|
||||
|
||||
class FactCheck(Session):
|
||||
session_type: Literal["fact_check"] = "fact_check"
|
||||
facts: str = ""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_facts(self):
|
||||
if not self.facts.strip():
|
||||
raise ValueError("Facts cannot be empty")
|
||||
return self
|
||||
|
||||
class Chat(Session):
|
||||
session_type: Literal["chat"] = "chat"
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name(self):
|
||||
return self
|
Loading…
x
Reference in New Issue
Block a user