diff --git a/src/server.py b/src/server.py index f73c445..4ee2c83 100644 --- a/src/server.py +++ b/src/server.py @@ -1,6 +1,6 @@ from utils import logger -from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar +from typing import AsyncGenerator # %% # Imports [standard] @@ -40,13 +40,13 @@ try_import("sklearn") import ollama import requests from contextlib import asynccontextmanager -from fastapi import FastAPI, Request, BackgroundTasks -from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse -from fastapi.middleware.cors import CORSMiddleware -import uvicorn -import numpy as np -import umap -from sklearn.preprocessing import MinMaxScaler +from fastapi import FastAPI, Request, BackgroundTasks # type: ignore +from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse # type: ignore +from fastapi.middleware.cors import CORSMiddleware # type: ignore +import uvicorn # type: ignore +import numpy as np # type: ignore +import umap # type: ignore +from sklearn.preprocessing import MinMaxScaler # type: ignore from utils import ( rag as Rag, @@ -449,6 +449,8 @@ class WebServer: prompt = system_generate_resume case "fact_check": prompt = system_message + case _: + prompt = system_message agent.system_prompt = prompt response["system_prompt"] = { "system_prompt": prompt } @@ -759,9 +761,9 @@ class WebServer: with open(file_path, "r") as f: content = f.read() logger.info(f"Loading context from {file_path}, content length: {len(content)}") + import json try: # Try parsing as JSON first to ensure valid JSON - import json json_data = json.loads(content) logger.info("JSON parsed successfully, attempting model validation") @@ -1426,7 +1428,7 @@ def main(): module="umap.*" ) - llm = ollama.Client(host=args.ollama_server) + llm = ollama.Client(host=args.ollama_server) # type: ignore model = args.ollama_model web_server = WebServer(llm, model) diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 60cee98..51ea43d 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,18 +1,14 @@ from __future__ import annotations -from typing import Optional, Type import importlib -from pydantic import BaseModel -from typing import Type +from pydantic import BaseModel # type: ignore from . import defines -from . rag import ChromaDBFileWatcher, start_file_watcher -from . message import Message -from . conversation import Conversation from . context import Context -from . import agents +from . conversation import Conversation +from . message import Message +from . rag import ChromaDBFileWatcher, start_file_watcher from . setup_logging import setup_logging - from .agents import class_registry, AnyAgent, Agent, __all__ as agents_all __all__ = [ @@ -21,9 +17,11 @@ __all__ = [ 'Conversation', 'Message', 'ChromaDBFileWatcher', - 'start_file_watcher' + 'start_file_watcher', 'logger', -] + agents_all +] + +__all__.extend(agents_all) # type: ignore # Resolve circular dependencies by rebuilding models # Call model_rebuild() on Agent and Context diff --git a/src/utils/agents/__init__.py b/src/utils/agents/__init__.py index 4ba0581..a3eaa99 100644 --- a/src/utils/agents/__init__.py +++ b/src/utils/agents/__init__.py @@ -39,11 +39,11 @@ for path in package_dir.glob("*.py"): ): class_registry[name] = (full_module_name, name) globals()[name] = obj - logger.info(f"Adding agent: {name} from {full_module_name}") - __all__.append(name) + logger.info(f"Adding agent: {name}") + __all__.append(name) # type: ignore except ImportError as e: logger.error(f"Error importing {full_module_name}: {e}") - continue + raise e except Exception as e: logger.error(f"Error processing {full_module_name}: {e}") raise e diff --git a/src/utils/agents/base.py b/src/utils/agents/base.py index c9e0c2b..94d4ea0 100644 --- a/src/utils/agents/base.py +++ b/src/utils/agents/base.py @@ -1,10 +1,18 @@ from __future__ import annotations -from pydantic import BaseModel, model_validator, PrivateAttr, Field -from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef, Any -from abc import ABC, abstractmethod -from typing_extensions import Annotated +from pydantic import BaseModel, PrivateAttr, Field # type: ignore +from typing import ( + Literal, get_args, List, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, Any, +) +from abc import ABC from .. setup_logging import setup_logging from .. import defines +from abc import ABC +import logging +from .. message import Message +from .. import tools as Tools +import json +import time +import inspect logger = setup_logging() @@ -22,6 +30,9 @@ class Agent(BaseModel, ABC): Base class for all agent types. This class defines the common attributes and methods for all agent types. """ + # Agent management with pydantic + agent_type: Literal["base"] = "base" + _agent_type: ClassVar[str] = agent_type # Add this for registration # context_size is shared across all subclasses _context_size: ClassVar[int] = int(defines.max_context * 0.5) @@ -33,10 +44,6 @@ class Agent(BaseModel, ABC): def context_size(self, value: int): Agent._context_size = value - # Agent management with pydantic - agent_type: Literal["base"] = "base" - _agent_type: ClassVar[str] = agent_type # Add this for registration - # Agent properties system_prompt: str # Mandatory conversation: Conversation = Conversation() @@ -101,6 +108,316 @@ class Agent(BaseModel, ABC): # Agent methods def get_agent_type(self): return self._agent_type + + async def process_tool_calls(self, llm: Any, model: str, message: Message, tool_message: Any, messages: List[Any]) -> AsyncGenerator[Message, None]: + logging.info(f"{self.agent_type} - {inspect.stack()[1].function}") + + if not self.context: + raise ValueError("Context is not set for this agent.") + if not message.metadata["tools"]: + raise ValueError("tools field not initialized") + + tool_metadata = message.metadata["tools"] + tool_metadata["messages"] = messages + tool_metadata["tool_calls"] = [] + + message.status = "tooling" + + for i, tool_call in enumerate(tool_message.tool_calls): + arguments = tool_call.function.arguments + tool = tool_call.function.name + + # Yield status update before processing each tool + message.response = f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..." + yield message + logging.info(f"LLM - {message.response}") + + # Process the tool based on its type + match tool: + case "TickerValue": + ticker = arguments.get("ticker") + if not ticker: + ret = None + else: + ret = Tools.TickerValue(ticker) + + case "AnalyzeSite": + url = arguments.get("url") + question = arguments.get("question", "what is the summary of this content?") + + # Additional status update for long-running operations + message.response = f"Retrieving and summarizing content from {url}..." + yield message + ret = await Tools.AnalyzeSite(llm=llm, model=model, url=url, question=question) + + case "DateTime": + tz = arguments.get("timezone") + ret = Tools.DateTime(tz) + + case "WeatherForecast": + city = arguments.get("city") + state = arguments.get("state") + + message.response = f"Fetching weather data for {city}, {state}..." + yield message + ret = Tools.WeatherForecast(city, state) + + case _: + ret = None + + # Build response for this tool + tool_response = { + "role": "tool", + "content": json.dumps(ret), + "name": tool_call.function.name + } + + tool_metadata["tool_calls"].append(tool_response) + + if len(tool_metadata["tool_calls"]) == 0: + message.status = "done" + yield message + return + + message_dict = { + "role": tool_message.get("role", "assistant"), + "content": tool_message.get("content", ""), + "tool_calls": [ { + "function": { + "name": tc["function"]["name"], + "arguments": tc["function"]["arguments"] + } + } for tc in tool_message.tool_calls + ] + } + + messages.append(message_dict) + messages.extend(tool_metadata["tool_calls"]) + + message.status = "thinking" + + # Decrease creativity when processing tool call requests + message.response = "" + start_time = time.perf_counter() + for response in llm.chat( + model=model, + messages=messages, + stream=True, + options={ + **message.metadata["options"], + # "temperature": 0.5, + } + ): + # logging.info(f"LLM::Tools: {'done' if response.done else 'processing'} - {response.message}") + message.status = "streaming" + message.response += response.message.content + if not response.done: + yield message + if response.done: + message.metadata["eval_count"] += response.eval_count + message.metadata["eval_duration"] += response.eval_duration + message.metadata["prompt_eval_count"] += response.prompt_eval_count + message.metadata["prompt_eval_duration"] += response.prompt_eval_duration + self.context_tokens = response.prompt_eval_count + response.eval_count + message.status = "done" + yield message + + end_time = time.perf_counter() + message.metadata["timers"]["llm_with_tools"] = f"{(end_time - start_time):.4f}" + return + + async def generate_llm_response(self, llm: Any, model: str, message: Message) -> AsyncGenerator[Message, None]: + logging.info(f"{self.agent_type} - {inspect.stack()[1].function}") + + if not self.context: + raise ValueError("Context is not set for this agent.") + + messages = [ { "role": "system", "content": message.system_prompt } ] + messages.extend([ + item for m in self.conversation.messages + for item in [ + {"role": "user", "content": m.prompt.strip()}, + {"role": "assistant", "content": m.response.strip()} + ] + ]) + messages.append({ + "role": "user", + "content": message.context_prompt.strip(), + }) + message.metadata["messages"] = messages + message.metadata["options"]={ + "seed": 8911, + "num_ctx": self.context_size, + #"temperature": 0.9, # Higher temperature to encourage tool usage + } + + message.metadata["timers"] = {} + + use_tools = message.enable_tools and len(self.context.tools) > 0 + message.metadata["tools"] = { + "available": Tools.llm_tools(self.context.tools), + "used": False + } + tool_metadata = message.metadata["tools"] + + if use_tools: + message.status = "thinking" + message.response = f"Performing tool analysis step 1/2..." + yield message + + logging.info("Checking for LLM tool usage") + start_time = time.perf_counter() + # Tools are enabled and available, so query the LLM with a short token target to see if it will + # use the tools + tool_metadata["messages"] = [{ "role": "system", "content": self.system_prompt}, {"role": "user", "content": message.prompt}] + response = llm.chat( + model=model, + messages=tool_metadata["messages"], + tools=tool_metadata["available"], + options={ + **message.metadata["options"], + #"num_predict": 1024, # "Low" token limit to cut off after tool call + }, + stream=False # No need to stream the probe + ) + end_time = time.perf_counter() + message.metadata["timers"]["tool_check"] = f"{(end_time - start_time):.4f}" + if not response.message.tool_calls: + logging.info("LLM indicates tools will not be used") + # The LLM will not use tools, so disable use_tools so we can stream the full response + use_tools = False + + if use_tools: + logging.info("LLM indicates tools will be used") + + # Tools are enabled and available and the LLM indicated it will use them + tool_metadata["attempted"] = response.message.tool_calls + message.response = f"Performing tool analysis step 2/2 (tool use suspected)..." + yield message + + logging.info(f"Performing LLM call with tools") + start_time = time.perf_counter() + response = llm.chat( + model=model, + messages=tool_metadata["messages"], # messages, + tools=tool_metadata["available"], + options={ + **message.metadata["options"], + }, + stream=False + ) + end_time = time.perf_counter() + message.metadata["timers"]["non_streaming"] = f"{(end_time - start_time):.4f}" + + if not response: + message.status = "error" + message.response = "No response from LLM." + yield message + return + + if response.message.tool_calls: + tool_metadata["used"] = response.message.tool_calls + # Process all yielded items from the handler + start_time = time.perf_counter() + async for message in self.process_tool_calls(llm=llm, model=model, message=message, tool_message=response.message, messages=messages): + if message.status == "error": + yield message + return + yield message + end_time = time.perf_counter() + message.metadata["timers"]["process_tool_calls"] = f"{(end_time - start_time):.4f}" + message.status = "done" + return + + logging.info("LLM indicated tools will be used, and then they weren't") + message.response = response.message.content + message.status = "done" + yield message + return + + # not use_tools + yield message + # Reset the response for streaming + message.response = "" + start_time = time.perf_counter() + for response in llm.chat( + model=model, + messages=messages, + options={ + **message.metadata["options"], + }, + stream=True, + ): + if not response: + message.status = "error" + message.response = "No response from LLM." + yield message + return + + message.status = "streaming" + message.response += response.message.content + + if not response.done: + yield message + + if response.done: + message.metadata["eval_count"] += response.eval_count + message.metadata["eval_duration"] += response.eval_duration + message.metadata["prompt_eval_count"] += response.prompt_eval_count + message.metadata["prompt_eval_duration"] += response.prompt_eval_duration + self.context_tokens = response.prompt_eval_count + response.eval_count + message.status = "done" + yield message + + end_time = time.perf_counter() + message.metadata["timers"]["streamed"] = f"{(end_time - start_time):.4f}" + return + + async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]: + logging.info(f"{self.agent_type} - {inspect.stack()[1].function}") + + if not self.context: + raise ValueError("Context is not set for this agent.") + + if self.context.processing: + logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time") + message.status = "error" + message.response = "Busy processing another request." + yield message + return + + self.context.processing = True + + message.metadata["system_prompt"] = f"<|system|>\n{self.system_prompt.strip()}\n" + message.context_prompt = "" + for p in message.preamble.keys(): + message.context_prompt += f"\n<|{p}|>\n{message.preamble[p].strip()}\n" + message.context_prompt += f"{message.prompt}" + + # Estimate token length of new messages + message.response = f"Optimizing context..." + message.status = "thinking" + yield message + + message.metadata["context_size"] = self.set_optimal_context_size(llm, model, prompt=message.context_prompt) + + message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..." + message.status = "thinking" + yield message + + async for message in self.generate_llm_response(llm, model, message): + # logging.info(f"LLM: {message.status} - {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}") + if message.status == "error": + yield message + self.context.processing = False + return + yield message + + # Done processing, add message to conversation + message.status = "done" + self.conversation.add_message(message) + self.context.processing = False + return # Register the base agent registry.register(Agent._agent_type, Agent) diff --git a/src/utils/agents/chat.py b/src/utils/agents/chat.py index e4c33c1..bc8711e 100644 --- a/src/utils/agents/chat.py +++ b/src/utils/agents/chat.py @@ -1,26 +1,15 @@ from __future__ import annotations -from pydantic import BaseModel, model_validator, PrivateAttr -from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, Any -from typing_extensions import Annotated -from abc import ABC, abstractmethod -from typing_extensions import Annotated +from typing import Literal, AsyncGenerator, ClassVar, Optional import logging from . base import Agent, registry -from .. conversation import Conversation from .. message import Message -from .. import defines -from .. import tools as Tools -from ollama import ChatResponse -import json -import time import inspect -class Chat(Agent, ABC): +class Chat(Agent): """ - Base class for all agent types. - This class defines the common attributes and methods for all agent types. + Chat Agent """ - agent_type: Literal["chat"] = "chat" + agent_type: Literal["chat"] = "chat" # type: ignore _agent_type: ClassVar[str] = agent_type # Add this for registration async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]: @@ -74,313 +63,5 @@ class Chat(Agent, ABC): yield message return - async def process_tool_calls(self, llm: Any, model: str, message: Message, tool_message: Any, messages: List[Any]) -> AsyncGenerator[Message, None]: - logging.info(f"{self.agent_type} - {inspect.stack()[1].function}") - - if not self.context: - raise ValueError("Context is not set for this agent.") - if not message.metadata["tools"]: - raise ValueError("tools field not initialized") - - tool_metadata = message.metadata["tools"] - tool_metadata["messages"] = messages - tool_metadata["tool_calls"] = [] - - message.status = "tooling" - - for i, tool_call in enumerate(tool_message.tool_calls): - arguments = tool_call.function.arguments - tool = tool_call.function.name - - # Yield status update before processing each tool - message.response = f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..." - yield message - logging.info(f"LLM - {message.response}") - - # Process the tool based on its type - match tool: - case "TickerValue": - ticker = arguments.get("ticker") - if not ticker: - ret = None - else: - ret = Tools.TickerValue(ticker) - - case "AnalyzeSite": - url = arguments.get("url") - question = arguments.get("question", "what is the summary of this content?") - - # Additional status update for long-running operations - message.response = f"Retrieving and summarizing content from {url}..." - yield message - ret = await Tools.AnalyzeSite(llm=llm, model=model, url=url, question=question) - - case "DateTime": - tz = arguments.get("timezone") - ret = Tools.DateTime(tz) - - case "WeatherForecast": - city = arguments.get("city") - state = arguments.get("state") - - message.response = f"Fetching weather data for {city}, {state}..." - yield message - ret = Tools.WeatherForecast(city, state) - - case _: - ret = None - - # Build response for this tool - tool_response = { - "role": "tool", - "content": json.dumps(ret), - "name": tool_call.function.name - } - - tool_metadata["tool_calls"].append(tool_response) - - if len(tool_metadata["tool_calls"]) == 0: - message.status = "done" - yield message - return - - message_dict = { - "role": tool_message.get("role", "assistant"), - "content": tool_message.get("content", ""), - "tool_calls": [ - { - "function": { - "name": tc["function"]["name"], - "arguments": tc["function"]["arguments"] - } - } - for tc in tool_message.tool_calls - ] - } - - messages.append(message_dict) - messages.extend(tool_metadata["tool_calls"]) - - message.status = "thinking" - - # Decrease creativity when processing tool call requests - message.response = "" - start_time = time.perf_counter() - for response in llm.chat( - model=model, - messages=messages, - stream=True, - options={ - **message.metadata["options"], - # "temperature": 0.5, - } - ): - # logging.info(f"LLM::Tools: {'done' if response.done else 'processing'} - {response.message}") - message.status = "streaming" - message.response += response.message.content - if not response.done: - yield message - if response.done: - message.metadata["eval_count"] += response.eval_count - message.metadata["eval_duration"] += response.eval_duration - message.metadata["prompt_eval_count"] += response.prompt_eval_count - message.metadata["prompt_eval_duration"] += response.prompt_eval_duration - self.context_tokens = response.prompt_eval_count + response.eval_count - message.status = "done" - yield message - - end_time = time.perf_counter() - message.metadata["timers"]["llm_with_tools"] = f"{(end_time - start_time):.4f}" - return - - async def generate_llm_response(self, llm: Any, model: str, message: Message) -> AsyncGenerator[Message, None]: - logging.info(f"{self.agent_type} - {inspect.stack()[1].function}") - - if not self.context: - raise ValueError("Context is not set for this agent.") - - messages = [ { "role": "system", "content": message.system_prompt } ] - messages.extend([ - item for m in self.conversation.messages - for item in [ - {"role": "user", "content": m.prompt.strip()}, - {"role": "assistant", "content": m.response.strip()} - ] - ]) - messages.append({ - "role": "user", - "content": message.context_prompt.strip(), - }) - message.metadata["messages"] = messages - message.metadata["options"]={ - "seed": 8911, - "num_ctx": self.context_size, - #"temperature": 0.9, # Higher temperature to encourage tool usage - } - - message.metadata["timers"] = {} - - use_tools = message.enable_tools and len(self.context.tools) > 0 - message.metadata["tools"] = { - "available": Tools.llm_tools(self.context.tools), - "used": False - } - tool_metadata = message.metadata["tools"] - - if use_tools: - message.status = "thinking" - message.response = f"Performing tool analysis step 1/2..." - yield message - - logging.info("Checking for LLM tool usage") - start_time = time.perf_counter() - # Tools are enabled and available, so query the LLM with a short token target to see if it will - # use the tools - tool_metadata["messages"] = [{ "role": "system", "content": self.system_prompt}, {"role": "user", "content": message.prompt}] - response = llm.chat( - model=model, - messages=tool_metadata["messages"], - tools=tool_metadata["available"], - options={ - **message.metadata["options"], - #"num_predict": 1024, # "Low" token limit to cut off after tool call - }, - stream=False # No need to stream the probe - ) - end_time = time.perf_counter() - message.metadata["timers"]["tool_check"] = f"{(end_time - start_time):.4f}" - if not response.message.tool_calls: - logging.info("LLM indicates tools will not be used") - # The LLM will not use tools, so disable use_tools so we can stream the full response - use_tools = False - - if use_tools: - logging.info("LLM indicates tools will be used") - - # Tools are enabled and available and the LLM indicated it will use them - tool_metadata["attempted"] = response.message.tool_calls - message.response = f"Performing tool analysis step 2/2 (tool use suspected)..." - yield message - - logging.info(f"Performing LLM call with tools") - start_time = time.perf_counter() - response = llm.chat( - model=model, - messages=tool_metadata["messages"], # messages, - tools=tool_metadata["available"], - options={ - **message.metadata["options"], - }, - stream=False - ) - end_time = time.perf_counter() - message.metadata["timers"]["non_streaming"] = f"{(end_time - start_time):.4f}" - - if not response: - message.status = "error" - message.response = "No response from LLM." - yield message - return - - if response.message.tool_calls: - tool_metadata["used"] = response.message.tool_calls - # Process all yielded items from the handler - start_time = time.perf_counter() - async for message in self.process_tool_calls(llm=llm, model=model, message=message, tool_message=response.message, messages=messages): - if message.status == "error": - yield message - return - yield message - end_time = time.perf_counter() - message.metadata["timers"]["process_tool_calls"] = f"{(end_time - start_time):.4f}" - message.status = "done" - return - - logging.info("LLM indicated tools will be used, and then they weren't") - message.response = response.message.content - message.status = "done" - yield message - return - - # not use_tools - yield message - # Reset the response for streaming - message.response = "" - start_time = time.perf_counter() - for response in llm.chat( - model=model, - messages=messages, - options={ - **message.metadata["options"], - }, - stream=True, - ): - if not response: - message.status = "error" - message.response = "No response from LLM." - yield message - return - message.status = "streaming" - message.response += response.message.content - if not response.done: - yield message - if response.done: - message.metadata["eval_count"] += response.eval_count - message.metadata["eval_duration"] += response.eval_duration - message.metadata["prompt_eval_count"] += response.prompt_eval_count - message.metadata["prompt_eval_duration"] += response.prompt_eval_duration - self.context_tokens = response.prompt_eval_count + response.eval_count - message.status = "done" - yield message - - end_time = time.perf_counter() - message.metadata["timers"]["streamed"] = f"{(end_time - start_time):.4f}" - return - - async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]: - logging.info(f"{self.agent_type} - {inspect.stack()[1].function}") - - if not self.context: - raise ValueError("Context is not set for this agent.") - - if self.context.processing: - logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time") - message.status = "error" - message.response = "Busy processing another request." - yield message - return - - self.context.processing = True - - message.metadata["system_prompt"] = f"<|system|>\n{self.system_prompt.strip()}\n" - message.context_prompt = "" - for p in message.preamble.keys(): - message.context_prompt += f"\n<|{p}|>\n{message.preamble[p].strip()}\n" - message.context_prompt += f"{message.prompt}" - - # Estimate token length of new messages - message.response = f"Optimizing context..." - message.status = "thinking" - yield message - message.metadata["context_size"] = self.set_optimal_context_size(llm, model, prompt=message.context_prompt) - - message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..." - message.status = "thinking" - yield message - - async for message in self.generate_llm_response(llm, model, message): - # logging.info(f"LLM: {message.status} - {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}") - if message.status == "error": - yield message - self.context.processing = False - return - yield message - - # Done processing, add message to conversation - message.status = "done" - self.conversation.add_message(message) - self.context.processing = False - return - # Register the base agent registry.register(Chat._agent_type, Chat) diff --git a/src/utils/agents/fact_check.py b/src/utils/agents/fact_check.py index 0800387..066ef5c 100644 --- a/src/utils/agents/fact_check.py +++ b/src/utils/agents/fact_check.py @@ -1,15 +1,9 @@ -from pydantic import BaseModel, Field, model_validator, PrivateAttr -from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar -from typing_extensions import Annotated -from abc import ABC, abstractmethod -from typing_extensions import Annotated -import logging +from pydantic import model_validator # type: ignore +from typing import Literal, ClassVar, Optional from .base import Agent, registry -from .. conversation import Conversation -from .. message import Message class FactCheck(Agent): - agent_type: Literal["fact_check"] = "fact_check" + agent_type: Literal["fact_check"] = "fact_check" # type: ignore _agent_type: ClassVar[str] = agent_type # Add this for registration facts: str = "" diff --git a/src/utils/agents/job_description.py b/src/utils/agents/job_description.py index c5da2a1..671e5f7 100644 --- a/src/utils/agents/job_description.py +++ b/src/utils/agents/job_description.py @@ -1,15 +1,12 @@ -from pydantic import BaseModel, Field, model_validator, PrivateAttr -from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar -from typing_extensions import Annotated -from abc import ABC, abstractmethod -from typing_extensions import Annotated -import logging +from pydantic import model_validator # type: ignore +from typing import Literal, ClassVar, Optional from .base import Agent, registry from .. conversation import Conversation from .. message import Message +from abc import ABC class JobDescription(Agent): - agent_type: Literal["job_description"] = "job_description" + agent_type: Literal["job_description"] = "job_description" # type: ignore _agent_type: ClassVar[str] = agent_type # Add this for registration job_description: str = "" diff --git a/src/utils/agents/resume.py b/src/utils/agents/resume.py index 16fec73..aa9b929 100644 --- a/src/utils/agents/resume.py +++ b/src/utils/agents/resume.py @@ -9,7 +9,7 @@ from .. conversation import Conversation from .. message import Message class Resume(Agent): - agent_type: Literal["resume"] = "resume" + agent_type: Literal["resume"] = "resume" # type: ignore _agent_type: ClassVar[str] = agent_type # Add this for registration resume: str = "" diff --git a/src/utils/agents/types.py b/src/utils/agents/types.py index 8938805..5baa9b6 100644 --- a/src/utils/agents/types.py +++ b/src/utils/agents/types.py @@ -1,8 +1,5 @@ from __future__ import annotations -from typing import List, Dict, Any, Union, ForwardRef, TypeVar, Optional, TYPE_CHECKING, Type, ClassVar, Literal -from typing_extensions import Annotated -from pydantic import Field, BaseModel -from abc import ABC, abstractmethod +from typing import List, Dict, ForwardRef, Optional, Type # Forward references AgentRef = ForwardRef('Agent') diff --git a/src/utils/chunk.py b/src/utils/chunk.py index e43a390..636f427 100644 --- a/src/utils/chunk.py +++ b/src/utils/chunk.py @@ -1,4 +1,4 @@ -import tiktoken +import tiktoken # type: ignore from . import defines from typing import List, Dict, Any, Union diff --git a/src/utils/context.py b/src/utils/context.py index f4bb32c..e88acaa 100644 --- a/src/utils/context.py +++ b/src/utils/context.py @@ -1,9 +1,9 @@ from __future__ import annotations -from pydantic import BaseModel, Field, model_validator, ValidationError +from pydantic import BaseModel, Field, model_validator# type: ignore from uuid import uuid4 -from typing import List, Dict, Any, Optional, Generator, TYPE_CHECKING +from typing import List, Optional, Generator from typing_extensions import Annotated, Union -import numpy as np +import numpy as np # type: ignore import logging from uuid import uuid4 import re @@ -34,7 +34,7 @@ class Context(BaseModel): rags: List[dict] = [] message_history_length: int = 5 # Class managed fields - agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( + agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( # type: ignore default_factory=list ) diff --git a/src/utils/message.py b/src/utils/message.py index d106bd2..6478f8a 100644 --- a/src/utils/message.py +++ b/src/utils/message.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, model_validator +from pydantic import BaseModel # type: ignore from typing import Dict, List, Optional, Any from datetime import datetime, timezone diff --git a/src/utils/rag.py b/src/utils/rag.py index 76c2db5..2a1cb9c 100644 --- a/src/utils/rag.py +++ b/src/utils/rag.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field, model_validator, PrivateAttr +from typing import List import os import glob from pathlib import Path @@ -13,18 +13,18 @@ import hashlib import asyncio import json import pickle -import numpy as np +import numpy as np # type: ignore import re import chromadb import ollama -from langchain.text_splitter import CharacterTextSplitter -from sentence_transformers import SentenceTransformer -from langchain.schema import Document -from watchdog.observers import Observer -from watchdog.events import FileSystemEventHandler -import umap -from markitdown import MarkItDown +from langchain.text_splitter import CharacterTextSplitter # type: ignore +from sentence_transformers import SentenceTransformer # type: ignore +from langchain.schema import Document # type: ignore +from watchdog.observers import Observer # type: ignore +from watchdog.events import FileSystemEventHandler # type: ignore +import umap # type: ignore +from markitdown import MarkItDown # type: ignore # Import your existing modules if __name__ == "__main__": @@ -53,11 +53,10 @@ class ChromaDBFileWatcher(FileSystemEventHandler): self.chunk_overlap = chunk_overlap self.loop = loop self._umap_collection = None - self._umap_embedding_2d = [] + self._umap_embedding_2d : List[int]= [] self._umap_embedding_3d = [] - self._umap_model_2d = None - self._umap_model_3d = None - self._collection = None + self._umap_model_2d : umap.UMAP = None + self._umap_model_3d : umap.UMAP = None self.md = MarkItDown(enable_plugins=False) # Set to True to enable plugins #self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') @@ -94,11 +93,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler): return self._umap_collection @property - def umap_embedding_2d(self): + def umap_embedding_2d(self) -> List[int]: return self._umap_embedding_2d @property - def umap_embedding_3d(self): + def umap_embedding_3d(self) -> List[int]: return self._umap_embedding_3d @property @@ -289,9 +288,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler): def _get_vector_collection(self, recreate=False): """Get or create a ChromaDB collection.""" # Initialize ChromaDB client - chroma_client = chromadb.PersistentClient( + chroma_client = chromadb.PersistentClient( # type: ignore path=self.persist_directory, - settings=chromadb.Settings(anonymized_telemetry=False) + settings=chromadb.Settings(anonymized_telemetry=False) # type: ignore ) # Check if the collection exists @@ -577,7 +576,7 @@ if __name__ == "__main__": import defines # Initialize Ollama client - llm = ollama.Client(host=defines.ollama_api_url) + llm = ollama.Client(host=defines.ollama_api_url) # type: ignore # Start the file watcher (with initialization) observer, file_watcher = start_file_watcher(