from __future__ import annotations from pydantic import BaseModel, Field # type: ignore from typing import ( Literal, get_args, List, AsyncGenerator, Optional, ClassVar, Any, ) import time import re from abc import ABC from datetime import datetime, UTC from prometheus_client import CollectorRegistry # type: ignore import numpy as np # type: ignore import json_extractor as json_extractor from pydantic import BaseModel, Field # type: ignore from uuid import uuid4 from typing import List, Optional, ClassVar, Any, Literal from datetime import datetime, UTC import numpy as np # type: ignore from uuid import uuid4 from prometheus_client import CollectorRegistry # type: ignore import os import re from pathlib import Path from rag import start_file_watcher, ChromaDBFileWatcher import defines from logger import logger from models import (Tunables, ChatMessageUser, ChatMessage, RagEntry, ChatMessageMetaData, ApiStatusType, Candidate, ChatContextType) import utils.llm_proxy as llm_manager from database.manager import RedisDatabase from models import ChromaDBGetResponse from utils.metrics import Metrics from models import ( ApiActivityType, ApiMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiStatusType, ChatMessageMetaData, Candidate) from logger import logger import defines from .registry import agent_registry from models import ( ChromaDBGetResponse ) class CandidateEntity(Candidate): model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc id: str = Field(default_factory=lambda: str(uuid4()), description="Unique identifier for the entity") last_accessed: datetime = Field(default_factory=lambda: datetime.now(UTC), description="Last accessed timestamp") reference_count: int = Field(default=0, description="Number of active references to this entity") async def cleanup(self): """Cleanup resources associated with this entity""" # Internal instance members CandidateEntity__agents: List[Agent] = [] CandidateEntity__observer: Optional[Any] = Field(default=None, exclude=True) CandidateEntity__file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True) CandidateEntity__prometheus_collector: Optional[CollectorRegistry] = Field( default=None, exclude=True ) CandidateEntity__metrics: Optional[Metrics] = Field( default=None, description="Metrics collector for this agent, used to track performance and usage." ) def __init__(self, candidate=None): if candidate is not None: # Copy attributes from the candidate instance super().__init__(**vars(candidate)) else: raise ValueError("CandidateEntity must be initialized with a Candidate instance or attributes") @classmethod def exists(cls, username: str): # Validate username format (only allow safe characters) if not re.match(r'^[a-zA-Z0-9_-]+$', username): return False # Invalid username characters # Check for minimum and maximum length if not (3 <= len(username) <= 32): return False # Invalid username length # Use Path for safe path handling and normalization user_dir = Path(defines.user_dir) / username user_info_path = user_dir / defines.user_info_file # Ensure the final path is actually within the intended parent directory # to help prevent directory traversal attacks try: if not user_dir.resolve().is_relative_to(Path(defines.user_dir).resolve()): return False # Path traversal attempt detected except (ValueError, RuntimeError): # Potential exceptions from resolve() return False # Check if file exists return user_info_path.is_file() def get_or_create_agent(self, agent_type: ChatContextType) -> Agent: """ Get or create an agent of the specified type for this candidate. Args: agent_type: The type of agent to create (default is 'candidate_chat'). **kwargs: Additional fields required by the specific agent subclass. Returns: The created agent instance. """ # Only instantiate one agent of each type per user for agent in self.CandidateEntity__agents: if agent.agent_type == agent_type: return agent return get_or_create_agent( agent_type=agent_type, user=self, prometheus_collector=self.prometheus_collector ) # Wrapper properties that map into file_watcher @property def umap_collection(self) -> ChromaDBGetResponse: if not self.CandidateEntity__file_watcher: raise ValueError("initialize() has not been called.") return self.CandidateEntity__file_watcher.umap_collection # Fields managed by initialize() CandidateEntity__initialized: bool = Field(default=False, exclude=True) @property def metrics(self) -> Metrics: if not self.CandidateEntity__metrics: raise ValueError("initialize() has not been called.") return self.CandidateEntity__metrics @property def file_watcher(self) -> ChromaDBFileWatcher: if not self.CandidateEntity__file_watcher: raise ValueError("initialize() has not been called.") return self.CandidateEntity__file_watcher @property def prometheus_collector(self) -> CollectorRegistry: if not self.CandidateEntity__prometheus_collector: raise ValueError("initialize() has not been called with a prometheus_collector.") return self.CandidateEntity__prometheus_collector @property def observer(self) -> Any: if not self.CandidateEntity__observer: raise ValueError("initialize() has not been called.") return self.CandidateEntity__observer def collect_metrics(self, agent: Agent, response): if not self.metrics: logger.warning("No metrics collector set for this agent.") return self.metrics.tokens_prompt.labels(agent=agent.agent_type).inc( response.usage.prompt_eval_count ) self.metrics.tokens_eval.labels(agent=agent.agent_type).inc(response.usage.eval_count) async def initialize( self, prometheus_collector: CollectorRegistry, database: RedisDatabase): if self.CandidateEntity__initialized: # Initialization can only be attempted once; if there are multiple attempts, it means # a subsystem is failing or there is a logic bug in the code. # # NOTE: It is intentional that self.CandidateEntity__initialize = True regardless of whether it # succeeded. This prevents server loops on failure raise ValueError("initialize can only be attempted once") self.CandidateEntity__initialized = True if not self.username: raise ValueError("username can not be empty") if not prometheus_collector: raise ValueError("prometheus_collector can not be None") self.CandidateEntity__prometheus_collector = prometheus_collector self.CandidateEntity__metrics = Metrics(prometheus_collector=self.prometheus_collector) user_dir = os.path.join(defines.user_dir, self.username) vector_db_dir=os.path.join(user_dir, defines.persist_directory) rag_content_dir=os.path.join(user_dir, defines.rag_content_dir) os.makedirs(vector_db_dir, exist_ok=True) os.makedirs(rag_content_dir, exist_ok=True) self.CandidateEntity__observer, self.CandidateEntity__file_watcher = start_file_watcher( llm=llm_manager.get_llm(), user_id=self.id, collection_name=self.username, persist_directory=vector_db_dir, watch_directory=rag_content_dir, database=database, recreate=False, # Don't recreate if exists ) has_username_rag = any(item.name == self.username for item in self.rags) if not has_username_rag: self.rags.append(RagEntry( name=self.username, description=f"Expert data about {self.full_name}.", )) self.rag_content_size = self.file_watcher.collection.count() class Agent(BaseModel, ABC): """ Base class for all agent types. This class defines the common attributes and methods for all agent types. """ class Config: arbitrary_types_allowed = True # Allow arbitrary types like RedisDatabase # Agent management with pydantic agent_type: Literal["base"] = "base" _agent_type: ClassVar[str] = agent_type # Add this for registration user: Optional[CandidateEntity] = None # Tunables (sets default for new Messages attached to this agent) tunables: Tunables = Field(default_factory=Tunables) # Agent properties system_prompt: str = "" context_tokens: int = 0 # context_size is shared across all subclasses _context_size: ClassVar[int] = int(defines.max_context * 0.5) conversation: List[ChatMessageUser] = Field( default_factory=list, description="Conversation history for this agent, used to maintain context across messages." ) @property def context_size(self) -> int: return Agent._context_size @context_size.setter def context_size(self, value: int): Agent._context_size = value async def get_last_item(self, generator): last_item = None async for item in generator: last_item = item return last_item def set_optimal_context_size( self, llm: Any, model: str, prompt: str, ctx_buffer=2048 ) -> int: # Most models average 1.3-1.5 tokens per word word_count = len(prompt.split()) tokens = int(word_count * 1.4) # Add buffer for safety total_ctx = tokens + ctx_buffer if total_ctx > self.context_size: logger.info( f"Increasing context size from {self.context_size} to {total_ctx}" ) # Grow the context size if necessary self.context_size = max(self.context_size, total_ctx) # Use actual model maximum context size return self.context_size # Class and pydantic model management def __init_subclass__(cls, **kwargs) -> None: """Auto-register subclasses""" super().__init_subclass__(**kwargs) # Register this class if it has an agent_type if hasattr(cls, "agent_type") and cls.agent_type != Agent._agent_type: agent_registry.register(cls.agent_type, cls) def model_dump(self, *args, **kwargs) -> Any: # Ensure context is always excluded, even with exclude_unset=True kwargs.setdefault("exclude", set()) if isinstance(kwargs["exclude"], set): kwargs["exclude"].add("context") elif isinstance(kwargs["exclude"], dict): kwargs["exclude"]["context"] = True return super().model_dump(*args, **kwargs) @classmethod def valid_agent_types(cls) -> set[str]: """Return the set of valid agent_type values.""" return set(get_args(cls.__annotations__["agent_type"])) # Agent methods def get_agent_type(self): return self._agent_type # async def process_tool_calls( # self, # llm: Any, # model: str, # message: ChatMessage, # tool_message: Any, # llama response # messages: List[LLMMessage], # ) -> AsyncGenerator[ChatMessage, None]: # logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") # self.metrics.tool_count.labels(agent=self.agent_type).inc() # with self.metrics.tool_duration.labels(agent=self.agent_type).time(): # if not self.context: # raise ValueError("Context is not set for this agent.") # if not message.metadata.tools: # raise ValueError("tools field not initialized") # tool_metadata = message.metadata.tools # tool_metadata["tool_calls"] = [] # message.status = "tooling" # for i, tool_call in enumerate(tool_message.tool_calls): # arguments = tool_call.function.arguments # tool = tool_call.function.name # # Yield status update before processing each tool # message.content = ( # f"Processing tool {i+1}/{len(tool_message.tool_calls)}: {tool}..." # ) # yield message # logger.info(f"LLM - {message.content}") # # Process the tool based on its type # match tool: # case "TickerValue": # ticker = arguments.get("ticker") # if not ticker: # ret = None # else: # ret = TickerValue(ticker) # case "AnalyzeSite": # url = arguments.get("url") # question = arguments.get( # "question", "what is the summary of this content?" # ) # # Additional status update for long-running operations # message.content = ( # f"Retrieving and summarizing content from {url}..." # ) # yield message # ret = await AnalyzeSite( # llm=llm, model=model, url=url, question=question # ) # case "GenerateImage": # prompt = arguments.get("prompt", None) # if not prompt: # logger.info("No prompt supplied to GenerateImage") # ret = { "error": "No prompt supplied to GenerateImage" } # # Additional status update for long-running operations # message.content = ( # f"Generating image for {prompt}..." # ) # yield message # ret = await GenerateImage( # llm=llm, model=model, prompt=prompt # ) # logger.info("GenerateImage returning", ret) # case "DateTime": # tz = arguments.get("timezone") # ret = DateTime(tz) # case "WeatherForecast": # city = arguments.get("city") # state = arguments.get("state") # message.content = ( # f"Fetching weather data for {city}, {state}..." # ) # yield message # ret = WeatherForecast(city, state) # case _: # logger.error(f"Requested tool {tool} does not exist") # ret = None # # Build response for this tool # tool_response = { # "role": "tool", # "content": json.dumps(ret), # "name": tool_call.function.name, # } # tool_metadata["tool_calls"].append(tool_response) # if len(tool_metadata["tool_calls"]) == 0: # message.status = "done" # yield message # return # message_dict = LLMMessage( # role=tool_message.get("role", "assistant"), # content=tool_message.get("content", ""), # tool_calls=[ # { # "function": { # "name": tc["function"]["name"], # "arguments": tc["function"]["arguments"], # } # } # for tc in tool_message.tool_calls # ], # ) # messages.append(message_dict) # messages.extend(tool_metadata["tool_calls"]) # message.status = "thinking" # message.content = "Incorporating tool results into response..." # yield message # # Decrease creativity when processing tool call requests # message.content = "" # start_time = time.perf_counter() # for response in llm.chat( # model=model, # messages=messages, # options={ # **message.metadata.options, # }, # stream=True, # ): # # logger.info(f"LLM::Tools: {'done' if response.finish_reason else 'processing'} - {response}") # message.status = "streaming" # message.chunk = response.content # message.content += message.chunk # if not response.finish_reason: # yield message # if response.finish_reason: # self.collect_metrics(response) # message.metadata.eval_count += response.eval_count # message.metadata.eval_duration += response.eval_duration # message.metadata.prompt_eval_count += response.prompt_eval_count # message.metadata.prompt_eval_duration += response.prompt_eval_duration # self.context_tokens = ( # response.prompt_eval_count + response.eval_count # ) # message.status = "done" # yield message # end_time = time.perf_counter() # message.metadata.timers["llm_with_tools"] = end_time - start_time # return def get_rag_context(self, rag_message: ChatMessageRagSearch) -> str: """ Extracts the RAG context from the rag_message. """ if not rag_message.content: return "" context = [] for chroma_results in rag_message.content: for index, metadata in enumerate(chroma_results.metadatas): content = "\n".join([ line.strip() for line in chroma_results.documents[index].split("\n") if line ]).strip() context.append(f""" Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")} Document reference: {chroma_results.ids[index]} Content: {content} """) return "\n".join(context) async def generate_rag_results( self, session_id: str, prompt: str, top_k: int=defines.default_rag_top_k, threshold: float=defines.default_rag_threshold, ) -> AsyncGenerator[ApiMessage, None]: """ Generate RAG results for the given query. Args: query: The query string to generate RAG results for. Returns: A list of dictionaries containing the RAG results. """ if not self.user: error_message = ChatMessageError( session_id=session_id, content="No user set for RAG generation." ) yield error_message return results : List[ChromaDBGetResponse] = [] user: CandidateEntity = self.user for rag in user.rags: if not rag.enabled: continue status_message = ChatMessageStatus( session_id=session_id, activity=ApiActivityType.SEARCHING, content = f"Searching RAG context {rag.name}..." ) yield status_message try: chroma_results = await user.file_watcher.find_similar( query=prompt, top_k=top_k, threshold=threshold ) if not chroma_results: continue query_embedding = np.array(chroma_results["query_embedding"]).flatten() # type: ignore umap_2d = user.file_watcher.umap_model_2d.transform([query_embedding])[0] # type: ignore umap_3d = user.file_watcher.umap_model_3d.transform([query_embedding])[0] # type: ignore rag_metadata = ChromaDBGetResponse( name=rag.name, query=prompt, query_embedding=query_embedding.tolist(), ids=chroma_results.get("ids", []), embeddings=chroma_results.get("embeddings", []), documents=chroma_results.get("documents", []), metadatas=chroma_results.get("metadatas", []), umap_embedding_2d=umap_2d.tolist(), umap_embedding_3d=umap_3d.tolist(), ) results.append(rag_metadata) except Exception as e: continue_message = ChatMessageStatus( session_id=session_id, activity=ApiActivityType.SEARCHING, content=f"Error searching RAG context {rag.name}: {str(e)}" ) yield continue_message final_message = ChatMessageRagSearch( session_id=session_id, content=results, status=ApiStatusType.DONE, ) yield final_message return async def llm_one_shot( self, llm: Any, model: str, session_id: str, prompt: str, system_prompt: str, tunables: Optional[Tunables] = None, temperature=0.7) -> AsyncGenerator[ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessage, None]: if not self.user: error_message = ChatMessageError( session_id=session_id, content="No user set for chat generation." ) yield error_message return self.set_optimal_context_size( llm=llm, model=model, prompt=prompt+system_prompt ) options = ChatOptions( seed=8911, num_ctx=self.context_size, temperature=temperature, ) messages: List[LLMMessage] = [ LLMMessage(role="system", content=system_prompt), LLMMessage(role="user", content=prompt), ] status_message = ChatMessageStatus( session_id=session_id, activity=ApiActivityType.GENERATING, content="Generating response..." ) yield status_message logger.info(f"Message options: {options.model_dump(exclude_unset=True)}") response = None content = "" async for response in llm.chat_stream( model=model, messages=messages, options={ **options.model_dump(exclude_unset=True), }, stream=True, ): if not response: error_message = ChatMessageError( session_id=session_id, content="No response from LLM." ) yield error_message return content += response.content if not response.finish_reason: streaming_message = ChatMessageStreaming( session_id=session_id, content=response.content, status=ApiStatusType.STREAMING, ) yield streaming_message if not response: error_message = ChatMessageError( session_id=session_id, content="No response from LLM." ) yield error_message return self.user.collect_metrics(agent=self, response=response) self.context_tokens = ( response.usage.prompt_eval_count + response.usage.eval_count ) chat_message = ChatMessage( session_id=session_id, tunables=tunables, status=ApiStatusType.DONE, content=content, metadata = ChatMessageMetaData( options=options, eval_count=response.usage.eval_count, eval_duration=response.usage.eval_duration, prompt_eval_count=response.usage.prompt_eval_count, prompt_eval_duration=response.usage.prompt_eval_duration, ) ) yield chat_message return async def generate( self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7 ) -> AsyncGenerator[ApiMessage, None]: if not self.user: error_message = ChatMessageError( session_id=session_id, content="No user set for chat generation." ) yield error_message return user_message = ChatMessageUser( session_id=session_id, content=prompt, ) self.user.metrics.generate_count.labels(agent=self.agent_type).inc() with self.user.metrics.generate_duration.labels(agent=self.agent_type).time(): context = None rag_message : ChatMessageRagSearch | None = None if self.user: message = None async for message in self.generate_rag_results(session_id=session_id, prompt=prompt): if message.status == ApiStatusType.ERROR: yield message return # Only yield messages that are in a streaming state if message.status == ApiStatusType.STATUS: yield message if not isinstance(message, ChatMessageRagSearch): raise ValueError( f"Expected ChatMessageRagSearch, got {type(rag_message)}" ) rag_message = message context = self.get_rag_context(rag_message) # Create a pruned down message list based purely on the prompt and responses, # discarding the full preamble generated by prepare_message messages: List[LLMMessage] = [ LLMMessage(role="system", content=self.system_prompt) ] # Add the conversation history to the messages messages.extend([ LLMMessage(role="user" if isinstance(m, ChatMessageUser) else "assistant", content=m.content) for m in self.conversation ]) # Add the RAG context to the messages if available if context: messages.append( LLMMessage( role="user", content=f"<|context|>\nThe following is context information about {self.user.full_name}:\n{context}\n\n\nPrompt to respond to:\n{prompt}\n" ) ) else: # Only the actual user query is provided with the full context message messages.append( LLMMessage(role="user", content=prompt) ) # use_tools = message.tunables.enable_tools and len(self.context.tools) > 0 # message.metadata.tools = { # "available": llm_tools(self.context.tools), # "used": False, # } # tool_metadata = message.metadata.tools # if use_tools: # message.status = "thinking" # message.content = f"Performing tool analysis step 1/2..." # yield message # logger.info("Checking for LLM tool usage") # start_time = time.perf_counter() # # Tools are enabled and available, so query the LLM with a short context of messages # # in case the LLM did something like ask "Do you want me to run the tool?" and the # # user said "Yes" -- need to keep the context in the thread. # tool_metadata["messages"] = ( # [{"role": "system", "content": self.system_prompt}] + messages[-6:] # if len(messages) >= 7 # else messages # ) # response = llm.chat( # model=model, # messages=tool_metadata["messages"], # tools=tool_metadata["available"], # options={ # **message.metadata.options, # }, # stream=False, # No need to stream the probe # ) # self.collect_metrics(response) # end_time = time.perf_counter() # message.metadata.timers["tool_check"] = end_time - start_time # if not response.tool_calls: # logger.info("LLM indicates tools will not be used") # # The LLM will not use tools, so disable use_tools so we can stream the full response # use_tools = False # else: # tool_metadata["attempted"] = response.tool_calls # if use_tools: # logger.info("LLM indicates tools will be used") # # Tools are enabled and available and the LLM indicated it will use them # message.content = ( # f"Performing tool analysis step 2/2 (tool use suspected)..." # ) # yield message # logger.info(f"Performing LLM call with tools") # start_time = time.perf_counter() # response = llm.chat( # model=model, # messages=tool_metadata["messages"], # messages, # tools=tool_metadata["available"], # options={ # **message.metadata.options, # }, # stream=False, # ) # self.collect_metrics(response) # end_time = time.perf_counter() # message.metadata.timers["non_streaming"] = end_time - start_time # if not response: # message.status = "error" # message.content = "No response from LLM." # yield message # return # if response.tool_calls: # tool_metadata["used"] = response.tool_calls # # Process all yielded items from the handler # start_time = time.perf_counter() # async for message in self.process_tool_calls( # llm=llm, # model=model, # message=message, # tool_message=response, # messages=messages, # ): # if message.status == "error": # yield message # return # yield message # end_time = time.perf_counter() # message.metadata.timers["process_tool_calls"] = end_time - start_time # message.status = "done" # return # logger.info("LLM indicated tools will be used, and then they weren't") # message.content = response.content # message.status = "done" # yield message # return # not use_tools status_message = ChatMessageStatus( session_id=session_id, activity=ApiActivityType.GENERATING, content="Generating response..." ) yield status_message # Set the response for streaming self.set_optimal_context_size( llm, model, prompt=prompt ) options = ChatOptions( seed=8911, num_ctx=self.context_size, temperature=temperature, ) logger.info(f"Message options: {options.model_dump(exclude_unset=True)}") content = "" start_time = time.perf_counter() response = None async for response in llm.chat_stream( model=model, messages=messages, options={ **options.model_dump(exclude_unset=True), }, stream=True, ): if not response: error_message = ChatMessageError( session_id=session_id, content="No response from LLM." ) yield error_message return content += response.content if not response.finish_reason: streaming_message = ChatMessageStreaming( session_id=session_id, content=response.content, ) yield streaming_message if not response: error_message = ChatMessageError( session_id=session_id, content="No response from LLM." ) yield error_message return self.user.collect_metrics(agent=self, response=response) self.context_tokens = ( response.usage.prompt_eval_count + response.usage.eval_count ) end_time = time.perf_counter() chat_message = ChatMessage( session_id=session_id, tunables=tunables, status=ApiStatusType.DONE, content=content, metadata = ChatMessageMetaData( options=options, eval_count=response.usage.eval_count, eval_duration=response.usage.eval_duration, prompt_eval_count=response.usage.prompt_eval_count, prompt_eval_duration=response.usage.prompt_eval_duration, rag_results=rag_message.content if rag_message else [], timers={ "llm_streamed": end_time - start_time, "llm_with_tools": 0, # Placeholder for tool processing time }, ) ) # Add the user and chat messages to the conversation self.conversation.append(user_message) self.conversation.append(chat_message) yield chat_message return # async def process_message( # self, llm: Any, model: str, message: Message # ) -> AsyncGenerator[Message, None]: # logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") # self.metrics.process_count.labels(agent=self.agent_type).inc() # with self.metrics.process_duration.labels(agent=self.agent_type).time(): # if not self.context: # raise ValueError("Context is not set for this agent.") # logger.info( # "TODO: Implement delay queing; busy for same agent, otherwise return queue size and estimated wait time" # ) # spinner: List[str] = ["\\", "|", "/", "-"] # tick: int = 0 # while self.context.processing: # message.status = "waiting" # message.content = ( # f"Busy processing another request. Please wait. {spinner[tick]}" # ) # tick = (tick + 1) % len(spinner) # yield message # await asyncio.sleep(1) # Allow the event loop to process the write # self.context.processing = True # message.system_prompt = ( # f"<|system|>\n{self.system_prompt.strip()}\n" # ) # message.context_prompt = "" # for p in message.preamble.keys(): # message.context_prompt += ( # f"\n<|{p}|>\n{message.preamble[p].strip()}\n\n\n" # ) # message.context_prompt += f"{message.prompt}" # # Estimate token length of new messages # message.content = f"Optimizing context..." # message.status = "thinking" # yield message # message.context_size = self.set_optimal_context_size( # llm, model, prompt=message.context_prompt # ) # message.content = f"Processing {'RAG augmented ' if message.metadata.rag else ''}query..." # message.status = "thinking" # yield message # async for message in self.generate_llm_response( # llm=llm, model=model, message=message # ): # # logger.info(f"LLM: {message.status} - {f'...{message.content[-20:]}' if len(message.content) > 20 else message.content}") # if message.status == "error": # yield message # self.context.processing = False # return # yield message # # Done processing, add message to conversation # message.status = "done" # self.conversation.add(message) # self.context.processing = False # return def extract_json_blocks(self, text: str, allow_multiple: bool = False) -> List[dict]: """ Extract JSON blocks from text, even if surrounded by markdown or noisy text. If allow_multiple is True, returns all JSON blocks; otherwise, only the first. """ return json_extractor.extract_json_blocks(text, allow_multiple) def extract_json_from_text(self, text: str) -> str: """Extract JSON string from text that may contain other content.""" return json_extractor.extract_json_from_text(text) def extract_markdown_from_text(self, text: str) -> str: """Extract Markdown string from text that may contain other content.""" markdown_pattern = r"```(md|markdown)\s*([\s\S]*?)\s*```" match = re.search(markdown_pattern, text) if match: return match.group(2).strip() raise ValueError("No Markdown found in the response") _agents: List[Agent] = [] def get_or_create_agent( agent_type: str, prometheus_collector: CollectorRegistry, user: Optional[CandidateEntity]=None) -> Agent: """ Get or create and append a new agent of the specified type, ensuring only one agent per type exists. Args: agent_type: The type of agent to create (e.g., 'general', 'candidate_chat', 'image_generation'). **kwargs: Additional fields required by the specific agent subclass. Returns: The created agent instance. Raises: ValueError: If no matching agent type is found or if a agent of this type already exists. """ # Check if a global (non-user) agent with the given agent_type already exists if not user: for agent in _agents: if agent.agent_type == agent_type: return agent # Find the matching subclass for agent_cls in Agent.__subclasses__(): if agent_cls.model_fields["agent_type"].default == agent_type: # Create the agent instance with provided kwargs agent = agent_cls(agent_type=agent_type, # type: ignore[call-arg] user=user) _agents.append(agent) return agent raise ValueError(f"No agent class found for agent_type: {agent_type}") # Register the base agent agent_registry.register(Agent._agent_type, Agent) CandidateEntity.model_rebuild()