ruff reformat

This commit is contained in:
James Ketr 2025-06-18 13:53:07 -07:00
parent f1c2e16389
commit 2cf3fa7b04
69 changed files with 4863 additions and 4915 deletions

View File

@ -17,11 +17,9 @@ from logger import logger
AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses
# Maps class_name to (module_name, class_name)
class_registry: Dict[str, Tuple[str, str]] = (
{}
)
class_registry: Dict[str, Tuple[str, str]] = {}
__all__ = ['get_or_create_agent']
__all__ = ["get_or_create_agent"]
package_dir = pathlib.Path(__file__).parent
package_name = __name__
@ -38,12 +36,7 @@ for path in package_dir.glob("*.py"):
# Find all Agent subclasses in the module
for name, obj in inspect.getmembers(module, inspect.isclass):
if (
issubclass(obj, AnyAgent)
and obj is not AnyAgent
and obj is not Agent
and name not in class_registry
):
if issubclass(obj, AnyAgent) and obj is not AnyAgent and obj is not Agent and name not in class_registry:
class_registry[name] = (full_module_name, name)
globals()[name] = obj
logger.info(f"Adding agent: {name}")

View File

@ -32,19 +32,45 @@ from pathlib import Path
from rag import start_file_watcher, ChromaDBFileWatcher
import defines
from logger import logger
from models import (Tunables, ChatMessageUser, ChatMessage, RagEntry, ChatMessageMetaData, ApiStatusType, Candidate, ChatContextType)
from models import (
Tunables,
ChatMessageUser,
ChatMessage,
RagEntry,
ChatMessageMetaData,
ApiStatusType,
Candidate,
ChatContextType,
)
import utils.llm_proxy as llm_manager
from database.manager import RedisDatabase
from models import ChromaDBGetResponse
from utils.metrics import Metrics
from models import ( ApiActivityType, ApiMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiStatusType, ChatMessageMetaData, Candidate)
from models import (
ApiActivityType,
ApiMessage,
ChatMessageError,
ChatMessageRagSearch,
ChatMessageStatus,
ChatMessageStreaming,
LLMMessage,
ChatMessage,
ChatOptions,
ChatMessageUser,
Tunables,
ApiStatusType,
ChatMessageMetaData,
Candidate,
)
from logger import logger
import defines
from .registry import agent_registry
from models import ( ChromaDBGetResponse )
from models import ChromaDBGetResponse
class CandidateEntity(Candidate):
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
@ -59,13 +85,10 @@ class CandidateEntity(Candidate):
CandidateEntity__agents: List[Agent] = []
CandidateEntity__observer: Optional[Any] = Field(default=None, exclude=True)
CandidateEntity__file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
CandidateEntity__prometheus_collector: Optional[CollectorRegistry] = Field(
default=None, exclude=True
)
CandidateEntity__prometheus_collector: Optional[CollectorRegistry] = Field(default=None, exclude=True)
CandidateEntity__metrics: Optional[Metrics] = Field(
default=None,
description="Metrics collector for this agent, used to track performance and usage."
default=None, description="Metrics collector for this agent, used to track performance and usage."
)
def __init__(self, candidate=None):
@ -78,7 +101,7 @@ class CandidateEntity(Candidate):
@classmethod
def exists(cls, username: str):
# Validate username format (only allow safe characters)
if not re.match(r'^[a-zA-Z0-9_-]+$', username):
if not re.match(r"^[a-zA-Z0-9_-]+$", username):
return False # Invalid username characters
# Check for minimum and maximum length
@ -117,11 +140,7 @@ class CandidateEntity(Candidate):
if agent.agent_type == agent_type:
return agent
return get_or_create_agent(
agent_type=agent_type,
user=self,
prometheus_collector=self.prometheus_collector
)
return get_or_create_agent(agent_type=agent_type, user=self, prometheus_collector=self.prometheus_collector)
# Wrapper properties that map into file_watcher
@property
@ -132,6 +151,7 @@ class CandidateEntity(Candidate):
# Fields managed by initialize()
CandidateEntity__initialized: bool = Field(default=False, exclude=True)
@property
def metrics(self) -> Metrics:
if not self.CandidateEntity__metrics:
@ -160,15 +180,10 @@ class CandidateEntity(Candidate):
if not self.metrics:
logger.warning("No metrics collector set for this agent.")
return
self.metrics.tokens_prompt.labels(agent=agent.agent_type).inc(
response.usage.prompt_eval_count
)
self.metrics.tokens_prompt.labels(agent=agent.agent_type).inc(response.usage.prompt_eval_count)
self.metrics.tokens_eval.labels(agent=agent.agent_type).inc(response.usage.eval_count)
async def initialize(
self,
prometheus_collector: CollectorRegistry,
database: RedisDatabase):
async def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
if self.CandidateEntity__initialized:
# Initialization can only be attempted once; if there are multiple attempts, it means
# a subsystem is failing or there is a logic bug in the code.
@ -205,17 +220,21 @@ class CandidateEntity(Candidate):
)
has_username_rag = any(item.name == self.username for item in self.rags)
if not has_username_rag:
self.rags.append(RagEntry(
self.rags.append(
RagEntry(
name=self.username,
description=f"Expert data about {self.full_name}.",
))
)
)
self.rag_content_size = self.file_watcher.collection.count()
class Agent(BaseModel, ABC):
"""
Base class for all agent types.
This class defines the common attributes and methods for all agent types.
"""
class Config:
arbitrary_types_allowed = True # Allow arbitrary types like RedisDatabase
@ -237,7 +256,7 @@ class Agent(BaseModel, ABC):
conversation: List[ChatMessageUser] = Field(
default_factory=list,
description="Conversation history for this agent, used to maintain context across messages."
description="Conversation history for this agent, used to maintain context across messages.",
)
@property
@ -254,9 +273,7 @@ class Agent(BaseModel, ABC):
last_item = item
return last_item
def set_optimal_context_size(
self, llm: Any, model: str, prompt: str, ctx_buffer=2048
) -> int:
def set_optimal_context_size(self, llm: Any, model: str, prompt: str, ctx_buffer=2048) -> int:
# Most models average 1.3-1.5 tokens per word
word_count = len(prompt.split())
tokens = int(word_count * 1.4)
@ -265,9 +282,7 @@ class Agent(BaseModel, ABC):
total_ctx = tokens + ctx_buffer
if total_ctx > self.context_size:
logger.info(
f"Increasing context size from {self.context_size} to {total_ctx}"
)
logger.info(f"Increasing context size from {self.context_size} to {total_ctx}")
# Grow the context size if necessary
self.context_size = max(self.context_size, total_ctx)
@ -472,16 +487,16 @@ class Agent(BaseModel, ABC):
context = []
for chroma_results in rag_message.content:
for index, metadata in enumerate(chroma_results.metadatas):
content = "\n".join([
line.strip()
for line in chroma_results.documents[index].split("\n")
if line
]).strip()
context.append(f"""
content = "\n".join(
[line.strip() for line in chroma_results.documents[index].split("\n") if line]
).strip()
context.append(
f"""
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")}
Document reference: {chroma_results.ids[index]}
Content: {content}
""")
"""
)
return "\n".join(context)
async def generate_rag_results(
@ -501,10 +516,7 @@ Content: {content}
A list of dictionaries containing the RAG results.
"""
if not self.user:
error_message = ChatMessageError(
session_id=session_id,
content="No user set for RAG generation."
)
error_message = ChatMessageError(session_id=session_id, content="No user set for RAG generation.")
yield error_message
return
@ -517,14 +529,12 @@ Content: {content}
status_message = ChatMessageStatus(
session_id=session_id,
activity=ApiActivityType.SEARCHING,
content = f"Searching RAG context {rag.name}..."
content=f"Searching RAG context {rag.name}...",
)
yield status_message
try:
chroma_results = await user.file_watcher.find_similar(
query=prompt, top_k=top_k, threshold=threshold
)
chroma_results = await user.file_watcher.find_similar(query=prompt, top_k=top_k, threshold=threshold)
if not chroma_results:
continue
query_embedding = np.array(chroma_results["query_embedding"]).flatten() # type: ignore
@ -548,7 +558,7 @@ Content: {content}
continue_message = ChatMessageStatus(
session_id=session_id,
activity=ApiActivityType.SEARCHING,
content=f"Error searching RAG context {rag.name}: {str(e)}"
content=f"Error searching RAG context {rag.name}: {str(e)}",
)
yield continue_message
@ -562,22 +572,20 @@ Content: {content}
async def llm_one_shot(
self,
llm: Any, model: str,
session_id: str, prompt: str, system_prompt: str,
llm: Any,
model: str,
session_id: str,
prompt: str,
system_prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7) -> AsyncGenerator[ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessage, None]:
temperature=0.7,
) -> AsyncGenerator[ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessage, None]:
if not self.user:
error_message = ChatMessageError(
session_id=session_id,
content="No user set for chat generation."
)
error_message = ChatMessageError(session_id=session_id, content="No user set for chat generation.")
yield error_message
return
self.set_optimal_context_size(
llm=llm, model=model, prompt=prompt+system_prompt
)
self.set_optimal_context_size(llm=llm, model=model, prompt=prompt + system_prompt)
options = ChatOptions(
seed=8911,
@ -591,9 +599,7 @@ Content: {content}
]
status_message = ChatMessageStatus(
session_id=session_id,
activity=ApiActivityType.GENERATING,
content="Generating response..."
session_id=session_id, activity=ApiActivityType.GENERATING, content="Generating response..."
)
yield status_message
@ -609,10 +615,7 @@ Content: {content}
stream=True,
):
if not response:
error_message = ChatMessageError(
session_id=session_id,
content="No response from LLM."
)
error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
yield error_message
return
@ -627,17 +630,12 @@ Content: {content}
yield streaming_message
if not response:
error_message = ChatMessageError(
session_id=session_id,
content="No response from LLM."
)
error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
yield error_message
return
self.user.collect_metrics(agent=self, response=response)
self.context_tokens = (
response.usage.prompt_eval_count + response.usage.eval_count
)
self.context_tokens = response.usage.prompt_eval_count + response.usage.eval_count
chat_message = ChatMessage(
session_id=session_id,
@ -650,23 +648,16 @@ Content: {content}
eval_duration=response.usage.eval_duration,
prompt_eval_count=response.usage.prompt_eval_count,
prompt_eval_duration=response.usage.prompt_eval_duration,
)
),
)
yield chat_message
return
async def generate(
self, llm: Any, model: str,
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
) -> AsyncGenerator[ApiMessage, None]:
if not self.user:
error_message = ChatMessageError(
session_id=session_id,
content="No user set for chat generation."
)
error_message = ChatMessageError(session_id=session_id, content="No user set for chat generation.")
yield error_message
return
@ -690,37 +681,32 @@ Content: {content}
yield message
if not isinstance(message, ChatMessageRagSearch):
raise ValueError(
f"Expected ChatMessageRagSearch, got {type(rag_message)}"
)
raise ValueError(f"Expected ChatMessageRagSearch, got {type(rag_message)}")
rag_message = message
context = self.get_rag_context(rag_message)
# Create a pruned down message list based purely on the prompt and responses,
# discarding the full preamble generated by prepare_message
messages: List[LLMMessage] = [
LLMMessage(role="system", content=self.system_prompt)
]
messages: List[LLMMessage] = [LLMMessage(role="system", content=self.system_prompt)]
# Add the conversation history to the messages
messages.extend([
messages.extend(
[
LLMMessage(role="user" if isinstance(m, ChatMessageUser) else "assistant", content=m.content)
for m in self.conversation
])
]
)
# Add the RAG context to the messages if available
if context:
messages.append(
LLMMessage(
role="user",
content=f"<|context|>\nThe following is context information about {self.user.full_name}:\n{context}\n</|context|>\n\nPrompt to respond to:\n{prompt}\n"
content=f"<|context|>\nThe following is context information about {self.user.full_name}:\n{context}\n</|context|>\n\nPrompt to respond to:\n{prompt}\n",
)
)
else:
# Only the actual user query is provided with the full context message
messages.append(
LLMMessage(role="user", content=prompt)
)
messages.append(LLMMessage(role="user", content=prompt))
# use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
# message.metadata.tools = {
@ -824,16 +810,12 @@ Content: {content}
# not use_tools
status_message = ChatMessageStatus(
session_id=session_id,
activity=ApiActivityType.GENERATING,
content="Generating response..."
session_id=session_id, activity=ApiActivityType.GENERATING, content="Generating response..."
)
yield status_message
# Set the response for streaming
self.set_optimal_context_size(
llm, model, prompt=prompt
)
self.set_optimal_context_size(llm, model, prompt=prompt)
options = ChatOptions(
seed=8911,
@ -853,10 +835,7 @@ Content: {content}
stream=True,
):
if not response:
error_message = ChatMessageError(
session_id=session_id,
content="No response from LLM."
)
error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
yield error_message
return
@ -870,17 +849,12 @@ Content: {content}
yield streaming_message
if not response:
error_message = ChatMessageError(
session_id=session_id,
content="No response from LLM."
)
error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
yield error_message
return
self.user.collect_metrics(agent=self, response=response)
self.context_tokens = (
response.usage.prompt_eval_count + response.usage.eval_count
)
self.context_tokens = response.usage.prompt_eval_count + response.usage.eval_count
end_time = time.perf_counter()
chat_message = ChatMessage(
@ -899,9 +873,8 @@ Content: {content}
"llm_streamed": end_time - start_time,
"llm_with_tools": 0, # Placeholder for tool processing time
},
),
)
)
# Add the user and chat messages to the conversation
self.conversation.append(user_message)
@ -996,12 +969,13 @@ Content: {content}
raise ValueError("No Markdown found in the response")
_agents: List[Agent] = []
def get_or_create_agent(
agent_type: str,
prometheus_collector: CollectorRegistry,
user: Optional[CandidateEntity]=None) -> Agent:
agent_type: str, prometheus_collector: CollectorRegistry, user: Optional[CandidateEntity] = None
) -> Agent:
"""
Get or create and append a new agent of the specified type, ensuring only one agent per type exists.
@ -1025,14 +999,16 @@ def get_or_create_agent(
for agent_cls in Agent.__subclasses__():
if agent_cls.model_fields["agent_type"].default == agent_type:
# Create the agent instance with provided kwargs
agent = agent_cls(agent_type=agent_type, # type: ignore[call-arg]
user=user)
agent = agent_cls(
agent_type=agent_type, # type: ignore[call-arg]
user=user,
)
_agents.append(agent)
return agent
raise ValueError(f"No agent class found for agent_type: {agent_type}")
# Register the base agent
agent_registry.register(Agent._agent_type, Agent)
CandidateEntity.model_rebuild()

View File

@ -5,7 +5,7 @@ from .base import Agent, agent_registry
from logger import logger
from .registry import agent_registry
from models import ( ApiMessage, Tunables, ApiStatusType)
from models import ApiMessage, Tunables, ApiStatusType
system_message = """
@ -21,6 +21,7 @@ Always <|context|> when possible. Be concise, and never make up information. If
Before answering, ensure you have spelled the candidate's name correctly.
"""
class CandidateChat(Agent):
"""
CandidateChat Agent
@ -32,10 +33,7 @@ class CandidateChat(Agent):
system_prompt: str = system_message
async def generate(
self, llm: Any, model: str,
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
) -> AsyncGenerator[ApiMessage, None]:
user = self.user
if not user:
@ -54,12 +52,14 @@ Use that spelling instead of any spelling you may find in the <|context|>.
{system_message}
"""
async for message in super().generate(llm=llm, model=model, session_id=session_id, prompt=prompt, temperature=temperature, tunables=tunables):
async for message in super().generate(
llm=llm, model=model, session_id=session_id, prompt=prompt, temperature=temperature, tunables=tunables
):
if message.status == ApiStatusType.ERROR:
yield message
return
yield message
# Register the base agent
agent_registry.register(CandidateChat._agent_type, CandidateChat)

View File

@ -49,11 +49,13 @@ class Chat(Agent):
"""
Chat Agent
"""
agent_type: Literal["general"] = "general" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
system_prompt: str = system_message
# async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
# if not self.context:

View File

@ -4,7 +4,7 @@ from typing import (
ClassVar,
Any,
AsyncGenerator,
Optional
Optional,
# override
) # NOTE: You must import Optional for late binding to work
import random
@ -13,7 +13,15 @@ import time
import os
from .base import Agent, agent_registry
from models import ApiActivityType, ChatMessage, ChatMessageError, ChatMessageStatus, ChatMessageStreaming, ApiStatusType, Tunables
from models import (
ApiActivityType,
ChatMessage,
ChatMessageError,
ChatMessageStatus,
ChatMessageStreaming,
ApiStatusType,
Tunables,
)
from logger import logger
import defines
import backstory_traceback as traceback
@ -23,6 +31,7 @@ from image_generator.profile_image import generate_image, ImageRequest
seed = int(time.time())
random.seed(seed)
class ImageGenerator(Agent):
agent_type: Literal["generate_image"] = "generate_image" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
@ -31,10 +40,7 @@ class ImageGenerator(Agent):
system_prompt: str = "" # No system prompt is used
async def generate(
self, llm: Any, model: str,
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
if not self.user:
logger.error("User is not set for ImageGenerator agent.")
@ -54,11 +60,17 @@ class ImageGenerator(Agent):
yield status_message
logger.info(f"Image generation: {file_path} <- {prompt}")
request = ImageRequest(filepath=file_path, session_id=session_id, prompt=prompt, iterations=4, height=256, width=256, guidance_scale=7.5)
request = ImageRequest(
filepath=file_path,
session_id=session_id,
prompt=prompt,
iterations=4,
height=256,
width=256,
guidance_scale=7.5,
)
generated_message = None
async for generated_message in generate_image(
request=request
):
async for generated_message in generate_image(request=request):
if generated_message.status == ApiStatusType.ERROR:
yield generated_message
return
@ -68,8 +80,7 @@ class ImageGenerator(Agent):
if generated_message is None:
error_message = ChatMessageError(
session_id=session_id,
content="Image generation failed to produce a valid response."
session_id=session_id, content="Image generation failed to produce a valid response."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
@ -84,20 +95,18 @@ class ImageGenerator(Agent):
session_id=session_id,
status=ApiStatusType.DONE,
content=f"{defines.api_prefix}/profile/{user.username}",
metadata=generated_message.metadata
metadata=generated_message.metadata,
)
yield generated_image
return
except Exception as e:
error_message = ChatMessageError(
session_id=session_id,
content=f"Error generating image: {str(e)}"
)
error_message = ChatMessageError(session_id=session_id, content=f"Error generating image: {str(e)}")
logger.error(traceback.format_exc())
logger.error(f"⚠️ {error_message.content}")
yield error_message
return
# Register the base agent
agent_registry.register(ImageGenerator._agent_type, ImageGenerator)

View File

@ -8,7 +8,7 @@ from typing import (
Tuple,
AsyncGenerator,
List,
Optional
Optional,
# override
) # NOTE: You must import Optional for late binding to work
import random
@ -20,7 +20,16 @@ import os
import random
from .base import Agent, agent_registry
from models import ApiActivityType, ChatMessage, ChatMessageError, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ApiStatusType, Tunables
from models import (
ApiActivityType,
ChatMessage,
ChatMessageError,
ApiMessageType,
ChatMessageStatus,
ChatMessageStreaming,
ApiStatusType,
Tunables,
)
from logger import logger
import defines
import backstory_traceback as traceback
@ -43,6 +52,7 @@ emptyUser = {
"questions": [],
}
def generate_persona_system_prompt(persona: Dict[str, Any]) -> str:
return f"""\
You are a casting director for a movie. Your job is to provide information on ficticious personas for use in a screen play.
@ -84,6 +94,7 @@ DO NOT infer, imply, abbreviate, or state the ethnicity, gender, or age in the u
You are providing those only for use later by the system when casting individuals for the role.
"""
generate_resume_system_prompt = """
You are a creative writing casting director. As part of the casting, you are building backstories about individuals. The first part
of that is to create an in-depth resume for the person. You will be provided with the following information:
@ -115,10 +126,12 @@ import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EthnicNameGenerator:
def __init__(self):
try:
from names_dataset import NameDataset # type: ignore
self.nd = NameDataset()
except ImportError:
logger.error("NameDataset not available. Please install: pip install names-dataset")
@ -129,24 +142,24 @@ class EthnicNameGenerator:
# US Census 2020 approximate ethnic distribution
self.ethnic_weights = {
'White': 0.576,
'Hispanic': 0.186,
'Black': 0.134,
'Asian': 0.062,
'Native American': 0.013,
'Pacific Islander': 0.003,
'Mixed/Other': 0.026
"White": 0.576,
"Hispanic": 0.186,
"Black": 0.134,
"Asian": 0.062,
"Native American": 0.013,
"Pacific Islander": 0.003,
"Mixed/Other": 0.026,
}
# Map ethnicities to countries (using alpha-2 codes that NameDataset uses)
self.ethnic_country_mapping = {
'White': ['US', 'GB', 'DE', 'IE', 'IT', 'PL', 'FR', 'CA', 'AU'],
'Hispanic': ['MX', 'ES', 'CO', 'PE', 'AR', 'CU', 'VE', 'CL'],
'Black': ['US'], # African American names
'Asian': ['CN', 'IN', 'PH', 'VN', 'KR', 'JP', 'TH', 'MY'],
'Native American': ['US'],
'Pacific Islander': ['US'],
'Mixed/Other': ['US']
"White": ["US", "GB", "DE", "IE", "IT", "PL", "FR", "CA", "AU"],
"Hispanic": ["MX", "ES", "CO", "PE", "AR", "CU", "VE", "CL"],
"Black": ["US"], # African American names
"Asian": ["CN", "IN", "PH", "VN", "KR", "JP", "TH", "MY"],
"Native American": ["US"],
"Pacific Islander": ["US"],
"Mixed/Other": ["US"],
}
def get_weighted_ethnicity(self) -> str:
@ -155,8 +168,9 @@ class EthnicNameGenerator:
weights = list(self.ethnic_weights.values())
return random.choices(ethnicities, weights=weights)[0]
def get_names_by_criteria(self, countries: List[str], gender: Optional[str] = None,
n: int = 50, use_first_names: bool = True) -> List[str]:
def get_names_by_criteria(
self, countries: List[str], gender: Optional[str] = None, n: int = 50, use_first_names: bool = True
) -> List[str]:
"""Get names matching criteria using NameDataset's get_top_names method"""
if not self.nd:
return []
@ -166,16 +180,13 @@ class EthnicNameGenerator:
try:
# Get top names for this country
top_names = self.nd.get_top_names(
n=n,
use_first_names=use_first_names,
country_alpha2=country_code,
gender=gender
n=n, use_first_names=use_first_names, country_alpha2=country_code, gender=gender
)
if country_code in top_names:
if use_first_names and gender:
# For first names with gender specified
gender_key = 'M' if gender.upper() in ['M', 'MALE'] else 'F'
gender_key = "M" if gender.upper() in ["M", "MALE"] else "F"
if gender_key in top_names[country_code]:
all_names.extend(top_names[country_code][gender_key])
elif use_first_names:
@ -192,25 +203,18 @@ class EthnicNameGenerator:
return list(set(all_names)) # Remove duplicates
def get_name_by_ethnicity(self, ethnicity: str, gender: str = 'random') -> Tuple[str, str, str, str]:
def get_name_by_ethnicity(self, ethnicity: str, gender: str = "random") -> Tuple[str, str, str, str]:
"""Generate a name based on ethnicity using the correct NameDataset API"""
if gender == 'random':
gender = random.choice(['Male', 'Female'])
if gender == "random":
gender = random.choice(["Male", "Female"])
countries = self.ethnic_country_mapping.get(ethnicity, ['US'])
countries = self.ethnic_country_mapping.get(ethnicity, ["US"])
# Get first names
first_names = self.get_names_by_criteria(
countries=countries,
gender=gender,
use_first_names=True
)
first_names = self.get_names_by_criteria(countries=countries, gender=gender, use_first_names=True)
# Get last names
last_names = self.get_names_by_criteria(
countries=countries,
use_first_names=False
)
last_names = self.get_names_by_criteria(countries=countries, use_first_names=False)
# Select names or use fallbacks
if first_names:
@ -230,57 +234,60 @@ class EthnicNameGenerator:
def _get_fallback_first_name(self, gender: str, ethnicity: str) -> str:
"""Provide culturally appropriate fallback first names"""
fallback_names = {
'White': {
'Male': ['James', 'Robert', 'John', 'Michael', 'William', 'David', 'Richard', 'Joseph'],
'Female': ['Mary', 'Patricia', 'Jennifer', 'Linda', 'Elizabeth', 'Barbara', 'Susan', 'Jessica']
"White": {
"Male": ["James", "Robert", "John", "Michael", "William", "David", "Richard", "Joseph"],
"Female": ["Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan", "Jessica"],
},
'Hispanic': {
'Male': ['José', 'Luis', 'Miguel', 'Juan', 'Francisco', 'Alejandro', 'Antonio', 'Carlos'],
'Female': ['María', 'Guadalupe', 'Juana', 'Margarita', 'Francisca', 'Teresa', 'Rosa', 'Ana']
"Hispanic": {
"Male": ["José", "Luis", "Miguel", "Juan", "Francisco", "Alejandro", "Antonio", "Carlos"],
"Female": ["María", "Guadalupe", "Juana", "Margarita", "Francisca", "Teresa", "Rosa", "Ana"],
},
'Black': {
'Male': ['James', 'Robert', 'John', 'Michael', 'William', 'David', 'Richard', 'Charles'],
'Female': ['Mary', 'Patricia', 'Linda', 'Elizabeth', 'Barbara', 'Susan', 'Jessica', 'Sarah']
"Black": {
"Male": ["James", "Robert", "John", "Michael", "William", "David", "Richard", "Charles"],
"Female": ["Mary", "Patricia", "Linda", "Elizabeth", "Barbara", "Susan", "Jessica", "Sarah"],
},
"Asian": {
"Male": ["Wei", "Ming", "Chen", "Li", "Kumar", "Raj", "Hiroshi", "Takeshi"],
"Female": ["Mei", "Lin", "Ling", "Priya", "Yuki", "Soo", "Hana", "Anh"],
},
'Asian': {
'Male': ['Wei', 'Ming', 'Chen', 'Li', 'Kumar', 'Raj', 'Hiroshi', 'Takeshi'],
'Female': ['Mei', 'Lin', 'Ling', 'Priya', 'Yuki', 'Soo', 'Hana', 'Anh']
}
}
ethnicity_names = fallback_names.get(ethnicity, fallback_names['White'])
return random.choice(ethnicity_names.get(gender, ethnicity_names['Male']))
ethnicity_names = fallback_names.get(ethnicity, fallback_names["White"])
return random.choice(ethnicity_names.get(gender, ethnicity_names["Male"]))
def _get_fallback_last_name(self, ethnicity: str) -> str:
"""Provide culturally appropriate fallback last names"""
fallback_surnames = {
'White': ['Smith', 'Johnson', 'Williams', 'Brown', 'Jones', 'Miller', 'Wilson', 'Moore'],
'Hispanic': ['García', 'Rodríguez', 'Martínez', 'López', 'González', 'Pérez', 'Sánchez', 'Ramírez'],
'Black': ['Johnson', 'Williams', 'Brown', 'Jones', 'Davis', 'Miller', 'Wilson', 'Moore'],
'Asian': ['Li', 'Wang', 'Zhang', 'Liu', 'Chen', 'Yang', 'Huang', 'Zhao']
"White": ["Smith", "Johnson", "Williams", "Brown", "Jones", "Miller", "Wilson", "Moore"],
"Hispanic": ["García", "Rodríguez", "Martínez", "López", "González", "Pérez", "Sánchez", "Ramírez"],
"Black": ["Johnson", "Williams", "Brown", "Jones", "Davis", "Miller", "Wilson", "Moore"],
"Asian": ["Li", "Wang", "Zhang", "Liu", "Chen", "Yang", "Huang", "Zhao"],
}
return random.choice(fallback_surnames.get(ethnicity, fallback_surnames['White']))
return random.choice(fallback_surnames.get(ethnicity, fallback_surnames["White"]))
def generate_random_name(self, gender: str = 'random') -> Tuple[str, str, str, str]:
def generate_random_name(self, gender: str = "random") -> Tuple[str, str, str, str]:
"""Generate a random name with ethnicity based on US demographics"""
ethnicity = self.get_weighted_ethnicity()
return self.get_name_by_ethnicity(ethnicity, gender)
def generate_multiple_names(self, count: int = 10, gender: str = 'random') -> List[Dict]:
def generate_multiple_names(self, count: int = 10, gender: str = "random") -> List[Dict]:
"""Generate multiple random names"""
names = []
for _ in range(count):
first, last, ethnicity, actual_gender = self.generate_random_name(gender)
names.append({
'full_name': f"{first} {last}",
'first_name': first,
'last_name': last,
'ethnicity': ethnicity,
'gender': actual_gender
})
names.append(
{
"full_name": f"{first} {last}",
"first_name": first,
"last_name": last,
"ethnicity": ethnicity,
"gender": actual_gender,
}
)
return names
class GeneratePersona(Agent):
agent_type: Literal["generate_persona"] = "generate_persona" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
@ -305,10 +312,7 @@ class GeneratePersona(Agent):
self.full_name = f"{self.first_name} {self.last_name}"
async def generate(
self, llm: Any, model: str,
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
self.randomize()
@ -340,7 +344,8 @@ Incorporate the following into the job description: {original_prompt}
logger.info("🤖 Generating persona...")
generating_message = None
async for generating_message in self.llm_one_shot(
llm=llm, model=model,
llm=llm,
model=model,
session_id=session_id,
prompt=prompt,
system_prompt=generate_persona_system_prompt(persona=persona),
@ -354,8 +359,7 @@ Incorporate the following into the job description: {original_prompt}
if not generating_message:
error_message = ChatMessageError(
session_id=session_id,
content="Persona generation failed to generate a response."
session_id=session_id, content="Persona generation failed to generate a response."
)
yield error_message
return
@ -373,7 +377,7 @@ Incorporate the following into the job description: {original_prompt}
self.username = persona.get("username", None)
if not self.username:
raise ValueError("LLM did not generate a username")
self.username = re.sub(r'\s+', '.', self.username)
self.username = re.sub(r"\s+", ".", self.username)
user_dir = os.path.join(defines.user_dir, persona["username"])
while os.path.exists(user_dir):
match = re.match(r"^(.*?)(\d*)$", persona["username"])
@ -396,19 +400,14 @@ Incorporate the following into the job description: {original_prompt}
location_parts = persona["location"].split(",")
if len(location_parts) == 3:
city, state, country = [part.strip() for part in location_parts]
persona["location"] = {
"city": city,
"state": state,
"country": country
}
persona["location"] = {"city": city, "state": state, "country": country}
else:
logger.error(f"Invalid location format: {persona['location']}")
persona["location"] = None
persona["is_ai"] = True
except Exception as e:
error_message = ChatMessageError(
session_id=session_id,
content=f"Error parsing LLM response: {str(e)}\n\n{json_str}"
session_id=session_id, content=f"Error parsing LLM response: {str(e)}\n\n{json_str}"
)
logger.error(f"❌ Error parsing LLM response: {error_message.content}")
logger.error(traceback.format_exc())
@ -420,10 +419,7 @@ Incorporate the following into the job description: {original_prompt}
# Persona generated
persona_message = ChatMessage(
session_id=session_id,
status=ApiStatusType.DONE,
type=ApiMessageType.JSON,
content = json.dumps(persona)
session_id=session_id, status=ApiStatusType.DONE, type=ApiMessageType.JSON, content=json.dumps(persona)
)
yield persona_message
@ -433,7 +429,7 @@ Incorporate the following into the job description: {original_prompt}
status_message = ChatMessageStatus(
session_id=session_id,
activity=ApiActivityType.THINKING,
content = f"Generating resume for {persona['full_name']}..."
content=f"Generating resume for {persona['full_name']}...",
)
logger.info(f"🤖 {status_message.content}")
yield status_message
@ -456,7 +452,8 @@ Incorporate the following into the job description: {original_prompt}
Make sure at least one of the candidate's job descriptions take into account the following: {original_prompt}."""
async for generating_message in self.llm_one_shot(
llm=llm, model=model,
llm=llm,
model=model,
session_id=session_id,
prompt=content,
system_prompt=generate_resume_system_prompt,
@ -470,8 +467,7 @@ Make sure at least one of the candidate's job descriptions take into account the
if not generating_message:
error_message = ChatMessageError(
session_id=session_id,
content="Resume generation failed to generate a response."
session_id=session_id, content="Resume generation failed to generate a response."
)
logger.error(f"{error_message.content}")
yield error_message
@ -479,10 +475,7 @@ Make sure at least one of the candidate's job descriptions take into account the
resume = self.extract_markdown_from_text(generating_message.content)
resume_message = ChatMessage(
session_id=session_id,
status=ApiStatusType.DONE,
type=ApiMessageType.TEXT,
content=resume
session_id=session_id, status=ApiStatusType.DONE, type=ApiMessageType.TEXT, content=resume
)
yield resume_message
return
@ -502,5 +495,6 @@ Make sure at least one of the candidate's job descriptions take into account the
raise ValueError("No JSON found in the response")
# Register the base agent
agent_registry.register(GeneratePersona._agent_type, GeneratePersona)

View File

@ -4,23 +4,31 @@ from typing import (
ClassVar,
Any,
AsyncGenerator,
List
List,
# override
) # NOTE: You must import Optional for late binding to work
import json
from logger import logger
from .base import Agent, agent_registry
from models import (ApiActivityType, ApiMessage, ApiStatusType, ChatMessage, ChatMessageError, ChatMessageResume, ChatMessageStatus, SkillAssessment, SkillStrength)
from models import (
ApiActivityType,
ApiMessage,
ApiStatusType,
ChatMessage,
ChatMessageError,
ChatMessageResume,
ChatMessageStatus,
SkillAssessment,
SkillStrength,
)
class GenerateResume(Agent):
agent_type: Literal["generate_resume"] = "generate_resume" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
def generate_resume_prompt(
self,
skills: List[SkillAssessment]
):
def generate_resume_prompt(self, skills: List[SkillAssessment]):
"""
Generate a professional resume based on skill assessment results
@ -41,7 +49,7 @@ class GenerateResume(Agent):
SkillStrength.STRONG: [],
SkillStrength.MODERATE: [],
SkillStrength.WEAK: [],
SkillStrength.NONE: []
SkillStrength.NONE: [],
}
experience_evidence = {}
@ -63,11 +71,7 @@ class GenerateResume(Agent):
experience_evidence[source] = []
experience_evidence[source].append(
{
"skill": skill,
"quote": evidence.quote,
"context": evidence.context
}
{"skill": skill, "quote": evidence.quote, "context": evidence.context}
)
# Build the system prompt
@ -167,16 +171,16 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
) -> AsyncGenerator[ApiMessage, None]:
# Stage 1A: Analyze job requirements
status_message = ChatMessageStatus(
session_id=session_id,
content = "Analyzing job requirements",
activity=ApiActivityType.THINKING
session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING
)
yield status_message
system_prompt, prompt = self.generate_resume_prompt(skills=skills)
generated_message = None
async for generated_message in self.llm_one_shot(llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt):
async for generated_message in self.llm_one_shot(
llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt
):
if generated_message.status == ApiStatusType.ERROR:
yield generated_message
return
@ -185,8 +189,7 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
if not generated_message:
error_message = ChatMessageError(
session_id=session_id,
content="Job requirements analysis failed to generate a response."
session_id=session_id, content="Job requirements analysis failed to generate a response."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
@ -194,8 +197,7 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
if not isinstance(generated_message, ChatMessage):
error_message = ChatMessageError(
session_id=session_id,
content="Job requirements analysis did not return a valid message."
session_id=session_id, content="Job requirements analysis did not return a valid message."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
@ -214,5 +216,6 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
logger.info("✅ Resume generation completed successfully.")
return
# Register the base agent
agent_registry.register(GenerateResume._agent_type, GenerateResume)

View File

@ -5,17 +5,30 @@ from typing import (
ClassVar,
Any,
AsyncGenerator,
Optional
Optional,
# override
) # NOTE: You must import Optional for late binding to work
import inspect
import json
from .base import Agent, agent_registry
from models import ApiActivityType, ApiMessage, ChatMessage, ChatMessageError, ChatMessageStatus, ChatMessageStreaming, ApiStatusType, Job, JobRequirements, JobRequirementsMessage, Tunables
from models import (
ApiActivityType,
ApiMessage,
ChatMessage,
ChatMessageError,
ChatMessageStatus,
ChatMessageStreaming,
ApiStatusType,
Job,
JobRequirements,
JobRequirementsMessage,
Tunables,
)
from logger import logger
import backstory_traceback as traceback
class JobRequirementsAgent(Agent):
agent_type: Literal["job_requirements"] = "job_requirements" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
@ -91,14 +104,14 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
"""Analyze job requirements from job description."""
system_prompt, prompt = self.create_job_analysis_prompt(prompt)
status_message = ChatMessageStatus(
session_id=session_id,
content="Analyzing job requirements",
activity=ApiActivityType.THINKING
session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING
)
yield status_message
logger.info(f"🔍 {status_message.content}")
generated_message = None
async for generated_message in self.llm_one_shot(llm, model, session_id=session_id, prompt=prompt, system_prompt=system_prompt):
async for generated_message in self.llm_one_shot(
llm, model, session_id=session_id, prompt=prompt, system_prompt=system_prompt
):
if generated_message.status == ApiStatusType.ERROR:
yield generated_message
return
@ -107,8 +120,8 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
if not generated_message:
error_message = ChatMessageError(
session_id=session_id,
content="Job requirements analysis failed to generate a response.")
session_id=session_id, content="Job requirements analysis failed to generate a response."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
return
@ -129,18 +142,18 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
display = {
"technical_skills": {
"required": reqs.technical_skills.required,
"preferred": reqs.technical_skills.preferred
"preferred": reqs.technical_skills.preferred,
},
"experience_requirements": {
"required": reqs.experience_requirements.required,
"preferred": reqs.experience_requirements.preferred
"preferred": reqs.experience_requirements.preferred,
},
"soft_skills": reqs.soft_skills,
"experience": reqs.experience,
"education": reqs.education,
"certifications": reqs.certifications,
"preferred_attributes": reqs.preferred_attributes,
"company_values": reqs.company_values
"company_values": reqs.company_values,
}
return display
@ -149,19 +162,14 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
) -> AsyncGenerator[ApiMessage, None]:
if not self.user:
error_message = ChatMessageError(
session_id=session_id,
content="User is not set for this agent."
)
error_message = ChatMessageError(session_id=session_id, content="User is not set for this agent.")
logger.error(f"⚠️ {error_message.content}")
yield error_message
return
# Stage 1A: Analyze job requirements
status_message = ChatMessageStatus(
session_id=session_id,
content = "Analyzing job requirements",
activity=ApiActivityType.THINKING
session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING
)
yield status_message
@ -175,8 +183,7 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
if not generated_message:
error_message = ChatMessageError(
session_id=session_id,
content="Job requirements analysis failed to generate a response."
session_id=session_id, content="Job requirements analysis failed to generate a response."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
@ -211,7 +218,9 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
return
except Exception as e:
status_message.status = ApiStatusType.ERROR
status_message.content = f"Unexpected error processing job requirements: {str(e)}\n\n{job_requirements_data}"
status_message.content = (
f"Unexpected error processing job requirements: {str(e)}\n\n{job_requirements_data}"
)
logger.error(traceback.format_exc())
logger.error(f"⚠️ {status_message.content}")
yield status_message
@ -238,5 +247,6 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
logger.info("✅ Job requirements analysis completed successfully.")
return
# Register the base agent
agent_registry.register(JobRequirementsAgent._agent_type, JobRequirementsAgent)

View File

@ -5,20 +5,19 @@ from .base import Agent, agent_registry
from logger import logger
from .registry import agent_registry
from models import ( ApiMessage, ApiStatusType, ChatMessageError, ChatMessageRagSearch, ApiStatusType, Tunables )
from models import ApiMessage, ApiStatusType, ChatMessageError, ChatMessageRagSearch, ApiStatusType, Tunables
class Chat(Agent):
"""
Chat Agent
"""
agent_type: Literal["rag_search"] = "rag_search" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
async def generate(
self, llm: Any, model: str,
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
) -> AsyncGenerator[ApiMessage, None]:
"""
Generate a response based on the user message and the provided LLM.
@ -44,8 +43,7 @@ class Chat(Agent):
if not isinstance(rag_message, ChatMessageRagSearch):
logger.error(f"Expected ChatMessageRagSearch, got {type(rag_message)}")
error_message = ChatMessageError(
session_id=session_id,
content="RAG search did not return a valid response."
session_id=session_id, content="RAG search did not return a valid response."
)
yield error_message
return
@ -53,5 +51,6 @@ class Chat(Agent):
rag_message.status = ApiStatusType.DONE
yield rag_message
# Register the base agent
agent_registry.register(Chat._agent_type, Chat)

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from typing import List, Dict, Optional, Type
# We'll use a registry pattern rather than hardcoded strings
class AgentRegistry:
"""Registry for agent types and classes"""

View File

@ -4,17 +4,27 @@ from typing import (
ClassVar,
Any,
AsyncGenerator,
Optional
Optional,
# override
) # NOTE: You must import Optional for late binding to work
import json
from .base import Agent, agent_registry
from models import (ApiMessage, ChatMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageSkillAssessment, ApiStatusType, EvidenceDetail,
SkillAssessment, Tunables)
from models import (
ApiMessage,
ChatMessage,
ChatMessageError,
ChatMessageRagSearch,
ChatMessageSkillAssessment,
ApiStatusType,
EvidenceDetail,
SkillAssessment,
Tunables,
)
from logger import logger
import backstory_traceback as traceback
class SkillMatchAgent(Agent):
agent_type: Literal["skill_match"] = "skill_match" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
@ -96,15 +106,12 @@ JSON RESPONSE:"""
return system_prompt, prompt
async def generate(
self, llm: Any, model: str,
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
) -> AsyncGenerator[ApiMessage, None]:
if not self.user:
error_message = ChatMessageError(
session_id=session_id,
content="Agent not attached to user. Attach the agent to a user before generating responses."
content="Agent not attached to user. Attach the agent to a user before generating responses.",
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
@ -112,10 +119,7 @@ JSON RESPONSE:"""
skill = prompt.strip()
if not skill:
error_message = ChatMessageError(
session_id=session_id,
content="Skill cannot be empty."
)
error_message = ChatMessageError(session_id=session_id, content="Skill cannot be empty.")
logger.error(f"⚠️ {error_message.content}")
yield error_message
return
@ -130,8 +134,7 @@ JSON RESPONSE:"""
if generated_message is None:
error_message = ChatMessageError(
session_id=session_id,
content="RAG search did not return a valid response."
session_id=session_id, content="RAG search did not return a valid response."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
@ -140,8 +143,7 @@ JSON RESPONSE:"""
if not isinstance(generated_message, ChatMessageRagSearch):
logger.error(f"Expected ChatMessageRagSearch, got {type(generated_message)}")
error_message = ChatMessageError(
session_id=session_id,
content="RAG search did not return a valid response."
session_id=session_id, content="RAG search did not return a valid response."
)
yield error_message
return
@ -152,7 +154,9 @@ JSON RESPONSE:"""
system_prompt, prompt = self.generate_skill_assessment_prompt(skill=skill, rag_context=rag_context)
generated_message = None
async for generated_message in self.llm_one_shot(llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt, temperature=0.7):
async for generated_message in self.llm_one_shot(
llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt, temperature=0.7
):
if generated_message.status == ApiStatusType.ERROR:
logger.error(f"⚠️ {generated_message.content}")
yield generated_message
@ -162,8 +166,7 @@ JSON RESPONSE:"""
if generated_message is None:
error_message = ChatMessageError(
session_id=session_id,
content="Skill assessment failed to generate a response."
session_id=session_id, content="Skill assessment failed to generate a response."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
@ -171,8 +174,7 @@ JSON RESPONSE:"""
if not isinstance(generated_message, ChatMessage):
error_message = ChatMessageError(
session_id=session_id,
content="Skill assessment did not return a valid message."
session_id=session_id, content="Skill assessment did not return a valid message."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
@ -195,14 +197,15 @@ JSON RESPONSE:"""
EvidenceDetail(
source=evidence.get("source", ""),
quote=evidence.get("quote", ""),
context=evidence.get("context", "")
) for evidence in skill_assessment_data.get("evidence_details", [])
]
context=evidence.get("context", ""),
)
for evidence in skill_assessment_data.get("evidence_details", [])
],
)
except Exception as e:
error_message = ChatMessageError(
session_id=session_id,
content=f"Failed to parse Skill assessment JSON: {str(e)}\n\n{generated_message.content}\n\nJSON:\n{json_str}\n\n"
content=f"Failed to parse Skill assessment JSON: {str(e)}\n\n{generated_message.content}\n\nJSON:\n{json_str}\n\n",
)
logger.error(traceback.format_exc())
logger.error(f"⚠️ {error_message.content}")
@ -232,5 +235,6 @@ JSON RESPONSE:"""
logger.info("✅ Skill assessment completed successfully.")
return
# Register the base agent
agent_registry.register(SkillMatchAgent._agent_type, SkillMatchAgent)

View File

@ -9,6 +9,7 @@ from typing import Optional, List, Dict, Any, Callable
from logger import logger
from database.manager import DatabaseManager
class BackgroundTaskManager:
"""Manages background tasks for the application using asyncio instead of threading"""
@ -65,10 +66,12 @@ class BackgroundTaskManager:
stats = await database.get_guest_statistics()
# Log interesting statistics
if stats.get('total_guests', 0) > 0:
logger.info(f"📊 Guest stats: {stats['total_guests']} total, "
if stats.get("total_guests", 0) > 0:
logger.info(
f"📊 Guest stats: {stats['total_guests']} total, "
f"{stats['active_last_hour']} active in last hour, "
f"{stats['converted_guests']} converted")
f"{stats['converted_guests']} converted"
)
return stats
except Exception as e:
@ -84,6 +87,7 @@ class BackgroundTaskManager:
# Get Redis client safely (using the event loop safe method)
from database.manager import redis_manager
redis = await redis_manager.get_client()
# Clean up rate limit keys older than specified days
@ -192,10 +196,7 @@ class BackgroundTaskManager:
# Create asyncio tasks for each periodic task
for name, func, interval, *args in periodic_tasks:
task = asyncio.create_task(
self._run_periodic_task(name, func, interval, *args),
name=f"background_{name}"
)
task = asyncio.create_task(self._run_periodic_task(name, func, interval, *args), name=f"background_{name}")
self.tasks.append(task)
logger.info(f"📅 Scheduled background task: {name}")
@ -238,10 +239,7 @@ class BackgroundTaskManager:
# Wait for all tasks to complete with timeout
if self.tasks:
try:
await asyncio.wait_for(
asyncio.gather(*self.tasks, return_exceptions=True),
timeout=30.0
)
await asyncio.wait_for(asyncio.gather(*self.tasks, return_exceptions=True), timeout=30.0)
logger.info("✅ All background tasks stopped gracefully")
except asyncio.TimeoutError:
logger.warning("⚠️ Some background tasks did not stop within timeout")
@ -258,7 +256,7 @@ class BackgroundTaskManager:
"main_loop_id": id(self.main_loop) if self.main_loop else None,
"current_loop_id": None,
"task_count": len(self.tasks),
"tasks": []
"tasks": [],
}
try:
@ -317,6 +315,7 @@ async def setup_background_tasks(database_manager: DatabaseManager) -> Backgroun
await task_manager.start()
return task_manager
# For integration with your existing app startup
async def initialize_with_background_tasks(database_manager: DatabaseManager):
"""Initialize database and background tasks together"""

View File

@ -3,6 +3,7 @@ import os
import sys
import defines
def filter_traceback(tb, app_path=None, module_name=None):
"""
Filter traceback to include only frames from the specified application path or module.
@ -36,7 +37,8 @@ def filter_traceback(tb, app_path=None, module_name=None):
formatted_exc = traceback.format_exception_only(exc_type, exc_value)
# Combine the filtered stack trace with the exception message
return ''.join(formatted_stack + formatted_exc)
return "".join(formatted_stack + formatted_exc)
def format_exc(app_path=defines.app_path, module_name=None):
"""

View File

@ -1,4 +1,4 @@
from .core import RedisDatabase
from .manager import DatabaseManager, redis_manager
__all__ = ['RedisDatabase', 'DatabaseManager', 'redis_manager']
__all__ = ["RedisDatabase", "DatabaseManager", "redis_manager"]

View File

@ -1,15 +1,15 @@
KEY_PREFIXES = {
'viewers': 'viewer:',
'candidates': 'candidate:',
'employers': 'employer:',
'jobs': 'job:',
'job_applications': 'job_application:',
'chat_sessions': 'chat_session:',
'chat_messages': 'chat_messages:',
'ai_parameters': 'ai_parameters:',
'users': 'user:',
'candidate_documents': 'candidate_documents:',
'job_requirements': 'job_requirements:',
'resumes': 'resume:',
'user_resumes': 'user_resumes:',
"viewers": "viewer:",
"candidates": "candidate:",
"employers": "employer:",
"jobs": "job:",
"job_applications": "job_application:",
"chat_sessions": "chat_session:",
"chat_messages": "chat_messages:",
"ai_parameters": "ai_parameters:",
"users": "user:",
"candidate_documents": "candidate_documents:",
"job_requirements": "job_requirements:",
"resumes": "resume:",
"user_resumes": "user_resumes:",
}

View File

@ -10,6 +10,7 @@ from .mixins.job import JobMixin
from .mixins.skill import SkillMixin
from .mixins.ai import AIMixin
# RedisDatabase is the main class that combines all mixins for a
# comprehensive Redis database interface.
class RedisDatabase(

View File

@ -1,4 +1,4 @@
from redis.asyncio import (Redis, ConnectionPool)
from redis.asyncio import Redis, ConnectionPool
from typing import Optional, Optional
import json
import logging
@ -9,6 +9,7 @@ from .core import RedisDatabase
logger = logging.getLogger(__name__)
# _RedisManager is a singleton class that manages the Redis connection and
# provides methods for connecting, disconnecting, and performing health checks.
#
@ -42,12 +43,10 @@ class _RedisManager:
retry_on_timeout=True,
socket_keepalive=True,
socket_keepalive_options={},
health_check_interval=30
health_check_interval=30,
)
self.redis = Redis(
connection_pool=self._connection_pool
)
self.redis = Redis(connection_pool=self._connection_pool)
if not self.redis:
raise RuntimeError("Redis client not initialized")
@ -131,7 +130,7 @@ class _RedisManager:
"uptime_seconds": info.get("uptime_in_seconds", 0),
"connected_clients": info.get("connected_clients", 0),
"used_memory_human": info.get("used_memory_human", "unknown"),
"total_commands_processed": info.get("total_commands_processed", 0)
"total_commands_processed": info.get("total_commands_processed", 0),
}
except Exception as e:
logger.error(f"Redis health check failed: {e}")
@ -173,9 +172,11 @@ class _RedisManager:
logger.error(f"Failed to get Redis info: {e}")
return None
# Global Redis manager instance
redis_manager = _RedisManager()
# DatabaseManager is an enhanced database manager that provides graceful shutdown capabilities
# It manages the Redis connection, tracks active requests, and allows for data backup before shutdown.
class DatabaseManager:
@ -227,7 +228,7 @@ class DatabaseManager:
backup_filename = f"backup_{datetime.now(UTC).strftime('%Y%m%d_%H%M%S')}.json"
# Save to local file (you might want to save to cloud storage instead)
with open(backup_filename, 'w') as f:
with open(backup_filename, "w") as f:
json.dump(backup_data, f, indent=2, default=str)
logger.info(f"Backup created: {backup_filename}")
@ -310,5 +311,3 @@ class DatabaseManager:
if self._shutdown_initiated:
raise RuntimeError("Application is shutting down")
return self.db

View File

@ -8,8 +8,10 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__)
class AIMixin(DatabaseProtocol):
"""Mixin for AI operations"""
async def get_ai_parameters(self, param_id: str) -> Optional[Dict]:
"""Get AI parameters by ID"""
key = f"{KEY_PREFIXES['ai_parameters']}{param_id}"
@ -36,7 +38,7 @@ class AIMixin(DatabaseProtocol):
result = {}
for key, value in zip(keys, values):
param_id = key.replace(KEY_PREFIXES['ai_parameters'], '')
param_id = key.replace(KEY_PREFIXES["ai_parameters"], "")
result[param_id] = self._deserialize(value)
return result
@ -45,4 +47,3 @@ class AIMixin(DatabaseProtocol):
"""Delete AI parameters"""
key = f"{KEY_PREFIXES['ai_parameters']}{param_id}"
await self.redis.delete(key)

View File

@ -7,6 +7,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class AnalyticsMixin:
"""Mixin for analytics-related database operations"""

View File

@ -8,6 +8,7 @@ from .protocols import DatabaseProtocol
logger = logging.getLogger(__name__)
class AuthMixin(DatabaseProtocol):
"""Mixin for auth-related database operations"""
@ -25,8 +26,9 @@ class AuthMixin(DatabaseProtocol):
token_data = await self.redis.get(key)
if token_data:
verification_info = json.loads(token_data)
if (verification_info.get("email", "").lower() == email_lower and
not verification_info.get("verified", False)):
if verification_info.get("email", "").lower() == email_lower and not verification_info.get(
"verified", False
):
# Extract token from key
token = key.replace("email_verification:", "")
verification_info["token"] = token
@ -115,10 +117,7 @@ class AuthMixin(DatabaseProtocol):
window_start = current_time - timedelta(hours=24)
# Filter out old attempts
recent_attempts = [
attempt for attempt in attempts_data
if datetime.fromisoformat(attempt) > window_start
]
recent_attempts = [attempt for attempt in attempts_data if datetime.fromisoformat(attempt) > window_start]
return len(recent_attempts)
@ -141,16 +140,13 @@ class AuthMixin(DatabaseProtocol):
# Keep only last 24 hours of attempts
window_start = current_time - timedelta(hours=24)
recent_attempts = [
attempt for attempt in attempts_data
if datetime.fromisoformat(attempt) > window_start
]
recent_attempts = [attempt for attempt in attempts_data if datetime.fromisoformat(attempt) > window_start]
# Store with 24 hour expiration
await self.redis.setex(
key,
24 * 60 * 60, # 24 hours
json.dumps(recent_attempts)
json.dumps(recent_attempts),
)
return True
@ -169,14 +165,14 @@ class AuthMixin(DatabaseProtocol):
"user_data": user_data,
"expires_at": (datetime.now(timezone.utc) + timedelta(hours=24)).isoformat(),
"created_at": datetime.now(timezone.utc).isoformat(),
"verified": False
"verified": False,
}
# Store with 24 hour expiration
await self.redis.setex(
key,
24 * 60 * 60, # 24 hours in seconds
json.dumps(verification_data, default=str)
json.dumps(verification_data, default=str),
)
logger.info(f"📧 Stored email verification token for {email}")
@ -208,7 +204,7 @@ class AuthMixin(DatabaseProtocol):
await self.redis.setex(
key,
24 * 60 * 60, # Keep for remaining TTL
json.dumps(token_data, default=str)
json.dumps(token_data, default=str),
)
return True
return False
@ -228,14 +224,14 @@ class AuthMixin(DatabaseProtocol):
"expires_at": (datetime.now(timezone.utc) + timedelta(minutes=10)).isoformat(),
"created_at": datetime.now(timezone.utc).isoformat(),
"attempts": 0,
"verified": False
"verified": False,
}
# Store with 10 minute expiration
await self.redis.setex(
key,
10 * 60, # 10 minutes in seconds
json.dumps(mfa_data, default=str)
json.dumps(mfa_data, default=str),
)
logger.info(f"🔐 Stored MFA code for {email}")
@ -266,7 +262,7 @@ class AuthMixin(DatabaseProtocol):
await self.redis.setex(
key,
10 * 60, # Keep original TTL
json.dumps(mfa_data, default=str)
json.dumps(mfa_data, default=str),
)
return mfa_data["attempts"]
return 0
@ -285,7 +281,7 @@ class AuthMixin(DatabaseProtocol):
await self.redis.setex(
key,
10 * 60, # Keep for remaining TTL
json.dumps(mfa_data, default=str)
json.dumps(mfa_data, default=str),
)
return True
return False
@ -327,7 +323,9 @@ class AuthMixin(DatabaseProtocol):
logger.error(f"❌ Error deleting authentication record for {user_id}: {e}")
return False
async def store_refresh_token(self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]) -> bool:
async def store_refresh_token(
self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]
) -> bool:
"""Store refresh token for a user"""
try:
key = f"refresh_token:{token}"
@ -337,7 +335,7 @@ class AuthMixin(DatabaseProtocol):
"device": device_info.get("device", "unknown"),
"ip_address": device_info.get("ip_address", "unknown"),
"is_revoked": False,
"created_at": datetime.now(timezone.utc).isoformat()
"created_at": datetime.now(timezone.utc).isoformat(),
}
# Store with expiration
@ -420,7 +418,7 @@ class AuthMixin(DatabaseProtocol):
"email": email.lower(),
"expires_at": expires_at.isoformat(),
"used": False,
"created_at": datetime.now(timezone.utc).isoformat()
"created_at": datetime.now(timezone.utc).isoformat(),
}
# Store with expiration
@ -473,7 +471,7 @@ class AuthMixin(DatabaseProtocol):
"timestamp": datetime.now(timezone.utc).isoformat(),
"user_id": user_id,
"event_type": event_type,
"details": details
"details": details,
}
# Add to list (latest events first)
@ -496,7 +494,7 @@ class AuthMixin(DatabaseProtocol):
try:
events = []
for i in range(days):
date = (datetime.now(timezone.utc) - timedelta(days=i)).strftime('%Y-%m-%d')
date = (datetime.now(timezone.utc) - timedelta(days=i)).strftime("%Y-%m-%d")
key = f"security_log:{user_id}:{date}"
daily_events = await self.redis.lrange(key, 0, -1) # type: ignore
@ -509,4 +507,3 @@ class AuthMixin(DatabaseProtocol):
except Exception as e:
logger.error(f"❌ Error retrieving security log for {user_id}: {e}")
return []

View File

@ -5,11 +5,13 @@ from typing import Any, Dict, TYPE_CHECKING
from .protocols import DatabaseProtocol
from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
pass
class BaseMixin(DatabaseProtocol):
"""Base mixin with core Redis operations and utilities"""
@ -45,6 +47,3 @@ class BaseMixin(DatabaseProtocol):
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)

View File

@ -8,6 +8,7 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__)
class ChatMixin(DatabaseProtocol):
"""Mixin for chat-related database operations"""
@ -22,7 +23,7 @@ class ChatMixin(DatabaseProtocol):
"total_sessions": 0,
"total_messages": 0,
"first_chat": None,
"last_chat": None
"last_chat": None,
}
total_messages = 0
@ -41,7 +42,7 @@ class ChatMixin(DatabaseProtocol):
"total_messages": total_messages,
"first_chat": sessions_by_date[0].get("createdAt") if sessions_by_date else None,
"last_chat": sessions_by_date[-1].get("lastActivity") if sessions_by_date else None,
"recent_sessions": sessions[:5] # Last 5 sessions
"recent_sessions": sessions[:5], # Last 5 sessions
}
# Chat Sessions operations
@ -71,13 +72,13 @@ class ChatMixin(DatabaseProtocol):
result = {}
for key, value in zip(keys, values):
session_id = key.replace(KEY_PREFIXES['chat_sessions'], '')
session_id = key.replace(KEY_PREFIXES["chat_sessions"], "")
result[session_id] = self._deserialize(value)
return result
async def delete_chat_session(self, session_id: str) -> bool:
'''Delete a chat session from Redis'''
"""Delete a chat session from Redis"""
try:
result = await self.redis.delete(f"chat_session:{session_id}")
return result > 0
@ -86,7 +87,7 @@ class ChatMixin(DatabaseProtocol):
raise
async def delete_chat_message(self, session_id: str, message_id: str) -> bool:
'''Delete a specific chat message from Redis'''
"""Delete a specific chat message from Redis"""
try:
# Remove from the session's message list
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
@ -132,7 +133,7 @@ class ChatMixin(DatabaseProtocol):
result = {}
for key in keys:
session_id = key.replace(KEY_PREFIXES['chat_messages'], '')
session_id = key.replace(KEY_PREFIXES["chat_messages"], "")
messages = await self.redis.lrange(key, 0, -1) # type: ignore
result[session_id] = [self._deserialize(msg) for msg in messages if msg]
@ -164,8 +165,7 @@ class ChatMixin(DatabaseProtocol):
for session_data in all_sessions.values():
context = session_data.get("context", {})
if (context.get("relatedEntityType") == "candidate" and
context.get("relatedEntityId") == candidate_id):
if context.get("relatedEntityType") == "candidate" and context.get("relatedEntityId") == candidate_id:
candidate_sessions.append(session_data)
# Sort by last activity (most recent first)
@ -236,7 +236,6 @@ class ChatMixin(DatabaseProtocol):
return archived_count
# Analytics and Reporting
async def get_chat_statistics(self) -> Dict[str, Any]:
"""Get comprehensive chat statistics"""
@ -250,7 +249,7 @@ class ChatMixin(DatabaseProtocol):
"archived_sessions": 0,
"sessions_by_type": {},
"sessions_with_candidates": 0,
"average_messages_per_session": 0
"average_messages_per_session": 0,
}
# Analyze sessions

View File

@ -7,6 +7,7 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__)
class DocumentMixin(DatabaseProtocol):
"""Mixin for document-related database operations"""
@ -136,8 +137,7 @@ class DocumentMixin(DatabaseProtocol):
query_lower = query.lower()
return [
doc for doc in all_documents
if (query_lower in doc.get("filename", "").lower() or
query_lower in doc.get("originalName", "").lower())
doc
for doc in all_documents
if (query_lower in doc.get("filename", "").lower() or query_lower in doc.get("originalName", "").lower())
]

View File

@ -8,8 +8,10 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__)
class JobMixin(DatabaseProtocol):
"""Mixin for job-related database operations"""
async def get_job(self, job_id: str) -> Optional[Dict]:
"""Get job by ID"""
key = f"{KEY_PREFIXES['jobs']}{job_id}"
@ -36,7 +38,7 @@ class JobMixin(DatabaseProtocol):
result = {}
for key, value in zip(keys, values):
job_id = key.replace(KEY_PREFIXES['jobs'], '')
job_id = key.replace(KEY_PREFIXES["jobs"], "")
result[job_id] = self._deserialize(value)
return result
@ -124,7 +126,7 @@ class JobMixin(DatabaseProtocol):
result = {}
for key, value in zip(keys, values):
app_id = key.replace(KEY_PREFIXES['job_applications'], '')
app_id = key.replace(KEY_PREFIXES["job_applications"], "")
result[app_id] = self._deserialize(value)
return result
@ -189,7 +191,7 @@ class JobMixin(DatabaseProtocol):
requirements_with_meta = {
**requirements,
"cached_at": datetime.now(UTC).isoformat(),
"document_id": document_id
"document_id": document_id,
}
await self.redis.set(key, self._serialize(requirements_with_meta))
@ -232,7 +234,7 @@ class JobMixin(DatabaseProtocol):
result = {}
for key, value in zip(keys, values):
document_id = key.replace(KEY_PREFIXES['job_requirements'], '')
document_id = key.replace(KEY_PREFIXES["job_requirements"], "")
if value:
result[document_id] = self._deserialize(value)
@ -247,11 +249,7 @@ class JobMixin(DatabaseProtocol):
pattern = f"{KEY_PREFIXES['job_requirements']}*"
keys = await self.redis.keys(pattern)
stats = {
"total_cached_requirements": len(keys),
"cache_dates": {},
"documents_with_requirements": []
}
stats = {"total_cached_requirements": len(keys), "cache_dates": {}, "documents_with_requirements": []}
if keys:
# Get cache dates for analysis
@ -264,7 +262,7 @@ class JobMixin(DatabaseProtocol):
if value:
requirements_data = self._deserialize(value)
if requirements_data:
document_id = key.replace(KEY_PREFIXES['job_requirements'], '')
document_id = key.replace(KEY_PREFIXES["job_requirements"], "")
stats["documents_with_requirements"].append(document_id)
# Track cache dates
@ -277,4 +275,3 @@ class JobMixin(DatabaseProtocol):
except Exception as e:
logger.error(f"❌ Error getting job requirements stats: {e}")
return {"total_cached_requirements": 0, "cache_dates": {}, "documents_with_requirements": []}

View File

@ -7,153 +7,412 @@ if TYPE_CHECKING:
from models import SkillAssessment
class DatabaseProtocol(Protocol):
# Base mixin
redis: Redis
def _serialize(self, data) -> str: ...
def _deserialize(self, data: str): ...
def _serialize(self, data) -> str:
...
def _deserialize(self, data: str):
...
# Chat mixin
async def add_chat_message(self, session_id: str, message_data: Dict): ...
async def archive_chat_session(self, session_id: str): ...
async def bulk_update_chat_sessions(self, session_updates: Dict[str, Dict]): ...
async def delete_chat_message(self, session_id: str, message_id: str) -> bool: ...
async def delete_chat_messages(self, session_id: str): ...
async def delete_chat_session_completely(self, session_id: str): ...
async def delete_chat_session(self, session_id: str) -> bool: ...
async def add_chat_message(self, session_id: str, message_data: Dict):
...
async def archive_chat_session(self, session_id: str):
...
async def bulk_update_chat_sessions(self, session_updates: Dict[str, Dict]):
...
async def delete_chat_message(self, session_id: str, message_id: str) -> bool:
...
async def delete_chat_messages(self, session_id: str):
...
async def delete_chat_session_completely(self, session_id: str):
...
async def delete_chat_session(self, session_id: str) -> bool:
...
# Document mixin
async def add_document_to_candidate(self, candidate_id: str, document_id: str): ...
async def bulk_update_document_rag_status(self, candidate_id: str, document_ids: List[str], include_in_rag: bool): ...
async def add_document_to_candidate(self, candidate_id: str, document_id: str):
...
async def bulk_update_document_rag_status(self, candidate_id: str, document_ids: List[str], include_in_rag: bool):
...
# Job mixin
async def bulk_delete_job_requirements(self, document_ids: List[str]) -> int: ...
async def cache_skill_match(self, cache_key: str, assessment: SkillAssessment) -> None: ...
async def bulk_delete_job_requirements(self, document_ids: List[str]) -> int:
...
async def cache_skill_match(self, cache_key: str, assessment: SkillAssessment) -> None:
...
# User mixin
async def delete_candidate_batch(self, candidate_ids: List[str]) -> Dict[str, Dict[str, int]]: ...
async def delete_candidate(self, candidate_id: str) -> Dict[str, int]: ...
async def delete_employer(self, employer_id: str): ...
async def delete_guest(self, guest_id: str) -> bool: ...
async def delete_user(self, email: str): ...
async def find_candidate_by_username(self, username: str) -> Optional[Dict]: ...
async def get_all_users(self) -> Dict[str, Any]: ...
async def get_all_viewers(self) -> Dict[str, Any]: ...
async def get_candidate_chat_summary(self, candidate_id: str) -> Dict[str, Any]: ...
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]: ...
async def get_candidate(self, candidate_id: str) -> Optional[Dict]: ...
async def get_employer(self, employer_id: str) -> Optional[Dict]: ...
async def get_guest_by_session_id(self, session_id: str) -> Optional[Dict[str, Any]]: ...
async def get_guest(self, guest_id: str) -> Optional[Dict[str, Any]]: ...
async def get_guest_statistics(self) -> Dict[str, Any]: ...
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: ...
async def get_user_by_username(self, username: str) -> Optional[Dict]: ...
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]: ...
async def get_user_security_log(self, user_id: str, days: int = 7) -> List[Dict[str, Any]]: ...
async def get_user(self, login: str) -> Optional[Dict[str, Any]]: ...
async def invalidate_candidate_skill_cache(self, candidate_id: str) -> int: ...
async def invalidate_user_skill_cache(self, user_id: str) -> int: ...
async def set_candidate(self, candidate_id: str, candidate_data: Dict): ...
async def set_employer(self, employer_id: str, employer_data: Dict): ...
async def set_guest(self, guest_id: str, guest_data: Dict[str, Any]) -> None: ...
async def set_user_by_id(self, user_id: str, user_data: Dict[str, Any]) -> bool: ...
async def set_user(self, login: str, user_data: Dict[str, Any]) -> bool: ...
async def update_user_rag_timestamp(self, user_id: str) -> bool: ...
async def user_exists_by_email(self, email: str) -> bool: ...
async def user_exists_by_username(self, username: str) -> bool: ...
async def delete_candidate_batch(self, candidate_ids: List[str]) -> Dict[str, Dict[str, int]]:
...
async def delete_candidate(self, candidate_id: str) -> Dict[str, int]:
...
async def delete_employer(self, employer_id: str):
...
async def delete_guest(self, guest_id: str) -> bool:
...
async def delete_user(self, email: str):
...
async def find_candidate_by_username(self, username: str) -> Optional[Dict]:
...
async def get_all_users(self) -> Dict[str, Any]:
...
async def get_all_viewers(self) -> Dict[str, Any]:
...
async def get_candidate_chat_summary(self, candidate_id: str) -> Dict[str, Any]:
...
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]:
...
async def get_candidate(self, candidate_id: str) -> Optional[Dict]:
...
async def get_employer(self, employer_id: str) -> Optional[Dict]:
...
async def get_guest_by_session_id(self, session_id: str) -> Optional[Dict[str, Any]]:
...
async def get_guest(self, guest_id: str) -> Optional[Dict[str, Any]]:
...
async def get_guest_statistics(self) -> Dict[str, Any]:
...
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
...
async def get_user_by_username(self, username: str) -> Optional[Dict]:
...
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]:
...
async def get_user_security_log(self, user_id: str, days: int = 7) -> List[Dict[str, Any]]:
...
async def get_user(self, login: str) -> Optional[Dict[str, Any]]:
...
async def invalidate_candidate_skill_cache(self, candidate_id: str) -> int:
...
async def invalidate_user_skill_cache(self, user_id: str) -> int:
...
async def set_candidate(self, candidate_id: str, candidate_data: Dict):
...
async def set_employer(self, employer_id: str, employer_data: Dict):
...
async def set_guest(self, guest_id: str, guest_data: Dict[str, Any]) -> None:
...
async def set_user_by_id(self, user_id: str, user_data: Dict[str, Any]) -> bool:
...
async def set_user(self, login: str, user_data: Dict[str, Any]) -> bool:
...
async def update_user_rag_timestamp(self, user_id: str) -> bool:
...
async def user_exists_by_email(self, email: str) -> bool:
...
async def user_exists_by_username(self, username: str) -> bool:
...
# Auth mixin
async def cleanup_expired_verification_tokens(self) -> int: ...
async def cleanup_inactive_guests(self, inactive_hours: int = 24) -> int: ...
async def cleanup_old_chat_sessions(self, days_old: int = 90) -> int: ...
async def cleanup_orphaned_job_requirements(self) -> int: ...
async def clear_all_data(self: "DatabaseProtocol"): ...
async def clear_all_skill_match_cache(self) -> int: ...
async def cleanup_expired_verification_tokens(self) -> int:
...
async def cleanup_inactive_guests(self, inactive_hours: int = 24) -> int:
...
async def cleanup_old_chat_sessions(self, days_old: int = 90) -> int:
...
async def cleanup_orphaned_job_requirements(self) -> int:
...
async def clear_all_data(self: "DatabaseProtocol"):
...
async def clear_all_skill_match_cache(self) -> int:
...
# Resume mixin
async def delete_ai_parameters(self, param_id: str): ...
async def delete_all_candidate_documents(self, candidate_id: str) -> int: ...
async def delete_all_resumes_for_user(self, user_id: str) -> int: ...
async def delete_authentication(self, user_id: str) -> bool: ...
async def delete_document(self, document_id: str): ...
async def delete_ai_parameters(self, param_id: str):
...
async def delete_job_application(self, application_id: str): ...
async def delete_job_requirements(self, document_id: str) -> bool: ...
async def delete_job(self, job_id: str): ...
async def delete_resume(self, user_id: str, resume_id: str) -> bool: ...
async def delete_viewer(self, viewer_id: str): ...
async def find_verification_token_by_email(self, email: str) -> Optional[Dict[str, Any]]: ...
async def get_ai_parameters(self, param_id: str) -> Optional[Dict]: ...
async def get_all_ai_parameters(self) -> Dict[str, Any]: ...
async def get_all_candidates(self) -> Dict[str, Any]: ...
async def get_all_chat_messages(self) -> Dict[str, List[Dict]]: ...
async def get_all_chat_sessions(self) -> Dict[str, Any]: ...
async def get_all_employers(self) -> Dict[str, Any]: ...
async def get_all_guests(self) -> Dict[str, Dict[str, Any]]: ...
async def get_all_job_applications(self) -> Dict[str, Any]: ...
async def get_all_job_requirements(self) -> Dict[str, Any]: ...
async def get_all_jobs(self) -> Dict[str, Any]: ...
async def get_all_resumes_for_user(self, user_id: str) -> List[Dict]: ...
async def get_all_resumes(self) -> Dict[str, List[Dict]]: ...
async def get_authentication(self, user_id: str) -> Optional[Dict[str, Any]]: ...
async def get_cached_skill_match(self, cache_key: str) -> Optional[SkillAssessment]: ...
async def get_chat_message_count(self, session_id: str) -> int: ...
async def get_chat_messages(self, session_id: str) -> List[Dict]: ...
async def get_chat_sessions_by_candidate(self, candidate_id: str) -> List[Dict]: ...
async def get_chat_sessions_by_user(self, user_id: str) -> List[Dict]: ...
async def get_chat_session(self, session_id: str) -> Optional[Dict]: ...
async def get_chat_statistics(self) -> Dict[str, Any]: ...
async def get_document_count_for_candidate(self, candidate_id: str) -> int: ...
async def get_documents_by_rag_status(self, candidate_id: str, include_in_rag: bool = True) -> List[Dict]: ...
async def get_document(self, document_id: str) -> Optional[Dict]: ...
async def get_email_verification_token(self, token: str) -> Optional[Dict[str, Any]]: ...
async def get_job_application(self, application_id: str) -> Optional[Dict]: ...
async def get_job_requirements_by_candidate(self, candidate_id: str) -> List[Dict]: ...
async def get_job_requirements(self, document_id: str) -> Optional[Dict]: ...
async def get_job_requirements_stats(self) -> Dict[str, Any]: ...
async def get_job(self, job_id: str) -> Optional[Dict]: ...
async def get_mfa_code(self, email: str, device_id: str) -> Optional[Dict[str, Any]]: ...
async def get_multiple_candidates_by_usernames(self, usernames: List[str]) -> Dict[str, Dict]: ...
async def get_password_reset_token(self, token: str) -> Optional[Dict[str, Any]]: ...
async def get_pending_verifications_count(self) -> int: ...
async def get_recent_chat_messages(self, session_id: str, limit: int = 10) -> List[Dict]: ...
async def get_refresh_token(self, token: str) -> Optional[Dict[str, Any]]: ...
async def get_resumes_by_candidate(self, user_id: str, candidate_id: str) -> List[Dict]: ...
async def get_resumes_by_job(self, user_id: str, job_id: str) -> List[Dict]: ...
async def get_resume(self, user_id: str, resume_id: str) -> Optional[Dict]: ...
async def get_resume_statistics(self, user_id: str) -> Dict[str, Any]: ...
async def get_stats(self) -> Dict[str, int]: ...
async def get_verification_attempts_count(self, email: str) -> int: ...
async def get_viewer(self, viewer_id: str) -> Optional[Dict]: ...
async def increment_mfa_attempts(self, email: str, device_id: str) -> int: ...
async def invalidate_job_requirements_cache(self, document_id: str) -> bool: ...
async def log_security_event(self, user_id: str, event_type: str, details: Dict[str, Any]) -> bool: ...
async def mark_email_verified(self, token: str) -> bool: ...
async def mark_mfa_verified(self, email: str, device_id: str) -> bool: ...
async def mark_password_reset_token_used(self, token: str) -> bool: ...
async def record_verification_attempt(self, email: str) -> bool: ...
async def remove_document_from_candidate(self, candidate_id: str, document_id: str): ...
async def revoke_all_user_tokens(self, user_id: str) -> bool: ...
async def revoke_refresh_token(self, token: str) -> bool: ...
async def save_job_requirements(self, document_id: str, requirements: Dict) -> bool: ...
async def search_candidate_documents(self, candidate_id: str, query: str) -> List[Dict]: ...
async def search_chat_messages(self, session_id: str, query: str) -> List[Dict]: ...
async def search_resumes_for_user(self, user_id: str, query: str) -> List[Dict]: ...
async def set_ai_parameters(self, param_id: str, param_data: Dict): ...
async def set_authentication(self, user_id: str, auth_data: Dict[str, Any]) -> bool: ...
async def set_chat_messages(self, session_id: str, messages: List[Dict]): ...
async def set_chat_session(self, session_id: str, session_data: Dict): ...
async def set_document(self, document_id: str, document_data: Dict): ...
async def set_job_application(self, application_id: str, application_data: Dict): ...
async def set_job(self, job_id: str, job_data: Dict): ...
async def set_resume(self, user_id: str, resume_data: Dict) -> bool: ...
async def set_viewer(self, viewer_id: str, viewer_data: Dict): ...
async def store_email_verification_token(self, email: str, token: str, user_type: str, user_data: dict) -> bool: ...
async def store_mfa_code(self, email: str, code: str, device_id: str) -> bool: ...
async def store_password_reset_token(self, email: str, token: str, expires_at: datetime) -> bool: ...
async def store_refresh_token(self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]) -> bool: ...
async def update_chat_session_activity(self, session_id: str): ...
async def update_document(self, document_id: str, updates: Dict)-> Dict[Any, Any] | None: ...
async def update_resume(self, user_id: str, resume_id: str, updates: Dict) -> Optional[Dict]: ...
async def delete_all_candidate_documents(self, candidate_id: str) -> int:
...
async def delete_all_resumes_for_user(self, user_id: str) -> int:
...
async def delete_authentication(self, user_id: str) -> bool:
...
async def delete_document(self, document_id: str):
...
async def delete_job_application(self, application_id: str):
...
async def delete_job_requirements(self, document_id: str) -> bool:
...
async def delete_job(self, job_id: str):
...
async def delete_resume(self, user_id: str, resume_id: str) -> bool:
...
async def delete_viewer(self, viewer_id: str):
...
async def find_verification_token_by_email(self, email: str) -> Optional[Dict[str, Any]]:
...
async def get_ai_parameters(self, param_id: str) -> Optional[Dict]:
...
async def get_all_ai_parameters(self) -> Dict[str, Any]:
...
async def get_all_candidates(self) -> Dict[str, Any]:
...
async def get_all_chat_messages(self) -> Dict[str, List[Dict]]:
...
async def get_all_chat_sessions(self) -> Dict[str, Any]:
...
async def get_all_employers(self) -> Dict[str, Any]:
...
async def get_all_guests(self) -> Dict[str, Dict[str, Any]]:
...
async def get_all_job_applications(self) -> Dict[str, Any]:
...
async def get_all_job_requirements(self) -> Dict[str, Any]:
...
async def get_all_jobs(self) -> Dict[str, Any]:
...
async def get_all_resumes_for_user(self, user_id: str) -> List[Dict]:
...
async def get_all_resumes(self) -> Dict[str, List[Dict]]:
...
async def get_authentication(self, user_id: str) -> Optional[Dict[str, Any]]:
...
async def get_cached_skill_match(self, cache_key: str) -> Optional[SkillAssessment]:
...
async def get_chat_message_count(self, session_id: str) -> int:
...
async def get_chat_messages(self, session_id: str) -> List[Dict]:
...
async def get_chat_sessions_by_candidate(self, candidate_id: str) -> List[Dict]:
...
async def get_chat_sessions_by_user(self, user_id: str) -> List[Dict]:
...
async def get_chat_session(self, session_id: str) -> Optional[Dict]:
...
async def get_chat_statistics(self) -> Dict[str, Any]:
...
async def get_document_count_for_candidate(self, candidate_id: str) -> int:
...
async def get_documents_by_rag_status(self, candidate_id: str, include_in_rag: bool = True) -> List[Dict]:
...
async def get_document(self, document_id: str) -> Optional[Dict]:
...
async def get_email_verification_token(self, token: str) -> Optional[Dict[str, Any]]:
...
async def get_job_application(self, application_id: str) -> Optional[Dict]:
...
async def get_job_requirements_by_candidate(self, candidate_id: str) -> List[Dict]:
...
async def get_job_requirements(self, document_id: str) -> Optional[Dict]:
...
async def get_job_requirements_stats(self) -> Dict[str, Any]:
...
async def get_job(self, job_id: str) -> Optional[Dict]:
...
async def get_mfa_code(self, email: str, device_id: str) -> Optional[Dict[str, Any]]:
...
async def get_multiple_candidates_by_usernames(self, usernames: List[str]) -> Dict[str, Dict]:
...
async def get_password_reset_token(self, token: str) -> Optional[Dict[str, Any]]:
...
async def get_pending_verifications_count(self) -> int:
...
async def get_recent_chat_messages(self, session_id: str, limit: int = 10) -> List[Dict]:
...
async def get_refresh_token(self, token: str) -> Optional[Dict[str, Any]]:
...
async def get_resumes_by_candidate(self, user_id: str, candidate_id: str) -> List[Dict]:
...
async def get_resumes_by_job(self, user_id: str, job_id: str) -> List[Dict]:
...
async def get_resume(self, user_id: str, resume_id: str) -> Optional[Dict]:
...
async def get_resume_statistics(self, user_id: str) -> Dict[str, Any]:
...
async def get_stats(self) -> Dict[str, int]:
...
async def get_verification_attempts_count(self, email: str) -> int:
...
async def get_viewer(self, viewer_id: str) -> Optional[Dict]:
...
async def increment_mfa_attempts(self, email: str, device_id: str) -> int:
...
async def invalidate_job_requirements_cache(self, document_id: str) -> bool:
...
async def log_security_event(self, user_id: str, event_type: str, details: Dict[str, Any]) -> bool:
...
async def mark_email_verified(self, token: str) -> bool:
...
async def mark_mfa_verified(self, email: str, device_id: str) -> bool:
...
async def mark_password_reset_token_used(self, token: str) -> bool:
...
async def record_verification_attempt(self, email: str) -> bool:
...
async def remove_document_from_candidate(self, candidate_id: str, document_id: str):
...
async def revoke_all_user_tokens(self, user_id: str) -> bool:
...
async def revoke_refresh_token(self, token: str) -> bool:
...
async def save_job_requirements(self, document_id: str, requirements: Dict) -> bool:
...
async def search_candidate_documents(self, candidate_id: str, query: str) -> List[Dict]:
...
async def search_chat_messages(self, session_id: str, query: str) -> List[Dict]:
...
async def search_resumes_for_user(self, user_id: str, query: str) -> List[Dict]:
...
async def set_ai_parameters(self, param_id: str, param_data: Dict):
...
async def set_authentication(self, user_id: str, auth_data: Dict[str, Any]) -> bool:
...
async def set_chat_messages(self, session_id: str, messages: List[Dict]):
...
async def set_chat_session(self, session_id: str, session_data: Dict):
...
async def set_document(self, document_id: str, document_data: Dict):
...
async def set_job_application(self, application_id: str, application_data: Dict):
...
async def set_job(self, job_id: str, job_data: Dict):
...
async def set_resume(self, user_id: str, resume_data: Dict) -> bool:
...
async def set_viewer(self, viewer_id: str, viewer_data: Dict):
...
async def store_email_verification_token(self, email: str, token: str, user_type: str, user_data: dict) -> bool:
...
async def store_mfa_code(self, email: str, code: str, device_id: str) -> bool:
...
async def store_password_reset_token(self, email: str, token: str, expires_at: datetime) -> bool:
...
async def store_refresh_token(
self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]
) -> bool:
...
async def update_chat_session_activity(self, session_id: str):
...
async def update_document(self, document_id: str, updates: Dict) -> Dict[Any, Any] | None:
...
async def update_resume(self, user_id: str, resume_id: str, updates: Dict) -> Optional[Dict]:
...

View File

@ -7,6 +7,7 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__)
class ResumeMixin(DatabaseProtocol):
"""Mixin for resume-related database operations"""
@ -14,10 +15,10 @@ class ResumeMixin(DatabaseProtocol):
"""Save a resume for a user"""
try:
# Generate resume_id if not present
if 'id' not in resume_data:
if "id" not in resume_data:
raise ValueError("Resume data must include an 'id' field")
resume_id = resume_data['id']
resume_id = resume_data["id"]
# Store the resume data
key = f"{KEY_PREFIXES['resumes']}{user_id}:{resume_id}"
@ -159,7 +160,7 @@ class ResumeMixin(DatabaseProtocol):
for key, value in zip(keys, values):
if value:
# Extract user_id from key format: resume:{user_id}:{resume_id}
key_parts = key.replace(KEY_PREFIXES['resumes'], '').split(':', 1)
key_parts = key.replace(KEY_PREFIXES["resumes"], "").split(":", 1)
if len(key_parts) >= 1:
user_id = key_parts[0]
resume_data = self._deserialize(value)
@ -186,12 +187,14 @@ class ResumeMixin(DatabaseProtocol):
matching_resumes = []
for resume in all_resumes:
# Search in resume content, job_id, candidate_id, etc.
searchable_text = " ".join([
searchable_text = " ".join(
[
resume.get("resume", ""),
resume.get("job_id", ""),
resume.get("candidate_id", ""),
str(resume.get("created_at", ""))
]).lower()
str(resume.get("created_at", "")),
]
).lower()
if query_lower in searchable_text:
matching_resumes.append(resume)
@ -206,10 +209,7 @@ class ResumeMixin(DatabaseProtocol):
"""Get all resumes for a specific candidate created by a user"""
try:
all_resumes = await self.get_all_resumes_for_user(user_id)
candidate_resumes = [
resume for resume in all_resumes
if resume.get("candidate_id") == candidate_id
]
candidate_resumes = [resume for resume in all_resumes if resume.get("candidate_id") == candidate_id]
logger.info(f"📄 Found {len(candidate_resumes)} resumes for candidate {candidate_id} by user {user_id}")
return candidate_resumes
@ -221,10 +221,7 @@ class ResumeMixin(DatabaseProtocol):
"""Get all resumes for a specific job created by a user"""
try:
all_resumes = await self.get_all_resumes_for_user(user_id)
job_resumes = [
resume for resume in all_resumes
if resume.get("job_id") == job_id
]
job_resumes = [resume for resume in all_resumes if resume.get("job_id") == job_id]
logger.info(f"📄 Found {len(job_resumes)} resumes for job {job_id} by user {user_id}")
return job_resumes
@ -242,7 +239,7 @@ class ResumeMixin(DatabaseProtocol):
"resumes_by_candidate": {},
"resumes_by_job": {},
"creation_timeline": {},
"recent_resumes": []
"recent_resumes": [],
}
for resume in all_resumes:
@ -269,7 +266,13 @@ class ResumeMixin(DatabaseProtocol):
return stats
except Exception as e:
logger.error(f"❌ Error getting resume statistics for user {user_id}: {e}")
return {"total_resumes": 0, "resumes_by_candidate": {}, "resumes_by_job": {}, "creation_timeline": {}, "recent_resumes": []}
return {
"total_resumes": 0,
"resumes_by_candidate": {},
"resumes_by_job": {},
"creation_timeline": {},
"recent_resumes": [],
}
async def update_resume(self, user_id: str, resume_id: str, updates: Dict) -> Optional[Dict]:
"""Update specific fields of a resume"""

View File

@ -8,6 +8,7 @@ from .protocols import DatabaseProtocol
logger = logging.getLogger(__name__)
class SkillMixin(DatabaseProtocol):
"""Mixin for Skill-related database operations"""
@ -74,7 +75,9 @@ class SkillMixin(DatabaseProtocol):
# Cache for 1 hour by default
await self.redis.set(
cache_key,
json.dumps(assessment.model_dump(mode='json', by_alias=True), default=str) # Serialize with datetime handling
json.dumps(
assessment.model_dump(mode="json", by_alias=True), default=str
), # Serialize with datetime handling
)
logger.info(f"💾 Skill match cached: {cache_key}")
except Exception as e:

View File

@ -10,6 +10,7 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__)
class UserMixin(DatabaseProtocol):
"""Mixin for user operations"""
@ -29,7 +30,7 @@ class UserMixin(DatabaseProtocol):
await self.redis.setex(
f"guest_backup:{guest_id}",
86400 * 7, # 7 days TTL
json.dumps(guest_data)
json.dumps(guest_data),
)
logger.info(f"💾 Guest stored with backup: {guest_id}")
@ -83,10 +84,7 @@ class UserMixin(DatabaseProtocol):
"""Get all guests"""
try:
data = await self.redis.hgetall("guests") # type: ignore
return {
guest_id: json.loads(guest_json)
for guest_id, guest_json in data.items()
}
return {guest_id: json.loads(guest_json) for guest_id, guest_json in data.items()}
except Exception as e:
logger.error(f"❌ Error getting all guests: {e}")
return {}
@ -120,7 +118,7 @@ class UserMixin(DatabaseProtocol):
# Skip cleanup if guest is very new (less than 1 hour old)
if created_at_str:
created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00"))
if current_time - created_at < timedelta(hours=1):
preserved_count += 1
logger.info(f"🛡️ Preserving new guest: {guest_id}")
@ -130,7 +128,7 @@ class UserMixin(DatabaseProtocol):
should_delete = False
if last_activity_str:
try:
last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00'))
last_activity = datetime.fromisoformat(last_activity_str.replace("Z", "+00:00"))
if last_activity < cutoff_time:
should_delete = True
except ValueError:
@ -172,7 +170,7 @@ class UserMixin(DatabaseProtocol):
"active_last_day": 0,
"converted_guests": 0,
"by_ip": {},
"creation_timeline": {}
"creation_timeline": {},
}
hour_ago = current_time - timedelta(hours=1)
@ -183,7 +181,7 @@ class UserMixin(DatabaseProtocol):
last_activity_str = guest_data.get("last_activity")
if last_activity_str:
try:
last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00'))
last_activity = datetime.fromisoformat(last_activity_str.replace("Z", "+00:00"))
if last_activity > hour_ago:
stats["active_last_hour"] += 1
if last_activity > day_ago:
@ -203,8 +201,8 @@ class UserMixin(DatabaseProtocol):
created_at_str = guest_data.get("created_at")
if created_at_str:
try:
created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
date_key = created_at.strftime('%Y-%m-%d')
created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00"))
date_key = created_at.strftime("%Y-%m-%d")
stats["creation_timeline"][date_key] = stats["creation_timeline"].get(date_key, 0) + 1
except ValueError:
pass
@ -321,7 +319,7 @@ class UserMixin(DatabaseProtocol):
result = {}
for key, value in zip(keys, values):
email = key.replace(KEY_PREFIXES['users'], '')
email = key.replace(KEY_PREFIXES["users"], "")
logger.info(f"🔍 Found user key: {key}, type: {type(value)}")
if type(value) == str:
result[email] = value
@ -364,7 +362,6 @@ class UserMixin(DatabaseProtocol):
logger.error(f"❌ Error storing user {login}: {e}")
return False
# ================
# Employers
# ================
@ -394,7 +391,7 @@ class UserMixin(DatabaseProtocol):
result = {}
for key, value in zip(keys, values):
employer_id = key.replace(KEY_PREFIXES['employers'], '')
employer_id = key.replace(KEY_PREFIXES["employers"], "")
result[employer_id] = self._deserialize(value)
return result
@ -404,7 +401,6 @@ class UserMixin(DatabaseProtocol):
key = f"{KEY_PREFIXES['employers']}{employer_id}"
await self.redis.delete(key)
# ================
# Candidates
# ================
@ -435,7 +431,7 @@ class UserMixin(DatabaseProtocol):
result = {}
for key, value in zip(keys, values):
candidate_id = key.replace(KEY_PREFIXES['candidates'], '')
candidate_id = key.replace(KEY_PREFIXES["candidates"], "")
result[candidate_id] = self._deserialize(value)
return result
@ -456,7 +452,7 @@ class UserMixin(DatabaseProtocol):
"security_logs": 0,
"ai_parameters": 0,
"candidate_record": 0,
"resumes": 0
"resumes": 0,
}
logger.info(f"🗑️ Starting cascading delete for candidate {candidate_id}")
@ -495,7 +491,9 @@ class UserMixin(DatabaseProtocol):
deletion_stats["chat_sessions"] = len(candidate_sessions)
deletion_stats["chat_messages"] = messages_deleted
logger.info(f"🗑️ Deleted {len(candidate_sessions)} chat sessions and {messages_deleted} messages for candidate {candidate_id}")
logger.info(
f"🗑️ Deleted {len(candidate_sessions)} chat sessions and {messages_deleted} messages for candidate {candidate_id}"
)
except Exception as e:
logger.error(f"❌ Error deleting chat sessions: {e}")
@ -528,9 +526,11 @@ class UserMixin(DatabaseProtocol):
logger.info(f"🗑️ Deleted user record by email: {candidate_email}")
# Delete by username (if different from email)
if (candidate_username and
candidate_username != candidate_email and
await self.user_exists_by_username(candidate_username)):
if (
candidate_username
and candidate_username != candidate_email
and await self.user_exists_by_username(candidate_username)
):
await self.delete_user(candidate_username)
user_records_deleted += 1
logger.info(f"🗑️ Deleted user record by username: {candidate_username}")
@ -593,8 +593,7 @@ class UserMixin(DatabaseProtocol):
candidate_ai_params = []
for param_id, param_data in all_ai_params.items():
if (param_data.get("candidateId") == candidate_id or
param_data.get("userId") == candidate_id):
if param_data.get("candidateId") == candidate_id or param_data.get("userId") == candidate_id:
candidate_ai_params.append(param_id)
# Delete each AI parameter set
@ -630,7 +629,9 @@ class UserMixin(DatabaseProtocol):
break
if tokens_deleted > 0:
logger.info(f"🗑️ Deleted {tokens_deleted} email verification tokens for candidate {candidate_id}")
logger.info(
f"🗑️ Deleted {tokens_deleted} email verification tokens for candidate {candidate_id}"
)
except Exception as e:
logger.error(f"❌ Error deleting email verification tokens: {e}")
@ -717,8 +718,10 @@ class UserMixin(DatabaseProtocol):
# 15. Log the deletion as a security event (if we have admin/system user context)
try:
total_items_deleted = sum(deletion_stats.values())
logger.info(f"✅ Completed cascading delete for candidate {candidate_id}. "
f"Total items deleted: {total_items_deleted}")
logger.info(
f"✅ Completed cascading delete for candidate {candidate_id}. "
f"Total items deleted: {total_items_deleted}"
)
logger.info(f"📊 Deletion breakdown: {deletion_stats}")
except Exception as e:
logger.error(f"❌ Error logging deletion summary: {e}")
@ -774,8 +777,8 @@ class UserMixin(DatabaseProtocol):
"total_candidates_processed": len(candidate_ids),
"successful_deletions": len([r for r in batch_results.values() if "error" not in r]),
"failed_deletions": len([r for r in batch_results.values() if "error" in r]),
"total_items_deleted": sum(total_stats.values())
}
"total_items_deleted": sum(total_stats.values()),
},
}
except Exception as e:
@ -816,7 +819,7 @@ class UserMixin(DatabaseProtocol):
"total_sessions": 0,
"total_messages": 0,
"first_chat": None,
"last_chat": None
"last_chat": None,
}
total_messages = 0
@ -835,7 +838,7 @@ class UserMixin(DatabaseProtocol):
"total_messages": total_messages,
"first_chat": sessions_by_date[0].get("createdAt") if sessions_by_date else None,
"last_chat": sessions_by_date[-1].get("lastActivity") if sessions_by_date else None,
"recent_sessions": sessions[:5] # Last 5 sessions
"recent_sessions": sessions[:5], # Last 5 sessions
}
# ================
@ -868,7 +871,7 @@ class UserMixin(DatabaseProtocol):
result = {}
for key, value in zip(keys, values):
viewer_id = key.replace(KEY_PREFIXES['viewers'], '')
viewer_id = key.replace(KEY_PREFIXES["viewers"], "")
result[viewer_id] = self._deserialize(value)
return result
@ -877,4 +880,3 @@ class UserMixin(DatabaseProtocol):
"""Delete viewer"""
key = f"{KEY_PREFIXES['viewers']}{viewer_id}"
await self.redis.delete(key)

View File

@ -6,6 +6,7 @@ from datetime import datetime, timezone
from user_agents import parse
import json
class DeviceManager:
def __init__(self, database: RedisDatabase):
self.database = database
@ -34,7 +35,7 @@ class DeviceManager:
"os": user_agent.os.family,
"os_version": user_agent.os.version_string,
"ip_address": request.client.host if request.client else "unknown",
"user_agent": user_agent_string
"user_agent": user_agent_string,
}
async def is_trusted_device(self, user_id: str, device_id: str) -> bool:
@ -54,14 +55,14 @@ class DeviceManager:
device_data = {
**device_info,
"added_at": datetime.now(timezone.utc).isoformat(),
"last_used": datetime.now(timezone.utc).isoformat()
"last_used": datetime.now(timezone.utc).isoformat(),
}
# Store for 90 days
await self.database.redis.setex(
key,
90 * 24 * 60 * 60, # 90 days in seconds
json.dumps(device_data, default=str)
json.dumps(device_data, default=str),
)
logger.info(f"🔒 Added trusted device {device_id} for user {user_id}")
@ -79,7 +80,7 @@ class DeviceManager:
await self.database.redis.setex(
key,
90 * 24 * 60 * 60, # Reset 90 day expiry
json.dumps(device_info, default=str)
json.dumps(device_info, default=str),
)
except Exception as e:
logger.error(f"Error updating device last used: {e}")

View File

@ -10,6 +10,7 @@ from datetime import datetime, timezone, timedelta
import json
from database.manager import RedisDatabase
class EmailService:
def __init__(self):
# Configure these in your .env file
@ -30,36 +31,25 @@ class EmailService:
def _format_template(self, template: str, **kwargs) -> str:
"""Format template with provided variables"""
return template.format(
app_name=self.app_name,
from_name=self.from_name,
frontend_url=self.frontend_url,
**kwargs
app_name=self.app_name, from_name=self.from_name, frontend_url=self.frontend_url, **kwargs
)
async def send_verification_email(
self,
to_email: str,
verification_token: str,
user_name: str,
user_type: str = "user"
self, to_email: str, verification_token: str, user_name: str, user_type: str = "user"
):
"""Send email verification email using template"""
try:
template = self._get_template("verification")
verification_link = f"{self.frontend_url}/login/verify-email?token={verification_token}"
subject = self._format_template(
template["subject"],
user_name=user_name,
to_email=to_email
)
subject = self._format_template(template["subject"], user_name=user_name, to_email=to_email)
html_content = self._format_template(
template["html"],
user_name=user_name,
user_type=user_type,
to_email=to_email,
verification_link=verification_link
verification_link=verification_link,
)
await self._send_email(to_email, subject, html_content)
@ -70,12 +60,7 @@ class EmailService:
raise
async def send_mfa_email(
self,
to_email: str,
mfa_code: str,
device_name: str,
user_name: str,
ip_address: str = "Unknown"
self, to_email: str, mfa_code: str, device_name: str, user_name: str, ip_address: str = "Unknown"
):
"""Send MFA code email using template"""
try:
@ -91,7 +76,7 @@ class EmailService:
ip_address=ip_address,
login_time=login_time,
mfa_code=mfa_code,
to_email=to_email
to_email=to_email,
)
await self._send_email(to_email, subject, html_content)
@ -101,12 +86,7 @@ class EmailService:
logger.error(f"❌ Failed to send MFA email to {to_email}: {e}")
raise
async def send_password_reset_email(
self,
to_email: str,
reset_token: str,
user_name: str
):
async def send_password_reset_email(self, to_email: str, reset_token: str, user_name: str):
"""Send password reset email using template"""
try:
template = self._get_template("password_reset")
@ -115,10 +95,7 @@ class EmailService:
subject = self._format_template(template["subject"])
html_content = self._format_template(
template["html"],
user_name=user_name,
reset_link=reset_link,
to_email=to_email
template["html"], user_name=user_name, reset_link=reset_link, to_email=to_email
)
await self._send_email(to_email, subject, html_content)
@ -134,14 +111,14 @@ class EmailService:
if not self.email_user:
raise ValueError("Email user is not configured")
# Create message
msg = MIMEMultipart('alternative')
msg['From'] = f"{self.from_name} <{self.email_user}>"
msg['To'] = to_email
msg['Subject'] = subject
msg['Reply-To'] = self.email_user
msg = MIMEMultipart("alternative")
msg["From"] = f"{self.from_name} <{self.email_user}>"
msg["To"] = to_email
msg["Subject"] = subject
msg["Reply-To"] = self.email_user
# Add HTML content
html_part = MIMEText(html_content, 'html', 'utf-8')
html_part = MIMEText(html_content, "html", "utf-8")
msg.attach(html_part)
# Send email with connection pooling and retry logic
@ -157,7 +134,6 @@ class EmailService:
text = msg.as_string()
server.sendmail(self.email_user, to_email, text)
break # Success, exit retry loop
except smtplib.SMTPException as e:
if attempt == max_retries - 1: # Last attempt
raise
@ -170,6 +146,7 @@ class EmailService:
logger.error(f"❌ SMTP error sending to {to_email}: {e}")
raise
class EmailRateLimiter:
def __init__(self, database: RedisDatabase):
self.database = database
@ -191,10 +168,7 @@ class EmailRateLimiter:
email_records = json.loads(count_data)
# Filter out old records
recent_records = [
record for record in email_records
if datetime.fromisoformat(record) > window_start
]
recent_records = [record for record in email_records if datetime.fromisoformat(record) > window_start]
if len(recent_records) >= limit:
logger.warning(f"⚠️ Email rate limit exceeded for {email} ({email_type})")
@ -202,11 +176,7 @@ class EmailRateLimiter:
# Add current email to records
recent_records.append(current_time.isoformat())
await self.database.redis.setex(
key,
window_minutes * 60,
json.dumps(recent_records)
)
await self.database.redis.setex(key, window_minutes * 60, json.dumps(recent_records))
return True
@ -217,11 +187,8 @@ class EmailRateLimiter:
async def _record_email_sent(self, key: str, timestamp: datetime, ttl_minutes: int):
"""Record that an email was sent"""
await self.database.redis.setex(
key,
ttl_minutes * 60,
json.dumps([timestamp.isoformat()])
)
await self.database.redis.setex(key, ttl_minutes * 60, json.dumps([timestamp.isoformat()]))
class VerificationEmailRateLimiter:
def __init__(self, database: RedisDatabase):
@ -242,7 +209,10 @@ class VerificationEmailRateLimiter:
# Check daily limit
daily_count = await self.database.get_verification_attempts_count(email)
if daily_count >= self.max_attempts_per_day:
return False, f"Daily limit reached. You can request up to {self.max_attempts_per_day} verification emails per day."
return (
False,
f"Daily limit reached. You can request up to {self.max_attempts_per_day} verification emails per day.",
)
# Check hourly limit
hour_ago = current_time - timedelta(hours=1)
@ -251,13 +221,13 @@ class VerificationEmailRateLimiter:
if data:
attempts_data = json.loads(data)
recent_attempts = [
attempt for attempt in attempts_data
if datetime.fromisoformat(attempt) > hour_ago
]
recent_attempts = [attempt for attempt in attempts_data if datetime.fromisoformat(attempt) > hour_ago]
if len(recent_attempts) >= self.max_attempts_per_hour:
return False, f"Hourly limit reached. You can request up to {self.max_attempts_per_hour} verification emails per hour."
return (
False,
f"Hourly limit reached. You can request up to {self.max_attempts_per_hour} verification emails per hour.",
)
# Check cooldown period
if recent_attempts:
@ -280,7 +250,4 @@ class VerificationEmailRateLimiter:
await self.database.record_verification_attempt(email)
email_service = EmailService()

View File

@ -129,9 +129,8 @@ EMAIL_TEMPLATES = {
</div>
</body>
</html>
"""
""",
},
"mfa": {
"subject": "Security Code for Backstory",
"html": """
@ -274,9 +273,8 @@ EMAIL_TEMPLATES = {
</div>
</body>
</html>
"""
""",
},
"password_reset": {
"subject": "Reset your Backstory password",
"html": """
@ -386,6 +384,6 @@ EMAIL_TEMPLATES = {
</div>
</body>
</html>
"""
}
""",
},
}

View File

@ -10,6 +10,7 @@ from agents.base import CandidateEntity
from database.manager import RedisDatabase
from prometheus_client import CollectorRegistry # type: ignore
class EntityManager(BaseModel):
"""Manages lifecycle of CandidateEntity instances"""
@ -36,10 +37,7 @@ class EntityManager(BaseModel):
pass
self._cleanup_task = None
def initialize(
self,
prometheus_collector: CollectorRegistry,
database: RedisDatabase):
def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
"""Initialize the EntityManager with Prometheus collector"""
self._prometheus_collector = prometheus_collector
self._database = database
@ -58,9 +56,7 @@ class EntityManager(BaseModel):
raise ValueError("EntityManager has not been initialized with required components.")
entity = CandidateEntity(candidate=candidate)
await entity.initialize(
prometheus_collector=self._prometheus_collector,
database=self._database)
await entity.initialize(prometheus_collector=self._prometheus_collector, database=self._database)
# Store with reference tracking
self._entities[candidate.id] = entity
@ -105,10 +101,12 @@ class EntityManager(BaseModel):
def _on_entity_deleted(self, user_id: str):
"""Callback when entity is garbage collected"""
def cleanup_callback(weak_ref):
self._entities.pop(user_id, None)
self._weak_refs.pop(user_id, None)
print(f"Entity {user_id} garbage collected")
return cleanup_callback
async def release_entity(self, user_id: str):
@ -138,8 +136,7 @@ class EntityManager(BaseModel):
time_since_access = current_time - entity.last_accessed
# Remove if TTL exceeded and no active references
if (time_since_access > timedelta(minutes=self._ttl_minutes)
and entity.reference_count == 0):
if time_since_access > timedelta(minutes=self._ttl_minutes) and entity.reference_count == 0:
expired_entities.append(user_id)
for user_id in expired_entities:
@ -153,6 +150,7 @@ class EntityManager(BaseModel):
# Global entity manager instance
entity_manager = EntityManager(default_ttl_minutes=30)
@asynccontextmanager
async def get_candidate_entity(candidate: Candidate):
"""Context manager for safe entity access with automatic reference management"""
@ -164,4 +162,5 @@ async def get_candidate_entity(candidate: Candidate):
finally:
await entity_manager.release_entity(candidate.id)
EntityManager.model_rebuild()

View File

@ -6,10 +6,7 @@ without getting caught up in serialization format complexities
import sys
from datetime import datetime
from models import (
UserStatus, UserType, SkillLevel, EmploymentType,
Candidate, Employer, Location, Skill
)
from models import UserStatus, UserType, SkillLevel, EmploymentType, Candidate, Employer, Location, Skill
def test_model_creation():
@ -37,7 +34,7 @@ def test_model_creation():
preferred_job_types=[EmploymentType.FULL_TIME],
location=location,
languages=[],
certifications=[]
certifications=[],
)
# Create employer
@ -54,7 +51,7 @@ def test_model_creation():
industry="Technology",
company_size="50-200",
company_description="A test company",
location=location
location=location,
)
print(f"✅ Candidate: {candidate.first_name} {candidate.last_name}")
@ -63,6 +60,7 @@ def test_model_creation():
return candidate, employer
def test_json_api_format():
"""Test JSON serialization in API format (the most important use case)"""
print("\n📡 Testing JSON API format...")
@ -90,6 +88,7 @@ def test_json_api_format():
return True
def test_api_dict_format():
"""Test dictionary format with aliases (for API requests/responses)"""
print("\n📊 Testing API dictionary format...")
@ -120,6 +119,7 @@ def test_api_dict_format():
return True
def test_validation_constraints():
"""Test that validation constraints work"""
print("\n🔒 Testing validation constraints...")
@ -135,7 +135,7 @@ def test_validation_constraints():
status=UserStatus.ACTIVE,
first_name="Jane",
last_name="Doe",
full_name="Jane Doe"
full_name="Jane Doe",
)
print("❌ Validation should have failed but didn't")
return False
@ -143,6 +143,7 @@ def test_validation_constraints():
print(f"✅ Validation error caught: {e}")
return True
def test_enum_values():
"""Test that enum values work correctly"""
print("\n📋 Testing enum values...")
@ -162,6 +163,7 @@ def test_enum_values():
return True
def main():
"""Run all focused tests"""
print("🎯 Focused Pydantic Model Tests")
@ -187,10 +189,12 @@ def main():
except Exception as e:
print(f"\n❌ Test failed: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
print(f"\n{traceback.format_exc()}")
return False
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)

View File

@ -2,6 +2,7 @@ from pydantic import BaseModel
import json
from typing import Any, List, Set
def check_serializable(obj: Any, path: str = "", errors: List[str] = [], visited: Set[int] = set()) -> List[str]:
"""
Recursively check all fields in an object for non-JSON-serializable types, avoiding infinite recursion.

View File

@ -10,16 +10,19 @@ assert issubclass(CandidateAI, BaseUserWithType), "CandidateAI must inherit from
assert issubclass(Employer, BaseUserWithType), "Employer must inherit from BaseUserWithType"
assert issubclass(Guest, BaseUserWithType), "Guest must inherit from BaseUserWithType"
T = TypeVar('T', bound=BaseModel)
T = TypeVar("T", bound=BaseModel)
def cast_to_model(model_cls: Type[T], source: BaseModel) -> T:
data = {field: getattr(source, field) for field in model_cls.__fields__}
return model_cls(**data)
def cast_to_model_safe(model_cls: Type[T], source: BaseModel) -> T:
data = {field: copy.deepcopy(getattr(source, field)) for field in model_cls.__fields__}
return model_cls(**data)
def cast_to_base_user_with_type(user) -> BaseUserWithType:
"""
Casts a Candidate, CandidateAI, Employer, or Guest to BaseUserWithType.

View File

@ -7,6 +7,7 @@ from typing import Any
import torch
from diffusers import StableDiffusionPipeline, FluxPipeline
class ImageModelCache: # Stay loaded for 3 hours
def __init__(self, timeout_seconds: float = 3 * 60 * 60):
self._pipe = None
@ -36,11 +37,11 @@ class ImageModelCache: # Stay loaded for 3 hours
cached_model_type = self._get_model_type(self._model_name) if self._model_name else None
if (
self._pipe is not None and
self._model_name == model and
self._device == device and
current_model_type == cached_model_type and
current_time - self._last_access_time < self._timeout_seconds
self._pipe is not None
and self._model_name == model
and self._device == device
and current_model_type == cached_model_type
and current_time - self._last_access_time < self._timeout_seconds
):
self._last_access_time = current_time
return self._pipe
@ -52,8 +53,10 @@ class ImageModelCache: # Stay loaded for 3 hours
model,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
)
def dummy_safety_checker(images, clip_input):
return images, [False] * len(images)
pipe.safety_checker = dummy_safety_checker
else:
pipe = FluxPipeline.from_pretrained(
@ -61,7 +64,7 @@ class ImageModelCache: # Stay loaded for 3 hours
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
)
try:
pipe.load_lora_weights('enhanceaiteam/Flux-uncensored', weight_name='lora.safetensors')
pipe.load_lora_weights("enhanceaiteam/Flux-uncensored", weight_name="lora.safetensors")
except Exception as e:
raise Exception(f"Failed to load LoRA weights: {str(e)}")
@ -89,10 +92,7 @@ class ImageModelCache: # Stay loaded for 3 hours
async def cleanup_if_expired(self):
async with self._lock:
if (
self._pipe is not None and
time.time() - self._last_access_time >= self._timeout_seconds
):
if self._pipe is not None and time.time() - self._last_access_time >= self._timeout_seconds:
await self._unload_model()
async def _periodic_cleanup(self):

View File

@ -29,6 +29,7 @@ TIME_ESTIMATES = {
}
}
class ImageRequest(BaseModel):
session_id: str
filepath: str
@ -39,18 +40,22 @@ class ImageRequest(BaseModel):
width: int = 256
guidance_scale: float = 7.5
# Global model cache instance
model_cache = ImageModelCache()
def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str):
"""Background worker for Flux image generation"""
try:
# Flux: Run generation in the background and yield progress updates
status_queue.put(ChatMessageStatus(
status_queue.put(
ChatMessageStatus(
session_id=params.session_id,
content="Initializing image generation.",
activity=ApiActivityType.GENERATING_IMAGE,
))
)
)
# Start the generation task
start_gen_time = time.time()
@ -60,11 +65,13 @@ def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task
# Send progress updates
progress = int((step + 1) / params.iterations * 100)
status_queue.put(ChatMessageStatus(
status_queue.put(
ChatMessageStatus(
session_id=params.session_id,
content=f"Processing step {step+1}/{params.iterations} ({progress}%)",
activity=ApiActivityType.GENERATING_IMAGE,
))
)
)
return callback_kwargs
# Replace this block with your actual Flux pipe call:
@ -84,22 +91,28 @@ def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task
image.save(params.filepath)
# Final completion status
status_queue.put(ChatMessage(
status_queue.put(
ChatMessage(
session_id=params.session_id,
status=ApiStatusType.DONE,
content=f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.",
))
)
)
except Exception as e:
logger.error(traceback.format_exc())
logger.error(e)
status_queue.put(ChatMessageError(
status_queue.put(
ChatMessageError(
session_id=params.session_id,
content=f"Error during image generation: {str(e)}",
))
)
)
async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError, None]:
async def async_generate_image(
pipe: Any, params: ImageRequest
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError, None]:
"""
Single async function that handles background Flux generation with status streaming
"""
@ -109,18 +122,14 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato
try:
# Start background worker thread
worker_thread = Thread(
target=flux_worker,
args=(pipe, params, status_queue, task_id),
daemon=True
)
worker_thread = Thread(target=flux_worker, args=(pipe, params, status_queue, task_id), daemon=True)
worker_thread.start()
# Initial status
status_message = ChatMessageStatus(
session_id=params.session_id,
content=f"Starting image generation with task ID {task_id}",
activity=ApiActivityType.THINKING
activity=ApiActivityType.THINKING,
)
yield status_message
@ -177,18 +186,16 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato
yield final_status
except Exception as e:
error_status = ChatMessageError(
session_id=params.session_id,
content=f'Server error: {str(e)}'
)
error_status = ChatMessageError(session_id=params.session_id, content=f"Server error: {str(e)}")
logger.error(error_status)
yield error_status
finally:
# Cleanup: ensure thread completion
if worker_thread and 'worker_thread' in locals() and worker_thread.is_alive():
if worker_thread and "worker_thread" in locals() and worker_thread.is_alive():
worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup
def status(session_id: str, status: str) -> ChatMessageStatus:
"""Update chat message status and return it."""
chat_message = ChatMessageStatus(
@ -198,6 +205,7 @@ def status(session_id: str, status: str) -> ChatMessageStatus:
)
return chat_message
async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, None]:
"""Generate an image with specified dimensions and yield status updates with time estimates."""
session_id = request.session_id
@ -205,10 +213,7 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
try:
# Validate prompt
if not prompt:
error_message = ChatMessageError(
session_id=session_id,
content="Prompt cannot be empty."
)
error_message = ChatMessageError(session_id=session_id, content="Prompt cannot be empty.")
logger.error(error_message.content)
yield error_message
return
@ -216,8 +221,7 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
# Validate dimensions
if request.height <= 0 or request.width <= 0:
error_message = ChatMessageError(
session_id=session_id,
content="Height and width must be positive integers."
session_id=session_id, content="Height and width must be positive integers."
)
logger.error(error_message.content)
yield error_message
@ -240,7 +244,10 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
yield status(session_id, "Loading generative image model...")
pipe = await model_cache.get_pipeline(request.model, device)
load_time = time.time() - start_time
yield status(session_id, f"Model loaded in {load_time:.1f} seconds.",)
yield status(
session_id,
f"Model loaded in {load_time:.1f} seconds.",
)
progress = None
async for progress in async_generate_image(pipe, request):
@ -252,8 +259,7 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
if not progress:
error_message = ChatMessageError(
session_id=session_id,
content="Image generation failed to produce a valid response."
session_id=session_id, content="Image generation failed to produce a valid response."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
@ -269,10 +275,7 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
yield chat_message
except Exception as e:
error_message = ChatMessageError(
session_id=session_id,
content=f"Error during image generation: {str(e)}"
)
error_message = ChatMessageError(session_id=session_id, content=f"Error during image generation: {str(e)}")
logger.error(traceback.format_exc())
logger.error(error_message.content)
yield error_message

View File

@ -2,6 +2,7 @@ import json
import re
from typing import List, Union
def extract_json_blocks(text: str, allow_multiple: bool = False) -> List[dict]:
"""
Extract JSON blocks from text, even if surrounded by markdown or noisy text.
@ -31,13 +32,14 @@ def extract_json_blocks(text: str, allow_multiple: bool = False) -> List[dict]:
return found
def _extract_standalone_json(text: str, allow_multiple: bool = False) -> List[Union[dict, list]]:
"""Extract standalone JSON objects or arrays from text using proper brace counting."""
found = []
i = 0
while i < len(text):
if text[i] in '{[':
if text[i] in "{[":
# Found potential JSON start
json_str = _extract_complete_json_at_position(text, i)
if json_str:
@ -55,16 +57,17 @@ def _extract_standalone_json(text: str, allow_multiple: bool = False) -> List[Un
return found
def _extract_complete_json_at_position(text: str, start_pos: int) -> str:
"""
Extract a complete JSON object or array starting at the given position.
Uses proper brace/bracket counting and string escape handling.
"""
if start_pos >= len(text) or text[start_pos] not in '{[':
if start_pos >= len(text) or text[start_pos] not in "{[":
return ""
start_char = text[start_pos]
end_char = '}' if start_char == '{' else ']'
end_char = "}" if start_char == "{" else "]"
count = 1
i = start_pos + 1
@ -76,7 +79,7 @@ def _extract_complete_json_at_position(text: str, start_pos: int) -> str:
if escape_next:
escape_next = False
elif char == '\\' and in_string:
elif char == "\\" and in_string:
escape_next = True
elif char == '"' and not escape_next:
in_string = not in_string
@ -92,6 +95,7 @@ def _extract_complete_json_at_position(text: str, start_pos: int) -> str:
return text[start_pos:i]
return ""
def extract_json_from_text(text: str) -> str:
"""Extract JSON string from text that may contain other content."""
return json.dumps(extract_json_blocks(text, allow_multiple=False)[0])

View File

@ -2,11 +2,11 @@ import os
import warnings
import logging
import defines
def _setup_logging(level=defines.logging_level) -> logging.Logger:
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
warnings.filterwarnings(
"ignore", message="Overriding a previously registered kernel"
)
warnings.filterwarnings("ignore", message="Overriding a previously registered kernel")
warnings.filterwarnings("ignore", message="Warning only once for all operators")
warnings.filterwarnings("ignore", message=".*Couldn't find ffmpeg or avconv.*")
warnings.filterwarnings("ignore", message="'force_all_finite' was renamed to")
@ -19,8 +19,7 @@ def _setup_logging(level=defines.logging_level) -> logging.Logger:
# Create a custom formatter
formatter = logging.Formatter(
fmt="%(levelname)s - %(filename)s:%(lineno)d - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
fmt="%(levelname)s - %(filename)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
# Create a handler (e.g., StreamHandler for console output)
@ -50,5 +49,6 @@ def _setup_logging(level=defines.logging_level) -> logging.Logger:
logger = logging.getLogger(__name__)
return logger
logger = _setup_logging(level=defines.logging_level)
logger.debug(f"Logging initialized with level: {defines.logging_level}")

View File

@ -58,6 +58,7 @@ background_task_manager = None
prev_int = signal.getsignal(signal.SIGINT)
prev_term = signal.getsignal(signal.SIGTERM)
def signal_handler(signum, frame):
logger.info(f"⚠️ Received signal {signum!r}, shutting down…")
# now call the old handler (it might raise KeyboardInterrupt or exit)
@ -66,9 +67,11 @@ def signal_handler(signum, frame):
elif signum == signal.SIGTERM and callable(prev_term):
prev_term(signum, frame)
# Global background task manager
background_task_manager: Optional[BackgroundTaskManager] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
@ -116,6 +119,7 @@ async def lifespan(app: FastAPI):
if db_manager:
await db_manager.graceful_shutdown()
app = FastAPI(
lifespan=lifespan,
title="Backstory API",
@ -129,11 +133,9 @@ app = FastAPI(
ssl_enabled = os.getenv("SSL_ENABLED", "true").lower() == "true"
if ssl_enabled:
allow_origins = ["https://battle-linux.ketrenos.com:3000",
"https://backstory-beta.ketrenos.com"]
allow_origins = ["https://battle-linux.ketrenos.com:3000", "https://backstory-beta.ketrenos.com"]
else:
allow_origins = ["http://battle-linux.ketrenos.com:3000",
"http://backstory-beta.ketrenos.com"]
allow_origins = ["http://battle-linux.ketrenos.com:3000", "http://backstory-beta.ketrenos.com"]
# Add CORS middleware
app.add_middleware(
@ -144,12 +146,14 @@ app.add_middleware(
allow_headers=["*"],
)
# ============================
# Debug data type failures
# ============================
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
import traceback
logger.error(traceback.format_exc())
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Validation error {request.method} {request.url.path}: {str(exc)}")
@ -158,6 +162,7 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
content=json.dumps({"detail": str(exc)}),
)
# ============================
# Create API router with prefix
# ============================
@ -181,6 +186,7 @@ api_router.include_router(users.router)
# Health Check and Info Endpoints
# ============================
@app.get("/health")
async def health_check(
database=Depends(get_database),
@ -202,15 +208,12 @@ async def health_check(
return {
"status": "healthy",
"timestamp": datetime.utcnow().isoformat(),
"database": {
"status": "connected",
"stats": stats
},
"database": {"status": "connected", "stats": stats},
"redis": {
"version": redis_info.get("redis_version", "unknown"),
"uptime": redis_info.get("uptime_in_seconds", 0),
"memory_used": redis_info.get("used_memory_human", "unknown")
}
"memory_used": redis_info.get("used_memory_human", "unknown"),
},
}
except RuntimeError as e:
@ -219,6 +222,7 @@ async def health_check(
logger.error(f"❌ Health check failed: {e}")
return {"status": "error", "message": str(e)}
# ============================
# Include Router in App
# ============================
@ -231,11 +235,14 @@ app.include_router(api_router)
# ============================
logger.info(f"Debug mode is {'enabled' if defines.debug else 'disabled'}")
@app.middleware("http")
async def log_requests(request: Request, call_next):
try:
if defines.debug and not re.match(rf"{defines.api_prefix}/metrics", request.url.path):
logger.info(f"📝 Request {request.method}: {request.url.path}, Remote: {request.client.host if request.client else ''}")
logger.info(
f"📝 Request {request.method}: {request.url.path}, Remote: {request.client.host if request.client else ''}"
)
response = await call_next(request)
if defines.debug and not re.match(rf"{defines.api_prefix}/metrics", request.url.path):
if response.status_code < 200 or response.status_code >= 300:
@ -243,11 +250,13 @@ async def log_requests(request: Request, call_next):
return response
except Exception as e:
import traceback
logger.error(traceback.format_exc())
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Error processing request: {str(e)}, Path: {request.url.path}, Method: {request.method}")
return JSONResponse(status_code=400, content={"detail": "Invalid HTTP request"})
# ============================
# Request tracking middleware
# ============================
@ -266,6 +275,7 @@ async def track_requests(request, call_next):
finally:
db_manager.decrement_requests()
# ============================
# FastAPI Metrics
# ============================
@ -277,7 +287,7 @@ instrumentator = Instrumentator(
should_ignore_untemplated=True,
should_group_untemplated=True,
excluded_handlers=[f"{defines.api_prefix}/metrics"],
registry=prometheus_collector
registry=prometheus_collector,
)
# Instrument the FastAPI app
@ -291,6 +301,7 @@ instrumentator.expose(app, endpoint=f"{defines.api_prefix}/metrics")
# Static File Serving
# ============================
@app.get("/{path:path}")
async def serve_static(path: str, request: Request):
full_path = os.path.join(defines.static_content, path)
@ -300,6 +311,7 @@ async def serve_static(path: str, request: Request):
return FileResponse(os.path.join(defines.static_content, "index.html"))
# Root endpoint when no static files
@app.get("/", include_in_schema=False)
async def root():
@ -309,9 +321,10 @@ async def root():
"version": "1.0.0",
"api_prefix": defines.api_prefix,
"documentation": f"{defines.api_prefix}/docs",
"health": f"{defines.api_prefix}/health"
"health": f"{defines.api_prefix}/health",
}
async def periodic_verification_cleanup():
"""Background task to periodically clean up expired verification tokens"""
try:
@ -324,6 +337,7 @@ async def periodic_verification_cleanup():
except Exception as e:
logger.error(f"❌ Error in periodic verification cleanup: {e}")
if __name__ == "__main__":
host = defines.host
port = defines.port

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,23 @@
[tool.ruff]
# Set line length (default is 88, Black-compatible)
line-length = 120
exclude = [
".git",
"__pycache__",
"migrations",
"venv",
"dist",
"build",
]
target-version = "py312"
[tool.ruff.format]
# Use Ruff's formatter (Black-compatible by default)
quote-style = "double" # Enforce double quotes
indent-style = "space" # Use spaces for indentation
skip-magic-trailing-comma = false # Add trailing commas where applicable
[tool.ruff.per-file-ignores]
# Ignore specific rules for certain files (e.g., tests or scripts)
"tests/**" = ["D100", "S101"] # Ignore missing docstrings and assert warnings in tests
"scripts/**" = ["E402"] # Ignore import order in scripts

View File

@ -1,7 +1,3 @@
from .rag import ChromaDBFileWatcher, start_file_watcher, RagEntry
__all__ = [
"ChromaDBFileWatcher",
"start_file_watcher",
"RagEntry"
]
__all__ = ["ChromaDBFileWatcher", "start_file_watcher", "RagEntry"]

View File

@ -7,10 +7,12 @@ import logging
import defines
class Chunk(TypedDict):
text: str
metadata: Dict[str, Any]
def clear_chunk(chunk: Chunk):
chunk["text"] = ""
chunk["metadata"] = {
@ -22,6 +24,7 @@ def clear_chunk(chunk: Chunk):
}
return chunk
class MarkdownChunker:
def __init__(self):
# Initialize the Markdown parser
@ -76,10 +79,7 @@ class MarkdownChunker:
# Initialize a chunk structure
chunk: Chunk = {
"text": "",
"metadata": {
"source_file": file_path,
"lines": total_lines
},
"metadata": {"source_file": file_path, "lines": total_lines},
}
clear_chunk(chunk)
@ -89,9 +89,7 @@ class MarkdownChunker:
return chunks
def _sanitize_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
return {
k: ("" if v is None else v) for k, v in metadata.items() if v is not None
}
return {k: ("" if v is None else v) for k, v in metadata.items() if v is not None}
def _extract_text_from_children(self, node: SyntaxTreeNode) -> str:
lines = []
@ -114,7 +112,7 @@ class MarkdownChunker:
chunks: List[Chunk],
chunk: Chunk,
level: int,
buffer: int = defines.chunk_buffer
buffer: int = defines.chunk_buffer,
) -> int:
is_list = False
# Handle heading nodes
@ -198,9 +196,7 @@ class MarkdownChunker:
# Recursively process children
if not is_list:
for child in node.children:
level = self._process_node(
child, current_headings, chunks, chunk, level=level
)
level = self._process_node(child, current_headings, chunks, chunk, level=level)
# After root-level recursion, finalize any remaining chunk
if node.type == "document":

View File

@ -33,11 +33,13 @@ __all__ = ["ChromaDBFileWatcher", "start_file_watcher"]
DEFAULT_CHUNK_SIZE = 750
DEFAULT_CHUNK_OVERLAP = 100
class RagEntry(BaseModel):
name: str
description: str = ""
enabled: bool = True
class ChromaDBFileWatcher(FileSystemEventHandler):
def __init__(
self,
@ -72,9 +74,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Path for storing file hash state
self.hash_state_path = os.path.join(
self.persist_directory, f"{collection_name}_hash_state.json"
)
self.hash_state_path = os.path.join(self.persist_directory, f"{collection_name}_hash_state.json")
# Flag to track if this is a new collection
self.is_new_collection = False
@ -158,9 +158,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
process_all: If True, process all files regardless of hash status
"""
# Check for new or modified files
file_paths = glob.glob(
os.path.join(self.watch_directory, "**/*"), recursive=True
)
file_paths = glob.glob(os.path.join(self.watch_directory, "**/*"), recursive=True)
files_checked = 0
files_processed = 0
files_to_process = []
@ -180,20 +178,12 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
continue
# If file is new, changed, or we're processing all files
if (
process_all
or file_path not in self.file_hashes
or self.file_hashes[file_path] != current_hash
):
if process_all or file_path not in self.file_hashes or self.file_hashes[file_path] != current_hash:
self.file_hashes[file_path] = current_hash
files_to_process.append(file_path)
logging.info(
f"File {'found' if process_all else 'changed'}: {file_path}"
)
logging.info(f"File {'found' if process_all else 'changed'}: {file_path}")
logging.info(
f"Found {len(files_to_process)} files to process after scanning {files_checked} files"
)
logging.info(f"Found {len(files_to_process)} files to process after scanning {files_checked} files")
# Check for deleted files
deleted_files = []
@ -201,9 +191,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
if not os.path.exists(file_path):
deleted_files.append(file_path)
# Schedule removal
asyncio.run_coroutine_threadsafe(
self.remove_file_from_collection(file_path), self.loop
)
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
# Don't block on result, just let it run
logging.info(f"File deleted: {file_path}")
@ -253,10 +241,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
if not current_hash: # File might have been deleted or is inaccessible
return
if (
file_path in self.file_hashes
and self.file_hashes[file_path] == current_hash
):
if file_path in self.file_hashes and self.file_hashes[file_path] == current_hash:
# File hasn't actually changed in content
logging.info(f"Hash has not changed for {file_path}")
return
@ -289,9 +274,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
if results and "ids" in results and results["ids"]:
self.collection.delete(ids=results["ids"])
await self.database.update_user_rag_timestamp(self.user_id)
logging.info(
f"Removed {len(results['ids'])} chunks for deleted file: {file_path}"
)
logging.info(f"Removed {len(results['ids'])} chunks for deleted file: {file_path}")
# Remove from hash dictionary
if file_path in self.file_hashes:
@ -304,17 +287,15 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
def _update_umaps(self):
# Update the UMAP embeddings
self._umap_collection = ChromaDBGetResponse.model_validate(self._collection.get(
include=["embeddings", "documents", "metadatas"]
))
self._umap_collection = ChromaDBGetResponse.model_validate(
self._collection.get(include=["embeddings", "documents", "metadatas"])
)
if not self._umap_collection or not len(self._umap_collection.embeddings):
logging.warning("⚠️ No embeddings found in the collection.")
return
# During initialization
logging.info(
f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors"
)
logging.info(f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors")
vectors = np.array(self._umap_collection.embeddings)
self._umap_model_2d = umap.UMAP(
n_components=2,
@ -328,9 +309,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# f"2D UMAP model n_components: {self._umap_model_2d.n_components}"
# ) # Should be 2
logging.info(
f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors"
)
logging.info(f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors")
self._umap_model_3d = umap.UMAP(
n_components=3,
random_state=8911,
@ -373,9 +352,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
self.is_new_collection = True
logging.info(f"Recreating collection: {self.collection_name}")
return chroma_client.get_or_create_collection(
name=self.collection_name, metadata={"hnsw:space": "cosine"}
)
return chroma_client.get_or_create_collection(name=self.collection_name, metadata={"hnsw:space": "cosine"})
async def get_embedding(self, text: str) -> np.ndarray:
"""Generate and normalize an embedding for the given text."""
@ -419,9 +396,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# Generate a more unique ID based on content and metadata
path_hash = ""
if "path" in metadata:
path_hash = hashlib.md5(metadata["source_file"].encode()).hexdigest()[
:8
]
path_hash = hashlib.md5(metadata["source_file"].encode()).hexdigest()[:8]
content_hash = hashlib.md5(text.encode()).hexdigest()[:8]
chunk_id = f"{path_hash}_{i}_{content_hash}"
@ -541,9 +516,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
return
file_path = event.src_path
asyncio.run_coroutine_threadsafe(
self.remove_file_from_collection(file_path), self.loop
)
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
logging.info(f"File deleted: {file_path}")
def on_moved(self, event):
@ -571,11 +544,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
try:
# Remove existing entries for this file
existing_results = self.collection.get(where={"path": file_path})
if (
existing_results
and "ids" in existing_results
and existing_results["ids"]
):
if existing_results and "ids" in existing_results and existing_results["ids"]:
self.collection.delete(ids=existing_results["ids"])
await self.database.update_user_rag_timestamp(self.user_id)
@ -584,15 +553,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
p = Path(file_path)
p_as_md = p.with_suffix(".md")
if p_as_md.exists():
logging.info(
f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}"
)
logging.info(f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}")
# If file_path.md doesn't exist or file_path is newer than file_path.md,
# fire off markitdown
if (not p_as_md.exists()) or (
p.stat().st_mtime > p_as_md.stat().st_mtime
):
if (not p_as_md.exists()) or (p.stat().st_mtime > p_as_md.stat().st_mtime):
self._markitdown(file_path, p_as_md)
return
@ -626,9 +591,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# Process all files regardless of hash state
num_processed = await self.scan_directory(process_all=True)
logging.info(
f"Vectorstore initialized with {self.collection.count()} documents"
)
logging.info(f"Vectorstore initialized with {self.collection.count()} documents")
self._update_umaps()
@ -676,7 +639,7 @@ def start_file_watcher(
persist_directory=persist_directory,
collection_name=collection_name,
recreate=recreate,
database=database
database=database,
)
# Process all files if:

View File

@ -13,14 +13,4 @@ from . import employers
from . import admin
from . import system
__all__ = [
"auth",
"candidates",
"resumes",
"jobs",
"chat",
"users",
"employers",
"admin",
"system"
]
__all__ = ["auth", "candidates", "resumes", "jobs", "chat", "users", "employers", "admin", "system"]

View File

@ -5,8 +5,18 @@ import json
from datetime import datetime, timezone, UTC
from fastapi import (
APIRouter, HTTPException, Depends, Body, Request, HTTPException,
Depends, Query, Path, Body, APIRouter, Request
APIRouter,
HTTPException,
Depends,
Body,
Request,
HTTPException,
Depends,
Query,
Path,
Body,
APIRouter,
Request,
)
from fastapi.responses import JSONResponse
@ -14,52 +24,51 @@ from fastapi.responses import JSONResponse
from utils.rate_limiter import RateLimiter, get_rate_limiter
from database.manager import RedisDatabase
from logger import logger
from utils.dependencies import (
get_current_admin, get_current_user_or_guest, get_database, background_task_manager
)
from utils.dependencies import get_current_admin, get_current_user_or_guest, get_database, background_task_manager
from utils.responses import (
create_paginated_response, create_success_response, create_error_response,
create_paginated_response,
create_success_response,
create_error_response,
)
# Create router for authentication endpoints
router = APIRouter(prefix="/admin", tags=["admin"])
@router.post("/tasks/cleanup-guests")
async def manual_guest_cleanup(
inactive_hours: int = Body(24, embed=True),
current_user=Depends(get_current_admin),
admin_user = Depends(get_current_admin)
admin_user=Depends(get_current_admin),
):
"""Manually trigger guest cleanup (admin only)"""
try:
if not background_task_manager:
return JSONResponse(
status_code=500,
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available")
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available"),
)
cleaned_count = await background_task_manager.cleanup_inactive_guests(inactive_hours)
logger.info(f"🧹 Manual guest cleanup triggered by admin {admin_user.id}: {cleaned_count} guests cleaned")
return create_success_response({
return create_success_response(
{
"message": f"Guest cleanup completed. Removed {cleaned_count} inactive sessions.",
"cleaned_count": cleaned_count,
"triggered_by": admin_user.id
})
"triggered_by": admin_user.id,
}
)
except Exception as e:
logger.error(f"❌ Manual guest cleanup error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("CLEANUP_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
@router.post("/tasks/cleanup-tokens")
async def manual_token_cleanup(
admin_user = Depends(get_current_admin)
):
async def manual_token_cleanup(admin_user=Depends(get_current_admin)):
"""Manually trigger verification token cleanup (admin only)"""
try:
global background_task_manager
@ -67,31 +76,28 @@ async def manual_token_cleanup(
if not background_task_manager:
return JSONResponse(
status_code=500,
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available")
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available"),
)
cleaned_count = await background_task_manager.cleanup_expired_verification_tokens()
logger.info(f"🧹 Manual token cleanup triggered by admin {admin_user.id}: {cleaned_count} tokens cleaned")
return create_success_response({
return create_success_response(
{
"message": f"Token cleanup completed. Removed {cleaned_count} expired tokens.",
"cleaned_count": cleaned_count,
"triggered_by": admin_user.id
})
"triggered_by": admin_user.id,
}
)
except Exception as e:
logger.error(f"❌ Manual token cleanup error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("CLEANUP_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
@router.post("/tasks/cleanup-rate-limits")
async def manual_rate_limit_cleanup(
days_old: int = Body(7, embed=True),
admin_user = Depends(get_current_admin)
):
async def manual_rate_limit_cleanup(days_old: int = Body(7, embed=True), admin_user=Depends(get_current_admin)):
"""Manually trigger rate limit data cleanup (admin only)"""
try:
global background_task_manager
@ -99,60 +105,55 @@ async def manual_rate_limit_cleanup(
if not background_task_manager:
return JSONResponse(
status_code=500,
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available")
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available"),
)
cleaned_count = await background_task_manager.cleanup_old_rate_limit_data(days_old)
logger.info(f"🧹 Manual rate limit cleanup triggered by admin {admin_user.id}: {cleaned_count} keys cleaned")
return create_success_response({
return create_success_response(
{
"message": f"Rate limit cleanup completed. Removed {cleaned_count} old keys.",
"cleaned_count": cleaned_count,
"triggered_by": admin_user.id
})
"triggered_by": admin_user.id,
}
)
except Exception as e:
logger.error(f"❌ Manual rate limit cleanup error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("CLEANUP_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
# ========================================
# System Health and Maintenance Endpoints
# ========================================
@router.get("/system/health")
async def get_system_health(
request: Request,
admin_user = Depends(get_current_admin)
):
async def get_system_health(request: Request, admin_user=Depends(get_current_admin)):
"""Get comprehensive system health status (admin only)"""
try:
# Database health
database_manager = getattr(request.app.state, 'database_manager', None)
database_manager = getattr(request.app.state, "database_manager", None)
db_health = {"status": "unavailable", "healthy": False}
if database_manager:
try:
database_manager.get_database()
from database.manager import redis_manager
redis_health = await redis_manager.health_check()
db_health = {
"status": redis_health.get("status", "unknown"),
"healthy": redis_health.get("status") == "healthy",
"details": redis_health
"details": redis_health,
}
except Exception as e:
db_health = {
"status": "error",
"healthy": False,
"error": str(e)
}
db_health = {"status": "error", "healthy": False, "error": str(e)}
# Background task health
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
background_task_manager = getattr(request.app.state, "background_task_manager", None)
task_health = {"status": "unavailable", "healthy": False}
if background_task_manager:
@ -166,39 +167,30 @@ async def get_system_health(
"healthy": task_status["running"] and failed_tasks == 0,
"running_tasks": running_tasks,
"failed_tasks": failed_tasks,
"total_tasks": task_status["task_count"]
"total_tasks": task_status["task_count"],
}
except Exception as e:
task_health = {
"status": "error",
"healthy": False,
"error": str(e)
}
task_health = {"status": "error", "healthy": False, "error": str(e)}
# Overall health
overall_healthy = db_health["healthy"] and task_health["healthy"]
return create_success_response({
return create_success_response(
{
"timestamp": datetime.now(UTC).isoformat(),
"overall_healthy": overall_healthy,
"components": {
"database": db_health,
"background_tasks": task_health
"components": {"database": db_health, "background_tasks": task_health},
}
})
)
except Exception as e:
logger.error(f"❌ Error getting system health: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("HEALTH_CHECK_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("HEALTH_CHECK_ERROR", str(e)))
@router.post("/maintenance/cleanup")
async def run_maintenance_cleanup(
request: Request,
admin_user = Depends(get_current_admin),
database: RedisDatabase = Depends(get_database)
request: Request, admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
):
"""Run comprehensive maintenance cleanup (admin only)"""
try:
@ -217,50 +209,40 @@ async def run_maintenance_cleanup(
cleanup_results[operation_name] = {
"success": True,
"cleaned_count": result,
"message": f"Cleaned {result} items"
"message": f"Cleaned {result} items",
}
except Exception as e:
cleanup_results[operation_name] = {
"success": False,
"error": str(e),
"message": f"Failed: {str(e)}"
}
cleanup_results[operation_name] = {"success": False, "error": str(e), "message": f"Failed: {str(e)}"}
# Calculate totals
total_cleaned = sum(
result.get("cleaned_count", 0)
for result in cleanup_results.values()
if result.get("success", False)
result.get("cleaned_count", 0) for result in cleanup_results.values() if result.get("success", False)
)
successful_operations = len([
r for r in cleanup_results.values()
if r.get("success", False)
])
successful_operations = len([r for r in cleanup_results.values() if r.get("success", False)])
return create_success_response({
return create_success_response(
{
"message": f"Maintenance cleanup completed. {total_cleaned} items cleaned across {successful_operations} operations.",
"total_cleaned": total_cleaned,
"successful_operations": successful_operations,
"details": cleanup_results
})
"details": cleanup_results,
}
)
except Exception as e:
logger.error(f"❌ Error in maintenance cleanup: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("CLEANUP_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
# ========================================
# Background Task Statistics
# ========================================
@router.get("/tasks/stats")
async def get_task_statistics(
request: Request,
admin_user = Depends(get_current_admin),
database: RedisDatabase = Depends(get_database)
request: Request, admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
):
"""Get background task execution statistics (admin only)"""
try:
@ -268,7 +250,7 @@ async def get_task_statistics(
guest_stats = await database.get_guest_statistics()
# Get background task manager status
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
background_task_manager = getattr(request.app.state, "background_task_manager", None)
task_manager_stats = {}
if background_task_manager:
@ -276,7 +258,7 @@ async def get_task_statistics(
task_manager_stats = {
"running": task_status["running"],
"task_count": task_status["task_count"],
"task_breakdown": {}
"task_breakdown": {},
}
# Count tasks by status
@ -284,40 +266,35 @@ async def get_task_statistics(
status = task["status"]
task_manager_stats["task_breakdown"][status] = task_manager_stats["task_breakdown"].get(status, 0) + 1
return create_success_response({
return create_success_response(
{
"guest_statistics": guest_stats,
"task_manager": task_manager_stats,
"timestamp": datetime.now(UTC).isoformat()
})
"timestamp": datetime.now(UTC).isoformat(),
}
)
except Exception as e:
logger.error(f"❌ Error getting task statistics: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("STATS_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("STATS_ERROR", str(e)))
# ========================================
# Background Task Status Endpoints
# ========================================
@router.get("/tasks/status")
async def get_background_task_status(
request: Request,
admin_user = Depends(get_current_admin)
):
async def get_background_task_status(request: Request, admin_user=Depends(get_current_admin)):
"""Get background task manager status (admin only)"""
try:
# Get background task manager from app state
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
background_task_manager = getattr(request.app.state, "background_task_manager", None)
if not background_task_manager:
return create_success_response({
"running": False,
"message": "Background task manager not initialized",
"tasks": [],
"task_count": 0
})
return create_success_response(
{"running": False, "message": "Background task manager not initialized", "tasks": [], "task_count": 0}
)
# Get comprehensive task status using the new method
task_status = await background_task_manager.get_task_status()
@ -329,87 +306,64 @@ async def get_background_task_status(
}
# Format the response
return create_success_response({
return create_success_response(
{
"running": task_status["running"],
"task_count": task_status["task_count"],
"loop_status": {
"main_loop_id": task_status["main_loop_id"],
"current_loop_id": task_status["current_loop_id"],
"loop_matches": task_status.get("loop_matches", False)
"loop_matches": task_status.get("loop_matches", False),
},
"tasks": task_status["tasks"],
"system_info": system_info
})
"system_info": system_info,
}
)
except Exception as e:
logger.error(f"❌ Get task status error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("STATUS_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("STATUS_ERROR", str(e)))
@router.post("/tasks/run/{task_name}")
async def run_background_task(
task_name: str,
request: Request,
admin_user = Depends(get_current_admin)
):
async def run_background_task(task_name: str, request: Request, admin_user=Depends(get_current_admin)):
"""Manually trigger a specific background task (admin only)"""
try:
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
background_task_manager = getattr(request.app.state, "background_task_manager", None)
if not background_task_manager:
return JSONResponse(
status_code=503,
content=create_error_response(
"MANAGER_UNAVAILABLE",
"Background task manager not initialized"
)
content=create_error_response("MANAGER_UNAVAILABLE", "Background task manager not initialized"),
)
# List of available tasks
available_tasks = [
"guest_cleanup",
"token_cleanup",
"guest_stats",
"rate_limit_cleanup",
"orphaned_cleanup"
]
available_tasks = ["guest_cleanup", "token_cleanup", "guest_stats", "rate_limit_cleanup", "orphaned_cleanup"]
if task_name not in available_tasks:
return JSONResponse(
status_code=400,
content=create_error_response(
"INVALID_TASK",
f"Unknown task: {task_name}. Available: {available_tasks}"
)
"INVALID_TASK", f"Unknown task: {task_name}. Available: {available_tasks}"
),
)
# Run the task
result = await background_task_manager.force_run_task(task_name)
return create_success_response({
"task_name": task_name,
"result": result,
"message": f"Task {task_name} completed successfully"
})
return create_success_response(
{"task_name": task_name, "result": result, "message": f"Task {task_name} completed successfully"}
)
except ValueError as e:
return JSONResponse(
status_code=400,
content=create_error_response("INVALID_TASK", str(e))
)
return JSONResponse(status_code=400, content=create_error_response("INVALID_TASK", str(e)))
except Exception as e:
logger.error(f"❌ Error running task {task_name}: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("TASK_EXECUTION_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("TASK_EXECUTION_ERROR", str(e)))
@router.get("/tasks/list")
async def list_available_tasks(
admin_user = Depends(get_current_admin)
):
async def list_available_tasks(admin_user=Depends(get_current_admin)):
"""List all available background tasks (admin only)"""
try:
tasks = [
@ -417,63 +371,51 @@ async def list_available_tasks(
"name": "guest_cleanup",
"description": "Clean up inactive guest sessions",
"interval": "6 hours",
"parameters": ["inactive_hours (default: 48)"]
"parameters": ["inactive_hours (default: 48)"],
},
{
"name": "token_cleanup",
"description": "Clean up expired email verification tokens",
"interval": "12 hours",
"parameters": []
"parameters": [],
},
{
"name": "guest_stats",
"description": "Update guest usage statistics",
"interval": "1 hour",
"parameters": []
"parameters": [],
},
{
"name": "rate_limit_cleanup",
"description": "Clean up old rate limiting data",
"interval": "24 hours",
"parameters": ["days_old (default: 7)"]
"parameters": ["days_old (default: 7)"],
},
{
"name": "orphaned_cleanup",
"description": "Clean up orphaned database records",
"interval": "6 hours",
"parameters": []
}
"parameters": [],
},
]
return create_success_response({
"total_tasks": len(tasks),
"tasks": tasks
})
return create_success_response({"total_tasks": len(tasks), "tasks": tasks})
except Exception as e:
logger.error(f"❌ Error listing tasks: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("LIST_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("LIST_ERROR", str(e)))
@router.post("/tasks/restart")
async def restart_background_tasks(
request: Request,
admin_user = Depends(get_current_admin)
):
async def restart_background_tasks(request: Request, admin_user=Depends(get_current_admin)):
"""Restart the background task manager (admin only)"""
try:
database_manager = getattr(request.app.state, 'database_manager', None)
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
database_manager = getattr(request.app.state, "database_manager", None)
background_task_manager = getattr(request.app.state, "background_task_manager", None)
if not database_manager:
return JSONResponse(
status_code=503,
content=create_error_response(
"DATABASE_UNAVAILABLE",
"Database manager not available"
)
status_code=503, content=create_error_response("DATABASE_UNAVAILABLE", "Database manager not available")
)
# Stop existing background tasks
@ -483,6 +425,7 @@ async def restart_background_tasks(
# Create and start new background task manager
from background_tasks import BackgroundTaskManager
new_background_task_manager = BackgroundTaskManager(database_manager)
await new_background_task_manager.start()
@ -492,22 +435,20 @@ async def restart_background_tasks(
# Get status of new manager
status = await new_background_task_manager.get_task_status()
return create_success_response({
"message": "Background task manager restarted successfully",
"new_status": status
})
return create_success_response(
{"message": "Background task manager restarted successfully", "new_status": status}
)
except Exception as e:
logger.error(f"❌ Error restarting background tasks: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("RESTART_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("RESTART_ERROR", str(e)))
# ============================
# Task Monitoring and Metrics
# ============================
class TaskMetrics:
"""Collect metrics for background tasks"""
@ -544,36 +485,32 @@ class TaskMetrics:
metrics[task_name] = {
"total_runs": self.task_runs[task_name],
"total_errors": self.task_errors[task_name],
"success_rate": (self.task_runs[task_name] - self.task_errors[task_name]) / self.task_runs[task_name] if self.task_runs[task_name] > 0 else 0,
"success_rate": (self.task_runs[task_name] - self.task_errors[task_name]) / self.task_runs[task_name]
if self.task_runs[task_name] > 0
else 0,
"average_duration": avg_duration,
"last_runs": durations[-10:] if durations else []
"last_runs": durations[-10:] if durations else [],
}
return metrics
# Global task metrics
task_metrics = TaskMetrics()
@router.get("/tasks/metrics")
async def get_task_metrics(
admin_user = Depends(get_current_admin)
):
async def get_task_metrics(admin_user=Depends(get_current_admin)):
"""Get background task metrics (admin only)"""
try:
global task_metrics
metrics = task_metrics.get_metrics()
return create_success_response({
"metrics": metrics,
"timestamp": datetime.now(UTC).isoformat()
})
return create_success_response({"metrics": metrics, "timestamp": datetime.now(UTC).isoformat()})
except Exception as e:
logger.error(f"❌ Get task metrics error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("METRICS_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("METRICS_ERROR", str(e)))
# ============================
@ -581,8 +518,7 @@ async def get_task_metrics(
# ============================
# @router.get("/verification-stats")
async def get_verification_statistics(
current_user = Depends(get_current_admin),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
):
"""Get verification statistics (admin only)"""
try:
@ -591,22 +527,19 @@ async def get_verification_statistics(
stats = {
"pending_verifications": await database.get_pending_verifications_count(),
"expired_tokens_cleaned": await database.cleanup_expired_verification_tokens()
"expired_tokens_cleaned": await database.cleanup_expired_verification_tokens(),
}
return create_success_response(stats)
except Exception as e:
logger.error(f"❌ Error getting verification stats: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("STATS_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("STATS_ERROR", str(e)))
@router.post("/cleanup-verifications")
async def cleanup_verification_tokens(
current_user = Depends(get_current_admin),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
):
"""Manually trigger cleanup of expired verification tokens (admin only)"""
try:
@ -617,24 +550,24 @@ async def cleanup_verification_tokens(
logger.info(f"🧹 Manual cleanup completed by admin {current_user.id}: {cleaned_count} tokens cleaned")
return create_success_response({
return create_success_response(
{
"message": f"Cleanup completed. Removed {cleaned_count} expired verification tokens.",
"cleaned_count": cleaned_count
})
"cleaned_count": cleaned_count,
}
)
except Exception as e:
logger.error(f"❌ Error in manual cleanup: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("CLEANUP_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
@router.get("/pending-verifications")
async def get_pending_verifications(
current_user=Depends(get_current_admin),
page: int = Query(1, ge=1),
limit: int = Query(20, ge=1, le=100),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Get list of pending email verifications (admin only)"""
try:
@ -656,14 +589,16 @@ async def get_pending_verifications(
if not verification_info.get("verified", False):
expires_at = datetime.fromisoformat(verification_info.get("expires_at", ""))
pending_verifications.append({
pending_verifications.append(
{
"email": verification_info.get("email"),
"user_type": verification_info.get("user_type"),
"created_at": verification_info.get("created_at"),
"expires_at": verification_info.get("expires_at"),
"is_expired": current_time > expires_at,
"resend_count": verification_info.get("resend_count", 0)
})
"resend_count": verification_info.get("resend_count", 0),
}
)
if cursor == 0:
break
@ -677,35 +612,27 @@ async def get_pending_verifications(
end = start + limit
paginated_verifications = pending_verifications[start:end]
paginated_response = create_paginated_response(
paginated_verifications,
page, limit, total
)
paginated_response = create_paginated_response(paginated_verifications, page, limit, total)
return create_success_response(paginated_response)
except Exception as e:
logger.error(f"❌ Error getting pending verifications: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("FETCH_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e)))
@router.get("/rate-limits/info")
async def get_user_rate_limit_status(
current_user=Depends(get_current_user_or_guest),
rate_limiter: RateLimiter = Depends(get_rate_limiter),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Get rate limit status for a user (admin only)"""
try:
# Get user to determine type
user_data = await database.get_user_by_id(current_user.id)
if not user_data:
return JSONResponse(
status_code=404,
content=create_error_response("USER_NOT_FOUND", "User not found")
)
return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
user_type = user_data.get("type", "unknown")
is_admin = False
@ -725,27 +652,22 @@ async def get_user_rate_limit_status(
except Exception as e:
logger.error(f"❌ Get rate limit status error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("STATUS_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("STATUS_ERROR", str(e)))
@router.get("/rate-limits/{user_id}")
async def get_anyone_rate_limit_status(
user_id: str = Path(...),
admin_user=Depends(get_current_admin),
rate_limiter: RateLimiter = Depends(get_rate_limiter),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Get rate limit status for a user (admin only)"""
try:
# Get user to determine type
user_data = await database.get_user_by_id(user_id)
if not user_data:
return JSONResponse(
status_code=404,
content=create_error_response("USER_NOT_FOUND", "User not found")
)
return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
user_type = user_data.get("type", "unknown")
is_admin = False
@ -765,47 +687,36 @@ async def get_anyone_rate_limit_status(
except Exception as e:
logger.error(f"❌ Get rate limit status error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("STATUS_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("STATUS_ERROR", str(e)))
@router.post("/rate-limits/{user_id}/reset")
async def reset_user_rate_limits(
user_id: str = Path(...),
admin_user=Depends(get_current_admin),
rate_limiter: RateLimiter = Depends(get_rate_limiter),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Reset rate limits for a user (admin only)"""
try:
# Get user to determine type
user_data = await database.get_user_by_id(user_id)
if not user_data:
return JSONResponse(
status_code=404,
content=create_error_response("USER_NOT_FOUND", "User not found")
)
return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
user_type = user_data.get("type", "unknown")
success = await rate_limiter.reset_user_rate_limits(user_id, user_type)
if success:
logger.info(f"🔄 Rate limits reset for {user_type} {user_id} by admin {admin_user.id}")
return create_success_response({
"message": f"Rate limits reset for {user_type} {user_id}",
"resetBy": admin_user.id
})
return create_success_response(
{"message": f"Rate limits reset for {user_type} {user_id}", "resetBy": admin_user.id}
)
else:
return JSONResponse(
status_code=500,
content=create_error_response("RESET_FAILED", "Failed to reset rate limits")
status_code=500, content=create_error_response("RESET_FAILED", "Failed to reset rate limits")
)
except Exception as e:
logger.error(f"❌ Reset rate limits error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("RESET_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("RESET_ERROR", str(e)))

View File

@ -20,18 +20,25 @@ from device_manager import DeviceManager
from email_service import VerificationEmailRateLimiter, email_service
from logger import logger
from models import (
LoginRequest, CreateCandidateRequest, Candidate,
Employer, Guest, AuthResponse, MFARequest,
MFAData, MFAVerifyRequest, ResendVerificationRequest,
MFARequestResponse, MFARequestResponse
)
from utils.dependencies import (
get_current_admin, get_database, get_current_user, create_access_token
LoginRequest,
CreateCandidateRequest,
Candidate,
Employer,
Guest,
AuthResponse,
MFARequest,
MFAData,
MFAVerifyRequest,
ResendVerificationRequest,
MFARequestResponse,
MFARequestResponse,
)
from utils.dependencies import get_current_admin, get_database, get_current_user, create_access_token
from utils.responses import create_success_response, create_error_response
from utils.rate_limiter import get_rate_limiter
from utils.auth_utils import (
AuthenticationManager, SecurityConfig,
AuthenticationManager,
SecurityConfig,
validate_password_strength,
)
@ -43,28 +50,31 @@ if JWT_SECRET_KEY == "":
raise ValueError("JWT_SECRET_KEY environment variable is not set")
ALGORITHM = "HS256"
# ============================
# Password Reset Endpoints
# ============================
class PasswordResetRequest(BaseModel):
email: EmailStr
class PasswordResetConfirm(BaseModel):
token: str
new_password: str
@field_validator('new_password')
@field_validator("new_password")
def validate_password_strength(cls, v):
is_valid, issues = validate_password_strength(v)
if not is_valid:
raise ValueError('; '.join(issues))
raise ValueError("; ".join(issues))
return v
@router.post("/guest")
async def create_guest_session_enhanced(
request: Request,
database: RedisDatabase = Depends(get_database),
rate_limiter: RateLimiter = Depends(get_rate_limiter)
rate_limiter: RateLimiter = Depends(get_rate_limiter),
):
"""Create a guest session with enhanced validation and persistence"""
try:
@ -73,21 +83,15 @@ async def create_guest_session_enhanced(
# Check rate limits for guest session creation
rate_result = await rate_limiter.check_rate_limit(
user_id=ip_address,
user_type="guest_creation",
is_admin=False,
endpoint="/guest"
user_id=ip_address, user_type="guest_creation", is_admin=False, endpoint="/guest"
)
if not rate_result.allowed:
logger.warning(f"🚫 Guest creation rate limit exceeded for IP {ip_address}")
return JSONResponse(
status_code=429,
content=create_error_response(
"RATE_LIMITED",
rate_result.reason or "Too many guest sessions created"
),
headers={"Retry-After": str(rate_result.retry_after_seconds or 300)}
content=create_error_response("RATE_LIMITED", rate_result.reason or "Too many guest sessions created"),
headers={"Retry-After": str(rate_result.retry_after_seconds or 300)},
)
# Generate unique guest identifier with timestamp for uniqueness
@ -139,7 +143,7 @@ async def create_guest_session_enhanced(
"email": guest_data["email"],
"username": guest_username,
"session_id": session_id,
"created_at": current_time.isoformat()
"created_at": current_time.isoformat(),
}
await database.set_user(guest_data["email"], user_auth_data)
@ -149,11 +153,11 @@ async def create_guest_session_enhanced(
# Create authentication tokens with longer expiry for guests
access_token = create_access_token(
data={"sub": guest_id, "type": "guest"},
expires_delta=timedelta(hours=48) # Longer expiry for guests
expires_delta=timedelta(hours=48), # Longer expiry for guests
)
refresh_token = create_access_token(
data={"sub": guest_id, "type": "refresh_guest"},
expires_delta=timedelta(days=14) # 2 weeks refresh for guests
expires_delta=timedelta(days=14), # 2 weeks refresh for guests
)
# Verify guest was stored correctly
@ -161,8 +165,7 @@ async def create_guest_session_enhanced(
if not verification:
logger.error(f"❌ Failed to verify guest storage: {guest_id}")
return JSONResponse(
status_code=500,
content=create_error_response("STORAGE_ERROR", "Failed to create guest session")
status_code=500, content=create_error_response("STORAGE_ERROR", "Failed to create guest session")
)
# Create guest object for response
@ -178,7 +181,7 @@ async def create_guest_session_enhanced(
"user": guest.model_dump(by_alias=True),
"expiresAt": int((current_time + timedelta(hours=48)).timestamp()),
"userType": "guest",
"isGuest": True
"isGuest": True,
}
return create_success_response(auth_response)
@ -186,25 +189,25 @@ async def create_guest_session_enhanced(
except Exception as e:
logger.error(f"❌ Guest session creation error: {e}")
import traceback
logger.error(traceback.format_exc())
return JSONResponse(
status_code=500,
content=create_error_response("GUEST_CREATION_FAILED", "Failed to create guest session")
status_code=500, content=create_error_response("GUEST_CREATION_FAILED", "Failed to create guest session")
)
@router.post("/guest/convert")
async def convert_guest_to_user(
registration_data: Dict[str, Any] = Body(...),
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Convert a guest session to a permanent user account"""
try:
# Verify current user is a guest
if current_user.user_type != "guest":
return JSONResponse(
status_code=400,
content=create_error_response("NOT_GUEST", "Only guest users can be converted")
status_code=400, content=create_error_response("NOT_GUEST", "Only guest users can be converted")
)
guest: Guest = current_user
@ -215,25 +218,18 @@ async def convert_guest_to_user(
try:
candidate_request = CreateCandidateRequest.model_validate(registration_data)
except ValidationError as e:
return JSONResponse(
status_code=400,
content=create_error_response("VALIDATION_ERROR", str(e))
)
return JSONResponse(status_code=400, content=create_error_response("VALIDATION_ERROR", str(e)))
# Check if email/username already exists
auth_manager = AuthenticationManager(database)
user_exists, conflict_field = await auth_manager.check_user_exists(
candidate_request.email,
candidate_request.username
candidate_request.email, candidate_request.username
)
if user_exists:
return JSONResponse(
status_code=409,
content=create_error_response(
"USER_EXISTS",
f"A user with this {conflict_field} already exists"
)
content=create_error_response("USER_EXISTS", f"A user with this {conflict_field} already exists"),
)
# Create candidate
@ -253,7 +249,7 @@ async def convert_guest_to_user(
"updated_at": current_time.isoformat(),
"status": "active",
"is_admin": False,
"converted_from_guest": guest.id
"converted_from_guest": guest.id,
}
candidate = Candidate.model_validate(candidate_data)
@ -269,7 +265,7 @@ async def convert_guest_to_user(
"id": candidate_id,
"type": "candidate",
"email": candidate.email,
"username": candidate.username
"username": candidate.username,
}
await database.set_user(candidate.email, user_auth_data)
@ -286,43 +282,45 @@ async def convert_guest_to_user(
access_token = create_access_token(data={"sub": candidate_id})
refresh_token = create_access_token(
data={"sub": candidate_id, "type": "refresh"},
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS)
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS),
)
auth_response = AuthResponse(
access_token=access_token,
refresh_token=refresh_token,
user=candidate,
expires_at=int((current_time + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp())
expires_at=int((current_time + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()),
)
logger.info(f"✅ Guest {guest.session_id} converted to candidate {candidate.username}")
return create_success_response({
return create_success_response(
{
"message": "Guest account successfully converted to candidate",
"auth": auth_response.model_dump(by_alias=True),
"conversionType": "candidate"
})
"conversionType": "candidate",
}
)
else:
return JSONResponse(
status_code=400,
content=create_error_response("INVALID_TYPE", "Only candidate conversion is currently supported")
content=create_error_response("INVALID_TYPE", "Only candidate conversion is currently supported"),
)
except Exception as e:
logger.error(f"❌ Guest conversion error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("CONVERSION_FAILED", "Failed to convert guest account")
status_code=500, content=create_error_response("CONVERSION_FAILED", "Failed to convert guest account")
)
@router.post("/logout")
async def logout(
access_token: str = Body(..., alias="accessToken"),
refresh_token: str = Body(..., alias="refreshToken"),
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Logout endpoint - revokes both access and refresh tokens"""
logger.info(f"🔑 User {current_user.id} is logging out")
@ -336,21 +334,18 @@ async def logout(
if not user_id or token_type != "refresh":
return JSONResponse(
status_code=401,
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
)
except jwt.PyJWTError as e:
logger.warning(f"⚠️ Invalid refresh token during logout: {e}")
return JSONResponse(
status_code=401,
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
)
# Verify that the refresh token belongs to the current user
if user_id != current_user.id:
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Token does not belong to current user")
status_code=403, content=create_error_response("FORBIDDEN", "Token does not belong to current user")
)
# Get Redis client
@ -362,12 +357,14 @@ async def logout(
await redis.setex(
f"blacklisted_token:{refresh_token}",
refresh_ttl,
json.dumps({
json.dumps(
{
"user_id": user_id,
"token_type": "refresh",
"revoked_at": datetime.now(UTC).isoformat(),
"reason": "user_logout"
})
"reason": "user_logout",
}
),
)
logger.info(f"🔒 Blacklisted refresh token for user {user_id}")
@ -385,12 +382,14 @@ async def logout(
await redis.setex(
f"blacklisted_token:{access_token}",
access_ttl,
json.dumps({
json.dumps(
{
"user_id": user_id,
"token_type": "access",
"revoked_at": datetime.now(UTC).isoformat(),
"reason": "user_logout"
})
"reason": "user_logout",
}
),
)
logger.info(f"🔒 Blacklisted access token for user {user_id}")
else:
@ -409,26 +408,20 @@ async def logout(
# )
logger.info(f"🔑 User {user_id} logged out successfully")
return create_success_response({
return create_success_response(
{
"message": "Logged out successfully",
"tokensRevoked": {
"refreshToken": True,
"accessToken": bool(access_token)
"tokensRevoked": {"refreshToken": True, "accessToken": bool(access_token)},
}
})
)
except Exception as e:
logger.error(f"❌ Logout error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("LOGOUT_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("LOGOUT_ERROR", str(e)))
@router.post("/logout-all")
async def logout_all_devices(
current_user = Depends(get_current_admin),
database: RedisDatabase = Depends(get_database)
):
async def logout_all_devices(current_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)):
"""Logout from all devices by revoking all tokens for the user"""
try:
redis = redis_manager.get_client()
@ -437,25 +430,20 @@ async def logout_all_devices(
await redis.setex(
f"user_tokens_revoked:{current_user.id}",
int(timedelta(days=30).total_seconds()), # Max refresh token lifetime
datetime.now(UTC).isoformat()
datetime.now(UTC).isoformat(),
)
logger.info(f"🔒 All tokens revoked for user {current_user.id}")
return create_success_response({
"message": "Logged out from all devices successfully"
})
return create_success_response({"message": "Logged out from all devices successfully"})
except Exception as e:
logger.error(f"❌ Logout all devices error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("LOGOUT_ALL_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("LOGOUT_ALL_ERROR", str(e)))
@router.post("/refresh")
async def refresh_token_endpoint(
refresh_token: str = Body(..., alias="refreshToken"),
database: RedisDatabase = Depends(get_database)
refresh_token: str = Body(..., alias="refreshToken"), database: RedisDatabase = Depends(get_database)
):
"""Refresh token endpoint"""
try:
@ -466,8 +454,7 @@ async def refresh_token_endpoint(
if not user_id or token_type != "refresh":
return JSONResponse(
status_code=401,
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
)
# Create new access token
@ -484,37 +471,29 @@ async def refresh_token_endpoint(
user = Employer.model_validate(employer_data)
if not user:
return JSONResponse(
status_code=404,
content=create_error_response("USER_NOT_FOUND", "User not found")
)
return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
auth_response = AuthResponse(
access_token=access_token,
refresh_token=refresh_token, # Keep same refresh token
user=user,
expires_at=int((datetime.now(UTC) + timedelta(hours=24)).timestamp())
expires_at=int((datetime.now(UTC) + timedelta(hours=24)).timestamp()),
)
return create_success_response(auth_response.model_dump(by_alias=True))
except jwt.PyJWTError:
return JSONResponse(
status_code=401,
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
)
return JSONResponse(status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token"))
except Exception as e:
logger.error(f"❌ Token refresh error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("REFRESH_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("REFRESH_ERROR", str(e)))
@router.post("/resend-verification")
async def resend_verification_email(
request: ResendVerificationRequest,
background_tasks: BackgroundTasks,
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Resend verification email with comprehensive rate limiting and validation"""
try:
@ -527,10 +506,7 @@ async def resend_verification_email(
can_send, reason = await rate_limiter.can_send_verification_email(email_lower)
if not can_send:
logger.warning(f"⚠️ Verification email rate limit exceeded for {email_lower}: {reason}")
return JSONResponse(
status_code=429,
content=create_error_response("RATE_LIMITED", reason)
)
return JSONResponse(status_code=429, content=create_error_response("RATE_LIMITED", reason))
# Clean up expired tokens first
await database.cleanup_expired_verification_tokens()
@ -541,9 +517,11 @@ async def resend_verification_email(
# User exists and is verified - don't reveal this for security
logger.info(f"🔍 Resend verification requested for already verified user: {email_lower}")
await rate_limiter.record_email_sent(email_lower) # Record attempt to prevent abuse
return create_success_response({
return create_success_response(
{
"message": "If your email is in our system and pending verification, a new verification email has been sent."
})
}
)
# Look for pending verification token
verification_data = await database.find_verification_token_by_email(email_lower)
@ -552,9 +530,11 @@ async def resend_verification_email(
# No pending verification found - don't reveal this for security
logger.info(f"🔍 Resend verification requested for non-existent pending verification: {email_lower}")
await rate_limiter.record_email_sent(email_lower) # Record attempt to prevent abuse
return create_success_response({
return create_success_response(
{
"message": "If your email is in our system and pending verification, a new verification email has been sent."
})
}
)
# Check if verification token has expired
expires_at = datetime.fromisoformat(verification_data["expires_at"])
@ -568,8 +548,8 @@ async def resend_verification_email(
status_code=400,
content=create_error_response(
"TOKEN_EXPIRED",
"Your verification link has expired. Please register again to create a new account."
)
"Your verification link has expired. Please register again to create a new account.",
),
)
# Generate new verification token (invalidate old one)
@ -577,20 +557,19 @@ async def resend_verification_email(
new_token = secrets.token_urlsafe(32)
# Update verification data with new token and reset attempts
verification_data.update({
verification_data.update(
{
"token": new_token,
"expires_at": (current_time + timedelta(hours=24)).isoformat(),
"resent_at": current_time.isoformat(),
"resend_count": verification_data.get("resend_count", 0) + 1
})
"resend_count": verification_data.get("resend_count", 0) + 1,
}
)
# Store new token and remove old one
await database.redis.delete(f"email_verification:{old_token}")
await database.store_email_verification_token(
email_lower,
new_token,
verification_data["user_type"],
verification_data["user_data"]
email_lower, new_token, verification_data["user_type"], verification_data["user_data"]
)
# Get user name for email
@ -610,69 +589,62 @@ async def resend_verification_email(
await rate_limiter.record_email_sent(email_lower)
# Send new verification email in background
background_tasks.add_task(
email_service.send_verification_email,
email_lower,
new_token,
user_name,
user_type
)
background_tasks.add_task(email_service.send_verification_email, email_lower, new_token, user_name, user_type)
# Log security event
await database.log_security_event(
verification_data["user_data"].get("candidate_data", {}).get("id") or
verification_data["user_data"].get("employer_data", {}).get("id") or "unknown",
verification_data["user_data"].get("candidate_data", {}).get("id")
or verification_data["user_data"].get("employer_data", {}).get("id")
or "unknown",
"verification_resend",
{
"email": email_lower,
"user_type": user_type,
"resend_count": verification_data.get("resend_count", 1),
"old_token_invalidated": old_token[:8] + "...", # Log partial token for debugging
"ip_address": "unknown" # You can extract this from request if needed
"ip_address": "unknown", # You can extract this from request if needed
},
)
logger.info(
f"✅ Verification email resent to {email_lower} (attempt #{verification_data.get('resend_count', 1)})"
)
return create_success_response(
{
"message": "A new verification email has been sent to your email address. Please check your inbox and spam folder.",
"resendCount": verification_data.get("resend_count", 1),
}
)
logger.info(f"✅ Verification email resent to {email_lower} (attempt #{verification_data.get('resend_count', 1)})")
return create_success_response({
"message": "A new verification email has been sent to your email address. Please check your inbox and spam folder.",
"resendCount": verification_data.get("resend_count", 1)
})
except ValueError as ve:
logger.warning(f"⚠️ Invalid resend verification request: {ve}")
return JSONResponse(
status_code=400,
content=create_error_response("VALIDATION_ERROR", str(ve))
)
return JSONResponse(status_code=400, content=create_error_response("VALIDATION_ERROR", str(ve)))
except Exception as e:
logger.error(f"❌ Resend verification email error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("RESEND_FAILED", "An error occurred while processing your request. Please try again later.")
content=create_error_response(
"RESEND_FAILED", "An error occurred while processing your request. Please try again later."
),
)
@router.post("/mfa/request")
async def request_mfa(
request: MFARequest,
background_tasks: BackgroundTasks,
http_request: Request,
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Request MFA for login from new device"""
try:
# Verify credentials first
auth_manager = AuthenticationManager(database)
is_valid, user_data, error_message = await auth_manager.verify_user_credentials(
request.email,
request.password
)
is_valid, user_data, error_message = await auth_manager.verify_user_credentials(request.email, request.password)
if not is_valid or not user_data:
return JSONResponse(
status_code=401,
content=create_error_response("AUTH_FAILED", "Invalid credentials")
)
return JSONResponse(status_code=401, content=create_error_response("AUTH_FAILED", "Invalid credentials"))
# Check if device is trusted
device_manager = DeviceManager(database)
@ -684,10 +656,7 @@ async def request_mfa(
# Device is trusted, proceed with normal login
await device_manager.update_device_last_used(user_data["id"], request.device_id)
return create_success_response({
"mfa_required": False,
"message": "Device is trusted, proceed with login"
})
return create_success_response({"mfa_required": False, "message": "Device is trusted, proceed with login"})
# Generate MFA code
mfa_code = f"{secrets.randbelow(1000000):06d}" # 6-digit code
@ -709,8 +678,7 @@ async def request_mfa(
if not email:
return JSONResponse(
status_code=400,
content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA")
status_code=400, content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA")
)
# Store MFA code
@ -718,13 +686,7 @@ async def request_mfa(
logger.info(f"🔐 MFA code generated for {email} on device {request.device_id}")
# Send MFA code via email
background_tasks.add_task(
email_service.send_mfa_email,
email,
mfa_code,
request.device_name,
user_name
)
background_tasks.add_task(email_service.send_mfa_email, email, mfa_code, request.device_name, user_name)
logger.info(f"🔐 MFA requested for {request.email} from new device {request.device_name}")
@ -735,25 +697,22 @@ async def request_mfa(
device_id=request.device_id,
device_name=request.device_name,
)
mfa_response = MFARequestResponse(
mfa_required=True,
mfa_data=mfa_data
)
mfa_response = MFARequestResponse(mfa_required=True, mfa_data=mfa_data)
return create_success_response(mfa_response)
except Exception as e:
logger.error(f"❌ MFA request error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("MFA_REQUEST_FAILED", "Failed to process MFA request")
status_code=500, content=create_error_response("MFA_REQUEST_FAILED", "Failed to process MFA request")
)
@router.post("/login")
async def login(
request: LoginRequest,
http_request: Request,
background_tasks: BackgroundTasks,
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""login with automatic MFA email sending for new devices"""
try:
@ -766,16 +725,12 @@ async def login(
device_id = device_info["device_id"]
# Verify credentials first
is_valid, user_data, error_message = await auth_manager.verify_user_credentials(
request.login,
request.password
)
is_valid, user_data, error_message = await auth_manager.verify_user_credentials(request.login, request.password)
if not is_valid or not user_data:
logger.warning(f"⚠️ Failed login attempt for: {request.login}")
return JSONResponse(
status_code=401,
content=create_error_response("AUTH_FAILED", error_message or "Invalid credentials")
status_code=401, content=create_error_response("AUTH_FAILED", error_message or "Invalid credentials")
)
# Check if device is trusted
@ -804,8 +759,7 @@ async def login(
if not email:
return JSONResponse(
status_code=400,
content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA")
status_code=400, content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA")
)
# Store MFA code
@ -817,12 +771,7 @@ async def login(
# Send MFA code via email in background
background_tasks.add_task(
email_service.send_mfa_email,
email,
mfa_code,
device_info["device_name"],
user_name,
ip_address
email_service.send_mfa_email, email, mfa_code, device_info["device_name"], user_name, ip_address
)
# Log security event
@ -834,8 +783,8 @@ async def login(
"device_name": device_info["device_name"],
"ip_address": ip_address,
"user_agent": device_info.get("user_agent", ""),
"auto_sent": True
}
"auto_sent": True,
},
)
logger.info(f"🔐 MFA code automatically sent to {request.login} for device {device_info['device_name']}")
@ -847,8 +796,8 @@ async def login(
email=email,
device_id=device_id,
device_name=device_info["device_name"],
code_sent=mfa_code
)
code_sent=mfa_code,
),
)
return create_success_response(mfa_response.model_dump(by_alias=True))
@ -860,7 +809,7 @@ async def login(
access_token = create_access_token(data={"sub": user_data["id"]})
refresh_token = create_access_token(
data={"sub": user_data["id"], "type": "refresh"},
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS)
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS),
)
# Get user object
@ -876,8 +825,7 @@ async def login(
if not user:
return JSONResponse(
status_code=404,
content=create_error_response("USER_NOT_FOUND", "User profile not found")
status_code=404, content=create_error_response("USER_NOT_FOUND", "User profile not found")
)
# Log successful login from trusted device
@ -888,8 +836,8 @@ async def login(
"device_id": device_id,
"device_name": device_info["device_name"],
"ip_address": http_request.client.host if http_request.client else "Unknown",
"trusted_device": True
}
"trusted_device": True,
},
)
# Create response
@ -897,7 +845,9 @@ async def login(
access_token=access_token,
refresh_token=refresh_token,
user=user,
expires_at=int((datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp())
expires_at=int(
(datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()
),
)
logger.info(f"🔑 User {request.login} logged in successfully from trusted device")
@ -908,17 +858,12 @@ async def login(
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Login error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("LOGIN_ERROR", "An error occurred during login")
status_code=500, content=create_error_response("LOGIN_ERROR", "An error occurred during login")
)
@router.post("/mfa/verify")
async def verify_mfa(
request: MFAVerifyRequest,
http_request: Request,
database: RedisDatabase = Depends(get_database)
):
async def verify_mfa(request: MFAVerifyRequest, http_request: Request, database: RedisDatabase = Depends(get_database)):
"""Verify MFA code and complete login with error handling"""
try:
# Get MFA data
@ -928,13 +873,17 @@ async def verify_mfa(
logger.warning(f"⚠️ No MFA session found for {request.email} on device {request.device_id}")
return JSONResponse(
status_code=404,
content=create_error_response("NO_MFA_SESSION", "No active MFA session found. Please try logging in again.")
content=create_error_response(
"NO_MFA_SESSION", "No active MFA session found. Please try logging in again."
),
)
if mfa_data.get("verified"):
return JSONResponse(
status_code=400,
content=create_error_response("ALREADY_VERIFIED", "This MFA code has already been used. Please login again.")
content=create_error_response(
"ALREADY_VERIFIED", "This MFA code has already been used. Please login again."
),
)
# Check expiration
@ -944,7 +893,7 @@ async def verify_mfa(
await database.redis.delete(f"mfa_code:{request.email.lower()}:{request.device_id}")
return JSONResponse(
status_code=400,
content=create_error_response("MFA_EXPIRED", "MFA code has expired. Please try logging in again.")
content=create_error_response("MFA_EXPIRED", "MFA code has expired. Please try logging in again."),
)
# Check attempts
@ -954,7 +903,9 @@ async def verify_mfa(
await database.redis.delete(f"mfa_code:{request.email.lower()}:{request.device_id}")
return JSONResponse(
status_code=429,
content=create_error_response("TOO_MANY_ATTEMPTS", "Too many incorrect attempts. Please try logging in again.")
content=create_error_response(
"TOO_MANY_ATTEMPTS", "Too many incorrect attempts. Please try logging in again."
),
)
# Verify code
@ -965,9 +916,8 @@ async def verify_mfa(
return JSONResponse(
status_code=400,
content=create_error_response(
"INVALID_CODE",
f"Invalid MFA code. {remaining_attempts} attempts remaining."
)
"INVALID_CODE", f"Invalid MFA code. {remaining_attempts} attempts remaining."
),
)
# Mark as verified
@ -976,20 +926,13 @@ async def verify_mfa(
# Get user data
user_data = await database.get_user(request.email)
if not user_data:
return JSONResponse(
status_code=404,
content=create_error_response("USER_NOT_FOUND", "User not found")
)
return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
# Add device to trusted devices if requested
if request.remember_device:
device_manager = DeviceManager(database)
device_info = device_manager.parse_device_info(http_request)
await device_manager.add_trusted_device(
user_data["id"],
request.device_id,
device_info
)
await device_manager.add_trusted_device(user_data["id"], request.device_id, device_info)
logger.info(f"🔒 Device {request.device_id} added to trusted devices for user {user_data['id']}")
# Update last login
@ -1000,7 +943,7 @@ async def verify_mfa(
access_token = create_access_token(data={"sub": user_data["id"]})
refresh_token = create_access_token(
data={"sub": user_data["id"], "type": "refresh"},
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS)
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS),
)
# Get user object
@ -1016,8 +959,7 @@ async def verify_mfa(
if not user:
return JSONResponse(
status_code=404,
content=create_error_response("USER_NOT_FOUND", "User profile not found")
status_code=404, content=create_error_response("USER_NOT_FOUND", "User profile not found")
)
# Log successful MFA verification and login
@ -1028,8 +970,8 @@ async def verify_mfa(
"device_id": request.device_id,
"ip_address": http_request.client.host if http_request.client else "Unknown",
"device_remembered": request.remember_device,
"attempts_used": current_attempts + 1
}
"attempts_used": current_attempts + 1,
},
)
await database.log_security_event(
@ -1039,8 +981,8 @@ async def verify_mfa(
"device_id": request.device_id,
"ip_address": http_request.client.host if http_request.client else "Unknown",
"mfa_verified": True,
"new_device": True
}
"new_device": True,
},
)
# Clean up MFA session
@ -1051,7 +993,9 @@ async def verify_mfa(
access_token=access_token,
refresh_token=refresh_token,
user=user,
expires_at=int((datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp())
expires_at=int(
(datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()
),
)
logger.info(f"✅ MFA verified and login completed for {request.email}")
@ -1062,15 +1006,12 @@ async def verify_mfa(
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ MFA verification error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("MFA_VERIFICATION_FAILED", "Failed to verify MFA")
status_code=500, content=create_error_response("MFA_VERIFICATION_FAILED", "Failed to verify MFA")
)
@router.post("/password-reset/request")
async def request_password_reset(
request: PasswordResetRequest,
database: RedisDatabase = Depends(get_database)
):
async def request_password_reset(request: PasswordResetRequest, database: RedisDatabase = Depends(get_database)):
"""Request password reset"""
try:
# Check if user exists
@ -1100,15 +1041,12 @@ async def request_password_reset(
except Exception as e:
logger.error(f"❌ Password reset request error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("RESET_ERROR", "An error occurred processing the request")
status_code=500, content=create_error_response("RESET_ERROR", "An error occurred processing the request")
)
@router.post("/password-reset/confirm")
async def confirm_password_reset(
request: PasswordResetConfirm,
database: RedisDatabase = Depends(get_database)
):
async def confirm_password_reset(request: PasswordResetConfirm, database: RedisDatabase = Depends(get_database)):
"""Confirm password reset with token"""
try:
# Find user by reset token
@ -1122,8 +1060,5 @@ async def confirm_password_reset(
except Exception as e:
logger.error(f"❌ Password reset confirm error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("RESET_ERROR", "An error occurred resetting the password")
status_code=500, content=create_error_response("RESET_ERROR", "An error occurred resetting the password")
)

File diff suppressed because it is too large Load Diff

View File

@ -4,96 +4,69 @@ Chat routes
import json
import uuid
from datetime import datetime, UTC
from typing import (Dict, Any)
from typing import Dict, Any
from fastapi import (
APIRouter, Depends, Body, Depends, Query, Path,
Body, APIRouter
)
from fastapi import APIRouter, Depends, Body, Depends, Query, Path, Body, APIRouter
from fastapi.responses import JSONResponse
from database.manager import RedisDatabase
from logger import logger
from utils.dependencies import (
get_database, get_current_user, get_current_user_or_guest
)
from utils.responses import (
create_success_response, create_error_response, create_paginated_response
)
from utils.helpers import (
stream_agent_response
)
from utils.dependencies import get_database, get_current_user, get_current_user_or_guest
from utils.responses import create_success_response, create_error_response, create_paginated_response
from utils.helpers import stream_agent_response
import backstory_traceback
import entities.entity_manager as entities
from models import (
Candidate, ChatMessageUser, Candidate, BaseUserWithType, ChatSession, ChatMessage
)
from models import Candidate, ChatMessageUser, Candidate, BaseUserWithType, ChatSession, ChatMessage
# Create router for authentication endpoints
router = APIRouter(prefix="/chat", tags=["chat"])
@router.post("/sessions/{session_id}/archive")
async def archive_chat_session(
session_id: str = Path(...),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
session_id: str = Path(...), current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)
):
"""Archive a chat session"""
try:
session_data = await database.get_chat_session(session_id)
if not session_data:
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
# Check if user owns this session or is admin
if session_data.get("userId") != current_user.id:
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Cannot archive another user's session")
status_code=403, content=create_error_response("FORBIDDEN", "Cannot archive another user's session")
)
await database.archive_chat_session(session_id)
return create_success_response({
"message": "Chat session archived successfully",
"sessionId": session_id
})
return create_success_response({"message": "Chat session archived successfully", "sessionId": session_id})
except Exception as e:
logger.error(f"❌ Archive chat session error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("ARCHIVE_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("ARCHIVE_ERROR", str(e)))
@router.get("/statistics")
async def get_chat_statistics(
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
):
async def get_chat_statistics(current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)):
"""Get chat statistics (admin/analytics endpoint)"""
try:
stats = await database.get_chat_statistics()
return create_success_response(stats)
except Exception as e:
logger.error(f"❌ Get chat statistics error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("STATS_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("STATS_ERROR", str(e)))
@router.post("/sessions")
async def create_chat_session(
session_data: Dict[str, Any] = Body(...),
current_user: BaseUserWithType = Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Create a new chat session with optional candidate username association"""
try:
@ -111,15 +84,14 @@ async def create_chat_session(
candidates_list = [Candidate.model_validate(data) for data in all_candidates_data.values()]
# Find candidate by username (case-insensitive)
matching_candidates = [
c for c in candidates_list
if c.username.lower() == username.lower()
]
matching_candidates = [c for c in candidates_list if c.username.lower() == username.lower()]
if not matching_candidates:
return JSONResponse(
status_code=404,
content=create_error_response("CANDIDATE_NOT_FOUND", f"Candidate with username '{username}' not found")
content=create_error_response(
"CANDIDATE_NOT_FOUND", f"Candidate with username '{username}' not found"
),
)
candidate_data = matching_candidates[0]
@ -148,7 +120,7 @@ async def create_chat_session(
"username": candidate_data.username,
"skills": [skill.name for skill in candidate_data.skills] if candidate_data.skills else [],
"experience": len(candidate_data.experience) if candidate_data.experience else 0,
"location": candidate_data.location.city if candidate_data.location else "Unknown"
"location": candidate_data.location.city if candidate_data.location else "Unknown",
}
context["additionalContext"] = additional_context
@ -162,8 +134,10 @@ async def create_chat_session(
chat_session = ChatSession.model_validate(session_data)
await database.set_chat_session(chat_session.id, chat_session.model_dump())
logger.info(f"✅ Chat session created: {chat_session.id} for user {current_user.id}" +
(f" about candidate {candidate_data.full_name}" if candidate_data else ""))
logger.info(
f"✅ Chat session created: {chat_session.id} for user {current_user.id}"
+ (f" about candidate {candidate_data.full_name}" if candidate_data else "")
)
return create_success_response(chat_session.model_dump(by_alias=True))
@ -171,49 +145,58 @@ async def create_chat_session(
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Chat session creation error: {e}")
logger.info(json.dumps(session_data, indent=2))
return JSONResponse(
status_code=400,
content=create_error_response("CREATION_FAILED", str(e))
)
return JSONResponse(status_code=400, content=create_error_response("CREATION_FAILED", str(e)))
@router.post("/sessions/messages/stream")
async def post_chat_session_message_stream(
user_message: ChatMessageUser = Body(...),
current_user=Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Post a message to a chat session and stream the response with persistence"""
try:
chat_session_data = await database.get_chat_session(user_message.session_id)
if not chat_session_data:
logger.info("🔗 Chat session not found for session ID: " + user_message.session_id)
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
chat_session = ChatSession.model_validate(chat_session_data)
chat_type = chat_session.context.type
candidate_info = chat_session.context.additional_context.get("candidateInfo", {}) if chat_session.context and chat_session.context.additional_context else None
candidate_info = (
chat_session.context.additional_context.get("candidateInfo", {})
if chat_session.context and chat_session.context.additional_context
else None
)
# Get candidate info if this chat is about a specific candidate
if candidate_info:
logger.info(f"🔗 Chat session {user_message.session_id} about candidate {candidate_info['name']} accessed by user {current_user.id}")
logger.info(
f"🔗 Chat session {user_message.session_id} about candidate {candidate_info['name']} accessed by user {current_user.id}"
)
else:
logger.info(f"🔗 Chat session {user_message.session_id} type {chat_type} accessed by user {current_user.id}")
logger.info(
f"🔗 Chat session {user_message.session_id} type {chat_type} accessed by user {current_user.id}"
)
return JSONResponse(
status_code=400,
content=create_error_response("CANDIDATE_REQUIRED", "This chat session requires a candidate association")
content=create_error_response(
"CANDIDATE_REQUIRED", "This chat session requires a candidate association"
),
)
candidate_data = await database.get_candidate(candidate_info["id"]) if candidate_info else None
candidate: Candidate | None = Candidate.model_validate(candidate_data) if candidate_data else None
if not candidate:
logger.info(f"🔗 Candidate not found for chat session {user_message.session_id} with ID {candidate_info['id']}")
logger.info(
f"🔗 Candidate not found for chat session {user_message.session_id} with ID {candidate_info['id']}"
)
return JSONResponse(
status_code=404,
content=create_error_response("CANDIDATE_NOT_FOUND", "Candidate not found for this chat session")
content=create_error_response("CANDIDATE_NOT_FOUND", "Candidate not found for this chat session"),
)
logger.info(
f"🔗 User {current_user.id} posting message to chat session {user_message.session_id} with query length: {len(user_message.content)}"
)
logger.info(f"🔗 User {current_user.id} posting message to chat session {user_message.session_id} with query length: {len(user_message.content)}")
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
# Entity automatically released when done
@ -222,7 +205,7 @@ async def post_chat_session_message_stream(
logger.info(f"🔗 No chat agent found for session {user_message.session_id} with type {chat_type}")
return JSONResponse(
status_code=400,
content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type")
content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type"),
)
# Persist user message to database
@ -243,10 +226,8 @@ async def post_chat_session_message_stream(
except Exception:
logger.error(backstory_traceback.format_exc())
logger.error("❌ Chat message streaming error")
return JSONResponse(
status_code=500,
content=create_error_response("STREAMING_ERROR", "")
)
return JSONResponse(status_code=500, content=create_error_response("STREAMING_ERROR", ""))
@router.get("/sessions/{session_id}/messages")
async def get_chat_session_messages(
@ -254,16 +235,13 @@ async def get_chat_session_messages(
current_user=Depends(get_current_user_or_guest),
page: int = Query(1, ge=1),
limit: int = Query(50, ge=1, le=100), # Increased default for chat messages
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Get persisted chat messages for a session"""
try:
chat_session_data = await database.get_chat_session(session_id)
if not chat_session_data:
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
# Get messages from database
chat_messages = await database.get_chat_messages(session_id)
@ -288,43 +266,36 @@ async def get_chat_session_messages(
paginated_messages = messages_list[start:end]
paginated_response = create_paginated_response(
[m.model_dump(by_alias=True) for m in paginated_messages],
page, limit, total
[m.model_dump(by_alias=True) for m in paginated_messages], page, limit, total
)
return create_success_response(paginated_response)
except Exception as e:
logger.error(f"❌ Get chat messages error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("FETCH_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e)))
@router.patch("/sessions/{session_id}")
async def update_chat_session(
session_id: str = Path(...),
updates: Dict[str, Any] = Body(...),
current_user=Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Update a chat session's properties"""
try:
# Get the existing session
session_data = await database.get_chat_session(session_id)
if not session_data:
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
session = ChatSession.model_validate(session_data)
# Check authorization - user can only update their own sessions
if session.user_id != current_user.id:
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Cannot update another user's chat session")
status_code=403, content=create_error_response("FORBIDDEN", "Cannot update another user's chat session")
)
# Validate and apply updates
@ -333,8 +304,7 @@ async def update_chat_session(
if not filtered_updates:
return JSONResponse(
status_code=400,
content=create_error_response("INVALID_UPDATES", "No valid fields provided for update")
status_code=400, content=create_error_response("INVALID_UPDATES", "No valid fields provided for update")
)
# Apply updates to session data
@ -388,40 +358,31 @@ async def update_chat_session(
except ValueError as ve:
logger.warning(f"⚠️ Validation error updating chat session: {ve}")
return JSONResponse(
status_code=400,
content=create_error_response("VALIDATION_ERROR", str(ve))
)
return JSONResponse(status_code=400, content=create_error_response("VALIDATION_ERROR", str(ve)))
except Exception as e:
logger.error(f"❌ Update chat session error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("UPDATE_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("UPDATE_ERROR", str(e)))
@router.delete("/sessions/{session_id}")
async def delete_chat_session(
session_id: str = Path(...),
current_user=Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Delete a chat session and all its messages"""
try:
# Get the session to verify it exists and check ownership
session_data = await database.get_chat_session(session_id)
if not session_data:
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
session = ChatSession.model_validate(session_data)
# Check authorization - user can only delete their own sessions
if session.user_id != current_user.id:
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Cannot delete another user's chat session")
status_code=403, content=create_error_response("FORBIDDEN", "Cannot delete another user's chat session")
)
# Delete all messages associated with this session
@ -440,42 +401,34 @@ async def delete_chat_session(
logger.info(f"🗑️ Chat session {session_id} deleted by user {current_user.id}")
return create_success_response({
"success": True,
"message": "Chat session deleted successfully",
"sessionId": session_id
})
return create_success_response(
{"success": True, "message": "Chat session deleted successfully", "sessionId": session_id}
)
except Exception as e:
logger.error(f"❌ Delete chat session error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("DELETE_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("DELETE_ERROR", str(e)))
@router.patch("/sessions/{session_id}/reset")
async def reset_chat_session(
session_id: str = Path(...),
current_user=Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Delete a chat session and all its messages"""
try:
# Get the session to verify it exists and check ownership
session_data = await database.get_chat_session(session_id)
if not session_data:
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
session = ChatSession.model_validate(session_data)
# Check authorization - user can only delete their own sessions
if session.user_id != current_user.id:
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Cannot reset another user's chat session")
status_code=403, content=create_error_response("FORBIDDEN", "Cannot reset another user's chat session")
)
# Delete all messages associated with this session
@ -489,20 +442,12 @@ async def reset_chat_session(
logger.warning(f"⚠️ Error deleting messages for session {session_id}: {e}")
# Continue with session deletion even if message deletion fails
logger.info(f"🗑️ Chat session {session_id} reset by user {current_user.id}")
return create_success_response({
"success": True,
"message": "Chat session reset successfully",
"sessionId": session_id
})
return create_success_response(
{"success": True, "message": "Chat session reset successfully", "sessionId": session_id}
)
except Exception as e:
logger.error(f"❌ Reset chat session error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("RESET_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("RESET_ERROR", str(e)))

View File

@ -9,19 +9,16 @@ from fastapi.responses import JSONResponse
from database.manager import RedisDatabase
from logger import logger
from utils.dependencies import (
get_current_admin, get_database
)
from utils.dependencies import get_current_admin, get_database
from utils.responses import create_success_response, create_error_response
# Create router for authentication endpoints
router = APIRouter(prefix="/auth", tags=["authentication"])
@router.get("/guest/{guest_id}")
async def debug_guest_session(
guest_id: str = Path(...),
admin_user = Depends(get_current_admin),
database: RedisDatabase = Depends(get_database)
guest_id: str = Path(...), admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
):
"""Debug guest session issues (admin only)"""
try:
@ -45,22 +42,19 @@ async def debug_guest_session(
"primary_storage": {
"exists": primary_exists,
"data": json.loads(primary_data) if primary_data else None,
"ttl": primary_ttl
"ttl": primary_ttl,
},
"backup_storage": {
"exists": backup_exists,
"data": json.loads(backup_data) if backup_data else None,
"ttl": backup_ttl
"ttl": backup_ttl,
},
"user_lookup": user_lookup,
"timestamp": datetime.now(UTC).isoformat()
"timestamp": datetime.now(UTC).isoformat(),
}
return create_success_response(debug_info)
except Exception as e:
logger.error(f"❌ Debug guest session error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("DEBUG_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("DEBUG_ERROR", str(e)))

View File

@ -10,12 +10,8 @@ from fastapi.responses import JSONResponse
from database.manager import RedisDatabase
from logger import logger
from models import (
CreateEmployerRequest
)
from utils.dependencies import (
get_database
)
from models import CreateEmployerRequest
from utils.dependencies import get_database
from utils.responses import create_success_response, create_error_response
from email_service import email_service
from utils.auth_utils import AuthenticationManager
@ -23,29 +19,22 @@ from utils.auth_utils import AuthenticationManager
# Create router for job endpoints
router = APIRouter(prefix="/employers", tags=["employers"])
@router.post("")
async def create_employer_with_verification(
request: CreateEmployerRequest,
background_tasks: BackgroundTasks,
database: RedisDatabase = Depends(get_database)
request: CreateEmployerRequest, background_tasks: BackgroundTasks, database: RedisDatabase = Depends(get_database)
):
"""Create a new employer with email verification"""
try:
# Similar to candidate creation but for employer
auth_manager = AuthenticationManager(database)
user_exists, conflict_field = await auth_manager.check_user_exists(
request.email,
request.username
)
user_exists, conflict_field = await auth_manager.check_user_exists(request.email, request.username)
if user_exists and conflict_field:
return JSONResponse(
status_code=409,
content=create_error_response(
"USER_EXISTS",
f"A user with this {conflict_field} already exists"
)
content=create_error_response("USER_EXISTS", f"A user with this {conflict_field} already exists"),
)
employer_id = str(uuid.uuid4())
@ -64,12 +53,8 @@ async def create_employer_with_verification(
"updatedAt": current_time.isoformat(),
"status": "pending", # Not active until verified
"userType": "employer",
"location": {
"city": "",
"country": "",
"remote": False
},
"socialLinks": []
"location": {"city": "", "country": "", "remote": False},
"socialLinks": [],
}
verification_token = secrets.token_urlsafe(32)
@ -78,32 +63,25 @@ async def create_employer_with_verification(
request.email,
verification_token,
"employer",
{
"employer_data": employer_data,
"password": request.password,
"username": request.username
}
{"employer_data": employer_data, "password": request.password, "username": request.username},
)
background_tasks.add_task(
email_service.send_verification_email,
request.email,
verification_token,
request.company_name
email_service.send_verification_email, request.email, verification_token, request.company_name
)
logger.info(f"✅ Employer registration initiated for: {request.email}")
return create_success_response({
return create_success_response(
{
"message": "Registration successful! Please check your email to verify your account.",
"email": request.email,
"verificationRequired": True
})
"verificationRequired": True,
}
)
except Exception as e:
logger.error(f"❌ Employer creation error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("CREATION_FAILED", "Failed to create employer account")
status_code=500, content=create_error_response("CREATION_FAILED", "Failed to create employer account")
)

View File

@ -21,14 +21,24 @@ from utils.helpers import create_job_from_content, filter_and_paginate, get_docu
from database.manager import RedisDatabase
from logger import logger
from models import (
MOCK_UUID, ApiActivityType, ApiStatusType, ChatContextType, ChatMessage, ChatMessageError, ChatMessageStatus, DocumentType, Job, JobRequirementsMessage, Candidate, Employer
)
from utils.dependencies import (
get_current_admin, get_database, get_current_user
MOCK_UUID,
ApiActivityType,
ApiStatusType,
ChatContextType,
ChatMessage,
ChatMessageError,
ChatMessageStatus,
DocumentType,
Job,
JobRequirementsMessage,
Candidate,
Employer,
)
from utils.dependencies import get_current_admin, get_database, get_current_user
from utils.responses import create_paginated_response, create_success_response, create_error_response
import utils.llm_proxy as llm_manager
import entities.entity_manager as entities
# Create router for job endpoints
router = APIRouter(prefix="/jobs", tags=["jobs"])
@ -38,14 +48,14 @@ async def reformat_as_markdown(database: RedisDatabase, candidate_entity: Candid
if not chat_agent:
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="No agent found for job requirements chat type"
content="No agent found for job requirements chat type",
)
yield error_message
return
status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads
content="Reformatting job description as markdown...",
activity=ApiActivityType.CONVERTING
activity=ApiActivityType.CONVERTING,
)
yield status_message
@ -58,7 +68,7 @@ async def reformat_as_markdown(database: RedisDatabase, candidate_entity: Candid
system_prompt="""
You are a document editor. Take the provided job description and reformat as legible markdown.
Return only the markdown content, no other text. Make sure all content is included.
"""
""",
):
pass
@ -66,7 +76,7 @@ Return only the markdown content, no other text. Make sure all content is includ
logger.error("❌ Failed to reformat job description to markdown")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to reformat job description"
content="Failed to reformat job description",
)
yield error_message
return
@ -84,7 +94,7 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Initiating connection with {current_user.first_name}'s AI agent...",
activity=ApiActivityType.INFO
activity=ApiActivityType.INFO,
)
yield status_message
await asyncio.sleep(0) # Let the status message propagate
@ -98,7 +108,7 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
if not message or not isinstance(message, ChatMessage):
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to reformat job description"
content="Failed to reformat job description",
)
yield error_message
return
@ -108,23 +118,20 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
if not chat_agent:
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="No agent found for job requirements chat type"
content="No agent found for job requirements chat type",
)
yield error_message
return
status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads
content="Analyzing document for company and requirement details...",
activity=ApiActivityType.SEARCHING
activity=ApiActivityType.SEARCHING,
)
yield status_message
message = None
async for message in chat_agent.generate(
llm=llm_manager.get_llm(),
model=defines.model,
session_id=MOCK_UUID,
prompt=markdown_message.content
llm=llm_manager.get_llm(), model=defines.model, session_id=MOCK_UUID, prompt=markdown_message.content
):
if message.status != ApiStatusType.DONE:
yield message
@ -132,7 +139,7 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
if not message or not isinstance(message, JobRequirementsMessage):
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Job extraction did not convert successfully"
content="Job extraction did not convert successfully",
)
yield error_message
return
@ -143,12 +150,11 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
return
@router.post("")
async def create_job(
job_data: Dict[str, Any] = Body(...),
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Create a new job"""
try:
@ -165,17 +171,14 @@ async def create_job(
except Exception as e:
logger.error(f"❌ Job creation error: {e}")
return JSONResponse(
status_code=400,
content=create_error_response("CREATION_FAILED", str(e))
)
return JSONResponse(status_code=400, content=create_error_response("CREATION_FAILED", str(e)))
@router.post("")
async def create_candidate_job(
job_data: Dict[str, Any] = Body(...),
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Create a new job"""
isinstance(current_user, Employer)
@ -194,10 +197,7 @@ async def create_candidate_job(
except Exception as e:
logger.error(f"❌ Job creation error: {e}")
return JSONResponse(
status_code=400,
content=create_error_response("CREATION_FAILED", str(e))
)
return JSONResponse(status_code=400, content=create_error_response("CREATION_FAILED", str(e)))
@router.patch("/{job_id}")
@ -205,17 +205,14 @@ async def update_job(
job_id: str = Path(...),
updates: Dict[str, Any] = Body(...),
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Update a candidate"""
try:
job_data = await database.get_job(job_id)
if not job_data:
logger.warning(f"⚠️ Job not found for update: {job_data}")
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Job not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Job not found"))
job = Job.model_validate(job_data)
@ -223,8 +220,7 @@ async def update_job(
if current_user.is_admin is False and job.owner_id != current_user.id:
logger.warning(f"⚠️ Unauthorized update attempt by user {current_user.id} on job {job_id}")
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Cannot update another user's job")
status_code=403, content=create_error_response("FORBIDDEN", "Cannot update another user's job")
)
# Apply updates
@ -239,25 +235,22 @@ async def update_job(
except Exception as e:
logger.error(f"❌ Update job error: {e}")
return JSONResponse(
status_code=400,
content=create_error_response("UPDATE_FAILED", str(e))
)
return JSONResponse(status_code=400, content=create_error_response("UPDATE_FAILED", str(e)))
@router.post("/from-content")
async def create_job_from_description(
content: str = Body(...),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
content: str = Body(...), current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)
):
"""Upload a document for the current candidate"""
async def content_stream_generator(content):
# Verify user is a candidate
if current_user.user_type != "candidate":
logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Only candidates can upload documents"
content="Only candidates can upload documents",
)
yield error_message
return
@ -277,10 +270,11 @@ async def create_job_from_description(
return
try:
async def to_json(method):
try:
async for message in method:
json_data = message.model_dump(mode='json', by_alias=True)
json_data = message.model_dump(mode="json", by_alias=True)
json_str = json.dumps(json_data)
yield f"data: {json_str}\n\n".encode("utf-8")
except Exception as e:
@ -304,18 +298,25 @@ async def create_job_from_description(
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Document upload error: {e}")
return StreamingResponse(
iter([json.dumps(ChatMessageError(
iter(
[
json.dumps(
ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to upload document"
).model_dump(by_alias=True)).encode("utf-8")]),
media_type="text/event-stream"
content="Failed to upload document",
).model_dump(by_alias=True)
).encode("utf-8")
]
),
media_type="text/event-stream",
)
@router.post("/upload")
async def create_job_from_file(
file: UploadFile = File(...),
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Upload a job document for the current candidate and create a Job"""
# Check file size (limit to 10MB)
@ -324,55 +325,70 @@ async def create_job_from_file(
if len(file_content) > max_size:
logger.info(f"⚠️ File too large: {file.filename} ({len(file_content)} bytes)")
return StreamingResponse(
iter([json.dumps(ChatMessageError(
iter(
[
json.dumps(
ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="File size exceeds 10MB limit"
).model_dump(by_alias=True)).encode("utf-8")]),
media_type="text/event-stream"
content="File size exceeds 10MB limit",
).model_dump(by_alias=True)
).encode("utf-8")
]
),
media_type="text/event-stream",
)
if len(file_content) == 0:
logger.info(f"⚠️ File is empty: {file.filename}")
return StreamingResponse(
iter([json.dumps(ChatMessageError(
iter(
[
json.dumps(
ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="File is empty"
).model_dump(by_alias=True)).encode("utf-8")]),
media_type="text/event-stream"
content="File is empty",
).model_dump(by_alias=True)
).encode("utf-8")
]
),
media_type="text/event-stream",
)
"""Upload a document for the current candidate"""
async def upload_stream_generator(file_content):
# Verify user is a candidate
if current_user.user_type != "candidate":
logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Only candidates can upload documents"
content="Only candidates can upload documents",
)
yield error_message
return
file.filename = re.sub(r'^.*/', '', file.filename) if file.filename else '' # Sanitize filename
file.filename = re.sub(r"^.*/", "", file.filename) if file.filename else "" # Sanitize filename
if not file.filename or file.filename.strip() == "":
logger.warning("⚠️ File upload attempt with missing filename")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="File must have a valid filename"
content="File must have a valid filename",
)
yield error_message
return
logger.info(f"📁 Received file upload: filename='{file.filename}', content_type='{file.content_type}', size='{len(file_content)} bytes'")
logger.info(
f"📁 Received file upload: filename='{file.filename}', content_type='{file.content_type}', size='{len(file_content)} bytes'"
)
# Validate file type
allowed_types = ['.txt', '.md', '.docx', '.pdf', '.png', '.jpg', '.jpeg', '.gif']
allowed_types = [".txt", ".md", ".docx", ".pdf", ".png", ".jpg", ".jpeg", ".gif"]
file_extension = pathlib.Path(file.filename).suffix.lower() if file.filename else ""
if file_extension not in allowed_types:
logger.warning(f"⚠️ Invalid file type: {file_extension} for file {file.filename}")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}"
content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}",
)
yield error_message
return
@ -383,7 +399,7 @@ async def create_job_from_file(
status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Converting content from {document_type}...",
activity=ApiActivityType.CONVERTING
activity=ApiActivityType.CONVERTING,
)
yield status_message
try:
@ -391,7 +407,7 @@ async def create_job_from_file(
stream = io.BytesIO(file_content)
stream_info = StreamInfo(
extension=file_extension, # e.g., ".pdf"
url=file.filename # optional, helps with logging and guessing
url=file.filename, # optional, helps with logging and guessing
)
result = md.convert_stream(stream, stream_info=stream_info, output_format="markdown")
file_content = result.text_content
@ -405,15 +421,18 @@ async def create_job_from_file(
logger.error(f"❌ Error converting {file.filename} to Markdown: {e}")
return
async for message in create_job_from_content(database=database, current_user=current_user, content=file_content):
async for message in create_job_from_content(
database=database, current_user=current_user, content=file_content
):
yield message
return
try:
async def to_json(method):
try:
async for message in method:
json_data = message.model_dump(mode='json', by_alias=True)
json_data = message.model_dump(mode="json", by_alias=True)
json_str = json.dumps(json_data)
yield f"data: {json_str}\n\n".encode("utf-8")
except Exception as e:
@ -437,26 +456,27 @@ async def create_job_from_file(
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Document upload error: {e}")
return StreamingResponse(
iter([json.dumps(ChatMessageError(
iter(
[
json.dumps(
ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to upload document"
).model_dump(mode='json', by_alias=True)).encode("utf-8")]),
media_type="text/event-stream"
content="Failed to upload document",
).model_dump(mode="json", by_alias=True)
).encode("utf-8")
]
),
media_type="text/event-stream",
)
@router.get("/{job_id}")
async def get_job(
job_id: str = Path(...),
database: RedisDatabase = Depends(get_database)
):
async def get_job(job_id: str = Path(...), database: RedisDatabase = Depends(get_database)):
"""Get a job by ID"""
try:
job_data = await database.get_job(job_id)
if not job_data:
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Job not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Job not found"))
# Increment view count
job_data["views"] = job_data.get("views", 0) + 1
@ -467,10 +487,8 @@ async def get_job(
except Exception as e:
logger.error(f"❌ Get job error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("FETCH_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e)))
@router.get("")
async def get_jobs(
@ -479,7 +497,7 @@ async def get_jobs(
sortBy: Optional[str] = Query(None, alias="sortBy"),
sortOrder: str = Query("desc", pattern="^(asc|desc)$", alias="sortOrder"),
filters: Optional[str] = Query(None),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Get paginated list of jobs"""
try:
@ -493,23 +511,18 @@ async def get_jobs(
for job in all_jobs_data.values():
jobs_list.append(Job.model_validate(job))
paginated_jobs, total = filter_and_paginate(
jobs_list, page, limit, sortBy, sortOrder, filter_dict
)
paginated_jobs, total = filter_and_paginate(jobs_list, page, limit, sortBy, sortOrder, filter_dict)
paginated_response = create_paginated_response(
[j.model_dump(by_alias=True) for j in paginated_jobs],
page, limit, total
[j.model_dump(by_alias=True) for j in paginated_jobs], page, limit, total
)
return create_success_response(paginated_response)
except Exception as e:
logger.error(f"❌ Get jobs error: {e}")
return JSONResponse(
status_code=400,
content=create_error_response("FETCH_FAILED", str(e))
)
return JSONResponse(status_code=400, content=create_error_response("FETCH_FAILED", str(e)))
@router.get("/search")
async def search_jobs(
@ -517,7 +530,7 @@ async def search_jobs(
filters: Optional[str] = Query(None),
page: int = Query(1, ge=1),
limit: int = Query(20, ge=1, le=100),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Search jobs"""
try:
@ -532,69 +545,52 @@ async def search_jobs(
if query:
query_lower = query.lower()
jobs_list = [
j for j in jobs_list
if ((j.title and query_lower in j.title.lower()) or
(j.description and query_lower in j.description.lower()) or
any(query_lower in skill.lower() for skill in getattr(j, "skills", []) or []))
j
for j in jobs_list
if (
(j.title and query_lower in j.title.lower())
or (j.description and query_lower in j.description.lower())
or any(query_lower in skill.lower() for skill in getattr(j, "skills", []) or [])
)
]
paginated_jobs, total = filter_and_paginate(
jobs_list, page, limit, filters=filter_dict
)
paginated_jobs, total = filter_and_paginate(jobs_list, page, limit, filters=filter_dict)
paginated_response = create_paginated_response(
[j.model_dump(by_alias=True) for j in paginated_jobs],
page, limit, total
[j.model_dump(by_alias=True) for j in paginated_jobs], page, limit, total
)
return create_success_response(paginated_response)
except Exception as e:
logger.error(f"❌ Search jobs error: {e}")
return JSONResponse(
status_code=400,
content=create_error_response("SEARCH_FAILED", str(e))
)
return JSONResponse(status_code=400, content=create_error_response("SEARCH_FAILED", str(e)))
@router.delete("/{job_id}")
async def delete_job(
job_id: str = Path(...),
admin_user = Depends(get_current_admin),
database: RedisDatabase = Depends(get_database)
job_id: str = Path(...), admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
):
"""Delete a Job"""
try:
# Check if admin user
if not admin_user.is_admin:
logger.warning(f"⚠️ Unauthorized delete attempt by user {admin_user.id}")
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Only admins can delete")
)
return JSONResponse(status_code=403, content=create_error_response("FORBIDDEN", "Only admins can delete"))
# Get candidate data
job_data = await database.get_job(job_id)
if not job_data:
logger.warning(f"⚠️ Candidate not found for deletion: {job_id}")
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Job not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Job not found"))
# Delete job from database
await database.delete_job(job_id)
logger.info(f"🗑️ Job deleted: {job_id} by admin {admin_user.id}")
return create_success_response({
"message": "Job deleted successfully",
"jobId": job_id
})
return create_success_response({"message": "Job deleted successfully", "jobId": job_id})
except Exception as e:
logger.error(f"❌ Delete job error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("DELETE_ERROR", "Failed to delete job")
)
return JSONResponse(status_code=500, content=create_error_response("DELETE_ERROR", "Failed to delete job"))

View File

@ -6,6 +6,7 @@ from utils.llm_proxy import LLMProvider, get_llm
router = APIRouter(prefix="/providers", tags=["providers"])
@router.get("/models")
async def list_models(provider: Optional[str] = None):
"""List available models for a provider"""
@ -17,29 +18,25 @@ async def list_models(provider: Optional[str] = None):
try:
provider_enum = LLMProvider(provider.lower())
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Unsupported provider: {provider}"
)
raise HTTPException(status_code=400, detail=f"Unsupported provider: {provider}")
models = await llm.list_models(provider_enum)
return {
"provider": provider_enum.value if provider_enum else llm.default_provider.value,
"models": models
}
return {"provider": provider_enum.value if provider_enum else llm.default_provider.value, "models": models}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("")
async def list_providers():
"""List all configured providers"""
llm = get_llm()
return {
"providers": [provider.value for provider in llm._initialized_providers],
"default": llm.default_provider.value
"default": llm.default_provider.value,
}
@router.post("/{provider}/set-default")
async def set_default_provider(provider: str):
"""Set the default provider"""
@ -51,6 +48,7 @@ async def set_default_provider(provider: str):
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# Health check endpoint
@router.get("/health")
async def health_check():
@ -59,5 +57,5 @@ async def health_check():
return {
"status": "healthy",
"providers_configured": len(llm._initialized_providers),
"default_provider": llm.default_provider.value
"default_provider": llm.default_provider.value,
}

View File

@ -11,26 +11,24 @@ from fastapi.responses import StreamingResponse
import backstory_traceback as backstory_traceback
from database.manager import RedisDatabase
from logger import logger
from models import (
MOCK_UUID, ChatMessageError, Job, Candidate, Resume, ResumeMessage
)
from utils.dependencies import (
get_database, get_current_user
)
from models import MOCK_UUID, ChatMessageError, Job, Candidate, Resume, ResumeMessage
from utils.dependencies import get_database, get_current_user
from utils.responses import create_success_response
# Create router for authentication endpoints
router = APIRouter(prefix="/resumes", tags=["resumes"])
@router.post("/{candidate_id}/{job_id}")
async def create_candidate_resume(
candidate_id: str = Path(..., description="ID of the candidate"),
job_id: str = Path(..., description="ID of the job"),
resume_content: str = Body(...),
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Create a new resume for a candidate/job combination"""
async def message_stream_generator():
logger.info(f"🔍 Looking up candidate and job details for {candidate_id}/{job_id}")
@ -39,7 +37,7 @@ async def create_candidate_resume(
logger.error(f"❌ Candidate with ID '{candidate_id}' not found")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Candidate with ID '{candidate_id}' not found"
content=f"Candidate with ID '{candidate_id}' not found",
)
yield error_message
return
@ -50,13 +48,15 @@ async def create_candidate_resume(
logger.error(f"❌ Job with ID '{job_id}' not found")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Job with ID '{job_id}' not found"
content=f"Job with ID '{job_id}' not found",
)
yield error_message
return
job = Job.model_validate(job_data)
logger.info(f"📄 Saving resume for candidate {candidate.first_name} {candidate.last_name} for job '{job.title}'")
logger.info(
f"📄 Saving resume for candidate {candidate.first_name} {candidate.last_name} for job '{job.title}'"
)
# Job and Candidate are valid. Save the resume
resume = Resume(
@ -66,16 +66,13 @@ async def create_candidate_resume(
)
resume_message: ResumeMessage = ResumeMessage(
session_id=MOCK_UUID, # No session ID for document uploads
resume=resume
resume=resume,
)
# Save to database
success = await database.set_resume(current_user.id, resume.model_dump())
if not success:
error_message = ChatMessageError(
session_id=MOCK_UUID,
content="Failed to save resume to database"
)
error_message = ChatMessageError(session_id=MOCK_UUID, content="Failed to save resume to database")
yield error_message
return
@ -84,10 +81,11 @@ async def create_candidate_resume(
return
try:
async def to_json(method):
try:
async for message in method:
json_data = message.model_dump(mode='json', by_alias=True)
json_data = message.model_dump(mode="json", by_alias=True)
json_str = json.dumps(json_data)
yield f"data: {json_str}\n\n".encode("utf-8")
except Exception as e:
@ -111,18 +109,22 @@ async def create_candidate_resume(
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Resume creation error: {e}")
return StreamingResponse(
iter([json.dumps(ChatMessageError(
iter(
[
json.dumps(
ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to create resume"
).model_dump(mode='json', by_alias=True))]),
media_type="text/event-stream"
content="Failed to create resume",
).model_dump(mode="json", by_alias=True)
)
]
),
media_type="text/event-stream",
)
@router.get("")
async def get_user_resumes(
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
):
async def get_user_resumes(current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)):
"""Get all resumes for the current user"""
try:
resumes_data = await database.get_all_resumes_for_user(current_user.id)
@ -135,19 +137,17 @@ async def get_user_resumes(
if candidate_data:
resume.candidate = Candidate.model_validate(candidate_data)
resumes.sort(key=lambda x: x.updated_at, reverse=True) # Sort by creation date
return create_success_response({
"resumes": resumes,
"count": len(resumes)
})
return create_success_response({"resumes": resumes, "count": len(resumes)})
except Exception as e:
logger.error(f"❌ Error retrieving resumes for user {current_user.id}: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve resumes")
@router.get("/{resume_id}")
async def get_resume(
resume_id: str = Path(..., description="ID of the resume"),
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Get a specific resume by ID"""
try:
@ -155,21 +155,19 @@ async def get_resume(
if not resume:
raise HTTPException(status_code=404, detail="Resume not found")
return {
"success": True,
"resume": resume
}
return {"success": True, "resume": resume}
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ Error retrieving resume {resume_id} for user {current_user.id}: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve resume")
@router.delete("/{resume_id}")
async def delete_resume(
resume_id: str = Path(..., description="ID of the resume"),
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Delete a specific resume"""
try:
@ -177,102 +175,82 @@ async def delete_resume(
if not success:
raise HTTPException(status_code=404, detail="Resume not found")
return {
"success": True,
"message": f"Resume {resume_id} deleted successfully"
}
return {"success": True, "message": f"Resume {resume_id} deleted successfully"}
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ Error deleting resume {resume_id} for user {current_user.id}: {e}")
raise HTTPException(status_code=500, detail="Failed to delete resume")
@router.get("/candidate/{candidate_id}")
async def get_resumes_by_candidate(
candidate_id: str = Path(..., description="ID of the candidate"),
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Get all resumes for a specific candidate"""
try:
resumes = await database.get_resumes_by_candidate(current_user.id, candidate_id)
return {
"success": True,
"candidate_id": candidate_id,
"resumes": resumes,
"count": len(resumes)
}
return {"success": True, "candidate_id": candidate_id, "resumes": resumes, "count": len(resumes)}
except Exception as e:
logger.error(f"❌ Error retrieving resumes for candidate {candidate_id}: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve candidate resumes")
@router.get("/job/{job_id}")
async def get_resumes_by_job(
job_id: str = Path(..., description="ID of the job"),
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Get all resumes for a specific job"""
try:
resumes = await database.get_resumes_by_job(current_user.id, job_id)
return {
"success": True,
"job_id": job_id,
"resumes": resumes,
"count": len(resumes)
}
return {"success": True, "job_id": job_id, "resumes": resumes, "count": len(resumes)}
except Exception as e:
logger.error(f"❌ Error retrieving resumes for job {job_id}: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve job resumes")
@router.get("/search")
async def search_resumes(
q: str = Query(..., description="Search query"),
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Search resumes by content"""
try:
resumes = await database.search_resumes_for_user(current_user.id, q)
return {
"success": True,
"query": q,
"resumes": resumes,
"count": len(resumes)
}
return {"success": True, "query": q, "resumes": resumes, "count": len(resumes)}
except Exception as e:
logger.error(f"❌ Error searching resumes for user {current_user.id}: {e}")
raise HTTPException(status_code=500, detail="Failed to search resumes")
@router.get("/stats")
async def get_resume_statistics(
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)
):
"""Get resume statistics for the current user"""
try:
stats = await database.get_resume_statistics(current_user.id)
return {
"success": True,
"statistics": stats
}
return {"success": True, "statistics": stats}
except Exception as e:
logger.error(f"❌ Error retrieving resume statistics for user {current_user.id}: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve resume statistics")
@router.put("/{resume_id}")
async def update_resume(
resume_id: str = Path(..., description="ID of the resume"),
resume: str = Body(..., description="Updated resume content"),
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Update the content of a specific resume"""
try:
updates = {
"resume": resume,
"updated_at": datetime.now(UTC).isoformat()
}
updates = {"resume": resume, "updated_at": datetime.now(UTC).isoformat()}
updated_resume_data = await database.update_resume(current_user.id, resume_id, updates)
if not updated_resume_data:
@ -280,11 +258,9 @@ async def update_resume(
raise HTTPException(status_code=404, detail="Resume not found")
updated_resume = Resume.model_validate(updated_resume_data) if updated_resume_data else None
return create_success_response({
"success": True,
"message": f"Resume {resume_id} updated successfully",
"resume": updated_resume
})
return create_success_response(
{"success": True, "message": f"Resume {resume_id} updated successfully", "resume": updated_resume}
)
except HTTPException:
raise
except Exception as e:

View File

@ -10,6 +10,7 @@ from utils.responses import create_success_response
# Create router for authentication endpoints
router = APIRouter(prefix="/system", tags=["system"])
async def get_redis() -> Redis:
"""Dependency to get Redis client"""
return redis_manager.get_client()
@ -19,9 +20,11 @@ async def get_redis() -> Redis:
async def get_system_info(request: Request):
"""Get system information"""
from system_info import system_info # Import system_info function from system_info module
system = system_info()
return create_success_response(system.model_dump(mode='json'))
return create_success_response(system.model_dump(mode="json"))
@router.get("/redis/stats")
async def redis_stats(redis: Redis = Depends(get_redis)):
@ -33,7 +36,7 @@ async def redis_stats(redis: Redis = Depends(get_redis)):
"total_commands_processed": info.get("total_commands_processed"),
"keyspace_hits": info.get("keyspace_hits"),
"keyspace_misses": info.get("keyspace_misses"),
"uptime_in_seconds": info.get("uptime_in_seconds")
"uptime_in_seconds": info.get("uptime_in_seconds"),
}
except Exception as e:
raise HTTPException(status_code=503, detail=f"Redis stats unavailable: {e}")

View File

@ -7,23 +7,17 @@ from fastapi.responses import JSONResponse
from database.manager import RedisDatabase
from logger import logger
from models import (
BaseUserWithType
)
from utils.dependencies import (
get_database
)
from models import BaseUserWithType
from utils.dependencies import get_database
from utils.responses import create_success_response, create_error_response
# Create router for job endpoints
router = APIRouter(prefix="/users", tags=["users"])
# reference can be candidateId, username, or email
@router.get("/users/{reference}")
async def get_user(
reference: str = Path(...),
database: RedisDatabase = Depends(get_database)
):
async def get_user(reference: str = Path(...), database: RedisDatabase = Depends(get_database)):
"""Get a candidate by username"""
try:
# Normalize reference to lowercase for case-insensitive search
@ -32,16 +26,15 @@ async def get_user(
all_candidate_data = await database.get_all_candidates()
if not all_candidate_data:
logger.warning("⚠️ No users found in database")
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "No users found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "No users found"))
user_data = None
for user in all_candidate_data.values():
if (user.get("id", "").lower() == query_lower or
user.get("username", "").lower() == query_lower or
user.get("email", "").lower() == query_lower):
if (
user.get("id", "").lower() == query_lower
or user.get("username", "").lower() == query_lower
or user.get("email", "").lower() == query_lower
):
user_data = user
break
@ -49,23 +42,19 @@ async def get_user(
all_guest_data = await database.get_all_guests()
if not all_guest_data:
logger.warning("⚠️ No guests found in database")
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "No users found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "No users found"))
for user in all_guest_data.values():
if (user.get("id", "").lower() == query_lower or
user.get("username", "").lower() == query_lower or
user.get("email", "").lower() == query_lower):
if (
user.get("id", "").lower() == query_lower
or user.get("username", "").lower() == query_lower
or user.get("email", "").lower() == query_lower
):
user_data = user
break
if not user_data:
logger.warning(f"⚠️ User nor Guest found for reference: {reference}")
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "User not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "User not found"))
user = BaseUserWithType.model_validate(user_data)
@ -73,8 +62,4 @@ async def get_user(
except Exception as e:
logger.error(f"❌ Get user error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("FETCH_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e)))

View File

@ -4,6 +4,7 @@ import subprocess
import math
from models import SystemInfo
def get_installed_ram():
try:
with open("/proc/meminfo", "r") as f:
@ -19,9 +20,7 @@ def get_graphics_cards():
gpus = []
try:
# Run the ze-monitor utility
result = subprocess.run(
["ze-monitor"], capture_output=True, text=True, check=True
)
result = subprocess.run(["ze-monitor"], capture_output=True, text=True, check=True)
# Clean up the output (remove leading/trailing whitespace and newlines)
output = result.stdout.strip()
@ -71,6 +70,7 @@ def get_cpu_info():
except Exception as e:
return f"Error retrieving CPU info: {e}"
def system_info() -> SystemInfo:
"""
Collects system information including RAM, GPU, CPU, LLM model, embedding model, and context length.

View File

@ -122,9 +122,7 @@ def get_forecast(grid_endpoint):
# Process the forecast data into a simpler format
forecast = {
"location": data["properties"]
.get("relativeLocation", {})
.get("properties", {}),
"location": data["properties"].get("relativeLocation", {}).get("properties", {}),
"updated": data["properties"].get("updated", ""),
"periods": [],
}
@ -189,9 +187,7 @@ def TickerValue(ticker_symbols):
if ticker_symbol == "":
continue
url = (
f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}"
)
url = f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}"
response = requests.get(url)
data = response.json()
@ -244,9 +240,7 @@ def yfTickerValue(ticker_symbols):
logging.error(f"Error fetching data for {ticker_symbol}: {e}")
logging.error(traceback.format_exc())
results.append(
{"error": f"Error fetching data for {ticker_symbol}: {str(e)}"}
)
results.append({"error": f"Error fetching data for {ticker_symbol}: {str(e)}"})
return results[0] if len(results) == 1 else results
@ -278,9 +272,11 @@ def DateTime(timezone="America/Los_Angeles"):
except Exception as e:
return {"error": f"Invalid timezone {timezone}: {str(e)}"}
async def GenerateImage(llm, model: str, prompt: str):
return {"image_id": "image-a830a83-bd831"}
async def AnalyzeSite(llm, model: str, url: str, question: str):
"""
Fetches content from a URL, extracts the text, and uses Ollama to summarize it.
@ -347,7 +343,6 @@ async def AnalyzeSite(llm, model: str, url: str, question: str):
return f"Error processing the website content: {str(e)}"
# %%
class Function(BaseModel):
name: str
@ -355,10 +350,12 @@ class Function(BaseModel):
parameters: Dict[str, Any]
returns: Optional[Dict[str, Any]] = {}
class Tool(BaseModel):
type: str
function: Function
tools: List[Tool] = [
# Tool.model_validate({
# "type": "function",
@ -366,24 +363,20 @@ tools : List[Tool] = [
# "name": "GenerateImage",
# "description": """\
# CRITICAL INSTRUCTIONS FOR IMAGE GENERATION:
# 1. Call this tool when users request images, drawings, or visual content
# 2. This tool returns an image_id (e.g., "img_abc123")
# 3. MANDATORY: You must respond with EXACTLY this format: <GenerateImage id={image_id}/>
# 4. FORBIDDEN: DO NOT use markdown image syntax ![](url)
# 5. FORBIDDEN: DO NOT create fake URLs or file paths
# 6. FORBIDDEN: DO NOT use any other image embedding format
# CORRECT EXAMPLE:
# User: "Draw a cat"
# Tool returns: {"image_id": "img_xyz789"}
# Your response: "Here's your cat image: <GenerateImage id=img_xyz789/>"
# WRONG EXAMPLES (DO NOT DO THIS):
# - ![](https://example.com/...)
# - ![Cat image](any_url)
# - <img src="...">
# The <GenerateImage id={image_id}/> format is the ONLY way to display images in this system.
# """,
# "parameters": {
@ -407,7 +400,8 @@ tools : List[Tool] = [
# }
# }
# }),
Tool.model_validate({
Tool.model_validate(
{
"type": "function",
"function": {
"name": "TickerValue",
@ -424,8 +418,10 @@ tools : List[Tool] = [
"additionalProperties": False,
},
},
}),
Tool.model_validate({
}
),
Tool.model_validate(
{
"type": "function",
"function": {
"name": "AnalyzeSite",
@ -463,8 +459,10 @@ tools : List[Tool] = [
},
},
},
}),
Tool.model_validate({
}
),
Tool.model_validate(
{
"type": "function",
"function": {
"name": "DateTime",
@ -480,8 +478,10 @@ tools : List[Tool] = [
"required": [],
},
},
}),
Tool.model_validate({
}
),
Tool.model_validate(
{
"type": "function",
"function": {
"name": "WeatherForecast",
@ -504,22 +504,27 @@ tools : List[Tool] = [
"additionalProperties": False,
},
},
}),
}
),
]
class ToolEntry(BaseModel):
enabled: bool = True
tool: Tool
def llm_tools(tools: List[ToolEntry]) -> List[Dict[str, Any]]:
return [entry.tool.model_dump(mode='json') for entry in tools if entry.enabled is True]
return [entry.tool.model_dump(mode="json") for entry in tools if entry.enabled is True]
def all_tools() -> List[ToolEntry]:
return [ToolEntry(tool=tool) for tool in tools]
def enabled_tools(tools: List[ToolEntry]) -> List[ToolEntry]:
return [ToolEntry(tool=entry.tool) for entry in tools if entry.enabled is True]
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite", "GenerateImage"]
__all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"]

View File

@ -14,6 +14,7 @@ from pydantic import BaseModel
logger = logging.getLogger(__name__)
class PasswordSecurity:
"""Handles password hashing and verification using bcrypt"""
@ -32,9 +33,9 @@ class PasswordSecurity:
salt = bcrypt.gensalt()
# Hash the password
password_hash = bcrypt.hashpw(password.encode('utf-8'), salt)
password_hash = bcrypt.hashpw(password.encode("utf-8"), salt)
return password_hash.decode('utf-8'), salt.decode('utf-8')
return password_hash.decode("utf-8"), salt.decode("utf-8")
@staticmethod
def verify_password(password: str, password_hash: str) -> bool:
@ -49,10 +50,7 @@ class PasswordSecurity:
True if password matches, False otherwise
"""
try:
return bcrypt.checkpw(
password.encode('utf-8'),
password_hash.encode('utf-8')
)
return bcrypt.checkpw(password.encode("utf-8"), password_hash.encode("utf-8"))
except Exception as e:
logger.error(f"Password verification error: {e}")
return False
@ -62,8 +60,10 @@ class PasswordSecurity:
"""Generate a cryptographically secure random token"""
return secrets.token_urlsafe(length)
class AuthenticationRecord(BaseModel):
"""Authentication record for storing user credentials"""
user_id: str
password_hash: str
salt: str
@ -78,18 +78,19 @@ class AuthenticationRecord(BaseModel):
locked_until: Optional[datetime] = None
class Config:
json_encoders = {
datetime: lambda v: v.isoformat() if v else None
}
json_encoders = {datetime: lambda v: v.isoformat() if v else None}
class SecurityConfig:
"""Security configuration constants"""
MAX_LOGIN_ATTEMPTS = 5
ACCOUNT_LOCKOUT_DURATION_MINUTES = 15
PASSWORD_MIN_LENGTH = 8
TOKEN_EXPIRY_HOURS = 24
REFRESH_TOKEN_EXPIRY_DAYS = 30
class AuthenticationManager:
"""Manages authentication operations with security features"""
@ -120,7 +121,7 @@ class AuthenticationManager:
password_hash=password_hash,
salt=salt,
last_password_change=datetime.now(timezone.utc),
login_attempts=0
login_attempts=0,
)
# Store in database
@ -129,7 +130,9 @@ class AuthenticationManager:
logger.info(f"🔐 Created authentication record for user {user_id}")
return auth_record
async def verify_user_credentials(self, login: str, password: str) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
async def verify_user_credentials(
self, login: str, password: str
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
"""
Verify user credentials with security checks
@ -164,7 +167,11 @@ class AuthenticationManager:
seconds = int(total_seconds % 60)
time_until_unlock_str = f"{minutes}m {seconds}s"
logger.warning(f"🔒 Account is locked for user {login} for another {time_until_unlock_str}.")
return False, None, f"Account is temporarily locked due to too many failed attempts. Retry after {time_until_unlock_str}"
return (
False,
None,
f"Account is temporarily locked due to too many failed attempts. Retry after {time_until_unlock_str}",
)
# Verify password
if not self.password_security.verify_password(password, auth_data.password_hash):
@ -176,7 +183,9 @@ class AuthenticationManager:
auth_data.locked_until = datetime.now(timezone.utc) + timedelta(
minutes=SecurityConfig.ACCOUNT_LOCKOUT_DURATION_MINUTES
)
logger.warning(f"🔒 Account locked for user {login} after {auth_data.login_attempts} failed attempts")
logger.warning(
f"🔒 Account locked for user {login} after {auth_data.login_attempts} failed attempts"
)
# Update authentication record
await self.database.set_authentication(user_data["id"], auth_data.model_dump())
@ -238,6 +247,7 @@ class AuthenticationManager:
except Exception as e:
logger.error(f"❌ Error updating last login for user {user_id}: {e}")
# Utility functions for common operations
def validate_password_strength(password: str) -> Tuple[bool, list]:
"""
@ -270,6 +280,7 @@ def validate_password_strength(password: str) -> Tuple[bool, list]:
return len(issues) == 0, issues
def sanitize_login_input(login: str) -> str:
"""Sanitize login input (email or username)"""
return login.strip().lower() if login else ""

View File

@ -33,11 +33,13 @@ background_task_manager: Optional[BackgroundTaskManager] = None
# Global database manager reference
db_manager = None
def set_db_manager(manager: DatabaseManager):
"""Set the global database manager reference"""
global db_manager
db_manager = manager
def get_database() -> RedisDatabase:
"""
Safe database dependency that checks for availability
@ -47,26 +49,18 @@ def get_database() -> RedisDatabase:
if db_manager is None:
logger.error("Database manager not initialized")
raise HTTPException(
status_code=503,
detail="Database not available - service starting up"
)
raise HTTPException(status_code=503, detail="Database not available - service starting up")
if db_manager.is_shutting_down:
logger.warning("Database is shutting down")
raise HTTPException(
status_code=503,
detail="Service is shutting down"
)
raise HTTPException(status_code=503, detail="Service is shutting down")
try:
return db_manager.get_database()
except RuntimeError as e:
logger.error(f"Database not available: {e}")
raise HTTPException(
status_code=503,
detail="Database connection not available"
)
raise HTTPException(status_code=503, detail="Database connection not available")
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
@ -78,6 +72,7 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""Enhanced token verification with guest session recovery"""
try:
@ -125,9 +120,9 @@ async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials
logger.error(f"❌ Token verification error: {e}")
raise HTTPException(status_code=401, detail="Token verification failed")
async def get_current_user(
user_id: str = Depends(verify_token_with_blacklist),
database: RedisDatabase = Depends(get_database)
user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database)
) -> BaseUserWithType:
"""Get current user from database"""
try:
@ -135,6 +130,7 @@ async def get_current_user(
candidate_data = await database.get_candidate(user_id)
if candidate_data:
from helpers.model_cast import cast_to_base_user_with_type
if candidate_data.get("is_AI"):
return cast_to_base_user_with_type(CandidateAI.model_validate(candidate_data))
else:
@ -152,16 +148,20 @@ async def get_current_user(
logger.error(f"❌ Error getting current user: {e}")
raise HTTPException(status_code=404, detail="User not found")
async def get_current_user_or_guest(
user_id: str = Depends(verify_token_with_blacklist),
database: RedisDatabase = Depends(get_database)
user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database)
) -> BaseUserWithType:
"""Get current user (including guests) from database"""
try:
# Check candidates first
candidate_data = await database.get_candidate(user_id)
if candidate_data:
return Candidate.model_validate(candidate_data) if not candidate_data.get("is_AI") else CandidateAI.model_validate(candidate_data)
return (
Candidate.model_validate(candidate_data)
if not candidate_data.get("is_AI")
else CandidateAI.model_validate(candidate_data)
)
# Check employers
employer_data = await database.get_employer(user_id)
@ -180,9 +180,9 @@ async def get_current_user_or_guest(
logger.error(f"❌ Error getting current user: {e}")
raise HTTPException(status_code=404, detail="User not found")
async def get_current_admin(
user_id: str = Depends(verify_token_with_blacklist),
database: RedisDatabase = Depends(get_database)
user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database)
) -> BaseUserWithType:
user = await get_current_user(user_id=user_id, database=database)
if isinstance(user, Candidate) and user.is_admin:
@ -193,6 +193,7 @@ async def get_current_admin(
logger.warning(f"⚠️ User {user_id} is not an admin")
raise HTTPException(status_code=403, detail="Admin access required")
prometheus_collector = CollectorRegistry()
# Keep the Instrumentator instance alive
@ -201,5 +202,5 @@ instrumentator = Instrumentator(
should_ignore_untemplated=True,
should_group_untemplated=True,
excluded_handlers=[f"{defines.api_prefix}/metrics"],
registry=prometheus_collector
registry=prometheus_collector,
)

View File

@ -1,5 +1,6 @@
from typing import List, Dict
from models import (Job)
from models import Job
def get_requirements_list(job: Job) -> List[Dict[str, str]]:
requirements: List[Dict[str, str]] = []
@ -7,56 +8,56 @@ def get_requirements_list(job: Job) -> List[Dict[str, str]]:
if job.requirements:
if job.requirements.technical_skills:
if job.requirements.technical_skills.required:
requirements.extend([
requirements.extend(
[
{"requirement": req, "domain": "Technical Skills (required)"}
for req in job.requirements.technical_skills.required
])
]
)
if job.requirements.technical_skills.preferred:
requirements.extend([
requirements.extend(
[
{"requirement": req, "domain": "Technical Skills (preferred)"}
for req in job.requirements.technical_skills.preferred
])
]
)
if job.requirements.experience_requirements:
if job.requirements.experience_requirements.required:
requirements.extend([
requirements.extend(
[
{"requirement": req, "domain": "Experience (required)"}
for req in job.requirements.experience_requirements.required
])
]
)
if job.requirements.experience_requirements.preferred:
requirements.extend([
requirements.extend(
[
{"requirement": req, "domain": "Experience (preferred)"}
for req in job.requirements.experience_requirements.preferred
])
]
)
if job.requirements.soft_skills:
requirements.extend([
{"requirement": req, "domain": "Soft Skills"}
for req in job.requirements.soft_skills
])
requirements.extend([{"requirement": req, "domain": "Soft Skills"} for req in job.requirements.soft_skills])
if job.requirements.experience:
requirements.extend([
{"requirement": req, "domain": "Experience"}
for req in job.requirements.experience
])
requirements.extend([{"requirement": req, "domain": "Experience"} for req in job.requirements.experience])
if job.requirements.education:
requirements.extend([
{"requirement": req, "domain": "Education"}
for req in job.requirements.education
])
requirements.extend([{"requirement": req, "domain": "Education"} for req in job.requirements.education])
if job.requirements.certifications:
requirements.extend([
{"requirement": req, "domain": "Certifications"}
for req in job.requirements.certifications
])
requirements.extend(
[{"requirement": req, "domain": "Certifications"} for req in job.requirements.certifications]
)
if job.requirements.preferred_attributes:
requirements.extend([
requirements.extend(
[
{"requirement": req, "domain": "Preferred Attributes"}
for req in job.requirements.preferred_attributes
])
]
)
return requirements

View File

@ -13,15 +13,13 @@ from fastapi.responses import StreamingResponse
import defines
from logger import logger
from models import DocumentType
from models import (
Job,
ChatMessage, DocumentType, ApiStatusType
)
from models import Job, ChatMessage, DocumentType, ApiStatusType
from typing import List, Dict
from models import (Job)
from models import Job
import utils.llm_proxy as llm_manager
async def get_last_item(generator):
"""Get the last item from an async generator"""
last_item = None
@ -36,7 +34,7 @@ def filter_and_paginate(
limit: int = 20,
sort_by: Optional[str] = None,
sort_order: str = "desc",
filters: Optional[Dict] = None
filters: Optional[Dict] = None,
) -> Tuple[List[Any], int]:
"""Filter, sort, and paginate items"""
filtered_items = items.copy()
@ -47,8 +45,7 @@ def filter_and_paginate(
if isinstance(filtered_items[0], dict) and key in filtered_items[0]:
filtered_items = [item for item in filtered_items if item.get(key) == value]
elif hasattr(filtered_items[0], key) if filtered_items else False:
filtered_items = [item for item in filtered_items
if getattr(item, key, None) == value]
filtered_items = [item for item in filtered_items if getattr(item, key, None) == value]
# Sort items
if sort_by and filtered_items:
@ -72,6 +69,7 @@ def filter_and_paginate(
async def stream_agent_response(chat_agent, user_message, chat_session_data=None, database=None) -> StreamingResponse:
"""Stream agent response with proper formatting"""
async def message_stream_generator():
"""Generator to stream messages with persistence"""
final_message = None
@ -97,12 +95,15 @@ async def stream_agent_response(chat_agent, user_message, chat_session_data=None
# metadata and other unnecessary fields for streaming
if generated_message.status != ApiStatusType.DONE:
from models import ChatMessageStreaming, ChatMessageStatus
if not isinstance(generated_message, ChatMessageStreaming) and not isinstance(generated_message, ChatMessageStatus):
if not isinstance(generated_message, ChatMessageStreaming) and not isinstance(
generated_message, ChatMessageStatus
):
raise TypeError(
f"Expected ChatMessageStreaming or ChatMessageStatus, got {type(generated_message)}"
)
json_data = generated_message.model_dump(mode='json', by_alias=True)
json_data = generated_message.model_dump(mode="json", by_alias=True)
json_str = json.dumps(json_data)
yield f"data: {json_str}\n\n"
@ -147,16 +148,16 @@ def get_document_type_from_filename(filename: str) -> DocumentType:
extension = pathlib.Path(filename).suffix.lower()
type_mapping = {
'.pdf': DocumentType.PDF,
'.docx': DocumentType.DOCX,
'.doc': DocumentType.DOCX,
'.txt': DocumentType.TXT,
'.md': DocumentType.MARKDOWN,
'.markdown': DocumentType.MARKDOWN,
'.png': DocumentType.IMAGE,
'.jpg': DocumentType.IMAGE,
'.jpeg': DocumentType.IMAGE,
'.gif': DocumentType.IMAGE,
".pdf": DocumentType.PDF,
".docx": DocumentType.DOCX,
".doc": DocumentType.DOCX,
".txt": DocumentType.TXT,
".md": DocumentType.MARKDOWN,
".markdown": DocumentType.MARKDOWN,
".png": DocumentType.IMAGE,
".jpg": DocumentType.IMAGE,
".jpeg": DocumentType.IMAGE,
".gif": DocumentType.IMAGE,
}
return type_mapping.get(extension, DocumentType.TXT)
@ -176,17 +177,12 @@ async def reformat_as_markdown(database, candidate_entity, content: str):
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
if not chat_agent:
error_message = ChatMessageError(
session_id=MOCK_UUID,
content="No agent found for job requirements chat type"
)
error_message = ChatMessageError(session_id=MOCK_UUID, content="No agent found for job requirements chat type")
yield error_message
return
status_message = ChatMessageStatus(
session_id=MOCK_UUID,
content="Reformatting job description as markdown...",
activity=ApiActivityType.CONVERTING
session_id=MOCK_UUID, content="Reformatting job description as markdown...", activity=ApiActivityType.CONVERTING
)
yield status_message
@ -199,16 +195,13 @@ async def reformat_as_markdown(database, candidate_entity, content: str):
system_prompt="""
You are a document editor. Take the provided job description and reformat as legible markdown.
Return only the markdown content, no other text. Make sure all content is included.
"""
""",
):
pass
if not message or not isinstance(message, ChatMessage):
logger.error("❌ Failed to reformat job description to markdown")
error_message = ChatMessageError(
session_id=MOCK_UUID,
content="Failed to reformat job description"
)
error_message = ChatMessageError(session_id=MOCK_UUID, content="Failed to reformat job description")
yield error_message
return
@ -226,14 +219,19 @@ Return only the markdown content, no other text. Make sure all content is includ
async def create_job_from_content(database, current_user, content: str):
"""Create a job from content using AI analysis"""
from models import (
MOCK_UUID, ApiStatusType, ChatMessageError, ChatMessageStatus,
ApiActivityType, ChatContextType, JobRequirementsMessage
MOCK_UUID,
ApiStatusType,
ChatMessageError,
ChatMessageStatus,
ApiActivityType,
ChatContextType,
JobRequirementsMessage,
)
status_message = ChatMessageStatus(
session_id=MOCK_UUID,
content=f"Initiating connection with {current_user.first_name}'s AI agent...",
activity=ApiActivityType.INFO
activity=ApiActivityType.INFO,
)
yield status_message
await asyncio.sleep(0) # Let the status message propagate
@ -248,10 +246,7 @@ async def create_job_from_content(database, current_user, content: str):
yield message
if not message or not isinstance(message, ChatMessage):
error_message = ChatMessageError(
session_id=MOCK_UUID,
content="Failed to reformat job description"
)
error_message = ChatMessageError(session_id=MOCK_UUID, content="Failed to reformat job description")
yield error_message
return
@ -260,8 +255,7 @@ async def create_job_from_content(database, current_user, content: str):
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
if not chat_agent:
error_message = ChatMessageError(
session_id=MOCK_UUID,
content="No agent found for job requirements chat type"
session_id=MOCK_UUID, content="No agent found for job requirements chat type"
)
yield error_message
return
@ -269,24 +263,20 @@ async def create_job_from_content(database, current_user, content: str):
status_message = ChatMessageStatus(
session_id=MOCK_UUID,
content="Analyzing document for company and requirement details...",
activity=ApiActivityType.SEARCHING
activity=ApiActivityType.SEARCHING,
)
yield status_message
message = None
async for message in chat_agent.generate(
llm=llm_manager.get_llm(),
model=defines.model,
session_id=MOCK_UUID,
prompt=markdown_message.content
llm=llm_manager.get_llm(), model=defines.model, session_id=MOCK_UUID, prompt=markdown_message.content
):
if message.status != ApiStatusType.DONE:
yield message
if not message or not isinstance(message, JobRequirementsMessage):
error_message = ChatMessageError(
session_id=MOCK_UUID,
content="Job extraction did not convert successfully"
session_id=MOCK_UUID, content="Job extraction did not convert successfully"
)
yield error_message
return
@ -296,62 +286,63 @@ async def create_job_from_content(database, current_user, content: str):
yield job_requirements
return
def get_requirements_list(job: Job) -> List[Dict[str, str]]:
requirements: List[Dict[str, str]] = []
if job.requirements:
if job.requirements.technical_skills:
if job.requirements.technical_skills.required:
requirements.extend([
requirements.extend(
[
{"requirement": req, "domain": "Technical Skills (required)"}
for req in job.requirements.technical_skills.required
])
]
)
if job.requirements.technical_skills.preferred:
requirements.extend([
requirements.extend(
[
{"requirement": req, "domain": "Technical Skills (preferred)"}
for req in job.requirements.technical_skills.preferred
])
]
)
if job.requirements.experience_requirements:
if job.requirements.experience_requirements.required:
requirements.extend([
requirements.extend(
[
{"requirement": req, "domain": "Experience (required)"}
for req in job.requirements.experience_requirements.required
])
]
)
if job.requirements.experience_requirements.preferred:
requirements.extend([
requirements.extend(
[
{"requirement": req, "domain": "Experience (preferred)"}
for req in job.requirements.experience_requirements.preferred
])
]
)
if job.requirements.soft_skills:
requirements.extend([
{"requirement": req, "domain": "Soft Skills"}
for req in job.requirements.soft_skills
])
requirements.extend([{"requirement": req, "domain": "Soft Skills"} for req in job.requirements.soft_skills])
if job.requirements.experience:
requirements.extend([
{"requirement": req, "domain": "Experience"}
for req in job.requirements.experience
])
requirements.extend([{"requirement": req, "domain": "Experience"} for req in job.requirements.experience])
if job.requirements.education:
requirements.extend([
{"requirement": req, "domain": "Education"}
for req in job.requirements.education
])
requirements.extend([{"requirement": req, "domain": "Education"} for req in job.requirements.education])
if job.requirements.certifications:
requirements.extend([
{"requirement": req, "domain": "Certifications"}
for req in job.requirements.certifications
])
requirements.extend(
[{"requirement": req, "domain": "Certifications"} for req in job.requirements.certifications]
)
if job.requirements.preferred_attributes:
requirements.extend([
requirements.extend(
[
{"requirement": req, "domain": "Preferred Attributes"}
for req in job.requirements.preferred_attributes
])
]
)
return requirements

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,7 @@
from prometheus_client import Counter, Histogram # type: ignore
from threading import Lock
def singleton(cls):
instance = None
lock = Lock()

View File

@ -12,50 +12,62 @@ from logger import logger
from .dependencies import get_current_user_or_guest, get_database
async def get_rate_limiter(database: RedisDatabase = Depends(get_database)) -> RateLimiter:
"""Dependency to get rate limiter instance"""
return RateLimiter(database)
class RateLimitConfig(BaseModel):
"""Rate limit configuration"""
requests_per_minute: int
requests_per_hour: int
requests_per_day: int
burst_limit: int # Maximum requests in a short burst
burst_window_seconds: int = 60 # Window for burst detection
class GuestRateLimitConfig(RateLimitConfig):
"""Rate limits for guest users - more restrictive"""
requests_per_minute: int = 10
requests_per_hour: int = 100
requests_per_day: int = 500
burst_limit: int = 15
burst_window_seconds: int = 60
class AuthenticatedUserRateLimitConfig(RateLimitConfig):
"""Rate limits for authenticated users - more generous"""
requests_per_minute: int = 60
requests_per_hour: int = 1000
requests_per_day: int = 10000
burst_limit: int = 100
burst_window_seconds: int = 60
class PremiumUserRateLimitConfig(RateLimitConfig):
"""Rate limits for premium/admin users - most generous"""
requests_per_minute: int = 120
requests_per_hour: int = 5000
requests_per_day: int = 50000
burst_limit: int = 200
burst_window_seconds: int = 60
class RateLimitResult(BaseModel):
"""Result of rate limit check"""
allowed: bool
reason: Optional[str] = None
retry_after_seconds: Optional[int] = None
remaining_requests: Dict[str, int] = {}
reset_times: Dict[str, datetime] = {}
class RateLimiter:
"""Rate limiter using Redis for distributed rate limiting"""
@ -78,11 +90,7 @@ class RateLimiter:
return self.user_config
async def check_rate_limit(
self,
user_id: str,
user_type: str,
is_admin: bool = False,
endpoint: Optional[str] = None
self, user_id: str, user_type: str, is_admin: bool = False, endpoint: Optional[str] = None
) -> RateLimitResult:
"""
Check if user has exceeded rate limits
@ -105,7 +113,7 @@ class RateLimiter:
"minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}",
"hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}",
"day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}",
"burst": f"{base_key}:burst"
"burst": f"{base_key}:burst",
}
# Add endpoint-specific limiting if provided
@ -125,7 +133,7 @@ class RateLimiter:
"minute": int(results[0] or 0),
"hour": int(results[1] or 0),
"day": int(results[2] or 0),
"burst": int(results[3] or 0)
"burst": int(results[3] or 0),
}
# Check limits
@ -133,7 +141,7 @@ class RateLimiter:
"minute": config.requests_per_minute,
"hour": config.requests_per_hour,
"day": config.requests_per_day,
"burst": config.burst_limit
"burst": config.burst_limit,
}
# Check each limit
@ -146,18 +154,22 @@ class RateLimiter:
elif window == "hour":
retry_after = 3600 - (current_time.minute * 60 + current_time.second)
elif window == "day":
retry_after = 86400 - (current_time.hour * 3600 + current_time.minute * 60 + current_time.second)
retry_after = 86400 - (
current_time.hour * 3600 + current_time.minute * 60 + current_time.second
)
else: # burst
retry_after = config.burst_window_seconds
logger.warning(f"🚫 Rate limit exceeded for {user_type} {user_id}: {current_count}/{limit} {window}")
logger.warning(
f"🚫 Rate limit exceeded for {user_type} {user_id}: {current_count}/{limit} {window}"
)
return RateLimitResult(
allowed=False,
reason=f"Rate limit exceeded: {current_count}/{limit} requests per {window}",
retry_after_seconds=retry_after,
remaining_requests={k: max(0, limits[k] - v) for k, v in current_counts.items()},
reset_times=self._calculate_reset_times(current_time)
reset_times=self._calculate_reset_times(current_time),
)
# If we get here, request is allowed - increment counters
@ -182,17 +194,12 @@ class RateLimiter:
await pipe.execute()
# Calculate remaining requests
remaining = {
k: max(0, limits[k] - (current_counts[k] + 1))
for k in current_counts.keys()
}
remaining = {k: max(0, limits[k] - (current_counts[k] + 1)) for k in current_counts.keys()}
logger.debug(f"✅ Rate limit check passed for {user_type} {user_id}")
return RateLimitResult(
allowed=True,
remaining_requests=remaining,
reset_times=self._calculate_reset_times(current_time)
allowed=True, remaining_requests=remaining, reset_times=self._calculate_reset_times(current_time)
)
except Exception as e:
@ -206,18 +213,9 @@ class RateLimiter:
next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
next_day = current_time.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
return {
"minute": next_minute,
"hour": next_hour,
"day": next_day
}
return {"minute": next_minute, "hour": next_hour, "day": next_day}
async def get_user_rate_limit_status(
self,
user_id: str,
user_type: str,
is_admin: bool = False
) -> Dict[str, Any]:
async def get_user_rate_limit_status(self, user_id: str, user_type: str, is_admin: bool = False) -> Dict[str, Any]:
"""Get current rate limit status for a user"""
config = self.get_config_for_user(user_type, is_admin)
current_time = datetime.now(UTC)
@ -227,7 +225,7 @@ class RateLimiter:
"minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}",
"hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}",
"day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}",
"burst": f"{base_key}:burst"
"burst": f"{base_key}:burst",
}
try:
@ -240,14 +238,14 @@ class RateLimiter:
"minute": int(results[0] or 0),
"hour": int(results[1] or 0),
"day": int(results[2] or 0),
"burst": int(results[3] or 0)
"burst": int(results[3] or 0),
}
limits = {
"minute": config.requests_per_minute,
"hour": config.requests_per_hour,
"day": config.requests_per_day,
"burst": config.burst_limit
"burst": config.burst_limit,
}
return {
@ -258,7 +256,7 @@ class RateLimiter:
"limits": limits,
"remaining": {k: max(0, limits[k] - current_counts[k]) for k in limits.keys()},
"reset_times": self._calculate_reset_times(current_time),
"config": config.model_dump()
"config": config.model_dump(),
}
except Exception as e:
@ -295,11 +293,9 @@ class RateLimiter:
# Rate Limited Decorator
# ============================
def rate_limited(
guest_per_minute: int = 10,
user_per_minute: int = 60,
admin_per_minute: int = 120,
endpoint_specific: bool = True
guest_per_minute: int = 10, user_per_minute: int = 60, admin_per_minute: int = 120, endpoint_specific: bool = True
):
"""
Decorator to easily apply rate limiting to endpoints
@ -320,11 +316,13 @@ def rate_limited(
):
return {"message": "Rate limited endpoint"}
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs):
# Extract dependencies from function signature
import inspect
inspect.signature(func)
# Get request, current_user, and rate_limiter from kwargs or args
@ -336,7 +334,7 @@ def rate_limited(
for param_name, param_value in kwargs.items():
if isinstance(param_value, Request):
request = param_value
elif hasattr(param_value, 'user_type'): # User-like object
elif hasattr(param_value, "user_type"): # User-like object
current_user = param_value
elif isinstance(param_value, RateLimiter):
rate_limiter = param_value
@ -350,30 +348,33 @@ def rate_limited(
# Apply rate limiting if we have the required components
if request and current_user and rate_limiter:
await apply_custom_rate_limiting(
request, current_user, rate_limiter,
guest_per_minute, user_per_minute, admin_per_minute
request, current_user, rate_limiter, guest_per_minute, user_per_minute, admin_per_minute
)
# Call the original function
return await func(*args, **kwargs)
return wrapper
return decorator
async def apply_custom_rate_limiting(
request: Request,
current_user,
rate_limiter: RateLimiter,
guest_per_minute: int,
user_per_minute: int,
admin_per_minute: int
admin_per_minute: int,
):
"""Apply custom rate limiting with specified limits"""
try:
# Determine user info
user_id = current_user.id
user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type)
is_admin = getattr(current_user, 'is_admin', False)
user_type = (
current_user.user_type.value if hasattr(current_user.user_type, "value") else str(current_user.user_type)
)
is_admin = getattr(current_user, "is_admin", False)
# Determine appropriate limit
if is_admin:
@ -385,13 +386,17 @@ async def apply_custom_rate_limiting(
# Create custom rate limit key
current_time = datetime.now(UTC)
custom_key = f"custom_rate_limit:{request.url.path}:{user_type}:{user_id}:minute:{current_time.strftime('%Y%m%d%H%M')}"
custom_key = (
f"custom_rate_limit:{request.url.path}:{user_type}:{user_id}:minute:{current_time.strftime('%Y%m%d%H%M')}"
)
# Check current usage
current_count = int(await rate_limiter.redis.get(custom_key) or 0)
if current_count >= requests_per_minute:
logger.warning(f"🚫 Custom rate limit exceeded for {user_type} {user_id}: {current_count}/{requests_per_minute}")
logger.warning(
f"🚫 Custom rate limit exceeded for {user_type} {user_id}: {current_count}/{requests_per_minute}"
)
raise HTTPException(
status_code=429,
detail={
@ -399,9 +404,9 @@ async def apply_custom_rate_limiting(
"message": f"Custom rate limit exceeded: {current_count}/{requests_per_minute} requests per minute",
"retryAfter": 60 - current_time.second,
"userType": user_type,
"endpoint": request.url.path
"endpoint": request.url.path,
},
headers={"Retry-After": str(60 - current_time.second)}
headers={"Retry-After": str(60 - current_time.second)},
)
# Increment counter
@ -410,7 +415,9 @@ async def apply_custom_rate_limiting(
pipe.expire(custom_key, 120) # 2 minutes TTL
await pipe.execute()
logger.debug(f"✅ Custom rate limit check passed for {user_type} {user_id}: {current_count + 1}/{requests_per_minute}")
logger.debug(
f"✅ Custom rate limit check passed for {user_type} {user_id}: {current_count + 1}/{requests_per_minute}"
)
except HTTPException:
raise
@ -418,15 +425,13 @@ async def apply_custom_rate_limiting(
logger.error(f"❌ Custom rate limiting error: {e}")
# Fail open
# ============================
# Alternative: FastAPI Dependency-Based Rate Limiting
# ============================
def create_rate_limit_dependency(
guest_per_minute: int = 10,
user_per_minute: int = 60,
admin_per_minute: int = 120
):
def create_rate_limit_dependency(guest_per_minute: int = 10, user_per_minute: int = 60, admin_per_minute: int = 120):
"""
Create a FastAPI dependency for rate limiting
@ -441,23 +446,25 @@ def create_rate_limit_dependency(
):
return {"message": "Rate limited endpoint"}
"""
async def rate_limit_dependency(
request: Request,
current_user=Depends(get_current_user_or_guest),
rate_limiter: RateLimiter = Depends(get_rate_limiter)
rate_limiter: RateLimiter = Depends(get_rate_limiter),
):
await apply_custom_rate_limiting(
request, current_user, rate_limiter,
guest_per_minute, user_per_minute, admin_per_minute
request, current_user, rate_limiter, guest_per_minute, user_per_minute, admin_per_minute
)
return True
return rate_limit_dependency
# ============================
# Rate Limiting Utilities
# ============================
class EndpointRateLimiter:
"""Utility class for endpoint-specific rate limiting"""
@ -477,9 +484,11 @@ class EndpointRateLimiter:
return True # No custom limits set
limits = self.custom_limits[endpoint]
user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type)
user_type = (
current_user.user_type.value if hasattr(current_user.user_type, "value") else str(current_user.user_type)
)
if getattr(current_user, 'is_admin', False):
if getattr(current_user, "is_admin", False):
user_type = "admin"
limit = limits.get(user_type, limits.get("default", 60))
@ -491,8 +500,7 @@ class EndpointRateLimiter:
if current_count >= limit:
raise HTTPException(
status_code=429,
detail=f"Endpoint rate limit exceeded: {current_count}/{limit} for {endpoint}"
status_code=429, detail=f"Endpoint rate limit exceeded: {current_count}/{limit} for {endpoint}"
)
# Increment counter
@ -501,9 +509,11 @@ class EndpointRateLimiter:
return True
# Global endpoint rate limiter instance
endpoint_rate_limiter = None
def get_endpoint_rate_limiter(rate_limiter: RateLimiter = Depends(get_rate_limiter)) -> EndpointRateLimiter:
"""Get endpoint rate limiter instance"""
global endpoint_rate_limiter
@ -511,15 +521,14 @@ def get_endpoint_rate_limiter(rate_limiter: RateLimiter = Depends(get_rate_limit
endpoint_rate_limiter = EndpointRateLimiter(rate_limiter)
# Configure endpoint-specific limits
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/chat/sessions/*/messages/stream", {
"guest": 5, "candidate": 30, "employer": 30, "admin": 100
})
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/candidates/documents/upload", {
"guest": 2, "candidate": 10, "employer": 10, "admin": 50
})
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/jobs", {
"guest": 1, "candidate": 5, "employer": 20, "admin": 50
})
endpoint_rate_limiter.set_endpoint_limits(
"/api/1.0/chat/sessions/*/messages/stream", {"guest": 5, "candidate": 30, "employer": 30, "admin": 100}
)
endpoint_rate_limiter.set_endpoint_limits(
"/api/1.0/candidates/documents/upload", {"guest": 2, "candidate": 10, "employer": 10, "admin": 50}
)
endpoint_rate_limiter.set_endpoint_limits(
"/api/1.0/jobs", {"guest": 1, "candidate": 5, "employer": 20, "admin": 50}
)
return endpoint_rate_limiter

View File

@ -3,37 +3,17 @@ Response utility functions for consistent API responses
"""
from typing import Any, Optional, Dict, List
def create_success_response(data: Any, meta: Optional[Dict] = None) -> Dict:
return {
"success": True,
"data": data,
"meta": meta
}
return {"success": True, "data": data, "meta": meta}
def create_error_response(code: str, message: str, details: Any = None) -> Dict:
return {
"success": False,
"error": {
"code": code,
"message": message,
"details": details
}
}
return {"success": False, "error": {"code": code, "message": message, "details": details}}
def create_paginated_response(
data: List[Any],
page: int,
limit: int,
total: int
) -> Dict:
def create_paginated_response(data: List[Any], page: int, limit: int, total: int) -> Dict:
total_pages = (total + limit - 1) // limit
has_more = page < total_pages
return {
"data": data,
"total": total,
"page": page,
"limit": limit,
"totalPages": total_pages,
"hasMore": has_more
}
return {"data": data, "total": total, "page": page, "limit": limit, "totalPages": total_pages, "hasMore": has_more}