1039 lines
40 KiB
Python
1039 lines
40 KiB
Python
from __future__ import annotations
|
|
from pydantic import BaseModel, Field # type: ignore
|
|
from typing import (
|
|
Literal,
|
|
get_args,
|
|
List,
|
|
AsyncGenerator,
|
|
Optional,
|
|
ClassVar,
|
|
Any,
|
|
)
|
|
import time
|
|
import re
|
|
from abc import ABC
|
|
from datetime import datetime, UTC
|
|
from prometheus_client import CollectorRegistry # type: ignore
|
|
import numpy as np # type: ignore
|
|
import json_extractor as json_extractor
|
|
from pydantic import BaseModel, Field # type: ignore
|
|
from uuid import uuid4
|
|
from typing import List, Optional, ClassVar, Any, Literal
|
|
|
|
from datetime import datetime, UTC
|
|
import numpy as np # type: ignore
|
|
|
|
from uuid import uuid4
|
|
from prometheus_client import CollectorRegistry # type: ignore
|
|
import os
|
|
import re
|
|
from pathlib import Path
|
|
|
|
from rag import start_file_watcher, ChromaDBFileWatcher
|
|
import defines
|
|
from logger import logger
|
|
from models import (Tunables, ChatMessageUser, ChatMessage, RagEntry, ChatMessageMetaData, ApiStatusType, Candidate, ChatContextType)
|
|
import utils.llm_proxy as llm_manager
|
|
from database.manager import RedisDatabase
|
|
from models import ChromaDBGetResponse
|
|
from utils.metrics import Metrics
|
|
|
|
|
|
from models import ( ApiActivityType, ApiMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiStatusType, ChatMessageMetaData, Candidate)
|
|
from logger import logger
|
|
import defines
|
|
from .registry import agent_registry
|
|
|
|
from models import ( ChromaDBGetResponse )
|
|
class CandidateEntity(Candidate):
|
|
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
|
|
|
|
id: str = Field(default_factory=lambda: str(uuid4()), description="Unique identifier for the entity")
|
|
last_accessed: datetime = Field(default_factory=lambda: datetime.now(UTC), description="Last accessed timestamp")
|
|
reference_count: int = Field(default=0, description="Number of active references to this entity")
|
|
|
|
async def cleanup(self):
|
|
"""Cleanup resources associated with this entity"""
|
|
|
|
# Internal instance members
|
|
CandidateEntity__agents: List[Agent] = []
|
|
CandidateEntity__observer: Optional[Any] = Field(default=None, exclude=True)
|
|
CandidateEntity__file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
|
|
CandidateEntity__prometheus_collector: Optional[CollectorRegistry] = Field(
|
|
default=None, exclude=True
|
|
)
|
|
|
|
CandidateEntity__metrics: Optional[Metrics] = Field(
|
|
default=None,
|
|
description="Metrics collector for this agent, used to track performance and usage."
|
|
)
|
|
|
|
def __init__(self, candidate=None):
|
|
if candidate is not None:
|
|
# Copy attributes from the candidate instance
|
|
super().__init__(**vars(candidate))
|
|
else:
|
|
raise ValueError("CandidateEntity must be initialized with a Candidate instance or attributes")
|
|
|
|
@classmethod
|
|
def exists(cls, username: str):
|
|
# Validate username format (only allow safe characters)
|
|
if not re.match(r'^[a-zA-Z0-9_-]+$', username):
|
|
return False # Invalid username characters
|
|
|
|
# Check for minimum and maximum length
|
|
if not (3 <= len(username) <= 32):
|
|
return False # Invalid username length
|
|
|
|
# Use Path for safe path handling and normalization
|
|
user_dir = Path(defines.user_dir) / username
|
|
user_info_path = user_dir / defines.user_info_file
|
|
|
|
# Ensure the final path is actually within the intended parent directory
|
|
# to help prevent directory traversal attacks
|
|
try:
|
|
if not user_dir.resolve().is_relative_to(Path(defines.user_dir).resolve()):
|
|
return False # Path traversal attempt detected
|
|
except (ValueError, RuntimeError): # Potential exceptions from resolve()
|
|
return False
|
|
|
|
# Check if file exists
|
|
return user_info_path.is_file()
|
|
|
|
def get_or_create_agent(self, agent_type: ChatContextType) -> Agent:
|
|
"""
|
|
Get or create an agent of the specified type for this candidate.
|
|
|
|
Args:
|
|
agent_type: The type of agent to create (default is 'candidate_chat').
|
|
**kwargs: Additional fields required by the specific agent subclass.
|
|
|
|
Returns:
|
|
The created agent instance.
|
|
"""
|
|
|
|
# Only instantiate one agent of each type per user
|
|
for agent in self.CandidateEntity__agents:
|
|
if agent.agent_type == agent_type:
|
|
return agent
|
|
|
|
return get_or_create_agent(
|
|
agent_type=agent_type,
|
|
user=self,
|
|
prometheus_collector=self.prometheus_collector
|
|
)
|
|
|
|
# Wrapper properties that map into file_watcher
|
|
@property
|
|
def umap_collection(self) -> ChromaDBGetResponse:
|
|
if not self.CandidateEntity__file_watcher:
|
|
raise ValueError("initialize() has not been called.")
|
|
return self.CandidateEntity__file_watcher.umap_collection
|
|
|
|
# Fields managed by initialize()
|
|
CandidateEntity__initialized: bool = Field(default=False, exclude=True)
|
|
@property
|
|
def metrics(self) -> Metrics:
|
|
if not self.CandidateEntity__metrics:
|
|
raise ValueError("initialize() has not been called.")
|
|
return self.CandidateEntity__metrics
|
|
|
|
@property
|
|
def file_watcher(self) -> ChromaDBFileWatcher:
|
|
if not self.CandidateEntity__file_watcher:
|
|
raise ValueError("initialize() has not been called.")
|
|
return self.CandidateEntity__file_watcher
|
|
|
|
@property
|
|
def prometheus_collector(self) -> CollectorRegistry:
|
|
if not self.CandidateEntity__prometheus_collector:
|
|
raise ValueError("initialize() has not been called with a prometheus_collector.")
|
|
return self.CandidateEntity__prometheus_collector
|
|
|
|
@property
|
|
def observer(self) -> Any:
|
|
if not self.CandidateEntity__observer:
|
|
raise ValueError("initialize() has not been called.")
|
|
return self.CandidateEntity__observer
|
|
|
|
def collect_metrics(self, agent: Agent, response):
|
|
if not self.metrics:
|
|
logger.warning("No metrics collector set for this agent.")
|
|
return
|
|
self.metrics.tokens_prompt.labels(agent=agent.agent_type).inc(
|
|
response.usage.prompt_eval_count
|
|
)
|
|
self.metrics.tokens_eval.labels(agent=agent.agent_type).inc(response.usage.eval_count)
|
|
|
|
async def initialize(
|
|
self,
|
|
prometheus_collector: CollectorRegistry,
|
|
database: RedisDatabase):
|
|
if self.CandidateEntity__initialized:
|
|
# Initialization can only be attempted once; if there are multiple attempts, it means
|
|
# a subsystem is failing or there is a logic bug in the code.
|
|
#
|
|
# NOTE: It is intentional that self.CandidateEntity__initialize = True regardless of whether it
|
|
# succeeded. This prevents server loops on failure
|
|
raise ValueError("initialize can only be attempted once")
|
|
self.CandidateEntity__initialized = True
|
|
|
|
if not self.username:
|
|
raise ValueError("username can not be empty")
|
|
|
|
if not prometheus_collector:
|
|
raise ValueError("prometheus_collector can not be None")
|
|
|
|
self.CandidateEntity__prometheus_collector = prometheus_collector
|
|
self.CandidateEntity__metrics = Metrics(prometheus_collector=self.prometheus_collector)
|
|
|
|
user_dir = os.path.join(defines.user_dir, self.username)
|
|
vector_db_dir=os.path.join(user_dir, defines.persist_directory)
|
|
rag_content_dir=os.path.join(user_dir, defines.rag_content_dir)
|
|
|
|
os.makedirs(vector_db_dir, exist_ok=True)
|
|
os.makedirs(rag_content_dir, exist_ok=True)
|
|
|
|
self.CandidateEntity__observer, self.CandidateEntity__file_watcher = start_file_watcher(
|
|
llm=llm_manager.get_llm(),
|
|
user_id=self.id,
|
|
collection_name=self.username,
|
|
persist_directory=vector_db_dir,
|
|
watch_directory=rag_content_dir,
|
|
database=database,
|
|
recreate=False, # Don't recreate if exists
|
|
)
|
|
has_username_rag = any(item.name == self.username for item in self.rags)
|
|
if not has_username_rag:
|
|
self.rags.append(RagEntry(
|
|
name=self.username,
|
|
description=f"Expert data about {self.full_name}.",
|
|
))
|
|
self.rag_content_size = self.file_watcher.collection.count()
|
|
|
|
class Agent(BaseModel, ABC):
|
|
"""
|
|
Base class for all agent types.
|
|
This class defines the common attributes and methods for all agent types.
|
|
"""
|
|
class Config:
|
|
arbitrary_types_allowed = True # Allow arbitrary types like RedisDatabase
|
|
|
|
# Agent management with pydantic
|
|
agent_type: Literal["base"] = "base"
|
|
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
|
|
|
user: Optional[CandidateEntity] = None
|
|
|
|
# Tunables (sets default for new Messages attached to this agent)
|
|
tunables: Tunables = Field(default_factory=Tunables)
|
|
|
|
# Agent properties
|
|
system_prompt: str = ""
|
|
context_tokens: int = 0
|
|
|
|
# context_size is shared across all subclasses
|
|
_context_size: ClassVar[int] = int(defines.max_context * 0.5)
|
|
|
|
conversation: List[ChatMessageUser] = Field(
|
|
default_factory=list,
|
|
description="Conversation history for this agent, used to maintain context across messages."
|
|
)
|
|
|
|
@property
|
|
def context_size(self) -> int:
|
|
return Agent._context_size
|
|
|
|
@context_size.setter
|
|
def context_size(self, value: int):
|
|
Agent._context_size = value
|
|
|
|
async def get_last_item(self, generator):
|
|
last_item = None
|
|
async for item in generator:
|
|
last_item = item
|
|
return last_item
|
|
|
|
def set_optimal_context_size(
|
|
self, llm: Any, model: str, prompt: str, ctx_buffer=2048
|
|
) -> int:
|
|
# Most models average 1.3-1.5 tokens per word
|
|
word_count = len(prompt.split())
|
|
tokens = int(word_count * 1.4)
|
|
|
|
# Add buffer for safety
|
|
total_ctx = tokens + ctx_buffer
|
|
|
|
if total_ctx > self.context_size:
|
|
logger.info(
|
|
f"Increasing context size from {self.context_size} to {total_ctx}"
|
|
)
|
|
|
|
# Grow the context size if necessary
|
|
self.context_size = max(self.context_size, total_ctx)
|
|
# Use actual model maximum context size
|
|
return self.context_size
|
|
|
|
# Class and pydantic model management
|
|
def __init_subclass__(cls, **kwargs) -> None:
|
|
"""Auto-register subclasses"""
|
|
super().__init_subclass__(**kwargs)
|
|
# Register this class if it has an agent_type
|
|
if hasattr(cls, "agent_type") and cls.agent_type != Agent._agent_type:
|
|
agent_registry.register(cls.agent_type, cls)
|
|
|
|
def model_dump(self, *args, **kwargs) -> Any:
|
|
# Ensure context is always excluded, even with exclude_unset=True
|
|
kwargs.setdefault("exclude", set())
|
|
if isinstance(kwargs["exclude"], set):
|
|
kwargs["exclude"].add("context")
|
|
elif isinstance(kwargs["exclude"], dict):
|
|
kwargs["exclude"]["context"] = True
|
|
return super().model_dump(*args, **kwargs)
|
|
|
|
@classmethod
|
|
def valid_agent_types(cls) -> set[str]:
|
|
"""Return the set of valid agent_type values."""
|
|
return set(get_args(cls.__annotations__["agent_type"]))
|
|
|
|
# Agent methods
|
|
def get_agent_type(self):
|
|
return self._agent_type
|
|
|
|
# async def process_tool_calls(
|
|
# self,
|
|
# llm: Any,
|
|
# model: str,
|
|
# message: ChatMessage,
|
|
# tool_message: Any, # llama response
|
|
# messages: List[LLMMessage],
|
|
# ) -> AsyncGenerator[ChatMessage, None]:
|
|
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
|
|
|
# self.metrics.tool_count.labels(agent=self.agent_type).inc()
|
|
# with self.metrics.tool_duration.labels(agent=self.agent_type).time():
|
|
|
|
# if not self.context:
|
|
# raise ValueError("Context is not set for this agent.")
|
|
# if not message.metadata.tools:
|
|
# raise ValueError("tools field not initialized")
|
|
|
|
# tool_metadata = message.metadata.tools
|
|
# tool_metadata["tool_calls"] = []
|
|
|
|
# message.status = "tooling"
|
|
|
|
# for i, tool_call in enumerate(tool_message.tool_calls):
|
|
# arguments = tool_call.function.arguments
|
|
# tool = tool_call.function.name
|
|
|
|
# # Yield status update before processing each tool
|
|
# message.content = (
|
|
# f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..."
|
|
# )
|
|
# yield message
|
|
# logger.info(f"LLM - {message.content}")
|
|
|
|
# # Process the tool based on its type
|
|
# match tool:
|
|
# case "TickerValue":
|
|
# ticker = arguments.get("ticker")
|
|
# if not ticker:
|
|
# ret = None
|
|
# else:
|
|
# ret = TickerValue(ticker)
|
|
|
|
# case "AnalyzeSite":
|
|
# url = arguments.get("url")
|
|
# question = arguments.get(
|
|
# "question", "what is the summary of this content?"
|
|
# )
|
|
|
|
# # Additional status update for long-running operations
|
|
# message.content = (
|
|
# f"Retrieving and summarizing content from {url}..."
|
|
# )
|
|
# yield message
|
|
# ret = await AnalyzeSite(
|
|
# llm=llm, model=model, url=url, question=question
|
|
# )
|
|
|
|
# case "GenerateImage":
|
|
# prompt = arguments.get("prompt", None)
|
|
# if not prompt:
|
|
# logger.info("No prompt supplied to GenerateImage")
|
|
# ret = { "error": "No prompt supplied to GenerateImage" }
|
|
|
|
# # Additional status update for long-running operations
|
|
# message.content = (
|
|
# f"Generating image for {prompt}..."
|
|
# )
|
|
# yield message
|
|
# ret = await GenerateImage(
|
|
# llm=llm, model=model, prompt=prompt
|
|
# )
|
|
# logger.info("GenerateImage returning", ret)
|
|
|
|
# case "DateTime":
|
|
# tz = arguments.get("timezone")
|
|
# ret = DateTime(tz)
|
|
|
|
# case "WeatherForecast":
|
|
# city = arguments.get("city")
|
|
# state = arguments.get("state")
|
|
|
|
# message.content = (
|
|
# f"Fetching weather data for {city}, {state}..."
|
|
# )
|
|
# yield message
|
|
# ret = WeatherForecast(city, state)
|
|
|
|
# case _:
|
|
# logger.error(f"Requested tool {tool} does not exist")
|
|
# ret = None
|
|
|
|
# # Build response for this tool
|
|
# tool_response = {
|
|
# "role": "tool",
|
|
# "content": json.dumps(ret),
|
|
# "name": tool_call.function.name,
|
|
# }
|
|
|
|
# tool_metadata["tool_calls"].append(tool_response)
|
|
|
|
# if len(tool_metadata["tool_calls"]) == 0:
|
|
# message.status = "done"
|
|
# yield message
|
|
# return
|
|
|
|
# message_dict = LLMMessage(
|
|
# role=tool_message.get("role", "assistant"),
|
|
# content=tool_message.get("content", ""),
|
|
# tool_calls=[
|
|
# {
|
|
# "function": {
|
|
# "name": tc["function"]["name"],
|
|
# "arguments": tc["function"]["arguments"],
|
|
# }
|
|
# }
|
|
# for tc in tool_message.tool_calls
|
|
# ],
|
|
# )
|
|
|
|
# messages.append(message_dict)
|
|
# messages.extend(tool_metadata["tool_calls"])
|
|
|
|
# message.status = "thinking"
|
|
# message.content = "Incorporating tool results into response..."
|
|
# yield message
|
|
|
|
# # Decrease creativity when processing tool call requests
|
|
# message.content = ""
|
|
# start_time = time.perf_counter()
|
|
# for response in llm.chat(
|
|
# model=model,
|
|
# messages=messages,
|
|
# options={
|
|
# **message.metadata.options,
|
|
# },
|
|
# stream=True,
|
|
# ):
|
|
# # logger.info(f"LLM::Tools: {'done' if response.finish_reason else 'processing'} - {response}")
|
|
# message.status = "streaming"
|
|
# message.chunk = response.content
|
|
# message.content += message.chunk
|
|
|
|
# if not response.finish_reason:
|
|
# yield message
|
|
|
|
# if response.finish_reason:
|
|
# self.collect_metrics(response)
|
|
# message.metadata.eval_count += response.eval_count
|
|
# message.metadata.eval_duration += response.eval_duration
|
|
# message.metadata.prompt_eval_count += response.prompt_eval_count
|
|
# message.metadata.prompt_eval_duration += response.prompt_eval_duration
|
|
# self.context_tokens = (
|
|
# response.prompt_eval_count + response.eval_count
|
|
# )
|
|
# message.status = "done"
|
|
# yield message
|
|
|
|
# end_time = time.perf_counter()
|
|
# message.metadata.timers["llm_with_tools"] = end_time - start_time
|
|
# return
|
|
|
|
def get_rag_context(self, rag_message: ChatMessageRagSearch) -> str:
|
|
"""
|
|
Extracts the RAG context from the rag_message.
|
|
"""
|
|
if not rag_message.content:
|
|
return ""
|
|
|
|
context = []
|
|
for chroma_results in rag_message.content:
|
|
for index, metadata in enumerate(chroma_results.metadatas):
|
|
content = "\n".join([
|
|
line.strip()
|
|
for line in chroma_results.documents[index].split("\n")
|
|
if line
|
|
]).strip()
|
|
context.append(f"""
|
|
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")}
|
|
Document reference: {chroma_results.ids[index]}
|
|
Content: {content}
|
|
""")
|
|
return "\n".join(context)
|
|
|
|
async def generate_rag_results(
|
|
self,
|
|
session_id: str,
|
|
prompt: str,
|
|
top_k: int=defines.default_rag_top_k,
|
|
threshold: float=defines.default_rag_threshold,
|
|
) -> AsyncGenerator[ApiMessage, None]:
|
|
"""
|
|
Generate RAG results for the given query.
|
|
|
|
Args:
|
|
query: The query string to generate RAG results for.
|
|
|
|
Returns:
|
|
A list of dictionaries containing the RAG results.
|
|
"""
|
|
if not self.user:
|
|
error_message = ChatMessageError(
|
|
session_id=session_id,
|
|
content="No user set for RAG generation."
|
|
)
|
|
yield error_message
|
|
return
|
|
|
|
results : List[ChromaDBGetResponse] = []
|
|
user: CandidateEntity = self.user
|
|
for rag in user.rags:
|
|
if not rag.enabled:
|
|
continue
|
|
|
|
status_message = ChatMessageStatus(
|
|
session_id=session_id,
|
|
activity=ApiActivityType.SEARCHING,
|
|
content = f"Searching RAG context {rag.name}..."
|
|
)
|
|
yield status_message
|
|
|
|
try:
|
|
chroma_results = await user.file_watcher.find_similar(
|
|
query=prompt, top_k=top_k, threshold=threshold
|
|
)
|
|
if not chroma_results:
|
|
continue
|
|
query_embedding = np.array(chroma_results["query_embedding"]).flatten() # type: ignore
|
|
|
|
umap_2d = user.file_watcher.umap_model_2d.transform([query_embedding])[0] # type: ignore
|
|
umap_3d = user.file_watcher.umap_model_3d.transform([query_embedding])[0] # type: ignore
|
|
|
|
rag_metadata = ChromaDBGetResponse(
|
|
name=rag.name,
|
|
query=prompt,
|
|
query_embedding=query_embedding.tolist(),
|
|
ids=chroma_results.get("ids", []),
|
|
embeddings=chroma_results.get("embeddings", []),
|
|
documents=chroma_results.get("documents", []),
|
|
metadatas=chroma_results.get("metadatas", []),
|
|
umap_embedding_2d=umap_2d.tolist(),
|
|
umap_embedding_3d=umap_3d.tolist(),
|
|
)
|
|
results.append(rag_metadata)
|
|
except Exception as e:
|
|
continue_message = ChatMessageStatus(
|
|
session_id=session_id,
|
|
activity=ApiActivityType.SEARCHING,
|
|
content=f"Error searching RAG context {rag.name}: {str(e)}"
|
|
)
|
|
yield continue_message
|
|
|
|
final_message = ChatMessageRagSearch(
|
|
session_id=session_id,
|
|
content=results,
|
|
status=ApiStatusType.DONE,
|
|
)
|
|
yield final_message
|
|
return
|
|
|
|
async def llm_one_shot(
|
|
self,
|
|
llm: Any, model: str,
|
|
session_id: str, prompt: str, system_prompt: str,
|
|
tunables: Optional[Tunables] = None,
|
|
temperature=0.7) -> AsyncGenerator[ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessage, None]:
|
|
|
|
if not self.user:
|
|
error_message = ChatMessageError(
|
|
session_id=session_id,
|
|
content="No user set for chat generation."
|
|
)
|
|
yield error_message
|
|
return
|
|
|
|
self.set_optimal_context_size(
|
|
llm=llm, model=model, prompt=prompt+system_prompt
|
|
)
|
|
|
|
options = ChatOptions(
|
|
seed=8911,
|
|
num_ctx=self.context_size,
|
|
temperature=temperature,
|
|
)
|
|
|
|
messages: List[LLMMessage] = [
|
|
LLMMessage(role="system", content=system_prompt),
|
|
LLMMessage(role="user", content=prompt),
|
|
]
|
|
|
|
status_message = ChatMessageStatus(
|
|
session_id=session_id,
|
|
activity=ApiActivityType.GENERATING,
|
|
content="Generating response..."
|
|
)
|
|
yield status_message
|
|
|
|
logger.info(f"Message options: {options.model_dump(exclude_unset=True)}")
|
|
response = None
|
|
content = ""
|
|
async for response in llm.chat_stream(
|
|
model=model,
|
|
messages=messages,
|
|
options={
|
|
**options.model_dump(exclude_unset=True),
|
|
},
|
|
stream=True,
|
|
):
|
|
if not response:
|
|
error_message = ChatMessageError(
|
|
session_id=session_id,
|
|
content="No response from LLM."
|
|
)
|
|
yield error_message
|
|
return
|
|
|
|
content += response.content
|
|
|
|
if not response.finish_reason:
|
|
streaming_message = ChatMessageStreaming(
|
|
session_id=session_id,
|
|
content=response.content,
|
|
status=ApiStatusType.STREAMING,
|
|
)
|
|
yield streaming_message
|
|
|
|
if not response:
|
|
error_message = ChatMessageError(
|
|
session_id=session_id,
|
|
content="No response from LLM."
|
|
)
|
|
yield error_message
|
|
return
|
|
|
|
self.user.collect_metrics(agent=self, response=response)
|
|
self.context_tokens = (
|
|
response.usage.prompt_eval_count + response.usage.eval_count
|
|
)
|
|
|
|
chat_message = ChatMessage(
|
|
session_id=session_id,
|
|
tunables=tunables,
|
|
status=ApiStatusType.DONE,
|
|
content=content,
|
|
metadata = ChatMessageMetaData(
|
|
options=options,
|
|
eval_count=response.usage.eval_count,
|
|
eval_duration=response.usage.eval_duration,
|
|
prompt_eval_count=response.usage.prompt_eval_count,
|
|
prompt_eval_duration=response.usage.prompt_eval_duration,
|
|
|
|
)
|
|
)
|
|
yield chat_message
|
|
return
|
|
|
|
async def generate(
|
|
self, llm: Any, model: str,
|
|
session_id: str, prompt: str,
|
|
tunables: Optional[Tunables] = None,
|
|
temperature=0.7
|
|
) -> AsyncGenerator[ApiMessage, None]:
|
|
if not self.user:
|
|
error_message = ChatMessageError(
|
|
session_id=session_id,
|
|
content="No user set for chat generation."
|
|
)
|
|
yield error_message
|
|
return
|
|
|
|
user_message = ChatMessageUser(
|
|
session_id=session_id,
|
|
content=prompt,
|
|
)
|
|
|
|
self.user.metrics.generate_count.labels(agent=self.agent_type).inc()
|
|
with self.user.metrics.generate_duration.labels(agent=self.agent_type).time():
|
|
context = None
|
|
rag_message : ChatMessageRagSearch | None = None
|
|
if self.user:
|
|
message = None
|
|
async for message in self.generate_rag_results(session_id=session_id, prompt=prompt):
|
|
if message.status == ApiStatusType.ERROR:
|
|
yield message
|
|
return
|
|
# Only yield messages that are in a streaming state
|
|
if message.status == ApiStatusType.STATUS:
|
|
yield message
|
|
|
|
if not isinstance(message, ChatMessageRagSearch):
|
|
raise ValueError(
|
|
f"Expected ChatMessageRagSearch, got {type(rag_message)}"
|
|
)
|
|
|
|
rag_message = message
|
|
context = self.get_rag_context(rag_message)
|
|
|
|
# Create a pruned down message list based purely on the prompt and responses,
|
|
# discarding the full preamble generated by prepare_message
|
|
messages: List[LLMMessage] = [
|
|
LLMMessage(role="system", content=self.system_prompt)
|
|
]
|
|
# Add the conversation history to the messages
|
|
messages.extend([
|
|
LLMMessage(role="user" if isinstance(m, ChatMessageUser) else "assistant", content=m.content)
|
|
for m in self.conversation
|
|
])
|
|
# Add the RAG context to the messages if available
|
|
if context:
|
|
messages.append(
|
|
LLMMessage(
|
|
role="user",
|
|
content=f"<|context|>\nThe following is context information about {self.user.full_name}:\n{context}\n</|context|>\n\nPrompt to respond to:\n{prompt}\n"
|
|
)
|
|
)
|
|
else:
|
|
# Only the actual user query is provided with the full context message
|
|
messages.append(
|
|
LLMMessage(role="user", content=prompt)
|
|
)
|
|
|
|
|
|
# use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
|
|
# message.metadata.tools = {
|
|
# "available": llm_tools(self.context.tools),
|
|
# "used": False,
|
|
# }
|
|
# tool_metadata = message.metadata.tools
|
|
|
|
# if use_tools:
|
|
# message.status = "thinking"
|
|
# message.content = f"Performing tool analysis step 1/2..."
|
|
# yield message
|
|
|
|
# logger.info("Checking for LLM tool usage")
|
|
# start_time = time.perf_counter()
|
|
# # Tools are enabled and available, so query the LLM with a short context of messages
|
|
# # in case the LLM did something like ask "Do you want me to run the tool?" and the
|
|
# # user said "Yes" -- need to keep the context in the thread.
|
|
# tool_metadata["messages"] = (
|
|
# [{"role": "system", "content": self.system_prompt}] + messages[-6:]
|
|
# if len(messages) >= 7
|
|
# else messages
|
|
# )
|
|
|
|
# response = llm.chat(
|
|
# model=model,
|
|
# messages=tool_metadata["messages"],
|
|
# tools=tool_metadata["available"],
|
|
# options={
|
|
# **message.metadata.options,
|
|
# },
|
|
# stream=False, # No need to stream the probe
|
|
# )
|
|
# self.collect_metrics(response)
|
|
|
|
# end_time = time.perf_counter()
|
|
# message.metadata.timers["tool_check"] = end_time - start_time
|
|
# if not response.tool_calls:
|
|
# logger.info("LLM indicates tools will not be used")
|
|
# # The LLM will not use tools, so disable use_tools so we can stream the full response
|
|
# use_tools = False
|
|
# else:
|
|
# tool_metadata["attempted"] = response.tool_calls
|
|
|
|
# if use_tools:
|
|
# logger.info("LLM indicates tools will be used")
|
|
|
|
# # Tools are enabled and available and the LLM indicated it will use them
|
|
# message.content = (
|
|
# f"Performing tool analysis step 2/2 (tool use suspected)..."
|
|
# )
|
|
# yield message
|
|
|
|
# logger.info(f"Performing LLM call with tools")
|
|
# start_time = time.perf_counter()
|
|
# response = llm.chat(
|
|
# model=model,
|
|
# messages=tool_metadata["messages"], # messages,
|
|
# tools=tool_metadata["available"],
|
|
# options={
|
|
# **message.metadata.options,
|
|
# },
|
|
# stream=False,
|
|
# )
|
|
# self.collect_metrics(response)
|
|
|
|
# end_time = time.perf_counter()
|
|
# message.metadata.timers["non_streaming"] = end_time - start_time
|
|
|
|
# if not response:
|
|
# message.status = "error"
|
|
# message.content = "No response from LLM."
|
|
# yield message
|
|
# return
|
|
|
|
# if response.tool_calls:
|
|
# tool_metadata["used"] = response.tool_calls
|
|
# # Process all yielded items from the handler
|
|
# start_time = time.perf_counter()
|
|
# async for message in self.process_tool_calls(
|
|
# llm=llm,
|
|
# model=model,
|
|
# message=message,
|
|
# tool_message=response,
|
|
# messages=messages,
|
|
# ):
|
|
# if message.status == "error":
|
|
# yield message
|
|
# return
|
|
# yield message
|
|
# end_time = time.perf_counter()
|
|
# message.metadata.timers["process_tool_calls"] = end_time - start_time
|
|
# message.status = "done"
|
|
# return
|
|
|
|
# logger.info("LLM indicated tools will be used, and then they weren't")
|
|
# message.content = response.content
|
|
# message.status = "done"
|
|
# yield message
|
|
# return
|
|
|
|
# not use_tools
|
|
status_message = ChatMessageStatus(
|
|
session_id=session_id,
|
|
activity=ApiActivityType.GENERATING,
|
|
content="Generating response..."
|
|
)
|
|
yield status_message
|
|
|
|
# Set the response for streaming
|
|
self.set_optimal_context_size(
|
|
llm, model, prompt=prompt
|
|
)
|
|
|
|
options = ChatOptions(
|
|
seed=8911,
|
|
num_ctx=self.context_size,
|
|
temperature=temperature,
|
|
)
|
|
logger.info(f"Message options: {options.model_dump(exclude_unset=True)}")
|
|
content = ""
|
|
start_time = time.perf_counter()
|
|
response = None
|
|
async for response in llm.chat_stream(
|
|
model=model,
|
|
messages=messages,
|
|
options={
|
|
**options.model_dump(exclude_unset=True),
|
|
},
|
|
stream=True,
|
|
):
|
|
if not response:
|
|
error_message = ChatMessageError(
|
|
session_id=session_id,
|
|
content="No response from LLM."
|
|
)
|
|
yield error_message
|
|
return
|
|
|
|
content += response.content
|
|
|
|
if not response.finish_reason:
|
|
streaming_message = ChatMessageStreaming(
|
|
session_id=session_id,
|
|
content=response.content,
|
|
)
|
|
yield streaming_message
|
|
|
|
if not response:
|
|
error_message = ChatMessageError(
|
|
session_id=session_id,
|
|
content="No response from LLM."
|
|
)
|
|
yield error_message
|
|
return
|
|
|
|
self.user.collect_metrics(agent=self, response=response)
|
|
self.context_tokens = (
|
|
response.usage.prompt_eval_count + response.usage.eval_count
|
|
)
|
|
end_time = time.perf_counter()
|
|
|
|
chat_message = ChatMessage(
|
|
session_id=session_id,
|
|
tunables=tunables,
|
|
status=ApiStatusType.DONE,
|
|
content=content,
|
|
metadata = ChatMessageMetaData(
|
|
options=options,
|
|
eval_count=response.usage.eval_count,
|
|
eval_duration=response.usage.eval_duration,
|
|
prompt_eval_count=response.usage.prompt_eval_count,
|
|
prompt_eval_duration=response.usage.prompt_eval_duration,
|
|
rag_results=rag_message.content if rag_message else [],
|
|
timers={
|
|
"llm_streamed": end_time - start_time,
|
|
"llm_with_tools": 0, # Placeholder for tool processing time
|
|
},
|
|
)
|
|
)
|
|
|
|
|
|
# Add the user and chat messages to the conversation
|
|
self.conversation.append(user_message)
|
|
self.conversation.append(chat_message)
|
|
yield chat_message
|
|
return
|
|
|
|
# async def process_message(
|
|
# self, llm: Any, model: str, message: Message
|
|
# ) -> AsyncGenerator[Message, None]:
|
|
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
|
|
|
# self.metrics.process_count.labels(agent=self.agent_type).inc()
|
|
# with self.metrics.process_duration.labels(agent=self.agent_type).time():
|
|
|
|
# if not self.context:
|
|
# raise ValueError("Context is not set for this agent.")
|
|
|
|
# logger.info(
|
|
# "TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time"
|
|
# )
|
|
# spinner: List[str] = ["\\", "|", "/", "-"]
|
|
# tick: int = 0
|
|
# while self.context.processing:
|
|
# message.status = "waiting"
|
|
# message.content = (
|
|
# f"Busy processing another request. Please wait. {spinner[tick]}"
|
|
# )
|
|
# tick = (tick + 1) % len(spinner)
|
|
# yield message
|
|
# await asyncio.sleep(1) # Allow the event loop to process the write
|
|
|
|
# self.context.processing = True
|
|
|
|
# message.system_prompt = (
|
|
# f"<|system|>\n{self.system_prompt.strip()}\n</|system|>"
|
|
# )
|
|
# message.context_prompt = ""
|
|
# for p in message.preamble.keys():
|
|
# message.context_prompt += (
|
|
# f"\n<|{p}|>\n{message.preamble[p].strip()}\n</|{p}>\n\n"
|
|
# )
|
|
# message.context_prompt += f"{message.prompt}"
|
|
|
|
# # Estimate token length of new messages
|
|
# message.content = f"Optimizing context..."
|
|
# message.status = "thinking"
|
|
# yield message
|
|
|
|
# message.context_size = self.set_optimal_context_size(
|
|
# llm, model, prompt=message.context_prompt
|
|
# )
|
|
|
|
# message.content = f"Processing {'RAG augmented ' if message.metadata.rag else ''}query..."
|
|
# message.status = "thinking"
|
|
# yield message
|
|
|
|
# async for message in self.generate_llm_response(
|
|
# llm=llm, model=model, message=message
|
|
# ):
|
|
# # logger.info(f"LLM: {message.status} - {f'...{message.content[-20:]}' if len(message.content) > 20 else message.content}")
|
|
# if message.status == "error":
|
|
# yield message
|
|
# self.context.processing = False
|
|
# return
|
|
# yield message
|
|
|
|
# # Done processing, add message to conversation
|
|
# message.status = "done"
|
|
# self.conversation.add(message)
|
|
# self.context.processing = False
|
|
|
|
# return
|
|
|
|
def extract_json_blocks(self, text: str, allow_multiple: bool = False) -> List[dict]:
|
|
"""
|
|
Extract JSON blocks from text, even if surrounded by markdown or noisy text.
|
|
If allow_multiple is True, returns all JSON blocks; otherwise, only the first.
|
|
"""
|
|
return json_extractor.extract_json_blocks(text, allow_multiple)
|
|
|
|
def extract_json_from_text(self, text: str) -> str:
|
|
"""Extract JSON string from text that may contain other content."""
|
|
return json_extractor.extract_json_from_text(text)
|
|
|
|
def extract_markdown_from_text(self, text: str) -> str:
|
|
"""Extract Markdown string from text that may contain other content."""
|
|
markdown_pattern = r"```(md|markdown)\s*([\s\S]*?)\s*```"
|
|
match = re.search(markdown_pattern, text)
|
|
if match:
|
|
return match.group(2).strip()
|
|
|
|
raise ValueError("No Markdown found in the response")
|
|
|
|
_agents: List[Agent] = []
|
|
|
|
def get_or_create_agent(
|
|
agent_type: str,
|
|
prometheus_collector: CollectorRegistry,
|
|
user: Optional[CandidateEntity]=None) -> Agent:
|
|
"""
|
|
Get or create and append a new agent of the specified type, ensuring only one agent per type exists.
|
|
|
|
Args:
|
|
agent_type: The type of agent to create (e.g., 'general', 'candidate_chat', 'image_generation').
|
|
**kwargs: Additional fields required by the specific agent subclass.
|
|
|
|
Returns:
|
|
The created agent instance.
|
|
|
|
Raises:
|
|
ValueError: If no matching agent type is found or if a agent of this type already exists.
|
|
"""
|
|
# Check if a global (non-user) agent with the given agent_type already exists
|
|
if not user:
|
|
for agent in _agents:
|
|
if agent.agent_type == agent_type:
|
|
return agent
|
|
|
|
# Find the matching subclass
|
|
for agent_cls in Agent.__subclasses__():
|
|
if agent_cls.model_fields["agent_type"].default == agent_type:
|
|
# Create the agent instance with provided kwargs
|
|
agent = agent_cls(agent_type=agent_type, # type: ignore[call-arg]
|
|
user=user)
|
|
_agents.append(agent)
|
|
return agent
|
|
|
|
raise ValueError(f"No agent class found for agent_type: {agent_type}")
|
|
|
|
# Register the base agent
|
|
agent_registry.register(Agent._agent_type, Agent)
|
|
CandidateEntity.model_rebuild()
|
|
|