101 lines
3.2 KiB
Python
101 lines
3.2 KiB
Python
from __future__ import annotations
|
|
import traceback
|
|
from pydantic import BaseModel, Field # type: ignore
|
|
from typing import (
|
|
Literal,
|
|
get_args,
|
|
List,
|
|
AsyncGenerator,
|
|
TYPE_CHECKING,
|
|
Optional,
|
|
ClassVar,
|
|
Any,
|
|
TypeAlias,
|
|
Dict,
|
|
Tuple,
|
|
)
|
|
import importlib
|
|
import pathlib
|
|
import inspect
|
|
from prometheus_client import CollectorRegistry # type: ignore
|
|
|
|
from . base import Agent
|
|
from logger import logger
|
|
from models import Candidate
|
|
|
|
_agents: List[Agent] = []
|
|
|
|
def get_or_create_agent(agent_type: str, prometheus_collector: CollectorRegistry, user: Optional[Candidate]=None, **kwargs) -> Agent:
|
|
"""
|
|
Get or create and append a new agent of the specified type, ensuring only one agent per type exists.
|
|
|
|
Args:
|
|
agent_type: The type of agent to create (e.g., 'general', 'candidate_chat', 'image_generation').
|
|
**kwargs: Additional fields required by the specific agent subclass.
|
|
|
|
Returns:
|
|
The created agent instance.
|
|
|
|
Raises:
|
|
ValueError: If no matching agent type is found or if a agent of this type already exists.
|
|
"""
|
|
# Check if a agent with the given agent_type already exists
|
|
for agent in _agents:
|
|
if agent.agent_type == agent_type:
|
|
return agent
|
|
|
|
# Find the matching subclass
|
|
for agent_cls in Agent.__subclasses__():
|
|
if agent_cls.model_fields["agent_type"].default == agent_type:
|
|
# Create the agent instance with provided kwargs
|
|
agent = agent_cls(agent_type=agent_type, user=user, prometheus_collector=prometheus_collector, **kwargs)
|
|
# if agent.agent_persist: # If an agent is not set to persist, do not add it to the list
|
|
_agents.append(agent)
|
|
return agent
|
|
|
|
raise ValueError(f"No agent class found for agent_type: {agent_type}")
|
|
|
|
# Type alias for Agent or any subclass
|
|
AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses
|
|
|
|
# Maps class_name to (module_name, class_name)
|
|
class_registry: Dict[str, Tuple[str, str]] = (
|
|
{}
|
|
)
|
|
|
|
__all__ = ['get_or_create_agent']
|
|
|
|
package_dir = pathlib.Path(__file__).parent
|
|
package_name = __name__
|
|
|
|
for path in package_dir.glob("*.py"):
|
|
if path.name in ("__init__.py", "base.py") or path.name.startswith("_"):
|
|
continue
|
|
|
|
module_name = path.stem
|
|
full_module_name = f"{package_name}.{module_name}"
|
|
|
|
try:
|
|
module = importlib.import_module(full_module_name)
|
|
|
|
# Find all Agent subclasses in the module
|
|
for name, obj in inspect.getmembers(module, inspect.isclass):
|
|
if (
|
|
issubclass(obj, AnyAgent)
|
|
and obj is not AnyAgent
|
|
and obj is not Agent
|
|
and name not in class_registry
|
|
):
|
|
class_registry[name] = (full_module_name, name)
|
|
globals()[name] = obj
|
|
logger.info(f"Adding agent: {name}")
|
|
__all__.append(name) # type: ignore
|
|
except ImportError as e:
|
|
logger.error(traceback.format_exc())
|
|
logger.error(f"Error importing {full_module_name}: {e}")
|
|
raise e
|
|
except Exception as e:
|
|
logger.error(traceback.format_exc())
|
|
logger.error(f"Error processing {full_module_name}: {e}")
|
|
raise e
|