Tools are working and shared context is in use aross all agents

This commit is contained in:
James Ketr 2025-05-02 16:46:42 -07:00
parent cc0f6974ff
commit 3fe2cfd9ef
13 changed files with 384 additions and 399 deletions

View File

@ -1,6 +1,6 @@
from utils import logger from utils import logger
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar from typing import AsyncGenerator
# %% # %%
# Imports [standard] # Imports [standard]
@ -40,13 +40,13 @@ try_import("sklearn")
import ollama import ollama
import requests import requests
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, BackgroundTasks from fastapi import FastAPI, Request, BackgroundTasks # type: ignore
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse # type: ignore
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware # type: ignore
import uvicorn import uvicorn # type: ignore
import numpy as np import numpy as np # type: ignore
import umap import umap # type: ignore
from sklearn.preprocessing import MinMaxScaler from sklearn.preprocessing import MinMaxScaler # type: ignore
from utils import ( from utils import (
rag as Rag, rag as Rag,
@ -449,6 +449,8 @@ class WebServer:
prompt = system_generate_resume prompt = system_generate_resume
case "fact_check": case "fact_check":
prompt = system_message prompt = system_message
case _:
prompt = system_message
agent.system_prompt = prompt agent.system_prompt = prompt
response["system_prompt"] = { "system_prompt": prompt } response["system_prompt"] = { "system_prompt": prompt }
@ -759,9 +761,9 @@ class WebServer:
with open(file_path, "r") as f: with open(file_path, "r") as f:
content = f.read() content = f.read()
logger.info(f"Loading context from {file_path}, content length: {len(content)}") logger.info(f"Loading context from {file_path}, content length: {len(content)}")
import json
try: try:
# Try parsing as JSON first to ensure valid JSON # Try parsing as JSON first to ensure valid JSON
import json
json_data = json.loads(content) json_data = json.loads(content)
logger.info("JSON parsed successfully, attempting model validation") logger.info("JSON parsed successfully, attempting model validation")
@ -1426,7 +1428,7 @@ def main():
module="umap.*" module="umap.*"
) )
llm = ollama.Client(host=args.ollama_server) llm = ollama.Client(host=args.ollama_server) # type: ignore
model = args.ollama_model model = args.ollama_model
web_server = WebServer(llm, model) web_server = WebServer(llm, model)

View File

@ -1,18 +1,14 @@
from __future__ import annotations from __future__ import annotations
from typing import Optional, Type
import importlib import importlib
from pydantic import BaseModel from pydantic import BaseModel # type: ignore
from typing import Type
from . import defines from . import defines
from . rag import ChromaDBFileWatcher, start_file_watcher
from . message import Message
from . conversation import Conversation
from . context import Context from . context import Context
from . import agents from . conversation import Conversation
from . message import Message
from . rag import ChromaDBFileWatcher, start_file_watcher
from . setup_logging import setup_logging from . setup_logging import setup_logging
from .agents import class_registry, AnyAgent, Agent, __all__ as agents_all from .agents import class_registry, AnyAgent, Agent, __all__ as agents_all
__all__ = [ __all__ = [
@ -21,9 +17,11 @@ __all__ = [
'Conversation', 'Conversation',
'Message', 'Message',
'ChromaDBFileWatcher', 'ChromaDBFileWatcher',
'start_file_watcher' 'start_file_watcher',
'logger', 'logger',
] + agents_all ]
__all__.extend(agents_all) # type: ignore
# Resolve circular dependencies by rebuilding models # Resolve circular dependencies by rebuilding models
# Call model_rebuild() on Agent and Context # Call model_rebuild() on Agent and Context

View File

@ -39,11 +39,11 @@ for path in package_dir.glob("*.py"):
): ):
class_registry[name] = (full_module_name, name) class_registry[name] = (full_module_name, name)
globals()[name] = obj globals()[name] = obj
logger.info(f"Adding agent: {name} from {full_module_name}") logger.info(f"Adding agent: {name}")
__all__.append(name) __all__.append(name) # type: ignore
except ImportError as e: except ImportError as e:
logger.error(f"Error importing {full_module_name}: {e}") logger.error(f"Error importing {full_module_name}: {e}")
continue raise e
except Exception as e: except Exception as e:
logger.error(f"Error processing {full_module_name}: {e}") logger.error(f"Error processing {full_module_name}: {e}")
raise e raise e

View File

@ -1,10 +1,18 @@
from __future__ import annotations from __future__ import annotations
from pydantic import BaseModel, model_validator, PrivateAttr, Field from pydantic import BaseModel, PrivateAttr, Field # type: ignore
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef, Any from typing import (
from abc import ABC, abstractmethod Literal, get_args, List, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, Any,
from typing_extensions import Annotated )
from abc import ABC
from .. setup_logging import setup_logging from .. setup_logging import setup_logging
from .. import defines from .. import defines
from abc import ABC
import logging
from .. message import Message
from .. import tools as Tools
import json
import time
import inspect
logger = setup_logging() logger = setup_logging()
@ -22,6 +30,9 @@ class Agent(BaseModel, ABC):
Base class for all agent types. Base class for all agent types.
This class defines the common attributes and methods for all agent types. This class defines the common attributes and methods for all agent types.
""" """
# Agent management with pydantic
agent_type: Literal["base"] = "base"
_agent_type: ClassVar[str] = agent_type # Add this for registration
# context_size is shared across all subclasses # context_size is shared across all subclasses
_context_size: ClassVar[int] = int(defines.max_context * 0.5) _context_size: ClassVar[int] = int(defines.max_context * 0.5)
@ -33,10 +44,6 @@ class Agent(BaseModel, ABC):
def context_size(self, value: int): def context_size(self, value: int):
Agent._context_size = value Agent._context_size = value
# Agent management with pydantic
agent_type: Literal["base"] = "base"
_agent_type: ClassVar[str] = agent_type # Add this for registration
# Agent properties # Agent properties
system_prompt: str # Mandatory system_prompt: str # Mandatory
conversation: Conversation = Conversation() conversation: Conversation = Conversation()
@ -102,6 +109,316 @@ class Agent(BaseModel, ABC):
def get_agent_type(self): def get_agent_type(self):
return self._agent_type return self._agent_type
async def process_tool_calls(self, llm: Any, model: str, message: Message, tool_message: Any, messages: List[Any]) -> AsyncGenerator[Message, None]:
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
if not self.context:
raise ValueError("Context is not set for this agent.")
if not message.metadata["tools"]:
raise ValueError("tools field not initialized")
tool_metadata = message.metadata["tools"]
tool_metadata["messages"] = messages
tool_metadata["tool_calls"] = []
message.status = "tooling"
for i, tool_call in enumerate(tool_message.tool_calls):
arguments = tool_call.function.arguments
tool = tool_call.function.name
# Yield status update before processing each tool
message.response = f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..."
yield message
logging.info(f"LLM - {message.response}")
# Process the tool based on its type
match tool:
case "TickerValue":
ticker = arguments.get("ticker")
if not ticker:
ret = None
else:
ret = Tools.TickerValue(ticker)
case "AnalyzeSite":
url = arguments.get("url")
question = arguments.get("question", "what is the summary of this content?")
# Additional status update for long-running operations
message.response = f"Retrieving and summarizing content from {url}..."
yield message
ret = await Tools.AnalyzeSite(llm=llm, model=model, url=url, question=question)
case "DateTime":
tz = arguments.get("timezone")
ret = Tools.DateTime(tz)
case "WeatherForecast":
city = arguments.get("city")
state = arguments.get("state")
message.response = f"Fetching weather data for {city}, {state}..."
yield message
ret = Tools.WeatherForecast(city, state)
case _:
ret = None
# Build response for this tool
tool_response = {
"role": "tool",
"content": json.dumps(ret),
"name": tool_call.function.name
}
tool_metadata["tool_calls"].append(tool_response)
if len(tool_metadata["tool_calls"]) == 0:
message.status = "done"
yield message
return
message_dict = {
"role": tool_message.get("role", "assistant"),
"content": tool_message.get("content", ""),
"tool_calls": [ {
"function": {
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
}
} for tc in tool_message.tool_calls
]
}
messages.append(message_dict)
messages.extend(tool_metadata["tool_calls"])
message.status = "thinking"
# Decrease creativity when processing tool call requests
message.response = ""
start_time = time.perf_counter()
for response in llm.chat(
model=model,
messages=messages,
stream=True,
options={
**message.metadata["options"],
# "temperature": 0.5,
}
):
# logging.info(f"LLM::Tools: {'done' if response.done else 'processing'} - {response.message}")
message.status = "streaming"
message.response += response.message.content
if not response.done:
yield message
if response.done:
message.metadata["eval_count"] += response.eval_count
message.metadata["eval_duration"] += response.eval_duration
message.metadata["prompt_eval_count"] += response.prompt_eval_count
message.metadata["prompt_eval_duration"] += response.prompt_eval_duration
self.context_tokens = response.prompt_eval_count + response.eval_count
message.status = "done"
yield message
end_time = time.perf_counter()
message.metadata["timers"]["llm_with_tools"] = f"{(end_time - start_time):.4f}"
return
async def generate_llm_response(self, llm: Any, model: str, message: Message) -> AsyncGenerator[Message, None]:
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
if not self.context:
raise ValueError("Context is not set for this agent.")
messages = [ { "role": "system", "content": message.system_prompt } ]
messages.extend([
item for m in self.conversation.messages
for item in [
{"role": "user", "content": m.prompt.strip()},
{"role": "assistant", "content": m.response.strip()}
]
])
messages.append({
"role": "user",
"content": message.context_prompt.strip(),
})
message.metadata["messages"] = messages
message.metadata["options"]={
"seed": 8911,
"num_ctx": self.context_size,
#"temperature": 0.9, # Higher temperature to encourage tool usage
}
message.metadata["timers"] = {}
use_tools = message.enable_tools and len(self.context.tools) > 0
message.metadata["tools"] = {
"available": Tools.llm_tools(self.context.tools),
"used": False
}
tool_metadata = message.metadata["tools"]
if use_tools:
message.status = "thinking"
message.response = f"Performing tool analysis step 1/2..."
yield message
logging.info("Checking for LLM tool usage")
start_time = time.perf_counter()
# Tools are enabled and available, so query the LLM with a short token target to see if it will
# use the tools
tool_metadata["messages"] = [{ "role": "system", "content": self.system_prompt}, {"role": "user", "content": message.prompt}]
response = llm.chat(
model=model,
messages=tool_metadata["messages"],
tools=tool_metadata["available"],
options={
**message.metadata["options"],
#"num_predict": 1024, # "Low" token limit to cut off after tool call
},
stream=False # No need to stream the probe
)
end_time = time.perf_counter()
message.metadata["timers"]["tool_check"] = f"{(end_time - start_time):.4f}"
if not response.message.tool_calls:
logging.info("LLM indicates tools will not be used")
# The LLM will not use tools, so disable use_tools so we can stream the full response
use_tools = False
if use_tools:
logging.info("LLM indicates tools will be used")
# Tools are enabled and available and the LLM indicated it will use them
tool_metadata["attempted"] = response.message.tool_calls
message.response = f"Performing tool analysis step 2/2 (tool use suspected)..."
yield message
logging.info(f"Performing LLM call with tools")
start_time = time.perf_counter()
response = llm.chat(
model=model,
messages=tool_metadata["messages"], # messages,
tools=tool_metadata["available"],
options={
**message.metadata["options"],
},
stream=False
)
end_time = time.perf_counter()
message.metadata["timers"]["non_streaming"] = f"{(end_time - start_time):.4f}"
if not response:
message.status = "error"
message.response = "No response from LLM."
yield message
return
if response.message.tool_calls:
tool_metadata["used"] = response.message.tool_calls
# Process all yielded items from the handler
start_time = time.perf_counter()
async for message in self.process_tool_calls(llm=llm, model=model, message=message, tool_message=response.message, messages=messages):
if message.status == "error":
yield message
return
yield message
end_time = time.perf_counter()
message.metadata["timers"]["process_tool_calls"] = f"{(end_time - start_time):.4f}"
message.status = "done"
return
logging.info("LLM indicated tools will be used, and then they weren't")
message.response = response.message.content
message.status = "done"
yield message
return
# not use_tools
yield message
# Reset the response for streaming
message.response = ""
start_time = time.perf_counter()
for response in llm.chat(
model=model,
messages=messages,
options={
**message.metadata["options"],
},
stream=True,
):
if not response:
message.status = "error"
message.response = "No response from LLM."
yield message
return
message.status = "streaming"
message.response += response.message.content
if not response.done:
yield message
if response.done:
message.metadata["eval_count"] += response.eval_count
message.metadata["eval_duration"] += response.eval_duration
message.metadata["prompt_eval_count"] += response.prompt_eval_count
message.metadata["prompt_eval_duration"] += response.prompt_eval_duration
self.context_tokens = response.prompt_eval_count + response.eval_count
message.status = "done"
yield message
end_time = time.perf_counter()
message.metadata["timers"]["streamed"] = f"{(end_time - start_time):.4f}"
return
async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]:
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
if not self.context:
raise ValueError("Context is not set for this agent.")
if self.context.processing:
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
message.status = "error"
message.response = "Busy processing another request."
yield message
return
self.context.processing = True
message.metadata["system_prompt"] = f"<|system|>\n{self.system_prompt.strip()}\n"
message.context_prompt = ""
for p in message.preamble.keys():
message.context_prompt += f"\n<|{p}|>\n{message.preamble[p].strip()}\n"
message.context_prompt += f"{message.prompt}"
# Estimate token length of new messages
message.response = f"Optimizing context..."
message.status = "thinking"
yield message
message.metadata["context_size"] = self.set_optimal_context_size(llm, model, prompt=message.context_prompt)
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
message.status = "thinking"
yield message
async for message in self.generate_llm_response(llm, model, message):
# logging.info(f"LLM: {message.status} - {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}")
if message.status == "error":
yield message
self.context.processing = False
return
yield message
# Done processing, add message to conversation
message.status = "done"
self.conversation.add_message(message)
self.context.processing = False
return
# Register the base agent # Register the base agent
registry.register(Agent._agent_type, Agent) registry.register(Agent._agent_type, Agent)

View File

@ -1,26 +1,15 @@
from __future__ import annotations from __future__ import annotations
from pydantic import BaseModel, model_validator, PrivateAttr from typing import Literal, AsyncGenerator, ClassVar, Optional
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, Any
from typing_extensions import Annotated
from abc import ABC, abstractmethod
from typing_extensions import Annotated
import logging import logging
from . base import Agent, registry from . base import Agent, registry
from .. conversation import Conversation
from .. message import Message from .. message import Message
from .. import defines
from .. import tools as Tools
from ollama import ChatResponse
import json
import time
import inspect import inspect
class Chat(Agent, ABC): class Chat(Agent):
""" """
Base class for all agent types. Chat Agent
This class defines the common attributes and methods for all agent types.
""" """
agent_type: Literal["chat"] = "chat" agent_type: Literal["chat"] = "chat" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]: async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
@ -74,313 +63,5 @@ class Chat(Agent, ABC):
yield message yield message
return return
async def process_tool_calls(self, llm: Any, model: str, message: Message, tool_message: Any, messages: List[Any]) -> AsyncGenerator[Message, None]:
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
if not self.context:
raise ValueError("Context is not set for this agent.")
if not message.metadata["tools"]:
raise ValueError("tools field not initialized")
tool_metadata = message.metadata["tools"]
tool_metadata["messages"] = messages
tool_metadata["tool_calls"] = []
message.status = "tooling"
for i, tool_call in enumerate(tool_message.tool_calls):
arguments = tool_call.function.arguments
tool = tool_call.function.name
# Yield status update before processing each tool
message.response = f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..."
yield message
logging.info(f"LLM - {message.response}")
# Process the tool based on its type
match tool:
case "TickerValue":
ticker = arguments.get("ticker")
if not ticker:
ret = None
else:
ret = Tools.TickerValue(ticker)
case "AnalyzeSite":
url = arguments.get("url")
question = arguments.get("question", "what is the summary of this content?")
# Additional status update for long-running operations
message.response = f"Retrieving and summarizing content from {url}..."
yield message
ret = await Tools.AnalyzeSite(llm=llm, model=model, url=url, question=question)
case "DateTime":
tz = arguments.get("timezone")
ret = Tools.DateTime(tz)
case "WeatherForecast":
city = arguments.get("city")
state = arguments.get("state")
message.response = f"Fetching weather data for {city}, {state}..."
yield message
ret = Tools.WeatherForecast(city, state)
case _:
ret = None
# Build response for this tool
tool_response = {
"role": "tool",
"content": json.dumps(ret),
"name": tool_call.function.name
}
tool_metadata["tool_calls"].append(tool_response)
if len(tool_metadata["tool_calls"]) == 0:
message.status = "done"
yield message
return
message_dict = {
"role": tool_message.get("role", "assistant"),
"content": tool_message.get("content", ""),
"tool_calls": [
{
"function": {
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
}
}
for tc in tool_message.tool_calls
]
}
messages.append(message_dict)
messages.extend(tool_metadata["tool_calls"])
message.status = "thinking"
# Decrease creativity when processing tool call requests
message.response = ""
start_time = time.perf_counter()
for response in llm.chat(
model=model,
messages=messages,
stream=True,
options={
**message.metadata["options"],
# "temperature": 0.5,
}
):
# logging.info(f"LLM::Tools: {'done' if response.done else 'processing'} - {response.message}")
message.status = "streaming"
message.response += response.message.content
if not response.done:
yield message
if response.done:
message.metadata["eval_count"] += response.eval_count
message.metadata["eval_duration"] += response.eval_duration
message.metadata["prompt_eval_count"] += response.prompt_eval_count
message.metadata["prompt_eval_duration"] += response.prompt_eval_duration
self.context_tokens = response.prompt_eval_count + response.eval_count
message.status = "done"
yield message
end_time = time.perf_counter()
message.metadata["timers"]["llm_with_tools"] = f"{(end_time - start_time):.4f}"
return
async def generate_llm_response(self, llm: Any, model: str, message: Message) -> AsyncGenerator[Message, None]:
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
if not self.context:
raise ValueError("Context is not set for this agent.")
messages = [ { "role": "system", "content": message.system_prompt } ]
messages.extend([
item for m in self.conversation.messages
for item in [
{"role": "user", "content": m.prompt.strip()},
{"role": "assistant", "content": m.response.strip()}
]
])
messages.append({
"role": "user",
"content": message.context_prompt.strip(),
})
message.metadata["messages"] = messages
message.metadata["options"]={
"seed": 8911,
"num_ctx": self.context_size,
#"temperature": 0.9, # Higher temperature to encourage tool usage
}
message.metadata["timers"] = {}
use_tools = message.enable_tools and len(self.context.tools) > 0
message.metadata["tools"] = {
"available": Tools.llm_tools(self.context.tools),
"used": False
}
tool_metadata = message.metadata["tools"]
if use_tools:
message.status = "thinking"
message.response = f"Performing tool analysis step 1/2..."
yield message
logging.info("Checking for LLM tool usage")
start_time = time.perf_counter()
# Tools are enabled and available, so query the LLM with a short token target to see if it will
# use the tools
tool_metadata["messages"] = [{ "role": "system", "content": self.system_prompt}, {"role": "user", "content": message.prompt}]
response = llm.chat(
model=model,
messages=tool_metadata["messages"],
tools=tool_metadata["available"],
options={
**message.metadata["options"],
#"num_predict": 1024, # "Low" token limit to cut off after tool call
},
stream=False # No need to stream the probe
)
end_time = time.perf_counter()
message.metadata["timers"]["tool_check"] = f"{(end_time - start_time):.4f}"
if not response.message.tool_calls:
logging.info("LLM indicates tools will not be used")
# The LLM will not use tools, so disable use_tools so we can stream the full response
use_tools = False
if use_tools:
logging.info("LLM indicates tools will be used")
# Tools are enabled and available and the LLM indicated it will use them
tool_metadata["attempted"] = response.message.tool_calls
message.response = f"Performing tool analysis step 2/2 (tool use suspected)..."
yield message
logging.info(f"Performing LLM call with tools")
start_time = time.perf_counter()
response = llm.chat(
model=model,
messages=tool_metadata["messages"], # messages,
tools=tool_metadata["available"],
options={
**message.metadata["options"],
},
stream=False
)
end_time = time.perf_counter()
message.metadata["timers"]["non_streaming"] = f"{(end_time - start_time):.4f}"
if not response:
message.status = "error"
message.response = "No response from LLM."
yield message
return
if response.message.tool_calls:
tool_metadata["used"] = response.message.tool_calls
# Process all yielded items from the handler
start_time = time.perf_counter()
async for message in self.process_tool_calls(llm=llm, model=model, message=message, tool_message=response.message, messages=messages):
if message.status == "error":
yield message
return
yield message
end_time = time.perf_counter()
message.metadata["timers"]["process_tool_calls"] = f"{(end_time - start_time):.4f}"
message.status = "done"
return
logging.info("LLM indicated tools will be used, and then they weren't")
message.response = response.message.content
message.status = "done"
yield message
return
# not use_tools
yield message
# Reset the response for streaming
message.response = ""
start_time = time.perf_counter()
for response in llm.chat(
model=model,
messages=messages,
options={
**message.metadata["options"],
},
stream=True,
):
if not response:
message.status = "error"
message.response = "No response from LLM."
yield message
return
message.status = "streaming"
message.response += response.message.content
if not response.done:
yield message
if response.done:
message.metadata["eval_count"] += response.eval_count
message.metadata["eval_duration"] += response.eval_duration
message.metadata["prompt_eval_count"] += response.prompt_eval_count
message.metadata["prompt_eval_duration"] += response.prompt_eval_duration
self.context_tokens = response.prompt_eval_count + response.eval_count
message.status = "done"
yield message
end_time = time.perf_counter()
message.metadata["timers"]["streamed"] = f"{(end_time - start_time):.4f}"
return
async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]:
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
if not self.context:
raise ValueError("Context is not set for this agent.")
if self.context.processing:
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
message.status = "error"
message.response = "Busy processing another request."
yield message
return
self.context.processing = True
message.metadata["system_prompt"] = f"<|system|>\n{self.system_prompt.strip()}\n"
message.context_prompt = ""
for p in message.preamble.keys():
message.context_prompt += f"\n<|{p}|>\n{message.preamble[p].strip()}\n"
message.context_prompt += f"{message.prompt}"
# Estimate token length of new messages
message.response = f"Optimizing context..."
message.status = "thinking"
yield message
message.metadata["context_size"] = self.set_optimal_context_size(llm, model, prompt=message.context_prompt)
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
message.status = "thinking"
yield message
async for message in self.generate_llm_response(llm, model, message):
# logging.info(f"LLM: {message.status} - {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}")
if message.status == "error":
yield message
self.context.processing = False
return
yield message
# Done processing, add message to conversation
message.status = "done"
self.conversation.add_message(message)
self.context.processing = False
return
# Register the base agent # Register the base agent
registry.register(Chat._agent_type, Chat) registry.register(Chat._agent_type, Chat)

View File

@ -1,15 +1,9 @@
from pydantic import BaseModel, Field, model_validator, PrivateAttr from pydantic import model_validator # type: ignore
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar from typing import Literal, ClassVar, Optional
from typing_extensions import Annotated
from abc import ABC, abstractmethod
from typing_extensions import Annotated
import logging
from .base import Agent, registry from .base import Agent, registry
from .. conversation import Conversation
from .. message import Message
class FactCheck(Agent): class FactCheck(Agent):
agent_type: Literal["fact_check"] = "fact_check" agent_type: Literal["fact_check"] = "fact_check" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
facts: str = "" facts: str = ""

View File

@ -1,15 +1,12 @@
from pydantic import BaseModel, Field, model_validator, PrivateAttr from pydantic import model_validator # type: ignore
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar from typing import Literal, ClassVar, Optional
from typing_extensions import Annotated
from abc import ABC, abstractmethod
from typing_extensions import Annotated
import logging
from .base import Agent, registry from .base import Agent, registry
from .. conversation import Conversation from .. conversation import Conversation
from .. message import Message from .. message import Message
from abc import ABC
class JobDescription(Agent): class JobDescription(Agent):
agent_type: Literal["job_description"] = "job_description" agent_type: Literal["job_description"] = "job_description" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
job_description: str = "" job_description: str = ""

View File

@ -9,7 +9,7 @@ from .. conversation import Conversation
from .. message import Message from .. message import Message
class Resume(Agent): class Resume(Agent):
agent_type: Literal["resume"] = "resume" agent_type: Literal["resume"] = "resume" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
resume: str = "" resume: str = ""

View File

@ -1,8 +1,5 @@
from __future__ import annotations from __future__ import annotations
from typing import List, Dict, Any, Union, ForwardRef, TypeVar, Optional, TYPE_CHECKING, Type, ClassVar, Literal from typing import List, Dict, ForwardRef, Optional, Type
from typing_extensions import Annotated
from pydantic import Field, BaseModel
from abc import ABC, abstractmethod
# Forward references # Forward references
AgentRef = ForwardRef('Agent') AgentRef = ForwardRef('Agent')

View File

@ -1,4 +1,4 @@
import tiktoken import tiktoken # type: ignore
from . import defines from . import defines
from typing import List, Dict, Any, Union from typing import List, Dict, Any, Union

View File

@ -1,9 +1,9 @@
from __future__ import annotations from __future__ import annotations
from pydantic import BaseModel, Field, model_validator, ValidationError from pydantic import BaseModel, Field, model_validator# type: ignore
from uuid import uuid4 from uuid import uuid4
from typing import List, Dict, Any, Optional, Generator, TYPE_CHECKING from typing import List, Optional, Generator
from typing_extensions import Annotated, Union from typing_extensions import Annotated, Union
import numpy as np import numpy as np # type: ignore
import logging import logging
from uuid import uuid4 from uuid import uuid4
import re import re
@ -34,7 +34,7 @@ class Context(BaseModel):
rags: List[dict] = [] rags: List[dict] = []
message_history_length: int = 5 message_history_length: int = 5
# Class managed fields # Class managed fields
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( # type: ignore
default_factory=list default_factory=list
) )

View File

@ -1,4 +1,4 @@
from pydantic import BaseModel, model_validator from pydantic import BaseModel # type: ignore
from typing import Dict, List, Optional, Any from typing import Dict, List, Optional, Any
from datetime import datetime, timezone from datetime import datetime, timezone

View File

@ -1,4 +1,4 @@
from pydantic import BaseModel, Field, model_validator, PrivateAttr from typing import List
import os import os
import glob import glob
from pathlib import Path from pathlib import Path
@ -13,18 +13,18 @@ import hashlib
import asyncio import asyncio
import json import json
import pickle import pickle
import numpy as np import numpy as np # type: ignore
import re import re
import chromadb import chromadb
import ollama import ollama
from langchain.text_splitter import CharacterTextSplitter from langchain.text_splitter import CharacterTextSplitter # type: ignore
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer # type: ignore
from langchain.schema import Document from langchain.schema import Document # type: ignore
from watchdog.observers import Observer from watchdog.observers import Observer # type: ignore
from watchdog.events import FileSystemEventHandler from watchdog.events import FileSystemEventHandler # type: ignore
import umap import umap # type: ignore
from markitdown import MarkItDown from markitdown import MarkItDown # type: ignore
# Import your existing modules # Import your existing modules
if __name__ == "__main__": if __name__ == "__main__":
@ -53,11 +53,10 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
self.chunk_overlap = chunk_overlap self.chunk_overlap = chunk_overlap
self.loop = loop self.loop = loop
self._umap_collection = None self._umap_collection = None
self._umap_embedding_2d = [] self._umap_embedding_2d : List[int]= []
self._umap_embedding_3d = [] self._umap_embedding_3d = []
self._umap_model_2d = None self._umap_model_2d : umap.UMAP = None
self._umap_model_3d = None self._umap_model_3d : umap.UMAP = None
self._collection = None
self.md = MarkItDown(enable_plugins=False) # Set to True to enable plugins self.md = MarkItDown(enable_plugins=False) # Set to True to enable plugins
#self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') #self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
@ -94,11 +93,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
return self._umap_collection return self._umap_collection
@property @property
def umap_embedding_2d(self): def umap_embedding_2d(self) -> List[int]:
return self._umap_embedding_2d return self._umap_embedding_2d
@property @property
def umap_embedding_3d(self): def umap_embedding_3d(self) -> List[int]:
return self._umap_embedding_3d return self._umap_embedding_3d
@property @property
@ -289,9 +288,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
def _get_vector_collection(self, recreate=False): def _get_vector_collection(self, recreate=False):
"""Get or create a ChromaDB collection.""" """Get or create a ChromaDB collection."""
# Initialize ChromaDB client # Initialize ChromaDB client
chroma_client = chromadb.PersistentClient( chroma_client = chromadb.PersistentClient( # type: ignore
path=self.persist_directory, path=self.persist_directory,
settings=chromadb.Settings(anonymized_telemetry=False) settings=chromadb.Settings(anonymized_telemetry=False) # type: ignore
) )
# Check if the collection exists # Check if the collection exists
@ -577,7 +576,7 @@ if __name__ == "__main__":
import defines import defines
# Initialize Ollama client # Initialize Ollama client
llm = ollama.Client(host=defines.ollama_api_url) llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
# Start the file watcher (with initialization) # Start the file watcher (with initialization)
observer, file_watcher = start_file_watcher( observer, file_watcher = start_file_watcher(