Refactored code into separate files
This commit is contained in:
parent
4edf11a62d
commit
46faf456cf
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
||||
venv/**
|
||||
users/**
|
||||
!users/eliza
|
||||
users-prod/**
|
||||
|
@ -193,6 +193,9 @@ RUN pip install fastapi uvicorn "python-jose[cryptography]" bcrypt python-multip
|
||||
# Needed for email verification
|
||||
RUN pip install pyyaml user-agents cryptography
|
||||
|
||||
# OpenAPI CLI generator
|
||||
RUN pip install openapi-python-client
|
||||
|
||||
# Automatic type conversion pydantic -> typescript
|
||||
RUN pip install pydantic typing-inspect jinja2
|
||||
RUN apt-get update \
|
||||
|
@ -31,6 +31,7 @@ services:
|
||||
- ./users:/opt/backstory/users:rw # Live mount of user data
|
||||
- ./src:/opt/backstory/src:rw # Live mount server src
|
||||
- ./frontend/src/types:/opt/backstory/frontend/src/types # Live mount of types for pydantic->ts
|
||||
- ./venv:/opt/backstory/venv:rw # Live mount for python venv
|
||||
cap_add: # used for running ze-monitor within container
|
||||
- CAP_DAC_READ_SEARCH # Bypass all filesystem read access checks
|
||||
- CAP_PERFMON # Access to perf_events (vs. overloaded CAP_SYS_ADMIN)
|
||||
|
@ -100,12 +100,13 @@ const JobInfo: React.FC<JobInfoProps> = (props: JobInfoProps) => {
|
||||
setAdminStatusType(status.activity);
|
||||
setAdminStatus(status.content);
|
||||
},
|
||||
onMessage: (jobMessage: Types.JobRequirementsMessage) => {
|
||||
onMessage: async (jobMessage: Types.JobRequirementsMessage) => {
|
||||
const newJob: Types.Job = jobMessage.job
|
||||
console.log('onMessage - job', newJob);
|
||||
newJob.id = job.id;
|
||||
newJob.createdAt = job.createdAt;
|
||||
setActiveJob(newJob);
|
||||
const updatedJob: Types.Job = await apiClient.updateJob(job.id || '', newJob);
|
||||
setActiveJob(updatedJob);
|
||||
},
|
||||
onError: (error: Types.ChatMessageError) => {
|
||||
console.log('onError', error);
|
||||
|
@ -1446,7 +1446,7 @@ class ApiClient {
|
||||
options: StreamingOptions = {}
|
||||
): StreamingResponse {
|
||||
const body = JSON.stringify(formatApiRequest(chatMessage));
|
||||
return this.streamify(`/chat/sessions/${chatMessage.sessionId}/messages/stream`, body, options, "ChatMessage")
|
||||
return this.streamify(`/chat/sessions/messages/stream`, body, options, "ChatMessage")
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1,6 +1,6 @@
|
||||
// Generated TypeScript types from Pydantic models
|
||||
// Source: src/backend/models.py
|
||||
// Generated on: 2025-06-12T15:58:49.974420
|
||||
// Generated on: 2025-06-13T18:33:37.696910
|
||||
// DO NOT EDIT MANUALLY - This file is auto-generated
|
||||
|
||||
// ============================
|
||||
@ -770,6 +770,7 @@ export interface JobRequirements {
|
||||
education?: Array<string>;
|
||||
certifications?: Array<string>;
|
||||
preferredAttributes?: Array<string>;
|
||||
companyValues?: Array<string>;
|
||||
}
|
||||
|
||||
export interface JobRequirementsMessage {
|
||||
@ -967,7 +968,6 @@ export interface ResendVerificationRequest {
|
||||
|
||||
export interface Resume {
|
||||
id?: string;
|
||||
resumeId: string;
|
||||
jobId: string;
|
||||
candidateId: string;
|
||||
resume: string;
|
||||
|
@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import traceback
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import (
|
||||
@ -17,44 +18,9 @@ from typing import (
|
||||
import importlib
|
||||
import pathlib
|
||||
import inspect
|
||||
from prometheus_client import CollectorRegistry
|
||||
|
||||
from . base import Agent
|
||||
from .base import Agent, get_or_create_agent
|
||||
from logger import logger
|
||||
from models import Candidate
|
||||
|
||||
_agents: List[Agent] = []
|
||||
|
||||
def get_or_create_agent(agent_type: str, prometheus_collector: CollectorRegistry, user: Optional[Candidate]=None) -> Agent:
|
||||
"""
|
||||
Get or create and append a new agent of the specified type, ensuring only one agent per type exists.
|
||||
|
||||
Args:
|
||||
agent_type: The type of agent to create (e.g., 'general', 'candidate_chat', 'image_generation').
|
||||
**kwargs: Additional fields required by the specific agent subclass.
|
||||
|
||||
Returns:
|
||||
The created agent instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If no matching agent type is found or if a agent of this type already exists.
|
||||
"""
|
||||
# Check if a global (non-user) agent with the given agent_type already exists
|
||||
if not user:
|
||||
for agent in _agents:
|
||||
if agent.agent_type == agent_type:
|
||||
return agent
|
||||
|
||||
# Find the matching subclass
|
||||
for agent_cls in Agent.__subclasses__():
|
||||
if agent_cls.model_fields["agent_type"].default == agent_type:
|
||||
# Create the agent instance with provided kwargs
|
||||
agent = agent_cls(agent_type=agent_type, user=user, prometheus_collector=prometheus_collector)
|
||||
# if agent.agent_persist: # If an agent is not set to persist, do not add it to the list
|
||||
_agents.append(agent)
|
||||
return agent
|
||||
|
||||
raise ValueError(f"No agent class found for agent_type: {agent_type}")
|
||||
|
||||
# Type alias for Agent or any subclass
|
||||
AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses
|
||||
|
@ -23,16 +23,204 @@ from datetime import datetime, UTC
|
||||
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
|
||||
import numpy as np # type: ignore
|
||||
import json_extractor as json_extractor
|
||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
||||
from uuid import uuid4
|
||||
from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING, Literal
|
||||
|
||||
from models import ( ApiActivityType, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatQuery, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate)
|
||||
from datetime import datetime, date, UTC
|
||||
from typing_extensions import Annotated, Union
|
||||
import numpy as np # type: ignore
|
||||
|
||||
from uuid import uuid4
|
||||
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
||||
import traceback
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from rag import start_file_watcher, ChromaDBFileWatcher
|
||||
import defines
|
||||
from logger import logger
|
||||
from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, RagEntry, ChatMessageMetaData, ApiStatusType, Candidate, ChatContextType)
|
||||
import utils.llm_proxy as llm_manager
|
||||
from database import RedisDatabase
|
||||
from models import ChromaDBGetResponse
|
||||
from utils.metrics import Metrics
|
||||
|
||||
|
||||
from models import ( ApiActivityType, ApiMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatQuery, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate)
|
||||
from logger import logger
|
||||
import defines
|
||||
from .registry import agent_registry
|
||||
from metrics import Metrics
|
||||
import model_cast
|
||||
import backstory_traceback as traceback
|
||||
|
||||
from models import ( ChromaDBGetResponse )
|
||||
class CandidateEntity(Candidate):
|
||||
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()), description="Unique identifier for the entity")
|
||||
last_accessed: datetime = Field(default_factory=lambda: datetime.now(UTC), description="Last accessed timestamp")
|
||||
reference_count: int = Field(default=0, description="Number of active references to this entity")
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup resources associated with this entity"""
|
||||
pass
|
||||
|
||||
# Internal instance members
|
||||
CandidateEntity__agents: List[Agent] = []
|
||||
CandidateEntity__observer: Optional[Any] = Field(default=None, exclude=True)
|
||||
CandidateEntity__file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
|
||||
CandidateEntity__prometheus_collector: Optional[CollectorRegistry] = Field(
|
||||
default=None, exclude=True
|
||||
)
|
||||
|
||||
CandidateEntity__metrics: Optional[Metrics] = Field(
|
||||
default=None,
|
||||
description="Metrics collector for this agent, used to track performance and usage."
|
||||
)
|
||||
|
||||
def __init__(self, candidate=None):
|
||||
if candidate is not None:
|
||||
# Copy attributes from the candidate instance
|
||||
super().__init__(**vars(candidate))
|
||||
else:
|
||||
raise ValueError("CandidateEntity must be initialized with a Candidate instance or attributes")
|
||||
|
||||
@classmethod
|
||||
def exists(cls, username: str):
|
||||
# Validate username format (only allow safe characters)
|
||||
if not re.match(r'^[a-zA-Z0-9_-]+$', username):
|
||||
return False # Invalid username characters
|
||||
|
||||
# Check for minimum and maximum length
|
||||
if not (3 <= len(username) <= 32):
|
||||
return False # Invalid username length
|
||||
|
||||
# Use Path for safe path handling and normalization
|
||||
user_dir = Path(defines.user_dir) / username
|
||||
user_info_path = user_dir / defines.user_info_file
|
||||
|
||||
# Ensure the final path is actually within the intended parent directory
|
||||
# to help prevent directory traversal attacks
|
||||
try:
|
||||
if not user_dir.resolve().is_relative_to(Path(defines.user_dir).resolve()):
|
||||
return False # Path traversal attempt detected
|
||||
except (ValueError, RuntimeError): # Potential exceptions from resolve()
|
||||
return False
|
||||
|
||||
# Check if file exists
|
||||
return user_info_path.is_file()
|
||||
|
||||
def get_or_create_agent(self, agent_type: ChatContextType) -> Agent:
|
||||
"""
|
||||
Get or create an agent of the specified type for this candidate.
|
||||
|
||||
Args:
|
||||
agent_type: The type of agent to create (default is 'candidate_chat').
|
||||
**kwargs: Additional fields required by the specific agent subclass.
|
||||
|
||||
Returns:
|
||||
The created agent instance.
|
||||
"""
|
||||
|
||||
# Only instantiate one agent of each type per user
|
||||
for agent in self.CandidateEntity__agents:
|
||||
if agent.agent_type == agent_type:
|
||||
return agent
|
||||
|
||||
return get_or_create_agent(
|
||||
agent_type=agent_type,
|
||||
user=self,
|
||||
prometheus_collector=self.prometheus_collector
|
||||
)
|
||||
|
||||
# Wrapper properties that map into file_watcher
|
||||
@property
|
||||
def umap_collection(self) -> ChromaDBGetResponse:
|
||||
if not self.CandidateEntity__file_watcher:
|
||||
raise ValueError("initialize() has not been called.")
|
||||
return self.CandidateEntity__file_watcher.umap_collection
|
||||
|
||||
# Fields managed by initialize()
|
||||
CandidateEntity__initialized: bool = Field(default=False, exclude=True)
|
||||
@property
|
||||
def metrics(self) -> Metrics:
|
||||
if not self.CandidateEntity__metrics:
|
||||
raise ValueError("initialize() has not been called.")
|
||||
return self.CandidateEntity__metrics
|
||||
|
||||
@property
|
||||
def file_watcher(self) -> ChromaDBFileWatcher:
|
||||
if not self.CandidateEntity__file_watcher:
|
||||
raise ValueError("initialize() has not been called.")
|
||||
return self.CandidateEntity__file_watcher
|
||||
|
||||
@property
|
||||
def prometheus_collector(self) -> CollectorRegistry:
|
||||
if not self.CandidateEntity__prometheus_collector:
|
||||
raise ValueError("initialize() has not been called with a prometheus_collector.")
|
||||
return self.CandidateEntity__prometheus_collector
|
||||
|
||||
@property
|
||||
def observer(self) -> Any:
|
||||
if not self.CandidateEntity__observer:
|
||||
raise ValueError("initialize() has not been called.")
|
||||
return self.CandidateEntity__observer
|
||||
|
||||
def collect_metrics(self, agent: Agent, response):
|
||||
if not self.metrics:
|
||||
logger.warning("No metrics collector set for this agent.")
|
||||
return
|
||||
self.metrics.tokens_prompt.labels(agent=agent.agent_type).inc(
|
||||
response.usage.prompt_eval_count
|
||||
)
|
||||
self.metrics.tokens_eval.labels(agent=agent.agent_type).inc(response.usage.eval_count)
|
||||
|
||||
async def initialize(
|
||||
self,
|
||||
prometheus_collector: CollectorRegistry,
|
||||
database: RedisDatabase):
|
||||
if self.CandidateEntity__initialized:
|
||||
# Initialization can only be attempted once; if there are multiple attempts, it means
|
||||
# a subsystem is failing or there is a logic bug in the code.
|
||||
#
|
||||
# NOTE: It is intentional that self.CandidateEntity__initialize = True regardless of whether it
|
||||
# succeeded. This prevents server loops on failure
|
||||
raise ValueError("initialize can only be attempted once")
|
||||
self.CandidateEntity__initialized = True
|
||||
|
||||
if not self.username:
|
||||
raise ValueError("username can not be empty")
|
||||
|
||||
if not prometheus_collector:
|
||||
raise ValueError("prometheus_collector can not be None")
|
||||
|
||||
self.CandidateEntity__prometheus_collector = prometheus_collector
|
||||
self.CandidateEntity__metrics = Metrics(prometheus_collector=self.prometheus_collector)
|
||||
|
||||
user_dir = os.path.join(defines.user_dir, self.username)
|
||||
vector_db_dir=os.path.join(user_dir, defines.persist_directory)
|
||||
rag_content_dir=os.path.join(user_dir, defines.rag_content_dir)
|
||||
|
||||
os.makedirs(vector_db_dir, exist_ok=True)
|
||||
os.makedirs(rag_content_dir, exist_ok=True)
|
||||
|
||||
self.CandidateEntity__observer, self.CandidateEntity__file_watcher = start_file_watcher(
|
||||
llm=llm_manager.get_llm(),
|
||||
user_id=self.id,
|
||||
collection_name=self.username,
|
||||
persist_directory=vector_db_dir,
|
||||
watch_directory=rag_content_dir,
|
||||
database=database,
|
||||
recreate=False, # Don't recreate if exists
|
||||
)
|
||||
has_username_rag = any(item.name == self.username for item in self.rags)
|
||||
if not has_username_rag:
|
||||
self.rags.append(RagEntry(
|
||||
name=self.username,
|
||||
description=f"Expert data about {self.full_name}.",
|
||||
))
|
||||
self.rag_content_size = self.file_watcher.collection.count()
|
||||
|
||||
class Agent(BaseModel, ABC):
|
||||
"""
|
||||
@ -46,20 +234,10 @@ class Agent(BaseModel, ABC):
|
||||
agent_type: Literal["base"] = "base"
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
user: Optional[Candidate] = None
|
||||
prometheus_collector: CollectorRegistry = Field(..., description="Prometheus collector for this agent, used to track metrics.", exclude=True)
|
||||
user: Optional[CandidateEntity] = None
|
||||
|
||||
# Tunables (sets default for new Messages attached to this agent)
|
||||
tunables: Tunables = Field(default_factory=Tunables)
|
||||
metrics: Metrics = Field(
|
||||
None, description="Metrics collector for this agent, used to track performance and usage."
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def initialize_metrics(self) -> "Agent":
|
||||
if self.metrics is None:
|
||||
self.metrics = Metrics(prometheus_collector=self.prometheus_collector)
|
||||
return self
|
||||
|
||||
# Agent properties
|
||||
system_prompt: str = ""
|
||||
@ -68,7 +246,7 @@ class Agent(BaseModel, ABC):
|
||||
# context_size is shared across all subclasses
|
||||
_context_size: ClassVar[int] = int(defines.max_context * 0.5)
|
||||
|
||||
conversation: List[ChatMessage] = Field(
|
||||
conversation: List[ChatMessageUser] = Field(
|
||||
default_factory=list,
|
||||
description="Conversation history for this agent, used to maintain context across messages."
|
||||
)
|
||||
@ -295,12 +473,6 @@ class Agent(BaseModel, ABC):
|
||||
# message.metadata.timers["llm_with_tools"] = end_time - start_time
|
||||
# return
|
||||
|
||||
def collect_metrics(self, response):
|
||||
self.metrics.tokens_prompt.labels(agent=self.agent_type).inc(
|
||||
response.usage.prompt_eval_count
|
||||
)
|
||||
self.metrics.tokens_eval.labels(agent=self.agent_type).inc(response.usage.eval_count)
|
||||
|
||||
def get_rag_context(self, rag_message: ChatMessageRagSearch) -> str:
|
||||
"""
|
||||
Extracts the RAG context from the rag_message.
|
||||
@ -329,7 +501,7 @@ Content: {content}
|
||||
prompt: str,
|
||||
top_k: int=defines.default_rag_top_k,
|
||||
threshold: float=defines.default_rag_threshold,
|
||||
) -> AsyncGenerator[ChatMessageRagSearch | ChatMessageError | ChatMessageStatus, None]:
|
||||
) -> AsyncGenerator[ApiMessage, None]:
|
||||
"""
|
||||
Generate RAG results for the given query.
|
||||
|
||||
@ -349,7 +521,7 @@ Content: {content}
|
||||
|
||||
results : List[ChromaDBGetResponse] = []
|
||||
entries: int = 0
|
||||
user: Candidate = self.user
|
||||
user: CandidateEntity = self.user
|
||||
for rag in user.rags:
|
||||
if not rag.enabled:
|
||||
continue
|
||||
@ -367,10 +539,10 @@ Content: {content}
|
||||
)
|
||||
if not chroma_results:
|
||||
continue
|
||||
query_embedding = np.array(chroma_results["query_embedding"]).flatten()
|
||||
query_embedding = np.array(chroma_results["query_embedding"]).flatten() # type: ignore
|
||||
|
||||
umap_2d = user.file_watcher.umap_model_2d.transform([query_embedding])[0]
|
||||
umap_3d = user.file_watcher.umap_model_3d.transform([query_embedding])[0]
|
||||
umap_2d = user.file_watcher.umap_model_2d.transform([query_embedding])[0] # type: ignore
|
||||
umap_3d = user.file_watcher.umap_model_3d.transform([query_embedding])[0] # type: ignore
|
||||
|
||||
rag_metadata = ChromaDBGetResponse(
|
||||
name=rag.name,
|
||||
@ -405,7 +577,15 @@ Content: {content}
|
||||
llm: Any, model: str,
|
||||
session_id: str, prompt: str, system_prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7) -> AsyncGenerator[ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessage, None]:
|
||||
temperature=0.7) -> AsyncGenerator[ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessage, None]:
|
||||
|
||||
if not self.user:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="No user set for chat generation."
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
self.set_optimal_context_size(
|
||||
llm=llm, model=model, prompt=prompt+system_prompt
|
||||
@ -466,7 +646,7 @@ Content: {content}
|
||||
yield error_message
|
||||
return
|
||||
|
||||
self.collect_metrics(response)
|
||||
self.user.collect_metrics(agent=self, response=response)
|
||||
self.context_tokens = (
|
||||
response.usage.prompt_eval_count + response.usage.eval_count
|
||||
)
|
||||
@ -493,7 +673,7 @@ Content: {content}
|
||||
session_id: str, prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7
|
||||
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessageRagSearch, None]:
|
||||
) -> AsyncGenerator[ApiMessage, None]:
|
||||
if not self.user:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
@ -508,8 +688,8 @@ Content: {content}
|
||||
)
|
||||
user = self.user
|
||||
|
||||
self.metrics.generate_count.labels(agent=self.agent_type).inc()
|
||||
with self.metrics.generate_duration.labels(agent=self.agent_type).time():
|
||||
self.user.metrics.generate_count.labels(agent=self.agent_type).inc()
|
||||
with self.user.metrics.generate_duration.labels(agent=self.agent_type).time():
|
||||
context = None
|
||||
rag_message : ChatMessageRagSearch | None = None
|
||||
if self.user:
|
||||
@ -711,7 +891,7 @@ Content: {content}
|
||||
yield error_message
|
||||
return
|
||||
|
||||
self.collect_metrics(response)
|
||||
self.user.collect_metrics(agent=self, response=response)
|
||||
self.context_tokens = (
|
||||
response.usage.prompt_eval_count + response.usage.eval_count
|
||||
)
|
||||
@ -830,5 +1010,43 @@ Content: {content}
|
||||
|
||||
raise ValueError("No Markdown found in the response")
|
||||
|
||||
_agents: List[Agent] = []
|
||||
|
||||
def get_or_create_agent(
|
||||
agent_type: str,
|
||||
prometheus_collector: CollectorRegistry,
|
||||
user: Optional[CandidateEntity]=None) -> Agent:
|
||||
"""
|
||||
Get or create and append a new agent of the specified type, ensuring only one agent per type exists.
|
||||
|
||||
Args:
|
||||
agent_type: The type of agent to create (e.g., 'general', 'candidate_chat', 'image_generation').
|
||||
**kwargs: Additional fields required by the specific agent subclass.
|
||||
|
||||
Returns:
|
||||
The created agent instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If no matching agent type is found or if a agent of this type already exists.
|
||||
"""
|
||||
# Check if a global (non-user) agent with the given agent_type already exists
|
||||
if not user:
|
||||
for agent in _agents:
|
||||
if agent.agent_type == agent_type:
|
||||
return agent
|
||||
|
||||
# Find the matching subclass
|
||||
for agent_cls in Agent.__subclasses__():
|
||||
if agent_cls.model_fields["agent_type"].default == agent_type:
|
||||
# Create the agent instance with provided kwargs
|
||||
agent = agent_cls(agent_type=agent_type, # type: ignore[call-arg]
|
||||
user=user)
|
||||
_agents.append(agent)
|
||||
return agent
|
||||
|
||||
raise ValueError(f"No agent class found for agent_type: {agent_type}")
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(Agent._agent_type, Agent)
|
||||
CandidateEntity.model_rebuild()
|
||||
|
||||
|
@ -7,7 +7,7 @@ from .base import Agent, agent_registry
|
||||
from logger import logger
|
||||
|
||||
from .registry import agent_registry
|
||||
from models import ( ChatMessageError, ChatMessageStatus, ChatMessageStreaming, ChatQuery, ChatMessage, Tunables, ApiStatusType, ChatMessageUser, Candidate)
|
||||
from models import ( ApiMessage, ChatMessageError, ChatMessageStatus, ChatMessageStreaming, ChatQuery, ChatMessage, Tunables, ApiStatusType, ChatMessageUser, Candidate)
|
||||
|
||||
|
||||
system_message = f"""
|
||||
@ -38,7 +38,7 @@ class CandidateChat(Agent):
|
||||
session_id: str, prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7
|
||||
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
|
||||
) -> AsyncGenerator[ApiMessage, None]:
|
||||
user = self.user
|
||||
if not user:
|
||||
logger.error("User is not set for CandidateChat agent.")
|
||||
|
@ -37,7 +37,7 @@ seed = int(time.time())
|
||||
random.seed(seed)
|
||||
|
||||
class ImageGenerator(Agent):
|
||||
agent_type: Literal["generate_image"] = "generate_image"
|
||||
agent_type: Literal["generate_image"] = "generate_image" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
agent_persist: bool = False
|
||||
|
||||
|
@ -303,7 +303,7 @@ class GeneratePersona(Agent):
|
||||
|
||||
generator: Any = Field(default=EthnicNameGenerator(), exclude=True)
|
||||
llm: Any = Field(default=None, exclude=True)
|
||||
model: str = Field(default=None, exclude=True)
|
||||
model: Optional[str] = Field(default=None, exclude=True)
|
||||
|
||||
def randomize(self):
|
||||
self.age = random.randint(22, 67)
|
||||
|
@ -21,7 +21,7 @@ import numpy as np # type: ignore
|
||||
|
||||
from logger import logger
|
||||
from .base import Agent, agent_registry
|
||||
from models import (ApiActivityType, ApiStatusType, Candidate, ChatMessage, ChatMessageError, ChatMessageResume, ChatMessageStatus, JobRequirements, JobRequirementsMessage, SkillAssessment, SkillStrength, Tunables)
|
||||
from models import (ApiActivityType, ApiMessage, ApiStatusType, Candidate, ChatMessage, ChatMessageError, ChatMessageResume, ChatMessageStatus, JobRequirements, JobRequirementsMessage, SkillAssessment, SkillStrength, Tunables)
|
||||
|
||||
class GenerateResume(Agent):
|
||||
agent_type: Literal["generate_resume"] = "generate_resume" # type: ignore
|
||||
@ -174,7 +174,7 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
|
||||
|
||||
async def generate_resume(
|
||||
self, llm: Any, model: str, session_id: str, skills: List[SkillAssessment]
|
||||
) -> AsyncGenerator[ChatMessage | ChatMessageError, None]:
|
||||
) -> AsyncGenerator[ApiMessage, None]:
|
||||
# Stage 1A: Analyze job requirements
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
@ -202,6 +202,15 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
|
||||
yield error_message
|
||||
return
|
||||
|
||||
if not isinstance(generated_message, ChatMessage):
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Job requirements analysis did not return a valid message."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
resume_message = ChatMessageResume(
|
||||
session_id=session_id,
|
||||
status=ApiStatusType.DONE,
|
||||
|
@ -19,7 +19,7 @@ import asyncio
|
||||
import numpy as np # type: ignore
|
||||
|
||||
from .base import Agent, agent_registry, LLMMessage
|
||||
from models import ApiActivityType, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, Job, JobRequirements, JobRequirementsMessage, Tunables
|
||||
from models import ApiActivityType, ApiMessage, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions, ChatSenderType, ApiStatusType, Job, JobRequirements, JobRequirementsMessage, Tunables
|
||||
import model_cast
|
||||
from logger import logger
|
||||
import defines
|
||||
@ -34,55 +34,69 @@ class JobRequirementsAgent(Agent):
|
||||
"""Create the prompt for job requirements analysis."""
|
||||
logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||
system_prompt = """
|
||||
You are an objective job requirements analyzer. Your task is to extract and categorize the specific skills,
|
||||
experiences, and qualifications required in a job description WITHOUT any reference to any candidate.
|
||||
You are an objective job requirements analyzer. Your task is to extract and categorize the specific skills,
|
||||
experiences, and qualifications required in a job description WITHOUT any reference to any candidate.
|
||||
|
||||
## INSTRUCTIONS:
|
||||
## INSTRUCTIONS:
|
||||
|
||||
1. Analyze ONLY the job description provided.
|
||||
2. Extract company information, job title, and all requirements.
|
||||
3. Extract and categorize all requirements and preferences.
|
||||
4. DO NOT consider any candidate information - this is a pure job analysis task.
|
||||
1. Analyze ONLY the job description provided.
|
||||
2. Extract company information, job title, and all requirements.
|
||||
3. If a requirement is compound (e.g., "5+ years experience with React, Node.js and MongoDB" or "FastAPI/Django/React"), break it down into individual components.
|
||||
4. Categorize requirements into:
|
||||
- Technical skills (required and preferred)
|
||||
- Experience requirements (required and preferred)
|
||||
- Education requirements
|
||||
- Soft skills
|
||||
- Industry knowledge
|
||||
- Responsibilities
|
||||
- Company values
|
||||
5. Extract and categorize all requirements and preferences.
|
||||
6. DO NOT consider any candidate information - this is a pure job analysis task.
|
||||
7. Provide the output in a structured JSON format as specified below.
|
||||
|
||||
## OUTPUT FORMAT:
|
||||
## OUTPUT FORMAT:
|
||||
|
||||
```json
|
||||
{
|
||||
"company_name": "Company Name",
|
||||
"job_title": "Job Title",
|
||||
"job_summary": "Brief summary of the job",
|
||||
"job_requirements": {
|
||||
"technical_skills": {
|
||||
"required": ["skill1", "skill2"],
|
||||
"preferred": ["skill1", "skill2"]
|
||||
},
|
||||
"experience_requirements": {
|
||||
"required": ["exp1", "exp2"],
|
||||
"preferred": ["exp1", "exp2"]
|
||||
},
|
||||
"education_requirements": ["req1", "req2"],
|
||||
"soft_skills": ["skill1", "skill2"],
|
||||
"industry_knowledge": ["knowledge1", "knowledge2"],
|
||||
"responsibilities": ["resp1", "resp2"],
|
||||
"company_values": ["value1", "value2"]
|
||||
}
|
||||
}
|
||||
```
|
||||
```json
|
||||
{
|
||||
"company_name": "Company Name",
|
||||
"job_title": "Job Title",
|
||||
"job_summary": "Brief summary of the job",
|
||||
"job_requirements": {
|
||||
"technical_skills": {
|
||||
"required": ["skill1", "skill2"],
|
||||
"preferred": ["skill1", "skill2"]
|
||||
},
|
||||
"experience_requirements": {
|
||||
"required": ["exp1", "exp2"],
|
||||
"preferred": ["exp1", "exp2"]
|
||||
},
|
||||
"soft_skills": ["skill1", "skill2"],
|
||||
"experience": ["exp1", "exp2"],
|
||||
"education": ["req1", "req2"],
|
||||
"certifications": ["knowledge1", "knowledge2"],
|
||||
"preferred_attributes": ["resp1", "resp2"],
|
||||
"company_values": ["value1", "value2"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Be specific and detailed in your extraction. Break down compound requirements into individual components.
|
||||
For example, "5+ years experience with React, Node.js and MongoDB" should be separated into:
|
||||
- Experience: "5+ years software development"
|
||||
- Technical skills: "React", "Node.js", "MongoDB"
|
||||
Be specific and detailed in your extraction.
|
||||
If a requirement can be broken down into several separate requirements, split them.
|
||||
For example, the technical_skill of "Python/Django/FastAPI" should be separated into different requirements: Python, Django, and FastAPI.
|
||||
|
||||
Avoid vague categorizations and be precise about whether skills are explicitly required or just preferred.
|
||||
"""
|
||||
For example, if the job description mentions: "Python/Django/FastAPI", you should extract it as:
|
||||
|
||||
"technical_skills": { "required": [ "Python", "Django", "FastAPI" ] },
|
||||
|
||||
Avoid vague categorizations and be precise about whether skills are explicitly required or just preferred.
|
||||
"""
|
||||
|
||||
prompt = f"Job Description:\n{job_description}"
|
||||
return system_prompt, prompt
|
||||
|
||||
async def analyze_job_requirements(
|
||||
self, llm: Any, model: str, session_id: str, prompt: str
|
||||
) -> AsyncGenerator[ChatMessage | ChatMessageError, None]:
|
||||
) -> AsyncGenerator[ChatMessageStreaming | ChatMessage | ChatMessageError | ChatMessageStatus, None]:
|
||||
"""Analyze job requirements from job description."""
|
||||
system_prompt, prompt = self.create_job_analysis_prompt(prompt)
|
||||
status_message = ChatMessageStatus(
|
||||
@ -111,9 +125,38 @@ class JobRequirementsAgent(Agent):
|
||||
yield generated_message
|
||||
return
|
||||
|
||||
def gather_requirements(self, reqs: JobRequirements) -> Dict[str, Any]:
|
||||
# technical_skills: Requirements = Field(..., alias="technicalSkills")
|
||||
# experience_requirements: Requirements = Field(..., alias="experienceRequirements")
|
||||
# soft_skills: Optional[List[str]] = Field(default_factory=list, alias="softSkills")
|
||||
# experience: Optional[List[str]] = []
|
||||
# education: Optional[List[str]] = []
|
||||
# certifications: Optional[List[str]] = []
|
||||
# preferred_attributes: Optional[List[str]] = Field(None, alias="preferredAttributes")
|
||||
# company_values: Optional[List[str]] = Field(None, alias="companyValues")
|
||||
"""Gather and format job requirements for display."""
|
||||
display = {
|
||||
"technical_skills": {
|
||||
"required": reqs.technical_skills.required,
|
||||
"preferred": reqs.technical_skills.preferred
|
||||
},
|
||||
"experience_requirements": {
|
||||
"required": reqs.experience_requirements.required,
|
||||
"preferred": reqs.experience_requirements.preferred
|
||||
},
|
||||
"soft_skills": reqs.soft_skills,
|
||||
"experience": reqs.experience,
|
||||
"education": reqs.education,
|
||||
"certifications": reqs.certifications,
|
||||
"preferred_attributes": reqs.preferred_attributes,
|
||||
"company_values": reqs.company_values
|
||||
}
|
||||
|
||||
return display
|
||||
|
||||
async def generate(
|
||||
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
|
||||
) -> AsyncGenerator[ChatMessage, None]:
|
||||
) -> AsyncGenerator[ApiMessage, None]:
|
||||
if not self.user:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
@ -182,6 +225,10 @@ class JobRequirementsAgent(Agent):
|
||||
logger.error(f"⚠️ {status_message.content}")
|
||||
yield status_message
|
||||
return
|
||||
|
||||
# Gather and format requirements for display
|
||||
display = self.gather_requirements(requirements)
|
||||
logger.info(f"📋 Job requirements extracted: {json.dumps(display, indent=2)}")
|
||||
job = Job(
|
||||
owner_id=self.user.id,
|
||||
owner_type=self.user.user_type,
|
||||
@ -189,7 +236,6 @@ class JobRequirementsAgent(Agent):
|
||||
title=title,
|
||||
summary=summary,
|
||||
requirements=requirements,
|
||||
session_id=session_id,
|
||||
description=prompt,
|
||||
)
|
||||
job_requirements_message = JobRequirementsMessage(
|
||||
|
@ -7,7 +7,7 @@ from .base import Agent, agent_registry
|
||||
from logger import logger
|
||||
|
||||
from .registry import agent_registry
|
||||
from models import ( ChatMessage, ChromaDBGetResponse, ApiStatusType, ChatMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, ChatOptions, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate, Tunables )
|
||||
from models import ( ApiMessage, ChatMessage, ChromaDBGetResponse, ApiStatusType, ChatMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, ChatOptions, ApiMessageType, ChatSenderType, ApiStatusType, ChatMessageMetaData, Candidate, Tunables )
|
||||
|
||||
class Chat(Agent):
|
||||
"""
|
||||
@ -21,7 +21,7 @@ class Chat(Agent):
|
||||
session_id: str, prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7
|
||||
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
|
||||
) -> AsyncGenerator[ApiMessage, None]:
|
||||
"""
|
||||
Generate a response based on the user message and the provided LLM.
|
||||
|
||||
|
@ -19,7 +19,7 @@ import asyncio
|
||||
import numpy as np # type: ignore
|
||||
|
||||
from .base import Agent, agent_registry, LLMMessage
|
||||
from models import (Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType,
|
||||
from models import (ApiMessage, Candidate, ChatMessage, ChatMessageError, ChatMessageMetaData, ApiMessageType, ChatMessageRagSearch,
|
||||
ChatMessageSkillAssessment, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatOptions,
|
||||
ChatSenderType, ApiStatusType, EvidenceDetail, SkillAssessment, Tunables)
|
||||
import model_cast
|
||||
@ -107,43 +107,12 @@ JSON RESPONSE:"""
|
||||
|
||||
return system_prompt, prompt
|
||||
|
||||
async def analyze_job_requirements(
|
||||
self, llm: Any, model: str, session_id: str, requirement: str
|
||||
) -> AsyncGenerator[ChatMessage, None]:
|
||||
"""Analyze job requirements from job description."""
|
||||
system_prompt, prompt = self.create_job_analysis_prompt(requirement)
|
||||
|
||||
generated_message = None
|
||||
async for generated_message in self.llm_one_shot(llm, model, session_id=session_id, prompt=prompt, system_prompt=system_prompt):
|
||||
if generated_message.status == ApiStatusType.ERROR:
|
||||
generated_message.content = "Error analyzing job requirements."
|
||||
yield generated_message
|
||||
return
|
||||
if generated_message.status != ApiStatusType.DONE:
|
||||
yield generated_message
|
||||
|
||||
if not generated_message:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content = "Job requirements analysis failed to generate a response.")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
chat_message = ChatMessage(
|
||||
session_id=session_id,
|
||||
status=ApiStatusType.DONE,
|
||||
content=generated_message.content,
|
||||
metadata=generated_message.metadata,
|
||||
)
|
||||
yield chat_message
|
||||
return
|
||||
|
||||
async def generate(
|
||||
self, llm: Any, model: str,
|
||||
session_id: str, prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7
|
||||
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
|
||||
) -> AsyncGenerator[ApiMessage, None]:
|
||||
if not self.user:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
@ -163,15 +132,15 @@ JSON RESPONSE:"""
|
||||
yield error_message
|
||||
return
|
||||
|
||||
rag_message = None
|
||||
async for rag_message in self.generate_rag_results(session_id=session_id, prompt=skill):
|
||||
if rag_message.status == ApiStatusType.ERROR:
|
||||
yield rag_message
|
||||
generated_message = None
|
||||
async for generated_message in self.generate_rag_results(session_id=session_id, prompt=skill):
|
||||
if generated_message.status == ApiStatusType.ERROR:
|
||||
yield generated_message
|
||||
return
|
||||
if rag_message.status != ApiStatusType.DONE:
|
||||
yield rag_message
|
||||
if generated_message.status != ApiStatusType.DONE:
|
||||
yield generated_message
|
||||
|
||||
if rag_message is None:
|
||||
if generated_message is None:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="RAG search did not return a valid response."
|
||||
@ -180,20 +149,30 @@ JSON RESPONSE:"""
|
||||
yield error_message
|
||||
return
|
||||
|
||||
if not isinstance(generated_message, ChatMessageRagSearch):
|
||||
logger.error(f"Expected ChatMessageRagSearch, got {type(generated_message)}")
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="RAG search did not return a valid response."
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
rag_message : ChatMessageRagSearch = generated_message
|
||||
|
||||
rag_context = self.get_rag_context(rag_message)
|
||||
logger.info(f"🔍 RAG content retrieved {len(rag_context)} bytes of context")
|
||||
system_prompt, prompt = self.generate_skill_assessment_prompt(skill=skill, rag_context=rag_context)
|
||||
|
||||
skill_message = None
|
||||
async for skill_message in self.llm_one_shot(llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt, temperature=0.7):
|
||||
if skill_message.status == ApiStatusType.ERROR:
|
||||
logger.error(f"⚠️ {skill_message.content}")
|
||||
yield skill_message
|
||||
generated_message = None
|
||||
async for generated_message in self.llm_one_shot(llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt, temperature=0.7):
|
||||
if generated_message.status == ApiStatusType.ERROR:
|
||||
logger.error(f"⚠️ {generated_message.content}")
|
||||
yield generated_message
|
||||
return
|
||||
if skill_message.status != ApiStatusType.DONE:
|
||||
yield skill_message
|
||||
if generated_message.status != ApiStatusType.DONE:
|
||||
yield generated_message
|
||||
|
||||
if skill_message is None:
|
||||
if generated_message is None:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Skill assessment failed to generate a response."
|
||||
@ -202,7 +181,15 @@ JSON RESPONSE:"""
|
||||
yield error_message
|
||||
return
|
||||
|
||||
json_str = self.extract_json_from_text(skill_message.content)
|
||||
if not isinstance(generated_message, ChatMessage):
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Skill assessment did not return a valid message."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
return
|
||||
json_str = self.extract_json_from_text(generated_message.content)
|
||||
skill_assessment_data = ""
|
||||
skill_assessment = None
|
||||
try:
|
||||
@ -227,7 +214,7 @@ JSON RESPONSE:"""
|
||||
except Exception as e:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content=f"Failed to parse Skill assessment JSON: {str(e)}\n\n{skill_message.content}\n\nJSON:\n{json_str}\n\n"
|
||||
content=f"Failed to parse Skill assessment JSON: {str(e)}\n\n{generated_message.content}\n\nJSON:\n{json_str}\n\n"
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
@ -250,7 +237,7 @@ JSON RESPONSE:"""
|
||||
session_id=session_id,
|
||||
status=ApiStatusType.DONE,
|
||||
content=json.dumps(skill_assessment_data),
|
||||
metadata=skill_message.metadata,
|
||||
metadata=generated_message.metadata,
|
||||
skill_assessment=skill_assessment,
|
||||
)
|
||||
yield skill_assessment_message
|
||||
|
@ -1,5 +1,6 @@
|
||||
import redis.asyncio as redis # type: ignore
|
||||
from typing import Optional, Dict, List, Optional, Any
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from redis.asyncio import (Redis, ConnectionPool)
|
||||
from typing import Any, Optional, Dict, List, Optional, TypeGuard, Union
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@ -15,7 +16,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class _RedisManager:
|
||||
def __init__(self):
|
||||
self.redis: Optional[redis.Redis] = None
|
||||
self.redis: Optional[Redis] = None
|
||||
self.redis_url = os.getenv("REDIS_URL", "redis://redis:6379")
|
||||
self.redis_db = int(os.getenv("REDIS_DB", "0"))
|
||||
|
||||
@ -23,7 +24,7 @@ class _RedisManager:
|
||||
if not self.redis_url.endswith(f"/{self.redis_db}"):
|
||||
self.redis_url = f"{self.redis_url}/{self.redis_db}"
|
||||
|
||||
self._connection_pool: Optional[redis.ConnectionPool] = None
|
||||
self._connection_pool: Optional[ConnectionPool] = None
|
||||
self._is_connected = False
|
||||
|
||||
async def connect(self):
|
||||
@ -34,7 +35,7 @@ class _RedisManager:
|
||||
|
||||
try:
|
||||
# Create connection pool for better resource management
|
||||
self._connection_pool = redis.ConnectionPool.from_url(
|
||||
self._connection_pool = ConnectionPool.from_url(
|
||||
self.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
@ -45,7 +46,7 @@ class _RedisManager:
|
||||
health_check_interval=30
|
||||
)
|
||||
|
||||
self.redis = redis.Redis(
|
||||
self.redis = Redis(
|
||||
connection_pool=self._connection_pool
|
||||
)
|
||||
|
||||
@ -101,7 +102,7 @@ class _RedisManager:
|
||||
self.redis = None
|
||||
self._connection_pool = None
|
||||
|
||||
def get_client(self) -> redis.Redis:
|
||||
def get_client(self) -> Redis:
|
||||
"""Get Redis client instance"""
|
||||
if not self._is_connected or not self.redis:
|
||||
raise RuntimeError("Redis client not initialized or disconnected")
|
||||
@ -174,7 +175,7 @@ class _RedisManager:
|
||||
return None
|
||||
|
||||
class RedisDatabase:
|
||||
def __init__(self, redis: redis.Redis):
|
||||
def __init__(self, redis: Redis):
|
||||
self.redis = redis
|
||||
|
||||
# Redis key prefixes for different data types
|
||||
@ -227,7 +228,7 @@ class RedisDatabase:
|
||||
|
||||
# Add resume_id to user's resume list
|
||||
user_resumes_key = f"{self.KEY_PREFIXES['user_resumes']}{user_id}"
|
||||
await self.redis.rpush(user_resumes_key, resume_id)
|
||||
await self.redis.rpush(user_resumes_key, resume_id) # type: ignore
|
||||
|
||||
logger.info(f"📄 Saved resume {resume_id} for user {user_id}")
|
||||
return True
|
||||
@ -255,7 +256,7 @@ class RedisDatabase:
|
||||
try:
|
||||
# Get all resume IDs for this user
|
||||
user_resumes_key = f"{self.KEY_PREFIXES['user_resumes']}{user_id}"
|
||||
resume_ids = await self.redis.lrange(user_resumes_key, 0, -1)
|
||||
resume_ids = await self.redis.lrange(user_resumes_key, 0, -1)# type: ignore
|
||||
|
||||
if not resume_ids:
|
||||
logger.info(f"📄 No resumes found for user {user_id}")
|
||||
@ -275,7 +276,7 @@ class RedisDatabase:
|
||||
resumes.append(resume_data)
|
||||
else:
|
||||
# Clean up orphaned resume ID
|
||||
await self.redis.lrem(user_resumes_key, 0, resume_id)
|
||||
await self.redis.lrem(user_resumes_key, 0, resume_id)# type: ignore
|
||||
logger.warning(f"Removed orphaned resume ID {resume_id} for user {user_id}")
|
||||
|
||||
# Sort by created_at timestamp (most recent first)
|
||||
@ -296,7 +297,7 @@ class RedisDatabase:
|
||||
|
||||
# Remove from user's resume list
|
||||
user_resumes_key = f"{self.KEY_PREFIXES['user_resumes']}{user_id}"
|
||||
await self.redis.lrem(user_resumes_key, 0, resume_id)
|
||||
await self.redis.lrem(user_resumes_key, 0, resume_id)# type: ignore
|
||||
|
||||
if result > 0:
|
||||
logger.info(f"🗑️ Deleted resume {resume_id} for user {user_id}")
|
||||
@ -313,7 +314,7 @@ class RedisDatabase:
|
||||
try:
|
||||
# Get all resume IDs for this user
|
||||
user_resumes_key = f"{self.KEY_PREFIXES['user_resumes']}{user_id}"
|
||||
resume_ids = await self.redis.lrange(user_resumes_key, 0, -1)
|
||||
resume_ids = await self.redis.lrange(user_resumes_key, 0, -1)# type: ignore
|
||||
|
||||
if not resume_ids:
|
||||
logger.info(f"📄 No resumes found for user {user_id}")
|
||||
@ -513,7 +514,7 @@ class RedisDatabase:
|
||||
try:
|
||||
# Get all document IDs for this candidate
|
||||
key = f"{self.KEY_PREFIXES['candidate_documents']}{candidate_id}"
|
||||
document_ids = await self.redis.lrange(key, 0, -1)
|
||||
document_ids = await self.redis.lrange(key, 0, -1)# type: ignore
|
||||
|
||||
if not document_ids:
|
||||
logger.info(f"No documents found for candidate {candidate_id}")
|
||||
@ -656,7 +657,7 @@ class RedisDatabase:
|
||||
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]:
|
||||
"""Get all documents for a specific candidate"""
|
||||
key = f"{self.KEY_PREFIXES['candidate_documents']}{candidate_id}"
|
||||
document_ids = await self.redis.lrange(key, 0, -1)
|
||||
document_ids = await self.redis.lrange(key, 0, -1) # type: ignore
|
||||
|
||||
if not document_ids:
|
||||
return []
|
||||
@ -675,7 +676,7 @@ class RedisDatabase:
|
||||
documents.append(doc_data)
|
||||
else:
|
||||
# Clean up orphaned document ID
|
||||
await self.redis.lrem(key, 0, doc_id)
|
||||
await self.redis.lrem(key, 0, doc_id)# type: ignore
|
||||
logger.warning(f"Removed orphaned document ID {doc_id} for candidate {candidate_id}")
|
||||
|
||||
return documents
|
||||
@ -683,12 +684,12 @@ class RedisDatabase:
|
||||
async def add_document_to_candidate(self, candidate_id: str, document_id: str):
|
||||
"""Add a document ID to a candidate's document list"""
|
||||
key = f"{self.KEY_PREFIXES['candidate_documents']}{candidate_id}"
|
||||
await self.redis.rpush(key, document_id)
|
||||
await self.redis.rpush(key, document_id)# type: ignore
|
||||
|
||||
async def remove_document_from_candidate(self, candidate_id: str, document_id: str):
|
||||
"""Remove a document ID from a candidate's document list"""
|
||||
key = f"{self.KEY_PREFIXES['candidate_documents']}{candidate_id}"
|
||||
await self.redis.lrem(key, 0, document_id)
|
||||
await self.redis.lrem(key, 0, document_id)# type: ignore
|
||||
|
||||
async def update_document(self, document_id: str, updates: Dict):
|
||||
"""Update document metadata"""
|
||||
@ -720,7 +721,7 @@ class RedisDatabase:
|
||||
async def get_document_count_for_candidate(self, candidate_id: str) -> int:
|
||||
"""Get total number of documents for a candidate"""
|
||||
key = f"{self.KEY_PREFIXES['candidate_documents']}{candidate_id}"
|
||||
return await self.redis.llen(key)
|
||||
return await self.redis.llen(key)# type: ignore
|
||||
|
||||
async def search_candidate_documents(self, candidate_id: str, query: str) -> List[Dict]:
|
||||
"""Search documents by filename for a candidate"""
|
||||
@ -1712,7 +1713,7 @@ class RedisDatabase:
|
||||
try:
|
||||
# Remove from the session's message list
|
||||
key = f"{self.KEY_PREFIXES['chat_messages']}{session_id}"
|
||||
await self.redis.lrem(key, 0, message_id)
|
||||
await self.redis.lrem(key, 0, message_id)# type: ignore
|
||||
# Delete the message data itself
|
||||
result = await self.redis.delete(f"chat_message:{message_id}")
|
||||
return result > 0
|
||||
@ -1724,13 +1725,13 @@ class RedisDatabase:
|
||||
async def get_chat_messages(self, session_id: str) -> List[Dict]:
|
||||
"""Get chat messages for a session"""
|
||||
key = f"{self.KEY_PREFIXES['chat_messages']}{session_id}"
|
||||
messages = await self.redis.lrange(key, 0, -1)
|
||||
messages = await self.redis.lrange(key, 0, -1)# type: ignore
|
||||
return [self._deserialize(msg) for msg in messages if msg]
|
||||
|
||||
async def add_chat_message(self, session_id: str, message_data: Dict):
|
||||
"""Add a chat message to a session"""
|
||||
key = f"{self.KEY_PREFIXES['chat_messages']}{session_id}"
|
||||
await self.redis.rpush(key, self._serialize(message_data))
|
||||
await self.redis.rpush(key, self._serialize(message_data))# type: ignore
|
||||
|
||||
async def set_chat_messages(self, session_id: str, messages: List[Dict]):
|
||||
"""Set all chat messages for a session (replaces existing)"""
|
||||
@ -1742,7 +1743,7 @@ class RedisDatabase:
|
||||
# Add new messages
|
||||
if messages:
|
||||
serialized_messages = [self._serialize(msg) for msg in messages]
|
||||
await self.redis.rpush(key, *serialized_messages)
|
||||
await self.redis.rpush(key, *serialized_messages)# type: ignore
|
||||
|
||||
async def get_all_chat_messages(self) -> Dict[str, List[Dict]]:
|
||||
"""Get all chat messages grouped by session"""
|
||||
@ -1755,7 +1756,7 @@ class RedisDatabase:
|
||||
result = {}
|
||||
for key in keys:
|
||||
session_id = key.replace(self.KEY_PREFIXES['chat_messages'], '')
|
||||
messages = await self.redis.lrange(key, 0, -1)
|
||||
messages = await self.redis.lrange(key, 0, -1)# type: ignore
|
||||
result[session_id] = [self._deserialize(msg) for msg in messages if msg]
|
||||
|
||||
return result
|
||||
@ -1810,7 +1811,7 @@ class RedisDatabase:
|
||||
async def get_chat_message_count(self, session_id: str) -> int:
|
||||
"""Get the total number of messages in a chat session"""
|
||||
key = f"{self.KEY_PREFIXES['chat_messages']}{session_id}"
|
||||
return await self.redis.llen(key)
|
||||
return await self.redis.llen(key)# type: ignore
|
||||
|
||||
async def search_chat_messages(self, session_id: str, query: str) -> List[Dict]:
|
||||
"""Search for messages containing specific text in a session"""
|
||||
@ -2178,7 +2179,7 @@ class RedisDatabase:
|
||||
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get user lookup data by user ID"""
|
||||
try:
|
||||
data = await self.redis.hget("user_lookup_by_id", user_id)
|
||||
data = await self.redis.hget("user_lookup_by_id", user_id)# type: ignore
|
||||
if data:
|
||||
return json.loads(data)
|
||||
return None
|
||||
@ -2396,10 +2397,10 @@ class RedisDatabase:
|
||||
}
|
||||
|
||||
# Add to list (latest events first)
|
||||
await self.redis.lpush(key, json.dumps(event_data, default=str))
|
||||
await self.redis.lpush(key, json.dumps(event_data, default=str))# type: ignore
|
||||
|
||||
# Keep only last 100 events per day
|
||||
await self.redis.ltrim(key, 0, 99)
|
||||
await self.redis.ltrim(key, 0, 99)# type: ignore
|
||||
|
||||
# Set expiration for 30 days
|
||||
await self.redis.expire(key, 30 * 24 * 60 * 60)
|
||||
@ -2418,7 +2419,7 @@ class RedisDatabase:
|
||||
date = (datetime.now(timezone.utc) - timedelta(days=i)).strftime('%Y-%m-%d')
|
||||
key = f"security_log:{user_id}:{date}"
|
||||
|
||||
daily_events = await self.redis.lrange(key, 0, -1)
|
||||
daily_events = await self.redis.lrange(key, 0, -1)# type: ignore
|
||||
for event_json in daily_events:
|
||||
events.append(json.loads(event_json))
|
||||
|
||||
@ -2440,7 +2441,7 @@ class RedisDatabase:
|
||||
guest_data["last_activity"] = datetime.now(UTC).isoformat()
|
||||
|
||||
# Store in Redis with both hash and individual key for redundancy
|
||||
await self.redis.hset("guests", guest_id, json.dumps(guest_data))
|
||||
await self.redis.hset("guests", guest_id, json.dumps(guest_data))# type: ignore
|
||||
|
||||
# Also store with a longer TTL as backup
|
||||
await self.redis.setex(
|
||||
@ -2458,7 +2459,7 @@ class RedisDatabase:
|
||||
"""Get guest data with fallback to backup"""
|
||||
try:
|
||||
# Try primary storage first
|
||||
data = await self.redis.hget("guests", guest_id)
|
||||
data = await self.redis.hget("guests", guest_id)# type: ignore
|
||||
if data:
|
||||
guest_data = json.loads(data)
|
||||
# Update last activity when accessed
|
||||
@ -2499,7 +2500,7 @@ class RedisDatabase:
|
||||
async def get_all_guests(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all guests"""
|
||||
try:
|
||||
data = await self.redis.hgetall("guests")
|
||||
data = await self.redis.hgetall("guests")# type: ignore
|
||||
return {
|
||||
guest_id: json.loads(guest_json)
|
||||
for guest_id, guest_json in data.items()
|
||||
@ -2511,7 +2512,7 @@ class RedisDatabase:
|
||||
async def delete_guest(self, guest_id: str) -> bool:
|
||||
"""Delete a guest"""
|
||||
try:
|
||||
result = await self.redis.hdel("guests", guest_id)
|
||||
result = await self.redis.hdel("guests", guest_id)# type: ignore
|
||||
if result:
|
||||
logger.info(f"🗑️ Guest deleted: {guest_id}")
|
||||
return True
|
||||
|
@ -1,3 +0,0 @@
|
||||
from .candidate_entity import CandidateEntity
|
||||
from .entity_manager import entity_manager, get_candidate_entity
|
||||
|
@ -1,160 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
||||
from uuid import uuid4
|
||||
from typing import List, Optional, Generator, ClassVar, Any, Dict, TYPE_CHECKING, Literal
|
||||
|
||||
from typing_extensions import Annotated, Union
|
||||
import numpy as np # type: ignore
|
||||
|
||||
from uuid import uuid4
|
||||
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
||||
import traceback
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from rag import start_file_watcher, ChromaDBFileWatcher
|
||||
import defines
|
||||
from logger import logger
|
||||
import agents as agents
|
||||
from models import (Tunables, CandidateQuestion, ChatMessageUser, ChatMessage, RagEntry, ChatMessageMetaData, ApiStatusType, Candidate, ChatContextType)
|
||||
import llm_proxy as llm_manager
|
||||
from agents.base import Agent
|
||||
from database import RedisDatabase
|
||||
from models import ChromaDBGetResponse
|
||||
|
||||
class CandidateEntity(Candidate):
|
||||
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
|
||||
|
||||
# Internal instance members
|
||||
CandidateEntity__agents: List[Agent] = []
|
||||
CandidateEntity__observer: Optional[Any] = Field(default=None, exclude=True)
|
||||
CandidateEntity__file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
|
||||
CandidateEntity__prometheus_collector: Optional[CollectorRegistry] = Field(
|
||||
default=None, exclude=True
|
||||
)
|
||||
|
||||
def __init__(self, candidate=None):
|
||||
if candidate is not None:
|
||||
# Copy attributes from the candidate instance
|
||||
super().__init__(**vars(candidate))
|
||||
else:
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def exists(cls, username: str):
|
||||
# Validate username format (only allow safe characters)
|
||||
if not re.match(r'^[a-zA-Z0-9_-]+$', username):
|
||||
return False # Invalid username characters
|
||||
|
||||
# Check for minimum and maximum length
|
||||
if not (3 <= len(username) <= 32):
|
||||
return False # Invalid username length
|
||||
|
||||
# Use Path for safe path handling and normalization
|
||||
user_dir = Path(defines.user_dir) / username
|
||||
user_info_path = user_dir / defines.user_info_file
|
||||
|
||||
# Ensure the final path is actually within the intended parent directory
|
||||
# to help prevent directory traversal attacks
|
||||
try:
|
||||
if not user_dir.resolve().is_relative_to(Path(defines.user_dir).resolve()):
|
||||
return False # Path traversal attempt detected
|
||||
except (ValueError, RuntimeError): # Potential exceptions from resolve()
|
||||
return False
|
||||
|
||||
# Check if file exists
|
||||
return user_info_path.is_file()
|
||||
|
||||
def get_or_create_agent(self, agent_type: ChatContextType) -> agents.Agent:
|
||||
"""
|
||||
Get or create an agent of the specified type for this candidate.
|
||||
|
||||
Args:
|
||||
agent_type: The type of agent to create (default is 'candidate_chat').
|
||||
**kwargs: Additional fields required by the specific agent subclass.
|
||||
|
||||
Returns:
|
||||
The created agent instance.
|
||||
"""
|
||||
|
||||
# Only instantiate one agent of each type per user
|
||||
for agent in self.CandidateEntity__agents:
|
||||
if agent.agent_type == agent_type:
|
||||
return agent
|
||||
|
||||
return agents.get_or_create_agent(
|
||||
agent_type=agent_type,
|
||||
user=self,
|
||||
prometheus_collector=self.prometheus_collector
|
||||
)
|
||||
|
||||
# Wrapper properties that map into file_watcher
|
||||
@property
|
||||
def umap_collection(self) -> ChromaDBGetResponse:
|
||||
if not self.CandidateEntity__file_watcher:
|
||||
raise ValueError("initialize() has not been called.")
|
||||
return self.CandidateEntity__file_watcher.umap_collection
|
||||
|
||||
# Fields managed by initialize()
|
||||
CandidateEntity__initialized: bool = Field(default=False, exclude=True)
|
||||
@property
|
||||
def file_watcher(self) -> ChromaDBFileWatcher:
|
||||
if not self.CandidateEntity__file_watcher:
|
||||
raise ValueError("initialize() has not been called.")
|
||||
return self.CandidateEntity__file_watcher
|
||||
|
||||
@property
|
||||
def prometheus_collector(self) -> CollectorRegistry:
|
||||
if not self.CandidateEntity__prometheus_collector:
|
||||
raise ValueError("initialize() has not been called with a prometheus_collector.")
|
||||
return self.CandidateEntity__prometheus_collector
|
||||
|
||||
@property
|
||||
def observer(self) -> Any:
|
||||
if not self.CandidateEntity__observer:
|
||||
raise ValueError("initialize() has not been called.")
|
||||
return self.CandidateEntity__observer
|
||||
|
||||
async def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
|
||||
if self.CandidateEntity__initialized:
|
||||
# Initialization can only be attempted once; if there are multiple attempts, it means
|
||||
# a subsystem is failing or there is a logic bug in the code.
|
||||
#
|
||||
# NOTE: It is intentional that self.CandidateEntity__initialize = True regardless of whether it
|
||||
# succeeded. This prevents server loops on failure
|
||||
raise ValueError("initialize can only be attempted once")
|
||||
self.CandidateEntity__initialized = True
|
||||
|
||||
if not self.username:
|
||||
raise ValueError("username can not be empty")
|
||||
|
||||
user_dir = os.path.join(defines.user_dir, self.username)
|
||||
vector_db_dir=os.path.join(user_dir, defines.persist_directory)
|
||||
rag_content_dir=os.path.join(user_dir, defines.rag_content_dir)
|
||||
|
||||
os.makedirs(vector_db_dir, exist_ok=True)
|
||||
os.makedirs(rag_content_dir, exist_ok=True)
|
||||
|
||||
if prometheus_collector:
|
||||
self.CandidateEntity__prometheus_collector = prometheus_collector
|
||||
|
||||
self.CandidateEntity__observer, self.CandidateEntity__file_watcher = start_file_watcher(
|
||||
llm=llm_manager.get_llm(),
|
||||
user_id=self.id,
|
||||
collection_name=self.username,
|
||||
persist_directory=vector_db_dir,
|
||||
watch_directory=rag_content_dir,
|
||||
database=database,
|
||||
recreate=False, # Don't recreate if exists
|
||||
)
|
||||
has_username_rag = any(item["name"] == self.username for item in self.rags)
|
||||
if not has_username_rag:
|
||||
self.rags.append(RagEntry(
|
||||
name=self.username,
|
||||
description=f"Expert data about {self.full_name}.",
|
||||
))
|
||||
self.rag_content_size = self.file_watcher.collection.count()
|
||||
|
||||
CandidateEntity.model_rebuild()
|
@ -1,16 +1,18 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
from uuid import uuid4
|
||||
import weakref
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional, Any
|
||||
from contextlib import asynccontextmanager
|
||||
from pydantic import BaseModel, Field # type: ignore
|
||||
|
||||
from models import ( Candidate )
|
||||
from .candidate_entity import CandidateEntity
|
||||
from models import Candidate
|
||||
from agents.base import CandidateEntity
|
||||
from database import RedisDatabase
|
||||
from prometheus_client import CollectorRegistry # type: ignore
|
||||
|
||||
class EntityManager:
|
||||
class EntityManager(BaseModel):
|
||||
"""Manages lifecycle of CandidateEntity instances"""
|
||||
|
||||
def __init__(self, default_ttl_minutes: int = 30):
|
||||
@ -19,6 +21,7 @@ class EntityManager:
|
||||
self._ttl_minutes = default_ttl_minutes
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
self._prometheus_collector: Optional[CollectorRegistry] = None
|
||||
self._database: Optional[RedisDatabase] = None
|
||||
|
||||
async def start_cleanup_task(self):
|
||||
"""Start background cleanup task"""
|
||||
@ -35,7 +38,10 @@ class EntityManager:
|
||||
pass
|
||||
self._cleanup_task = None
|
||||
|
||||
def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
|
||||
def initialize(
|
||||
self,
|
||||
prometheus_collector: CollectorRegistry,
|
||||
database: RedisDatabase):
|
||||
"""Initialize the EntityManager with Prometheus collector"""
|
||||
self._prometheus_collector = prometheus_collector
|
||||
self._database = database
|
||||
@ -44,21 +50,26 @@ class EntityManager:
|
||||
"""Get or create CandidateEntity with proper reference tracking"""
|
||||
|
||||
# Check if entity exists and is still valid
|
||||
if candidate.id in self._entities:
|
||||
if id in self._entities:
|
||||
entity = self._entities[candidate.id]
|
||||
entity._last_accessed = datetime.now()
|
||||
entity._reference_count += 1
|
||||
entity.last_accessed = datetime.now()
|
||||
entity.reference_count += 1
|
||||
return entity
|
||||
|
||||
if not self._prometheus_collector or not self._database:
|
||||
raise ValueError("EntityManager has not been initialized with required components.")
|
||||
|
||||
entity = CandidateEntity(candidate=candidate)
|
||||
await entity.initialize(prometheus_collector=self._prometheus_collector, database=self._database)
|
||||
await entity.initialize(
|
||||
prometheus_collector=self._prometheus_collector,
|
||||
database=self._database)
|
||||
|
||||
# Store with reference tracking
|
||||
self._entities[candidate.id] = entity
|
||||
self._weak_refs[candidate.id] = weakref.ref(entity, self._on_entity_deleted(candidate.id))
|
||||
|
||||
entity._reference_count = 1
|
||||
entity._last_accessed = datetime.now()
|
||||
entity.reference_count = 1
|
||||
entity.last_accessed = datetime.now()
|
||||
|
||||
return entity
|
||||
|
||||
@ -106,8 +117,8 @@ class EntityManager:
|
||||
"""Explicitly release reference to entity"""
|
||||
if user_id in self._entities:
|
||||
entity = self._entities[user_id]
|
||||
entity._reference_count = max(0, entity._reference_count - 1)
|
||||
entity._last_accessed = datetime.now()
|
||||
entity.reference_count = max(0, entity.reference_count - 1)
|
||||
entity.last_accessed = datetime.now()
|
||||
|
||||
async def _periodic_cleanup(self):
|
||||
"""Background task to clean up expired entities"""
|
||||
@ -126,11 +137,11 @@ class EntityManager:
|
||||
expired_entities = []
|
||||
|
||||
for user_id, entity in list(self._entities.items()):
|
||||
time_since_access = current_time - entity._last_accessed
|
||||
time_since_access = current_time - entity.last_accessed
|
||||
|
||||
# Remove if TTL exceeded and no active references
|
||||
if (time_since_access > timedelta(minutes=self._ttl_minutes)
|
||||
and entity._reference_count == 0):
|
||||
and entity.reference_count == 0):
|
||||
expired_entities.append(user_id)
|
||||
|
||||
for user_id in expired_entities:
|
||||
@ -154,3 +165,5 @@ async def get_candidate_entity(candidate: Candidate):
|
||||
yield entity
|
||||
finally:
|
||||
await entity_manager.release_entity(candidate.id)
|
||||
|
||||
EntityManager.model_rebuild()
|
||||
|
6211
src/backend/main.py
6211
src/backend/main.py
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
26
src/backend/routes/__init__.py
Normal file
26
src/backend/routes/__init__.py
Normal file
@ -0,0 +1,26 @@
|
||||
"""
|
||||
Routes package - API route modules
|
||||
"""
|
||||
|
||||
# Import all route modules for easy access
|
||||
from . import auth
|
||||
from . import candidates
|
||||
from . import resumes
|
||||
from . import jobs
|
||||
from . import chat
|
||||
from . import users
|
||||
from . import employers
|
||||
from . import admin
|
||||
from . import system
|
||||
|
||||
__all__ = [
|
||||
"auth",
|
||||
"candidates",
|
||||
"resumes",
|
||||
"jobs",
|
||||
"chat",
|
||||
"users",
|
||||
"employers",
|
||||
"admin",
|
||||
"system"
|
||||
]
|
816
src/backend/routes/admin.py
Normal file
816
src/backend/routes/admin.py
Normal file
@ -0,0 +1,816 @@
|
||||
"""
|
||||
Chat routes
|
||||
"""
|
||||
import json
|
||||
import jwt
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone, UTC
|
||||
from typing import (Optional, List, Dict, Any)
|
||||
|
||||
from fastapi import (
|
||||
APIRouter, HTTPException, Depends, Body, Request, BackgroundTasks,
|
||||
FastAPI, HTTPException, Depends, Query, Path, Body, status,
|
||||
APIRouter, Request, BackgroundTasks, File, UploadFile, Form
|
||||
)
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from utils.rate_limiter import RateLimiter, get_rate_limiter
|
||||
from database import RedisDatabase
|
||||
from logger import logger
|
||||
from utils.dependencies import (
|
||||
get_current_admin, get_current_user_or_guest, get_database, background_task_manager
|
||||
)
|
||||
from utils.responses import (
|
||||
create_paginated_response, create_success_response, create_error_response,
|
||||
)
|
||||
|
||||
|
||||
# Create router for authentication endpoints
|
||||
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
|
||||
@router.post("/tasks/cleanup-guests")
|
||||
async def manual_guest_cleanup(
|
||||
inactive_hours: int = Body(24, embed=True),
|
||||
current_user = Depends(get_current_admin),
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
"""Manually trigger guest cleanup (admin only)"""
|
||||
try:
|
||||
if not background_task_manager:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available")
|
||||
)
|
||||
|
||||
cleaned_count = await background_task_manager.cleanup_inactive_guests(inactive_hours)
|
||||
|
||||
logger.info(f"🧹 Manual guest cleanup triggered by admin {admin_user.id}: {cleaned_count} guests cleaned")
|
||||
|
||||
return create_success_response({
|
||||
"message": f"Guest cleanup completed. Removed {cleaned_count} inactive sessions.",
|
||||
"cleaned_count": cleaned_count,
|
||||
"triggered_by": admin_user.id
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Manual guest cleanup error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("CLEANUP_ERROR", str(e))
|
||||
)
|
||||
|
||||
@router.post("/tasks/cleanup-tokens")
|
||||
async def manual_token_cleanup(
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
"""Manually trigger verification token cleanup (admin only)"""
|
||||
try:
|
||||
global background_task_manager
|
||||
|
||||
if not background_task_manager:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available")
|
||||
)
|
||||
|
||||
cleaned_count = await background_task_manager.cleanup_expired_verification_tokens()
|
||||
|
||||
logger.info(f"🧹 Manual token cleanup triggered by admin {admin_user.id}: {cleaned_count} tokens cleaned")
|
||||
|
||||
return create_success_response({
|
||||
"message": f"Token cleanup completed. Removed {cleaned_count} expired tokens.",
|
||||
"cleaned_count": cleaned_count,
|
||||
"triggered_by": admin_user.id
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Manual token cleanup error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("CLEANUP_ERROR", str(e))
|
||||
)
|
||||
|
||||
@router.post("/tasks/cleanup-rate-limits")
|
||||
async def manual_rate_limit_cleanup(
|
||||
days_old: int = Body(7, embed=True),
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
"""Manually trigger rate limit data cleanup (admin only)"""
|
||||
try:
|
||||
global background_task_manager
|
||||
|
||||
if not background_task_manager:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available")
|
||||
)
|
||||
|
||||
cleaned_count = await background_task_manager.cleanup_old_rate_limit_data(days_old)
|
||||
|
||||
logger.info(f"🧹 Manual rate limit cleanup triggered by admin {admin_user.id}: {cleaned_count} keys cleaned")
|
||||
|
||||
return create_success_response({
|
||||
"message": f"Rate limit cleanup completed. Removed {cleaned_count} old keys.",
|
||||
"cleaned_count": cleaned_count,
|
||||
"triggered_by": admin_user.id
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Manual rate limit cleanup error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("CLEANUP_ERROR", str(e))
|
||||
)
|
||||
|
||||
# ========================================
|
||||
# System Health and Maintenance Endpoints
|
||||
# ========================================
|
||||
|
||||
@router.get("/system/health")
|
||||
async def get_system_health(
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
"""Get comprehensive system health status (admin only)"""
|
||||
try:
|
||||
# Database health
|
||||
database_manager = getattr(request.app.state, 'database_manager', None)
|
||||
db_health = {"status": "unavailable", "healthy": False}
|
||||
|
||||
if database_manager:
|
||||
try:
|
||||
database = database_manager.get_database()
|
||||
from database import redis_manager
|
||||
redis_health = await redis_manager.health_check()
|
||||
db_health = {
|
||||
"status": redis_health.get("status", "unknown"),
|
||||
"healthy": redis_health.get("status") == "healthy",
|
||||
"details": redis_health
|
||||
}
|
||||
except Exception as e:
|
||||
db_health = {
|
||||
"status": "error",
|
||||
"healthy": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
# Background task health
|
||||
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
|
||||
task_health = {"status": "unavailable", "healthy": False}
|
||||
|
||||
if background_task_manager:
|
||||
try:
|
||||
task_status = await background_task_manager.get_task_status()
|
||||
running_tasks = len([t for t in task_status["tasks"] if t["status"] == "running"])
|
||||
failed_tasks = len([t for t in task_status["tasks"] if t["status"] == "failed"])
|
||||
|
||||
task_health = {
|
||||
"status": "healthy" if task_status["running"] and failed_tasks == 0 else "degraded",
|
||||
"healthy": task_status["running"] and failed_tasks == 0,
|
||||
"running_tasks": running_tasks,
|
||||
"failed_tasks": failed_tasks,
|
||||
"total_tasks": task_status["task_count"]
|
||||
}
|
||||
except Exception as e:
|
||||
task_health = {
|
||||
"status": "error",
|
||||
"healthy": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
# Overall health
|
||||
overall_healthy = db_health["healthy"] and task_health["healthy"]
|
||||
|
||||
return create_success_response({
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"overall_healthy": overall_healthy,
|
||||
"components": {
|
||||
"database": db_health,
|
||||
"background_tasks": task_health
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting system health: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("HEALTH_CHECK_ERROR", str(e))
|
||||
)
|
||||
|
||||
@router.post("/maintenance/cleanup")
|
||||
async def run_maintenance_cleanup(
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Run comprehensive maintenance cleanup (admin only)"""
|
||||
try:
|
||||
cleanup_results = {}
|
||||
|
||||
# Run various cleanup operations
|
||||
cleanup_operations = [
|
||||
("inactive_guests", lambda: database.cleanup_inactive_guests(72)), # 3 days
|
||||
("expired_tokens", lambda: database.cleanup_expired_verification_tokens()),
|
||||
("orphaned_job_requirements", lambda: database.cleanup_orphaned_job_requirements()),
|
||||
]
|
||||
|
||||
for operation_name, operation_func in cleanup_operations:
|
||||
try:
|
||||
result = await operation_func()
|
||||
cleanup_results[operation_name] = {
|
||||
"success": True,
|
||||
"cleaned_count": result,
|
||||
"message": f"Cleaned {result} items"
|
||||
}
|
||||
except Exception as e:
|
||||
cleanup_results[operation_name] = {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": f"Failed: {str(e)}"
|
||||
}
|
||||
|
||||
# Calculate totals
|
||||
total_cleaned = sum(
|
||||
result.get("cleaned_count", 0)
|
||||
for result in cleanup_results.values()
|
||||
if result.get("success", False)
|
||||
)
|
||||
|
||||
successful_operations = len([
|
||||
r for r in cleanup_results.values()
|
||||
if r.get("success", False)
|
||||
])
|
||||
|
||||
return create_success_response({
|
||||
"message": f"Maintenance cleanup completed. {total_cleaned} items cleaned across {successful_operations} operations.",
|
||||
"total_cleaned": total_cleaned,
|
||||
"successful_operations": successful_operations,
|
||||
"details": cleanup_results
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error in maintenance cleanup: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("CLEANUP_ERROR", str(e))
|
||||
)
|
||||
|
||||
# ========================================
|
||||
# Background Task Statistics
|
||||
# ========================================
|
||||
|
||||
@router.get("/tasks/stats")
|
||||
async def get_task_statistics(
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Get background task execution statistics (admin only)"""
|
||||
try:
|
||||
# Get guest statistics
|
||||
guest_stats = await database.get_guest_statistics()
|
||||
|
||||
# Get background task manager status
|
||||
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
|
||||
task_manager_stats = {}
|
||||
|
||||
if background_task_manager:
|
||||
task_status = await background_task_manager.get_task_status()
|
||||
task_manager_stats = {
|
||||
"running": task_status["running"],
|
||||
"task_count": task_status["task_count"],
|
||||
"task_breakdown": {}
|
||||
}
|
||||
|
||||
# Count tasks by status
|
||||
for task in task_status["tasks"]:
|
||||
status = task["status"]
|
||||
task_manager_stats["task_breakdown"][status] = task_manager_stats["task_breakdown"].get(status, 0) + 1
|
||||
|
||||
return create_success_response({
|
||||
"guest_statistics": guest_stats,
|
||||
"task_manager": task_manager_stats,
|
||||
"timestamp": datetime.now(UTC).isoformat()
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting task statistics: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STATS_ERROR", str(e))
|
||||
)
|
||||
|
||||
# ========================================
|
||||
# Background Task Status Endpoints
|
||||
# ========================================
|
||||
|
||||
@router.get("/tasks/status")
|
||||
async def get_background_task_status(
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
"""Get background task manager status (admin only)"""
|
||||
try:
|
||||
# Get background task manager from app state
|
||||
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
|
||||
|
||||
if not background_task_manager:
|
||||
return create_success_response({
|
||||
"running": False,
|
||||
"message": "Background task manager not initialized",
|
||||
"tasks": [],
|
||||
"task_count": 0
|
||||
})
|
||||
|
||||
# Get comprehensive task status using the new method
|
||||
task_status = await background_task_manager.get_task_status()
|
||||
|
||||
# Add additional system info
|
||||
system_info = {
|
||||
"uptime_seconds": None, # Could calculate from start time if stored
|
||||
"last_cleanup": None, # Could track last cleanup time
|
||||
}
|
||||
|
||||
# Format the response
|
||||
return create_success_response({
|
||||
"running": task_status["running"],
|
||||
"task_count": task_status["task_count"],
|
||||
"loop_status": {
|
||||
"main_loop_id": task_status["main_loop_id"],
|
||||
"current_loop_id": task_status["current_loop_id"],
|
||||
"loop_matches": task_status.get("loop_matches", False)
|
||||
},
|
||||
"tasks": task_status["tasks"],
|
||||
"system_info": system_info
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get task status error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STATUS_ERROR", str(e))
|
||||
)
|
||||
|
||||
@router.post("/tasks/run/{task_name}")
|
||||
async def run_background_task(
|
||||
task_name: str,
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
"""Manually trigger a specific background task (admin only)"""
|
||||
try:
|
||||
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
|
||||
|
||||
if not background_task_manager:
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content=create_error_response(
|
||||
"MANAGER_UNAVAILABLE",
|
||||
"Background task manager not initialized"
|
||||
)
|
||||
)
|
||||
|
||||
# List of available tasks
|
||||
available_tasks = [
|
||||
"guest_cleanup",
|
||||
"token_cleanup",
|
||||
"guest_stats",
|
||||
"rate_limit_cleanup",
|
||||
"orphaned_cleanup"
|
||||
]
|
||||
|
||||
if task_name not in available_tasks:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response(
|
||||
"INVALID_TASK",
|
||||
f"Unknown task: {task_name}. Available: {available_tasks}"
|
||||
)
|
||||
)
|
||||
|
||||
# Run the task
|
||||
result = await background_task_manager.force_run_task(task_name)
|
||||
|
||||
return create_success_response({
|
||||
"task_name": task_name,
|
||||
"result": result,
|
||||
"message": f"Task {task_name} completed successfully"
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("INVALID_TASK", str(e))
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error running task {task_name}: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("TASK_EXECUTION_ERROR", str(e))
|
||||
)
|
||||
|
||||
@router.get("/tasks/list")
|
||||
async def list_available_tasks(
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
"""List all available background tasks (admin only)"""
|
||||
try:
|
||||
tasks = [
|
||||
{
|
||||
"name": "guest_cleanup",
|
||||
"description": "Clean up inactive guest sessions",
|
||||
"interval": "6 hours",
|
||||
"parameters": ["inactive_hours (default: 48)"]
|
||||
},
|
||||
{
|
||||
"name": "token_cleanup",
|
||||
"description": "Clean up expired email verification tokens",
|
||||
"interval": "12 hours",
|
||||
"parameters": []
|
||||
},
|
||||
{
|
||||
"name": "guest_stats",
|
||||
"description": "Update guest usage statistics",
|
||||
"interval": "1 hour",
|
||||
"parameters": []
|
||||
},
|
||||
{
|
||||
"name": "rate_limit_cleanup",
|
||||
"description": "Clean up old rate limiting data",
|
||||
"interval": "24 hours",
|
||||
"parameters": ["days_old (default: 7)"]
|
||||
},
|
||||
{
|
||||
"name": "orphaned_cleanup",
|
||||
"description": "Clean up orphaned database records",
|
||||
"interval": "6 hours",
|
||||
"parameters": []
|
||||
}
|
||||
]
|
||||
|
||||
return create_success_response({
|
||||
"total_tasks": len(tasks),
|
||||
"tasks": tasks
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error listing tasks: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("LIST_ERROR", str(e))
|
||||
)
|
||||
|
||||
@router.post("/tasks/restart")
|
||||
async def restart_background_tasks(
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
"""Restart the background task manager (admin only)"""
|
||||
try:
|
||||
database_manager = getattr(request.app.state, 'database_manager', None)
|
||||
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
|
||||
|
||||
if not database_manager:
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content=create_error_response(
|
||||
"DATABASE_UNAVAILABLE",
|
||||
"Database manager not available"
|
||||
)
|
||||
)
|
||||
|
||||
# Stop existing background tasks
|
||||
if background_task_manager:
|
||||
await background_task_manager.stop()
|
||||
logger.info("🛑 Stopped existing background task manager")
|
||||
|
||||
# Create and start new background task manager
|
||||
from background_tasks import BackgroundTaskManager
|
||||
new_background_task_manager = BackgroundTaskManager(database_manager)
|
||||
await new_background_task_manager.start()
|
||||
|
||||
# Update app state
|
||||
request.app.state.background_task_manager = new_background_task_manager
|
||||
|
||||
# Get status of new manager
|
||||
status = await new_background_task_manager.get_task_status()
|
||||
|
||||
return create_success_response({
|
||||
"message": "Background task manager restarted successfully",
|
||||
"new_status": status
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error restarting background tasks: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("RESTART_ERROR", str(e))
|
||||
)
|
||||
|
||||
# ============================
|
||||
# Task Monitoring and Metrics
|
||||
# ============================
|
||||
|
||||
class TaskMetrics:
|
||||
"""Collect metrics for background tasks"""
|
||||
|
||||
def __init__(self):
|
||||
self.task_runs = {}
|
||||
self.task_durations = {}
|
||||
self.task_errors = {}
|
||||
|
||||
def record_task_run(self, task_name: str, duration: float, success: bool = True):
|
||||
"""Record a task execution"""
|
||||
if task_name not in self.task_runs:
|
||||
self.task_runs[task_name] = 0
|
||||
self.task_durations[task_name] = []
|
||||
self.task_errors[task_name] = 0
|
||||
|
||||
self.task_runs[task_name] += 1
|
||||
self.task_durations[task_name].append(duration)
|
||||
|
||||
if not success:
|
||||
self.task_errors[task_name] += 1
|
||||
|
||||
# Keep only last 100 durations to prevent memory growth
|
||||
if len(self.task_durations[task_name]) > 100:
|
||||
self.task_durations[task_name] = self.task_durations[task_name][-100:]
|
||||
|
||||
def get_metrics(self) -> dict:
|
||||
"""Get task metrics summary"""
|
||||
metrics = {}
|
||||
|
||||
for task_name in self.task_runs:
|
||||
durations = self.task_durations[task_name]
|
||||
avg_duration = sum(durations) / len(durations) if durations else 0
|
||||
|
||||
metrics[task_name] = {
|
||||
"total_runs": self.task_runs[task_name],
|
||||
"total_errors": self.task_errors[task_name],
|
||||
"success_rate": (self.task_runs[task_name] - self.task_errors[task_name]) / self.task_runs[task_name] if self.task_runs[task_name] > 0 else 0,
|
||||
"average_duration": avg_duration,
|
||||
"last_runs": durations[-10:] if durations else []
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
# Global task metrics
|
||||
task_metrics = TaskMetrics()
|
||||
|
||||
@router.get("/tasks/metrics")
|
||||
async def get_task_metrics(
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
"""Get background task metrics (admin only)"""
|
||||
try:
|
||||
global task_metrics
|
||||
metrics = task_metrics.get_metrics()
|
||||
|
||||
return create_success_response({
|
||||
"metrics": metrics,
|
||||
"timestamp": datetime.now(UTC).isoformat()
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get task metrics error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("METRICS_ERROR", str(e))
|
||||
)
|
||||
|
||||
|
||||
# ============================
|
||||
# Admin Endpoints
|
||||
# ============================
|
||||
# @router.get("/verification-stats")
|
||||
async def get_verification_statistics(
|
||||
current_user = Depends(get_current_admin),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Get verification statistics (admin only)"""
|
||||
try:
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
stats = {
|
||||
"pending_verifications": await database.get_pending_verifications_count(),
|
||||
"expired_tokens_cleaned": await database.cleanup_expired_verification_tokens()
|
||||
}
|
||||
|
||||
return create_success_response(stats)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting verification stats: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STATS_ERROR", str(e))
|
||||
)
|
||||
|
||||
@router.post("/cleanup-verifications")
|
||||
async def cleanup_verification_tokens(
|
||||
current_user = Depends(get_current_admin),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Manually trigger cleanup of expired verification tokens (admin only)"""
|
||||
try:
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
cleaned_count = await database.cleanup_expired_verification_tokens()
|
||||
|
||||
logger.info(f"🧹 Manual cleanup completed by admin {current_user.id}: {cleaned_count} tokens cleaned")
|
||||
|
||||
return create_success_response({
|
||||
"message": f"Cleanup completed. Removed {cleaned_count} expired verification tokens.",
|
||||
"cleaned_count": cleaned_count
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error in manual cleanup: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("CLEANUP_ERROR", str(e))
|
||||
)
|
||||
|
||||
@router.get("/pending-verifications")
|
||||
async def get_pending_verifications(
|
||||
current_user = Depends(get_current_admin),
|
||||
page: int = Query(1, ge=1),
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Get list of pending email verifications (admin only)"""
|
||||
try:
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
pattern = "email_verification:*"
|
||||
cursor = 0
|
||||
pending_verifications = []
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
while True:
|
||||
cursor, keys = await database.redis.scan(cursor, match=pattern, count=100)
|
||||
|
||||
for key in keys:
|
||||
token_data = await database.redis.get(key)
|
||||
if token_data:
|
||||
verification_info = json.loads(token_data)
|
||||
if not verification_info.get("verified", False):
|
||||
expires_at = datetime.fromisoformat(verification_info.get("expires_at", ""))
|
||||
|
||||
pending_verifications.append({
|
||||
"email": verification_info.get("email"),
|
||||
"user_type": verification_info.get("user_type"),
|
||||
"created_at": verification_info.get("created_at"),
|
||||
"expires_at": verification_info.get("expires_at"),
|
||||
"is_expired": current_time > expires_at,
|
||||
"resend_count": verification_info.get("resend_count", 0)
|
||||
})
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
# Sort by creation date (newest first)
|
||||
pending_verifications.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
|
||||
# Apply pagination
|
||||
total = len(pending_verifications)
|
||||
start = (page - 1) * limit
|
||||
end = start + limit
|
||||
paginated_verifications = pending_verifications[start:end]
|
||||
|
||||
paginated_response = create_paginated_response(
|
||||
paginated_verifications,
|
||||
page, limit, total
|
||||
)
|
||||
|
||||
return create_success_response(paginated_response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting pending verifications: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("FETCH_ERROR", str(e))
|
||||
)
|
||||
|
||||
@router.get("/rate-limits/info")
|
||||
async def get_user_rate_limit_status(
|
||||
current_user = Depends(get_current_user_or_guest),
|
||||
rate_limiter: RateLimiter = Depends(get_rate_limiter),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Get rate limit status for a user (admin only)"""
|
||||
try:
|
||||
# Get user to determine type
|
||||
user_data = await database.get_user_by_id(current_user.id)
|
||||
if not user_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("USER_NOT_FOUND", "User not found")
|
||||
)
|
||||
|
||||
user_type = user_data.get("type", "unknown")
|
||||
is_admin = False
|
||||
|
||||
if user_type == "candidate":
|
||||
candidate_data = await database.get_candidate(current_user.id)
|
||||
if candidate_data:
|
||||
is_admin = candidate_data.get("is_admin", False)
|
||||
elif user_type == "employer":
|
||||
employer_data = await database.get_employer(current_user.id)
|
||||
if employer_data:
|
||||
is_admin = employer_data.get("is_admin", False)
|
||||
|
||||
status = await rate_limiter.get_user_rate_limit_status(current_user.id, user_type, is_admin)
|
||||
|
||||
return create_success_response(status)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get rate limit status error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STATUS_ERROR", str(e))
|
||||
)
|
||||
|
||||
@router.get("/rate-limits/{user_id}")
|
||||
async def get_anyone_rate_limit_status(
|
||||
user_id: str = Path(...),
|
||||
admin_user = Depends(get_current_admin),
|
||||
rate_limiter: RateLimiter = Depends(get_rate_limiter),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Get rate limit status for a user (admin only)"""
|
||||
try:
|
||||
# Get user to determine type
|
||||
user_data = await database.get_user_by_id(user_id)
|
||||
if not user_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("USER_NOT_FOUND", "User not found")
|
||||
)
|
||||
|
||||
user_type = user_data.get("type", "unknown")
|
||||
is_admin = False
|
||||
|
||||
if user_type == "candidate":
|
||||
candidate_data = await database.get_candidate(user_id)
|
||||
if candidate_data:
|
||||
is_admin = candidate_data.get("is_admin", False)
|
||||
elif user_type == "employer":
|
||||
employer_data = await database.get_employer(user_id)
|
||||
if employer_data:
|
||||
is_admin = employer_data.get("is_admin", False)
|
||||
|
||||
status = await rate_limiter.get_user_rate_limit_status(user_id, user_type, is_admin)
|
||||
|
||||
return create_success_response(status)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get rate limit status error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STATUS_ERROR", str(e))
|
||||
)
|
||||
|
||||
@router.post("/rate-limits/{user_id}/reset")
|
||||
async def reset_user_rate_limits(
|
||||
user_id: str = Path(...),
|
||||
admin_user = Depends(get_current_admin),
|
||||
rate_limiter: RateLimiter = Depends(get_rate_limiter),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Reset rate limits for a user (admin only)"""
|
||||
try:
|
||||
# Get user to determine type
|
||||
user_data = await database.get_user_by_id(user_id)
|
||||
if not user_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("USER_NOT_FOUND", "User not found")
|
||||
)
|
||||
|
||||
user_type = user_data.get("type", "unknown")
|
||||
success = await rate_limiter.reset_user_rate_limits(user_id, user_type)
|
||||
|
||||
if success:
|
||||
logger.info(f"🔄 Rate limits reset for {user_type} {user_id} by admin {admin_user.id}")
|
||||
return create_success_response({
|
||||
"message": f"Rate limits reset for {user_type} {user_id}",
|
||||
"resetBy": admin_user.id
|
||||
})
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("RESET_FAILED", "Failed to reset rate limits")
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Reset rate limits error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("RESET_ERROR", str(e))
|
||||
)
|
||||
|
1134
src/backend/routes/auth.py
Normal file
1134
src/backend/routes/auth.py
Normal file
File diff suppressed because it is too large
Load Diff
1970
src/backend/routes/candidates.py
Normal file
1970
src/backend/routes/candidates.py
Normal file
File diff suppressed because it is too large
Load Diff
545
src/backend/routes/chat.py
Normal file
545
src/backend/routes/chat.py
Normal file
@ -0,0 +1,545 @@
|
||||
"""
|
||||
Chat routes
|
||||
"""
|
||||
import json
|
||||
import jwt
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone, UTC
|
||||
from typing import (Optional, List, Dict, Any)
|
||||
|
||||
from fastapi import (
|
||||
APIRouter, HTTPException, Depends, Body, Request, BackgroundTasks,
|
||||
FastAPI, HTTPException, Depends, Query, Path, Body, status,
|
||||
APIRouter, Request, BackgroundTasks, File, UploadFile, Form
|
||||
)
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import ValidationError
|
||||
|
||||
from auth_utils import AuthenticationManager, SecurityConfig
|
||||
from database import RedisDatabase, redis_manager
|
||||
from device_manager import DeviceManager
|
||||
from email_service import email_service
|
||||
from logger import logger
|
||||
from utils.dependencies import (
|
||||
get_database, get_current_user, get_current_user_or_guest,
|
||||
create_access_token, verify_token_with_blacklist
|
||||
)
|
||||
from utils.responses import (
|
||||
create_success_response, create_error_response, create_paginated_response
|
||||
)
|
||||
from utils.helpers import (
|
||||
stream_agent_response
|
||||
)
|
||||
import backstory_traceback
|
||||
|
||||
import entities.entity_manager as entities
|
||||
|
||||
from models import (
|
||||
LoginRequest, CreateCandidateRequest, CreateEmployerRequest,
|
||||
Candidate, Employer, Guest, AuthResponse,
|
||||
MFARequest, MFAData, MFARequestResponse, MFAVerifyRequest,
|
||||
EmailVerificationRequest, ResendVerificationRequest,
|
||||
# API
|
||||
MOCK_UUID, ApiActivityType, ChatMessageError, ChatMessageResume,
|
||||
ChatMessageSkillAssessment, ChatMessageStatus, ChatMessageStreaming,
|
||||
ChatMessageUser, DocumentMessage, DocumentOptions, Job,
|
||||
JobRequirements, JobRequirementsMessage, LoginRequest,
|
||||
CreateCandidateRequest, CreateEmployerRequest,
|
||||
|
||||
# User models
|
||||
Candidate, Employer, BaseUserWithType, BaseUser, Guest,
|
||||
Authentication, AuthResponse, CandidateAI,
|
||||
|
||||
# Job models
|
||||
JobApplication, ApplicationStatus,
|
||||
|
||||
# Chat models
|
||||
ChatSession, ChatMessage, ChatContext, ChatQuery, ApiStatusType, ChatSenderType, ApiMessageType, ChatContextType,
|
||||
ChatMessageRagSearch,
|
||||
|
||||
# Document models
|
||||
Document, DocumentType, DocumentListResponse, DocumentUpdateRequest, DocumentContentResponse,
|
||||
|
||||
# Supporting models
|
||||
Location, MFARequest, MFAData, MFARequestResponse, MFAVerifyRequest, RagContentMetadata, RagContentResponse, ResendVerificationRequest, Resume, ResumeMessage, Skill, SkillAssessment, SystemInfo, UserType, WorkExperience, Education,
|
||||
|
||||
# Email
|
||||
EmailVerificationRequest
|
||||
)
|
||||
|
||||
|
||||
# Create router for authentication endpoints
|
||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||
|
||||
@router.post("/sessions/{session_id}/archive")
|
||||
async def archive_chat_session(
|
||||
session_id: str = Path(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Archive a chat session"""
|
||||
try:
|
||||
session_data = await database.get_chat_session(session_id)
|
||||
if not session_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Chat session not found")
|
||||
)
|
||||
|
||||
# Check if user owns this session or is admin
|
||||
if session_data.get("userId") != current_user.id:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content=create_error_response("FORBIDDEN", "Cannot archive another user's session")
|
||||
)
|
||||
|
||||
await database.archive_chat_session(session_id)
|
||||
|
||||
return create_success_response({
|
||||
"message": "Chat session archived successfully",
|
||||
"sessionId": session_id
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Archive chat session error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("ARCHIVE_ERROR", str(e))
|
||||
)
|
||||
|
||||
@router.get("/statistics")
|
||||
async def get_chat_statistics(
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Get chat statistics (admin/analytics endpoint)"""
|
||||
try:
|
||||
stats = await database.get_chat_statistics()
|
||||
return create_success_response(stats)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get chat statistics error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STATS_ERROR", str(e))
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sessions")
|
||||
async def create_chat_session(
|
||||
session_data: Dict[str, Any] = Body(...),
|
||||
current_user: BaseUserWithType = Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Create a new chat session with optional candidate username association"""
|
||||
try:
|
||||
# Extract username if provided
|
||||
username = session_data.get("username")
|
||||
candidate_id = None
|
||||
candidate_data = None
|
||||
|
||||
# If username is provided, look up the candidate
|
||||
if username:
|
||||
logger.info(f"🔍 Looking up candidate with username: {username}")
|
||||
|
||||
# Get all candidates and find by username
|
||||
all_candidates_data = await database.get_all_candidates()
|
||||
candidates_list = [Candidate.model_validate(data) for data in all_candidates_data.values()]
|
||||
|
||||
# Find candidate by username (case-insensitive)
|
||||
matching_candidates = [
|
||||
c for c in candidates_list
|
||||
if c.username.lower() == username.lower()
|
||||
]
|
||||
|
||||
if not matching_candidates:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("CANDIDATE_NOT_FOUND", f"Candidate with username '{username}' not found")
|
||||
)
|
||||
|
||||
candidate_data = matching_candidates[0]
|
||||
candidate_id = candidate_data.id
|
||||
logger.info(f"✅ Found candidate: {candidate_data.full_name} (ID: {candidate_id})")
|
||||
|
||||
# Add required fields
|
||||
session_id = str(uuid.uuid4())
|
||||
session_data["id"] = session_id
|
||||
session_data["userId"] = current_user.id
|
||||
session_data["createdAt"] = datetime.now(UTC).isoformat()
|
||||
session_data["lastActivity"] = datetime.now(UTC).isoformat()
|
||||
|
||||
# Set up context with candidate association if username was provided
|
||||
context = session_data.get("context", {})
|
||||
if candidate_id and candidate_data:
|
||||
context["relatedEntityId"] = candidate_id
|
||||
context["relatedEntityType"] = "candidate"
|
||||
|
||||
# Add candidate info to additional context for AI reference
|
||||
additional_context = context.get("additionalContext", {})
|
||||
additional_context["candidateInfo"] = {
|
||||
"id": candidate_data.id,
|
||||
"name": candidate_data.full_name,
|
||||
"email": candidate_data.email,
|
||||
"username": candidate_data.username,
|
||||
"skills": [skill.name for skill in candidate_data.skills] if candidate_data.skills else [],
|
||||
"experience": len(candidate_data.experience) if candidate_data.experience else 0,
|
||||
"location": candidate_data.location.city if candidate_data.location else "Unknown"
|
||||
}
|
||||
context["additionalContext"] = additional_context
|
||||
|
||||
# Set a descriptive title if not provided
|
||||
if not session_data.get("title"):
|
||||
session_data["title"] = f"Chat about {candidate_data.full_name}"
|
||||
|
||||
session_data["context"] = context
|
||||
|
||||
# Create chat session
|
||||
chat_session = ChatSession.model_validate(session_data)
|
||||
await database.set_chat_session(chat_session.id, chat_session.model_dump())
|
||||
|
||||
logger.info(f"✅ Chat session created: {chat_session.id} for user {current_user.id}" +
|
||||
(f" about candidate {candidate_data.full_name}" if candidate_data else ""))
|
||||
|
||||
return create_success_response(chat_session.model_dump(by_alias=True))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Chat session creation error: {e}")
|
||||
logger.info(json.dumps(session_data, indent=2))
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("CREATION_FAILED", str(e))
|
||||
)
|
||||
|
||||
@router.post("/sessions/messages/stream")
|
||||
async def post_chat_session_message_stream(
|
||||
user_message: ChatMessageUser = Body(...),
|
||||
current_user = Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Post a message to a chat session and stream the response with persistence"""
|
||||
try:
|
||||
chat_session_data = await database.get_chat_session(user_message.session_id)
|
||||
if not chat_session_data:
|
||||
logger.info("🔗 Chat session not found for session ID: " + user_message.session_id)
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Chat session not found")
|
||||
)
|
||||
chat_session = ChatSession.model_validate(chat_session_data)
|
||||
chat_type = chat_session.context.type
|
||||
candidate_info = chat_session.context.additional_context.get("candidateInfo", {}) if chat_session.context and chat_session.context.additional_context else None
|
||||
|
||||
# Get candidate info if this chat is about a specific candidate
|
||||
if candidate_info:
|
||||
logger.info(f"🔗 Chat session {user_message.session_id} about candidate {candidate_info['name']} accessed by user {current_user.id}")
|
||||
else:
|
||||
logger.info(f"🔗 Chat session {user_message.session_id} type {chat_type} accessed by user {current_user.id}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("CANDIDATE_REQUIRED", "This chat session requires a candidate association")
|
||||
)
|
||||
|
||||
candidate_data = await database.get_candidate(candidate_info["id"]) if candidate_info else None
|
||||
candidate : Candidate | None = Candidate.model_validate(candidate_data) if candidate_data else None
|
||||
if not candidate:
|
||||
logger.info(f"🔗 Candidate not found for chat session {user_message.session_id} with ID {candidate_info['id']}")
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("CANDIDATE_NOT_FOUND", "Candidate not found for this chat session")
|
||||
)
|
||||
logger.info(f"🔗 User {current_user.id} posting message to chat session {user_message.session_id} with query length: {len(user_message.content)}")
|
||||
|
||||
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
|
||||
# Entity automatically released when done
|
||||
chat_agent = candidate_entity.get_or_create_agent(agent_type=chat_type)
|
||||
if not chat_agent:
|
||||
logger.info(f"🔗 No chat agent found for session {user_message.session_id} with type {chat_type}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type")
|
||||
)
|
||||
|
||||
# Persist user message to database
|
||||
await database.add_chat_message(user_message.session_id, user_message.model_dump())
|
||||
logger.info(f"💬 User message saved to database for session {user_message.session_id}")
|
||||
|
||||
# Update session last activity
|
||||
chat_session_data["lastActivity"] = datetime.now(UTC).isoformat()
|
||||
await database.set_chat_session(user_message.session_id, chat_session_data)
|
||||
|
||||
return await stream_agent_response(
|
||||
chat_agent=chat_agent,
|
||||
user_message=user_message,
|
||||
database=database,
|
||||
chat_session_data=chat_session_data,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Chat message streaming error")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STREAMING_ERROR", "")
|
||||
)
|
||||
|
||||
@router.get("/sessions/{session_id}/messages")
|
||||
async def get_chat_session_messages(
|
||||
session_id: str = Path(...),
|
||||
current_user = Depends(get_current_user_or_guest),
|
||||
page: int = Query(1, ge=1),
|
||||
limit: int = Query(50, ge=1, le=100), # Increased default for chat messages
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Get persisted chat messages for a session"""
|
||||
try:
|
||||
chat_session_data = await database.get_chat_session(session_id)
|
||||
if not chat_session_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Chat session not found")
|
||||
)
|
||||
|
||||
# Get messages from database
|
||||
chat_messages = await database.get_chat_messages(session_id)
|
||||
|
||||
# Convert to ChatMessage objects and sort by timestamp
|
||||
messages_list = []
|
||||
for msg_data in chat_messages:
|
||||
try:
|
||||
message = ChatMessage.model_validate(msg_data)
|
||||
messages_list.append(message)
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ Failed to validate message: {e}")
|
||||
continue
|
||||
|
||||
# Sort by timestamp (oldest first for chat history)
|
||||
messages_list.sort(key=lambda x: x.timestamp)
|
||||
|
||||
# Apply pagination
|
||||
total = len(messages_list)
|
||||
start = (page - 1) * limit
|
||||
end = start + limit
|
||||
paginated_messages = messages_list[start:end]
|
||||
|
||||
paginated_response = create_paginated_response(
|
||||
[m.model_dump(by_alias=True) for m in paginated_messages],
|
||||
page, limit, total
|
||||
)
|
||||
|
||||
return create_success_response(paginated_response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get chat messages error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("FETCH_ERROR", str(e))
|
||||
)
|
||||
|
||||
@router.patch("/sessions/{session_id}")
|
||||
async def update_chat_session(
|
||||
session_id: str = Path(...),
|
||||
updates: Dict[str, Any] = Body(...),
|
||||
current_user = Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Update a chat session's properties"""
|
||||
try:
|
||||
# Get the existing session
|
||||
session_data = await database.get_chat_session(session_id)
|
||||
if not session_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Chat session not found")
|
||||
)
|
||||
|
||||
session = ChatSession.model_validate(session_data)
|
||||
|
||||
# Check authorization - user can only update their own sessions
|
||||
if session.user_id != current_user.id:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content=create_error_response("FORBIDDEN", "Cannot update another user's chat session")
|
||||
)
|
||||
|
||||
# Validate and apply updates
|
||||
allowed_fields = {"title", "context", "isArchived", "systemPrompt"}
|
||||
filtered_updates = {k: v for k, v in updates.items() if k in allowed_fields}
|
||||
|
||||
if not filtered_updates:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("INVALID_UPDATES", "No valid fields provided for update")
|
||||
)
|
||||
|
||||
# Apply updates to session data
|
||||
session_dict = session.model_dump()
|
||||
|
||||
# Handle special field mappings (camelCase to snake_case)
|
||||
if "isArchived" in filtered_updates:
|
||||
session_dict["is_archived"] = filtered_updates["isArchived"]
|
||||
if "systemPrompt" in filtered_updates:
|
||||
session_dict["system_prompt"] = filtered_updates["systemPrompt"]
|
||||
if "title" in filtered_updates:
|
||||
session_dict["title"] = filtered_updates["title"]
|
||||
if "context" in filtered_updates:
|
||||
# Merge context updates with existing context
|
||||
existing_context = session_dict.get("context", {})
|
||||
context_updates = filtered_updates["context"]
|
||||
|
||||
# Update specific context fields while preserving others
|
||||
for context_key, context_value in context_updates.items():
|
||||
if context_key == "additionalContext":
|
||||
# Merge additional context
|
||||
existing_additional = existing_context.get("additional_context", {})
|
||||
existing_additional.update(context_value)
|
||||
existing_context["additional_context"] = existing_additional
|
||||
else:
|
||||
# Convert camelCase to snake_case for context fields
|
||||
snake_key = context_key
|
||||
if context_key == "relatedEntityId":
|
||||
snake_key = "related_entity_id"
|
||||
elif context_key == "relatedEntityType":
|
||||
snake_key = "related_entity_type"
|
||||
elif context_key == "aiParameters":
|
||||
snake_key = "ai_parameters"
|
||||
|
||||
existing_context[snake_key] = context_value
|
||||
|
||||
session_dict["context"] = existing_context
|
||||
|
||||
# Update last activity timestamp
|
||||
session_dict["last_activity"] = datetime.now(UTC).isoformat()
|
||||
|
||||
# Validate the updated session
|
||||
updated_session = ChatSession.model_validate(session_dict)
|
||||
|
||||
# Save to database
|
||||
await database.set_chat_session(session_id, updated_session.model_dump())
|
||||
|
||||
logger.info(f"✅ Chat session {session_id} updated by user {current_user.id}")
|
||||
|
||||
return create_success_response(updated_session.model_dump(by_alias=True))
|
||||
|
||||
except ValueError as ve:
|
||||
logger.warning(f"⚠️ Validation error updating chat session: {ve}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("VALIDATION_ERROR", str(ve))
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Update chat session error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("UPDATE_ERROR", str(e))
|
||||
)
|
||||
|
||||
@router.delete("/sessions/{session_id}")
|
||||
async def delete_chat_session(
|
||||
session_id: str = Path(...),
|
||||
current_user = Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Delete a chat session and all its messages"""
|
||||
try:
|
||||
# Get the session to verify it exists and check ownership
|
||||
session_data = await database.get_chat_session(session_id)
|
||||
if not session_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Chat session not found")
|
||||
)
|
||||
|
||||
session = ChatSession.model_validate(session_data)
|
||||
|
||||
# Check authorization - user can only delete their own sessions
|
||||
if session.user_id != current_user.id:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content=create_error_response("FORBIDDEN", "Cannot delete another user's chat session")
|
||||
)
|
||||
|
||||
# Delete all messages associated with this session
|
||||
try:
|
||||
await database.delete_chat_messages(session_id)
|
||||
chat_messages = await database.get_chat_messages(session_id)
|
||||
message_count = len(chat_messages)
|
||||
logger.info(f"🗑️ Deleted {message_count} messages from session {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ Error deleting messages for session {session_id}: {e}")
|
||||
# Continue with session deletion even if message deletion fails
|
||||
|
||||
# Delete the session itself
|
||||
await database.delete_chat_session(session_id)
|
||||
|
||||
logger.info(f"🗑️ Chat session {session_id} deleted by user {current_user.id}")
|
||||
|
||||
return create_success_response({
|
||||
"success": True,
|
||||
"message": "Chat session deleted successfully",
|
||||
"sessionId": session_id
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Delete chat session error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("DELETE_ERROR", str(e))
|
||||
)
|
||||
|
||||
@router.patch("/sessions/{session_id}/reset")
|
||||
async def reset_chat_session(
|
||||
session_id: str = Path(...),
|
||||
current_user = Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Delete a chat session and all its messages"""
|
||||
try:
|
||||
# Get the session to verify it exists and check ownership
|
||||
session_data = await database.get_chat_session(session_id)
|
||||
if not session_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Chat session not found")
|
||||
)
|
||||
|
||||
session = ChatSession.model_validate(session_data)
|
||||
|
||||
# Check authorization - user can only delete their own sessions
|
||||
if session.user_id != current_user.id:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content=create_error_response("FORBIDDEN", "Cannot reset another user's chat session")
|
||||
)
|
||||
|
||||
# Delete all messages associated with this session
|
||||
try:
|
||||
await database.delete_chat_messages(session_id)
|
||||
chat_messages = await database.get_chat_messages(session_id)
|
||||
message_count = len(chat_messages)
|
||||
logger.info(f"🗑️ Deleted {message_count} messages from session {session_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ Error deleting messages for session {session_id}: {e}")
|
||||
# Continue with session deletion even if message deletion fails
|
||||
|
||||
|
||||
logger.info(f"🗑️ Chat session {session_id} reset by user {current_user.id}")
|
||||
|
||||
return create_success_response({
|
||||
"success": True,
|
||||
"message": "Chat session reset successfully",
|
||||
"sessionId": session_id
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Reset chat session error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("RESET_ERROR", str(e))
|
||||
)
|
||||
|
||||
|
87
src/backend/routes/debug.py
Normal file
87
src/backend/routes/debug.py
Normal file
@ -0,0 +1,87 @@
|
||||
"""
|
||||
Debugging Endpoints
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime, UTC
|
||||
|
||||
from fastapi import APIRouter, Depends, Body, Path
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, EmailStr, ValidationError, field_validator
|
||||
|
||||
from auth_utils import AuthenticationManager, SecurityConfig
|
||||
import backstory_traceback as backstory_traceback
|
||||
from utils.rate_limiter import RateLimiter
|
||||
from database import RedisDatabase, redis_manager
|
||||
from device_manager import DeviceManager
|
||||
from email_service import VerificationEmailRateLimiter, email_service
|
||||
from logger import logger
|
||||
from models import (
|
||||
LoginRequest, CreateCandidateRequest, CreateEmployerRequest,
|
||||
Candidate, Employer, Guest, AuthResponse,
|
||||
MFARequest, MFAData, MFAVerifyRequest,
|
||||
EmailVerificationRequest, ResendVerificationRequest,
|
||||
MFARequestResponse, MFARequestResponse
|
||||
)
|
||||
from utils.dependencies import (
|
||||
get_current_admin, get_database, get_current_user, get_current_user_or_guest,
|
||||
create_access_token, verify_token_with_blacklist
|
||||
)
|
||||
from utils.responses import create_success_response, create_error_response
|
||||
from utils.rate_limiter import get_rate_limiter
|
||||
from auth_utils import (
|
||||
AuthenticationManager,
|
||||
validate_password_strength,
|
||||
sanitize_login_input,
|
||||
SecurityConfig
|
||||
)
|
||||
|
||||
# Create router for authentication endpoints
|
||||
router = APIRouter(prefix="/auth", tags=["authentication"])
|
||||
|
||||
@router.get("/guest/{guest_id}")
|
||||
async def debug_guest_session(
|
||||
guest_id: str = Path(...),
|
||||
admin_user = Depends(get_current_admin),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Debug guest session issues (admin only)"""
|
||||
try:
|
||||
# Check primary storage
|
||||
primary_data = await database.redis.hget("guests", guest_id) # type: ignore
|
||||
primary_exists = primary_data is not None
|
||||
|
||||
# Check backup storage
|
||||
backup_data = await database.redis.get(f"guest_backup:{guest_id}")
|
||||
backup_exists = backup_data is not None
|
||||
|
||||
# Check user lookup
|
||||
user_lookup = await database.get_user_by_id(guest_id)
|
||||
|
||||
# Get TTL info
|
||||
primary_ttl = await database.redis.ttl(f"guests")
|
||||
backup_ttl = await database.redis.ttl(f"guest_backup:{guest_id}")
|
||||
|
||||
debug_info = {
|
||||
"guest_id": guest_id,
|
||||
"primary_storage": {
|
||||
"exists": primary_exists,
|
||||
"data": json.loads(primary_data) if primary_data else None,
|
||||
"ttl": primary_ttl
|
||||
},
|
||||
"backup_storage": {
|
||||
"exists": backup_exists,
|
||||
"data": json.loads(backup_data) if backup_data else None,
|
||||
"ttl": backup_ttl
|
||||
},
|
||||
"user_lookup": user_lookup,
|
||||
"timestamp": datetime.now(UTC).isoformat()
|
||||
}
|
||||
|
||||
return create_success_response(debug_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Debug guest session error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("DEBUG_ERROR", str(e))
|
||||
)
|
128
src/backend/routes/employers.py
Normal file
128
src/backend/routes/employers.py
Normal file
@ -0,0 +1,128 @@
|
||||
"""
|
||||
Employer Routes
|
||||
"""
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import pathlib
|
||||
import re
|
||||
import shutil
|
||||
import jwt
|
||||
import secrets
|
||||
import uuid
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone, UTC
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, File, Form, HTTPException, Depends, Body, Path, Query, Request, BackgroundTasks, UploadFile
|
||||
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
from markitdown import MarkItDown, StreamInfo
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from auth_utils import AuthenticationManager
|
||||
from utils.helpers import create_job_from_content, filter_and_paginate, get_document_type_from_filename, get_skill_cache_key, get_requirements_list
|
||||
from database import RedisDatabase, redis_manager
|
||||
from logger import logger
|
||||
from models import (
|
||||
MOCK_UUID, ApiActivityType, ApiMessageType, ApiStatusType, CandidateAI, ChatContextType, ChatMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageResume, ChatMessageSkillAssessment, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatSession, Document, DocumentContentResponse, DocumentListResponse, DocumentMessage, DocumentOptions, DocumentType, DocumentUpdateRequest, Job, JobRequirements, JobRequirementsMessage, LoginRequest, CreateCandidateRequest, CreateEmployerRequest,
|
||||
Candidate, Employer, Guest, AuthResponse,
|
||||
MFARequest, MFAData, MFAVerifyRequest,
|
||||
EmailVerificationRequest, RAGDocumentRequest, RagContentMetadata, RagContentResponse, ResendVerificationRequest,
|
||||
MFARequestResponse, MFARequestResponse, Resume, ResumeMessage, SkillAssessment, UserType
|
||||
)
|
||||
from utils.dependencies import (
|
||||
get_current_admin, get_database, get_current_user, get_current_user_or_guest,
|
||||
create_access_token, verify_token_with_blacklist, prometheus_collector
|
||||
)
|
||||
from utils.responses import create_paginated_response, create_success_response, create_error_response
|
||||
import utils.llm_proxy as llm_manager
|
||||
import entities.entity_manager as entities
|
||||
from email_service import email_service
|
||||
|
||||
# Create router for job endpoints
|
||||
router = APIRouter(prefix="/employers", tags=["employers"])
|
||||
|
||||
@router.post("")
|
||||
async def create_employer_with_verification(
|
||||
request: CreateEmployerRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Create a new employer with email verification"""
|
||||
try:
|
||||
# Similar to candidate creation but for employer
|
||||
auth_manager = AuthenticationManager(database)
|
||||
|
||||
user_exists, conflict_field = await auth_manager.check_user_exists(
|
||||
request.email,
|
||||
request.username
|
||||
)
|
||||
|
||||
if user_exists and conflict_field:
|
||||
return JSONResponse(
|
||||
status_code=409,
|
||||
content=create_error_response(
|
||||
"USER_EXISTS",
|
||||
f"A user with this {conflict_field} already exists"
|
||||
)
|
||||
)
|
||||
|
||||
employer_id = str(uuid.uuid4())
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
employer_data = {
|
||||
"id": employer_id,
|
||||
"email": request.email,
|
||||
"companyName": request.company_name,
|
||||
"industry": request.industry,
|
||||
"companySize": request.company_size,
|
||||
"companyDescription": request.company_description,
|
||||
"websiteUrl": request.website_url,
|
||||
"phone": request.phone,
|
||||
"createdAt": current_time.isoformat(),
|
||||
"updatedAt": current_time.isoformat(),
|
||||
"status": "pending", # Not active until verified
|
||||
"userType": "employer",
|
||||
"location": {
|
||||
"city": "",
|
||||
"country": "",
|
||||
"remote": False
|
||||
},
|
||||
"socialLinks": []
|
||||
}
|
||||
|
||||
verification_token = secrets.token_urlsafe(32)
|
||||
|
||||
await database.store_email_verification_token(
|
||||
request.email,
|
||||
verification_token,
|
||||
"employer",
|
||||
{
|
||||
"employer_data": employer_data,
|
||||
"password": request.password,
|
||||
"username": request.username
|
||||
}
|
||||
)
|
||||
|
||||
background_tasks.add_task(
|
||||
email_service.send_verification_email,
|
||||
request.email,
|
||||
verification_token,
|
||||
request.company_name
|
||||
)
|
||||
|
||||
logger.info(f"✅ Employer registration initiated for: {request.email}")
|
||||
|
||||
return create_success_response({
|
||||
"message": "Registration successful! Please check your email to verify your account.",
|
||||
"email": request.email,
|
||||
"verificationRequired": True
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Employer creation error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("CREATION_FAILED", "Failed to create employer account")
|
||||
)
|
||||
|
610
src/backend/routes/jobs.py
Normal file
610
src/backend/routes/jobs.py
Normal file
@ -0,0 +1,610 @@
|
||||
"""
|
||||
Job Routes
|
||||
"""
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import pathlib
|
||||
import re
|
||||
import shutil
|
||||
import jwt
|
||||
import secrets
|
||||
import uuid
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone, UTC
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, File, Form, HTTPException, Depends, Body, Path, Query, Request, BackgroundTasks, UploadFile
|
||||
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
from markitdown import MarkItDown, StreamInfo
|
||||
|
||||
import backstory_traceback as backstory_traceback
|
||||
import defines
|
||||
from agents.generate_resume import GenerateResume
|
||||
from agents.base import CandidateEntity
|
||||
from utils.helpers import create_job_from_content, filter_and_paginate, get_document_type_from_filename, get_skill_cache_key, get_requirements_list
|
||||
from database import RedisDatabase, redis_manager
|
||||
from logger import logger
|
||||
from models import (
|
||||
MOCK_UUID, ApiActivityType, ApiMessageType, ApiStatusType, CandidateAI, ChatContextType, ChatMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageResume, ChatMessageSkillAssessment, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatSession, Document, DocumentContentResponse, DocumentListResponse, DocumentMessage, DocumentOptions, DocumentType, DocumentUpdateRequest, Job, JobRequirements, JobRequirementsMessage, LoginRequest, CreateCandidateRequest, CreateEmployerRequest,
|
||||
Candidate, Employer, Guest, AuthResponse,
|
||||
MFARequest, MFAData, MFAVerifyRequest,
|
||||
EmailVerificationRequest, RAGDocumentRequest, RagContentMetadata, RagContentResponse, ResendVerificationRequest,
|
||||
MFARequestResponse, MFARequestResponse, Resume, ResumeMessage, SkillAssessment, UserType
|
||||
)
|
||||
from utils.dependencies import (
|
||||
get_current_admin, get_database, get_current_user, get_current_user_or_guest,
|
||||
create_access_token, verify_token_with_blacklist, prometheus_collector
|
||||
)
|
||||
from utils.responses import create_paginated_response, create_success_response, create_error_response
|
||||
import utils.llm_proxy as llm_manager
|
||||
import entities.entity_manager as entities
|
||||
# Create router for job endpoints
|
||||
router = APIRouter(prefix="/jobs", tags=["jobs"])
|
||||
|
||||
|
||||
async def reformat_as_markdown(database: RedisDatabase, candidate_entity: CandidateEntity, content: str):
|
||||
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
|
||||
if not chat_agent:
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="No agent found for job requirements chat type"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Reformatting job description as markdown...",
|
||||
activity=ApiActivityType.CONVERTING
|
||||
)
|
||||
yield status_message
|
||||
|
||||
message = None
|
||||
async for message in chat_agent.llm_one_shot(
|
||||
llm=llm_manager.get_llm(),
|
||||
model=defines.model,
|
||||
session_id=MOCK_UUID,
|
||||
prompt=content,
|
||||
system_prompt="""
|
||||
You are a document editor. Take the provided job description and reformat as legible markdown.
|
||||
Return only the markdown content, no other text. Make sure all content is included.
|
||||
"""
|
||||
):
|
||||
pass
|
||||
|
||||
if not message or not isinstance(message, ChatMessage):
|
||||
logger.error("❌ Failed to reformat job description to markdown")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Failed to reformat job description"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
chat_message : ChatMessage = message
|
||||
try:
|
||||
chat_message.content = chat_agent.extract_markdown_from_text(chat_message.content)
|
||||
except Exception as e:
|
||||
pass
|
||||
logger.info(f"✅ Successfully converted content to markdown")
|
||||
yield chat_message
|
||||
return
|
||||
|
||||
|
||||
async def create_job_from_content(database: RedisDatabase, current_user: Candidate, content: str):
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Initiating connection with {current_user.first_name}'s AI agent...",
|
||||
activity=ApiActivityType.INFO
|
||||
)
|
||||
yield status_message
|
||||
await asyncio.sleep(0) # Let the status message propagate
|
||||
|
||||
async with entities.get_candidate_entity(candidate=current_user) as candidate_entity:
|
||||
message = None
|
||||
async for message in reformat_as_markdown(database, candidate_entity, content):
|
||||
# Only yield one final DONE message
|
||||
if message.status != ApiStatusType.DONE:
|
||||
yield message
|
||||
if not message or not isinstance(message, ChatMessage):
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Failed to reformat job description"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
markdown_message = message
|
||||
|
||||
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
|
||||
if not chat_agent:
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="No agent found for job requirements chat type"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Analyzing document for company and requirement details...",
|
||||
activity=ApiActivityType.SEARCHING
|
||||
)
|
||||
yield status_message
|
||||
|
||||
message = None
|
||||
async for message in chat_agent.generate(
|
||||
llm=llm_manager.get_llm(),
|
||||
model=defines.model,
|
||||
session_id=MOCK_UUID,
|
||||
prompt=markdown_message.content
|
||||
):
|
||||
if message.status != ApiStatusType.DONE:
|
||||
yield message
|
||||
|
||||
if not message or not isinstance(message, JobRequirementsMessage):
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Job extraction did not convert successfully"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
job_requirements : JobRequirementsMessage = message
|
||||
logger.info(f"✅ Successfully generated job requirements for job {job_requirements.id}")
|
||||
yield job_requirements
|
||||
return
|
||||
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_job(
|
||||
job_data: Dict[str, Any] = Body(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Create a new job"""
|
||||
try:
|
||||
job = Job.model_validate(job_data)
|
||||
|
||||
# Add required fields
|
||||
job.id = str(uuid.uuid4())
|
||||
job.owner_id = current_user.id
|
||||
job.owner = current_user
|
||||
|
||||
await database.set_job(job.id, job.model_dump())
|
||||
|
||||
return create_success_response(job.model_dump(by_alias=True))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Job creation error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("CREATION_FAILED", str(e))
|
||||
)
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_candidate_job(
|
||||
job_data: Dict[str, Any] = Body(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Create a new job"""
|
||||
is_employer = isinstance(current_user, Employer)
|
||||
|
||||
try:
|
||||
job = Job.model_validate(job_data)
|
||||
|
||||
# Add required fields
|
||||
job.id = str(uuid.uuid4())
|
||||
job.owner_id = current_user.id
|
||||
job.owner = current_user
|
||||
|
||||
await database.set_job(job.id, job.model_dump())
|
||||
|
||||
return create_success_response(job.model_dump(by_alias=True))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Job creation error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("CREATION_FAILED", str(e))
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/{job_id}")
|
||||
async def update_job(
|
||||
job_id: str = Path(...),
|
||||
updates: Dict[str, Any] = Body(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Update a candidate"""
|
||||
try:
|
||||
job_data = await database.get_job(job_id)
|
||||
if not job_data:
|
||||
logger.warning(f"⚠️ Job not found for update: {job_data}")
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Job not found")
|
||||
)
|
||||
|
||||
job = Job.model_validate(job_data)
|
||||
|
||||
# Check authorization (user can only update their own profile)
|
||||
if current_user.is_admin is False and job.owner_id != current_user.id:
|
||||
logger.warning(f"⚠️ Unauthorized update attempt by user {current_user.id} on job {job_id}")
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content=create_error_response("FORBIDDEN", "Cannot update another user's job")
|
||||
)
|
||||
|
||||
# Apply updates
|
||||
updates["updatedAt"] = datetime.now(UTC).isoformat()
|
||||
logger.info(f"🔄 Updating job {job_id} with data: {updates}")
|
||||
job_dict = job.model_dump()
|
||||
job_dict.update(updates)
|
||||
updated_job = Job.model_validate(job_dict)
|
||||
await database.set_job(job_id, updated_job.model_dump())
|
||||
|
||||
return create_success_response(updated_job.model_dump(by_alias=True))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Update job error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("UPDATE_FAILED", str(e))
|
||||
)
|
||||
|
||||
@router.post("/from-content")
|
||||
async def create_job_from_description(
|
||||
content: str = Body(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Upload a document for the current candidate"""
|
||||
async def content_stream_generator(content):
|
||||
# Verify user is a candidate
|
||||
if current_user.user_type != "candidate":
|
||||
logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Only candidates can upload documents"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
logger.info(f"📁 Received file content: size='{len(content)} bytes'")
|
||||
|
||||
last_yield_was_streaming = False
|
||||
async for message in create_job_from_content(database=database, current_user=current_user, content=content):
|
||||
if message.status != ApiStatusType.STREAMING:
|
||||
last_yield_was_streaming = False
|
||||
else:
|
||||
if last_yield_was_streaming:
|
||||
continue
|
||||
last_yield_was_streaming = True
|
||||
logger.info(f"📄 Yielding job creation message status: {message.status}")
|
||||
yield message
|
||||
return
|
||||
|
||||
try:
|
||||
async def to_json(method):
|
||||
try:
|
||||
async for message in method:
|
||||
json_data = message.model_dump(mode='json', by_alias=True)
|
||||
json_str = json.dumps(json_data)
|
||||
yield f"data: {json_str}\n\n".encode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"Error in to_json conversion: {e}")
|
||||
return
|
||||
|
||||
return StreamingResponse(
|
||||
to_json(content_stream_generator(content)),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache, no-store, must-revalidate",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Nginx
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"Access-Control-Allow-Origin": "*", # Adjust for your CORS needs
|
||||
"Transfer-Encoding": "chunked",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Document upload error: {e}")
|
||||
return StreamingResponse(
|
||||
iter([json.dumps(ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Failed to upload document"
|
||||
).model_dump(by_alias=True)).encode("utf-8")]),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
|
||||
@router.post("/upload")
|
||||
async def create_job_from_file(
|
||||
file: UploadFile = File(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Upload a job document for the current candidate and create a Job"""
|
||||
# Check file size (limit to 10MB)
|
||||
max_size = 10 * 1024 * 1024 # 10MB
|
||||
file_content = await file.read()
|
||||
if len(file_content) > max_size:
|
||||
logger.info(f"⚠️ File too large: {file.filename} ({len(file_content)} bytes)")
|
||||
return StreamingResponse(
|
||||
iter([json.dumps(ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="File size exceeds 10MB limit"
|
||||
).model_dump(by_alias=True)).encode("utf-8")]),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
if len(file_content) == 0:
|
||||
logger.info(f"⚠️ File is empty: {file.filename}")
|
||||
return StreamingResponse(
|
||||
iter([json.dumps(ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="File is empty"
|
||||
).model_dump(by_alias=True)).encode("utf-8")]),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
|
||||
"""Upload a document for the current candidate"""
|
||||
async def upload_stream_generator(file_content):
|
||||
# Verify user is a candidate
|
||||
if current_user.user_type != "candidate":
|
||||
logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Only candidates can upload documents"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
file.filename = re.sub(r'^.*/', '', file.filename) if file.filename else '' # Sanitize filename
|
||||
if not file.filename or file.filename.strip() == "":
|
||||
logger.warning("⚠️ File upload attempt with missing filename")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="File must have a valid filename"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
logger.info(f"📁 Received file upload: filename='{file.filename}', content_type='{file.content_type}', size='{len(file_content)} bytes'")
|
||||
|
||||
# Validate file type
|
||||
allowed_types = ['.txt', '.md', '.docx', '.pdf', '.png', '.jpg', '.jpeg', '.gif']
|
||||
file_extension = pathlib.Path(file.filename).suffix.lower() if file.filename else ""
|
||||
|
||||
if file_extension not in allowed_types:
|
||||
logger.warning(f"⚠️ Invalid file type: {file_extension} for file {file.filename}")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
document_type = get_document_type_from_filename(file.filename or "unknown.txt")
|
||||
|
||||
if document_type != DocumentType.MARKDOWN and document_type != DocumentType.TXT:
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Converting content from {document_type}...",
|
||||
activity=ApiActivityType.CONVERTING
|
||||
)
|
||||
yield status_message
|
||||
try:
|
||||
md = MarkItDown(enable_plugins=False) # Set to True to enable plugins
|
||||
stream = io.BytesIO(file_content)
|
||||
stream_info = StreamInfo(
|
||||
extension=file_extension, # e.g., ".pdf"
|
||||
url=file.filename # optional, helps with logging and guessing
|
||||
)
|
||||
result = md.convert_stream(stream, stream_info=stream_info, output_format="markdown")
|
||||
file_content = result.text_content
|
||||
logger.info(f"✅ Converted {file.filename} to Markdown format")
|
||||
except Exception as e:
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Failed to convert {file.filename} to Markdown.",
|
||||
)
|
||||
yield error_message
|
||||
logger.error(f"❌ Error converting {file.filename} to Markdown: {e}")
|
||||
return
|
||||
|
||||
async for message in create_job_from_content(database=database, current_user=current_user, content=file_content):
|
||||
yield message
|
||||
return
|
||||
|
||||
try:
|
||||
async def to_json(method):
|
||||
try:
|
||||
async for message in method:
|
||||
json_data = message.model_dump(mode='json', by_alias=True)
|
||||
json_str = json.dumps(json_data)
|
||||
yield f"data: {json_str}\n\n".encode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"Error in to_json conversion: {e}")
|
||||
return
|
||||
|
||||
return StreamingResponse(
|
||||
to_json(upload_stream_generator(file_content)),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache, no-store, must-revalidate",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Nginx
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"Access-Control-Allow-Origin": "*", # Adjust for your CORS needs
|
||||
"Transfer-Encoding": "chunked",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Document upload error: {e}")
|
||||
return StreamingResponse(
|
||||
iter([json.dumps(ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Failed to upload document"
|
||||
).model_dump(mode='json', by_alias=True)).encode("utf-8")]),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
|
||||
@router.get("/{job_id}")
|
||||
async def get_job(
|
||||
job_id: str = Path(...),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Get a job by ID"""
|
||||
try:
|
||||
job_data = await database.get_job(job_id)
|
||||
if not job_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Job not found")
|
||||
)
|
||||
|
||||
# Increment view count
|
||||
job_data["views"] = job_data.get("views", 0) + 1
|
||||
await database.set_job(job_id, job_data)
|
||||
|
||||
job = Job.model_validate(job_data)
|
||||
return create_success_response(job.model_dump(by_alias=True))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get job error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("FETCH_ERROR", str(e))
|
||||
)
|
||||
|
||||
@router.get("")
|
||||
async def get_jobs(
|
||||
page: int = Query(1, ge=1),
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
sortBy: Optional[str] = Query(None, alias="sortBy"),
|
||||
sortOrder: str = Query("desc", pattern="^(asc|desc)$", alias="sortOrder"),
|
||||
filters: Optional[str] = Query(None),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Get paginated list of jobs"""
|
||||
try:
|
||||
filter_dict = None
|
||||
if filters:
|
||||
filter_dict = json.loads(filters)
|
||||
|
||||
# Get all jobs from Redis
|
||||
all_jobs_data = await database.get_all_jobs()
|
||||
jobs_list = []
|
||||
for job in all_jobs_data.values():
|
||||
jobs_list.append(Job.model_validate(job))
|
||||
|
||||
paginated_jobs, total = filter_and_paginate(
|
||||
jobs_list, page, limit, sortBy, sortOrder, filter_dict
|
||||
)
|
||||
|
||||
paginated_response = create_paginated_response(
|
||||
[j.model_dump(by_alias=True) for j in paginated_jobs],
|
||||
page, limit, total
|
||||
)
|
||||
|
||||
return create_success_response(paginated_response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get jobs error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("FETCH_FAILED", str(e))
|
||||
)
|
||||
|
||||
@router.get("/search")
|
||||
async def search_jobs(
|
||||
query: str = Query(...),
|
||||
filters: Optional[str] = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Search jobs"""
|
||||
try:
|
||||
filter_dict = {}
|
||||
if filters:
|
||||
filter_dict = json.loads(filters)
|
||||
|
||||
# Get all jobs from Redis
|
||||
all_jobs_data = await database.get_all_jobs()
|
||||
jobs_list = [Job.model_validate(data) for data in all_jobs_data.values() if data.get("is_active", True)]
|
||||
|
||||
if query:
|
||||
query_lower = query.lower()
|
||||
jobs_list = [
|
||||
j for j in jobs_list
|
||||
if ((j.title and query_lower in j.title.lower()) or
|
||||
(j.description and query_lower in j.description.lower()) or
|
||||
any(query_lower in skill.lower() for skill in getattr(j, "skills", []) or []))
|
||||
]
|
||||
|
||||
paginated_jobs, total = filter_and_paginate(
|
||||
jobs_list, page, limit, filters=filter_dict
|
||||
)
|
||||
|
||||
paginated_response = create_paginated_response(
|
||||
[j.model_dump(by_alias=True) for j in paginated_jobs],
|
||||
page, limit, total
|
||||
)
|
||||
|
||||
return create_success_response(paginated_response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Search jobs error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("SEARCH_FAILED", str(e))
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{job_id}")
|
||||
async def delete_job(
|
||||
job_id: str = Path(...),
|
||||
admin_user = Depends(get_current_admin),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Delete a Job"""
|
||||
try:
|
||||
# Check if admin user
|
||||
if not admin_user.is_admin:
|
||||
logger.warning(f"⚠️ Unauthorized delete attempt by user {admin_user.id}")
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content=create_error_response("FORBIDDEN", "Only admins can delete")
|
||||
)
|
||||
|
||||
# Get candidate data
|
||||
job_data = await database.get_job(job_id)
|
||||
if not job_data:
|
||||
logger.warning(f"⚠️ Candidate not found for deletion: {job_id}")
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Job not found")
|
||||
)
|
||||
|
||||
# Delete job from database
|
||||
await database.delete_job(job_id)
|
||||
|
||||
logger.info(f"🗑️ Job deleted: {job_id} by admin {admin_user.id}")
|
||||
|
||||
return create_success_response({
|
||||
"message": "Job deleted successfully",
|
||||
"jobId": job_id
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Delete job error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("DELETE_ERROR", "Failed to delete job")
|
||||
)
|
308
src/backend/routes/resumes.py
Normal file
308
src/backend/routes/resumes.py
Normal file
@ -0,0 +1,308 @@
|
||||
"""
|
||||
Resume Routes
|
||||
"""
|
||||
import json
|
||||
import pathlib
|
||||
import re
|
||||
import shutil
|
||||
import jwt
|
||||
import secrets
|
||||
import uuid
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone, UTC
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, File, Form, HTTPException, Depends, Body, Path, Query, Request, BackgroundTasks, UploadFile
|
||||
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
import backstory_traceback as backstory_traceback
|
||||
from utils.helpers import filter_and_paginate, get_document_type_from_filename, get_skill_cache_key, get_requirements_list
|
||||
from database import RedisDatabase, redis_manager
|
||||
from logger import logger
|
||||
from models import (
|
||||
MOCK_UUID, ApiActivityType, ApiMessageType, ApiStatusType, CandidateAI, ChatContextType, ChatMessageError, ChatMessageRagSearch, ChatMessageResume, ChatMessageSkillAssessment, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatSession, Document, DocumentContentResponse, DocumentListResponse, DocumentMessage, DocumentOptions, DocumentType, DocumentUpdateRequest, Job, JobRequirements, LoginRequest, CreateCandidateRequest, CreateEmployerRequest,
|
||||
Candidate, Employer, Guest, AuthResponse,
|
||||
MFARequest, MFAData, MFAVerifyRequest,
|
||||
EmailVerificationRequest, RAGDocumentRequest, RagContentMetadata, RagContentResponse, ResendVerificationRequest,
|
||||
MFARequestResponse, MFARequestResponse, Resume, ResumeMessage, SkillAssessment, UserType
|
||||
)
|
||||
from utils.dependencies import (
|
||||
get_current_admin, get_database, get_current_user, get_current_user_or_guest,
|
||||
create_access_token, verify_token_with_blacklist, prometheus_collector
|
||||
)
|
||||
from utils.responses import create_paginated_response, create_success_response, create_error_response
|
||||
import utils.llm_proxy as llm_manager
|
||||
from utils.rate_limiter import get_rate_limiter
|
||||
|
||||
# Create router for authentication endpoints
|
||||
router = APIRouter(prefix="/resumes", tags=["resumes"])
|
||||
|
||||
@router.post("/{candidate_id}/{job_id}")
|
||||
async def create_candidate_resume(
|
||||
candidate_id: str = Path(..., description="ID of the candidate"),
|
||||
job_id: str = Path(..., description="ID of the job"),
|
||||
resume_content: str = Body(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Create a new resume for a candidate/job combination"""
|
||||
async def message_stream_generator():
|
||||
logger.info(f"🔍 Looking up candidate and job details for {candidate_id}/{job_id}")
|
||||
|
||||
candidate_data = await database.get_candidate(candidate_id)
|
||||
if not candidate_data:
|
||||
logger.error(f"❌ Candidate with ID '{candidate_id}' not found")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Candidate with ID '{candidate_id}' not found"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
candidate = Candidate.model_validate(candidate_data)
|
||||
|
||||
job_data = await database.get_job(job_id)
|
||||
if not job_data:
|
||||
logger.error(f"❌ Job with ID '{job_id}' not found")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Job with ID '{job_id}' not found"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
job = Job.model_validate(job_data)
|
||||
|
||||
logger.info(f"📄 Saving resume for candidate {candidate.first_name} {candidate.last_name} for job '{job.title}'")
|
||||
|
||||
# Job and Candidate are valid. Save the resume
|
||||
resume = Resume(
|
||||
job_id=job_id,
|
||||
candidate_id=candidate_id,
|
||||
resume=resume_content,
|
||||
)
|
||||
resume_message: ResumeMessage = ResumeMessage(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
resume=resume
|
||||
)
|
||||
|
||||
# Save to database
|
||||
success = await database.set_resume(current_user.id, resume.model_dump())
|
||||
if not success:
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID,
|
||||
content="Failed to save resume to database"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
logger.info(f"✅ Successfully saved resume {resume_message.resume.id} for user {current_user.id}")
|
||||
yield resume_message
|
||||
return
|
||||
|
||||
try:
|
||||
async def to_json(method):
|
||||
try:
|
||||
async for message in method:
|
||||
json_data = message.model_dump(mode='json', by_alias=True)
|
||||
json_str = json.dumps(json_data)
|
||||
yield f"data: {json_str}\n\n".encode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"Error in to_json conversion: {e}")
|
||||
return
|
||||
|
||||
return StreamingResponse(
|
||||
to_json(message_stream_generator()),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache, no-store, must-revalidate",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Nginx
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"Access-Control-Allow-Origin": "*", # Adjust for your CORS needs
|
||||
"Transfer-Encoding": "chunked",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Resume creation error: {e}")
|
||||
return StreamingResponse(
|
||||
iter([json.dumps(ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Failed to create resume"
|
||||
).model_dump(mode='json', by_alias=True))]),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
|
||||
@router.get("")
|
||||
async def get_user_resumes(
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Get all resumes for the current user"""
|
||||
try:
|
||||
resumes_data = await database.get_all_resumes_for_user(current_user.id)
|
||||
resumes : List[Resume] = [Resume.model_validate(data) for data in resumes_data]
|
||||
for resume in resumes:
|
||||
job_data = await database.get_job(resume.job_id)
|
||||
if job_data:
|
||||
resume.job = Job.model_validate(job_data)
|
||||
candidate_data = await database.get_candidate(resume.candidate_id)
|
||||
if candidate_data:
|
||||
resume.candidate = Candidate.model_validate(candidate_data)
|
||||
resumes.sort(key=lambda x: x.updated_at, reverse=True) # Sort by creation date
|
||||
return create_success_response({
|
||||
"resumes": resumes,
|
||||
"count": len(resumes)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error retrieving resumes for user {current_user.id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve resumes")
|
||||
|
||||
@router.get("/{resume_id}")
|
||||
async def get_resume(
|
||||
resume_id: str = Path(..., description="ID of the resume"),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Get a specific resume by ID"""
|
||||
try:
|
||||
resume = await database.get_resume(current_user.id, resume_id)
|
||||
if not resume:
|
||||
raise HTTPException(status_code=404, detail="Resume not found")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"resume": resume
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error retrieving resume {resume_id} for user {current_user.id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve resume")
|
||||
|
||||
@router.delete("/{resume_id}")
|
||||
async def delete_resume(
|
||||
resume_id: str = Path(..., description="ID of the resume"),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Delete a specific resume"""
|
||||
try:
|
||||
success = await database.delete_resume(current_user.id, resume_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Resume not found")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Resume {resume_id} deleted successfully"
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error deleting resume {resume_id} for user {current_user.id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to delete resume")
|
||||
|
||||
@router.get("/candidate/{candidate_id}")
|
||||
async def get_resumes_by_candidate(
|
||||
candidate_id: str = Path(..., description="ID of the candidate"),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Get all resumes for a specific candidate"""
|
||||
try:
|
||||
resumes = await database.get_resumes_by_candidate(current_user.id, candidate_id)
|
||||
return {
|
||||
"success": True,
|
||||
"candidate_id": candidate_id,
|
||||
"resumes": resumes,
|
||||
"count": len(resumes)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error retrieving resumes for candidate {candidate_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve candidate resumes")
|
||||
|
||||
@router.get("/job/{job_id}")
|
||||
async def get_resumes_by_job(
|
||||
job_id: str = Path(..., description="ID of the job"),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Get all resumes for a specific job"""
|
||||
try:
|
||||
resumes = await database.get_resumes_by_job(current_user.id, job_id)
|
||||
return {
|
||||
"success": True,
|
||||
"job_id": job_id,
|
||||
"resumes": resumes,
|
||||
"count": len(resumes)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error retrieving resumes for job {job_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve job resumes")
|
||||
|
||||
@router.get("/search")
|
||||
async def search_resumes(
|
||||
q: str = Query(..., description="Search query"),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Search resumes by content"""
|
||||
try:
|
||||
resumes = await database.search_resumes_for_user(current_user.id, q)
|
||||
return {
|
||||
"success": True,
|
||||
"query": q,
|
||||
"resumes": resumes,
|
||||
"count": len(resumes)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error searching resumes for user {current_user.id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to search resumes")
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_resume_statistics(
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Get resume statistics for the current user"""
|
||||
try:
|
||||
stats = await database.get_resume_statistics(current_user.id)
|
||||
return {
|
||||
"success": True,
|
||||
"statistics": stats
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error retrieving resume statistics for user {current_user.id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve resume statistics")
|
||||
|
||||
@router.put("/{resume_id}")
|
||||
async def update_resume(
|
||||
resume_id: str = Path(..., description="ID of the resume"),
|
||||
resume: str = Body(..., description="Updated resume content"),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Update the content of a specific resume"""
|
||||
try:
|
||||
updates = {
|
||||
"resume": resume,
|
||||
"updated_at": datetime.now(UTC).isoformat()
|
||||
}
|
||||
|
||||
updated_resume_data = await database.update_resume(current_user.id, resume_id, updates)
|
||||
if not updated_resume_data:
|
||||
logger.warning(f"⚠️ Resume {resume_id} not found for user {current_user.id}")
|
||||
raise HTTPException(status_code=404, detail="Resume not found")
|
||||
updated_resume = Resume.model_validate(updated_resume_data) if updated_resume_data else None
|
||||
|
||||
return create_success_response({
|
||||
"success": True,
|
||||
"message": f"Resume {resume_id} updated successfully",
|
||||
"resume": updated_resume
|
||||
})
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error updating resume {resume_id} for user {current_user.id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to update resume")
|
67
src/backend/routes/system.py
Normal file
67
src/backend/routes/system.py
Normal file
@ -0,0 +1,67 @@
|
||||
"""
|
||||
Health/system routes
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime, UTC
|
||||
|
||||
from fastapi import APIRouter, Depends, Body, HTTPException, Path, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, EmailStr, ValidationError, field_validator
|
||||
|
||||
from auth_utils import AuthenticationManager, SecurityConfig
|
||||
import backstory_traceback as backstory_traceback
|
||||
from utils.rate_limiter import RateLimiter
|
||||
from database import RedisDatabase, redis_manager, Redis
|
||||
from device_manager import DeviceManager
|
||||
from email_service import VerificationEmailRateLimiter, email_service
|
||||
from logger import logger
|
||||
from models import (
|
||||
LoginRequest, CreateCandidateRequest, CreateEmployerRequest,
|
||||
Candidate, Employer, Guest, AuthResponse,
|
||||
MFARequest, MFAData, MFAVerifyRequest,
|
||||
EmailVerificationRequest, ResendVerificationRequest,
|
||||
MFARequestResponse, MFARequestResponse
|
||||
)
|
||||
from utils.dependencies import (
|
||||
get_current_admin, get_database, get_current_user, get_current_user_or_guest,
|
||||
create_access_token, verify_token_with_blacklist
|
||||
)
|
||||
from utils.responses import create_success_response, create_error_response
|
||||
from utils.rate_limiter import get_rate_limiter
|
||||
from auth_utils import (
|
||||
AuthenticationManager,
|
||||
validate_password_strength,
|
||||
sanitize_login_input,
|
||||
SecurityConfig
|
||||
)
|
||||
|
||||
# Create router for authentication endpoints
|
||||
router = APIRouter(prefix="/system", tags=["system"])
|
||||
|
||||
async def get_redis() -> Redis:
|
||||
"""Dependency to get Redis client"""
|
||||
return redis_manager.get_client()
|
||||
|
||||
|
||||
@router.get("/info")
|
||||
async def get_system_info(request: Request):
|
||||
"""Get system information"""
|
||||
from system_info import system_info # Import system_info function from system_info module
|
||||
system = system_info()
|
||||
|
||||
return create_success_response(system.model_dump(mode='json'))
|
||||
|
||||
@router.get("/redis/stats")
|
||||
async def redis_stats(redis: Redis = Depends(get_redis)):
|
||||
try:
|
||||
info = await redis.info()
|
||||
return {
|
||||
"connected_clients": info.get("connected_clients"),
|
||||
"used_memory_human": info.get("used_memory_human"),
|
||||
"total_commands_processed": info.get("total_commands_processed"),
|
||||
"keyspace_hits": info.get("keyspace_hits"),
|
||||
"keyspace_misses": info.get("keyspace_misses"),
|
||||
"uptime_in_seconds": info.get("uptime_in_seconds")
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=503, detail=f"Redis stats unavailable: {e}")
|
99
src/backend/routes/users.py
Normal file
99
src/backend/routes/users.py
Normal file
@ -0,0 +1,99 @@
|
||||
"""
|
||||
Job Routes
|
||||
"""
|
||||
import io
|
||||
import json
|
||||
import pathlib
|
||||
import re
|
||||
import shutil
|
||||
import jwt
|
||||
import secrets
|
||||
import uuid
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone, UTC
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, File, Form, HTTPException, Depends, Body, Path, Query, Request, BackgroundTasks, UploadFile
|
||||
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
|
||||
from database import RedisDatabase, redis_manager
|
||||
from device_manager import DeviceManager
|
||||
from email_service import VerificationEmailRateLimiter, email_service
|
||||
from logger import logger
|
||||
from models import (
|
||||
MOCK_UUID, ApiActivityType, ApiMessageType, ApiStatusType, BaseUserWithType, CandidateAI, ChatContextType, ChatMessageError, ChatMessageRagSearch, ChatMessageResume, ChatMessageSkillAssessment, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatSession, Document, DocumentContentResponse, DocumentListResponse, DocumentMessage, DocumentOptions, DocumentType, DocumentUpdateRequest, Job, JobRequirements, LoginRequest, CreateCandidateRequest, CreateEmployerRequest,
|
||||
Candidate, Employer, Guest, AuthResponse,
|
||||
MFARequest, MFAData, MFAVerifyRequest,
|
||||
EmailVerificationRequest, RAGDocumentRequest, RagContentMetadata, RagContentResponse, ResendVerificationRequest,
|
||||
MFARequestResponse, MFARequestResponse, Resume, ResumeMessage, SkillAssessment, UserType
|
||||
)
|
||||
from utils.dependencies import (
|
||||
get_current_admin, get_database, get_current_user, get_current_user_or_guest,
|
||||
create_access_token, verify_token_with_blacklist, prometheus_collector
|
||||
)
|
||||
from utils.rate_limiter import get_rate_limiter
|
||||
from utils.responses import create_paginated_response, create_success_response, create_error_response
|
||||
|
||||
# Create router for job endpoints
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
# reference can be candidateId, username, or email
|
||||
@router.get("/users/{reference}")
|
||||
async def get_user(
|
||||
reference: str = Path(...),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Get a candidate by username"""
|
||||
try:
|
||||
# Normalize reference to lowercase for case-insensitive search
|
||||
query_lower = reference.lower()
|
||||
|
||||
all_candidate_data = await database.get_all_candidates()
|
||||
if not all_candidate_data:
|
||||
logger.warning(f"⚠️ No users found in database")
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "No users found")
|
||||
)
|
||||
|
||||
user_data = None
|
||||
for user in all_candidate_data.values():
|
||||
if (user.get("id", "").lower() == query_lower or
|
||||
user.get("username", "").lower() == query_lower or
|
||||
user.get("email", "").lower() == query_lower):
|
||||
user_data = user
|
||||
break
|
||||
|
||||
if not user_data:
|
||||
all_guest_data = await database.get_all_guests()
|
||||
if not all_guest_data:
|
||||
logger.warning(f"⚠️ No guests found in database")
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "No users found")
|
||||
)
|
||||
for user in all_guest_data.values():
|
||||
if (user.get("id", "").lower() == query_lower or
|
||||
user.get("username", "").lower() == query_lower or
|
||||
user.get("email", "").lower() == query_lower):
|
||||
user_data = user
|
||||
break
|
||||
|
||||
if not user_data:
|
||||
logger.warning(f"⚠️ User nor Guest found for reference: {reference}")
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "User not found")
|
||||
)
|
||||
|
||||
user = BaseUserWithType.model_validate(user_data)
|
||||
|
||||
return create_success_response(user.model_dump(by_alias=True))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get user error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("FETCH_ERROR", str(e))
|
||||
)
|
||||
|
54
src/backend/utils/__init__.py
Normal file
54
src/backend/utils/__init__.py
Normal file
@ -0,0 +1,54 @@
|
||||
"""
|
||||
Utils package - Utility functions and dependencies
|
||||
"""
|
||||
|
||||
# Import commonly used utilities for easy access
|
||||
from .dependencies import (
|
||||
get_database,
|
||||
get_current_user,
|
||||
get_current_user_or_guest,
|
||||
get_current_admin,
|
||||
create_access_token,
|
||||
verify_token_with_blacklist
|
||||
)
|
||||
|
||||
from .responses import (
|
||||
create_success_response,
|
||||
create_error_response,
|
||||
create_paginated_response
|
||||
)
|
||||
|
||||
from .helpers import (
|
||||
filter_and_paginate,
|
||||
stream_agent_response,
|
||||
get_candidate_files_dir,
|
||||
get_document_type_from_filename,
|
||||
get_skill_cache_key
|
||||
)
|
||||
|
||||
from .rate_limiter import (
|
||||
get_rate_limiter
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Dependencies
|
||||
"get_database",
|
||||
"get_current_user",
|
||||
"get_current_user_or_guest",
|
||||
"get_current_admin",
|
||||
"get_rate_limiter",
|
||||
"create_access_token",
|
||||
"verify_token_with_blacklist",
|
||||
|
||||
# Responses
|
||||
"create_success_response",
|
||||
"create_error_response",
|
||||
"create_paginated_response",
|
||||
|
||||
# Helpers
|
||||
"filter_and_paginate",
|
||||
"stream_agent_response",
|
||||
"get_candidate_files_dir",
|
||||
"get_document_type_from_filename",
|
||||
"get_skill_cache_key"
|
||||
]
|
206
src/backend/utils/dependencies.py
Normal file
206
src/backend/utils/dependencies.py
Normal file
@ -0,0 +1,206 @@
|
||||
"""
|
||||
Shared dependencies for FastAPI routes
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import jwt
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone, UTC
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException, Depends
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from prometheus_client import CollectorRegistry
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
|
||||
import defines
|
||||
|
||||
from database import RedisDatabase, redis_manager, DatabaseManager
|
||||
from models import BaseUserWithType, Candidate, CandidateAI, Employer, Guest
|
||||
from logger import logger
|
||||
from background_tasks import BackgroundTaskManager
|
||||
|
||||
#from . rate_limiter import RateLimiter
|
||||
|
||||
# Security
|
||||
security = HTTPBearer()
|
||||
JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "")
|
||||
if JWT_SECRET_KEY == "":
|
||||
raise ValueError("JWT_SECRET_KEY environment variable is not set")
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
background_task_manager: Optional[BackgroundTaskManager] = None
|
||||
|
||||
# Global database manager reference
|
||||
db_manager = None
|
||||
|
||||
def set_db_manager(manager: DatabaseManager):
|
||||
"""Set the global database manager reference"""
|
||||
global db_manager
|
||||
db_manager = manager
|
||||
|
||||
def get_database() -> RedisDatabase:
|
||||
"""
|
||||
Safe database dependency that checks for availability
|
||||
Raises HTTP 503 if database is not available
|
||||
"""
|
||||
global db_manager
|
||||
|
||||
if db_manager is None:
|
||||
logger.error("Database manager not initialized")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Database not available - service starting up"
|
||||
)
|
||||
|
||||
if db_manager.is_shutting_down:
|
||||
logger.warning("Database is shutting down")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Service is shutting down"
|
||||
)
|
||||
|
||||
try:
|
||||
return db_manager.get_database()
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Database not available: {e}")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Database connection not available"
|
||||
)
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.now(UTC) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(UTC) + timedelta(hours=24)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||
"""Enhanced token verification with guest session recovery"""
|
||||
try:
|
||||
if not db_manager:
|
||||
raise HTTPException(status_code=500, detail="Database not initialized")
|
||||
# First decode the token
|
||||
payload = jwt.decode(credentials.credentials, JWT_SECRET_KEY, algorithms=[ALGORITHM])
|
||||
user_id: str = payload.get("sub")
|
||||
token_type: str = payload.get("type", "access")
|
||||
|
||||
if user_id is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid authentication credentials")
|
||||
|
||||
# Check if token is blacklisted
|
||||
redis = redis_manager.get_client()
|
||||
blacklist_key = f"blacklisted_token:{credentials.credentials}"
|
||||
|
||||
is_blacklisted = await redis.exists(blacklist_key)
|
||||
if is_blacklisted:
|
||||
logger.warning(f"🚫 Attempt to use blacklisted token for user {user_id}")
|
||||
raise HTTPException(status_code=401, detail="Token has been revoked")
|
||||
|
||||
# For guest tokens, verify guest still exists and update activity
|
||||
if token_type == "guest" or payload.get("type") == "guest":
|
||||
database = db_manager.get_database()
|
||||
guest_data = await database.get_guest(user_id)
|
||||
|
||||
if not guest_data:
|
||||
logger.warning(f"🚫 Guest session not found for token: {user_id}")
|
||||
raise HTTPException(status_code=401, detail="Guest session expired")
|
||||
|
||||
# Update guest activity
|
||||
guest_data["last_activity"] = datetime.now(UTC).isoformat()
|
||||
await database.set_guest(user_id, guest_data)
|
||||
logger.debug(f"🔄 Guest activity updated: {user_id}")
|
||||
|
||||
return user_id
|
||||
|
||||
except jwt.PyJWTError as e:
|
||||
logger.warning(f"⚠️ JWT decode error: {e}")
|
||||
raise HTTPException(status_code=401, detail="Invalid authentication credentials")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Token verification error: {e}")
|
||||
raise HTTPException(status_code=401, detail="Token verification failed")
|
||||
|
||||
async def get_current_user(
|
||||
user_id: str = Depends(verify_token_with_blacklist),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
) -> BaseUserWithType:
|
||||
"""Get current user from database"""
|
||||
try:
|
||||
# Check candidates
|
||||
candidate_data = await database.get_candidate(user_id)
|
||||
if candidate_data:
|
||||
if candidate_data.get("is_AI"):
|
||||
from model_cast import cast_to_base_user_with_type
|
||||
return cast_to_base_user_with_type(CandidateAI.model_validate(candidate_data))
|
||||
else:
|
||||
from model_cast import cast_to_base_user_with_type
|
||||
return cast_to_base_user_with_type(Candidate.model_validate(candidate_data))
|
||||
|
||||
# Check employers
|
||||
employer = await database.get_employer(user_id)
|
||||
if employer:
|
||||
return Employer.model_validate(employer)
|
||||
|
||||
logger.warning(f"⚠️ User {user_id} not found in database")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting current user: {e}")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
async def get_current_user_or_guest(
|
||||
user_id: str = Depends(verify_token_with_blacklist),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
) -> BaseUserWithType:
|
||||
"""Get current user (including guests) from database"""
|
||||
try:
|
||||
# Check candidates first
|
||||
candidate_data = await database.get_candidate(user_id)
|
||||
if candidate_data:
|
||||
return Candidate.model_validate(candidate_data) if not candidate_data.get("is_AI") else CandidateAI.model_validate(candidate_data)
|
||||
|
||||
# Check employers
|
||||
employer_data = await database.get_employer(user_id)
|
||||
if employer_data:
|
||||
return Employer.model_validate(employer_data)
|
||||
|
||||
# Check guests
|
||||
guest_data = await database.get_guest(user_id)
|
||||
if guest_data:
|
||||
return Guest.model_validate(guest_data)
|
||||
|
||||
logger.warning(f"⚠️ User {user_id} not found in database")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting current user: {e}")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
async def get_current_admin(
|
||||
user_id: str = Depends(verify_token_with_blacklist),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
) -> BaseUserWithType:
|
||||
user = await get_current_user(user_id=user_id, database=database)
|
||||
if isinstance(user, Candidate) and user.is_admin:
|
||||
return user
|
||||
elif isinstance(user, Employer) and user.is_admin:
|
||||
return user
|
||||
else:
|
||||
logger.warning(f"⚠️ User {user_id} is not an admin")
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
prometheus_collector = CollectorRegistry()
|
||||
|
||||
# Keep the Instrumentator instance alive
|
||||
instrumentator = Instrumentator(
|
||||
should_group_status_codes=True,
|
||||
should_ignore_untemplated=True,
|
||||
should_group_untemplated=True,
|
||||
excluded_handlers=[f"{defines.api_prefix}/metrics"],
|
||||
registry=prometheus_collector
|
||||
)
|
389
src/backend/utils/helpers.py
Normal file
389
src/backend/utils/helpers.py
Normal file
@ -0,0 +1,389 @@
|
||||
"""
|
||||
Helper functions shared across routes
|
||||
"""
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
from datetime import datetime, UTC
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
import defines
|
||||
from logger import logger
|
||||
from models import DocumentType
|
||||
from models import (
|
||||
LoginRequest, CreateCandidateRequest, CreateEmployerRequest,
|
||||
Candidate, Employer, Guest, AuthResponse,
|
||||
MFARequest, MFAData, MFARequestResponse, MFAVerifyRequest,
|
||||
EmailVerificationRequest, ResendVerificationRequest,
|
||||
# API
|
||||
MOCK_UUID, ApiActivityType, ChatMessageError, ChatMessageResume,
|
||||
ChatMessageSkillAssessment, ChatMessageStatus, ChatMessageStreaming,
|
||||
ChatMessageUser, DocumentMessage, DocumentOptions, Job,
|
||||
JobRequirements, JobRequirementsMessage, LoginRequest,
|
||||
CreateCandidateRequest, CreateEmployerRequest,
|
||||
|
||||
# User models
|
||||
Candidate, Employer, BaseUserWithType, BaseUser, Guest,
|
||||
Authentication, AuthResponse, CandidateAI,
|
||||
|
||||
# Job models
|
||||
JobApplication, ApplicationStatus,
|
||||
|
||||
# Chat models
|
||||
ChatSession, ChatMessage, ChatContext, ChatQuery, ChatSenderType, ApiMessageType, ChatContextType,
|
||||
ChatMessageRagSearch,
|
||||
|
||||
# Document models
|
||||
Document, DocumentType, DocumentListResponse, DocumentUpdateRequest, DocumentContentResponse,
|
||||
|
||||
# Supporting models
|
||||
Location, MFARequest, MFAData, MFARequestResponse, MFAVerifyRequest, RagContentMetadata, RagContentResponse, ResendVerificationRequest, Resume, ResumeMessage, Skill, SkillAssessment, SystemInfo, UserType, WorkExperience, Education,
|
||||
|
||||
# Email
|
||||
EmailVerificationRequest,
|
||||
ApiStatusType
|
||||
)
|
||||
|
||||
from typing import List, Dict
|
||||
from models import (Job)
|
||||
import utils.llm_proxy as llm_manager
|
||||
|
||||
async def get_last_item(generator):
|
||||
"""Get the last item from an async generator"""
|
||||
last_item = None
|
||||
async for item in generator:
|
||||
last_item = item
|
||||
return last_item
|
||||
|
||||
|
||||
def filter_and_paginate(
|
||||
items: List[Any],
|
||||
page: int = 1,
|
||||
limit: int = 20,
|
||||
sort_by: Optional[str] = None,
|
||||
sort_order: str = "desc",
|
||||
filters: Optional[Dict] = None
|
||||
) -> Tuple[List[Any], int]:
|
||||
"""Filter, sort, and paginate items"""
|
||||
filtered_items = items.copy()
|
||||
|
||||
# Apply filters (simplified filtering logic)
|
||||
if filters:
|
||||
for key, value in filters.items():
|
||||
if isinstance(filtered_items[0], dict) and key in filtered_items[0]:
|
||||
filtered_items = [item for item in filtered_items if item.get(key) == value]
|
||||
elif hasattr(filtered_items[0], key) if filtered_items else False:
|
||||
filtered_items = [item for item in filtered_items
|
||||
if getattr(item, key, None) == value]
|
||||
|
||||
# Sort items
|
||||
if sort_by and filtered_items:
|
||||
reverse = sort_order.lower() == "desc"
|
||||
try:
|
||||
if isinstance(filtered_items[0], dict):
|
||||
filtered_items.sort(key=lambda x: x.get(sort_by, ""), reverse=reverse)
|
||||
else:
|
||||
filtered_items.sort(key=lambda x: getattr(x, sort_by, ""), reverse=reverse)
|
||||
except (AttributeError, TypeError):
|
||||
pass # Skip sorting if attribute doesn't exist or isn't comparable
|
||||
|
||||
# Paginate
|
||||
total = len(filtered_items)
|
||||
start = (page - 1) * limit
|
||||
end = start + limit
|
||||
paginated_items = filtered_items[start:end]
|
||||
|
||||
return paginated_items, total
|
||||
|
||||
|
||||
async def stream_agent_response(chat_agent, user_message, chat_session_data=None, database=None) -> StreamingResponse:
|
||||
"""Stream agent response with proper formatting"""
|
||||
async def message_stream_generator():
|
||||
"""Generator to stream messages with persistence"""
|
||||
last_log = None
|
||||
final_message = None
|
||||
|
||||
import utils.llm_proxy as llm_manager
|
||||
import agents
|
||||
|
||||
async for generated_message in chat_agent.generate(
|
||||
llm=llm_manager.get_llm(),
|
||||
model=defines.model,
|
||||
session_id=user_message.session_id,
|
||||
prompt=user_message.content,
|
||||
):
|
||||
if generated_message.status == ApiStatusType.ERROR:
|
||||
logger.error(f"❌ AI generation error: {generated_message.content}")
|
||||
yield f"data: {json.dumps({'status': 'error'})}\n\n"
|
||||
return
|
||||
|
||||
# Store reference to the complete AI message
|
||||
if generated_message.status == ApiStatusType.DONE:
|
||||
final_message = generated_message
|
||||
|
||||
# If the message is not done, convert it to a ChatMessageBase to remove
|
||||
# metadata and other unnecessary fields for streaming
|
||||
if generated_message.status != ApiStatusType.DONE:
|
||||
from models import ChatMessageStreaming, ChatMessageStatus
|
||||
if not isinstance(generated_message, ChatMessageStreaming) and not isinstance(generated_message, ChatMessageStatus):
|
||||
raise TypeError(
|
||||
f"Expected ChatMessageStreaming or ChatMessageStatus, got {type(generated_message)}"
|
||||
)
|
||||
|
||||
json_data = generated_message.model_dump(mode='json', by_alias=True)
|
||||
json_str = json.dumps(json_data)
|
||||
|
||||
yield f"data: {json_str}\n\n"
|
||||
|
||||
# After streaming is complete, persist the final AI message to database
|
||||
if final_message and final_message.status == ApiStatusType.DONE:
|
||||
try:
|
||||
if database and chat_session_data:
|
||||
await database.add_chat_message(final_message.session_id, final_message.model_dump())
|
||||
logger.info(f"🤖 Message saved to database for session {final_message.session_id}")
|
||||
|
||||
# Update session last activity again
|
||||
chat_session_data["lastActivity"] = datetime.now(UTC).isoformat()
|
||||
await database.set_chat_session(final_message.session_id, chat_session_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to save message to database: {e}")
|
||||
|
||||
return StreamingResponse(
|
||||
message_stream_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache, no-store, must-revalidate",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Nginx
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"Access-Control-Allow-Origin": "*", # Adjust for your CORS needs
|
||||
"Transfer-Encoding": "chunked",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def get_candidate_files_dir(username: str) -> pathlib.Path:
|
||||
"""Get the files directory for a candidate"""
|
||||
files_dir = pathlib.Path(defines.user_dir) / username / "files"
|
||||
files_dir.mkdir(parents=True, exist_ok=True)
|
||||
return files_dir
|
||||
|
||||
|
||||
def get_document_type_from_filename(filename: str) -> DocumentType:
|
||||
"""Determine document type from filename extension"""
|
||||
extension = pathlib.Path(filename).suffix.lower()
|
||||
|
||||
type_mapping = {
|
||||
'.pdf': DocumentType.PDF,
|
||||
'.docx': DocumentType.DOCX,
|
||||
'.doc': DocumentType.DOCX,
|
||||
'.txt': DocumentType.TXT,
|
||||
'.md': DocumentType.MARKDOWN,
|
||||
'.markdown': DocumentType.MARKDOWN,
|
||||
'.png': DocumentType.IMAGE,
|
||||
'.jpg': DocumentType.IMAGE,
|
||||
'.jpeg': DocumentType.IMAGE,
|
||||
'.gif': DocumentType.IMAGE,
|
||||
}
|
||||
|
||||
return type_mapping.get(extension, DocumentType.TXT)
|
||||
|
||||
|
||||
def get_skill_cache_key(candidate_id: str, skill: str) -> str:
|
||||
"""Generate a unique cache key for skill match"""
|
||||
# Create cache key for this specific candidate + skill combination
|
||||
skill_hash = hashlib.md5(skill.lower().encode()).hexdigest()[:8]
|
||||
return f"skill_match:{candidate_id}:{skill_hash}"
|
||||
|
||||
|
||||
async def reformat_as_markdown(database, candidate_entity, content: str):
|
||||
"""Reformat content as markdown using AI agent"""
|
||||
from models import ChatContextType, MOCK_UUID, ApiStatusType, ChatMessageError, ChatMessageStatus, ApiActivityType, ChatMessage
|
||||
import utils.llm_proxy as llm_manager
|
||||
|
||||
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
|
||||
if not chat_agent:
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID,
|
||||
content="No agent found for job requirements chat type"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID,
|
||||
content=f"Reformatting job description as markdown...",
|
||||
activity=ApiActivityType.CONVERTING
|
||||
)
|
||||
yield status_message
|
||||
|
||||
message = None
|
||||
async for message in chat_agent.llm_one_shot(
|
||||
llm=llm_manager.get_llm(),
|
||||
model=defines.model,
|
||||
session_id=MOCK_UUID,
|
||||
prompt=content,
|
||||
system_prompt="""
|
||||
You are a document editor. Take the provided job description and reformat as legible markdown.
|
||||
Return only the markdown content, no other text. Make sure all content is included.
|
||||
"""
|
||||
):
|
||||
pass
|
||||
|
||||
if not message or not isinstance(message, ChatMessage):
|
||||
logger.error("❌ Failed to reformat job description to markdown")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID,
|
||||
content="Failed to reformat job description"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
chat_message: ChatMessage = message
|
||||
try:
|
||||
chat_message.content = chat_agent.extract_markdown_from_text(chat_message.content)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
logger.info(f"✅ Successfully converted content to markdown")
|
||||
yield chat_message
|
||||
return
|
||||
|
||||
|
||||
async def create_job_from_content(database, current_user, content: str):
|
||||
"""Create a job from content using AI analysis"""
|
||||
from models import (
|
||||
MOCK_UUID, ApiStatusType, ChatMessageError, ChatMessageStatus,
|
||||
ApiActivityType, ChatContextType, JobRequirementsMessage
|
||||
)
|
||||
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID,
|
||||
content=f"Initiating connection with {current_user.first_name}'s AI agent...",
|
||||
activity=ApiActivityType.INFO
|
||||
)
|
||||
yield status_message
|
||||
await asyncio.sleep(0) # Let the status message propagate
|
||||
|
||||
import entities.entity_manager as entities
|
||||
|
||||
async with entities.get_candidate_entity(candidate=current_user) as candidate_entity:
|
||||
message = None
|
||||
async for message in reformat_as_markdown(database, candidate_entity, content):
|
||||
# Only yield one final DONE message
|
||||
if message.status != ApiStatusType.DONE:
|
||||
yield message
|
||||
|
||||
if not message or not isinstance(message, ChatMessage):
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID,
|
||||
content="Failed to reformat job description"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
markdown_message = message
|
||||
|
||||
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
|
||||
if not chat_agent:
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID,
|
||||
content="No agent found for job requirements chat type"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID,
|
||||
content=f"Analyzing document for company and requirement details...",
|
||||
activity=ApiActivityType.SEARCHING
|
||||
)
|
||||
yield status_message
|
||||
|
||||
message = None
|
||||
async for message in chat_agent.generate(
|
||||
llm=llm_manager.get_llm(),
|
||||
model=defines.model,
|
||||
session_id=MOCK_UUID,
|
||||
prompt=markdown_message.content
|
||||
):
|
||||
if message.status != ApiStatusType.DONE:
|
||||
yield message
|
||||
|
||||
if not message or not isinstance(message, JobRequirementsMessage):
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID,
|
||||
content="Job extraction did not convert successfully"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
job_requirements: JobRequirementsMessage = message
|
||||
logger.info(f"✅ Successfully generated job requirements for job {job_requirements.id}")
|
||||
yield job_requirements
|
||||
return
|
||||
|
||||
def get_requirements_list(job: Job) -> List[Dict[str, str]]:
|
||||
requirements: List[Dict[str, str]] = []
|
||||
|
||||
if job.requirements:
|
||||
if job.requirements.technical_skills:
|
||||
if job.requirements.technical_skills.required:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Technical Skills (required)"}
|
||||
for req in job.requirements.technical_skills.required
|
||||
])
|
||||
if job.requirements.technical_skills.preferred:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Technical Skills (preferred)"}
|
||||
for req in job.requirements.technical_skills.preferred
|
||||
])
|
||||
|
||||
if job.requirements.experience_requirements:
|
||||
if job.requirements.experience_requirements.required:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Experience (required)"}
|
||||
for req in job.requirements.experience_requirements.required
|
||||
])
|
||||
if job.requirements.experience_requirements.preferred:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Experience (preferred)"}
|
||||
for req in job.requirements.experience_requirements.preferred
|
||||
])
|
||||
|
||||
if job.requirements.soft_skills:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Soft Skills"}
|
||||
for req in job.requirements.soft_skills
|
||||
])
|
||||
|
||||
if job.requirements.experience:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Experience"}
|
||||
for req in job.requirements.experience
|
||||
])
|
||||
|
||||
if job.requirements.education:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Education"}
|
||||
for req in job.requirements.education
|
||||
])
|
||||
|
||||
if job.requirements.certifications:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Certifications"}
|
||||
for req in job.requirements.certifications
|
||||
])
|
||||
|
||||
if job.requirements.preferred_attributes:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Preferred Attributes"}
|
||||
for req in job.requirements.preferred_attributes
|
||||
])
|
||||
|
||||
return requirements
|
@ -1,15 +1,23 @@
|
||||
"""
|
||||
Rate limiting utilities for guest and authenticated users
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from functools import wraps
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timedelta, UTC
|
||||
from typing import Dict, Optional, Tuple, Any
|
||||
from typing import Callable, Dict, Optional, Tuple, Any
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from pydantic import BaseModel # type: ignore
|
||||
from database import RedisDatabase
|
||||
from logger import logger
|
||||
|
||||
from . dependencies import get_current_user_or_guest, get_database
|
||||
|
||||
async def get_rate_limiter(database: RedisDatabase = Depends(get_database)) -> RateLimiter:
|
||||
"""Dependency to get rate limiter instance"""
|
||||
return RateLimiter(database)
|
||||
|
||||
class RateLimitConfig(BaseModel):
|
||||
"""Rate limit configuration"""
|
||||
requests_per_minute: int
|
||||
@ -285,3 +293,235 @@ class RateLimiter:
|
||||
return False
|
||||
|
||||
|
||||
# ============================
|
||||
# Rate Limited Decorator
|
||||
# ============================
|
||||
|
||||
def rate_limited(
|
||||
guest_per_minute: int = 10,
|
||||
user_per_minute: int = 60,
|
||||
admin_per_minute: int = 120,
|
||||
endpoint_specific: bool = True
|
||||
):
|
||||
"""
|
||||
Decorator to easily apply rate limiting to endpoints
|
||||
|
||||
Args:
|
||||
guest_per_minute: Rate limit for guest users
|
||||
user_per_minute: Rate limit for authenticated users
|
||||
admin_per_minute: Rate limit for admin users
|
||||
endpoint_specific: Whether to apply endpoint-specific limits
|
||||
|
||||
Usage:
|
||||
@rate_limited(guest_per_minute=5, user_per_minute=30)
|
||||
@api_router.post("/my-endpoint")
|
||||
async def my_endpoint(
|
||||
request: Request,
|
||||
current_user = Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
return {"message": "Rate limited endpoint"}
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Extract dependencies from function signature
|
||||
import inspect
|
||||
sig = inspect.signature(func)
|
||||
|
||||
# Get request, current_user, and rate_limiter from kwargs or args
|
||||
request = None
|
||||
current_user = None
|
||||
rate_limiter = None
|
||||
|
||||
# Try to find dependencies in kwargs first
|
||||
for param_name, param_value in kwargs.items():
|
||||
if isinstance(param_value, Request):
|
||||
request = param_value
|
||||
elif hasattr(param_value, 'user_type'): # User-like object
|
||||
current_user = param_value
|
||||
elif isinstance(param_value, RateLimiter):
|
||||
rate_limiter = param_value
|
||||
|
||||
# If not found in kwargs, check if they're provided via Depends
|
||||
if not rate_limiter:
|
||||
# Create rate limiter instance (this should ideally come from DI)
|
||||
database = get_database()
|
||||
rate_limiter = RateLimiter(database)
|
||||
|
||||
# Apply rate limiting if we have the required components
|
||||
if request and current_user and rate_limiter:
|
||||
await apply_custom_rate_limiting(
|
||||
request, current_user, rate_limiter,
|
||||
guest_per_minute, user_per_minute, admin_per_minute
|
||||
)
|
||||
|
||||
# Call the original function
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
async def apply_custom_rate_limiting(
|
||||
request: Request,
|
||||
current_user,
|
||||
rate_limiter: RateLimiter,
|
||||
guest_per_minute: int,
|
||||
user_per_minute: int,
|
||||
admin_per_minute: int
|
||||
):
|
||||
"""Apply custom rate limiting with specified limits"""
|
||||
try:
|
||||
# Determine user info
|
||||
user_id = current_user.id
|
||||
user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type)
|
||||
is_admin = getattr(current_user, 'is_admin', False)
|
||||
|
||||
# Determine appropriate limit
|
||||
if is_admin:
|
||||
requests_per_minute = admin_per_minute
|
||||
elif user_type == "guest":
|
||||
requests_per_minute = guest_per_minute
|
||||
else:
|
||||
requests_per_minute = user_per_minute
|
||||
|
||||
# Create custom rate limit key
|
||||
current_time = datetime.now(UTC)
|
||||
custom_key = f"custom_rate_limit:{request.url.path}:{user_type}:{user_id}:minute:{current_time.strftime('%Y%m%d%H%M')}"
|
||||
|
||||
# Check current usage
|
||||
current_count = int(await rate_limiter.redis.get(custom_key) or 0)
|
||||
|
||||
if current_count >= requests_per_minute:
|
||||
logger.warning(f"🚫 Custom rate limit exceeded for {user_type} {user_id}: {current_count}/{requests_per_minute}")
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
"error": "Rate limit exceeded",
|
||||
"message": f"Custom rate limit exceeded: {current_count}/{requests_per_minute} requests per minute",
|
||||
"retryAfter": 60 - current_time.second,
|
||||
"userType": user_type,
|
||||
"endpoint": request.url.path
|
||||
},
|
||||
headers={"Retry-After": str(60 - current_time.second)}
|
||||
)
|
||||
|
||||
# Increment counter
|
||||
pipe = rate_limiter.redis.pipeline()
|
||||
pipe.incr(custom_key)
|
||||
pipe.expire(custom_key, 120) # 2 minutes TTL
|
||||
await pipe.execute()
|
||||
|
||||
logger.debug(f"✅ Custom rate limit check passed for {user_type} {user_id}: {current_count + 1}/{requests_per_minute}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Custom rate limiting error: {e}")
|
||||
# Fail open
|
||||
|
||||
# ============================
|
||||
# Alternative: FastAPI Dependency-Based Rate Limiting
|
||||
# ============================
|
||||
|
||||
def create_rate_limit_dependency(
|
||||
guest_per_minute: int = 10,
|
||||
user_per_minute: int = 60,
|
||||
admin_per_minute: int = 120
|
||||
):
|
||||
"""
|
||||
Create a FastAPI dependency for rate limiting
|
||||
|
||||
Usage:
|
||||
rate_limit_5_30 = create_rate_limit_dependency(guest_per_minute=5, user_per_minute=30)
|
||||
|
||||
@api_router.post("/my-endpoint")
|
||||
async def my_endpoint(
|
||||
rate_check = Depends(rate_limit_5_30),
|
||||
current_user = Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
return {"message": "Rate limited endpoint"}
|
||||
"""
|
||||
async def rate_limit_dependency(
|
||||
request: Request,
|
||||
current_user = Depends(get_current_user_or_guest),
|
||||
rate_limiter: RateLimiter = Depends(get_rate_limiter)
|
||||
):
|
||||
await apply_custom_rate_limiting(
|
||||
request, current_user, rate_limiter,
|
||||
guest_per_minute, user_per_minute, admin_per_minute
|
||||
)
|
||||
return True
|
||||
|
||||
return rate_limit_dependency
|
||||
|
||||
# ============================
|
||||
# Rate Limiting Utilities
|
||||
# ============================
|
||||
|
||||
class EndpointRateLimiter:
|
||||
"""Utility class for endpoint-specific rate limiting"""
|
||||
|
||||
def __init__(self, rate_limiter: RateLimiter):
|
||||
self.rate_limiter = rate_limiter
|
||||
self.custom_limits = {}
|
||||
|
||||
def set_endpoint_limits(self, endpoint: str, limits: dict):
|
||||
"""Set custom limits for an endpoint"""
|
||||
self.custom_limits[endpoint] = limits
|
||||
|
||||
async def check_endpoint_limit(self, request: Request, current_user) -> bool:
|
||||
"""Check if request exceeds endpoint-specific limits"""
|
||||
endpoint = request.url.path
|
||||
|
||||
if endpoint not in self.custom_limits:
|
||||
return True # No custom limits set
|
||||
|
||||
limits = self.custom_limits[endpoint]
|
||||
user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type)
|
||||
|
||||
if getattr(current_user, 'is_admin', False):
|
||||
user_type = "admin"
|
||||
|
||||
limit = limits.get(user_type, limits.get("default", 60))
|
||||
|
||||
current_time = datetime.now(UTC)
|
||||
key = f"endpoint_limit:{endpoint}:{user_type}:{current_user.id}:minute:{current_time.strftime('%Y%m%d%H%M')}"
|
||||
|
||||
current_count = int(await self.rate_limiter.redis.get(key) or 0)
|
||||
|
||||
if current_count >= limit:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Endpoint rate limit exceeded: {current_count}/{limit} for {endpoint}"
|
||||
)
|
||||
|
||||
# Increment counter
|
||||
await self.rate_limiter.redis.incr(key)
|
||||
await self.rate_limiter.redis.expire(key, 120)
|
||||
|
||||
return True
|
||||
|
||||
# Global endpoint rate limiter instance
|
||||
endpoint_rate_limiter = None
|
||||
|
||||
def get_endpoint_rate_limiter(rate_limiter: RateLimiter = Depends(get_rate_limiter)) -> EndpointRateLimiter:
|
||||
"""Get endpoint rate limiter instance"""
|
||||
global endpoint_rate_limiter
|
||||
if endpoint_rate_limiter is None:
|
||||
endpoint_rate_limiter = EndpointRateLimiter(rate_limiter)
|
||||
|
||||
# Configure endpoint-specific limits
|
||||
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/chat/sessions/*/messages/stream", {
|
||||
"guest": 5, "candidate": 30, "employer": 30, "admin": 100
|
||||
})
|
||||
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/candidates/documents/upload", {
|
||||
"guest": 2, "candidate": 10, "employer": 10, "admin": 50
|
||||
})
|
||||
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/jobs", {
|
||||
"guest": 1, "candidate": 5, "employer": 20, "admin": 50
|
||||
})
|
||||
|
||||
return endpoint_rate_limiter
|
||||
|
39
src/backend/utils/responses.py
Normal file
39
src/backend/utils/responses.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""
|
||||
Response utility functions for consistent API responses
|
||||
"""
|
||||
from typing import Any, Optional, Dict, List
|
||||
|
||||
def create_success_response(data: Any, meta: Optional[Dict] = None) -> Dict:
|
||||
return {
|
||||
"success": True,
|
||||
"data": data,
|
||||
"meta": meta
|
||||
}
|
||||
|
||||
def create_error_response(code: str, message: str, details: Any = None) -> Dict:
|
||||
return {
|
||||
"success": False,
|
||||
"error": {
|
||||
"code": code,
|
||||
"message": message,
|
||||
"details": details
|
||||
}
|
||||
}
|
||||
|
||||
def create_paginated_response(
|
||||
data: List[Any],
|
||||
page: int,
|
||||
limit: int,
|
||||
total: int
|
||||
) -> Dict:
|
||||
total_pages = (total + limit - 1) // limit
|
||||
has_more = page < total_pages
|
||||
|
||||
return {
|
||||
"data": data,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"totalPages": total_pages,
|
||||
"hasMore": has_more
|
||||
}
|
247
src/tests/api/README.md
Normal file
247
src/tests/api/README.md
Normal file
@ -0,0 +1,247 @@
|
||||
# Backstory API Test Suite
|
||||
|
||||
A comprehensive Python CLI utility for testing the Backstory API. This tool handles authentication with MFA support, token management with local persistence, and provides extensive testing coverage for all API endpoints.
|
||||
|
||||
## Features
|
||||
|
||||
- **MFA Authentication**: Full support for Multi-Factor Authentication flow
|
||||
- **Token Persistence**: Stores bearer tokens locally in JSON format for reuse across sessions
|
||||
- **Token Management**: Automatic token refresh and expiration handling
|
||||
- **Comprehensive Coverage**: Tests all API endpoints including candidates, jobs, health checks, and metrics
|
||||
- **Flexible Testing**: Run all tests or focus on specific endpoint groups
|
||||
- **Detailed Reporting**: Get comprehensive test results with timing and error information
|
||||
- **Multiple Output Formats**: Choose between human-readable text or JSON output
|
||||
- **Robust Error Handling**: Includes retry logic and proper error reporting
|
||||
|
||||
## Installation
|
||||
|
||||
1. Clone or download the test suite files
|
||||
2. Install dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Authentication Flows
|
||||
|
||||
### First-time Authentication (with MFA)
|
||||
|
||||
```bash
|
||||
# Login will prompt for MFA code if required
|
||||
python test_suite.py --login your-email@example.com --password your-password
|
||||
|
||||
# Or provide MFA code directly
|
||||
python test_suite.py --login your-email@example.com --password your-password --mfa-code 123456
|
||||
```
|
||||
|
||||
### Using Stored Tokens
|
||||
|
||||
```bash
|
||||
# Use previously stored tokens (no login required)
|
||||
python test_suite.py --use-stored-tokens
|
||||
|
||||
# Or just run normally - stored tokens will be used automatically if valid
|
||||
python test_suite.py --login your-email@example.com --password your-password
|
||||
```
|
||||
|
||||
### Token Management
|
||||
|
||||
```bash
|
||||
# Clear stored tokens
|
||||
python test_suite.py --clear-tokens
|
||||
|
||||
# Specify custom token file location
|
||||
python test_suite.py --login user@example.com --password pass --token-file /path/to/tokens.json
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Command Line Options
|
||||
|
||||
```bash
|
||||
python test_suite.py [OPTIONS]
|
||||
|
||||
Authentication Arguments:
|
||||
--login EMAIL Login email for authentication
|
||||
--password PASSWORD Login password for authentication
|
||||
--mfa-code CODE MFA code from email (if required)
|
||||
--use-stored-tokens Use previously stored tokens without login
|
||||
--clear-tokens Clear stored tokens and exit
|
||||
|
||||
Configuration Arguments:
|
||||
--base-url URL Base URL for the API (default: https://backstory-beta.ketrenos.com)
|
||||
--endpoint GROUP Test specific endpoint group: all, auth, candidates, jobs, health (default: all)
|
||||
--token-file PATH Token storage file (default: .backstory_tokens.json)
|
||||
--verbose, -v Enable verbose logging
|
||||
--output FORMAT Output format: text, json (default: text)
|
||||
--help, -h Show help message
|
||||
```
|
||||
|
||||
### Examples
|
||||
|
||||
**First-time login with MFA:**
|
||||
```bash
|
||||
python test_suite.py --login user@example.com --password mypassword
|
||||
# Will prompt for MFA code if required
|
||||
```
|
||||
|
||||
**Login with MFA code provided:**
|
||||
```bash
|
||||
python test_suite.py --login user@example.com --password mypassword --mfa-code 123456
|
||||
```
|
||||
|
||||
**Use stored tokens (no login required):**
|
||||
```bash
|
||||
python test_suite.py --use-stored-tokens
|
||||
```
|
||||
|
||||
**Run all tests with verbose output:**
|
||||
```bash
|
||||
python test_suite.py --login user@example.com --password mypassword --verbose
|
||||
```
|
||||
|
||||
**Test only candidate endpoints:**
|
||||
```bash
|
||||
python test_suite.py --use-stored-tokens --endpoint candidates
|
||||
```
|
||||
|
||||
**Test against a different environment:**
|
||||
```bash
|
||||
python test_suite.py --login user@example.com --password mypassword --base-url https://localhost:8000
|
||||
```
|
||||
|
||||
**Get JSON output for CI/CD integration:**
|
||||
```bash
|
||||
python test_suite.py --use-stored-tokens --output json
|
||||
```
|
||||
|
||||
**Clear stored tokens:**
|
||||
```bash
|
||||
python test_suite.py --clear-tokens
|
||||
```
|
||||
|
||||
## Authentication Flow
|
||||
|
||||
1. **Initial Login**: Provide email and password
|
||||
2. **Device ID Capture**: The system captures the `device_id` from the login response
|
||||
3. **MFA Verification**: If MFA is required, enter the code sent to your email (device_id is automatically included)
|
||||
4. **Token Storage**: Bearer and refresh tokens are automatically saved locally
|
||||
5. **Subsequent Runs**: Stored tokens are used automatically
|
||||
6. **Token Refresh**: Expired tokens are automatically refreshed using the refresh token
|
||||
7. **Token Expiry**: If refresh fails, you'll be prompted to login again
|
||||
|
||||
## Token Storage
|
||||
|
||||
- Tokens are stored in `.backstory_tokens.json` by default
|
||||
- The file includes bearer token, refresh token, expiration time, and base URL
|
||||
- Tokens are automatically loaded when the application starts
|
||||
- Different base URLs use separate token storage
|
||||
- Tokens are automatically refreshed when expired
|
||||
|
||||
## API Endpoints Tested
|
||||
|
||||
### Authentication
|
||||
- `POST /api/1.0/auth/login` - User login (captures device_id for MFA flow)
|
||||
- `POST /api/1.0/auth/mfa/verify` - MFA verification (includes device_id)
|
||||
- `POST /api/1.0/auth/refresh` - Token refresh
|
||||
|
||||
### Candidates
|
||||
- `GET /api/1.0/candidates` - List candidates (with pagination)
|
||||
- `POST /api/1.0/candidates` - Create candidate
|
||||
- `GET /api/1.0/candidates/{id}` - Get specific candidate
|
||||
- `PATCH /api/1.0/candidates/{id}` - Update candidate (requires auth)
|
||||
- `GET /api/1.0/candidates/search` - Search candidates
|
||||
|
||||
### Jobs
|
||||
- `GET /api/1.0/jobs` - List jobs (with pagination)
|
||||
- `POST /api/1.0/jobs` - Create job (requires auth)
|
||||
- `GET /api/1.0/jobs/{id}` - Get specific job
|
||||
- `GET /api/1.0/jobs/search` - Search jobs
|
||||
|
||||
### System
|
||||
- `GET /api/1.0/health` - Health check
|
||||
- `GET /api/1.0/` - API information
|
||||
- `GET /api/1.0/redis/stats` - Redis statistics
|
||||
- `GET /api/1.0/metrics` - Prometheus metrics
|
||||
|
||||
## Test Flow
|
||||
|
||||
1. **Authentication**:
|
||||
- Loads stored tokens if available and valid
|
||||
- If no valid tokens, performs login with optional MFA
|
||||
- Saves new tokens for future use
|
||||
2. **System Tests**: Tests health, info, and metrics endpoints
|
||||
3. **Candidate Tests**: Tests candidate CRUD operations and search
|
||||
4. **Job Tests**: Tests job CRUD operations and search
|
||||
5. **Dependent Tests**: If entities are created successfully, tests retrieval and updates
|
||||
|
||||
## Output
|
||||
|
||||
### Text Output
|
||||
The default output provides a comprehensive summary including:
|
||||
- Authentication status and method used
|
||||
- Total tests run and success rate
|
||||
- List of any failed tests with error details
|
||||
- Detailed results for all tests with response times
|
||||
|
||||
### JSON Output
|
||||
Perfect for CI/CD integration, includes:
|
||||
- Summary statistics
|
||||
- Detailed results for each test
|
||||
- Structured error information
|
||||
|
||||
## Error Handling
|
||||
|
||||
The test suite includes robust error handling:
|
||||
- **Authentication Issues**: Clear MFA prompts and token management
|
||||
- **Network Issues**: Automatic retries with exponential backoff
|
||||
- **Token Expiry**: Automatic refresh with fallback to re-authentication
|
||||
- **API Errors**: Detailed error reporting with status codes and messages
|
||||
- **Timeout Handling**: Configurable timeouts for all requests
|
||||
|
||||
## Exit Codes
|
||||
|
||||
- `0`: All tests passed
|
||||
- `1`: One or more tests failed or authentication failed
|
||||
|
||||
## Dependencies
|
||||
|
||||
- `requests`: HTTP client library
|
||||
- `urllib3`: HTTP library with retry capabilities
|
||||
|
||||
## Security Notes
|
||||
|
||||
- Token files are stored locally and contain sensitive authentication data
|
||||
- Ensure appropriate file permissions on token storage files
|
||||
- Tokens are automatically refreshed but may require re-authentication if refresh tokens expire
|
||||
- Consider using environment-specific token files for different deployments
|
||||
- Device IDs are automatically handled during MFA flows and cleared after authentication
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### MFA Issues
|
||||
- If MFA verification fails, ensure you're using the latest code from your email
|
||||
- Check verbose output (`--verbose`) to see device_id information
|
||||
- MFA codes typically expire quickly (usually 5-10 minutes)
|
||||
|
||||
### Token Issues
|
||||
- Use `--clear-tokens` to reset stored authentication
|
||||
- Check token file permissions if loading fails
|
||||
- Different base URLs maintain separate token storage
|
||||
|
||||
### Network Issues
|
||||
- The tool includes automatic retries for network failures
|
||||
- Increase verbosity with `--verbose` to see detailed request information
|
||||
- Check firewall/proxy settings if requests consistently fail
|
||||
|
||||
## Contributing
|
||||
|
||||
To add new tests:
|
||||
|
||||
1. Add test methods to the `BackstoryTestSuite` class
|
||||
2. Call your test methods from `run_all_tests()` or the appropriate endpoint group method
|
||||
3. Follow the existing pattern of creating `TestResult` objects
|
||||
|
||||
## License
|
||||
|
||||
This test suite is provided as-is for testing the Backstory API.
|
733
src/tests/api/test_suite.py
Normal file
733
src/tests/api/test_suite.py
Normal file
@ -0,0 +1,733 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Backstory API Test Suite CLI Utility
|
||||
|
||||
A comprehensive test suite for the Backstory API that handles authentication,
|
||||
token management, and testing of all API endpoints.
|
||||
|
||||
Usage:
|
||||
python test_suite.py --login user@example.com --password mypassword
|
||||
python test_suite.py --login user@example.com --password mypassword --verbose
|
||||
python test_suite.py --login user@example.com --password mypassword --endpoint candidates
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestResult:
|
||||
"""Represents the result of a single test."""
|
||||
endpoint: str
|
||||
method: str
|
||||
status_code: int
|
||||
success: bool
|
||||
response_time: float
|
||||
error_message: Optional[str] = None
|
||||
response_data: Optional[Dict] = None
|
||||
|
||||
|
||||
class BackstoryAPIClient:
|
||||
"""Client for interacting with the Backstory API."""
|
||||
|
||||
def __init__(self, base_url: str = "https://backstory-beta.ketrenos.com", token_file: str = ".backstory_tokens.json"):
|
||||
self.base_url = base_url
|
||||
self.token_file = Path(token_file)
|
||||
self.bearer_token: Optional[str] = None
|
||||
self.refresh_token: Optional[str] = None
|
||||
self.token_expires_at: Optional[datetime] = None
|
||||
self.device_id: Optional[str] = None # Store device_id for MFA flow
|
||||
self.user: Optional[Dict[str, Any]] = None # Store user ID if needed
|
||||
|
||||
# Configure session with retries
|
||||
self.session = requests.Session()
|
||||
retry_strategy = Retry(
|
||||
total=3,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
allowed_methods=["HEAD", "GET", "OPTIONS"],
|
||||
backoff_factor=1
|
||||
)
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
self.session.mount("http://", adapter)
|
||||
self.session.mount("https://", adapter)
|
||||
|
||||
# Try to load existing tokens
|
||||
self.load_tokens()
|
||||
|
||||
def save_tokens(self):
|
||||
"""Save tokens to local JSON file."""
|
||||
token_data = {
|
||||
"bearer_token": self.bearer_token,
|
||||
"refresh_token": self.refresh_token,
|
||||
"expires_at": self.token_expires_at.isoformat() if self.token_expires_at else None,
|
||||
"base_url": self.base_url,
|
||||
"user": self.user,
|
||||
}
|
||||
|
||||
logging.debug(f"Saving tokens to {self.token_file}")
|
||||
try:
|
||||
with open(self.token_file, 'w') as f:
|
||||
json.dump(token_data, f, indent=2)
|
||||
logging.debug(f"Tokens saved to {self.token_file}")
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to save tokens: {e}")
|
||||
|
||||
def load_tokens(self):
|
||||
"""Load tokens from local JSON file."""
|
||||
if not self.token_file.exists():
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(self.token_file, 'r') as f:
|
||||
token_data = json.load(f)
|
||||
|
||||
# Only load tokens if they're for the same base URL
|
||||
if token_data.get("base_url") != self.base_url:
|
||||
logging.debug("Token file exists but for different base URL")
|
||||
return False
|
||||
|
||||
self.user = token_data.get("user")
|
||||
self.bearer_token = token_data.get("bearer_token")
|
||||
self.refresh_token = token_data.get("refresh_token")
|
||||
|
||||
if token_data.get("expires_at"):
|
||||
self.token_expires_at = datetime.fromisoformat(token_data["expires_at"])
|
||||
|
||||
# Check if token is expired
|
||||
if self.token_expires_at and datetime.now() >= self.token_expires_at:
|
||||
logging.debug("Stored token is expired")
|
||||
return self.refresh_access_token()
|
||||
|
||||
if self.bearer_token:
|
||||
self.session.headers.update({
|
||||
"Authorization": f"Bearer {self.bearer_token}"
|
||||
})
|
||||
logging.debug("Loaded tokens from file")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to load tokens: {e}")
|
||||
|
||||
return False
|
||||
|
||||
def clear_tokens(self):
|
||||
"""Clear tokens from memory and file."""
|
||||
self.bearer_token = None
|
||||
self.refresh_token = None
|
||||
self.token_expires_at = None
|
||||
self.device_id = None # Clear device_id as well
|
||||
|
||||
if "Authorization" in self.session.headers:
|
||||
del self.session.headers["Authorization"]
|
||||
|
||||
if self.token_file.exists():
|
||||
try:
|
||||
self.token_file.unlink()
|
||||
logging.debug("Token file deleted")
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to delete token file: {e}")
|
||||
|
||||
def is_authenticated(self) -> bool:
|
||||
"""Check if we have valid authentication."""
|
||||
if not self.bearer_token:
|
||||
return False
|
||||
|
||||
if self.token_expires_at and datetime.now() >= self.token_expires_at:
|
||||
return self.refresh_access_token()
|
||||
|
||||
return True
|
||||
|
||||
def login(self, email: str, password: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Authenticate with the API. Returns result indicating if MFA is required.
|
||||
|
||||
Returns:
|
||||
Dict with 'success', 'mfa_required', 'message', and optionally 'device_id' keys
|
||||
"""
|
||||
url = urljoin(self.base_url, "/api/1.0/auth/login")
|
||||
payload = {"login": email, "password": password}
|
||||
|
||||
try:
|
||||
response = self.session.post(url, json=payload, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
data = result.get("data", {})
|
||||
mfaData = data.get("mfaData", {})
|
||||
|
||||
# Store device_id if present (needed for MFA flow)
|
||||
self.device_id = mfaData.get("deviceId")
|
||||
|
||||
# Check if this is a complete login or requires MFA
|
||||
if "accessToken" in data:
|
||||
# Complete login - store tokens
|
||||
self._store_auth_tokens(data)
|
||||
return {"success": True, "mfa_required": False, "message": "Login successful"}
|
||||
|
||||
elif "mfaRequired" in data or "requiresMfa" in data or response.status_code == 202:
|
||||
# MFA required - device_id should be available
|
||||
return {
|
||||
"success": True,
|
||||
"mfa_required": True,
|
||||
"message": "MFA code required",
|
||||
"device_id": self.device_id
|
||||
}
|
||||
|
||||
else:
|
||||
return {"success": False, "mfa_required": False, "message": "Unexpected login response"}
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logging.error(f"Login failed: {e}")
|
||||
return {"success": False, "mfa_required": False, "message": str(e)}
|
||||
|
||||
def verify_mfa(self, email: str, mfa_code: str) -> bool:
|
||||
"""
|
||||
Verify MFA code and complete authentication.
|
||||
|
||||
Args:
|
||||
email: User email
|
||||
mfa_code: MFA code from email
|
||||
|
||||
Returns:
|
||||
True if MFA verification successful, False otherwise
|
||||
"""
|
||||
# Try different possible MFA endpoints
|
||||
possible_endpoints = [
|
||||
"/api/1.0/auth/mfa/verify",
|
||||
"/api/1.0/auth/verify-mfa",
|
||||
"/api/1.0/auth/mfa",
|
||||
"/api/1.0/auth/verify"
|
||||
]
|
||||
|
||||
# Base payload with multiple field name variations
|
||||
payload = {
|
||||
"email": email,
|
||||
"code": mfa_code,
|
||||
"mfa_code": mfa_code, # Some APIs use this field name
|
||||
"verification_code": mfa_code # Some APIs use this field name
|
||||
}
|
||||
|
||||
# Include device_id if available (required for proper MFA flow)
|
||||
if self.device_id:
|
||||
payload["deviceId"] = self.device_id
|
||||
logging.debug(f"Including device_id in MFA verification: {self.device_id}")
|
||||
else:
|
||||
logging.warning("No device_id available for MFA verification - this may cause issues")
|
||||
|
||||
for endpoint in possible_endpoints:
|
||||
url = urljoin(self.base_url, endpoint)
|
||||
|
||||
try:
|
||||
response = self.session.post(url, json=payload, timeout=30)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
data = result.get("data", {})
|
||||
user = data.get("user", {})
|
||||
logging.info(f"Login response: {data}")
|
||||
self.user = user
|
||||
|
||||
if "accessToken" in data:
|
||||
self._store_auth_tokens(data)
|
||||
# Clear device_id after successful MFA
|
||||
self.device_id = None
|
||||
return True
|
||||
elif response.status_code == 404:
|
||||
# Endpoint doesn't exist, try next one
|
||||
continue
|
||||
else:
|
||||
logging.debug(f"MFA verification failed at {endpoint}: {response.status_code} - {response.text}")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logging.debug(f"MFA request failed for {endpoint}: {e}")
|
||||
continue
|
||||
|
||||
logging.error("MFA verification failed - no valid endpoint found")
|
||||
# Clear device_id on failure
|
||||
self.device_id = None
|
||||
return False
|
||||
|
||||
def _store_auth_tokens(self, data: Dict[str, Any]):
|
||||
"""Store authentication tokens from API response."""
|
||||
self.bearer_token = data.get("accessToken")
|
||||
self.refresh_token = data.get("refreshToken")
|
||||
|
||||
# Calculate expiration time (default to 1 hour if not provided)
|
||||
expires_in = data.get("expiresIn", 3600) # seconds
|
||||
self.token_expires_at = datetime.now() + timedelta(seconds=expires_in)
|
||||
|
||||
if self.bearer_token:
|
||||
self.session.headers.update({
|
||||
"Authorization": f"Bearer {self.bearer_token}"
|
||||
})
|
||||
self.save_tokens()
|
||||
|
||||
def refresh_access_token(self) -> bool:
|
||||
"""Refresh the access token using the refresh token."""
|
||||
if not self.refresh_token:
|
||||
return False
|
||||
|
||||
url = urljoin(self.base_url, "/api/1.0/auth/refresh")
|
||||
|
||||
try:
|
||||
response = self.session.post(url, json=self.refresh_token, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
data = result.get("data", {})
|
||||
new_token = data.get("accessToken")
|
||||
|
||||
if new_token:
|
||||
self._store_auth_tokens(data)
|
||||
return True
|
||||
return False
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logging.error(f"Token refresh failed: {e}")
|
||||
self.clear_tokens() # Clear invalid tokens
|
||||
return False
|
||||
|
||||
def make_request(self, method: str, endpoint: str, **kwargs) -> TestResult:
|
||||
"""Make an API request and return a TestResult."""
|
||||
url = urljoin(self.base_url, endpoint)
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
response = self.session.request(method, url, timeout=30, **kwargs)
|
||||
response_time = time.time() - start_time
|
||||
|
||||
success = 200 <= response.status_code < 300
|
||||
error_message = None if success else f"HTTP {response.status_code}: {response.text}"
|
||||
|
||||
try:
|
||||
response_data = response.json() if response.content else None
|
||||
except json.JSONDecodeError:
|
||||
response_data = {"raw_response": response.text}
|
||||
|
||||
return TestResult(
|
||||
endpoint=endpoint,
|
||||
method=method,
|
||||
status_code=response.status_code,
|
||||
success=success,
|
||||
response_time=response_time,
|
||||
error_message=error_message,
|
||||
response_data=response_data
|
||||
)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
response_time = time.time() - start_time
|
||||
return TestResult(
|
||||
endpoint=endpoint,
|
||||
method=method,
|
||||
status_code=0,
|
||||
success=False,
|
||||
response_time=response_time,
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
|
||||
class BackstoryTestSuite:
|
||||
"""Comprehensive test suite for the Backstory API."""
|
||||
|
||||
def __init__(self, client: BackstoryAPIClient):
|
||||
self.client = client
|
||||
self.test_results: List[TestResult] = []
|
||||
self.test_data = {
|
||||
"candidate_id": None,
|
||||
"job_id": None
|
||||
}
|
||||
|
||||
def run_all_tests(self) -> List[TestResult]:
|
||||
"""Run all tests in the correct order."""
|
||||
self.test_results.clear()
|
||||
|
||||
# Basic endpoints (no auth required)
|
||||
self._test_health_check()
|
||||
self._test_api_info()
|
||||
self._test_redis_stats()
|
||||
self._test_metrics()
|
||||
|
||||
# Candidates endpoints
|
||||
self._test_get_candidates()
|
||||
self._test_search_candidates()
|
||||
self._test_create_candidate()
|
||||
|
||||
if self.test_data["candidate_id"]:
|
||||
self._test_get_candidate()
|
||||
self._test_update_candidate()
|
||||
|
||||
# Jobs endpoints
|
||||
self._test_get_jobs()
|
||||
self._test_search_jobs()
|
||||
self._test_create_job()
|
||||
|
||||
if self.test_data["job_id"]:
|
||||
self._test_get_job()
|
||||
|
||||
return self.test_results
|
||||
|
||||
def run_endpoint_tests(self, endpoint_group: str) -> List[TestResult]:
|
||||
"""Run tests for a specific endpoint group."""
|
||||
self.test_results.clear()
|
||||
|
||||
if endpoint_group == "auth":
|
||||
# Auth tests are handled during login
|
||||
pass
|
||||
elif endpoint_group == "candidates":
|
||||
self._test_get_candidates()
|
||||
self._test_search_candidates()
|
||||
self._test_create_candidate()
|
||||
if self.test_data["candidate_id"]:
|
||||
self._test_get_candidate()
|
||||
self._test_update_candidate()
|
||||
elif endpoint_group == "jobs":
|
||||
self._test_get_jobs()
|
||||
self._test_search_jobs()
|
||||
self._test_create_job()
|
||||
if self.test_data["job_id"]:
|
||||
self._test_get_job()
|
||||
elif endpoint_group == "health":
|
||||
self._test_health_check()
|
||||
self._test_api_info()
|
||||
self._test_redis_stats()
|
||||
self._test_metrics()
|
||||
|
||||
return self.test_results
|
||||
|
||||
def _test_health_check(self):
|
||||
"""Test the health check endpoint."""
|
||||
result = self.client.make_request("GET", "/health")
|
||||
self.test_results.append(result)
|
||||
logging.info(f"Health check: {result.status_code} ({result.response_time:.3f}s)")
|
||||
|
||||
def _test_api_info(self):
|
||||
"""Test the API info endpoint."""
|
||||
result = self.client.make_request("GET", "/api/1.0/")
|
||||
self.test_results.append(result)
|
||||
logging.info(f"API info: {result.status_code} ({result.response_time:.3f}s)")
|
||||
|
||||
def _test_redis_stats(self):
|
||||
"""Test the Redis stats endpoint."""
|
||||
result = self.client.make_request("GET", "/api/1.0/redis/stats")
|
||||
self.test_results.append(result)
|
||||
logging.info(f"Redis stats: {result.status_code} ({result.response_time:.3f}s)")
|
||||
|
||||
def _test_metrics(self):
|
||||
"""Test the metrics endpoint."""
|
||||
result = self.client.make_request("GET", "/api/1.0/metrics")
|
||||
self.test_results.append(result)
|
||||
logging.info(f"Metrics: {result.status_code} ({result.response_time:.3f}s)")
|
||||
|
||||
def _test_get_candidates(self):
|
||||
"""Test getting candidates with pagination."""
|
||||
# Test basic get
|
||||
result = self.client.make_request("GET", "/api/1.0/candidates")
|
||||
self.test_results.append(result)
|
||||
logging.info(f"Get candidates: {result.status_code} ({result.response_time:.3f}s)")
|
||||
|
||||
# Test with pagination parameters
|
||||
params = {"page": 1, "limit": 10, "sortOrder": "asc"}
|
||||
result = self.client.make_request("GET", "/api/1.0/candidates", params=params)
|
||||
self.test_results.append(result)
|
||||
logging.info(f"Get candidates (paginated): {result.status_code} ({result.response_time:.3f}s)")
|
||||
|
||||
def _test_search_candidates(self):
|
||||
"""Test searching candidates."""
|
||||
params = {"query": "test", "page": 1, "limit": 5}
|
||||
result = self.client.make_request("GET", "/api/1.0/candidates/search", params=params)
|
||||
self.test_results.append(result)
|
||||
logging.info(f"Search candidates: {result.status_code} ({result.response_time:.3f}s)")
|
||||
|
||||
def _test_create_candidate(self):
|
||||
"""Test creating a candidate."""
|
||||
candidate_data = {
|
||||
"full_name": "Test Candidate",
|
||||
"first_name": "Test",
|
||||
"last_name": "Candidate",
|
||||
"username": "glassmonkey",
|
||||
"password": "se!curepassWord123",
|
||||
"email": "james_backstorytest@ketrenos.com",
|
||||
"skills": ["Python", "API Testing"],
|
||||
"experience": "5 years"
|
||||
}
|
||||
|
||||
result = self.client.make_request("POST", "/api/1.0/candidates", json=candidate_data)
|
||||
self.test_results.append(result)
|
||||
|
||||
# Store candidate ID for later tests
|
||||
if result.success and result.response_data:
|
||||
self.test_data["candidate_id"] = result.response_data.get("id")
|
||||
|
||||
logging.info(f"Create candidate: {result.status_code} ({result.response_time:.3f}s)")
|
||||
|
||||
def _test_get_candidate(self):
|
||||
"""Test getting a specific candidate."""
|
||||
if not self.test_data["candidate_id"]:
|
||||
return
|
||||
|
||||
endpoint = f"/api/1.0/candidates/{self.test_data['candidate_id']}"
|
||||
result = self.client.make_request("GET", endpoint)
|
||||
self.test_results.append(result)
|
||||
logging.info(f"Get candidate: {result.status_code} ({result.response_time:.3f}s)")
|
||||
|
||||
def _test_update_candidate(self):
|
||||
"""Test updating a candidate."""
|
||||
if not self.test_data["candidate_id"]:
|
||||
return
|
||||
|
||||
update_data = {"skills": ["Python", "API Testing", "Quality Assurance"]}
|
||||
endpoint = f"/api/1.0/candidates/{self.test_data['candidate_id']}"
|
||||
|
||||
result = self.client.make_request("PATCH", endpoint, json=update_data)
|
||||
self.test_results.append(result)
|
||||
logging.info(f"Update candidate: {result.status_code} ({result.response_time:.3f}s)")
|
||||
|
||||
def _test_get_jobs(self):
|
||||
"""Test getting jobs with pagination."""
|
||||
# Test basic get
|
||||
result = self.client.make_request("GET", "/api/1.0/jobs")
|
||||
self.test_results.append(result)
|
||||
logging.info(f"Get jobs: {result.status_code} ({result.response_time:.3f}s)")
|
||||
|
||||
# Test with pagination parameters
|
||||
params = {"page": 1, "limit": 10, "sortOrder": "desc"}
|
||||
result = self.client.make_request("GET", "/api/1.0/jobs", params=params)
|
||||
self.test_results.append(result)
|
||||
logging.info(f"Get jobs (paginated): {result.status_code} ({result.response_time:.3f}s)")
|
||||
|
||||
def _test_search_jobs(self):
|
||||
"""Test searching jobs."""
|
||||
params = {"query": "developer", "page": 1, "limit": 5}
|
||||
result = self.client.make_request("GET", "/api/1.0/jobs/search", params=params)
|
||||
self.test_results.append(result)
|
||||
logging.info(f"Search jobs: {result.status_code} ({result.response_time:.3f}s)")
|
||||
|
||||
def _test_create_job(self):
|
||||
"""Test creating a job."""
|
||||
job_data = {
|
||||
"title": "Senior Python Developer",
|
||||
"owner_id": self.client.user['id'] if self.client.user else None,
|
||||
"owner_type": self.client.user['userType'] if self.client.user else None,
|
||||
"description": "Looking for an experienced Python developer",
|
||||
"requirements": {
|
||||
"technical_skills": { "required": ["Python", "FastAPI", "PostgreSQL"], "preferred": ["Docker", "Kubernetes"] },
|
||||
"experience_requirements": { "required": [], "preferred": [] },
|
||||
"soft_skills": ["Communication", "Teamwork", "Problem Solving"]
|
||||
},#["Python", "FastAPI", "PostgreSQL"],
|
||||
"summary": "Senior Python Developer with 5+ years experience",
|
||||
"company": "Tech Innovations Inc.",
|
||||
"location": "Remote",
|
||||
"salary_range": "80000-120000"
|
||||
}
|
||||
|
||||
result = self.client.make_request("POST", "/api/1.0/jobs", json=job_data)
|
||||
self.test_results.append(result)
|
||||
|
||||
# Store job ID for later tests
|
||||
if result.success and result.response_data:
|
||||
self.test_data["job_id"] = result.response_data.get("id")
|
||||
|
||||
logging.info(f"Create job: {result.status_code} ({result.response_time:.3f}s)")
|
||||
|
||||
def _test_get_job(self):
|
||||
"""Test getting a specific job."""
|
||||
if not self.test_data["job_id"]:
|
||||
return
|
||||
|
||||
endpoint = f"/api/1.0/jobs/{self.test_data['job_id']}"
|
||||
result = self.client.make_request("GET", endpoint)
|
||||
self.test_results.append(result)
|
||||
logging.info(f"Get job: {result.status_code} ({result.response_time:.3f}s)")
|
||||
|
||||
|
||||
def print_test_summary(results: List[TestResult]):
|
||||
"""Print a summary of test results."""
|
||||
total_tests = len(results)
|
||||
successful_tests = sum(1 for r in results if r.success)
|
||||
failed_tests = total_tests - successful_tests
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("TEST SUMMARY")
|
||||
print("="*80)
|
||||
print(f"Total Tests: {total_tests}")
|
||||
print(f"Successful: {successful_tests}")
|
||||
print(f"Failed: {failed_tests}")
|
||||
print(f"Success Rate: {(successful_tests/total_tests*100):.1f}%" if total_tests > 0 else "No tests run")
|
||||
|
||||
if failed_tests > 0:
|
||||
print("\nFAILED TESTS:")
|
||||
print("-" * 40)
|
||||
for result in results:
|
||||
if not result.success:
|
||||
print(f"❌ {result.method} {result.endpoint}")
|
||||
print(f" Status: {result.status_code}")
|
||||
print(f" Error: {result.error_message}")
|
||||
print()
|
||||
|
||||
print("\nDETAILED RESULTS:")
|
||||
print("-" * 40)
|
||||
for result in results:
|
||||
status_icon = "✅" if result.success else "❌"
|
||||
print(f"{status_icon} {result.method} {result.endpoint} - "
|
||||
f"{result.status_code} ({result.response_time:.3f}s)")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main CLI entry point."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Backstory API Test Suite",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
%(prog)s --login user@example.com --password mypassword
|
||||
%(prog)s --login user@example.com --password mypassword --mfa-code 123456
|
||||
%(prog)s --login user@example.com --password mypassword --verbose
|
||||
%(prog)s --login user@example.com --password mypassword --endpoint candidates
|
||||
%(prog)s --login user@example.com --password mypassword --base-url https://localhost:8000
|
||||
%(prog)s --use-stored-tokens # Use previously stored tokens
|
||||
%(prog)s --clear-tokens # Clear stored tokens
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument("--login", help="Login email")
|
||||
parser.add_argument("--password", help="Login password")
|
||||
parser.add_argument("--mfa-code", help="MFA code from email (if required)")
|
||||
parser.add_argument("--base-url", default="https://backstory-beta.ketrenos.com",
|
||||
help="Base URL for the API (default: %(default)s)")
|
||||
parser.add_argument("--endpoint", choices=["all", "auth", "candidates", "jobs", "health"],
|
||||
default="all", help="Test specific endpoint group (default: %(default)s)")
|
||||
parser.add_argument("--verbose", "-v", action="store_true",
|
||||
help="Enable verbose logging")
|
||||
parser.add_argument("--output", choices=["text", "json"], default="text",
|
||||
help="Output format (default: %(default)s)")
|
||||
parser.add_argument("--token-file", default=".backstory_tokens.json",
|
||||
help="Token storage file (default: %(default)s)")
|
||||
parser.add_argument("--use-stored-tokens", action="store_true", default=True,
|
||||
help="Use previously stored tokens without login")
|
||||
parser.add_argument("--clear-tokens", action="store_true",
|
||||
help="Clear stored tokens and exit")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging
|
||||
log_level = logging.DEBUG if args.verbose else logging.INFO
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%H:%M:%S"
|
||||
)
|
||||
|
||||
# Initialize client
|
||||
client = BackstoryAPIClient(args.base_url, args.token_file)
|
||||
|
||||
# Handle token clearing
|
||||
if args.clear_tokens:
|
||||
client.clear_tokens()
|
||||
print("✅ Stored tokens cleared!")
|
||||
sys.exit(0)
|
||||
|
||||
# Try to use stored tokens first
|
||||
if args.use_stored_tokens or client.is_authenticated():
|
||||
if client.is_authenticated():
|
||||
print("✅ Using stored authentication tokens!")
|
||||
else:
|
||||
print("❌ No valid stored tokens found!")
|
||||
if not args.login:
|
||||
print("Please provide --login and --password to authenticate.")
|
||||
sys.exit(1)
|
||||
|
||||
# Perform authentication if needed
|
||||
if not client.is_authenticated():
|
||||
if not args.login or not args.password:
|
||||
print("❌ Login credentials required!")
|
||||
print("Please provide --login and --password")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Authenticating with {args.base_url}...")
|
||||
login_result = client.login(args.login, args.password)
|
||||
|
||||
if not login_result["success"]:
|
||||
print(f"❌ Authentication failed: {login_result['message']}")
|
||||
sys.exit(1)
|
||||
|
||||
if login_result["mfa_required"]:
|
||||
print(f"📧 MFA required! Check your email for the verification code.")
|
||||
|
||||
if args.verbose and login_result.get("device_id"):
|
||||
print(f"🔧 Device ID: {login_result['device_id']}")
|
||||
|
||||
# Get MFA code from command line or prompt user
|
||||
mfa_code = args.mfa_code
|
||||
if not mfa_code:
|
||||
try:
|
||||
mfa_code = input("Enter MFA code: ").strip()
|
||||
except KeyboardInterrupt:
|
||||
print("\n❌ Authentication cancelled!")
|
||||
sys.exit(1)
|
||||
|
||||
if not mfa_code:
|
||||
print("❌ MFA code required!")
|
||||
sys.exit(1)
|
||||
|
||||
print("Verifying MFA code...")
|
||||
if not client.verify_mfa(args.login, mfa_code):
|
||||
print("❌ MFA verification failed!")
|
||||
sys.exit(1)
|
||||
|
||||
print("✅ Authentication successful!")
|
||||
|
||||
# Initialize test suite and run tests
|
||||
test_suite = BackstoryTestSuite(client)
|
||||
|
||||
print(f"\nRunning tests for: {args.endpoint}")
|
||||
print("="*50)
|
||||
|
||||
if args.endpoint == "all":
|
||||
results = test_suite.run_all_tests()
|
||||
else:
|
||||
results = test_suite.run_endpoint_tests(args.endpoint)
|
||||
|
||||
# Output results
|
||||
if args.output == "json":
|
||||
output = {
|
||||
"summary": {
|
||||
"total": len(results),
|
||||
"successful": sum(1 for r in results if r.success),
|
||||
"failed": sum(1 for r in results if not r.success)
|
||||
},
|
||||
"results": [
|
||||
{
|
||||
"endpoint": r.endpoint,
|
||||
"method": r.method,
|
||||
"status_code": r.status_code,
|
||||
"success": r.success,
|
||||
"response_time": r.response_time,
|
||||
"error_message": r.error_message
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
}
|
||||
print(json.dumps(output, indent=2))
|
||||
else:
|
||||
print_test_summary(results)
|
||||
|
||||
# Exit with error code if any tests failed
|
||||
failed_count = sum(1 for r in results if not r.success)
|
||||
sys.exit(1 if failed_count > 0 else 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user