Tools are working and shared context is in use aross all agents
This commit is contained in:
parent
cc0f6974ff
commit
3fe2cfd9ef
@ -1,6 +1,6 @@
|
|||||||
from utils import logger
|
from utils import logger
|
||||||
|
|
||||||
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# Imports [standard]
|
# Imports [standard]
|
||||||
@ -40,13 +40,13 @@ try_import("sklearn")
|
|||||||
import ollama
|
import ollama
|
||||||
import requests
|
import requests
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI, Request, BackgroundTasks
|
from fastapi import FastAPI, Request, BackgroundTasks # type: ignore
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse
|
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse, RedirectResponse # type: ignore
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware # type: ignore
|
||||||
import uvicorn
|
import uvicorn # type: ignore
|
||||||
import numpy as np
|
import numpy as np # type: ignore
|
||||||
import umap
|
import umap # type: ignore
|
||||||
from sklearn.preprocessing import MinMaxScaler
|
from sklearn.preprocessing import MinMaxScaler # type: ignore
|
||||||
|
|
||||||
from utils import (
|
from utils import (
|
||||||
rag as Rag,
|
rag as Rag,
|
||||||
@ -449,6 +449,8 @@ class WebServer:
|
|||||||
prompt = system_generate_resume
|
prompt = system_generate_resume
|
||||||
case "fact_check":
|
case "fact_check":
|
||||||
prompt = system_message
|
prompt = system_message
|
||||||
|
case _:
|
||||||
|
prompt = system_message
|
||||||
|
|
||||||
agent.system_prompt = prompt
|
agent.system_prompt = prompt
|
||||||
response["system_prompt"] = { "system_prompt": prompt }
|
response["system_prompt"] = { "system_prompt": prompt }
|
||||||
@ -759,9 +761,9 @@ class WebServer:
|
|||||||
with open(file_path, "r") as f:
|
with open(file_path, "r") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
logger.info(f"Loading context from {file_path}, content length: {len(content)}")
|
logger.info(f"Loading context from {file_path}, content length: {len(content)}")
|
||||||
|
import json
|
||||||
try:
|
try:
|
||||||
# Try parsing as JSON first to ensure valid JSON
|
# Try parsing as JSON first to ensure valid JSON
|
||||||
import json
|
|
||||||
json_data = json.loads(content)
|
json_data = json.loads(content)
|
||||||
logger.info("JSON parsed successfully, attempting model validation")
|
logger.info("JSON parsed successfully, attempting model validation")
|
||||||
|
|
||||||
@ -1426,7 +1428,7 @@ def main():
|
|||||||
module="umap.*"
|
module="umap.*"
|
||||||
)
|
)
|
||||||
|
|
||||||
llm = ollama.Client(host=args.ollama_server)
|
llm = ollama.Client(host=args.ollama_server) # type: ignore
|
||||||
model = args.ollama_model
|
model = args.ollama_model
|
||||||
|
|
||||||
web_server = WebServer(llm, model)
|
web_server = WebServer(llm, model)
|
||||||
|
@ -1,18 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Optional, Type
|
|
||||||
import importlib
|
import importlib
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel # type: ignore
|
||||||
from typing import Type
|
|
||||||
|
|
||||||
from . import defines
|
from . import defines
|
||||||
from . rag import ChromaDBFileWatcher, start_file_watcher
|
|
||||||
from . message import Message
|
|
||||||
from . conversation import Conversation
|
|
||||||
from . context import Context
|
from . context import Context
|
||||||
from . import agents
|
from . conversation import Conversation
|
||||||
|
from . message import Message
|
||||||
|
from . rag import ChromaDBFileWatcher, start_file_watcher
|
||||||
from . setup_logging import setup_logging
|
from . setup_logging import setup_logging
|
||||||
|
|
||||||
from .agents import class_registry, AnyAgent, Agent, __all__ as agents_all
|
from .agents import class_registry, AnyAgent, Agent, __all__ as agents_all
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -21,9 +17,11 @@ __all__ = [
|
|||||||
'Conversation',
|
'Conversation',
|
||||||
'Message',
|
'Message',
|
||||||
'ChromaDBFileWatcher',
|
'ChromaDBFileWatcher',
|
||||||
'start_file_watcher'
|
'start_file_watcher',
|
||||||
'logger',
|
'logger',
|
||||||
] + agents_all
|
]
|
||||||
|
|
||||||
|
__all__.extend(agents_all) # type: ignore
|
||||||
|
|
||||||
# Resolve circular dependencies by rebuilding models
|
# Resolve circular dependencies by rebuilding models
|
||||||
# Call model_rebuild() on Agent and Context
|
# Call model_rebuild() on Agent and Context
|
||||||
|
@ -39,11 +39,11 @@ for path in package_dir.glob("*.py"):
|
|||||||
):
|
):
|
||||||
class_registry[name] = (full_module_name, name)
|
class_registry[name] = (full_module_name, name)
|
||||||
globals()[name] = obj
|
globals()[name] = obj
|
||||||
logger.info(f"Adding agent: {name} from {full_module_name}")
|
logger.info(f"Adding agent: {name}")
|
||||||
__all__.append(name)
|
__all__.append(name) # type: ignore
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.error(f"Error importing {full_module_name}: {e}")
|
logger.error(f"Error importing {full_module_name}: {e}")
|
||||||
continue
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing {full_module_name}: {e}")
|
logger.error(f"Error processing {full_module_name}: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
@ -1,10 +1,18 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import BaseModel, model_validator, PrivateAttr, Field
|
from pydantic import BaseModel, PrivateAttr, Field # type: ignore
|
||||||
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, ForwardRef, Any
|
from typing import (
|
||||||
from abc import ABC, abstractmethod
|
Literal, get_args, List, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, Any,
|
||||||
from typing_extensions import Annotated
|
)
|
||||||
|
from abc import ABC
|
||||||
from .. setup_logging import setup_logging
|
from .. setup_logging import setup_logging
|
||||||
from .. import defines
|
from .. import defines
|
||||||
|
from abc import ABC
|
||||||
|
import logging
|
||||||
|
from .. message import Message
|
||||||
|
from .. import tools as Tools
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import inspect
|
||||||
|
|
||||||
logger = setup_logging()
|
logger = setup_logging()
|
||||||
|
|
||||||
@ -22,6 +30,9 @@ class Agent(BaseModel, ABC):
|
|||||||
Base class for all agent types.
|
Base class for all agent types.
|
||||||
This class defines the common attributes and methods for all agent types.
|
This class defines the common attributes and methods for all agent types.
|
||||||
"""
|
"""
|
||||||
|
# Agent management with pydantic
|
||||||
|
agent_type: Literal["base"] = "base"
|
||||||
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
# context_size is shared across all subclasses
|
# context_size is shared across all subclasses
|
||||||
_context_size: ClassVar[int] = int(defines.max_context * 0.5)
|
_context_size: ClassVar[int] = int(defines.max_context * 0.5)
|
||||||
@ -33,10 +44,6 @@ class Agent(BaseModel, ABC):
|
|||||||
def context_size(self, value: int):
|
def context_size(self, value: int):
|
||||||
Agent._context_size = value
|
Agent._context_size = value
|
||||||
|
|
||||||
# Agent management with pydantic
|
|
||||||
agent_type: Literal["base"] = "base"
|
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
|
||||||
|
|
||||||
# Agent properties
|
# Agent properties
|
||||||
system_prompt: str # Mandatory
|
system_prompt: str # Mandatory
|
||||||
conversation: Conversation = Conversation()
|
conversation: Conversation = Conversation()
|
||||||
@ -101,6 +108,316 @@ class Agent(BaseModel, ABC):
|
|||||||
# Agent methods
|
# Agent methods
|
||||||
def get_agent_type(self):
|
def get_agent_type(self):
|
||||||
return self._agent_type
|
return self._agent_type
|
||||||
|
|
||||||
|
async def process_tool_calls(self, llm: Any, model: str, message: Message, tool_message: Any, messages: List[Any]) -> AsyncGenerator[Message, None]:
|
||||||
|
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
|
||||||
|
|
||||||
|
if not self.context:
|
||||||
|
raise ValueError("Context is not set for this agent.")
|
||||||
|
if not message.metadata["tools"]:
|
||||||
|
raise ValueError("tools field not initialized")
|
||||||
|
|
||||||
|
tool_metadata = message.metadata["tools"]
|
||||||
|
tool_metadata["messages"] = messages
|
||||||
|
tool_metadata["tool_calls"] = []
|
||||||
|
|
||||||
|
message.status = "tooling"
|
||||||
|
|
||||||
|
for i, tool_call in enumerate(tool_message.tool_calls):
|
||||||
|
arguments = tool_call.function.arguments
|
||||||
|
tool = tool_call.function.name
|
||||||
|
|
||||||
|
# Yield status update before processing each tool
|
||||||
|
message.response = f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..."
|
||||||
|
yield message
|
||||||
|
logging.info(f"LLM - {message.response}")
|
||||||
|
|
||||||
|
# Process the tool based on its type
|
||||||
|
match tool:
|
||||||
|
case "TickerValue":
|
||||||
|
ticker = arguments.get("ticker")
|
||||||
|
if not ticker:
|
||||||
|
ret = None
|
||||||
|
else:
|
||||||
|
ret = Tools.TickerValue(ticker)
|
||||||
|
|
||||||
|
case "AnalyzeSite":
|
||||||
|
url = arguments.get("url")
|
||||||
|
question = arguments.get("question", "what is the summary of this content?")
|
||||||
|
|
||||||
|
# Additional status update for long-running operations
|
||||||
|
message.response = f"Retrieving and summarizing content from {url}..."
|
||||||
|
yield message
|
||||||
|
ret = await Tools.AnalyzeSite(llm=llm, model=model, url=url, question=question)
|
||||||
|
|
||||||
|
case "DateTime":
|
||||||
|
tz = arguments.get("timezone")
|
||||||
|
ret = Tools.DateTime(tz)
|
||||||
|
|
||||||
|
case "WeatherForecast":
|
||||||
|
city = arguments.get("city")
|
||||||
|
state = arguments.get("state")
|
||||||
|
|
||||||
|
message.response = f"Fetching weather data for {city}, {state}..."
|
||||||
|
yield message
|
||||||
|
ret = Tools.WeatherForecast(city, state)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
ret = None
|
||||||
|
|
||||||
|
# Build response for this tool
|
||||||
|
tool_response = {
|
||||||
|
"role": "tool",
|
||||||
|
"content": json.dumps(ret),
|
||||||
|
"name": tool_call.function.name
|
||||||
|
}
|
||||||
|
|
||||||
|
tool_metadata["tool_calls"].append(tool_response)
|
||||||
|
|
||||||
|
if len(tool_metadata["tool_calls"]) == 0:
|
||||||
|
message.status = "done"
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
message_dict = {
|
||||||
|
"role": tool_message.get("role", "assistant"),
|
||||||
|
"content": tool_message.get("content", ""),
|
||||||
|
"tool_calls": [ {
|
||||||
|
"function": {
|
||||||
|
"name": tc["function"]["name"],
|
||||||
|
"arguments": tc["function"]["arguments"]
|
||||||
|
}
|
||||||
|
} for tc in tool_message.tool_calls
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
messages.append(message_dict)
|
||||||
|
messages.extend(tool_metadata["tool_calls"])
|
||||||
|
|
||||||
|
message.status = "thinking"
|
||||||
|
|
||||||
|
# Decrease creativity when processing tool call requests
|
||||||
|
message.response = ""
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
for response in llm.chat(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
stream=True,
|
||||||
|
options={
|
||||||
|
**message.metadata["options"],
|
||||||
|
# "temperature": 0.5,
|
||||||
|
}
|
||||||
|
):
|
||||||
|
# logging.info(f"LLM::Tools: {'done' if response.done else 'processing'} - {response.message}")
|
||||||
|
message.status = "streaming"
|
||||||
|
message.response += response.message.content
|
||||||
|
if not response.done:
|
||||||
|
yield message
|
||||||
|
if response.done:
|
||||||
|
message.metadata["eval_count"] += response.eval_count
|
||||||
|
message.metadata["eval_duration"] += response.eval_duration
|
||||||
|
message.metadata["prompt_eval_count"] += response.prompt_eval_count
|
||||||
|
message.metadata["prompt_eval_duration"] += response.prompt_eval_duration
|
||||||
|
self.context_tokens = response.prompt_eval_count + response.eval_count
|
||||||
|
message.status = "done"
|
||||||
|
yield message
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
message.metadata["timers"]["llm_with_tools"] = f"{(end_time - start_time):.4f}"
|
||||||
|
return
|
||||||
|
|
||||||
|
async def generate_llm_response(self, llm: Any, model: str, message: Message) -> AsyncGenerator[Message, None]:
|
||||||
|
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
|
||||||
|
|
||||||
|
if not self.context:
|
||||||
|
raise ValueError("Context is not set for this agent.")
|
||||||
|
|
||||||
|
messages = [ { "role": "system", "content": message.system_prompt } ]
|
||||||
|
messages.extend([
|
||||||
|
item for m in self.conversation.messages
|
||||||
|
for item in [
|
||||||
|
{"role": "user", "content": m.prompt.strip()},
|
||||||
|
{"role": "assistant", "content": m.response.strip()}
|
||||||
|
]
|
||||||
|
])
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": message.context_prompt.strip(),
|
||||||
|
})
|
||||||
|
message.metadata["messages"] = messages
|
||||||
|
message.metadata["options"]={
|
||||||
|
"seed": 8911,
|
||||||
|
"num_ctx": self.context_size,
|
||||||
|
#"temperature": 0.9, # Higher temperature to encourage tool usage
|
||||||
|
}
|
||||||
|
|
||||||
|
message.metadata["timers"] = {}
|
||||||
|
|
||||||
|
use_tools = message.enable_tools and len(self.context.tools) > 0
|
||||||
|
message.metadata["tools"] = {
|
||||||
|
"available": Tools.llm_tools(self.context.tools),
|
||||||
|
"used": False
|
||||||
|
}
|
||||||
|
tool_metadata = message.metadata["tools"]
|
||||||
|
|
||||||
|
if use_tools:
|
||||||
|
message.status = "thinking"
|
||||||
|
message.response = f"Performing tool analysis step 1/2..."
|
||||||
|
yield message
|
||||||
|
|
||||||
|
logging.info("Checking for LLM tool usage")
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
# Tools are enabled and available, so query the LLM with a short token target to see if it will
|
||||||
|
# use the tools
|
||||||
|
tool_metadata["messages"] = [{ "role": "system", "content": self.system_prompt}, {"role": "user", "content": message.prompt}]
|
||||||
|
response = llm.chat(
|
||||||
|
model=model,
|
||||||
|
messages=tool_metadata["messages"],
|
||||||
|
tools=tool_metadata["available"],
|
||||||
|
options={
|
||||||
|
**message.metadata["options"],
|
||||||
|
#"num_predict": 1024, # "Low" token limit to cut off after tool call
|
||||||
|
},
|
||||||
|
stream=False # No need to stream the probe
|
||||||
|
)
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
message.metadata["timers"]["tool_check"] = f"{(end_time - start_time):.4f}"
|
||||||
|
if not response.message.tool_calls:
|
||||||
|
logging.info("LLM indicates tools will not be used")
|
||||||
|
# The LLM will not use tools, so disable use_tools so we can stream the full response
|
||||||
|
use_tools = False
|
||||||
|
|
||||||
|
if use_tools:
|
||||||
|
logging.info("LLM indicates tools will be used")
|
||||||
|
|
||||||
|
# Tools are enabled and available and the LLM indicated it will use them
|
||||||
|
tool_metadata["attempted"] = response.message.tool_calls
|
||||||
|
message.response = f"Performing tool analysis step 2/2 (tool use suspected)..."
|
||||||
|
yield message
|
||||||
|
|
||||||
|
logging.info(f"Performing LLM call with tools")
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
response = llm.chat(
|
||||||
|
model=model,
|
||||||
|
messages=tool_metadata["messages"], # messages,
|
||||||
|
tools=tool_metadata["available"],
|
||||||
|
options={
|
||||||
|
**message.metadata["options"],
|
||||||
|
},
|
||||||
|
stream=False
|
||||||
|
)
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
message.metadata["timers"]["non_streaming"] = f"{(end_time - start_time):.4f}"
|
||||||
|
|
||||||
|
if not response:
|
||||||
|
message.status = "error"
|
||||||
|
message.response = "No response from LLM."
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
if response.message.tool_calls:
|
||||||
|
tool_metadata["used"] = response.message.tool_calls
|
||||||
|
# Process all yielded items from the handler
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
async for message in self.process_tool_calls(llm=llm, model=model, message=message, tool_message=response.message, messages=messages):
|
||||||
|
if message.status == "error":
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
yield message
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
message.metadata["timers"]["process_tool_calls"] = f"{(end_time - start_time):.4f}"
|
||||||
|
message.status = "done"
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.info("LLM indicated tools will be used, and then they weren't")
|
||||||
|
message.response = response.message.content
|
||||||
|
message.status = "done"
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
# not use_tools
|
||||||
|
yield message
|
||||||
|
# Reset the response for streaming
|
||||||
|
message.response = ""
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
for response in llm.chat(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
options={
|
||||||
|
**message.metadata["options"],
|
||||||
|
},
|
||||||
|
stream=True,
|
||||||
|
):
|
||||||
|
if not response:
|
||||||
|
message.status = "error"
|
||||||
|
message.response = "No response from LLM."
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
message.status = "streaming"
|
||||||
|
message.response += response.message.content
|
||||||
|
|
||||||
|
if not response.done:
|
||||||
|
yield message
|
||||||
|
|
||||||
|
if response.done:
|
||||||
|
message.metadata["eval_count"] += response.eval_count
|
||||||
|
message.metadata["eval_duration"] += response.eval_duration
|
||||||
|
message.metadata["prompt_eval_count"] += response.prompt_eval_count
|
||||||
|
message.metadata["prompt_eval_duration"] += response.prompt_eval_duration
|
||||||
|
self.context_tokens = response.prompt_eval_count + response.eval_count
|
||||||
|
message.status = "done"
|
||||||
|
yield message
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
message.metadata["timers"]["streamed"] = f"{(end_time - start_time):.4f}"
|
||||||
|
return
|
||||||
|
|
||||||
|
async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]:
|
||||||
|
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
|
||||||
|
|
||||||
|
if not self.context:
|
||||||
|
raise ValueError("Context is not set for this agent.")
|
||||||
|
|
||||||
|
if self.context.processing:
|
||||||
|
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
||||||
|
message.status = "error"
|
||||||
|
message.response = "Busy processing another request."
|
||||||
|
yield message
|
||||||
|
return
|
||||||
|
|
||||||
|
self.context.processing = True
|
||||||
|
|
||||||
|
message.metadata["system_prompt"] = f"<|system|>\n{self.system_prompt.strip()}\n"
|
||||||
|
message.context_prompt = ""
|
||||||
|
for p in message.preamble.keys():
|
||||||
|
message.context_prompt += f"\n<|{p}|>\n{message.preamble[p].strip()}\n"
|
||||||
|
message.context_prompt += f"{message.prompt}"
|
||||||
|
|
||||||
|
# Estimate token length of new messages
|
||||||
|
message.response = f"Optimizing context..."
|
||||||
|
message.status = "thinking"
|
||||||
|
yield message
|
||||||
|
|
||||||
|
message.metadata["context_size"] = self.set_optimal_context_size(llm, model, prompt=message.context_prompt)
|
||||||
|
|
||||||
|
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
|
||||||
|
message.status = "thinking"
|
||||||
|
yield message
|
||||||
|
|
||||||
|
async for message in self.generate_llm_response(llm, model, message):
|
||||||
|
# logging.info(f"LLM: {message.status} - {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}")
|
||||||
|
if message.status == "error":
|
||||||
|
yield message
|
||||||
|
self.context.processing = False
|
||||||
|
return
|
||||||
|
yield message
|
||||||
|
|
||||||
|
# Done processing, add message to conversation
|
||||||
|
message.status = "done"
|
||||||
|
self.conversation.add_message(message)
|
||||||
|
self.context.processing = False
|
||||||
|
return
|
||||||
|
|
||||||
# Register the base agent
|
# Register the base agent
|
||||||
registry.register(Agent._agent_type, Agent)
|
registry.register(Agent._agent_type, Agent)
|
||||||
|
@ -1,26 +1,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import BaseModel, model_validator, PrivateAttr
|
from typing import Literal, AsyncGenerator, ClassVar, Optional
|
||||||
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar, Any
|
|
||||||
from typing_extensions import Annotated
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing_extensions import Annotated
|
|
||||||
import logging
|
import logging
|
||||||
from . base import Agent, registry
|
from . base import Agent, registry
|
||||||
from .. conversation import Conversation
|
|
||||||
from .. message import Message
|
from .. message import Message
|
||||||
from .. import defines
|
|
||||||
from .. import tools as Tools
|
|
||||||
from ollama import ChatResponse
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
class Chat(Agent, ABC):
|
class Chat(Agent):
|
||||||
"""
|
"""
|
||||||
Base class for all agent types.
|
Chat Agent
|
||||||
This class defines the common attributes and methods for all agent types.
|
|
||||||
"""
|
"""
|
||||||
agent_type: Literal["chat"] = "chat"
|
agent_type: Literal["chat"] = "chat" # type: ignore
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
async def prepare_message(self, message:Message) -> AsyncGenerator[Message, None]:
|
||||||
@ -74,313 +63,5 @@ class Chat(Agent, ABC):
|
|||||||
yield message
|
yield message
|
||||||
return
|
return
|
||||||
|
|
||||||
async def process_tool_calls(self, llm: Any, model: str, message: Message, tool_message: Any, messages: List[Any]) -> AsyncGenerator[Message, None]:
|
|
||||||
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
|
|
||||||
|
|
||||||
if not self.context:
|
|
||||||
raise ValueError("Context is not set for this agent.")
|
|
||||||
if not message.metadata["tools"]:
|
|
||||||
raise ValueError("tools field not initialized")
|
|
||||||
|
|
||||||
tool_metadata = message.metadata["tools"]
|
|
||||||
tool_metadata["messages"] = messages
|
|
||||||
tool_metadata["tool_calls"] = []
|
|
||||||
|
|
||||||
message.status = "tooling"
|
|
||||||
|
|
||||||
for i, tool_call in enumerate(tool_message.tool_calls):
|
|
||||||
arguments = tool_call.function.arguments
|
|
||||||
tool = tool_call.function.name
|
|
||||||
|
|
||||||
# Yield status update before processing each tool
|
|
||||||
message.response = f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..."
|
|
||||||
yield message
|
|
||||||
logging.info(f"LLM - {message.response}")
|
|
||||||
|
|
||||||
# Process the tool based on its type
|
|
||||||
match tool:
|
|
||||||
case "TickerValue":
|
|
||||||
ticker = arguments.get("ticker")
|
|
||||||
if not ticker:
|
|
||||||
ret = None
|
|
||||||
else:
|
|
||||||
ret = Tools.TickerValue(ticker)
|
|
||||||
|
|
||||||
case "AnalyzeSite":
|
|
||||||
url = arguments.get("url")
|
|
||||||
question = arguments.get("question", "what is the summary of this content?")
|
|
||||||
|
|
||||||
# Additional status update for long-running operations
|
|
||||||
message.response = f"Retrieving and summarizing content from {url}..."
|
|
||||||
yield message
|
|
||||||
ret = await Tools.AnalyzeSite(llm=llm, model=model, url=url, question=question)
|
|
||||||
|
|
||||||
case "DateTime":
|
|
||||||
tz = arguments.get("timezone")
|
|
||||||
ret = Tools.DateTime(tz)
|
|
||||||
|
|
||||||
case "WeatherForecast":
|
|
||||||
city = arguments.get("city")
|
|
||||||
state = arguments.get("state")
|
|
||||||
|
|
||||||
message.response = f"Fetching weather data for {city}, {state}..."
|
|
||||||
yield message
|
|
||||||
ret = Tools.WeatherForecast(city, state)
|
|
||||||
|
|
||||||
case _:
|
|
||||||
ret = None
|
|
||||||
|
|
||||||
# Build response for this tool
|
|
||||||
tool_response = {
|
|
||||||
"role": "tool",
|
|
||||||
"content": json.dumps(ret),
|
|
||||||
"name": tool_call.function.name
|
|
||||||
}
|
|
||||||
|
|
||||||
tool_metadata["tool_calls"].append(tool_response)
|
|
||||||
|
|
||||||
if len(tool_metadata["tool_calls"]) == 0:
|
|
||||||
message.status = "done"
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
message_dict = {
|
|
||||||
"role": tool_message.get("role", "assistant"),
|
|
||||||
"content": tool_message.get("content", ""),
|
|
||||||
"tool_calls": [
|
|
||||||
{
|
|
||||||
"function": {
|
|
||||||
"name": tc["function"]["name"],
|
|
||||||
"arguments": tc["function"]["arguments"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for tc in tool_message.tool_calls
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
messages.append(message_dict)
|
|
||||||
messages.extend(tool_metadata["tool_calls"])
|
|
||||||
|
|
||||||
message.status = "thinking"
|
|
||||||
|
|
||||||
# Decrease creativity when processing tool call requests
|
|
||||||
message.response = ""
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
for response in llm.chat(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
stream=True,
|
|
||||||
options={
|
|
||||||
**message.metadata["options"],
|
|
||||||
# "temperature": 0.5,
|
|
||||||
}
|
|
||||||
):
|
|
||||||
# logging.info(f"LLM::Tools: {'done' if response.done else 'processing'} - {response.message}")
|
|
||||||
message.status = "streaming"
|
|
||||||
message.response += response.message.content
|
|
||||||
if not response.done:
|
|
||||||
yield message
|
|
||||||
if response.done:
|
|
||||||
message.metadata["eval_count"] += response.eval_count
|
|
||||||
message.metadata["eval_duration"] += response.eval_duration
|
|
||||||
message.metadata["prompt_eval_count"] += response.prompt_eval_count
|
|
||||||
message.metadata["prompt_eval_duration"] += response.prompt_eval_duration
|
|
||||||
self.context_tokens = response.prompt_eval_count + response.eval_count
|
|
||||||
message.status = "done"
|
|
||||||
yield message
|
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
|
||||||
message.metadata["timers"]["llm_with_tools"] = f"{(end_time - start_time):.4f}"
|
|
||||||
return
|
|
||||||
|
|
||||||
async def generate_llm_response(self, llm: Any, model: str, message: Message) -> AsyncGenerator[Message, None]:
|
|
||||||
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
|
|
||||||
|
|
||||||
if not self.context:
|
|
||||||
raise ValueError("Context is not set for this agent.")
|
|
||||||
|
|
||||||
messages = [ { "role": "system", "content": message.system_prompt } ]
|
|
||||||
messages.extend([
|
|
||||||
item for m in self.conversation.messages
|
|
||||||
for item in [
|
|
||||||
{"role": "user", "content": m.prompt.strip()},
|
|
||||||
{"role": "assistant", "content": m.response.strip()}
|
|
||||||
]
|
|
||||||
])
|
|
||||||
messages.append({
|
|
||||||
"role": "user",
|
|
||||||
"content": message.context_prompt.strip(),
|
|
||||||
})
|
|
||||||
message.metadata["messages"] = messages
|
|
||||||
message.metadata["options"]={
|
|
||||||
"seed": 8911,
|
|
||||||
"num_ctx": self.context_size,
|
|
||||||
#"temperature": 0.9, # Higher temperature to encourage tool usage
|
|
||||||
}
|
|
||||||
|
|
||||||
message.metadata["timers"] = {}
|
|
||||||
|
|
||||||
use_tools = message.enable_tools and len(self.context.tools) > 0
|
|
||||||
message.metadata["tools"] = {
|
|
||||||
"available": Tools.llm_tools(self.context.tools),
|
|
||||||
"used": False
|
|
||||||
}
|
|
||||||
tool_metadata = message.metadata["tools"]
|
|
||||||
|
|
||||||
if use_tools:
|
|
||||||
message.status = "thinking"
|
|
||||||
message.response = f"Performing tool analysis step 1/2..."
|
|
||||||
yield message
|
|
||||||
|
|
||||||
logging.info("Checking for LLM tool usage")
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
# Tools are enabled and available, so query the LLM with a short token target to see if it will
|
|
||||||
# use the tools
|
|
||||||
tool_metadata["messages"] = [{ "role": "system", "content": self.system_prompt}, {"role": "user", "content": message.prompt}]
|
|
||||||
response = llm.chat(
|
|
||||||
model=model,
|
|
||||||
messages=tool_metadata["messages"],
|
|
||||||
tools=tool_metadata["available"],
|
|
||||||
options={
|
|
||||||
**message.metadata["options"],
|
|
||||||
#"num_predict": 1024, # "Low" token limit to cut off after tool call
|
|
||||||
},
|
|
||||||
stream=False # No need to stream the probe
|
|
||||||
)
|
|
||||||
end_time = time.perf_counter()
|
|
||||||
message.metadata["timers"]["tool_check"] = f"{(end_time - start_time):.4f}"
|
|
||||||
if not response.message.tool_calls:
|
|
||||||
logging.info("LLM indicates tools will not be used")
|
|
||||||
# The LLM will not use tools, so disable use_tools so we can stream the full response
|
|
||||||
use_tools = False
|
|
||||||
|
|
||||||
if use_tools:
|
|
||||||
logging.info("LLM indicates tools will be used")
|
|
||||||
|
|
||||||
# Tools are enabled and available and the LLM indicated it will use them
|
|
||||||
tool_metadata["attempted"] = response.message.tool_calls
|
|
||||||
message.response = f"Performing tool analysis step 2/2 (tool use suspected)..."
|
|
||||||
yield message
|
|
||||||
|
|
||||||
logging.info(f"Performing LLM call with tools")
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
response = llm.chat(
|
|
||||||
model=model,
|
|
||||||
messages=tool_metadata["messages"], # messages,
|
|
||||||
tools=tool_metadata["available"],
|
|
||||||
options={
|
|
||||||
**message.metadata["options"],
|
|
||||||
},
|
|
||||||
stream=False
|
|
||||||
)
|
|
||||||
end_time = time.perf_counter()
|
|
||||||
message.metadata["timers"]["non_streaming"] = f"{(end_time - start_time):.4f}"
|
|
||||||
|
|
||||||
if not response:
|
|
||||||
message.status = "error"
|
|
||||||
message.response = "No response from LLM."
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
if response.message.tool_calls:
|
|
||||||
tool_metadata["used"] = response.message.tool_calls
|
|
||||||
# Process all yielded items from the handler
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
async for message in self.process_tool_calls(llm=llm, model=model, message=message, tool_message=response.message, messages=messages):
|
|
||||||
if message.status == "error":
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
yield message
|
|
||||||
end_time = time.perf_counter()
|
|
||||||
message.metadata["timers"]["process_tool_calls"] = f"{(end_time - start_time):.4f}"
|
|
||||||
message.status = "done"
|
|
||||||
return
|
|
||||||
|
|
||||||
logging.info("LLM indicated tools will be used, and then they weren't")
|
|
||||||
message.response = response.message.content
|
|
||||||
message.status = "done"
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
# not use_tools
|
|
||||||
yield message
|
|
||||||
# Reset the response for streaming
|
|
||||||
message.response = ""
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
for response in llm.chat(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
options={
|
|
||||||
**message.metadata["options"],
|
|
||||||
},
|
|
||||||
stream=True,
|
|
||||||
):
|
|
||||||
if not response:
|
|
||||||
message.status = "error"
|
|
||||||
message.response = "No response from LLM."
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
message.status = "streaming"
|
|
||||||
message.response += response.message.content
|
|
||||||
if not response.done:
|
|
||||||
yield message
|
|
||||||
if response.done:
|
|
||||||
message.metadata["eval_count"] += response.eval_count
|
|
||||||
message.metadata["eval_duration"] += response.eval_duration
|
|
||||||
message.metadata["prompt_eval_count"] += response.prompt_eval_count
|
|
||||||
message.metadata["prompt_eval_duration"] += response.prompt_eval_duration
|
|
||||||
self.context_tokens = response.prompt_eval_count + response.eval_count
|
|
||||||
message.status = "done"
|
|
||||||
yield message
|
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
|
||||||
message.metadata["timers"]["streamed"] = f"{(end_time - start_time):.4f}"
|
|
||||||
return
|
|
||||||
|
|
||||||
async def process_message(self, llm: Any, model: str, message:Message) -> AsyncGenerator[Message, None]:
|
|
||||||
logging.info(f"{self.agent_type} - {inspect.stack()[1].function}")
|
|
||||||
|
|
||||||
if not self.context:
|
|
||||||
raise ValueError("Context is not set for this agent.")
|
|
||||||
|
|
||||||
if self.context.processing:
|
|
||||||
logging.info("TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time")
|
|
||||||
message.status = "error"
|
|
||||||
message.response = "Busy processing another request."
|
|
||||||
yield message
|
|
||||||
return
|
|
||||||
|
|
||||||
self.context.processing = True
|
|
||||||
|
|
||||||
message.metadata["system_prompt"] = f"<|system|>\n{self.system_prompt.strip()}\n"
|
|
||||||
message.context_prompt = ""
|
|
||||||
for p in message.preamble.keys():
|
|
||||||
message.context_prompt += f"\n<|{p}|>\n{message.preamble[p].strip()}\n"
|
|
||||||
message.context_prompt += f"{message.prompt}"
|
|
||||||
|
|
||||||
# Estimate token length of new messages
|
|
||||||
message.response = f"Optimizing context..."
|
|
||||||
message.status = "thinking"
|
|
||||||
yield message
|
|
||||||
message.metadata["context_size"] = self.set_optimal_context_size(llm, model, prompt=message.context_prompt)
|
|
||||||
|
|
||||||
message.response = f"Processing {'RAG augmented ' if message.metadata['rag'] else ''}query..."
|
|
||||||
message.status = "thinking"
|
|
||||||
yield message
|
|
||||||
|
|
||||||
async for message in self.generate_llm_response(llm, model, message):
|
|
||||||
# logging.info(f"LLM: {message.status} - {f'...{message.response[-20:]}' if len(message.response) > 20 else message.response}")
|
|
||||||
if message.status == "error":
|
|
||||||
yield message
|
|
||||||
self.context.processing = False
|
|
||||||
return
|
|
||||||
yield message
|
|
||||||
|
|
||||||
# Done processing, add message to conversation
|
|
||||||
message.status = "done"
|
|
||||||
self.conversation.add_message(message)
|
|
||||||
self.context.processing = False
|
|
||||||
return
|
|
||||||
|
|
||||||
# Register the base agent
|
# Register the base agent
|
||||||
registry.register(Chat._agent_type, Chat)
|
registry.register(Chat._agent_type, Chat)
|
||||||
|
@ -1,15 +1,9 @@
|
|||||||
from pydantic import BaseModel, Field, model_validator, PrivateAttr
|
from pydantic import model_validator # type: ignore
|
||||||
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
|
from typing import Literal, ClassVar, Optional
|
||||||
from typing_extensions import Annotated
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing_extensions import Annotated
|
|
||||||
import logging
|
|
||||||
from .base import Agent, registry
|
from .base import Agent, registry
|
||||||
from .. conversation import Conversation
|
|
||||||
from .. message import Message
|
|
||||||
|
|
||||||
class FactCheck(Agent):
|
class FactCheck(Agent):
|
||||||
agent_type: Literal["fact_check"] = "fact_check"
|
agent_type: Literal["fact_check"] = "fact_check" # type: ignore
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
facts: str = ""
|
facts: str = ""
|
||||||
|
@ -1,15 +1,12 @@
|
|||||||
from pydantic import BaseModel, Field, model_validator, PrivateAttr
|
from pydantic import model_validator # type: ignore
|
||||||
from typing import Literal, TypeAlias, get_args, List, Generator, Iterator, AsyncGenerator, TYPE_CHECKING, Optional, ClassVar
|
from typing import Literal, ClassVar, Optional
|
||||||
from typing_extensions import Annotated
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing_extensions import Annotated
|
|
||||||
import logging
|
|
||||||
from .base import Agent, registry
|
from .base import Agent, registry
|
||||||
from .. conversation import Conversation
|
from .. conversation import Conversation
|
||||||
from .. message import Message
|
from .. message import Message
|
||||||
|
from abc import ABC
|
||||||
|
|
||||||
class JobDescription(Agent):
|
class JobDescription(Agent):
|
||||||
agent_type: Literal["job_description"] = "job_description"
|
agent_type: Literal["job_description"] = "job_description" # type: ignore
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
job_description: str = ""
|
job_description: str = ""
|
||||||
|
@ -9,7 +9,7 @@ from .. conversation import Conversation
|
|||||||
from .. message import Message
|
from .. message import Message
|
||||||
|
|
||||||
class Resume(Agent):
|
class Resume(Agent):
|
||||||
agent_type: Literal["resume"] = "resume"
|
agent_type: Literal["resume"] = "resume" # type: ignore
|
||||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||||
|
|
||||||
resume: str = ""
|
resume: str = ""
|
||||||
|
@ -1,8 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import List, Dict, Any, Union, ForwardRef, TypeVar, Optional, TYPE_CHECKING, Type, ClassVar, Literal
|
from typing import List, Dict, ForwardRef, Optional, Type
|
||||||
from typing_extensions import Annotated
|
|
||||||
from pydantic import Field, BaseModel
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
# Forward references
|
# Forward references
|
||||||
AgentRef = ForwardRef('Agent')
|
AgentRef = ForwardRef('Agent')
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import tiktoken
|
import tiktoken # type: ignore
|
||||||
from . import defines
|
from . import defines
|
||||||
from typing import List, Dict, Any, Union
|
from typing import List, Dict, Any, Union
|
||||||
|
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from pydantic import BaseModel, Field, model_validator, ValidationError
|
from pydantic import BaseModel, Field, model_validator# type: ignore
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from typing import List, Dict, Any, Optional, Generator, TYPE_CHECKING
|
from typing import List, Optional, Generator
|
||||||
from typing_extensions import Annotated, Union
|
from typing_extensions import Annotated, Union
|
||||||
import numpy as np
|
import numpy as np # type: ignore
|
||||||
import logging
|
import logging
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
import re
|
import re
|
||||||
@ -34,7 +34,7 @@ class Context(BaseModel):
|
|||||||
rags: List[dict] = []
|
rags: List[dict] = []
|
||||||
message_history_length: int = 5
|
message_history_length: int = 5
|
||||||
# Class managed fields
|
# Class managed fields
|
||||||
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field(
|
agents: List[Annotated[Union[*Agent.__subclasses__()], Field(discriminator="agent_type")]] = Field( # type: ignore
|
||||||
default_factory=list
|
default_factory=list
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel # type: ignore
|
||||||
from typing import Dict, List, Optional, Any
|
from typing import Dict, List, Optional, Any
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from pydantic import BaseModel, Field, model_validator, PrivateAttr
|
from typing import List
|
||||||
import os
|
import os
|
||||||
import glob
|
import glob
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -13,18 +13,18 @@ import hashlib
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import pickle
|
import pickle
|
||||||
import numpy as np
|
import numpy as np # type: ignore
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
import ollama
|
import ollama
|
||||||
from langchain.text_splitter import CharacterTextSplitter
|
from langchain.text_splitter import CharacterTextSplitter # type: ignore
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer # type: ignore
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document # type: ignore
|
||||||
from watchdog.observers import Observer
|
from watchdog.observers import Observer # type: ignore
|
||||||
from watchdog.events import FileSystemEventHandler
|
from watchdog.events import FileSystemEventHandler # type: ignore
|
||||||
import umap
|
import umap # type: ignore
|
||||||
from markitdown import MarkItDown
|
from markitdown import MarkItDown # type: ignore
|
||||||
|
|
||||||
# Import your existing modules
|
# Import your existing modules
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -53,11 +53,10 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
self.chunk_overlap = chunk_overlap
|
self.chunk_overlap = chunk_overlap
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
self._umap_collection = None
|
self._umap_collection = None
|
||||||
self._umap_embedding_2d = []
|
self._umap_embedding_2d : List[int]= []
|
||||||
self._umap_embedding_3d = []
|
self._umap_embedding_3d = []
|
||||||
self._umap_model_2d = None
|
self._umap_model_2d : umap.UMAP = None
|
||||||
self._umap_model_3d = None
|
self._umap_model_3d : umap.UMAP = None
|
||||||
self._collection = None
|
|
||||||
self.md = MarkItDown(enable_plugins=False) # Set to True to enable plugins
|
self.md = MarkItDown(enable_plugins=False) # Set to True to enable plugins
|
||||||
|
|
||||||
#self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
#self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||||
@ -94,11 +93,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
return self._umap_collection
|
return self._umap_collection
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def umap_embedding_2d(self):
|
def umap_embedding_2d(self) -> List[int]:
|
||||||
return self._umap_embedding_2d
|
return self._umap_embedding_2d
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def umap_embedding_3d(self):
|
def umap_embedding_3d(self) -> List[int]:
|
||||||
return self._umap_embedding_3d
|
return self._umap_embedding_3d
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -289,9 +288,9 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
|||||||
def _get_vector_collection(self, recreate=False):
|
def _get_vector_collection(self, recreate=False):
|
||||||
"""Get or create a ChromaDB collection."""
|
"""Get or create a ChromaDB collection."""
|
||||||
# Initialize ChromaDB client
|
# Initialize ChromaDB client
|
||||||
chroma_client = chromadb.PersistentClient(
|
chroma_client = chromadb.PersistentClient( # type: ignore
|
||||||
path=self.persist_directory,
|
path=self.persist_directory,
|
||||||
settings=chromadb.Settings(anonymized_telemetry=False)
|
settings=chromadb.Settings(anonymized_telemetry=False) # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the collection exists
|
# Check if the collection exists
|
||||||
@ -577,7 +576,7 @@ if __name__ == "__main__":
|
|||||||
import defines
|
import defines
|
||||||
|
|
||||||
# Initialize Ollama client
|
# Initialize Ollama client
|
||||||
llm = ollama.Client(host=defines.ollama_api_url)
|
llm = ollama.Client(host=defines.ollama_api_url) # type: ignore
|
||||||
|
|
||||||
# Start the file watcher (with initialization)
|
# Start the file watcher (with initialization)
|
||||||
observer, file_watcher = start_file_watcher(
|
observer, file_watcher = start_file_watcher(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user