ruff reformat

This commit is contained in:
James Ketr 2025-06-18 13:53:07 -07:00
parent f1c2e16389
commit 2cf3fa7b04
69 changed files with 4863 additions and 4915 deletions

View File

@ -17,11 +17,9 @@ from logger import logger
AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses
# Maps class_name to (module_name, class_name) # Maps class_name to (module_name, class_name)
class_registry: Dict[str, Tuple[str, str]] = ( class_registry: Dict[str, Tuple[str, str]] = {}
{}
)
__all__ = ['get_or_create_agent'] __all__ = ["get_or_create_agent"]
package_dir = pathlib.Path(__file__).parent package_dir = pathlib.Path(__file__).parent
package_name = __name__ package_name = __name__
@ -38,12 +36,7 @@ for path in package_dir.glob("*.py"):
# Find all Agent subclasses in the module # Find all Agent subclasses in the module
for name, obj in inspect.getmembers(module, inspect.isclass): for name, obj in inspect.getmembers(module, inspect.isclass):
if ( if issubclass(obj, AnyAgent) and obj is not AnyAgent and obj is not Agent and name not in class_registry:
issubclass(obj, AnyAgent)
and obj is not AnyAgent
and obj is not Agent
and name not in class_registry
):
class_registry[name] = (full_module_name, name) class_registry[name] = (full_module_name, name)
globals()[name] = obj globals()[name] = obj
logger.info(f"Adding agent: {name}") logger.info(f"Adding agent: {name}")

View File

@ -32,19 +32,45 @@ from pathlib import Path
from rag import start_file_watcher, ChromaDBFileWatcher from rag import start_file_watcher, ChromaDBFileWatcher
import defines import defines
from logger import logger from logger import logger
from models import (Tunables, ChatMessageUser, ChatMessage, RagEntry, ChatMessageMetaData, ApiStatusType, Candidate, ChatContextType) from models import (
Tunables,
ChatMessageUser,
ChatMessage,
RagEntry,
ChatMessageMetaData,
ApiStatusType,
Candidate,
ChatContextType,
)
import utils.llm_proxy as llm_manager import utils.llm_proxy as llm_manager
from database.manager import RedisDatabase from database.manager import RedisDatabase
from models import ChromaDBGetResponse from models import ChromaDBGetResponse
from utils.metrics import Metrics from utils.metrics import Metrics
from models import ( ApiActivityType, ApiMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiStatusType, ChatMessageMetaData, Candidate) from models import (
ApiActivityType,
ApiMessage,
ChatMessageError,
ChatMessageRagSearch,
ChatMessageStatus,
ChatMessageStreaming,
LLMMessage,
ChatMessage,
ChatOptions,
ChatMessageUser,
Tunables,
ApiStatusType,
ChatMessageMetaData,
Candidate,
)
from logger import logger from logger import logger
import defines import defines
from .registry import agent_registry from .registry import agent_registry
from models import ( ChromaDBGetResponse ) from models import ChromaDBGetResponse
class CandidateEntity(Candidate): class CandidateEntity(Candidate):
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
@ -59,13 +85,10 @@ class CandidateEntity(Candidate):
CandidateEntity__agents: List[Agent] = [] CandidateEntity__agents: List[Agent] = []
CandidateEntity__observer: Optional[Any] = Field(default=None, exclude=True) CandidateEntity__observer: Optional[Any] = Field(default=None, exclude=True)
CandidateEntity__file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True) CandidateEntity__file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
CandidateEntity__prometheus_collector: Optional[CollectorRegistry] = Field( CandidateEntity__prometheus_collector: Optional[CollectorRegistry] = Field(default=None, exclude=True)
default=None, exclude=True
)
CandidateEntity__metrics: Optional[Metrics] = Field( CandidateEntity__metrics: Optional[Metrics] = Field(
default=None, default=None, description="Metrics collector for this agent, used to track performance and usage."
description="Metrics collector for this agent, used to track performance and usage."
) )
def __init__(self, candidate=None): def __init__(self, candidate=None):
@ -78,7 +101,7 @@ class CandidateEntity(Candidate):
@classmethod @classmethod
def exists(cls, username: str): def exists(cls, username: str):
# Validate username format (only allow safe characters) # Validate username format (only allow safe characters)
if not re.match(r'^[a-zA-Z0-9_-]+$', username): if not re.match(r"^[a-zA-Z0-9_-]+$", username):
return False # Invalid username characters return False # Invalid username characters
# Check for minimum and maximum length # Check for minimum and maximum length
@ -117,11 +140,7 @@ class CandidateEntity(Candidate):
if agent.agent_type == agent_type: if agent.agent_type == agent_type:
return agent return agent
return get_or_create_agent( return get_or_create_agent(agent_type=agent_type, user=self, prometheus_collector=self.prometheus_collector)
agent_type=agent_type,
user=self,
prometheus_collector=self.prometheus_collector
)
# Wrapper properties that map into file_watcher # Wrapper properties that map into file_watcher
@property @property
@ -132,6 +151,7 @@ class CandidateEntity(Candidate):
# Fields managed by initialize() # Fields managed by initialize()
CandidateEntity__initialized: bool = Field(default=False, exclude=True) CandidateEntity__initialized: bool = Field(default=False, exclude=True)
@property @property
def metrics(self) -> Metrics: def metrics(self) -> Metrics:
if not self.CandidateEntity__metrics: if not self.CandidateEntity__metrics:
@ -160,15 +180,10 @@ class CandidateEntity(Candidate):
if not self.metrics: if not self.metrics:
logger.warning("No metrics collector set for this agent.") logger.warning("No metrics collector set for this agent.")
return return
self.metrics.tokens_prompt.labels(agent=agent.agent_type).inc( self.metrics.tokens_prompt.labels(agent=agent.agent_type).inc(response.usage.prompt_eval_count)
response.usage.prompt_eval_count
)
self.metrics.tokens_eval.labels(agent=agent.agent_type).inc(response.usage.eval_count) self.metrics.tokens_eval.labels(agent=agent.agent_type).inc(response.usage.eval_count)
async def initialize( async def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
self,
prometheus_collector: CollectorRegistry,
database: RedisDatabase):
if self.CandidateEntity__initialized: if self.CandidateEntity__initialized:
# Initialization can only be attempted once; if there are multiple attempts, it means # Initialization can only be attempted once; if there are multiple attempts, it means
# a subsystem is failing or there is a logic bug in the code. # a subsystem is failing or there is a logic bug in the code.
@ -205,17 +220,21 @@ class CandidateEntity(Candidate):
) )
has_username_rag = any(item.name == self.username for item in self.rags) has_username_rag = any(item.name == self.username for item in self.rags)
if not has_username_rag: if not has_username_rag:
self.rags.append(RagEntry( self.rags.append(
RagEntry(
name=self.username, name=self.username,
description=f"Expert data about {self.full_name}.", description=f"Expert data about {self.full_name}.",
)) )
)
self.rag_content_size = self.file_watcher.collection.count() self.rag_content_size = self.file_watcher.collection.count()
class Agent(BaseModel, ABC): class Agent(BaseModel, ABC):
""" """
Base class for all agent types. Base class for all agent types.
This class defines the common attributes and methods for all agent types. This class defines the common attributes and methods for all agent types.
""" """
class Config: class Config:
arbitrary_types_allowed = True # Allow arbitrary types like RedisDatabase arbitrary_types_allowed = True # Allow arbitrary types like RedisDatabase
@ -237,7 +256,7 @@ class Agent(BaseModel, ABC):
conversation: List[ChatMessageUser] = Field( conversation: List[ChatMessageUser] = Field(
default_factory=list, default_factory=list,
description="Conversation history for this agent, used to maintain context across messages." description="Conversation history for this agent, used to maintain context across messages.",
) )
@property @property
@ -254,9 +273,7 @@ class Agent(BaseModel, ABC):
last_item = item last_item = item
return last_item return last_item
def set_optimal_context_size( def set_optimal_context_size(self, llm: Any, model: str, prompt: str, ctx_buffer=2048) -> int:
self, llm: Any, model: str, prompt: str, ctx_buffer=2048
) -> int:
# Most models average 1.3-1.5 tokens per word # Most models average 1.3-1.5 tokens per word
word_count = len(prompt.split()) word_count = len(prompt.split())
tokens = int(word_count * 1.4) tokens = int(word_count * 1.4)
@ -265,9 +282,7 @@ class Agent(BaseModel, ABC):
total_ctx = tokens + ctx_buffer total_ctx = tokens + ctx_buffer
if total_ctx > self.context_size: if total_ctx > self.context_size:
logger.info( logger.info(f"Increasing context size from {self.context_size} to {total_ctx}")
f"Increasing context size from {self.context_size} to {total_ctx}"
)
# Grow the context size if necessary # Grow the context size if necessary
self.context_size = max(self.context_size, total_ctx) self.context_size = max(self.context_size, total_ctx)
@ -472,16 +487,16 @@ class Agent(BaseModel, ABC):
context = [] context = []
for chroma_results in rag_message.content: for chroma_results in rag_message.content:
for index, metadata in enumerate(chroma_results.metadatas): for index, metadata in enumerate(chroma_results.metadatas):
content = "\n".join([ content = "\n".join(
line.strip() [line.strip() for line in chroma_results.documents[index].split("\n") if line]
for line in chroma_results.documents[index].split("\n") ).strip()
if line context.append(
]).strip() f"""
context.append(f"""
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")} Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")}
Document reference: {chroma_results.ids[index]} Document reference: {chroma_results.ids[index]}
Content: {content} Content: {content}
""") """
)
return "\n".join(context) return "\n".join(context)
async def generate_rag_results( async def generate_rag_results(
@ -501,10 +516,7 @@ Content: {content}
A list of dictionaries containing the RAG results. A list of dictionaries containing the RAG results.
""" """
if not self.user: if not self.user:
error_message = ChatMessageError( error_message = ChatMessageError(session_id=session_id, content="No user set for RAG generation.")
session_id=session_id,
content="No user set for RAG generation."
)
yield error_message yield error_message
return return
@ -517,14 +529,12 @@ Content: {content}
status_message = ChatMessageStatus( status_message = ChatMessageStatus(
session_id=session_id, session_id=session_id,
activity=ApiActivityType.SEARCHING, activity=ApiActivityType.SEARCHING,
content = f"Searching RAG context {rag.name}..." content=f"Searching RAG context {rag.name}...",
) )
yield status_message yield status_message
try: try:
chroma_results = await user.file_watcher.find_similar( chroma_results = await user.file_watcher.find_similar(query=prompt, top_k=top_k, threshold=threshold)
query=prompt, top_k=top_k, threshold=threshold
)
if not chroma_results: if not chroma_results:
continue continue
query_embedding = np.array(chroma_results["query_embedding"]).flatten() # type: ignore query_embedding = np.array(chroma_results["query_embedding"]).flatten() # type: ignore
@ -548,7 +558,7 @@ Content: {content}
continue_message = ChatMessageStatus( continue_message = ChatMessageStatus(
session_id=session_id, session_id=session_id,
activity=ApiActivityType.SEARCHING, activity=ApiActivityType.SEARCHING,
content=f"Error searching RAG context {rag.name}: {str(e)}" content=f"Error searching RAG context {rag.name}: {str(e)}",
) )
yield continue_message yield continue_message
@ -562,22 +572,20 @@ Content: {content}
async def llm_one_shot( async def llm_one_shot(
self, self,
llm: Any, model: str, llm: Any,
session_id: str, prompt: str, system_prompt: str, model: str,
session_id: str,
prompt: str,
system_prompt: str,
tunables: Optional[Tunables] = None, tunables: Optional[Tunables] = None,
temperature=0.7) -> AsyncGenerator[ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessage, None]: temperature=0.7,
) -> AsyncGenerator[ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessage, None]:
if not self.user: if not self.user:
error_message = ChatMessageError( error_message = ChatMessageError(session_id=session_id, content="No user set for chat generation.")
session_id=session_id,
content="No user set for chat generation."
)
yield error_message yield error_message
return return
self.set_optimal_context_size( self.set_optimal_context_size(llm=llm, model=model, prompt=prompt + system_prompt)
llm=llm, model=model, prompt=prompt+system_prompt
)
options = ChatOptions( options = ChatOptions(
seed=8911, seed=8911,
@ -591,9 +599,7 @@ Content: {content}
] ]
status_message = ChatMessageStatus( status_message = ChatMessageStatus(
session_id=session_id, session_id=session_id, activity=ApiActivityType.GENERATING, content="Generating response..."
activity=ApiActivityType.GENERATING,
content="Generating response..."
) )
yield status_message yield status_message
@ -609,10 +615,7 @@ Content: {content}
stream=True, stream=True,
): ):
if not response: if not response:
error_message = ChatMessageError( error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
session_id=session_id,
content="No response from LLM."
)
yield error_message yield error_message
return return
@ -627,17 +630,12 @@ Content: {content}
yield streaming_message yield streaming_message
if not response: if not response:
error_message = ChatMessageError( error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
session_id=session_id,
content="No response from LLM."
)
yield error_message yield error_message
return return
self.user.collect_metrics(agent=self, response=response) self.user.collect_metrics(agent=self, response=response)
self.context_tokens = ( self.context_tokens = response.usage.prompt_eval_count + response.usage.eval_count
response.usage.prompt_eval_count + response.usage.eval_count
)
chat_message = ChatMessage( chat_message = ChatMessage(
session_id=session_id, session_id=session_id,
@ -650,23 +648,16 @@ Content: {content}
eval_duration=response.usage.eval_duration, eval_duration=response.usage.eval_duration,
prompt_eval_count=response.usage.prompt_eval_count, prompt_eval_count=response.usage.prompt_eval_count,
prompt_eval_duration=response.usage.prompt_eval_duration, prompt_eval_duration=response.usage.prompt_eval_duration,
),
)
) )
yield chat_message yield chat_message
return return
async def generate( async def generate(
self, llm: Any, model: str, self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
) -> AsyncGenerator[ApiMessage, None]: ) -> AsyncGenerator[ApiMessage, None]:
if not self.user: if not self.user:
error_message = ChatMessageError( error_message = ChatMessageError(session_id=session_id, content="No user set for chat generation.")
session_id=session_id,
content="No user set for chat generation."
)
yield error_message yield error_message
return return
@ -690,37 +681,32 @@ Content: {content}
yield message yield message
if not isinstance(message, ChatMessageRagSearch): if not isinstance(message, ChatMessageRagSearch):
raise ValueError( raise ValueError(f"Expected ChatMessageRagSearch, got {type(rag_message)}")
f"Expected ChatMessageRagSearch, got {type(rag_message)}"
)
rag_message = message rag_message = message
context = self.get_rag_context(rag_message) context = self.get_rag_context(rag_message)
# Create a pruned down message list based purely on the prompt and responses, # Create a pruned down message list based purely on the prompt and responses,
# discarding the full preamble generated by prepare_message # discarding the full preamble generated by prepare_message
messages: List[LLMMessage] = [ messages: List[LLMMessage] = [LLMMessage(role="system", content=self.system_prompt)]
LLMMessage(role="system", content=self.system_prompt)
]
# Add the conversation history to the messages # Add the conversation history to the messages
messages.extend([ messages.extend(
[
LLMMessage(role="user" if isinstance(m, ChatMessageUser) else "assistant", content=m.content) LLMMessage(role="user" if isinstance(m, ChatMessageUser) else "assistant", content=m.content)
for m in self.conversation for m in self.conversation
]) ]
)
# Add the RAG context to the messages if available # Add the RAG context to the messages if available
if context: if context:
messages.append( messages.append(
LLMMessage( LLMMessage(
role="user", role="user",
content=f"<|context|>\nThe following is context information about {self.user.full_name}:\n{context}\n</|context|>\n\nPrompt to respond to:\n{prompt}\n" content=f"<|context|>\nThe following is context information about {self.user.full_name}:\n{context}\n</|context|>\n\nPrompt to respond to:\n{prompt}\n",
) )
) )
else: else:
# Only the actual user query is provided with the full context message # Only the actual user query is provided with the full context message
messages.append( messages.append(LLMMessage(role="user", content=prompt))
LLMMessage(role="user", content=prompt)
)
# use_tools = message.tunables.enable_tools and len(self.context.tools) > 0 # use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
# message.metadata.tools = { # message.metadata.tools = {
@ -824,16 +810,12 @@ Content: {content}
# not use_tools # not use_tools
status_message = ChatMessageStatus( status_message = ChatMessageStatus(
session_id=session_id, session_id=session_id, activity=ApiActivityType.GENERATING, content="Generating response..."
activity=ApiActivityType.GENERATING,
content="Generating response..."
) )
yield status_message yield status_message
# Set the response for streaming # Set the response for streaming
self.set_optimal_context_size( self.set_optimal_context_size(llm, model, prompt=prompt)
llm, model, prompt=prompt
)
options = ChatOptions( options = ChatOptions(
seed=8911, seed=8911,
@ -853,10 +835,7 @@ Content: {content}
stream=True, stream=True,
): ):
if not response: if not response:
error_message = ChatMessageError( error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
session_id=session_id,
content="No response from LLM."
)
yield error_message yield error_message
return return
@ -870,17 +849,12 @@ Content: {content}
yield streaming_message yield streaming_message
if not response: if not response:
error_message = ChatMessageError( error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
session_id=session_id,
content="No response from LLM."
)
yield error_message yield error_message
return return
self.user.collect_metrics(agent=self, response=response) self.user.collect_metrics(agent=self, response=response)
self.context_tokens = ( self.context_tokens = response.usage.prompt_eval_count + response.usage.eval_count
response.usage.prompt_eval_count + response.usage.eval_count
)
end_time = time.perf_counter() end_time = time.perf_counter()
chat_message = ChatMessage( chat_message = ChatMessage(
@ -899,9 +873,8 @@ Content: {content}
"llm_streamed": end_time - start_time, "llm_streamed": end_time - start_time,
"llm_with_tools": 0, # Placeholder for tool processing time "llm_with_tools": 0, # Placeholder for tool processing time
}, },
),
) )
)
# Add the user and chat messages to the conversation # Add the user and chat messages to the conversation
self.conversation.append(user_message) self.conversation.append(user_message)
@ -996,12 +969,13 @@ Content: {content}
raise ValueError("No Markdown found in the response") raise ValueError("No Markdown found in the response")
_agents: List[Agent] = [] _agents: List[Agent] = []
def get_or_create_agent( def get_or_create_agent(
agent_type: str, agent_type: str, prometheus_collector: CollectorRegistry, user: Optional[CandidateEntity] = None
prometheus_collector: CollectorRegistry, ) -> Agent:
user: Optional[CandidateEntity]=None) -> Agent:
""" """
Get or create and append a new agent of the specified type, ensuring only one agent per type exists. Get or create and append a new agent of the specified type, ensuring only one agent per type exists.
@ -1025,14 +999,16 @@ def get_or_create_agent(
for agent_cls in Agent.__subclasses__(): for agent_cls in Agent.__subclasses__():
if agent_cls.model_fields["agent_type"].default == agent_type: if agent_cls.model_fields["agent_type"].default == agent_type:
# Create the agent instance with provided kwargs # Create the agent instance with provided kwargs
agent = agent_cls(agent_type=agent_type, # type: ignore[call-arg] agent = agent_cls(
user=user) agent_type=agent_type, # type: ignore[call-arg]
user=user,
)
_agents.append(agent) _agents.append(agent)
return agent return agent
raise ValueError(f"No agent class found for agent_type: {agent_type}") raise ValueError(f"No agent class found for agent_type: {agent_type}")
# Register the base agent # Register the base agent
agent_registry.register(Agent._agent_type, Agent) agent_registry.register(Agent._agent_type, Agent)
CandidateEntity.model_rebuild() CandidateEntity.model_rebuild()

View File

@ -5,7 +5,7 @@ from .base import Agent, agent_registry
from logger import logger from logger import logger
from .registry import agent_registry from .registry import agent_registry
from models import ( ApiMessage, Tunables, ApiStatusType) from models import ApiMessage, Tunables, ApiStatusType
system_message = """ system_message = """
@ -21,6 +21,7 @@ Always <|context|> when possible. Be concise, and never make up information. If
Before answering, ensure you have spelled the candidate's name correctly. Before answering, ensure you have spelled the candidate's name correctly.
""" """
class CandidateChat(Agent): class CandidateChat(Agent):
""" """
CandidateChat Agent CandidateChat Agent
@ -32,10 +33,7 @@ class CandidateChat(Agent):
system_prompt: str = system_message system_prompt: str = system_message
async def generate( async def generate(
self, llm: Any, model: str, self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
) -> AsyncGenerator[ApiMessage, None]: ) -> AsyncGenerator[ApiMessage, None]:
user = self.user user = self.user
if not user: if not user:
@ -54,12 +52,14 @@ Use that spelling instead of any spelling you may find in the <|context|>.
{system_message} {system_message}
""" """
async for message in super().generate(llm=llm, model=model, session_id=session_id, prompt=prompt, temperature=temperature, tunables=tunables): async for message in super().generate(
llm=llm, model=model, session_id=session_id, prompt=prompt, temperature=temperature, tunables=tunables
):
if message.status == ApiStatusType.ERROR: if message.status == ApiStatusType.ERROR:
yield message yield message
return return
yield message yield message
# Register the base agent # Register the base agent
agent_registry.register(CandidateChat._agent_type, CandidateChat) agent_registry.register(CandidateChat._agent_type, CandidateChat)

View File

@ -49,11 +49,13 @@ class Chat(Agent):
""" """
Chat Agent Chat Agent
""" """
agent_type: Literal["general"] = "general" # type: ignore agent_type: Literal["general"] = "general" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
system_prompt: str = system_message system_prompt: str = system_message
# async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]: # async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") # logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
# if not self.context: # if not self.context:

View File

@ -4,7 +4,7 @@ from typing import (
ClassVar, ClassVar,
Any, Any,
AsyncGenerator, AsyncGenerator,
Optional Optional,
# override # override
) # NOTE: You must import Optional for late binding to work ) # NOTE: You must import Optional for late binding to work
import random import random
@ -13,7 +13,15 @@ import time
import os import os
from .base import Agent, agent_registry from .base import Agent, agent_registry
from models import ApiActivityType, ChatMessage, ChatMessageError, ChatMessageStatus, ChatMessageStreaming, ApiStatusType, Tunables from models import (
ApiActivityType,
ChatMessage,
ChatMessageError,
ChatMessageStatus,
ChatMessageStreaming,
ApiStatusType,
Tunables,
)
from logger import logger from logger import logger
import defines import defines
import backstory_traceback as traceback import backstory_traceback as traceback
@ -23,6 +31,7 @@ from image_generator.profile_image import generate_image, ImageRequest
seed = int(time.time()) seed = int(time.time())
random.seed(seed) random.seed(seed)
class ImageGenerator(Agent): class ImageGenerator(Agent):
agent_type: Literal["generate_image"] = "generate_image" # type: ignore agent_type: Literal["generate_image"] = "generate_image" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
@ -31,10 +40,7 @@ class ImageGenerator(Agent):
system_prompt: str = "" # No system prompt is used system_prompt: str = "" # No system prompt is used
async def generate( async def generate(
self, llm: Any, model: str, self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]: ) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
if not self.user: if not self.user:
logger.error("User is not set for ImageGenerator agent.") logger.error("User is not set for ImageGenerator agent.")
@ -54,11 +60,17 @@ class ImageGenerator(Agent):
yield status_message yield status_message
logger.info(f"Image generation: {file_path} <- {prompt}") logger.info(f"Image generation: {file_path} <- {prompt}")
request = ImageRequest(filepath=file_path, session_id=session_id, prompt=prompt, iterations=4, height=256, width=256, guidance_scale=7.5) request = ImageRequest(
filepath=file_path,
session_id=session_id,
prompt=prompt,
iterations=4,
height=256,
width=256,
guidance_scale=7.5,
)
generated_message = None generated_message = None
async for generated_message in generate_image( async for generated_message in generate_image(request=request):
request=request
):
if generated_message.status == ApiStatusType.ERROR: if generated_message.status == ApiStatusType.ERROR:
yield generated_message yield generated_message
return return
@ -68,8 +80,7 @@ class ImageGenerator(Agent):
if generated_message is None: if generated_message is None:
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=session_id, session_id=session_id, content="Image generation failed to produce a valid response."
content="Image generation failed to produce a valid response."
) )
logger.error(f"⚠️ {error_message.content}") logger.error(f"⚠️ {error_message.content}")
yield error_message yield error_message
@ -84,20 +95,18 @@ class ImageGenerator(Agent):
session_id=session_id, session_id=session_id,
status=ApiStatusType.DONE, status=ApiStatusType.DONE,
content=f"{defines.api_prefix}/profile/{user.username}", content=f"{defines.api_prefix}/profile/{user.username}",
metadata=generated_message.metadata metadata=generated_message.metadata,
) )
yield generated_image yield generated_image
return return
except Exception as e: except Exception as e:
error_message = ChatMessageError( error_message = ChatMessageError(session_id=session_id, content=f"Error generating image: {str(e)}")
session_id=session_id,
content=f"Error generating image: {str(e)}"
)
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
logger.error(f"⚠️ {error_message.content}") logger.error(f"⚠️ {error_message.content}")
yield error_message yield error_message
return return
# Register the base agent # Register the base agent
agent_registry.register(ImageGenerator._agent_type, ImageGenerator) agent_registry.register(ImageGenerator._agent_type, ImageGenerator)

View File

@ -8,7 +8,7 @@ from typing import (
Tuple, Tuple,
AsyncGenerator, AsyncGenerator,
List, List,
Optional Optional,
# override # override
) # NOTE: You must import Optional for late binding to work ) # NOTE: You must import Optional for late binding to work
import random import random
@ -20,7 +20,16 @@ import os
import random import random
from .base import Agent, agent_registry from .base import Agent, agent_registry
from models import ApiActivityType, ChatMessage, ChatMessageError, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ApiStatusType, Tunables from models import (
ApiActivityType,
ChatMessage,
ChatMessageError,
ApiMessageType,
ChatMessageStatus,
ChatMessageStreaming,
ApiStatusType,
Tunables,
)
from logger import logger from logger import logger
import defines import defines
import backstory_traceback as traceback import backstory_traceback as traceback
@ -43,6 +52,7 @@ emptyUser = {
"questions": [], "questions": [],
} }
def generate_persona_system_prompt(persona: Dict[str, Any]) -> str: def generate_persona_system_prompt(persona: Dict[str, Any]) -> str:
return f"""\ return f"""\
You are a casting director for a movie. Your job is to provide information on ficticious personas for use in a screen play. You are a casting director for a movie. Your job is to provide information on ficticious personas for use in a screen play.
@ -84,6 +94,7 @@ DO NOT infer, imply, abbreviate, or state the ethnicity, gender, or age in the u
You are providing those only for use later by the system when casting individuals for the role. You are providing those only for use later by the system when casting individuals for the role.
""" """
generate_resume_system_prompt = """ generate_resume_system_prompt = """
You are a creative writing casting director. As part of the casting, you are building backstories about individuals. The first part You are a creative writing casting director. As part of the casting, you are building backstories about individuals. The first part
of that is to create an in-depth resume for the person. You will be provided with the following information: of that is to create an in-depth resume for the person. You will be provided with the following information:
@ -115,10 +126,12 @@ import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EthnicNameGenerator: class EthnicNameGenerator:
def __init__(self): def __init__(self):
try: try:
from names_dataset import NameDataset # type: ignore from names_dataset import NameDataset # type: ignore
self.nd = NameDataset() self.nd = NameDataset()
except ImportError: except ImportError:
logger.error("NameDataset not available. Please install: pip install names-dataset") logger.error("NameDataset not available. Please install: pip install names-dataset")
@ -129,24 +142,24 @@ class EthnicNameGenerator:
# US Census 2020 approximate ethnic distribution # US Census 2020 approximate ethnic distribution
self.ethnic_weights = { self.ethnic_weights = {
'White': 0.576, "White": 0.576,
'Hispanic': 0.186, "Hispanic": 0.186,
'Black': 0.134, "Black": 0.134,
'Asian': 0.062, "Asian": 0.062,
'Native American': 0.013, "Native American": 0.013,
'Pacific Islander': 0.003, "Pacific Islander": 0.003,
'Mixed/Other': 0.026 "Mixed/Other": 0.026,
} }
# Map ethnicities to countries (using alpha-2 codes that NameDataset uses) # Map ethnicities to countries (using alpha-2 codes that NameDataset uses)
self.ethnic_country_mapping = { self.ethnic_country_mapping = {
'White': ['US', 'GB', 'DE', 'IE', 'IT', 'PL', 'FR', 'CA', 'AU'], "White": ["US", "GB", "DE", "IE", "IT", "PL", "FR", "CA", "AU"],
'Hispanic': ['MX', 'ES', 'CO', 'PE', 'AR', 'CU', 'VE', 'CL'], "Hispanic": ["MX", "ES", "CO", "PE", "AR", "CU", "VE", "CL"],
'Black': ['US'], # African American names "Black": ["US"], # African American names
'Asian': ['CN', 'IN', 'PH', 'VN', 'KR', 'JP', 'TH', 'MY'], "Asian": ["CN", "IN", "PH", "VN", "KR", "JP", "TH", "MY"],
'Native American': ['US'], "Native American": ["US"],
'Pacific Islander': ['US'], "Pacific Islander": ["US"],
'Mixed/Other': ['US'] "Mixed/Other": ["US"],
} }
def get_weighted_ethnicity(self) -> str: def get_weighted_ethnicity(self) -> str:
@ -155,8 +168,9 @@ class EthnicNameGenerator:
weights = list(self.ethnic_weights.values()) weights = list(self.ethnic_weights.values())
return random.choices(ethnicities, weights=weights)[0] return random.choices(ethnicities, weights=weights)[0]
def get_names_by_criteria(self, countries: List[str], gender: Optional[str] = None, def get_names_by_criteria(
n: int = 50, use_first_names: bool = True) -> List[str]: self, countries: List[str], gender: Optional[str] = None, n: int = 50, use_first_names: bool = True
) -> List[str]:
"""Get names matching criteria using NameDataset's get_top_names method""" """Get names matching criteria using NameDataset's get_top_names method"""
if not self.nd: if not self.nd:
return [] return []
@ -166,16 +180,13 @@ class EthnicNameGenerator:
try: try:
# Get top names for this country # Get top names for this country
top_names = self.nd.get_top_names( top_names = self.nd.get_top_names(
n=n, n=n, use_first_names=use_first_names, country_alpha2=country_code, gender=gender
use_first_names=use_first_names,
country_alpha2=country_code,
gender=gender
) )
if country_code in top_names: if country_code in top_names:
if use_first_names and gender: if use_first_names and gender:
# For first names with gender specified # For first names with gender specified
gender_key = 'M' if gender.upper() in ['M', 'MALE'] else 'F' gender_key = "M" if gender.upper() in ["M", "MALE"] else "F"
if gender_key in top_names[country_code]: if gender_key in top_names[country_code]:
all_names.extend(top_names[country_code][gender_key]) all_names.extend(top_names[country_code][gender_key])
elif use_first_names: elif use_first_names:
@ -192,25 +203,18 @@ class EthnicNameGenerator:
return list(set(all_names)) # Remove duplicates return list(set(all_names)) # Remove duplicates
def get_name_by_ethnicity(self, ethnicity: str, gender: str = 'random') -> Tuple[str, str, str, str]: def get_name_by_ethnicity(self, ethnicity: str, gender: str = "random") -> Tuple[str, str, str, str]:
"""Generate a name based on ethnicity using the correct NameDataset API""" """Generate a name based on ethnicity using the correct NameDataset API"""
if gender == 'random': if gender == "random":
gender = random.choice(['Male', 'Female']) gender = random.choice(["Male", "Female"])
countries = self.ethnic_country_mapping.get(ethnicity, ['US']) countries = self.ethnic_country_mapping.get(ethnicity, ["US"])
# Get first names # Get first names
first_names = self.get_names_by_criteria( first_names = self.get_names_by_criteria(countries=countries, gender=gender, use_first_names=True)
countries=countries,
gender=gender,
use_first_names=True
)
# Get last names # Get last names
last_names = self.get_names_by_criteria( last_names = self.get_names_by_criteria(countries=countries, use_first_names=False)
countries=countries,
use_first_names=False
)
# Select names or use fallbacks # Select names or use fallbacks
if first_names: if first_names:
@ -230,57 +234,60 @@ class EthnicNameGenerator:
def _get_fallback_first_name(self, gender: str, ethnicity: str) -> str: def _get_fallback_first_name(self, gender: str, ethnicity: str) -> str:
"""Provide culturally appropriate fallback first names""" """Provide culturally appropriate fallback first names"""
fallback_names = { fallback_names = {
'White': { "White": {
'Male': ['James', 'Robert', 'John', 'Michael', 'William', 'David', 'Richard', 'Joseph'], "Male": ["James", "Robert", "John", "Michael", "William", "David", "Richard", "Joseph"],
'Female': ['Mary', 'Patricia', 'Jennifer', 'Linda', 'Elizabeth', 'Barbara', 'Susan', 'Jessica'] "Female": ["Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan", "Jessica"],
}, },
'Hispanic': { "Hispanic": {
'Male': ['José', 'Luis', 'Miguel', 'Juan', 'Francisco', 'Alejandro', 'Antonio', 'Carlos'], "Male": ["José", "Luis", "Miguel", "Juan", "Francisco", "Alejandro", "Antonio", "Carlos"],
'Female': ['María', 'Guadalupe', 'Juana', 'Margarita', 'Francisca', 'Teresa', 'Rosa', 'Ana'] "Female": ["María", "Guadalupe", "Juana", "Margarita", "Francisca", "Teresa", "Rosa", "Ana"],
}, },
'Black': { "Black": {
'Male': ['James', 'Robert', 'John', 'Michael', 'William', 'David', 'Richard', 'Charles'], "Male": ["James", "Robert", "John", "Michael", "William", "David", "Richard", "Charles"],
'Female': ['Mary', 'Patricia', 'Linda', 'Elizabeth', 'Barbara', 'Susan', 'Jessica', 'Sarah'] "Female": ["Mary", "Patricia", "Linda", "Elizabeth", "Barbara", "Susan", "Jessica", "Sarah"],
},
"Asian": {
"Male": ["Wei", "Ming", "Chen", "Li", "Kumar", "Raj", "Hiroshi", "Takeshi"],
"Female": ["Mei", "Lin", "Ling", "Priya", "Yuki", "Soo", "Hana", "Anh"],
}, },
'Asian': {
'Male': ['Wei', 'Ming', 'Chen', 'Li', 'Kumar', 'Raj', 'Hiroshi', 'Takeshi'],
'Female': ['Mei', 'Lin', 'Ling', 'Priya', 'Yuki', 'Soo', 'Hana', 'Anh']
}
} }
ethnicity_names = fallback_names.get(ethnicity, fallback_names['White']) ethnicity_names = fallback_names.get(ethnicity, fallback_names["White"])
return random.choice(ethnicity_names.get(gender, ethnicity_names['Male'])) return random.choice(ethnicity_names.get(gender, ethnicity_names["Male"]))
def _get_fallback_last_name(self, ethnicity: str) -> str: def _get_fallback_last_name(self, ethnicity: str) -> str:
"""Provide culturally appropriate fallback last names""" """Provide culturally appropriate fallback last names"""
fallback_surnames = { fallback_surnames = {
'White': ['Smith', 'Johnson', 'Williams', 'Brown', 'Jones', 'Miller', 'Wilson', 'Moore'], "White": ["Smith", "Johnson", "Williams", "Brown", "Jones", "Miller", "Wilson", "Moore"],
'Hispanic': ['García', 'Rodríguez', 'Martínez', 'López', 'González', 'Pérez', 'Sánchez', 'Ramírez'], "Hispanic": ["García", "Rodríguez", "Martínez", "López", "González", "Pérez", "Sánchez", "Ramírez"],
'Black': ['Johnson', 'Williams', 'Brown', 'Jones', 'Davis', 'Miller', 'Wilson', 'Moore'], "Black": ["Johnson", "Williams", "Brown", "Jones", "Davis", "Miller", "Wilson", "Moore"],
'Asian': ['Li', 'Wang', 'Zhang', 'Liu', 'Chen', 'Yang', 'Huang', 'Zhao'] "Asian": ["Li", "Wang", "Zhang", "Liu", "Chen", "Yang", "Huang", "Zhao"],
} }
return random.choice(fallback_surnames.get(ethnicity, fallback_surnames['White'])) return random.choice(fallback_surnames.get(ethnicity, fallback_surnames["White"]))
def generate_random_name(self, gender: str = 'random') -> Tuple[str, str, str, str]: def generate_random_name(self, gender: str = "random") -> Tuple[str, str, str, str]:
"""Generate a random name with ethnicity based on US demographics""" """Generate a random name with ethnicity based on US demographics"""
ethnicity = self.get_weighted_ethnicity() ethnicity = self.get_weighted_ethnicity()
return self.get_name_by_ethnicity(ethnicity, gender) return self.get_name_by_ethnicity(ethnicity, gender)
def generate_multiple_names(self, count: int = 10, gender: str = 'random') -> List[Dict]: def generate_multiple_names(self, count: int = 10, gender: str = "random") -> List[Dict]:
"""Generate multiple random names""" """Generate multiple random names"""
names = [] names = []
for _ in range(count): for _ in range(count):
first, last, ethnicity, actual_gender = self.generate_random_name(gender) first, last, ethnicity, actual_gender = self.generate_random_name(gender)
names.append({ names.append(
'full_name': f"{first} {last}", {
'first_name': first, "full_name": f"{first} {last}",
'last_name': last, "first_name": first,
'ethnicity': ethnicity, "last_name": last,
'gender': actual_gender "ethnicity": ethnicity,
}) "gender": actual_gender,
}
)
return names return names
class GeneratePersona(Agent): class GeneratePersona(Agent):
agent_type: Literal["generate_persona"] = "generate_persona" # type: ignore agent_type: Literal["generate_persona"] = "generate_persona" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
@ -305,10 +312,7 @@ class GeneratePersona(Agent):
self.full_name = f"{self.first_name} {self.last_name}" self.full_name = f"{self.first_name} {self.last_name}"
async def generate( async def generate(
self, llm: Any, model: str, self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]: ) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
self.randomize() self.randomize()
@ -340,7 +344,8 @@ Incorporate the following into the job description: {original_prompt}
logger.info("🤖 Generating persona...") logger.info("🤖 Generating persona...")
generating_message = None generating_message = None
async for generating_message in self.llm_one_shot( async for generating_message in self.llm_one_shot(
llm=llm, model=model, llm=llm,
model=model,
session_id=session_id, session_id=session_id,
prompt=prompt, prompt=prompt,
system_prompt=generate_persona_system_prompt(persona=persona), system_prompt=generate_persona_system_prompt(persona=persona),
@ -354,8 +359,7 @@ Incorporate the following into the job description: {original_prompt}
if not generating_message: if not generating_message:
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=session_id, session_id=session_id, content="Persona generation failed to generate a response."
content="Persona generation failed to generate a response."
) )
yield error_message yield error_message
return return
@ -373,7 +377,7 @@ Incorporate the following into the job description: {original_prompt}
self.username = persona.get("username", None) self.username = persona.get("username", None)
if not self.username: if not self.username:
raise ValueError("LLM did not generate a username") raise ValueError("LLM did not generate a username")
self.username = re.sub(r'\s+', '.', self.username) self.username = re.sub(r"\s+", ".", self.username)
user_dir = os.path.join(defines.user_dir, persona["username"]) user_dir = os.path.join(defines.user_dir, persona["username"])
while os.path.exists(user_dir): while os.path.exists(user_dir):
match = re.match(r"^(.*?)(\d*)$", persona["username"]) match = re.match(r"^(.*?)(\d*)$", persona["username"])
@ -396,19 +400,14 @@ Incorporate the following into the job description: {original_prompt}
location_parts = persona["location"].split(",") location_parts = persona["location"].split(",")
if len(location_parts) == 3: if len(location_parts) == 3:
city, state, country = [part.strip() for part in location_parts] city, state, country = [part.strip() for part in location_parts]
persona["location"] = { persona["location"] = {"city": city, "state": state, "country": country}
"city": city,
"state": state,
"country": country
}
else: else:
logger.error(f"Invalid location format: {persona['location']}") logger.error(f"Invalid location format: {persona['location']}")
persona["location"] = None persona["location"] = None
persona["is_ai"] = True persona["is_ai"] = True
except Exception as e: except Exception as e:
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=session_id, session_id=session_id, content=f"Error parsing LLM response: {str(e)}\n\n{json_str}"
content=f"Error parsing LLM response: {str(e)}\n\n{json_str}"
) )
logger.error(f"❌ Error parsing LLM response: {error_message.content}") logger.error(f"❌ Error parsing LLM response: {error_message.content}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
@ -420,10 +419,7 @@ Incorporate the following into the job description: {original_prompt}
# Persona generated # Persona generated
persona_message = ChatMessage( persona_message = ChatMessage(
session_id=session_id, session_id=session_id, status=ApiStatusType.DONE, type=ApiMessageType.JSON, content=json.dumps(persona)
status=ApiStatusType.DONE,
type=ApiMessageType.JSON,
content = json.dumps(persona)
) )
yield persona_message yield persona_message
@ -433,7 +429,7 @@ Incorporate the following into the job description: {original_prompt}
status_message = ChatMessageStatus( status_message = ChatMessageStatus(
session_id=session_id, session_id=session_id,
activity=ApiActivityType.THINKING, activity=ApiActivityType.THINKING,
content = f"Generating resume for {persona['full_name']}..." content=f"Generating resume for {persona['full_name']}...",
) )
logger.info(f"🤖 {status_message.content}") logger.info(f"🤖 {status_message.content}")
yield status_message yield status_message
@ -456,7 +452,8 @@ Incorporate the following into the job description: {original_prompt}
Make sure at least one of the candidate's job descriptions take into account the following: {original_prompt}.""" Make sure at least one of the candidate's job descriptions take into account the following: {original_prompt}."""
async for generating_message in self.llm_one_shot( async for generating_message in self.llm_one_shot(
llm=llm, model=model, llm=llm,
model=model,
session_id=session_id, session_id=session_id,
prompt=content, prompt=content,
system_prompt=generate_resume_system_prompt, system_prompt=generate_resume_system_prompt,
@ -470,8 +467,7 @@ Make sure at least one of the candidate's job descriptions take into account the
if not generating_message: if not generating_message:
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=session_id, session_id=session_id, content="Resume generation failed to generate a response."
content="Resume generation failed to generate a response."
) )
logger.error(f"{error_message.content}") logger.error(f"{error_message.content}")
yield error_message yield error_message
@ -479,10 +475,7 @@ Make sure at least one of the candidate's job descriptions take into account the
resume = self.extract_markdown_from_text(generating_message.content) resume = self.extract_markdown_from_text(generating_message.content)
resume_message = ChatMessage( resume_message = ChatMessage(
session_id=session_id, session_id=session_id, status=ApiStatusType.DONE, type=ApiMessageType.TEXT, content=resume
status=ApiStatusType.DONE,
type=ApiMessageType.TEXT,
content=resume
) )
yield resume_message yield resume_message
return return
@ -502,5 +495,6 @@ Make sure at least one of the candidate's job descriptions take into account the
raise ValueError("No JSON found in the response") raise ValueError("No JSON found in the response")
# Register the base agent # Register the base agent
agent_registry.register(GeneratePersona._agent_type, GeneratePersona) agent_registry.register(GeneratePersona._agent_type, GeneratePersona)

View File

@ -4,23 +4,31 @@ from typing import (
ClassVar, ClassVar,
Any, Any,
AsyncGenerator, AsyncGenerator,
List List,
# override # override
) # NOTE: You must import Optional for late binding to work ) # NOTE: You must import Optional for late binding to work
import json import json
from logger import logger from logger import logger
from .base import Agent, agent_registry from .base import Agent, agent_registry
from models import (ApiActivityType, ApiMessage, ApiStatusType, ChatMessage, ChatMessageError, ChatMessageResume, ChatMessageStatus, SkillAssessment, SkillStrength) from models import (
ApiActivityType,
ApiMessage,
ApiStatusType,
ChatMessage,
ChatMessageError,
ChatMessageResume,
ChatMessageStatus,
SkillAssessment,
SkillStrength,
)
class GenerateResume(Agent): class GenerateResume(Agent):
agent_type: Literal["generate_resume"] = "generate_resume" # type: ignore agent_type: Literal["generate_resume"] = "generate_resume" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
def generate_resume_prompt( def generate_resume_prompt(self, skills: List[SkillAssessment]):
self,
skills: List[SkillAssessment]
):
""" """
Generate a professional resume based on skill assessment results Generate a professional resume based on skill assessment results
@ -41,7 +49,7 @@ class GenerateResume(Agent):
SkillStrength.STRONG: [], SkillStrength.STRONG: [],
SkillStrength.MODERATE: [], SkillStrength.MODERATE: [],
SkillStrength.WEAK: [], SkillStrength.WEAK: [],
SkillStrength.NONE: [] SkillStrength.NONE: [],
} }
experience_evidence = {} experience_evidence = {}
@ -63,11 +71,7 @@ class GenerateResume(Agent):
experience_evidence[source] = [] experience_evidence[source] = []
experience_evidence[source].append( experience_evidence[source].append(
{ {"skill": skill, "quote": evidence.quote, "context": evidence.context}
"skill": skill,
"quote": evidence.quote,
"context": evidence.context
}
) )
# Build the system prompt # Build the system prompt
@ -167,16 +171,16 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
) -> AsyncGenerator[ApiMessage, None]: ) -> AsyncGenerator[ApiMessage, None]:
# Stage 1A: Analyze job requirements # Stage 1A: Analyze job requirements
status_message = ChatMessageStatus( status_message = ChatMessageStatus(
session_id=session_id, session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING
content = "Analyzing job requirements",
activity=ApiActivityType.THINKING
) )
yield status_message yield status_message
system_prompt, prompt = self.generate_resume_prompt(skills=skills) system_prompt, prompt = self.generate_resume_prompt(skills=skills)
generated_message = None generated_message = None
async for generated_message in self.llm_one_shot(llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt): async for generated_message in self.llm_one_shot(
llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt
):
if generated_message.status == ApiStatusType.ERROR: if generated_message.status == ApiStatusType.ERROR:
yield generated_message yield generated_message
return return
@ -185,8 +189,7 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
if not generated_message: if not generated_message:
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=session_id, session_id=session_id, content="Job requirements analysis failed to generate a response."
content="Job requirements analysis failed to generate a response."
) )
logger.error(f"⚠️ {error_message.content}") logger.error(f"⚠️ {error_message.content}")
yield error_message yield error_message
@ -194,8 +197,7 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
if not isinstance(generated_message, ChatMessage): if not isinstance(generated_message, ChatMessage):
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=session_id, session_id=session_id, content="Job requirements analysis did not return a valid message."
content="Job requirements analysis did not return a valid message."
) )
logger.error(f"⚠️ {error_message.content}") logger.error(f"⚠️ {error_message.content}")
yield error_message yield error_message
@ -214,5 +216,6 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
logger.info("✅ Resume generation completed successfully.") logger.info("✅ Resume generation completed successfully.")
return return
# Register the base agent # Register the base agent
agent_registry.register(GenerateResume._agent_type, GenerateResume) agent_registry.register(GenerateResume._agent_type, GenerateResume)

View File

@ -5,17 +5,30 @@ from typing import (
ClassVar, ClassVar,
Any, Any,
AsyncGenerator, AsyncGenerator,
Optional Optional,
# override # override
) # NOTE: You must import Optional for late binding to work ) # NOTE: You must import Optional for late binding to work
import inspect import inspect
import json import json
from .base import Agent, agent_registry from .base import Agent, agent_registry
from models import ApiActivityType, ApiMessage, ChatMessage, ChatMessageError, ChatMessageStatus, ChatMessageStreaming, ApiStatusType, Job, JobRequirements, JobRequirementsMessage, Tunables from models import (
ApiActivityType,
ApiMessage,
ChatMessage,
ChatMessageError,
ChatMessageStatus,
ChatMessageStreaming,
ApiStatusType,
Job,
JobRequirements,
JobRequirementsMessage,
Tunables,
)
from logger import logger from logger import logger
import backstory_traceback as traceback import backstory_traceback as traceback
class JobRequirementsAgent(Agent): class JobRequirementsAgent(Agent):
agent_type: Literal["job_requirements"] = "job_requirements" # type: ignore agent_type: Literal["job_requirements"] = "job_requirements" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
@ -91,14 +104,14 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
"""Analyze job requirements from job description.""" """Analyze job requirements from job description."""
system_prompt, prompt = self.create_job_analysis_prompt(prompt) system_prompt, prompt = self.create_job_analysis_prompt(prompt)
status_message = ChatMessageStatus( status_message = ChatMessageStatus(
session_id=session_id, session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING
content="Analyzing job requirements",
activity=ApiActivityType.THINKING
) )
yield status_message yield status_message
logger.info(f"🔍 {status_message.content}") logger.info(f"🔍 {status_message.content}")
generated_message = None generated_message = None
async for generated_message in self.llm_one_shot(llm, model, session_id=session_id, prompt=prompt, system_prompt=system_prompt): async for generated_message in self.llm_one_shot(
llm, model, session_id=session_id, prompt=prompt, system_prompt=system_prompt
):
if generated_message.status == ApiStatusType.ERROR: if generated_message.status == ApiStatusType.ERROR:
yield generated_message yield generated_message
return return
@ -107,8 +120,8 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
if not generated_message: if not generated_message:
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=session_id, session_id=session_id, content="Job requirements analysis failed to generate a response."
content="Job requirements analysis failed to generate a response.") )
logger.error(f"⚠️ {error_message.content}") logger.error(f"⚠️ {error_message.content}")
yield error_message yield error_message
return return
@ -129,18 +142,18 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
display = { display = {
"technical_skills": { "technical_skills": {
"required": reqs.technical_skills.required, "required": reqs.technical_skills.required,
"preferred": reqs.technical_skills.preferred "preferred": reqs.technical_skills.preferred,
}, },
"experience_requirements": { "experience_requirements": {
"required": reqs.experience_requirements.required, "required": reqs.experience_requirements.required,
"preferred": reqs.experience_requirements.preferred "preferred": reqs.experience_requirements.preferred,
}, },
"soft_skills": reqs.soft_skills, "soft_skills": reqs.soft_skills,
"experience": reqs.experience, "experience": reqs.experience,
"education": reqs.education, "education": reqs.education,
"certifications": reqs.certifications, "certifications": reqs.certifications,
"preferred_attributes": reqs.preferred_attributes, "preferred_attributes": reqs.preferred_attributes,
"company_values": reqs.company_values "company_values": reqs.company_values,
} }
return display return display
@ -149,19 +162,14 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7 self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
) -> AsyncGenerator[ApiMessage, None]: ) -> AsyncGenerator[ApiMessage, None]:
if not self.user: if not self.user:
error_message = ChatMessageError( error_message = ChatMessageError(session_id=session_id, content="User is not set for this agent.")
session_id=session_id,
content="User is not set for this agent."
)
logger.error(f"⚠️ {error_message.content}") logger.error(f"⚠️ {error_message.content}")
yield error_message yield error_message
return return
# Stage 1A: Analyze job requirements # Stage 1A: Analyze job requirements
status_message = ChatMessageStatus( status_message = ChatMessageStatus(
session_id=session_id, session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING
content = "Analyzing job requirements",
activity=ApiActivityType.THINKING
) )
yield status_message yield status_message
@ -175,8 +183,7 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
if not generated_message: if not generated_message:
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=session_id, session_id=session_id, content="Job requirements analysis failed to generate a response."
content="Job requirements analysis failed to generate a response."
) )
logger.error(f"⚠️ {error_message.content}") logger.error(f"⚠️ {error_message.content}")
yield error_message yield error_message
@ -211,7 +218,9 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
return return
except Exception as e: except Exception as e:
status_message.status = ApiStatusType.ERROR status_message.status = ApiStatusType.ERROR
status_message.content = f"Unexpected error processing job requirements: {str(e)}\n\n{job_requirements_data}" status_message.content = (
f"Unexpected error processing job requirements: {str(e)}\n\n{job_requirements_data}"
)
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
logger.error(f"⚠️ {status_message.content}") logger.error(f"⚠️ {status_message.content}")
yield status_message yield status_message
@ -238,5 +247,6 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
logger.info("✅ Job requirements analysis completed successfully.") logger.info("✅ Job requirements analysis completed successfully.")
return return
# Register the base agent # Register the base agent
agent_registry.register(JobRequirementsAgent._agent_type, JobRequirementsAgent) agent_registry.register(JobRequirementsAgent._agent_type, JobRequirementsAgent)

View File

@ -5,20 +5,19 @@ from .base import Agent, agent_registry
from logger import logger from logger import logger
from .registry import agent_registry from .registry import agent_registry
from models import ( ApiMessage, ApiStatusType, ChatMessageError, ChatMessageRagSearch, ApiStatusType, Tunables ) from models import ApiMessage, ApiStatusType, ChatMessageError, ChatMessageRagSearch, ApiStatusType, Tunables
class Chat(Agent): class Chat(Agent):
""" """
Chat Agent Chat Agent
""" """
agent_type: Literal["rag_search"] = "rag_search" # type: ignore agent_type: Literal["rag_search"] = "rag_search" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
async def generate( async def generate(
self, llm: Any, model: str, self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
) -> AsyncGenerator[ApiMessage, None]: ) -> AsyncGenerator[ApiMessage, None]:
""" """
Generate a response based on the user message and the provided LLM. Generate a response based on the user message and the provided LLM.
@ -44,8 +43,7 @@ class Chat(Agent):
if not isinstance(rag_message, ChatMessageRagSearch): if not isinstance(rag_message, ChatMessageRagSearch):
logger.error(f"Expected ChatMessageRagSearch, got {type(rag_message)}") logger.error(f"Expected ChatMessageRagSearch, got {type(rag_message)}")
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=session_id, session_id=session_id, content="RAG search did not return a valid response."
content="RAG search did not return a valid response."
) )
yield error_message yield error_message
return return
@ -53,5 +51,6 @@ class Chat(Agent):
rag_message.status = ApiStatusType.DONE rag_message.status = ApiStatusType.DONE
yield rag_message yield rag_message
# Register the base agent # Register the base agent
agent_registry.register(Chat._agent_type, Chat) agent_registry.register(Chat._agent_type, Chat)

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing import List, Dict, Optional, Type from typing import List, Dict, Optional, Type
# We'll use a registry pattern rather than hardcoded strings # We'll use a registry pattern rather than hardcoded strings
class AgentRegistry: class AgentRegistry:
"""Registry for agent types and classes""" """Registry for agent types and classes"""

View File

@ -4,17 +4,27 @@ from typing import (
ClassVar, ClassVar,
Any, Any,
AsyncGenerator, AsyncGenerator,
Optional Optional,
# override # override
) # NOTE: You must import Optional for late binding to work ) # NOTE: You must import Optional for late binding to work
import json import json
from .base import Agent, agent_registry from .base import Agent, agent_registry
from models import (ApiMessage, ChatMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageSkillAssessment, ApiStatusType, EvidenceDetail, from models import (
SkillAssessment, Tunables) ApiMessage,
ChatMessage,
ChatMessageError,
ChatMessageRagSearch,
ChatMessageSkillAssessment,
ApiStatusType,
EvidenceDetail,
SkillAssessment,
Tunables,
)
from logger import logger from logger import logger
import backstory_traceback as traceback import backstory_traceback as traceback
class SkillMatchAgent(Agent): class SkillMatchAgent(Agent):
agent_type: Literal["skill_match"] = "skill_match" # type: ignore agent_type: Literal["skill_match"] = "skill_match" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration _agent_type: ClassVar[str] = agent_type # Add this for registration
@ -96,15 +106,12 @@ JSON RESPONSE:"""
return system_prompt, prompt return system_prompt, prompt
async def generate( async def generate(
self, llm: Any, model: str, self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
) -> AsyncGenerator[ApiMessage, None]: ) -> AsyncGenerator[ApiMessage, None]:
if not self.user: if not self.user:
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=session_id, session_id=session_id,
content="Agent not attached to user. Attach the agent to a user before generating responses." content="Agent not attached to user. Attach the agent to a user before generating responses.",
) )
logger.error(f"⚠️ {error_message.content}") logger.error(f"⚠️ {error_message.content}")
yield error_message yield error_message
@ -112,10 +119,7 @@ JSON RESPONSE:"""
skill = prompt.strip() skill = prompt.strip()
if not skill: if not skill:
error_message = ChatMessageError( error_message = ChatMessageError(session_id=session_id, content="Skill cannot be empty.")
session_id=session_id,
content="Skill cannot be empty."
)
logger.error(f"⚠️ {error_message.content}") logger.error(f"⚠️ {error_message.content}")
yield error_message yield error_message
return return
@ -130,8 +134,7 @@ JSON RESPONSE:"""
if generated_message is None: if generated_message is None:
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=session_id, session_id=session_id, content="RAG search did not return a valid response."
content="RAG search did not return a valid response."
) )
logger.error(f"⚠️ {error_message.content}") logger.error(f"⚠️ {error_message.content}")
yield error_message yield error_message
@ -140,8 +143,7 @@ JSON RESPONSE:"""
if not isinstance(generated_message, ChatMessageRagSearch): if not isinstance(generated_message, ChatMessageRagSearch):
logger.error(f"Expected ChatMessageRagSearch, got {type(generated_message)}") logger.error(f"Expected ChatMessageRagSearch, got {type(generated_message)}")
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=session_id, session_id=session_id, content="RAG search did not return a valid response."
content="RAG search did not return a valid response."
) )
yield error_message yield error_message
return return
@ -152,7 +154,9 @@ JSON RESPONSE:"""
system_prompt, prompt = self.generate_skill_assessment_prompt(skill=skill, rag_context=rag_context) system_prompt, prompt = self.generate_skill_assessment_prompt(skill=skill, rag_context=rag_context)
generated_message = None generated_message = None
async for generated_message in self.llm_one_shot(llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt, temperature=0.7): async for generated_message in self.llm_one_shot(
llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt, temperature=0.7
):
if generated_message.status == ApiStatusType.ERROR: if generated_message.status == ApiStatusType.ERROR:
logger.error(f"⚠️ {generated_message.content}") logger.error(f"⚠️ {generated_message.content}")
yield generated_message yield generated_message
@ -162,8 +166,7 @@ JSON RESPONSE:"""
if generated_message is None: if generated_message is None:
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=session_id, session_id=session_id, content="Skill assessment failed to generate a response."
content="Skill assessment failed to generate a response."
) )
logger.error(f"⚠️ {error_message.content}") logger.error(f"⚠️ {error_message.content}")
yield error_message yield error_message
@ -171,8 +174,7 @@ JSON RESPONSE:"""
if not isinstance(generated_message, ChatMessage): if not isinstance(generated_message, ChatMessage):
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=session_id, session_id=session_id, content="Skill assessment did not return a valid message."
content="Skill assessment did not return a valid message."
) )
logger.error(f"⚠️ {error_message.content}") logger.error(f"⚠️ {error_message.content}")
yield error_message yield error_message
@ -195,14 +197,15 @@ JSON RESPONSE:"""
EvidenceDetail( EvidenceDetail(
source=evidence.get("source", ""), source=evidence.get("source", ""),
quote=evidence.get("quote", ""), quote=evidence.get("quote", ""),
context=evidence.get("context", "") context=evidence.get("context", ""),
) for evidence in skill_assessment_data.get("evidence_details", []) )
] for evidence in skill_assessment_data.get("evidence_details", [])
],
) )
except Exception as e: except Exception as e:
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=session_id, session_id=session_id,
content=f"Failed to parse Skill assessment JSON: {str(e)}\n\n{generated_message.content}\n\nJSON:\n{json_str}\n\n" content=f"Failed to parse Skill assessment JSON: {str(e)}\n\n{generated_message.content}\n\nJSON:\n{json_str}\n\n",
) )
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
logger.error(f"⚠️ {error_message.content}") logger.error(f"⚠️ {error_message.content}")
@ -232,5 +235,6 @@ JSON RESPONSE:"""
logger.info("✅ Skill assessment completed successfully.") logger.info("✅ Skill assessment completed successfully.")
return return
# Register the base agent # Register the base agent
agent_registry.register(SkillMatchAgent._agent_type, SkillMatchAgent) agent_registry.register(SkillMatchAgent._agent_type, SkillMatchAgent)

View File

@ -9,6 +9,7 @@ from typing import Optional, List, Dict, Any, Callable
from logger import logger from logger import logger
from database.manager import DatabaseManager from database.manager import DatabaseManager
class BackgroundTaskManager: class BackgroundTaskManager:
"""Manages background tasks for the application using asyncio instead of threading""" """Manages background tasks for the application using asyncio instead of threading"""
@ -65,10 +66,12 @@ class BackgroundTaskManager:
stats = await database.get_guest_statistics() stats = await database.get_guest_statistics()
# Log interesting statistics # Log interesting statistics
if stats.get('total_guests', 0) > 0: if stats.get("total_guests", 0) > 0:
logger.info(f"📊 Guest stats: {stats['total_guests']} total, " logger.info(
f"📊 Guest stats: {stats['total_guests']} total, "
f"{stats['active_last_hour']} active in last hour, " f"{stats['active_last_hour']} active in last hour, "
f"{stats['converted_guests']} converted") f"{stats['converted_guests']} converted"
)
return stats return stats
except Exception as e: except Exception as e:
@ -84,6 +87,7 @@ class BackgroundTaskManager:
# Get Redis client safely (using the event loop safe method) # Get Redis client safely (using the event loop safe method)
from database.manager import redis_manager from database.manager import redis_manager
redis = await redis_manager.get_client() redis = await redis_manager.get_client()
# Clean up rate limit keys older than specified days # Clean up rate limit keys older than specified days
@ -192,10 +196,7 @@ class BackgroundTaskManager:
# Create asyncio tasks for each periodic task # Create asyncio tasks for each periodic task
for name, func, interval, *args in periodic_tasks: for name, func, interval, *args in periodic_tasks:
task = asyncio.create_task( task = asyncio.create_task(self._run_periodic_task(name, func, interval, *args), name=f"background_{name}")
self._run_periodic_task(name, func, interval, *args),
name=f"background_{name}"
)
self.tasks.append(task) self.tasks.append(task)
logger.info(f"📅 Scheduled background task: {name}") logger.info(f"📅 Scheduled background task: {name}")
@ -238,10 +239,7 @@ class BackgroundTaskManager:
# Wait for all tasks to complete with timeout # Wait for all tasks to complete with timeout
if self.tasks: if self.tasks:
try: try:
await asyncio.wait_for( await asyncio.wait_for(asyncio.gather(*self.tasks, return_exceptions=True), timeout=30.0)
asyncio.gather(*self.tasks, return_exceptions=True),
timeout=30.0
)
logger.info("✅ All background tasks stopped gracefully") logger.info("✅ All background tasks stopped gracefully")
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning("⚠️ Some background tasks did not stop within timeout") logger.warning("⚠️ Some background tasks did not stop within timeout")
@ -258,7 +256,7 @@ class BackgroundTaskManager:
"main_loop_id": id(self.main_loop) if self.main_loop else None, "main_loop_id": id(self.main_loop) if self.main_loop else None,
"current_loop_id": None, "current_loop_id": None,
"task_count": len(self.tasks), "task_count": len(self.tasks),
"tasks": [] "tasks": [],
} }
try: try:
@ -317,6 +315,7 @@ async def setup_background_tasks(database_manager: DatabaseManager) -> Backgroun
await task_manager.start() await task_manager.start()
return task_manager return task_manager
# For integration with your existing app startup # For integration with your existing app startup
async def initialize_with_background_tasks(database_manager: DatabaseManager): async def initialize_with_background_tasks(database_manager: DatabaseManager):
"""Initialize database and background tasks together""" """Initialize database and background tasks together"""

View File

@ -3,6 +3,7 @@ import os
import sys import sys
import defines import defines
def filter_traceback(tb, app_path=None, module_name=None): def filter_traceback(tb, app_path=None, module_name=None):
""" """
Filter traceback to include only frames from the specified application path or module. Filter traceback to include only frames from the specified application path or module.
@ -36,7 +37,8 @@ def filter_traceback(tb, app_path=None, module_name=None):
formatted_exc = traceback.format_exception_only(exc_type, exc_value) formatted_exc = traceback.format_exception_only(exc_type, exc_value)
# Combine the filtered stack trace with the exception message # Combine the filtered stack trace with the exception message
return ''.join(formatted_stack + formatted_exc) return "".join(formatted_stack + formatted_exc)
def format_exc(app_path=defines.app_path, module_name=None): def format_exc(app_path=defines.app_path, module_name=None):
""" """

View File

@ -1,4 +1,4 @@
from .core import RedisDatabase from .core import RedisDatabase
from .manager import DatabaseManager, redis_manager from .manager import DatabaseManager, redis_manager
__all__ = ['RedisDatabase', 'DatabaseManager', 'redis_manager'] __all__ = ["RedisDatabase", "DatabaseManager", "redis_manager"]

View File

@ -1,15 +1,15 @@
KEY_PREFIXES = { KEY_PREFIXES = {
'viewers': 'viewer:', "viewers": "viewer:",
'candidates': 'candidate:', "candidates": "candidate:",
'employers': 'employer:', "employers": "employer:",
'jobs': 'job:', "jobs": "job:",
'job_applications': 'job_application:', "job_applications": "job_application:",
'chat_sessions': 'chat_session:', "chat_sessions": "chat_session:",
'chat_messages': 'chat_messages:', "chat_messages": "chat_messages:",
'ai_parameters': 'ai_parameters:', "ai_parameters": "ai_parameters:",
'users': 'user:', "users": "user:",
'candidate_documents': 'candidate_documents:', "candidate_documents": "candidate_documents:",
'job_requirements': 'job_requirements:', "job_requirements": "job_requirements:",
'resumes': 'resume:', "resumes": "resume:",
'user_resumes': 'user_resumes:', "user_resumes": "user_resumes:",
} }

View File

@ -10,6 +10,7 @@ from .mixins.job import JobMixin
from .mixins.skill import SkillMixin from .mixins.skill import SkillMixin
from .mixins.ai import AIMixin from .mixins.ai import AIMixin
# RedisDatabase is the main class that combines all mixins for a # RedisDatabase is the main class that combines all mixins for a
# comprehensive Redis database interface. # comprehensive Redis database interface.
class RedisDatabase( class RedisDatabase(

View File

@ -1,4 +1,4 @@
from redis.asyncio import (Redis, ConnectionPool) from redis.asyncio import Redis, ConnectionPool
from typing import Optional, Optional from typing import Optional, Optional
import json import json
import logging import logging
@ -9,6 +9,7 @@ from .core import RedisDatabase
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# _RedisManager is a singleton class that manages the Redis connection and # _RedisManager is a singleton class that manages the Redis connection and
# provides methods for connecting, disconnecting, and performing health checks. # provides methods for connecting, disconnecting, and performing health checks.
# #
@ -42,12 +43,10 @@ class _RedisManager:
retry_on_timeout=True, retry_on_timeout=True,
socket_keepalive=True, socket_keepalive=True,
socket_keepalive_options={}, socket_keepalive_options={},
health_check_interval=30 health_check_interval=30,
) )
self.redis = Redis( self.redis = Redis(connection_pool=self._connection_pool)
connection_pool=self._connection_pool
)
if not self.redis: if not self.redis:
raise RuntimeError("Redis client not initialized") raise RuntimeError("Redis client not initialized")
@ -131,7 +130,7 @@ class _RedisManager:
"uptime_seconds": info.get("uptime_in_seconds", 0), "uptime_seconds": info.get("uptime_in_seconds", 0),
"connected_clients": info.get("connected_clients", 0), "connected_clients": info.get("connected_clients", 0),
"used_memory_human": info.get("used_memory_human", "unknown"), "used_memory_human": info.get("used_memory_human", "unknown"),
"total_commands_processed": info.get("total_commands_processed", 0) "total_commands_processed": info.get("total_commands_processed", 0),
} }
except Exception as e: except Exception as e:
logger.error(f"Redis health check failed: {e}") logger.error(f"Redis health check failed: {e}")
@ -173,9 +172,11 @@ class _RedisManager:
logger.error(f"Failed to get Redis info: {e}") logger.error(f"Failed to get Redis info: {e}")
return None return None
# Global Redis manager instance # Global Redis manager instance
redis_manager = _RedisManager() redis_manager = _RedisManager()
# DatabaseManager is an enhanced database manager that provides graceful shutdown capabilities # DatabaseManager is an enhanced database manager that provides graceful shutdown capabilities
# It manages the Redis connection, tracks active requests, and allows for data backup before shutdown. # It manages the Redis connection, tracks active requests, and allows for data backup before shutdown.
class DatabaseManager: class DatabaseManager:
@ -227,7 +228,7 @@ class DatabaseManager:
backup_filename = f"backup_{datetime.now(UTC).strftime('%Y%m%d_%H%M%S')}.json" backup_filename = f"backup_{datetime.now(UTC).strftime('%Y%m%d_%H%M%S')}.json"
# Save to local file (you might want to save to cloud storage instead) # Save to local file (you might want to save to cloud storage instead)
with open(backup_filename, 'w') as f: with open(backup_filename, "w") as f:
json.dump(backup_data, f, indent=2, default=str) json.dump(backup_data, f, indent=2, default=str)
logger.info(f"Backup created: {backup_filename}") logger.info(f"Backup created: {backup_filename}")
@ -310,5 +311,3 @@ class DatabaseManager:
if self._shutdown_initiated: if self._shutdown_initiated:
raise RuntimeError("Application is shutting down") raise RuntimeError("Application is shutting down")
return self.db return self.db

View File

@ -8,8 +8,10 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AIMixin(DatabaseProtocol): class AIMixin(DatabaseProtocol):
"""Mixin for AI operations""" """Mixin for AI operations"""
async def get_ai_parameters(self, param_id: str) -> Optional[Dict]: async def get_ai_parameters(self, param_id: str) -> Optional[Dict]:
"""Get AI parameters by ID""" """Get AI parameters by ID"""
key = f"{KEY_PREFIXES['ai_parameters']}{param_id}" key = f"{KEY_PREFIXES['ai_parameters']}{param_id}"
@ -36,7 +38,7 @@ class AIMixin(DatabaseProtocol):
result = {} result = {}
for key, value in zip(keys, values): for key, value in zip(keys, values):
param_id = key.replace(KEY_PREFIXES['ai_parameters'], '') param_id = key.replace(KEY_PREFIXES["ai_parameters"], "")
result[param_id] = self._deserialize(value) result[param_id] = self._deserialize(value)
return result return result
@ -45,4 +47,3 @@ class AIMixin(DatabaseProtocol):
"""Delete AI parameters""" """Delete AI parameters"""
key = f"{KEY_PREFIXES['ai_parameters']}{param_id}" key = f"{KEY_PREFIXES['ai_parameters']}{param_id}"
await self.redis.delete(key) await self.redis.delete(key)

View File

@ -7,6 +7,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AnalyticsMixin: class AnalyticsMixin:
"""Mixin for analytics-related database operations""" """Mixin for analytics-related database operations"""

View File

@ -8,6 +8,7 @@ from .protocols import DatabaseProtocol
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AuthMixin(DatabaseProtocol): class AuthMixin(DatabaseProtocol):
"""Mixin for auth-related database operations""" """Mixin for auth-related database operations"""
@ -25,8 +26,9 @@ class AuthMixin(DatabaseProtocol):
token_data = await self.redis.get(key) token_data = await self.redis.get(key)
if token_data: if token_data:
verification_info = json.loads(token_data) verification_info = json.loads(token_data)
if (verification_info.get("email", "").lower() == email_lower and if verification_info.get("email", "").lower() == email_lower and not verification_info.get(
not verification_info.get("verified", False)): "verified", False
):
# Extract token from key # Extract token from key
token = key.replace("email_verification:", "") token = key.replace("email_verification:", "")
verification_info["token"] = token verification_info["token"] = token
@ -115,10 +117,7 @@ class AuthMixin(DatabaseProtocol):
window_start = current_time - timedelta(hours=24) window_start = current_time - timedelta(hours=24)
# Filter out old attempts # Filter out old attempts
recent_attempts = [ recent_attempts = [attempt for attempt in attempts_data if datetime.fromisoformat(attempt) > window_start]
attempt for attempt in attempts_data
if datetime.fromisoformat(attempt) > window_start
]
return len(recent_attempts) return len(recent_attempts)
@ -141,16 +140,13 @@ class AuthMixin(DatabaseProtocol):
# Keep only last 24 hours of attempts # Keep only last 24 hours of attempts
window_start = current_time - timedelta(hours=24) window_start = current_time - timedelta(hours=24)
recent_attempts = [ recent_attempts = [attempt for attempt in attempts_data if datetime.fromisoformat(attempt) > window_start]
attempt for attempt in attempts_data
if datetime.fromisoformat(attempt) > window_start
]
# Store with 24 hour expiration # Store with 24 hour expiration
await self.redis.setex( await self.redis.setex(
key, key,
24 * 60 * 60, # 24 hours 24 * 60 * 60, # 24 hours
json.dumps(recent_attempts) json.dumps(recent_attempts),
) )
return True return True
@ -169,14 +165,14 @@ class AuthMixin(DatabaseProtocol):
"user_data": user_data, "user_data": user_data,
"expires_at": (datetime.now(timezone.utc) + timedelta(hours=24)).isoformat(), "expires_at": (datetime.now(timezone.utc) + timedelta(hours=24)).isoformat(),
"created_at": datetime.now(timezone.utc).isoformat(), "created_at": datetime.now(timezone.utc).isoformat(),
"verified": False "verified": False,
} }
# Store with 24 hour expiration # Store with 24 hour expiration
await self.redis.setex( await self.redis.setex(
key, key,
24 * 60 * 60, # 24 hours in seconds 24 * 60 * 60, # 24 hours in seconds
json.dumps(verification_data, default=str) json.dumps(verification_data, default=str),
) )
logger.info(f"📧 Stored email verification token for {email}") logger.info(f"📧 Stored email verification token for {email}")
@ -208,7 +204,7 @@ class AuthMixin(DatabaseProtocol):
await self.redis.setex( await self.redis.setex(
key, key,
24 * 60 * 60, # Keep for remaining TTL 24 * 60 * 60, # Keep for remaining TTL
json.dumps(token_data, default=str) json.dumps(token_data, default=str),
) )
return True return True
return False return False
@ -228,14 +224,14 @@ class AuthMixin(DatabaseProtocol):
"expires_at": (datetime.now(timezone.utc) + timedelta(minutes=10)).isoformat(), "expires_at": (datetime.now(timezone.utc) + timedelta(minutes=10)).isoformat(),
"created_at": datetime.now(timezone.utc).isoformat(), "created_at": datetime.now(timezone.utc).isoformat(),
"attempts": 0, "attempts": 0,
"verified": False "verified": False,
} }
# Store with 10 minute expiration # Store with 10 minute expiration
await self.redis.setex( await self.redis.setex(
key, key,
10 * 60, # 10 minutes in seconds 10 * 60, # 10 minutes in seconds
json.dumps(mfa_data, default=str) json.dumps(mfa_data, default=str),
) )
logger.info(f"🔐 Stored MFA code for {email}") logger.info(f"🔐 Stored MFA code for {email}")
@ -266,7 +262,7 @@ class AuthMixin(DatabaseProtocol):
await self.redis.setex( await self.redis.setex(
key, key,
10 * 60, # Keep original TTL 10 * 60, # Keep original TTL
json.dumps(mfa_data, default=str) json.dumps(mfa_data, default=str),
) )
return mfa_data["attempts"] return mfa_data["attempts"]
return 0 return 0
@ -285,7 +281,7 @@ class AuthMixin(DatabaseProtocol):
await self.redis.setex( await self.redis.setex(
key, key,
10 * 60, # Keep for remaining TTL 10 * 60, # Keep for remaining TTL
json.dumps(mfa_data, default=str) json.dumps(mfa_data, default=str),
) )
return True return True
return False return False
@ -327,7 +323,9 @@ class AuthMixin(DatabaseProtocol):
logger.error(f"❌ Error deleting authentication record for {user_id}: {e}") logger.error(f"❌ Error deleting authentication record for {user_id}: {e}")
return False return False
async def store_refresh_token(self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]) -> bool: async def store_refresh_token(
self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]
) -> bool:
"""Store refresh token for a user""" """Store refresh token for a user"""
try: try:
key = f"refresh_token:{token}" key = f"refresh_token:{token}"
@ -337,7 +335,7 @@ class AuthMixin(DatabaseProtocol):
"device": device_info.get("device", "unknown"), "device": device_info.get("device", "unknown"),
"ip_address": device_info.get("ip_address", "unknown"), "ip_address": device_info.get("ip_address", "unknown"),
"is_revoked": False, "is_revoked": False,
"created_at": datetime.now(timezone.utc).isoformat() "created_at": datetime.now(timezone.utc).isoformat(),
} }
# Store with expiration # Store with expiration
@ -420,7 +418,7 @@ class AuthMixin(DatabaseProtocol):
"email": email.lower(), "email": email.lower(),
"expires_at": expires_at.isoformat(), "expires_at": expires_at.isoformat(),
"used": False, "used": False,
"created_at": datetime.now(timezone.utc).isoformat() "created_at": datetime.now(timezone.utc).isoformat(),
} }
# Store with expiration # Store with expiration
@ -473,7 +471,7 @@ class AuthMixin(DatabaseProtocol):
"timestamp": datetime.now(timezone.utc).isoformat(), "timestamp": datetime.now(timezone.utc).isoformat(),
"user_id": user_id, "user_id": user_id,
"event_type": event_type, "event_type": event_type,
"details": details "details": details,
} }
# Add to list (latest events first) # Add to list (latest events first)
@ -496,7 +494,7 @@ class AuthMixin(DatabaseProtocol):
try: try:
events = [] events = []
for i in range(days): for i in range(days):
date = (datetime.now(timezone.utc) - timedelta(days=i)).strftime('%Y-%m-%d') date = (datetime.now(timezone.utc) - timedelta(days=i)).strftime("%Y-%m-%d")
key = f"security_log:{user_id}:{date}" key = f"security_log:{user_id}:{date}"
daily_events = await self.redis.lrange(key, 0, -1) # type: ignore daily_events = await self.redis.lrange(key, 0, -1) # type: ignore
@ -509,4 +507,3 @@ class AuthMixin(DatabaseProtocol):
except Exception as e: except Exception as e:
logger.error(f"❌ Error retrieving security log for {user_id}: {e}") logger.error(f"❌ Error retrieving security log for {user_id}: {e}")
return [] return []

View File

@ -5,11 +5,13 @@ from typing import Any, Dict, TYPE_CHECKING
from .protocols import DatabaseProtocol from .protocols import DatabaseProtocol
from ..constants import KEY_PREFIXES from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
pass pass
class BaseMixin(DatabaseProtocol): class BaseMixin(DatabaseProtocol):
"""Base mixin with core Redis operations and utilities""" """Base mixin with core Redis operations and utilities"""
@ -45,6 +47,3 @@ class BaseMixin(DatabaseProtocol):
keys = await self.redis.keys(pattern) keys = await self.redis.keys(pattern)
if keys: if keys:
await self.redis.delete(*keys) await self.redis.delete(*keys)

View File

@ -8,6 +8,7 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ChatMixin(DatabaseProtocol): class ChatMixin(DatabaseProtocol):
"""Mixin for chat-related database operations""" """Mixin for chat-related database operations"""
@ -22,7 +23,7 @@ class ChatMixin(DatabaseProtocol):
"total_sessions": 0, "total_sessions": 0,
"total_messages": 0, "total_messages": 0,
"first_chat": None, "first_chat": None,
"last_chat": None "last_chat": None,
} }
total_messages = 0 total_messages = 0
@ -41,7 +42,7 @@ class ChatMixin(DatabaseProtocol):
"total_messages": total_messages, "total_messages": total_messages,
"first_chat": sessions_by_date[0].get("createdAt") if sessions_by_date else None, "first_chat": sessions_by_date[0].get("createdAt") if sessions_by_date else None,
"last_chat": sessions_by_date[-1].get("lastActivity") if sessions_by_date else None, "last_chat": sessions_by_date[-1].get("lastActivity") if sessions_by_date else None,
"recent_sessions": sessions[:5] # Last 5 sessions "recent_sessions": sessions[:5], # Last 5 sessions
} }
# Chat Sessions operations # Chat Sessions operations
@ -71,13 +72,13 @@ class ChatMixin(DatabaseProtocol):
result = {} result = {}
for key, value in zip(keys, values): for key, value in zip(keys, values):
session_id = key.replace(KEY_PREFIXES['chat_sessions'], '') session_id = key.replace(KEY_PREFIXES["chat_sessions"], "")
result[session_id] = self._deserialize(value) result[session_id] = self._deserialize(value)
return result return result
async def delete_chat_session(self, session_id: str) -> bool: async def delete_chat_session(self, session_id: str) -> bool:
'''Delete a chat session from Redis''' """Delete a chat session from Redis"""
try: try:
result = await self.redis.delete(f"chat_session:{session_id}") result = await self.redis.delete(f"chat_session:{session_id}")
return result > 0 return result > 0
@ -86,7 +87,7 @@ class ChatMixin(DatabaseProtocol):
raise raise
async def delete_chat_message(self, session_id: str, message_id: str) -> bool: async def delete_chat_message(self, session_id: str, message_id: str) -> bool:
'''Delete a specific chat message from Redis''' """Delete a specific chat message from Redis"""
try: try:
# Remove from the session's message list # Remove from the session's message list
key = f"{KEY_PREFIXES['chat_messages']}{session_id}" key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
@ -132,7 +133,7 @@ class ChatMixin(DatabaseProtocol):
result = {} result = {}
for key in keys: for key in keys:
session_id = key.replace(KEY_PREFIXES['chat_messages'], '') session_id = key.replace(KEY_PREFIXES["chat_messages"], "")
messages = await self.redis.lrange(key, 0, -1) # type: ignore messages = await self.redis.lrange(key, 0, -1) # type: ignore
result[session_id] = [self._deserialize(msg) for msg in messages if msg] result[session_id] = [self._deserialize(msg) for msg in messages if msg]
@ -164,8 +165,7 @@ class ChatMixin(DatabaseProtocol):
for session_data in all_sessions.values(): for session_data in all_sessions.values():
context = session_data.get("context", {}) context = session_data.get("context", {})
if (context.get("relatedEntityType") == "candidate" and if context.get("relatedEntityType") == "candidate" and context.get("relatedEntityId") == candidate_id:
context.get("relatedEntityId") == candidate_id):
candidate_sessions.append(session_data) candidate_sessions.append(session_data)
# Sort by last activity (most recent first) # Sort by last activity (most recent first)
@ -236,7 +236,6 @@ class ChatMixin(DatabaseProtocol):
return archived_count return archived_count
# Analytics and Reporting # Analytics and Reporting
async def get_chat_statistics(self) -> Dict[str, Any]: async def get_chat_statistics(self) -> Dict[str, Any]:
"""Get comprehensive chat statistics""" """Get comprehensive chat statistics"""
@ -250,7 +249,7 @@ class ChatMixin(DatabaseProtocol):
"archived_sessions": 0, "archived_sessions": 0,
"sessions_by_type": {}, "sessions_by_type": {},
"sessions_with_candidates": 0, "sessions_with_candidates": 0,
"average_messages_per_session": 0 "average_messages_per_session": 0,
} }
# Analyze sessions # Analyze sessions

View File

@ -7,6 +7,7 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DocumentMixin(DatabaseProtocol): class DocumentMixin(DatabaseProtocol):
"""Mixin for document-related database operations""" """Mixin for document-related database operations"""
@ -136,8 +137,7 @@ class DocumentMixin(DatabaseProtocol):
query_lower = query.lower() query_lower = query.lower()
return [ return [
doc for doc in all_documents doc
if (query_lower in doc.get("filename", "").lower() or for doc in all_documents
query_lower in doc.get("originalName", "").lower()) if (query_lower in doc.get("filename", "").lower() or query_lower in doc.get("originalName", "").lower())
] ]

View File

@ -8,8 +8,10 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class JobMixin(DatabaseProtocol): class JobMixin(DatabaseProtocol):
"""Mixin for job-related database operations""" """Mixin for job-related database operations"""
async def get_job(self, job_id: str) -> Optional[Dict]: async def get_job(self, job_id: str) -> Optional[Dict]:
"""Get job by ID""" """Get job by ID"""
key = f"{KEY_PREFIXES['jobs']}{job_id}" key = f"{KEY_PREFIXES['jobs']}{job_id}"
@ -36,7 +38,7 @@ class JobMixin(DatabaseProtocol):
result = {} result = {}
for key, value in zip(keys, values): for key, value in zip(keys, values):
job_id = key.replace(KEY_PREFIXES['jobs'], '') job_id = key.replace(KEY_PREFIXES["jobs"], "")
result[job_id] = self._deserialize(value) result[job_id] = self._deserialize(value)
return result return result
@ -124,7 +126,7 @@ class JobMixin(DatabaseProtocol):
result = {} result = {}
for key, value in zip(keys, values): for key, value in zip(keys, values):
app_id = key.replace(KEY_PREFIXES['job_applications'], '') app_id = key.replace(KEY_PREFIXES["job_applications"], "")
result[app_id] = self._deserialize(value) result[app_id] = self._deserialize(value)
return result return result
@ -189,7 +191,7 @@ class JobMixin(DatabaseProtocol):
requirements_with_meta = { requirements_with_meta = {
**requirements, **requirements,
"cached_at": datetime.now(UTC).isoformat(), "cached_at": datetime.now(UTC).isoformat(),
"document_id": document_id "document_id": document_id,
} }
await self.redis.set(key, self._serialize(requirements_with_meta)) await self.redis.set(key, self._serialize(requirements_with_meta))
@ -232,7 +234,7 @@ class JobMixin(DatabaseProtocol):
result = {} result = {}
for key, value in zip(keys, values): for key, value in zip(keys, values):
document_id = key.replace(KEY_PREFIXES['job_requirements'], '') document_id = key.replace(KEY_PREFIXES["job_requirements"], "")
if value: if value:
result[document_id] = self._deserialize(value) result[document_id] = self._deserialize(value)
@ -247,11 +249,7 @@ class JobMixin(DatabaseProtocol):
pattern = f"{KEY_PREFIXES['job_requirements']}*" pattern = f"{KEY_PREFIXES['job_requirements']}*"
keys = await self.redis.keys(pattern) keys = await self.redis.keys(pattern)
stats = { stats = {"total_cached_requirements": len(keys), "cache_dates": {}, "documents_with_requirements": []}
"total_cached_requirements": len(keys),
"cache_dates": {},
"documents_with_requirements": []
}
if keys: if keys:
# Get cache dates for analysis # Get cache dates for analysis
@ -264,7 +262,7 @@ class JobMixin(DatabaseProtocol):
if value: if value:
requirements_data = self._deserialize(value) requirements_data = self._deserialize(value)
if requirements_data: if requirements_data:
document_id = key.replace(KEY_PREFIXES['job_requirements'], '') document_id = key.replace(KEY_PREFIXES["job_requirements"], "")
stats["documents_with_requirements"].append(document_id) stats["documents_with_requirements"].append(document_id)
# Track cache dates # Track cache dates
@ -277,4 +275,3 @@ class JobMixin(DatabaseProtocol):
except Exception as e: except Exception as e:
logger.error(f"❌ Error getting job requirements stats: {e}") logger.error(f"❌ Error getting job requirements stats: {e}")
return {"total_cached_requirements": 0, "cache_dates": {}, "documents_with_requirements": []} return {"total_cached_requirements": 0, "cache_dates": {}, "documents_with_requirements": []}

View File

@ -7,153 +7,412 @@ if TYPE_CHECKING:
from models import SkillAssessment from models import SkillAssessment
class DatabaseProtocol(Protocol): class DatabaseProtocol(Protocol):
# Base mixin # Base mixin
redis: Redis redis: Redis
def _serialize(self, data) -> str: ...
def _deserialize(self, data: str): ... def _serialize(self, data) -> str:
...
def _deserialize(self, data: str):
...
# Chat mixin # Chat mixin
async def add_chat_message(self, session_id: str, message_data: Dict): ... async def add_chat_message(self, session_id: str, message_data: Dict):
async def archive_chat_session(self, session_id: str): ... ...
async def bulk_update_chat_sessions(self, session_updates: Dict[str, Dict]): ...
async def delete_chat_message(self, session_id: str, message_id: str) -> bool: ... async def archive_chat_session(self, session_id: str):
async def delete_chat_messages(self, session_id: str): ... ...
async def delete_chat_session_completely(self, session_id: str): ...
async def delete_chat_session(self, session_id: str) -> bool: ... async def bulk_update_chat_sessions(self, session_updates: Dict[str, Dict]):
...
async def delete_chat_message(self, session_id: str, message_id: str) -> bool:
...
async def delete_chat_messages(self, session_id: str):
...
async def delete_chat_session_completely(self, session_id: str):
...
async def delete_chat_session(self, session_id: str) -> bool:
...
# Document mixin # Document mixin
async def add_document_to_candidate(self, candidate_id: str, document_id: str): ... async def add_document_to_candidate(self, candidate_id: str, document_id: str):
async def bulk_update_document_rag_status(self, candidate_id: str, document_ids: List[str], include_in_rag: bool): ... ...
async def bulk_update_document_rag_status(self, candidate_id: str, document_ids: List[str], include_in_rag: bool):
...
# Job mixin # Job mixin
async def bulk_delete_job_requirements(self, document_ids: List[str]) -> int: ... async def bulk_delete_job_requirements(self, document_ids: List[str]) -> int:
async def cache_skill_match(self, cache_key: str, assessment: SkillAssessment) -> None: ... ...
async def cache_skill_match(self, cache_key: str, assessment: SkillAssessment) -> None:
...
# User mixin # User mixin
async def delete_candidate_batch(self, candidate_ids: List[str]) -> Dict[str, Dict[str, int]]: ... async def delete_candidate_batch(self, candidate_ids: List[str]) -> Dict[str, Dict[str, int]]:
async def delete_candidate(self, candidate_id: str) -> Dict[str, int]: ... ...
async def delete_employer(self, employer_id: str): ...
async def delete_guest(self, guest_id: str) -> bool: ... async def delete_candidate(self, candidate_id: str) -> Dict[str, int]:
async def delete_user(self, email: str): ... ...
async def find_candidate_by_username(self, username: str) -> Optional[Dict]: ...
async def get_all_users(self) -> Dict[str, Any]: ... async def delete_employer(self, employer_id: str):
async def get_all_viewers(self) -> Dict[str, Any]: ... ...
async def get_candidate_chat_summary(self, candidate_id: str) -> Dict[str, Any]: ...
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]: ... async def delete_guest(self, guest_id: str) -> bool:
async def get_candidate(self, candidate_id: str) -> Optional[Dict]: ... ...
async def get_employer(self, employer_id: str) -> Optional[Dict]: ...
async def get_guest_by_session_id(self, session_id: str) -> Optional[Dict[str, Any]]: ... async def delete_user(self, email: str):
async def get_guest(self, guest_id: str) -> Optional[Dict[str, Any]]: ... ...
async def get_guest_statistics(self) -> Dict[str, Any]: ...
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: ... async def find_candidate_by_username(self, username: str) -> Optional[Dict]:
async def get_user_by_username(self, username: str) -> Optional[Dict]: ... ...
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]: ...
async def get_user_security_log(self, user_id: str, days: int = 7) -> List[Dict[str, Any]]: ... async def get_all_users(self) -> Dict[str, Any]:
async def get_user(self, login: str) -> Optional[Dict[str, Any]]: ... ...
async def invalidate_candidate_skill_cache(self, candidate_id: str) -> int: ...
async def invalidate_user_skill_cache(self, user_id: str) -> int: ... async def get_all_viewers(self) -> Dict[str, Any]:
async def set_candidate(self, candidate_id: str, candidate_data: Dict): ... ...
async def set_employer(self, employer_id: str, employer_data: Dict): ...
async def set_guest(self, guest_id: str, guest_data: Dict[str, Any]) -> None: ... async def get_candidate_chat_summary(self, candidate_id: str) -> Dict[str, Any]:
async def set_user_by_id(self, user_id: str, user_data: Dict[str, Any]) -> bool: ... ...
async def set_user(self, login: str, user_data: Dict[str, Any]) -> bool: ...
async def update_user_rag_timestamp(self, user_id: str) -> bool: ... async def get_candidate_documents(self, candidate_id: str) -> List[Dict]:
async def user_exists_by_email(self, email: str) -> bool: ... ...
async def user_exists_by_username(self, username: str) -> bool: ...
async def get_candidate(self, candidate_id: str) -> Optional[Dict]:
...
async def get_employer(self, employer_id: str) -> Optional[Dict]:
...
async def get_guest_by_session_id(self, session_id: str) -> Optional[Dict[str, Any]]:
...
async def get_guest(self, guest_id: str) -> Optional[Dict[str, Any]]:
...
async def get_guest_statistics(self) -> Dict[str, Any]:
...
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
...
async def get_user_by_username(self, username: str) -> Optional[Dict]:
...
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]:
...
async def get_user_security_log(self, user_id: str, days: int = 7) -> List[Dict[str, Any]]:
...
async def get_user(self, login: str) -> Optional[Dict[str, Any]]:
...
async def invalidate_candidate_skill_cache(self, candidate_id: str) -> int:
...
async def invalidate_user_skill_cache(self, user_id: str) -> int:
...
async def set_candidate(self, candidate_id: str, candidate_data: Dict):
...
async def set_employer(self, employer_id: str, employer_data: Dict):
...
async def set_guest(self, guest_id: str, guest_data: Dict[str, Any]) -> None:
...
async def set_user_by_id(self, user_id: str, user_data: Dict[str, Any]) -> bool:
...
async def set_user(self, login: str, user_data: Dict[str, Any]) -> bool:
...
async def update_user_rag_timestamp(self, user_id: str) -> bool:
...
async def user_exists_by_email(self, email: str) -> bool:
...
async def user_exists_by_username(self, username: str) -> bool:
...
# Auth mixin # Auth mixin
async def cleanup_expired_verification_tokens(self) -> int: ... async def cleanup_expired_verification_tokens(self) -> int:
async def cleanup_inactive_guests(self, inactive_hours: int = 24) -> int: ... ...
async def cleanup_old_chat_sessions(self, days_old: int = 90) -> int: ...
async def cleanup_orphaned_job_requirements(self) -> int: ... async def cleanup_inactive_guests(self, inactive_hours: int = 24) -> int:
async def clear_all_data(self: "DatabaseProtocol"): ... ...
async def clear_all_skill_match_cache(self) -> int: ...
async def cleanup_old_chat_sessions(self, days_old: int = 90) -> int:
...
async def cleanup_orphaned_job_requirements(self) -> int:
...
async def clear_all_data(self: "DatabaseProtocol"):
...
async def clear_all_skill_match_cache(self) -> int:
...
# Resume mixin # Resume mixin
async def delete_ai_parameters(self, param_id: str): ... async def delete_ai_parameters(self, param_id: str):
async def delete_all_candidate_documents(self, candidate_id: str) -> int: ... ...
async def delete_all_resumes_for_user(self, user_id: str) -> int: ...
async def delete_authentication(self, user_id: str) -> bool: ...
async def delete_document(self, document_id: str): ...
async def delete_job_application(self, application_id: str): ... async def delete_all_candidate_documents(self, candidate_id: str) -> int:
async def delete_job_requirements(self, document_id: str) -> bool: ... ...
async def delete_job(self, job_id: str): ...
async def delete_resume(self, user_id: str, resume_id: str) -> bool: ...
async def delete_viewer(self, viewer_id: str): ...
async def find_verification_token_by_email(self, email: str) -> Optional[Dict[str, Any]]: ...
async def get_ai_parameters(self, param_id: str) -> Optional[Dict]: ...
async def get_all_ai_parameters(self) -> Dict[str, Any]: ...
async def get_all_candidates(self) -> Dict[str, Any]: ...
async def get_all_chat_messages(self) -> Dict[str, List[Dict]]: ...
async def get_all_chat_sessions(self) -> Dict[str, Any]: ...
async def get_all_employers(self) -> Dict[str, Any]: ...
async def get_all_guests(self) -> Dict[str, Dict[str, Any]]: ...
async def get_all_job_applications(self) -> Dict[str, Any]: ...
async def get_all_job_requirements(self) -> Dict[str, Any]: ...
async def get_all_jobs(self) -> Dict[str, Any]: ...
async def get_all_resumes_for_user(self, user_id: str) -> List[Dict]: ...
async def get_all_resumes(self) -> Dict[str, List[Dict]]: ...
async def get_authentication(self, user_id: str) -> Optional[Dict[str, Any]]: ...
async def get_cached_skill_match(self, cache_key: str) -> Optional[SkillAssessment]: ...
async def get_chat_message_count(self, session_id: str) -> int: ...
async def get_chat_messages(self, session_id: str) -> List[Dict]: ...
async def get_chat_sessions_by_candidate(self, candidate_id: str) -> List[Dict]: ...
async def get_chat_sessions_by_user(self, user_id: str) -> List[Dict]: ...
async def get_chat_session(self, session_id: str) -> Optional[Dict]: ...
async def get_chat_statistics(self) -> Dict[str, Any]: ...
async def get_document_count_for_candidate(self, candidate_id: str) -> int: ...
async def get_documents_by_rag_status(self, candidate_id: str, include_in_rag: bool = True) -> List[Dict]: ...
async def get_document(self, document_id: str) -> Optional[Dict]: ...
async def get_email_verification_token(self, token: str) -> Optional[Dict[str, Any]]: ...
async def get_job_application(self, application_id: str) -> Optional[Dict]: ...
async def get_job_requirements_by_candidate(self, candidate_id: str) -> List[Dict]: ...
async def get_job_requirements(self, document_id: str) -> Optional[Dict]: ...
async def get_job_requirements_stats(self) -> Dict[str, Any]: ...
async def get_job(self, job_id: str) -> Optional[Dict]: ...
async def get_mfa_code(self, email: str, device_id: str) -> Optional[Dict[str, Any]]: ...
async def get_multiple_candidates_by_usernames(self, usernames: List[str]) -> Dict[str, Dict]: ...
async def get_password_reset_token(self, token: str) -> Optional[Dict[str, Any]]: ...
async def get_pending_verifications_count(self) -> int: ...
async def get_recent_chat_messages(self, session_id: str, limit: int = 10) -> List[Dict]: ...
async def get_refresh_token(self, token: str) -> Optional[Dict[str, Any]]: ...
async def get_resumes_by_candidate(self, user_id: str, candidate_id: str) -> List[Dict]: ...
async def get_resumes_by_job(self, user_id: str, job_id: str) -> List[Dict]: ...
async def get_resume(self, user_id: str, resume_id: str) -> Optional[Dict]: ...
async def get_resume_statistics(self, user_id: str) -> Dict[str, Any]: ...
async def get_stats(self) -> Dict[str, int]: ...
async def get_verification_attempts_count(self, email: str) -> int: ...
async def get_viewer(self, viewer_id: str) -> Optional[Dict]: ...
async def increment_mfa_attempts(self, email: str, device_id: str) -> int: ...
async def invalidate_job_requirements_cache(self, document_id: str) -> bool: ...
async def log_security_event(self, user_id: str, event_type: str, details: Dict[str, Any]) -> bool: ...
async def mark_email_verified(self, token: str) -> bool: ...
async def mark_mfa_verified(self, email: str, device_id: str) -> bool: ...
async def mark_password_reset_token_used(self, token: str) -> bool: ...
async def record_verification_attempt(self, email: str) -> bool: ...
async def remove_document_from_candidate(self, candidate_id: str, document_id: str): ...
async def revoke_all_user_tokens(self, user_id: str) -> bool: ...
async def revoke_refresh_token(self, token: str) -> bool: ...
async def save_job_requirements(self, document_id: str, requirements: Dict) -> bool: ...
async def search_candidate_documents(self, candidate_id: str, query: str) -> List[Dict]: ...
async def search_chat_messages(self, session_id: str, query: str) -> List[Dict]: ...
async def search_resumes_for_user(self, user_id: str, query: str) -> List[Dict]: ...
async def set_ai_parameters(self, param_id: str, param_data: Dict): ...
async def set_authentication(self, user_id: str, auth_data: Dict[str, Any]) -> bool: ...
async def set_chat_messages(self, session_id: str, messages: List[Dict]): ...
async def set_chat_session(self, session_id: str, session_data: Dict): ...
async def set_document(self, document_id: str, document_data: Dict): ...
async def set_job_application(self, application_id: str, application_data: Dict): ...
async def set_job(self, job_id: str, job_data: Dict): ...
async def set_resume(self, user_id: str, resume_data: Dict) -> bool: ...
async def set_viewer(self, viewer_id: str, viewer_data: Dict): ...
async def store_email_verification_token(self, email: str, token: str, user_type: str, user_data: dict) -> bool: ...
async def store_mfa_code(self, email: str, code: str, device_id: str) -> bool: ...
async def store_password_reset_token(self, email: str, token: str, expires_at: datetime) -> bool: ...
async def store_refresh_token(self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]) -> bool: ...
async def update_chat_session_activity(self, session_id: str): ...
async def update_document(self, document_id: str, updates: Dict)-> Dict[Any, Any] | None: ...
async def update_resume(self, user_id: str, resume_id: str, updates: Dict) -> Optional[Dict]: ...
async def delete_all_resumes_for_user(self, user_id: str) -> int:
...
async def delete_authentication(self, user_id: str) -> bool:
...
async def delete_document(self, document_id: str):
...
async def delete_job_application(self, application_id: str):
...
async def delete_job_requirements(self, document_id: str) -> bool:
...
async def delete_job(self, job_id: str):
...
async def delete_resume(self, user_id: str, resume_id: str) -> bool:
...
async def delete_viewer(self, viewer_id: str):
...
async def find_verification_token_by_email(self, email: str) -> Optional[Dict[str, Any]]:
...
async def get_ai_parameters(self, param_id: str) -> Optional[Dict]:
...
async def get_all_ai_parameters(self) -> Dict[str, Any]:
...
async def get_all_candidates(self) -> Dict[str, Any]:
...
async def get_all_chat_messages(self) -> Dict[str, List[Dict]]:
...
async def get_all_chat_sessions(self) -> Dict[str, Any]:
...
async def get_all_employers(self) -> Dict[str, Any]:
...
async def get_all_guests(self) -> Dict[str, Dict[str, Any]]:
...
async def get_all_job_applications(self) -> Dict[str, Any]:
...
async def get_all_job_requirements(self) -> Dict[str, Any]:
...
async def get_all_jobs(self) -> Dict[str, Any]:
...
async def get_all_resumes_for_user(self, user_id: str) -> List[Dict]:
...
async def get_all_resumes(self) -> Dict[str, List[Dict]]:
...
async def get_authentication(self, user_id: str) -> Optional[Dict[str, Any]]:
...
async def get_cached_skill_match(self, cache_key: str) -> Optional[SkillAssessment]:
...
async def get_chat_message_count(self, session_id: str) -> int:
...
async def get_chat_messages(self, session_id: str) -> List[Dict]:
...
async def get_chat_sessions_by_candidate(self, candidate_id: str) -> List[Dict]:
...
async def get_chat_sessions_by_user(self, user_id: str) -> List[Dict]:
...
async def get_chat_session(self, session_id: str) -> Optional[Dict]:
...
async def get_chat_statistics(self) -> Dict[str, Any]:
...
async def get_document_count_for_candidate(self, candidate_id: str) -> int:
...
async def get_documents_by_rag_status(self, candidate_id: str, include_in_rag: bool = True) -> List[Dict]:
...
async def get_document(self, document_id: str) -> Optional[Dict]:
...
async def get_email_verification_token(self, token: str) -> Optional[Dict[str, Any]]:
...
async def get_job_application(self, application_id: str) -> Optional[Dict]:
...
async def get_job_requirements_by_candidate(self, candidate_id: str) -> List[Dict]:
...
async def get_job_requirements(self, document_id: str) -> Optional[Dict]:
...
async def get_job_requirements_stats(self) -> Dict[str, Any]:
...
async def get_job(self, job_id: str) -> Optional[Dict]:
...
async def get_mfa_code(self, email: str, device_id: str) -> Optional[Dict[str, Any]]:
...
async def get_multiple_candidates_by_usernames(self, usernames: List[str]) -> Dict[str, Dict]:
...
async def get_password_reset_token(self, token: str) -> Optional[Dict[str, Any]]:
...
async def get_pending_verifications_count(self) -> int:
...
async def get_recent_chat_messages(self, session_id: str, limit: int = 10) -> List[Dict]:
...
async def get_refresh_token(self, token: str) -> Optional[Dict[str, Any]]:
...
async def get_resumes_by_candidate(self, user_id: str, candidate_id: str) -> List[Dict]:
...
async def get_resumes_by_job(self, user_id: str, job_id: str) -> List[Dict]:
...
async def get_resume(self, user_id: str, resume_id: str) -> Optional[Dict]:
...
async def get_resume_statistics(self, user_id: str) -> Dict[str, Any]:
...
async def get_stats(self) -> Dict[str, int]:
...
async def get_verification_attempts_count(self, email: str) -> int:
...
async def get_viewer(self, viewer_id: str) -> Optional[Dict]:
...
async def increment_mfa_attempts(self, email: str, device_id: str) -> int:
...
async def invalidate_job_requirements_cache(self, document_id: str) -> bool:
...
async def log_security_event(self, user_id: str, event_type: str, details: Dict[str, Any]) -> bool:
...
async def mark_email_verified(self, token: str) -> bool:
...
async def mark_mfa_verified(self, email: str, device_id: str) -> bool:
...
async def mark_password_reset_token_used(self, token: str) -> bool:
...
async def record_verification_attempt(self, email: str) -> bool:
...
async def remove_document_from_candidate(self, candidate_id: str, document_id: str):
...
async def revoke_all_user_tokens(self, user_id: str) -> bool:
...
async def revoke_refresh_token(self, token: str) -> bool:
...
async def save_job_requirements(self, document_id: str, requirements: Dict) -> bool:
...
async def search_candidate_documents(self, candidate_id: str, query: str) -> List[Dict]:
...
async def search_chat_messages(self, session_id: str, query: str) -> List[Dict]:
...
async def search_resumes_for_user(self, user_id: str, query: str) -> List[Dict]:
...
async def set_ai_parameters(self, param_id: str, param_data: Dict):
...
async def set_authentication(self, user_id: str, auth_data: Dict[str, Any]) -> bool:
...
async def set_chat_messages(self, session_id: str, messages: List[Dict]):
...
async def set_chat_session(self, session_id: str, session_data: Dict):
...
async def set_document(self, document_id: str, document_data: Dict):
...
async def set_job_application(self, application_id: str, application_data: Dict):
...
async def set_job(self, job_id: str, job_data: Dict):
...
async def set_resume(self, user_id: str, resume_data: Dict) -> bool:
...
async def set_viewer(self, viewer_id: str, viewer_data: Dict):
...
async def store_email_verification_token(self, email: str, token: str, user_type: str, user_data: dict) -> bool:
...
async def store_mfa_code(self, email: str, code: str, device_id: str) -> bool:
...
async def store_password_reset_token(self, email: str, token: str, expires_at: datetime) -> bool:
...
async def store_refresh_token(
self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]
) -> bool:
...
async def update_chat_session_activity(self, session_id: str):
...
async def update_document(self, document_id: str, updates: Dict) -> Dict[Any, Any] | None:
...
async def update_resume(self, user_id: str, resume_id: str, updates: Dict) -> Optional[Dict]:
...

View File

@ -7,6 +7,7 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ResumeMixin(DatabaseProtocol): class ResumeMixin(DatabaseProtocol):
"""Mixin for resume-related database operations""" """Mixin for resume-related database operations"""
@ -14,10 +15,10 @@ class ResumeMixin(DatabaseProtocol):
"""Save a resume for a user""" """Save a resume for a user"""
try: try:
# Generate resume_id if not present # Generate resume_id if not present
if 'id' not in resume_data: if "id" not in resume_data:
raise ValueError("Resume data must include an 'id' field") raise ValueError("Resume data must include an 'id' field")
resume_id = resume_data['id'] resume_id = resume_data["id"]
# Store the resume data # Store the resume data
key = f"{KEY_PREFIXES['resumes']}{user_id}:{resume_id}" key = f"{KEY_PREFIXES['resumes']}{user_id}:{resume_id}"
@ -159,7 +160,7 @@ class ResumeMixin(DatabaseProtocol):
for key, value in zip(keys, values): for key, value in zip(keys, values):
if value: if value:
# Extract user_id from key format: resume:{user_id}:{resume_id} # Extract user_id from key format: resume:{user_id}:{resume_id}
key_parts = key.replace(KEY_PREFIXES['resumes'], '').split(':', 1) key_parts = key.replace(KEY_PREFIXES["resumes"], "").split(":", 1)
if len(key_parts) >= 1: if len(key_parts) >= 1:
user_id = key_parts[0] user_id = key_parts[0]
resume_data = self._deserialize(value) resume_data = self._deserialize(value)
@ -186,12 +187,14 @@ class ResumeMixin(DatabaseProtocol):
matching_resumes = [] matching_resumes = []
for resume in all_resumes: for resume in all_resumes:
# Search in resume content, job_id, candidate_id, etc. # Search in resume content, job_id, candidate_id, etc.
searchable_text = " ".join([ searchable_text = " ".join(
[
resume.get("resume", ""), resume.get("resume", ""),
resume.get("job_id", ""), resume.get("job_id", ""),
resume.get("candidate_id", ""), resume.get("candidate_id", ""),
str(resume.get("created_at", "")) str(resume.get("created_at", "")),
]).lower() ]
).lower()
if query_lower in searchable_text: if query_lower in searchable_text:
matching_resumes.append(resume) matching_resumes.append(resume)
@ -206,10 +209,7 @@ class ResumeMixin(DatabaseProtocol):
"""Get all resumes for a specific candidate created by a user""" """Get all resumes for a specific candidate created by a user"""
try: try:
all_resumes = await self.get_all_resumes_for_user(user_id) all_resumes = await self.get_all_resumes_for_user(user_id)
candidate_resumes = [ candidate_resumes = [resume for resume in all_resumes if resume.get("candidate_id") == candidate_id]
resume for resume in all_resumes
if resume.get("candidate_id") == candidate_id
]
logger.info(f"📄 Found {len(candidate_resumes)} resumes for candidate {candidate_id} by user {user_id}") logger.info(f"📄 Found {len(candidate_resumes)} resumes for candidate {candidate_id} by user {user_id}")
return candidate_resumes return candidate_resumes
@ -221,10 +221,7 @@ class ResumeMixin(DatabaseProtocol):
"""Get all resumes for a specific job created by a user""" """Get all resumes for a specific job created by a user"""
try: try:
all_resumes = await self.get_all_resumes_for_user(user_id) all_resumes = await self.get_all_resumes_for_user(user_id)
job_resumes = [ job_resumes = [resume for resume in all_resumes if resume.get("job_id") == job_id]
resume for resume in all_resumes
if resume.get("job_id") == job_id
]
logger.info(f"📄 Found {len(job_resumes)} resumes for job {job_id} by user {user_id}") logger.info(f"📄 Found {len(job_resumes)} resumes for job {job_id} by user {user_id}")
return job_resumes return job_resumes
@ -242,7 +239,7 @@ class ResumeMixin(DatabaseProtocol):
"resumes_by_candidate": {}, "resumes_by_candidate": {},
"resumes_by_job": {}, "resumes_by_job": {},
"creation_timeline": {}, "creation_timeline": {},
"recent_resumes": [] "recent_resumes": [],
} }
for resume in all_resumes: for resume in all_resumes:
@ -269,7 +266,13 @@ class ResumeMixin(DatabaseProtocol):
return stats return stats
except Exception as e: except Exception as e:
logger.error(f"❌ Error getting resume statistics for user {user_id}: {e}") logger.error(f"❌ Error getting resume statistics for user {user_id}: {e}")
return {"total_resumes": 0, "resumes_by_candidate": {}, "resumes_by_job": {}, "creation_timeline": {}, "recent_resumes": []} return {
"total_resumes": 0,
"resumes_by_candidate": {},
"resumes_by_job": {},
"creation_timeline": {},
"recent_resumes": [],
}
async def update_resume(self, user_id: str, resume_id: str, updates: Dict) -> Optional[Dict]: async def update_resume(self, user_id: str, resume_id: str, updates: Dict) -> Optional[Dict]:
"""Update specific fields of a resume""" """Update specific fields of a resume"""

View File

@ -8,6 +8,7 @@ from .protocols import DatabaseProtocol
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SkillMixin(DatabaseProtocol): class SkillMixin(DatabaseProtocol):
"""Mixin for Skill-related database operations""" """Mixin for Skill-related database operations"""
@ -74,7 +75,9 @@ class SkillMixin(DatabaseProtocol):
# Cache for 1 hour by default # Cache for 1 hour by default
await self.redis.set( await self.redis.set(
cache_key, cache_key,
json.dumps(assessment.model_dump(mode='json', by_alias=True), default=str) # Serialize with datetime handling json.dumps(
assessment.model_dump(mode="json", by_alias=True), default=str
), # Serialize with datetime handling
) )
logger.info(f"💾 Skill match cached: {cache_key}") logger.info(f"💾 Skill match cached: {cache_key}")
except Exception as e: except Exception as e:

View File

@ -10,6 +10,7 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UserMixin(DatabaseProtocol): class UserMixin(DatabaseProtocol):
"""Mixin for user operations""" """Mixin for user operations"""
@ -29,7 +30,7 @@ class UserMixin(DatabaseProtocol):
await self.redis.setex( await self.redis.setex(
f"guest_backup:{guest_id}", f"guest_backup:{guest_id}",
86400 * 7, # 7 days TTL 86400 * 7, # 7 days TTL
json.dumps(guest_data) json.dumps(guest_data),
) )
logger.info(f"💾 Guest stored with backup: {guest_id}") logger.info(f"💾 Guest stored with backup: {guest_id}")
@ -83,10 +84,7 @@ class UserMixin(DatabaseProtocol):
"""Get all guests""" """Get all guests"""
try: try:
data = await self.redis.hgetall("guests") # type: ignore data = await self.redis.hgetall("guests") # type: ignore
return { return {guest_id: json.loads(guest_json) for guest_id, guest_json in data.items()}
guest_id: json.loads(guest_json)
for guest_id, guest_json in data.items()
}
except Exception as e: except Exception as e:
logger.error(f"❌ Error getting all guests: {e}") logger.error(f"❌ Error getting all guests: {e}")
return {} return {}
@ -120,7 +118,7 @@ class UserMixin(DatabaseProtocol):
# Skip cleanup if guest is very new (less than 1 hour old) # Skip cleanup if guest is very new (less than 1 hour old)
if created_at_str: if created_at_str:
created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00')) created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00"))
if current_time - created_at < timedelta(hours=1): if current_time - created_at < timedelta(hours=1):
preserved_count += 1 preserved_count += 1
logger.info(f"🛡️ Preserving new guest: {guest_id}") logger.info(f"🛡️ Preserving new guest: {guest_id}")
@ -130,7 +128,7 @@ class UserMixin(DatabaseProtocol):
should_delete = False should_delete = False
if last_activity_str: if last_activity_str:
try: try:
last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00')) last_activity = datetime.fromisoformat(last_activity_str.replace("Z", "+00:00"))
if last_activity < cutoff_time: if last_activity < cutoff_time:
should_delete = True should_delete = True
except ValueError: except ValueError:
@ -172,7 +170,7 @@ class UserMixin(DatabaseProtocol):
"active_last_day": 0, "active_last_day": 0,
"converted_guests": 0, "converted_guests": 0,
"by_ip": {}, "by_ip": {},
"creation_timeline": {} "creation_timeline": {},
} }
hour_ago = current_time - timedelta(hours=1) hour_ago = current_time - timedelta(hours=1)
@ -183,7 +181,7 @@ class UserMixin(DatabaseProtocol):
last_activity_str = guest_data.get("last_activity") last_activity_str = guest_data.get("last_activity")
if last_activity_str: if last_activity_str:
try: try:
last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00')) last_activity = datetime.fromisoformat(last_activity_str.replace("Z", "+00:00"))
if last_activity > hour_ago: if last_activity > hour_ago:
stats["active_last_hour"] += 1 stats["active_last_hour"] += 1
if last_activity > day_ago: if last_activity > day_ago:
@ -203,8 +201,8 @@ class UserMixin(DatabaseProtocol):
created_at_str = guest_data.get("created_at") created_at_str = guest_data.get("created_at")
if created_at_str: if created_at_str:
try: try:
created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00')) created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00"))
date_key = created_at.strftime('%Y-%m-%d') date_key = created_at.strftime("%Y-%m-%d")
stats["creation_timeline"][date_key] = stats["creation_timeline"].get(date_key, 0) + 1 stats["creation_timeline"][date_key] = stats["creation_timeline"].get(date_key, 0) + 1
except ValueError: except ValueError:
pass pass
@ -321,7 +319,7 @@ class UserMixin(DatabaseProtocol):
result = {} result = {}
for key, value in zip(keys, values): for key, value in zip(keys, values):
email = key.replace(KEY_PREFIXES['users'], '') email = key.replace(KEY_PREFIXES["users"], "")
logger.info(f"🔍 Found user key: {key}, type: {type(value)}") logger.info(f"🔍 Found user key: {key}, type: {type(value)}")
if type(value) == str: if type(value) == str:
result[email] = value result[email] = value
@ -364,7 +362,6 @@ class UserMixin(DatabaseProtocol):
logger.error(f"❌ Error storing user {login}: {e}") logger.error(f"❌ Error storing user {login}: {e}")
return False return False
# ================ # ================
# Employers # Employers
# ================ # ================
@ -394,7 +391,7 @@ class UserMixin(DatabaseProtocol):
result = {} result = {}
for key, value in zip(keys, values): for key, value in zip(keys, values):
employer_id = key.replace(KEY_PREFIXES['employers'], '') employer_id = key.replace(KEY_PREFIXES["employers"], "")
result[employer_id] = self._deserialize(value) result[employer_id] = self._deserialize(value)
return result return result
@ -404,7 +401,6 @@ class UserMixin(DatabaseProtocol):
key = f"{KEY_PREFIXES['employers']}{employer_id}" key = f"{KEY_PREFIXES['employers']}{employer_id}"
await self.redis.delete(key) await self.redis.delete(key)
# ================ # ================
# Candidates # Candidates
# ================ # ================
@ -435,7 +431,7 @@ class UserMixin(DatabaseProtocol):
result = {} result = {}
for key, value in zip(keys, values): for key, value in zip(keys, values):
candidate_id = key.replace(KEY_PREFIXES['candidates'], '') candidate_id = key.replace(KEY_PREFIXES["candidates"], "")
result[candidate_id] = self._deserialize(value) result[candidate_id] = self._deserialize(value)
return result return result
@ -456,7 +452,7 @@ class UserMixin(DatabaseProtocol):
"security_logs": 0, "security_logs": 0,
"ai_parameters": 0, "ai_parameters": 0,
"candidate_record": 0, "candidate_record": 0,
"resumes": 0 "resumes": 0,
} }
logger.info(f"🗑️ Starting cascading delete for candidate {candidate_id}") logger.info(f"🗑️ Starting cascading delete for candidate {candidate_id}")
@ -495,7 +491,9 @@ class UserMixin(DatabaseProtocol):
deletion_stats["chat_sessions"] = len(candidate_sessions) deletion_stats["chat_sessions"] = len(candidate_sessions)
deletion_stats["chat_messages"] = messages_deleted deletion_stats["chat_messages"] = messages_deleted
logger.info(f"🗑️ Deleted {len(candidate_sessions)} chat sessions and {messages_deleted} messages for candidate {candidate_id}") logger.info(
f"🗑️ Deleted {len(candidate_sessions)} chat sessions and {messages_deleted} messages for candidate {candidate_id}"
)
except Exception as e: except Exception as e:
logger.error(f"❌ Error deleting chat sessions: {e}") logger.error(f"❌ Error deleting chat sessions: {e}")
@ -528,9 +526,11 @@ class UserMixin(DatabaseProtocol):
logger.info(f"🗑️ Deleted user record by email: {candidate_email}") logger.info(f"🗑️ Deleted user record by email: {candidate_email}")
# Delete by username (if different from email) # Delete by username (if different from email)
if (candidate_username and if (
candidate_username != candidate_email and candidate_username
await self.user_exists_by_username(candidate_username)): and candidate_username != candidate_email
and await self.user_exists_by_username(candidate_username)
):
await self.delete_user(candidate_username) await self.delete_user(candidate_username)
user_records_deleted += 1 user_records_deleted += 1
logger.info(f"🗑️ Deleted user record by username: {candidate_username}") logger.info(f"🗑️ Deleted user record by username: {candidate_username}")
@ -593,8 +593,7 @@ class UserMixin(DatabaseProtocol):
candidate_ai_params = [] candidate_ai_params = []
for param_id, param_data in all_ai_params.items(): for param_id, param_data in all_ai_params.items():
if (param_data.get("candidateId") == candidate_id or if param_data.get("candidateId") == candidate_id or param_data.get("userId") == candidate_id:
param_data.get("userId") == candidate_id):
candidate_ai_params.append(param_id) candidate_ai_params.append(param_id)
# Delete each AI parameter set # Delete each AI parameter set
@ -630,7 +629,9 @@ class UserMixin(DatabaseProtocol):
break break
if tokens_deleted > 0: if tokens_deleted > 0:
logger.info(f"🗑️ Deleted {tokens_deleted} email verification tokens for candidate {candidate_id}") logger.info(
f"🗑️ Deleted {tokens_deleted} email verification tokens for candidate {candidate_id}"
)
except Exception as e: except Exception as e:
logger.error(f"❌ Error deleting email verification tokens: {e}") logger.error(f"❌ Error deleting email verification tokens: {e}")
@ -717,8 +718,10 @@ class UserMixin(DatabaseProtocol):
# 15. Log the deletion as a security event (if we have admin/system user context) # 15. Log the deletion as a security event (if we have admin/system user context)
try: try:
total_items_deleted = sum(deletion_stats.values()) total_items_deleted = sum(deletion_stats.values())
logger.info(f"✅ Completed cascading delete for candidate {candidate_id}. " logger.info(
f"Total items deleted: {total_items_deleted}") f"✅ Completed cascading delete for candidate {candidate_id}. "
f"Total items deleted: {total_items_deleted}"
)
logger.info(f"📊 Deletion breakdown: {deletion_stats}") logger.info(f"📊 Deletion breakdown: {deletion_stats}")
except Exception as e: except Exception as e:
logger.error(f"❌ Error logging deletion summary: {e}") logger.error(f"❌ Error logging deletion summary: {e}")
@ -774,8 +777,8 @@ class UserMixin(DatabaseProtocol):
"total_candidates_processed": len(candidate_ids), "total_candidates_processed": len(candidate_ids),
"successful_deletions": len([r for r in batch_results.values() if "error" not in r]), "successful_deletions": len([r for r in batch_results.values() if "error" not in r]),
"failed_deletions": len([r for r in batch_results.values() if "error" in r]), "failed_deletions": len([r for r in batch_results.values() if "error" in r]),
"total_items_deleted": sum(total_stats.values()) "total_items_deleted": sum(total_stats.values()),
} },
} }
except Exception as e: except Exception as e:
@ -816,7 +819,7 @@ class UserMixin(DatabaseProtocol):
"total_sessions": 0, "total_sessions": 0,
"total_messages": 0, "total_messages": 0,
"first_chat": None, "first_chat": None,
"last_chat": None "last_chat": None,
} }
total_messages = 0 total_messages = 0
@ -835,7 +838,7 @@ class UserMixin(DatabaseProtocol):
"total_messages": total_messages, "total_messages": total_messages,
"first_chat": sessions_by_date[0].get("createdAt") if sessions_by_date else None, "first_chat": sessions_by_date[0].get("createdAt") if sessions_by_date else None,
"last_chat": sessions_by_date[-1].get("lastActivity") if sessions_by_date else None, "last_chat": sessions_by_date[-1].get("lastActivity") if sessions_by_date else None,
"recent_sessions": sessions[:5] # Last 5 sessions "recent_sessions": sessions[:5], # Last 5 sessions
} }
# ================ # ================
@ -868,7 +871,7 @@ class UserMixin(DatabaseProtocol):
result = {} result = {}
for key, value in zip(keys, values): for key, value in zip(keys, values):
viewer_id = key.replace(KEY_PREFIXES['viewers'], '') viewer_id = key.replace(KEY_PREFIXES["viewers"], "")
result[viewer_id] = self._deserialize(value) result[viewer_id] = self._deserialize(value)
return result return result
@ -877,4 +880,3 @@ class UserMixin(DatabaseProtocol):
"""Delete viewer""" """Delete viewer"""
key = f"{KEY_PREFIXES['viewers']}{viewer_id}" key = f"{KEY_PREFIXES['viewers']}{viewer_id}"
await self.redis.delete(key) await self.redis.delete(key)

View File

@ -6,6 +6,7 @@ from datetime import datetime, timezone
from user_agents import parse from user_agents import parse
import json import json
class DeviceManager: class DeviceManager:
def __init__(self, database: RedisDatabase): def __init__(self, database: RedisDatabase):
self.database = database self.database = database
@ -34,7 +35,7 @@ class DeviceManager:
"os": user_agent.os.family, "os": user_agent.os.family,
"os_version": user_agent.os.version_string, "os_version": user_agent.os.version_string,
"ip_address": request.client.host if request.client else "unknown", "ip_address": request.client.host if request.client else "unknown",
"user_agent": user_agent_string "user_agent": user_agent_string,
} }
async def is_trusted_device(self, user_id: str, device_id: str) -> bool: async def is_trusted_device(self, user_id: str, device_id: str) -> bool:
@ -54,14 +55,14 @@ class DeviceManager:
device_data = { device_data = {
**device_info, **device_info,
"added_at": datetime.now(timezone.utc).isoformat(), "added_at": datetime.now(timezone.utc).isoformat(),
"last_used": datetime.now(timezone.utc).isoformat() "last_used": datetime.now(timezone.utc).isoformat(),
} }
# Store for 90 days # Store for 90 days
await self.database.redis.setex( await self.database.redis.setex(
key, key,
90 * 24 * 60 * 60, # 90 days in seconds 90 * 24 * 60 * 60, # 90 days in seconds
json.dumps(device_data, default=str) json.dumps(device_data, default=str),
) )
logger.info(f"🔒 Added trusted device {device_id} for user {user_id}") logger.info(f"🔒 Added trusted device {device_id} for user {user_id}")
@ -79,7 +80,7 @@ class DeviceManager:
await self.database.redis.setex( await self.database.redis.setex(
key, key,
90 * 24 * 60 * 60, # Reset 90 day expiry 90 * 24 * 60 * 60, # Reset 90 day expiry
json.dumps(device_info, default=str) json.dumps(device_info, default=str),
) )
except Exception as e: except Exception as e:
logger.error(f"Error updating device last used: {e}") logger.error(f"Error updating device last used: {e}")

View File

@ -10,6 +10,7 @@ from datetime import datetime, timezone, timedelta
import json import json
from database.manager import RedisDatabase from database.manager import RedisDatabase
class EmailService: class EmailService:
def __init__(self): def __init__(self):
# Configure these in your .env file # Configure these in your .env file
@ -30,36 +31,25 @@ class EmailService:
def _format_template(self, template: str, **kwargs) -> str: def _format_template(self, template: str, **kwargs) -> str:
"""Format template with provided variables""" """Format template with provided variables"""
return template.format( return template.format(
app_name=self.app_name, app_name=self.app_name, from_name=self.from_name, frontend_url=self.frontend_url, **kwargs
from_name=self.from_name,
frontend_url=self.frontend_url,
**kwargs
) )
async def send_verification_email( async def send_verification_email(
self, self, to_email: str, verification_token: str, user_name: str, user_type: str = "user"
to_email: str,
verification_token: str,
user_name: str,
user_type: str = "user"
): ):
"""Send email verification email using template""" """Send email verification email using template"""
try: try:
template = self._get_template("verification") template = self._get_template("verification")
verification_link = f"{self.frontend_url}/login/verify-email?token={verification_token}" verification_link = f"{self.frontend_url}/login/verify-email?token={verification_token}"
subject = self._format_template( subject = self._format_template(template["subject"], user_name=user_name, to_email=to_email)
template["subject"],
user_name=user_name,
to_email=to_email
)
html_content = self._format_template( html_content = self._format_template(
template["html"], template["html"],
user_name=user_name, user_name=user_name,
user_type=user_type, user_type=user_type,
to_email=to_email, to_email=to_email,
verification_link=verification_link verification_link=verification_link,
) )
await self._send_email(to_email, subject, html_content) await self._send_email(to_email, subject, html_content)
@ -70,12 +60,7 @@ class EmailService:
raise raise
async def send_mfa_email( async def send_mfa_email(
self, self, to_email: str, mfa_code: str, device_name: str, user_name: str, ip_address: str = "Unknown"
to_email: str,
mfa_code: str,
device_name: str,
user_name: str,
ip_address: str = "Unknown"
): ):
"""Send MFA code email using template""" """Send MFA code email using template"""
try: try:
@ -91,7 +76,7 @@ class EmailService:
ip_address=ip_address, ip_address=ip_address,
login_time=login_time, login_time=login_time,
mfa_code=mfa_code, mfa_code=mfa_code,
to_email=to_email to_email=to_email,
) )
await self._send_email(to_email, subject, html_content) await self._send_email(to_email, subject, html_content)
@ -101,12 +86,7 @@ class EmailService:
logger.error(f"❌ Failed to send MFA email to {to_email}: {e}") logger.error(f"❌ Failed to send MFA email to {to_email}: {e}")
raise raise
async def send_password_reset_email( async def send_password_reset_email(self, to_email: str, reset_token: str, user_name: str):
self,
to_email: str,
reset_token: str,
user_name: str
):
"""Send password reset email using template""" """Send password reset email using template"""
try: try:
template = self._get_template("password_reset") template = self._get_template("password_reset")
@ -115,10 +95,7 @@ class EmailService:
subject = self._format_template(template["subject"]) subject = self._format_template(template["subject"])
html_content = self._format_template( html_content = self._format_template(
template["html"], template["html"], user_name=user_name, reset_link=reset_link, to_email=to_email
user_name=user_name,
reset_link=reset_link,
to_email=to_email
) )
await self._send_email(to_email, subject, html_content) await self._send_email(to_email, subject, html_content)
@ -134,14 +111,14 @@ class EmailService:
if not self.email_user: if not self.email_user:
raise ValueError("Email user is not configured") raise ValueError("Email user is not configured")
# Create message # Create message
msg = MIMEMultipart('alternative') msg = MIMEMultipart("alternative")
msg['From'] = f"{self.from_name} <{self.email_user}>" msg["From"] = f"{self.from_name} <{self.email_user}>"
msg['To'] = to_email msg["To"] = to_email
msg['Subject'] = subject msg["Subject"] = subject
msg['Reply-To'] = self.email_user msg["Reply-To"] = self.email_user
# Add HTML content # Add HTML content
html_part = MIMEText(html_content, 'html', 'utf-8') html_part = MIMEText(html_content, "html", "utf-8")
msg.attach(html_part) msg.attach(html_part)
# Send email with connection pooling and retry logic # Send email with connection pooling and retry logic
@ -157,7 +134,6 @@ class EmailService:
text = msg.as_string() text = msg.as_string()
server.sendmail(self.email_user, to_email, text) server.sendmail(self.email_user, to_email, text)
break # Success, exit retry loop break # Success, exit retry loop
except smtplib.SMTPException as e: except smtplib.SMTPException as e:
if attempt == max_retries - 1: # Last attempt if attempt == max_retries - 1: # Last attempt
raise raise
@ -170,6 +146,7 @@ class EmailService:
logger.error(f"❌ SMTP error sending to {to_email}: {e}") logger.error(f"❌ SMTP error sending to {to_email}: {e}")
raise raise
class EmailRateLimiter: class EmailRateLimiter:
def __init__(self, database: RedisDatabase): def __init__(self, database: RedisDatabase):
self.database = database self.database = database
@ -191,10 +168,7 @@ class EmailRateLimiter:
email_records = json.loads(count_data) email_records = json.loads(count_data)
# Filter out old records # Filter out old records
recent_records = [ recent_records = [record for record in email_records if datetime.fromisoformat(record) > window_start]
record for record in email_records
if datetime.fromisoformat(record) > window_start
]
if len(recent_records) >= limit: if len(recent_records) >= limit:
logger.warning(f"⚠️ Email rate limit exceeded for {email} ({email_type})") logger.warning(f"⚠️ Email rate limit exceeded for {email} ({email_type})")
@ -202,11 +176,7 @@ class EmailRateLimiter:
# Add current email to records # Add current email to records
recent_records.append(current_time.isoformat()) recent_records.append(current_time.isoformat())
await self.database.redis.setex( await self.database.redis.setex(key, window_minutes * 60, json.dumps(recent_records))
key,
window_minutes * 60,
json.dumps(recent_records)
)
return True return True
@ -217,11 +187,8 @@ class EmailRateLimiter:
async def _record_email_sent(self, key: str, timestamp: datetime, ttl_minutes: int): async def _record_email_sent(self, key: str, timestamp: datetime, ttl_minutes: int):
"""Record that an email was sent""" """Record that an email was sent"""
await self.database.redis.setex( await self.database.redis.setex(key, ttl_minutes * 60, json.dumps([timestamp.isoformat()]))
key,
ttl_minutes * 60,
json.dumps([timestamp.isoformat()])
)
class VerificationEmailRateLimiter: class VerificationEmailRateLimiter:
def __init__(self, database: RedisDatabase): def __init__(self, database: RedisDatabase):
@ -242,7 +209,10 @@ class VerificationEmailRateLimiter:
# Check daily limit # Check daily limit
daily_count = await self.database.get_verification_attempts_count(email) daily_count = await self.database.get_verification_attempts_count(email)
if daily_count >= self.max_attempts_per_day: if daily_count >= self.max_attempts_per_day:
return False, f"Daily limit reached. You can request up to {self.max_attempts_per_day} verification emails per day." return (
False,
f"Daily limit reached. You can request up to {self.max_attempts_per_day} verification emails per day.",
)
# Check hourly limit # Check hourly limit
hour_ago = current_time - timedelta(hours=1) hour_ago = current_time - timedelta(hours=1)
@ -251,13 +221,13 @@ class VerificationEmailRateLimiter:
if data: if data:
attempts_data = json.loads(data) attempts_data = json.loads(data)
recent_attempts = [ recent_attempts = [attempt for attempt in attempts_data if datetime.fromisoformat(attempt) > hour_ago]
attempt for attempt in attempts_data
if datetime.fromisoformat(attempt) > hour_ago
]
if len(recent_attempts) >= self.max_attempts_per_hour: if len(recent_attempts) >= self.max_attempts_per_hour:
return False, f"Hourly limit reached. You can request up to {self.max_attempts_per_hour} verification emails per hour." return (
False,
f"Hourly limit reached. You can request up to {self.max_attempts_per_hour} verification emails per hour.",
)
# Check cooldown period # Check cooldown period
if recent_attempts: if recent_attempts:
@ -280,7 +250,4 @@ class VerificationEmailRateLimiter:
await self.database.record_verification_attempt(email) await self.database.record_verification_attempt(email)
email_service = EmailService() email_service = EmailService()

View File

@ -129,9 +129,8 @@ EMAIL_TEMPLATES = {
</div> </div>
</body> </body>
</html> </html>
""" """,
}, },
"mfa": { "mfa": {
"subject": "Security Code for Backstory", "subject": "Security Code for Backstory",
"html": """ "html": """
@ -274,9 +273,8 @@ EMAIL_TEMPLATES = {
</div> </div>
</body> </body>
</html> </html>
""" """,
}, },
"password_reset": { "password_reset": {
"subject": "Reset your Backstory password", "subject": "Reset your Backstory password",
"html": """ "html": """
@ -386,6 +384,6 @@ EMAIL_TEMPLATES = {
</div> </div>
</body> </body>
</html> </html>
""" """,
} },
} }

View File

@ -10,6 +10,7 @@ from agents.base import CandidateEntity
from database.manager import RedisDatabase from database.manager import RedisDatabase
from prometheus_client import CollectorRegistry # type: ignore from prometheus_client import CollectorRegistry # type: ignore
class EntityManager(BaseModel): class EntityManager(BaseModel):
"""Manages lifecycle of CandidateEntity instances""" """Manages lifecycle of CandidateEntity instances"""
@ -36,10 +37,7 @@ class EntityManager(BaseModel):
pass pass
self._cleanup_task = None self._cleanup_task = None
def initialize( def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
self,
prometheus_collector: CollectorRegistry,
database: RedisDatabase):
"""Initialize the EntityManager with Prometheus collector""" """Initialize the EntityManager with Prometheus collector"""
self._prometheus_collector = prometheus_collector self._prometheus_collector = prometheus_collector
self._database = database self._database = database
@ -58,9 +56,7 @@ class EntityManager(BaseModel):
raise ValueError("EntityManager has not been initialized with required components.") raise ValueError("EntityManager has not been initialized with required components.")
entity = CandidateEntity(candidate=candidate) entity = CandidateEntity(candidate=candidate)
await entity.initialize( await entity.initialize(prometheus_collector=self._prometheus_collector, database=self._database)
prometheus_collector=self._prometheus_collector,
database=self._database)
# Store with reference tracking # Store with reference tracking
self._entities[candidate.id] = entity self._entities[candidate.id] = entity
@ -105,10 +101,12 @@ class EntityManager(BaseModel):
def _on_entity_deleted(self, user_id: str): def _on_entity_deleted(self, user_id: str):
"""Callback when entity is garbage collected""" """Callback when entity is garbage collected"""
def cleanup_callback(weak_ref): def cleanup_callback(weak_ref):
self._entities.pop(user_id, None) self._entities.pop(user_id, None)
self._weak_refs.pop(user_id, None) self._weak_refs.pop(user_id, None)
print(f"Entity {user_id} garbage collected") print(f"Entity {user_id} garbage collected")
return cleanup_callback return cleanup_callback
async def release_entity(self, user_id: str): async def release_entity(self, user_id: str):
@ -138,8 +136,7 @@ class EntityManager(BaseModel):
time_since_access = current_time - entity.last_accessed time_since_access = current_time - entity.last_accessed
# Remove if TTL exceeded and no active references # Remove if TTL exceeded and no active references
if (time_since_access > timedelta(minutes=self._ttl_minutes) if time_since_access > timedelta(minutes=self._ttl_minutes) and entity.reference_count == 0:
and entity.reference_count == 0):
expired_entities.append(user_id) expired_entities.append(user_id)
for user_id in expired_entities: for user_id in expired_entities:
@ -153,6 +150,7 @@ class EntityManager(BaseModel):
# Global entity manager instance # Global entity manager instance
entity_manager = EntityManager(default_ttl_minutes=30) entity_manager = EntityManager(default_ttl_minutes=30)
@asynccontextmanager @asynccontextmanager
async def get_candidate_entity(candidate: Candidate): async def get_candidate_entity(candidate: Candidate):
"""Context manager for safe entity access with automatic reference management""" """Context manager for safe entity access with automatic reference management"""
@ -164,4 +162,5 @@ async def get_candidate_entity(candidate: Candidate):
finally: finally:
await entity_manager.release_entity(candidate.id) await entity_manager.release_entity(candidate.id)
EntityManager.model_rebuild() EntityManager.model_rebuild()

View File

@ -6,10 +6,7 @@ without getting caught up in serialization format complexities
import sys import sys
from datetime import datetime from datetime import datetime
from models import ( from models import UserStatus, UserType, SkillLevel, EmploymentType, Candidate, Employer, Location, Skill
UserStatus, UserType, SkillLevel, EmploymentType,
Candidate, Employer, Location, Skill
)
def test_model_creation(): def test_model_creation():
@ -37,7 +34,7 @@ def test_model_creation():
preferred_job_types=[EmploymentType.FULL_TIME], preferred_job_types=[EmploymentType.FULL_TIME],
location=location, location=location,
languages=[], languages=[],
certifications=[] certifications=[],
) )
# Create employer # Create employer
@ -54,7 +51,7 @@ def test_model_creation():
industry="Technology", industry="Technology",
company_size="50-200", company_size="50-200",
company_description="A test company", company_description="A test company",
location=location location=location,
) )
print(f"✅ Candidate: {candidate.first_name} {candidate.last_name}") print(f"✅ Candidate: {candidate.first_name} {candidate.last_name}")
@ -63,6 +60,7 @@ def test_model_creation():
return candidate, employer return candidate, employer
def test_json_api_format(): def test_json_api_format():
"""Test JSON serialization in API format (the most important use case)""" """Test JSON serialization in API format (the most important use case)"""
print("\n📡 Testing JSON API format...") print("\n📡 Testing JSON API format...")
@ -90,6 +88,7 @@ def test_json_api_format():
return True return True
def test_api_dict_format(): def test_api_dict_format():
"""Test dictionary format with aliases (for API requests/responses)""" """Test dictionary format with aliases (for API requests/responses)"""
print("\n📊 Testing API dictionary format...") print("\n📊 Testing API dictionary format...")
@ -120,6 +119,7 @@ def test_api_dict_format():
return True return True
def test_validation_constraints(): def test_validation_constraints():
"""Test that validation constraints work""" """Test that validation constraints work"""
print("\n🔒 Testing validation constraints...") print("\n🔒 Testing validation constraints...")
@ -135,7 +135,7 @@ def test_validation_constraints():
status=UserStatus.ACTIVE, status=UserStatus.ACTIVE,
first_name="Jane", first_name="Jane",
last_name="Doe", last_name="Doe",
full_name="Jane Doe" full_name="Jane Doe",
) )
print("❌ Validation should have failed but didn't") print("❌ Validation should have failed but didn't")
return False return False
@ -143,6 +143,7 @@ def test_validation_constraints():
print(f"✅ Validation error caught: {e}") print(f"✅ Validation error caught: {e}")
return True return True
def test_enum_values(): def test_enum_values():
"""Test that enum values work correctly""" """Test that enum values work correctly"""
print("\n📋 Testing enum values...") print("\n📋 Testing enum values...")
@ -162,6 +163,7 @@ def test_enum_values():
return True return True
def main(): def main():
"""Run all focused tests""" """Run all focused tests"""
print("🎯 Focused Pydantic Model Tests") print("🎯 Focused Pydantic Model Tests")
@ -187,10 +189,12 @@ def main():
except Exception as e: except Exception as e:
print(f"\n❌ Test failed: {type(e).__name__}: {e}") print(f"\n❌ Test failed: {type(e).__name__}: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
print(f"\n{traceback.format_exc()}") print(f"\n{traceback.format_exc()}")
return False return False
if __name__ == "__main__": if __name__ == "__main__":
success = main() success = main()
sys.exit(0 if success else 1) sys.exit(0 if success else 1)

View File

@ -2,6 +2,7 @@ from pydantic import BaseModel
import json import json
from typing import Any, List, Set from typing import Any, List, Set
def check_serializable(obj: Any, path: str = "", errors: List[str] = [], visited: Set[int] = set()) -> List[str]: def check_serializable(obj: Any, path: str = "", errors: List[str] = [], visited: Set[int] = set()) -> List[str]:
""" """
Recursively check all fields in an object for non-JSON-serializable types, avoiding infinite recursion. Recursively check all fields in an object for non-JSON-serializable types, avoiding infinite recursion.

View File

@ -10,16 +10,19 @@ assert issubclass(CandidateAI, BaseUserWithType), "CandidateAI must inherit from
assert issubclass(Employer, BaseUserWithType), "Employer must inherit from BaseUserWithType" assert issubclass(Employer, BaseUserWithType), "Employer must inherit from BaseUserWithType"
assert issubclass(Guest, BaseUserWithType), "Guest must inherit from BaseUserWithType" assert issubclass(Guest, BaseUserWithType), "Guest must inherit from BaseUserWithType"
T = TypeVar('T', bound=BaseModel) T = TypeVar("T", bound=BaseModel)
def cast_to_model(model_cls: Type[T], source: BaseModel) -> T: def cast_to_model(model_cls: Type[T], source: BaseModel) -> T:
data = {field: getattr(source, field) for field in model_cls.__fields__} data = {field: getattr(source, field) for field in model_cls.__fields__}
return model_cls(**data) return model_cls(**data)
def cast_to_model_safe(model_cls: Type[T], source: BaseModel) -> T: def cast_to_model_safe(model_cls: Type[T], source: BaseModel) -> T:
data = {field: copy.deepcopy(getattr(source, field)) for field in model_cls.__fields__} data = {field: copy.deepcopy(getattr(source, field)) for field in model_cls.__fields__}
return model_cls(**data) return model_cls(**data)
def cast_to_base_user_with_type(user) -> BaseUserWithType: def cast_to_base_user_with_type(user) -> BaseUserWithType:
""" """
Casts a Candidate, CandidateAI, Employer, or Guest to BaseUserWithType. Casts a Candidate, CandidateAI, Employer, or Guest to BaseUserWithType.

View File

@ -7,6 +7,7 @@ from typing import Any
import torch import torch
from diffusers import StableDiffusionPipeline, FluxPipeline from diffusers import StableDiffusionPipeline, FluxPipeline
class ImageModelCache: # Stay loaded for 3 hours class ImageModelCache: # Stay loaded for 3 hours
def __init__(self, timeout_seconds: float = 3 * 60 * 60): def __init__(self, timeout_seconds: float = 3 * 60 * 60):
self._pipe = None self._pipe = None
@ -36,11 +37,11 @@ class ImageModelCache: # Stay loaded for 3 hours
cached_model_type = self._get_model_type(self._model_name) if self._model_name else None cached_model_type = self._get_model_type(self._model_name) if self._model_name else None
if ( if (
self._pipe is not None and self._pipe is not None
self._model_name == model and and self._model_name == model
self._device == device and and self._device == device
current_model_type == cached_model_type and and current_model_type == cached_model_type
current_time - self._last_access_time < self._timeout_seconds and current_time - self._last_access_time < self._timeout_seconds
): ):
self._last_access_time = current_time self._last_access_time = current_time
return self._pipe return self._pipe
@ -52,8 +53,10 @@ class ImageModelCache: # Stay loaded for 3 hours
model, model,
torch_dtype=torch.float16 if device == "cuda" else torch.float32, torch_dtype=torch.float16 if device == "cuda" else torch.float32,
) )
def dummy_safety_checker(images, clip_input): def dummy_safety_checker(images, clip_input):
return images, [False] * len(images) return images, [False] * len(images)
pipe.safety_checker = dummy_safety_checker pipe.safety_checker = dummy_safety_checker
else: else:
pipe = FluxPipeline.from_pretrained( pipe = FluxPipeline.from_pretrained(
@ -61,7 +64,7 @@ class ImageModelCache: # Stay loaded for 3 hours
torch_dtype=torch.float16 if device == "cuda" else torch.float32, torch_dtype=torch.float16 if device == "cuda" else torch.float32,
) )
try: try:
pipe.load_lora_weights('enhanceaiteam/Flux-uncensored', weight_name='lora.safetensors') pipe.load_lora_weights("enhanceaiteam/Flux-uncensored", weight_name="lora.safetensors")
except Exception as e: except Exception as e:
raise Exception(f"Failed to load LoRA weights: {str(e)}") raise Exception(f"Failed to load LoRA weights: {str(e)}")
@ -89,10 +92,7 @@ class ImageModelCache: # Stay loaded for 3 hours
async def cleanup_if_expired(self): async def cleanup_if_expired(self):
async with self._lock: async with self._lock:
if ( if self._pipe is not None and time.time() - self._last_access_time >= self._timeout_seconds:
self._pipe is not None and
time.time() - self._last_access_time >= self._timeout_seconds
):
await self._unload_model() await self._unload_model()
async def _periodic_cleanup(self): async def _periodic_cleanup(self):

View File

@ -29,6 +29,7 @@ TIME_ESTIMATES = {
} }
} }
class ImageRequest(BaseModel): class ImageRequest(BaseModel):
session_id: str session_id: str
filepath: str filepath: str
@ -39,18 +40,22 @@ class ImageRequest(BaseModel):
width: int = 256 width: int = 256
guidance_scale: float = 7.5 guidance_scale: float = 7.5
# Global model cache instance # Global model cache instance
model_cache = ImageModelCache() model_cache = ImageModelCache()
def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str): def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str):
"""Background worker for Flux image generation""" """Background worker for Flux image generation"""
try: try:
# Flux: Run generation in the background and yield progress updates # Flux: Run generation in the background and yield progress updates
status_queue.put(ChatMessageStatus( status_queue.put(
ChatMessageStatus(
session_id=params.session_id, session_id=params.session_id,
content="Initializing image generation.", content="Initializing image generation.",
activity=ApiActivityType.GENERATING_IMAGE, activity=ApiActivityType.GENERATING_IMAGE,
)) )
)
# Start the generation task # Start the generation task
start_gen_time = time.time() start_gen_time = time.time()
@ -60,11 +65,13 @@ def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task
# Send progress updates # Send progress updates
progress = int((step + 1) / params.iterations * 100) progress = int((step + 1) / params.iterations * 100)
status_queue.put(ChatMessageStatus( status_queue.put(
ChatMessageStatus(
session_id=params.session_id, session_id=params.session_id,
content=f"Processing step {step+1}/{params.iterations} ({progress}%)", content=f"Processing step {step+1}/{params.iterations} ({progress}%)",
activity=ApiActivityType.GENERATING_IMAGE, activity=ApiActivityType.GENERATING_IMAGE,
)) )
)
return callback_kwargs return callback_kwargs
# Replace this block with your actual Flux pipe call: # Replace this block with your actual Flux pipe call:
@ -84,22 +91,28 @@ def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task
image.save(params.filepath) image.save(params.filepath)
# Final completion status # Final completion status
status_queue.put(ChatMessage( status_queue.put(
ChatMessage(
session_id=params.session_id, session_id=params.session_id,
status=ApiStatusType.DONE, status=ApiStatusType.DONE,
content=f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.", content=f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.",
)) )
)
except Exception as e: except Exception as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
logger.error(e) logger.error(e)
status_queue.put(ChatMessageError( status_queue.put(
ChatMessageError(
session_id=params.session_id, session_id=params.session_id,
content=f"Error during image generation: {str(e)}", content=f"Error during image generation: {str(e)}",
)) )
)
async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError, None]: async def async_generate_image(
pipe: Any, params: ImageRequest
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError, None]:
""" """
Single async function that handles background Flux generation with status streaming Single async function that handles background Flux generation with status streaming
""" """
@ -109,18 +122,14 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato
try: try:
# Start background worker thread # Start background worker thread
worker_thread = Thread( worker_thread = Thread(target=flux_worker, args=(pipe, params, status_queue, task_id), daemon=True)
target=flux_worker,
args=(pipe, params, status_queue, task_id),
daemon=True
)
worker_thread.start() worker_thread.start()
# Initial status # Initial status
status_message = ChatMessageStatus( status_message = ChatMessageStatus(
session_id=params.session_id, session_id=params.session_id,
content=f"Starting image generation with task ID {task_id}", content=f"Starting image generation with task ID {task_id}",
activity=ApiActivityType.THINKING activity=ApiActivityType.THINKING,
) )
yield status_message yield status_message
@ -177,18 +186,16 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato
yield final_status yield final_status
except Exception as e: except Exception as e:
error_status = ChatMessageError( error_status = ChatMessageError(session_id=params.session_id, content=f"Server error: {str(e)}")
session_id=params.session_id,
content=f'Server error: {str(e)}'
)
logger.error(error_status) logger.error(error_status)
yield error_status yield error_status
finally: finally:
# Cleanup: ensure thread completion # Cleanup: ensure thread completion
if worker_thread and 'worker_thread' in locals() and worker_thread.is_alive(): if worker_thread and "worker_thread" in locals() and worker_thread.is_alive():
worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup
def status(session_id: str, status: str) -> ChatMessageStatus: def status(session_id: str, status: str) -> ChatMessageStatus:
"""Update chat message status and return it.""" """Update chat message status and return it."""
chat_message = ChatMessageStatus( chat_message = ChatMessageStatus(
@ -198,6 +205,7 @@ def status(session_id: str, status: str) -> ChatMessageStatus:
) )
return chat_message return chat_message
async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, None]: async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, None]:
"""Generate an image with specified dimensions and yield status updates with time estimates.""" """Generate an image with specified dimensions and yield status updates with time estimates."""
session_id = request.session_id session_id = request.session_id
@ -205,10 +213,7 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
try: try:
# Validate prompt # Validate prompt
if not prompt: if not prompt:
error_message = ChatMessageError( error_message = ChatMessageError(session_id=session_id, content="Prompt cannot be empty.")
session_id=session_id,
content="Prompt cannot be empty."
)
logger.error(error_message.content) logger.error(error_message.content)
yield error_message yield error_message
return return
@ -216,8 +221,7 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
# Validate dimensions # Validate dimensions
if request.height <= 0 or request.width <= 0: if request.height <= 0 or request.width <= 0:
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=session_id, session_id=session_id, content="Height and width must be positive integers."
content="Height and width must be positive integers."
) )
logger.error(error_message.content) logger.error(error_message.content)
yield error_message yield error_message
@ -240,7 +244,10 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
yield status(session_id, "Loading generative image model...") yield status(session_id, "Loading generative image model...")
pipe = await model_cache.get_pipeline(request.model, device) pipe = await model_cache.get_pipeline(request.model, device)
load_time = time.time() - start_time load_time = time.time() - start_time
yield status(session_id, f"Model loaded in {load_time:.1f} seconds.",) yield status(
session_id,
f"Model loaded in {load_time:.1f} seconds.",
)
progress = None progress = None
async for progress in async_generate_image(pipe, request): async for progress in async_generate_image(pipe, request):
@ -252,8 +259,7 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
if not progress: if not progress:
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=session_id, session_id=session_id, content="Image generation failed to produce a valid response."
content="Image generation failed to produce a valid response."
) )
logger.error(f"⚠️ {error_message.content}") logger.error(f"⚠️ {error_message.content}")
yield error_message yield error_message
@ -269,10 +275,7 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
yield chat_message yield chat_message
except Exception as e: except Exception as e:
error_message = ChatMessageError( error_message = ChatMessageError(session_id=session_id, content=f"Error during image generation: {str(e)}")
session_id=session_id,
content=f"Error during image generation: {str(e)}"
)
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
logger.error(error_message.content) logger.error(error_message.content)
yield error_message yield error_message

View File

@ -2,6 +2,7 @@ import json
import re import re
from typing import List, Union from typing import List, Union
def extract_json_blocks(text: str, allow_multiple: bool = False) -> List[dict]: def extract_json_blocks(text: str, allow_multiple: bool = False) -> List[dict]:
""" """
Extract JSON blocks from text, even if surrounded by markdown or noisy text. Extract JSON blocks from text, even if surrounded by markdown or noisy text.
@ -31,13 +32,14 @@ def extract_json_blocks(text: str, allow_multiple: bool = False) -> List[dict]:
return found return found
def _extract_standalone_json(text: str, allow_multiple: bool = False) -> List[Union[dict, list]]: def _extract_standalone_json(text: str, allow_multiple: bool = False) -> List[Union[dict, list]]:
"""Extract standalone JSON objects or arrays from text using proper brace counting.""" """Extract standalone JSON objects or arrays from text using proper brace counting."""
found = [] found = []
i = 0 i = 0
while i < len(text): while i < len(text):
if text[i] in '{[': if text[i] in "{[":
# Found potential JSON start # Found potential JSON start
json_str = _extract_complete_json_at_position(text, i) json_str = _extract_complete_json_at_position(text, i)
if json_str: if json_str:
@ -55,16 +57,17 @@ def _extract_standalone_json(text: str, allow_multiple: bool = False) -> List[Un
return found return found
def _extract_complete_json_at_position(text: str, start_pos: int) -> str: def _extract_complete_json_at_position(text: str, start_pos: int) -> str:
""" """
Extract a complete JSON object or array starting at the given position. Extract a complete JSON object or array starting at the given position.
Uses proper brace/bracket counting and string escape handling. Uses proper brace/bracket counting and string escape handling.
""" """
if start_pos >= len(text) or text[start_pos] not in '{[': if start_pos >= len(text) or text[start_pos] not in "{[":
return "" return ""
start_char = text[start_pos] start_char = text[start_pos]
end_char = '}' if start_char == '{' else ']' end_char = "}" if start_char == "{" else "]"
count = 1 count = 1
i = start_pos + 1 i = start_pos + 1
@ -76,7 +79,7 @@ def _extract_complete_json_at_position(text: str, start_pos: int) -> str:
if escape_next: if escape_next:
escape_next = False escape_next = False
elif char == '\\' and in_string: elif char == "\\" and in_string:
escape_next = True escape_next = True
elif char == '"' and not escape_next: elif char == '"' and not escape_next:
in_string = not in_string in_string = not in_string
@ -92,6 +95,7 @@ def _extract_complete_json_at_position(text: str, start_pos: int) -> str:
return text[start_pos:i] return text[start_pos:i]
return "" return ""
def extract_json_from_text(text: str) -> str: def extract_json_from_text(text: str) -> str:
"""Extract JSON string from text that may contain other content.""" """Extract JSON string from text that may contain other content."""
return json.dumps(extract_json_blocks(text, allow_multiple=False)[0]) return json.dumps(extract_json_blocks(text, allow_multiple=False)[0])

View File

@ -2,11 +2,11 @@ import os
import warnings import warnings
import logging import logging
import defines import defines
def _setup_logging(level=defines.logging_level) -> logging.Logger: def _setup_logging(level=defines.logging_level) -> logging.Logger:
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR" os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
warnings.filterwarnings( warnings.filterwarnings("ignore", message="Overriding a previously registered kernel")
"ignore", message="Overriding a previously registered kernel"
)
warnings.filterwarnings("ignore", message="Warning only once for all operators") warnings.filterwarnings("ignore", message="Warning only once for all operators")
warnings.filterwarnings("ignore", message=".*Couldn't find ffmpeg or avconv.*") warnings.filterwarnings("ignore", message=".*Couldn't find ffmpeg or avconv.*")
warnings.filterwarnings("ignore", message="'force_all_finite' was renamed to") warnings.filterwarnings("ignore", message="'force_all_finite' was renamed to")
@ -19,8 +19,7 @@ def _setup_logging(level=defines.logging_level) -> logging.Logger:
# Create a custom formatter # Create a custom formatter
formatter = logging.Formatter( formatter = logging.Formatter(
fmt="%(levelname)s - %(filename)s:%(lineno)d - %(message)s", fmt="%(levelname)s - %(filename)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
datefmt="%Y-%m-%d %H:%M:%S"
) )
# Create a handler (e.g., StreamHandler for console output) # Create a handler (e.g., StreamHandler for console output)
@ -50,5 +49,6 @@ def _setup_logging(level=defines.logging_level) -> logging.Logger:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
return logger return logger
logger = _setup_logging(level=defines.logging_level) logger = _setup_logging(level=defines.logging_level)
logger.debug(f"Logging initialized with level: {defines.logging_level}") logger.debug(f"Logging initialized with level: {defines.logging_level}")

View File

@ -58,6 +58,7 @@ background_task_manager = None
prev_int = signal.getsignal(signal.SIGINT) prev_int = signal.getsignal(signal.SIGINT)
prev_term = signal.getsignal(signal.SIGTERM) prev_term = signal.getsignal(signal.SIGTERM)
def signal_handler(signum, frame): def signal_handler(signum, frame):
logger.info(f"⚠️ Received signal {signum!r}, shutting down…") logger.info(f"⚠️ Received signal {signum!r}, shutting down…")
# now call the old handler (it might raise KeyboardInterrupt or exit) # now call the old handler (it might raise KeyboardInterrupt or exit)
@ -66,9 +67,11 @@ def signal_handler(signum, frame):
elif signum == signal.SIGTERM and callable(prev_term): elif signum == signal.SIGTERM and callable(prev_term):
prev_term(signum, frame) prev_term(signum, frame)
# Global background task manager # Global background task manager
background_task_manager: Optional[BackgroundTaskManager] = None background_task_manager: Optional[BackgroundTaskManager] = None
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# Startup # Startup
@ -116,6 +119,7 @@ async def lifespan(app: FastAPI):
if db_manager: if db_manager:
await db_manager.graceful_shutdown() await db_manager.graceful_shutdown()
app = FastAPI( app = FastAPI(
lifespan=lifespan, lifespan=lifespan,
title="Backstory API", title="Backstory API",
@ -129,11 +133,9 @@ app = FastAPI(
ssl_enabled = os.getenv("SSL_ENABLED", "true").lower() == "true" ssl_enabled = os.getenv("SSL_ENABLED", "true").lower() == "true"
if ssl_enabled: if ssl_enabled:
allow_origins = ["https://battle-linux.ketrenos.com:3000", allow_origins = ["https://battle-linux.ketrenos.com:3000", "https://backstory-beta.ketrenos.com"]
"https://backstory-beta.ketrenos.com"]
else: else:
allow_origins = ["http://battle-linux.ketrenos.com:3000", allow_origins = ["http://battle-linux.ketrenos.com:3000", "http://backstory-beta.ketrenos.com"]
"http://backstory-beta.ketrenos.com"]
# Add CORS middleware # Add CORS middleware
app.add_middleware( app.add_middleware(
@ -144,12 +146,14 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
# ============================ # ============================
# Debug data type failures # Debug data type failures
# ============================ # ============================
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError): async def validation_exception_handler(request: Request, exc: RequestValidationError):
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
logger.error(backstory_traceback.format_exc()) logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Validation error {request.method} {request.url.path}: {str(exc)}") logger.error(f"❌ Validation error {request.method} {request.url.path}: {str(exc)}")
@ -158,6 +162,7 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
content=json.dumps({"detail": str(exc)}), content=json.dumps({"detail": str(exc)}),
) )
# ============================ # ============================
# Create API router with prefix # Create API router with prefix
# ============================ # ============================
@ -181,6 +186,7 @@ api_router.include_router(users.router)
# Health Check and Info Endpoints # Health Check and Info Endpoints
# ============================ # ============================
@app.get("/health") @app.get("/health")
async def health_check( async def health_check(
database=Depends(get_database), database=Depends(get_database),
@ -202,15 +208,12 @@ async def health_check(
return { return {
"status": "healthy", "status": "healthy",
"timestamp": datetime.utcnow().isoformat(), "timestamp": datetime.utcnow().isoformat(),
"database": { "database": {"status": "connected", "stats": stats},
"status": "connected",
"stats": stats
},
"redis": { "redis": {
"version": redis_info.get("redis_version", "unknown"), "version": redis_info.get("redis_version", "unknown"),
"uptime": redis_info.get("uptime_in_seconds", 0), "uptime": redis_info.get("uptime_in_seconds", 0),
"memory_used": redis_info.get("used_memory_human", "unknown") "memory_used": redis_info.get("used_memory_human", "unknown"),
} },
} }
except RuntimeError as e: except RuntimeError as e:
@ -219,6 +222,7 @@ async def health_check(
logger.error(f"❌ Health check failed: {e}") logger.error(f"❌ Health check failed: {e}")
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
# ============================ # ============================
# Include Router in App # Include Router in App
# ============================ # ============================
@ -231,11 +235,14 @@ app.include_router(api_router)
# ============================ # ============================
logger.info(f"Debug mode is {'enabled' if defines.debug else 'disabled'}") logger.info(f"Debug mode is {'enabled' if defines.debug else 'disabled'}")
@app.middleware("http") @app.middleware("http")
async def log_requests(request: Request, call_next): async def log_requests(request: Request, call_next):
try: try:
if defines.debug and not re.match(rf"{defines.api_prefix}/metrics", request.url.path): if defines.debug and not re.match(rf"{defines.api_prefix}/metrics", request.url.path):
logger.info(f"📝 Request {request.method}: {request.url.path}, Remote: {request.client.host if request.client else ''}") logger.info(
f"📝 Request {request.method}: {request.url.path}, Remote: {request.client.host if request.client else ''}"
)
response = await call_next(request) response = await call_next(request)
if defines.debug and not re.match(rf"{defines.api_prefix}/metrics", request.url.path): if defines.debug and not re.match(rf"{defines.api_prefix}/metrics", request.url.path):
if response.status_code < 200 or response.status_code >= 300: if response.status_code < 200 or response.status_code >= 300:
@ -243,11 +250,13 @@ async def log_requests(request: Request, call_next):
return response return response
except Exception as e: except Exception as e:
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
logger.error(backstory_traceback.format_exc()) logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Error processing request: {str(e)}, Path: {request.url.path}, Method: {request.method}") logger.error(f"❌ Error processing request: {str(e)}, Path: {request.url.path}, Method: {request.method}")
return JSONResponse(status_code=400, content={"detail": "Invalid HTTP request"}) return JSONResponse(status_code=400, content={"detail": "Invalid HTTP request"})
# ============================ # ============================
# Request tracking middleware # Request tracking middleware
# ============================ # ============================
@ -266,6 +275,7 @@ async def track_requests(request, call_next):
finally: finally:
db_manager.decrement_requests() db_manager.decrement_requests()
# ============================ # ============================
# FastAPI Metrics # FastAPI Metrics
# ============================ # ============================
@ -277,7 +287,7 @@ instrumentator = Instrumentator(
should_ignore_untemplated=True, should_ignore_untemplated=True,
should_group_untemplated=True, should_group_untemplated=True,
excluded_handlers=[f"{defines.api_prefix}/metrics"], excluded_handlers=[f"{defines.api_prefix}/metrics"],
registry=prometheus_collector registry=prometheus_collector,
) )
# Instrument the FastAPI app # Instrument the FastAPI app
@ -291,6 +301,7 @@ instrumentator.expose(app, endpoint=f"{defines.api_prefix}/metrics")
# Static File Serving # Static File Serving
# ============================ # ============================
@app.get("/{path:path}") @app.get("/{path:path}")
async def serve_static(path: str, request: Request): async def serve_static(path: str, request: Request):
full_path = os.path.join(defines.static_content, path) full_path = os.path.join(defines.static_content, path)
@ -300,6 +311,7 @@ async def serve_static(path: str, request: Request):
return FileResponse(os.path.join(defines.static_content, "index.html")) return FileResponse(os.path.join(defines.static_content, "index.html"))
# Root endpoint when no static files # Root endpoint when no static files
@app.get("/", include_in_schema=False) @app.get("/", include_in_schema=False)
async def root(): async def root():
@ -309,9 +321,10 @@ async def root():
"version": "1.0.0", "version": "1.0.0",
"api_prefix": defines.api_prefix, "api_prefix": defines.api_prefix,
"documentation": f"{defines.api_prefix}/docs", "documentation": f"{defines.api_prefix}/docs",
"health": f"{defines.api_prefix}/health" "health": f"{defines.api_prefix}/health",
} }
async def periodic_verification_cleanup(): async def periodic_verification_cleanup():
"""Background task to periodically clean up expired verification tokens""" """Background task to periodically clean up expired verification tokens"""
try: try:
@ -324,6 +337,7 @@ async def periodic_verification_cleanup():
except Exception as e: except Exception as e:
logger.error(f"❌ Error in periodic verification cleanup: {e}") logger.error(f"❌ Error in periodic verification cleanup: {e}")
if __name__ == "__main__": if __name__ == "__main__":
host = defines.host host = defines.host
port = defines.port port = defines.port

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,23 @@
[tool.ruff]
# Set line length (default is 88, Black-compatible)
line-length = 120
exclude = [
".git",
"__pycache__",
"migrations",
"venv",
"dist",
"build",
]
target-version = "py312"
[tool.ruff.format]
# Use Ruff's formatter (Black-compatible by default)
quote-style = "double" # Enforce double quotes
indent-style = "space" # Use spaces for indentation
skip-magic-trailing-comma = false # Add trailing commas where applicable
[tool.ruff.per-file-ignores]
# Ignore specific rules for certain files (e.g., tests or scripts)
"tests/**" = ["D100", "S101"] # Ignore missing docstrings and assert warnings in tests
"scripts/**" = ["E402"] # Ignore import order in scripts

View File

@ -1,7 +1,3 @@
from .rag import ChromaDBFileWatcher, start_file_watcher, RagEntry from .rag import ChromaDBFileWatcher, start_file_watcher, RagEntry
__all__ = [
"ChromaDBFileWatcher",
"start_file_watcher",
"RagEntry"
]
__all__ = ["ChromaDBFileWatcher", "start_file_watcher", "RagEntry"]

View File

@ -7,10 +7,12 @@ import logging
import defines import defines
class Chunk(TypedDict): class Chunk(TypedDict):
text: str text: str
metadata: Dict[str, Any] metadata: Dict[str, Any]
def clear_chunk(chunk: Chunk): def clear_chunk(chunk: Chunk):
chunk["text"] = "" chunk["text"] = ""
chunk["metadata"] = { chunk["metadata"] = {
@ -22,6 +24,7 @@ def clear_chunk(chunk: Chunk):
} }
return chunk return chunk
class MarkdownChunker: class MarkdownChunker:
def __init__(self): def __init__(self):
# Initialize the Markdown parser # Initialize the Markdown parser
@ -76,10 +79,7 @@ class MarkdownChunker:
# Initialize a chunk structure # Initialize a chunk structure
chunk: Chunk = { chunk: Chunk = {
"text": "", "text": "",
"metadata": { "metadata": {"source_file": file_path, "lines": total_lines},
"source_file": file_path,
"lines": total_lines
},
} }
clear_chunk(chunk) clear_chunk(chunk)
@ -89,9 +89,7 @@ class MarkdownChunker:
return chunks return chunks
def _sanitize_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]: def _sanitize_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
return { return {k: ("" if v is None else v) for k, v in metadata.items() if v is not None}
k: ("" if v is None else v) for k, v in metadata.items() if v is not None
}
def _extract_text_from_children(self, node: SyntaxTreeNode) -> str: def _extract_text_from_children(self, node: SyntaxTreeNode) -> str:
lines = [] lines = []
@ -114,7 +112,7 @@ class MarkdownChunker:
chunks: List[Chunk], chunks: List[Chunk],
chunk: Chunk, chunk: Chunk,
level: int, level: int,
buffer: int = defines.chunk_buffer buffer: int = defines.chunk_buffer,
) -> int: ) -> int:
is_list = False is_list = False
# Handle heading nodes # Handle heading nodes
@ -198,9 +196,7 @@ class MarkdownChunker:
# Recursively process children # Recursively process children
if not is_list: if not is_list:
for child in node.children: for child in node.children:
level = self._process_node( level = self._process_node(child, current_headings, chunks, chunk, level=level)
child, current_headings, chunks, chunk, level=level
)
# After root-level recursion, finalize any remaining chunk # After root-level recursion, finalize any remaining chunk
if node.type == "document": if node.type == "document":

View File

@ -33,11 +33,13 @@ __all__ = ["ChromaDBFileWatcher", "start_file_watcher"]
DEFAULT_CHUNK_SIZE = 750 DEFAULT_CHUNK_SIZE = 750
DEFAULT_CHUNK_OVERLAP = 100 DEFAULT_CHUNK_OVERLAP = 100
class RagEntry(BaseModel): class RagEntry(BaseModel):
name: str name: str
description: str = "" description: str = ""
enabled: bool = True enabled: bool = True
class ChromaDBFileWatcher(FileSystemEventHandler): class ChromaDBFileWatcher(FileSystemEventHandler):
def __init__( def __init__(
self, self,
@ -72,9 +74,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') # self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Path for storing file hash state # Path for storing file hash state
self.hash_state_path = os.path.join( self.hash_state_path = os.path.join(self.persist_directory, f"{collection_name}_hash_state.json")
self.persist_directory, f"{collection_name}_hash_state.json"
)
# Flag to track if this is a new collection # Flag to track if this is a new collection
self.is_new_collection = False self.is_new_collection = False
@ -158,9 +158,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
process_all: If True, process all files regardless of hash status process_all: If True, process all files regardless of hash status
""" """
# Check for new or modified files # Check for new or modified files
file_paths = glob.glob( file_paths = glob.glob(os.path.join(self.watch_directory, "**/*"), recursive=True)
os.path.join(self.watch_directory, "**/*"), recursive=True
)
files_checked = 0 files_checked = 0
files_processed = 0 files_processed = 0
files_to_process = [] files_to_process = []
@ -180,20 +178,12 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
continue continue
# If file is new, changed, or we're processing all files # If file is new, changed, or we're processing all files
if ( if process_all or file_path not in self.file_hashes or self.file_hashes[file_path] != current_hash:
process_all
or file_path not in self.file_hashes
or self.file_hashes[file_path] != current_hash
):
self.file_hashes[file_path] = current_hash self.file_hashes[file_path] = current_hash
files_to_process.append(file_path) files_to_process.append(file_path)
logging.info( logging.info(f"File {'found' if process_all else 'changed'}: {file_path}")
f"File {'found' if process_all else 'changed'}: {file_path}"
)
logging.info( logging.info(f"Found {len(files_to_process)} files to process after scanning {files_checked} files")
f"Found {len(files_to_process)} files to process after scanning {files_checked} files"
)
# Check for deleted files # Check for deleted files
deleted_files = [] deleted_files = []
@ -201,9 +191,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
if not os.path.exists(file_path): if not os.path.exists(file_path):
deleted_files.append(file_path) deleted_files.append(file_path)
# Schedule removal # Schedule removal
asyncio.run_coroutine_threadsafe( asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
self.remove_file_from_collection(file_path), self.loop
)
# Don't block on result, just let it run # Don't block on result, just let it run
logging.info(f"File deleted: {file_path}") logging.info(f"File deleted: {file_path}")
@ -253,10 +241,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
if not current_hash: # File might have been deleted or is inaccessible if not current_hash: # File might have been deleted or is inaccessible
return return
if ( if file_path in self.file_hashes and self.file_hashes[file_path] == current_hash:
file_path in self.file_hashes
and self.file_hashes[file_path] == current_hash
):
# File hasn't actually changed in content # File hasn't actually changed in content
logging.info(f"Hash has not changed for {file_path}") logging.info(f"Hash has not changed for {file_path}")
return return
@ -289,9 +274,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
if results and "ids" in results and results["ids"]: if results and "ids" in results and results["ids"]:
self.collection.delete(ids=results["ids"]) self.collection.delete(ids=results["ids"])
await self.database.update_user_rag_timestamp(self.user_id) await self.database.update_user_rag_timestamp(self.user_id)
logging.info( logging.info(f"Removed {len(results['ids'])} chunks for deleted file: {file_path}")
f"Removed {len(results['ids'])} chunks for deleted file: {file_path}"
)
# Remove from hash dictionary # Remove from hash dictionary
if file_path in self.file_hashes: if file_path in self.file_hashes:
@ -304,17 +287,15 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
def _update_umaps(self): def _update_umaps(self):
# Update the UMAP embeddings # Update the UMAP embeddings
self._umap_collection = ChromaDBGetResponse.model_validate(self._collection.get( self._umap_collection = ChromaDBGetResponse.model_validate(
include=["embeddings", "documents", "metadatas"] self._collection.get(include=["embeddings", "documents", "metadatas"])
)) )
if not self._umap_collection or not len(self._umap_collection.embeddings): if not self._umap_collection or not len(self._umap_collection.embeddings):
logging.warning("⚠️ No embeddings found in the collection.") logging.warning("⚠️ No embeddings found in the collection.")
return return
# During initialization # During initialization
logging.info( logging.info(f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors")
f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors"
)
vectors = np.array(self._umap_collection.embeddings) vectors = np.array(self._umap_collection.embeddings)
self._umap_model_2d = umap.UMAP( self._umap_model_2d = umap.UMAP(
n_components=2, n_components=2,
@ -328,9 +309,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# f"2D UMAP model n_components: {self._umap_model_2d.n_components}" # f"2D UMAP model n_components: {self._umap_model_2d.n_components}"
# ) # Should be 2 # ) # Should be 2
logging.info( logging.info(f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors")
f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors"
)
self._umap_model_3d = umap.UMAP( self._umap_model_3d = umap.UMAP(
n_components=3, n_components=3,
random_state=8911, random_state=8911,
@ -373,9 +352,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
self.is_new_collection = True self.is_new_collection = True
logging.info(f"Recreating collection: {self.collection_name}") logging.info(f"Recreating collection: {self.collection_name}")
return chroma_client.get_or_create_collection( return chroma_client.get_or_create_collection(name=self.collection_name, metadata={"hnsw:space": "cosine"})
name=self.collection_name, metadata={"hnsw:space": "cosine"}
)
async def get_embedding(self, text: str) -> np.ndarray: async def get_embedding(self, text: str) -> np.ndarray:
"""Generate and normalize an embedding for the given text.""" """Generate and normalize an embedding for the given text."""
@ -419,9 +396,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# Generate a more unique ID based on content and metadata # Generate a more unique ID based on content and metadata
path_hash = "" path_hash = ""
if "path" in metadata: if "path" in metadata:
path_hash = hashlib.md5(metadata["source_file"].encode()).hexdigest()[ path_hash = hashlib.md5(metadata["source_file"].encode()).hexdigest()[:8]
:8
]
content_hash = hashlib.md5(text.encode()).hexdigest()[:8] content_hash = hashlib.md5(text.encode()).hexdigest()[:8]
chunk_id = f"{path_hash}_{i}_{content_hash}" chunk_id = f"{path_hash}_{i}_{content_hash}"
@ -541,9 +516,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
return return
file_path = event.src_path file_path = event.src_path
asyncio.run_coroutine_threadsafe( asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
self.remove_file_from_collection(file_path), self.loop
)
logging.info(f"File deleted: {file_path}") logging.info(f"File deleted: {file_path}")
def on_moved(self, event): def on_moved(self, event):
@ -571,11 +544,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
try: try:
# Remove existing entries for this file # Remove existing entries for this file
existing_results = self.collection.get(where={"path": file_path}) existing_results = self.collection.get(where={"path": file_path})
if ( if existing_results and "ids" in existing_results and existing_results["ids"]:
existing_results
and "ids" in existing_results
and existing_results["ids"]
):
self.collection.delete(ids=existing_results["ids"]) self.collection.delete(ids=existing_results["ids"])
await self.database.update_user_rag_timestamp(self.user_id) await self.database.update_user_rag_timestamp(self.user_id)
@ -584,15 +553,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
p = Path(file_path) p = Path(file_path)
p_as_md = p.with_suffix(".md") p_as_md = p.with_suffix(".md")
if p_as_md.exists(): if p_as_md.exists():
logging.info( logging.info(f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}")
f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}"
)
# If file_path.md doesn't exist or file_path is newer than file_path.md, # If file_path.md doesn't exist or file_path is newer than file_path.md,
# fire off markitdown # fire off markitdown
if (not p_as_md.exists()) or ( if (not p_as_md.exists()) or (p.stat().st_mtime > p_as_md.stat().st_mtime):
p.stat().st_mtime > p_as_md.stat().st_mtime
):
self._markitdown(file_path, p_as_md) self._markitdown(file_path, p_as_md)
return return
@ -626,9 +591,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# Process all files regardless of hash state # Process all files regardless of hash state
num_processed = await self.scan_directory(process_all=True) num_processed = await self.scan_directory(process_all=True)
logging.info( logging.info(f"Vectorstore initialized with {self.collection.count()} documents")
f"Vectorstore initialized with {self.collection.count()} documents"
)
self._update_umaps() self._update_umaps()
@ -676,7 +639,7 @@ def start_file_watcher(
persist_directory=persist_directory, persist_directory=persist_directory,
collection_name=collection_name, collection_name=collection_name,
recreate=recreate, recreate=recreate,
database=database database=database,
) )
# Process all files if: # Process all files if:

View File

@ -13,14 +13,4 @@ from . import employers
from . import admin from . import admin
from . import system from . import system
__all__ = [ __all__ = ["auth", "candidates", "resumes", "jobs", "chat", "users", "employers", "admin", "system"]
"auth",
"candidates",
"resumes",
"jobs",
"chat",
"users",
"employers",
"admin",
"system"
]

View File

@ -5,8 +5,18 @@ import json
from datetime import datetime, timezone, UTC from datetime import datetime, timezone, UTC
from fastapi import ( from fastapi import (
APIRouter, HTTPException, Depends, Body, Request, HTTPException, APIRouter,
Depends, Query, Path, Body, APIRouter, Request HTTPException,
Depends,
Body,
Request,
HTTPException,
Depends,
Query,
Path,
Body,
APIRouter,
Request,
) )
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
@ -14,52 +24,51 @@ from fastapi.responses import JSONResponse
from utils.rate_limiter import RateLimiter, get_rate_limiter from utils.rate_limiter import RateLimiter, get_rate_limiter
from database.manager import RedisDatabase from database.manager import RedisDatabase
from logger import logger from logger import logger
from utils.dependencies import ( from utils.dependencies import get_current_admin, get_current_user_or_guest, get_database, background_task_manager
get_current_admin, get_current_user_or_guest, get_database, background_task_manager
)
from utils.responses import ( from utils.responses import (
create_paginated_response, create_success_response, create_error_response, create_paginated_response,
create_success_response,
create_error_response,
) )
# Create router for authentication endpoints # Create router for authentication endpoints
router = APIRouter(prefix="/admin", tags=["admin"]) router = APIRouter(prefix="/admin", tags=["admin"])
@router.post("/tasks/cleanup-guests") @router.post("/tasks/cleanup-guests")
async def manual_guest_cleanup( async def manual_guest_cleanup(
inactive_hours: int = Body(24, embed=True), inactive_hours: int = Body(24, embed=True),
current_user=Depends(get_current_admin), current_user=Depends(get_current_admin),
admin_user = Depends(get_current_admin) admin_user=Depends(get_current_admin),
): ):
"""Manually trigger guest cleanup (admin only)""" """Manually trigger guest cleanup (admin only)"""
try: try:
if not background_task_manager: if not background_task_manager:
return JSONResponse( return JSONResponse(
status_code=500, status_code=500,
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available") content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available"),
) )
cleaned_count = await background_task_manager.cleanup_inactive_guests(inactive_hours) cleaned_count = await background_task_manager.cleanup_inactive_guests(inactive_hours)
logger.info(f"🧹 Manual guest cleanup triggered by admin {admin_user.id}: {cleaned_count} guests cleaned") logger.info(f"🧹 Manual guest cleanup triggered by admin {admin_user.id}: {cleaned_count} guests cleaned")
return create_success_response({ return create_success_response(
{
"message": f"Guest cleanup completed. Removed {cleaned_count} inactive sessions.", "message": f"Guest cleanup completed. Removed {cleaned_count} inactive sessions.",
"cleaned_count": cleaned_count, "cleaned_count": cleaned_count,
"triggered_by": admin_user.id "triggered_by": admin_user.id,
}) }
)
except Exception as e: except Exception as e:
logger.error(f"❌ Manual guest cleanup error: {e}") logger.error(f"❌ Manual guest cleanup error: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
status_code=500,
content=create_error_response("CLEANUP_ERROR", str(e))
)
@router.post("/tasks/cleanup-tokens") @router.post("/tasks/cleanup-tokens")
async def manual_token_cleanup( async def manual_token_cleanup(admin_user=Depends(get_current_admin)):
admin_user = Depends(get_current_admin)
):
"""Manually trigger verification token cleanup (admin only)""" """Manually trigger verification token cleanup (admin only)"""
try: try:
global background_task_manager global background_task_manager
@ -67,31 +76,28 @@ async def manual_token_cleanup(
if not background_task_manager: if not background_task_manager:
return JSONResponse( return JSONResponse(
status_code=500, status_code=500,
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available") content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available"),
) )
cleaned_count = await background_task_manager.cleanup_expired_verification_tokens() cleaned_count = await background_task_manager.cleanup_expired_verification_tokens()
logger.info(f"🧹 Manual token cleanup triggered by admin {admin_user.id}: {cleaned_count} tokens cleaned") logger.info(f"🧹 Manual token cleanup triggered by admin {admin_user.id}: {cleaned_count} tokens cleaned")
return create_success_response({ return create_success_response(
{
"message": f"Token cleanup completed. Removed {cleaned_count} expired tokens.", "message": f"Token cleanup completed. Removed {cleaned_count} expired tokens.",
"cleaned_count": cleaned_count, "cleaned_count": cleaned_count,
"triggered_by": admin_user.id "triggered_by": admin_user.id,
}) }
)
except Exception as e: except Exception as e:
logger.error(f"❌ Manual token cleanup error: {e}") logger.error(f"❌ Manual token cleanup error: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
status_code=500,
content=create_error_response("CLEANUP_ERROR", str(e))
)
@router.post("/tasks/cleanup-rate-limits") @router.post("/tasks/cleanup-rate-limits")
async def manual_rate_limit_cleanup( async def manual_rate_limit_cleanup(days_old: int = Body(7, embed=True), admin_user=Depends(get_current_admin)):
days_old: int = Body(7, embed=True),
admin_user = Depends(get_current_admin)
):
"""Manually trigger rate limit data cleanup (admin only)""" """Manually trigger rate limit data cleanup (admin only)"""
try: try:
global background_task_manager global background_task_manager
@ -99,60 +105,55 @@ async def manual_rate_limit_cleanup(
if not background_task_manager: if not background_task_manager:
return JSONResponse( return JSONResponse(
status_code=500, status_code=500,
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available") content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available"),
) )
cleaned_count = await background_task_manager.cleanup_old_rate_limit_data(days_old) cleaned_count = await background_task_manager.cleanup_old_rate_limit_data(days_old)
logger.info(f"🧹 Manual rate limit cleanup triggered by admin {admin_user.id}: {cleaned_count} keys cleaned") logger.info(f"🧹 Manual rate limit cleanup triggered by admin {admin_user.id}: {cleaned_count} keys cleaned")
return create_success_response({ return create_success_response(
{
"message": f"Rate limit cleanup completed. Removed {cleaned_count} old keys.", "message": f"Rate limit cleanup completed. Removed {cleaned_count} old keys.",
"cleaned_count": cleaned_count, "cleaned_count": cleaned_count,
"triggered_by": admin_user.id "triggered_by": admin_user.id,
}) }
)
except Exception as e: except Exception as e:
logger.error(f"❌ Manual rate limit cleanup error: {e}") logger.error(f"❌ Manual rate limit cleanup error: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
status_code=500,
content=create_error_response("CLEANUP_ERROR", str(e))
)
# ======================================== # ========================================
# System Health and Maintenance Endpoints # System Health and Maintenance Endpoints
# ======================================== # ========================================
@router.get("/system/health") @router.get("/system/health")
async def get_system_health( async def get_system_health(request: Request, admin_user=Depends(get_current_admin)):
request: Request,
admin_user = Depends(get_current_admin)
):
"""Get comprehensive system health status (admin only)""" """Get comprehensive system health status (admin only)"""
try: try:
# Database health # Database health
database_manager = getattr(request.app.state, 'database_manager', None) database_manager = getattr(request.app.state, "database_manager", None)
db_health = {"status": "unavailable", "healthy": False} db_health = {"status": "unavailable", "healthy": False}
if database_manager: if database_manager:
try: try:
database_manager.get_database() database_manager.get_database()
from database.manager import redis_manager from database.manager import redis_manager
redis_health = await redis_manager.health_check() redis_health = await redis_manager.health_check()
db_health = { db_health = {
"status": redis_health.get("status", "unknown"), "status": redis_health.get("status", "unknown"),
"healthy": redis_health.get("status") == "healthy", "healthy": redis_health.get("status") == "healthy",
"details": redis_health "details": redis_health,
} }
except Exception as e: except Exception as e:
db_health = { db_health = {"status": "error", "healthy": False, "error": str(e)}
"status": "error",
"healthy": False,
"error": str(e)
}
# Background task health # Background task health
background_task_manager = getattr(request.app.state, 'background_task_manager', None) background_task_manager = getattr(request.app.state, "background_task_manager", None)
task_health = {"status": "unavailable", "healthy": False} task_health = {"status": "unavailable", "healthy": False}
if background_task_manager: if background_task_manager:
@ -166,39 +167,30 @@ async def get_system_health(
"healthy": task_status["running"] and failed_tasks == 0, "healthy": task_status["running"] and failed_tasks == 0,
"running_tasks": running_tasks, "running_tasks": running_tasks,
"failed_tasks": failed_tasks, "failed_tasks": failed_tasks,
"total_tasks": task_status["task_count"] "total_tasks": task_status["task_count"],
} }
except Exception as e: except Exception as e:
task_health = { task_health = {"status": "error", "healthy": False, "error": str(e)}
"status": "error",
"healthy": False,
"error": str(e)
}
# Overall health # Overall health
overall_healthy = db_health["healthy"] and task_health["healthy"] overall_healthy = db_health["healthy"] and task_health["healthy"]
return create_success_response({ return create_success_response(
{
"timestamp": datetime.now(UTC).isoformat(), "timestamp": datetime.now(UTC).isoformat(),
"overall_healthy": overall_healthy, "overall_healthy": overall_healthy,
"components": { "components": {"database": db_health, "background_tasks": task_health},
"database": db_health,
"background_tasks": task_health
} }
}) )
except Exception as e: except Exception as e:
logger.error(f"❌ Error getting system health: {e}") logger.error(f"❌ Error getting system health: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("HEALTH_CHECK_ERROR", str(e)))
status_code=500,
content=create_error_response("HEALTH_CHECK_ERROR", str(e))
)
@router.post("/maintenance/cleanup") @router.post("/maintenance/cleanup")
async def run_maintenance_cleanup( async def run_maintenance_cleanup(
request: Request, request: Request, admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
admin_user = Depends(get_current_admin),
database: RedisDatabase = Depends(get_database)
): ):
"""Run comprehensive maintenance cleanup (admin only)""" """Run comprehensive maintenance cleanup (admin only)"""
try: try:
@ -217,50 +209,40 @@ async def run_maintenance_cleanup(
cleanup_results[operation_name] = { cleanup_results[operation_name] = {
"success": True, "success": True,
"cleaned_count": result, "cleaned_count": result,
"message": f"Cleaned {result} items" "message": f"Cleaned {result} items",
} }
except Exception as e: except Exception as e:
cleanup_results[operation_name] = { cleanup_results[operation_name] = {"success": False, "error": str(e), "message": f"Failed: {str(e)}"}
"success": False,
"error": str(e),
"message": f"Failed: {str(e)}"
}
# Calculate totals # Calculate totals
total_cleaned = sum( total_cleaned = sum(
result.get("cleaned_count", 0) result.get("cleaned_count", 0) for result in cleanup_results.values() if result.get("success", False)
for result in cleanup_results.values()
if result.get("success", False)
) )
successful_operations = len([ successful_operations = len([r for r in cleanup_results.values() if r.get("success", False)])
r for r in cleanup_results.values()
if r.get("success", False)
])
return create_success_response({ return create_success_response(
{
"message": f"Maintenance cleanup completed. {total_cleaned} items cleaned across {successful_operations} operations.", "message": f"Maintenance cleanup completed. {total_cleaned} items cleaned across {successful_operations} operations.",
"total_cleaned": total_cleaned, "total_cleaned": total_cleaned,
"successful_operations": successful_operations, "successful_operations": successful_operations,
"details": cleanup_results "details": cleanup_results,
}) }
)
except Exception as e: except Exception as e:
logger.error(f"❌ Error in maintenance cleanup: {e}") logger.error(f"❌ Error in maintenance cleanup: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
status_code=500,
content=create_error_response("CLEANUP_ERROR", str(e))
)
# ======================================== # ========================================
# Background Task Statistics # Background Task Statistics
# ======================================== # ========================================
@router.get("/tasks/stats") @router.get("/tasks/stats")
async def get_task_statistics( async def get_task_statistics(
request: Request, request: Request, admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
admin_user = Depends(get_current_admin),
database: RedisDatabase = Depends(get_database)
): ):
"""Get background task execution statistics (admin only)""" """Get background task execution statistics (admin only)"""
try: try:
@ -268,7 +250,7 @@ async def get_task_statistics(
guest_stats = await database.get_guest_statistics() guest_stats = await database.get_guest_statistics()
# Get background task manager status # Get background task manager status
background_task_manager = getattr(request.app.state, 'background_task_manager', None) background_task_manager = getattr(request.app.state, "background_task_manager", None)
task_manager_stats = {} task_manager_stats = {}
if background_task_manager: if background_task_manager:
@ -276,7 +258,7 @@ async def get_task_statistics(
task_manager_stats = { task_manager_stats = {
"running": task_status["running"], "running": task_status["running"],
"task_count": task_status["task_count"], "task_count": task_status["task_count"],
"task_breakdown": {} "task_breakdown": {},
} }
# Count tasks by status # Count tasks by status
@ -284,40 +266,35 @@ async def get_task_statistics(
status = task["status"] status = task["status"]
task_manager_stats["task_breakdown"][status] = task_manager_stats["task_breakdown"].get(status, 0) + 1 task_manager_stats["task_breakdown"][status] = task_manager_stats["task_breakdown"].get(status, 0) + 1
return create_success_response({ return create_success_response(
{
"guest_statistics": guest_stats, "guest_statistics": guest_stats,
"task_manager": task_manager_stats, "task_manager": task_manager_stats,
"timestamp": datetime.now(UTC).isoformat() "timestamp": datetime.now(UTC).isoformat(),
}) }
)
except Exception as e: except Exception as e:
logger.error(f"❌ Error getting task statistics: {e}") logger.error(f"❌ Error getting task statistics: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("STATS_ERROR", str(e)))
status_code=500,
content=create_error_response("STATS_ERROR", str(e))
)
# ======================================== # ========================================
# Background Task Status Endpoints # Background Task Status Endpoints
# ======================================== # ========================================
@router.get("/tasks/status") @router.get("/tasks/status")
async def get_background_task_status( async def get_background_task_status(request: Request, admin_user=Depends(get_current_admin)):
request: Request,
admin_user = Depends(get_current_admin)
):
"""Get background task manager status (admin only)""" """Get background task manager status (admin only)"""
try: try:
# Get background task manager from app state # Get background task manager from app state
background_task_manager = getattr(request.app.state, 'background_task_manager', None) background_task_manager = getattr(request.app.state, "background_task_manager", None)
if not background_task_manager: if not background_task_manager:
return create_success_response({ return create_success_response(
"running": False, {"running": False, "message": "Background task manager not initialized", "tasks": [], "task_count": 0}
"message": "Background task manager not initialized", )
"tasks": [],
"task_count": 0
})
# Get comprehensive task status using the new method # Get comprehensive task status using the new method
task_status = await background_task_manager.get_task_status() task_status = await background_task_manager.get_task_status()
@ -329,87 +306,64 @@ async def get_background_task_status(
} }
# Format the response # Format the response
return create_success_response({ return create_success_response(
{
"running": task_status["running"], "running": task_status["running"],
"task_count": task_status["task_count"], "task_count": task_status["task_count"],
"loop_status": { "loop_status": {
"main_loop_id": task_status["main_loop_id"], "main_loop_id": task_status["main_loop_id"],
"current_loop_id": task_status["current_loop_id"], "current_loop_id": task_status["current_loop_id"],
"loop_matches": task_status.get("loop_matches", False) "loop_matches": task_status.get("loop_matches", False),
}, },
"tasks": task_status["tasks"], "tasks": task_status["tasks"],
"system_info": system_info "system_info": system_info,
}) }
)
except Exception as e: except Exception as e:
logger.error(f"❌ Get task status error: {e}") logger.error(f"❌ Get task status error: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("STATUS_ERROR", str(e)))
status_code=500,
content=create_error_response("STATUS_ERROR", str(e))
)
@router.post("/tasks/run/{task_name}") @router.post("/tasks/run/{task_name}")
async def run_background_task( async def run_background_task(task_name: str, request: Request, admin_user=Depends(get_current_admin)):
task_name: str,
request: Request,
admin_user = Depends(get_current_admin)
):
"""Manually trigger a specific background task (admin only)""" """Manually trigger a specific background task (admin only)"""
try: try:
background_task_manager = getattr(request.app.state, 'background_task_manager', None) background_task_manager = getattr(request.app.state, "background_task_manager", None)
if not background_task_manager: if not background_task_manager:
return JSONResponse( return JSONResponse(
status_code=503, status_code=503,
content=create_error_response( content=create_error_response("MANAGER_UNAVAILABLE", "Background task manager not initialized"),
"MANAGER_UNAVAILABLE",
"Background task manager not initialized"
)
) )
# List of available tasks # List of available tasks
available_tasks = [ available_tasks = ["guest_cleanup", "token_cleanup", "guest_stats", "rate_limit_cleanup", "orphaned_cleanup"]
"guest_cleanup",
"token_cleanup",
"guest_stats",
"rate_limit_cleanup",
"orphaned_cleanup"
]
if task_name not in available_tasks: if task_name not in available_tasks:
return JSONResponse( return JSONResponse(
status_code=400, status_code=400,
content=create_error_response( content=create_error_response(
"INVALID_TASK", "INVALID_TASK", f"Unknown task: {task_name}. Available: {available_tasks}"
f"Unknown task: {task_name}. Available: {available_tasks}" ),
)
) )
# Run the task # Run the task
result = await background_task_manager.force_run_task(task_name) result = await background_task_manager.force_run_task(task_name)
return create_success_response({ return create_success_response(
"task_name": task_name, {"task_name": task_name, "result": result, "message": f"Task {task_name} completed successfully"}
"result": result, )
"message": f"Task {task_name} completed successfully"
})
except ValueError as e: except ValueError as e:
return JSONResponse( return JSONResponse(status_code=400, content=create_error_response("INVALID_TASK", str(e)))
status_code=400,
content=create_error_response("INVALID_TASK", str(e))
)
except Exception as e: except Exception as e:
logger.error(f"❌ Error running task {task_name}: {e}") logger.error(f"❌ Error running task {task_name}: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("TASK_EXECUTION_ERROR", str(e)))
status_code=500,
content=create_error_response("TASK_EXECUTION_ERROR", str(e))
)
@router.get("/tasks/list") @router.get("/tasks/list")
async def list_available_tasks( async def list_available_tasks(admin_user=Depends(get_current_admin)):
admin_user = Depends(get_current_admin)
):
"""List all available background tasks (admin only)""" """List all available background tasks (admin only)"""
try: try:
tasks = [ tasks = [
@ -417,63 +371,51 @@ async def list_available_tasks(
"name": "guest_cleanup", "name": "guest_cleanup",
"description": "Clean up inactive guest sessions", "description": "Clean up inactive guest sessions",
"interval": "6 hours", "interval": "6 hours",
"parameters": ["inactive_hours (default: 48)"] "parameters": ["inactive_hours (default: 48)"],
}, },
{ {
"name": "token_cleanup", "name": "token_cleanup",
"description": "Clean up expired email verification tokens", "description": "Clean up expired email verification tokens",
"interval": "12 hours", "interval": "12 hours",
"parameters": [] "parameters": [],
}, },
{ {
"name": "guest_stats", "name": "guest_stats",
"description": "Update guest usage statistics", "description": "Update guest usage statistics",
"interval": "1 hour", "interval": "1 hour",
"parameters": [] "parameters": [],
}, },
{ {
"name": "rate_limit_cleanup", "name": "rate_limit_cleanup",
"description": "Clean up old rate limiting data", "description": "Clean up old rate limiting data",
"interval": "24 hours", "interval": "24 hours",
"parameters": ["days_old (default: 7)"] "parameters": ["days_old (default: 7)"],
}, },
{ {
"name": "orphaned_cleanup", "name": "orphaned_cleanup",
"description": "Clean up orphaned database records", "description": "Clean up orphaned database records",
"interval": "6 hours", "interval": "6 hours",
"parameters": [] "parameters": [],
} },
] ]
return create_success_response({ return create_success_response({"total_tasks": len(tasks), "tasks": tasks})
"total_tasks": len(tasks),
"tasks": tasks
})
except Exception as e: except Exception as e:
logger.error(f"❌ Error listing tasks: {e}") logger.error(f"❌ Error listing tasks: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("LIST_ERROR", str(e)))
status_code=500,
content=create_error_response("LIST_ERROR", str(e))
)
@router.post("/tasks/restart") @router.post("/tasks/restart")
async def restart_background_tasks( async def restart_background_tasks(request: Request, admin_user=Depends(get_current_admin)):
request: Request,
admin_user = Depends(get_current_admin)
):
"""Restart the background task manager (admin only)""" """Restart the background task manager (admin only)"""
try: try:
database_manager = getattr(request.app.state, 'database_manager', None) database_manager = getattr(request.app.state, "database_manager", None)
background_task_manager = getattr(request.app.state, 'background_task_manager', None) background_task_manager = getattr(request.app.state, "background_task_manager", None)
if not database_manager: if not database_manager:
return JSONResponse( return JSONResponse(
status_code=503, status_code=503, content=create_error_response("DATABASE_UNAVAILABLE", "Database manager not available")
content=create_error_response(
"DATABASE_UNAVAILABLE",
"Database manager not available"
)
) )
# Stop existing background tasks # Stop existing background tasks
@ -483,6 +425,7 @@ async def restart_background_tasks(
# Create and start new background task manager # Create and start new background task manager
from background_tasks import BackgroundTaskManager from background_tasks import BackgroundTaskManager
new_background_task_manager = BackgroundTaskManager(database_manager) new_background_task_manager = BackgroundTaskManager(database_manager)
await new_background_task_manager.start() await new_background_task_manager.start()
@ -492,22 +435,20 @@ async def restart_background_tasks(
# Get status of new manager # Get status of new manager
status = await new_background_task_manager.get_task_status() status = await new_background_task_manager.get_task_status()
return create_success_response({ return create_success_response(
"message": "Background task manager restarted successfully", {"message": "Background task manager restarted successfully", "new_status": status}
"new_status": status )
})
except Exception as e: except Exception as e:
logger.error(f"❌ Error restarting background tasks: {e}") logger.error(f"❌ Error restarting background tasks: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("RESTART_ERROR", str(e)))
status_code=500,
content=create_error_response("RESTART_ERROR", str(e))
)
# ============================ # ============================
# Task Monitoring and Metrics # Task Monitoring and Metrics
# ============================ # ============================
class TaskMetrics: class TaskMetrics:
"""Collect metrics for background tasks""" """Collect metrics for background tasks"""
@ -544,36 +485,32 @@ class TaskMetrics:
metrics[task_name] = { metrics[task_name] = {
"total_runs": self.task_runs[task_name], "total_runs": self.task_runs[task_name],
"total_errors": self.task_errors[task_name], "total_errors": self.task_errors[task_name],
"success_rate": (self.task_runs[task_name] - self.task_errors[task_name]) / self.task_runs[task_name] if self.task_runs[task_name] > 0 else 0, "success_rate": (self.task_runs[task_name] - self.task_errors[task_name]) / self.task_runs[task_name]
if self.task_runs[task_name] > 0
else 0,
"average_duration": avg_duration, "average_duration": avg_duration,
"last_runs": durations[-10:] if durations else [] "last_runs": durations[-10:] if durations else [],
} }
return metrics return metrics
# Global task metrics # Global task metrics
task_metrics = TaskMetrics() task_metrics = TaskMetrics()
@router.get("/tasks/metrics") @router.get("/tasks/metrics")
async def get_task_metrics( async def get_task_metrics(admin_user=Depends(get_current_admin)):
admin_user = Depends(get_current_admin)
):
"""Get background task metrics (admin only)""" """Get background task metrics (admin only)"""
try: try:
global task_metrics global task_metrics
metrics = task_metrics.get_metrics() metrics = task_metrics.get_metrics()
return create_success_response({ return create_success_response({"metrics": metrics, "timestamp": datetime.now(UTC).isoformat()})
"metrics": metrics,
"timestamp": datetime.now(UTC).isoformat()
})
except Exception as e: except Exception as e:
logger.error(f"❌ Get task metrics error: {e}") logger.error(f"❌ Get task metrics error: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("METRICS_ERROR", str(e)))
status_code=500,
content=create_error_response("METRICS_ERROR", str(e))
)
# ============================ # ============================
@ -581,8 +518,7 @@ async def get_task_metrics(
# ============================ # ============================
# @router.get("/verification-stats") # @router.get("/verification-stats")
async def get_verification_statistics( async def get_verification_statistics(
current_user = Depends(get_current_admin), current_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database)
): ):
"""Get verification statistics (admin only)""" """Get verification statistics (admin only)"""
try: try:
@ -591,22 +527,19 @@ async def get_verification_statistics(
stats = { stats = {
"pending_verifications": await database.get_pending_verifications_count(), "pending_verifications": await database.get_pending_verifications_count(),
"expired_tokens_cleaned": await database.cleanup_expired_verification_tokens() "expired_tokens_cleaned": await database.cleanup_expired_verification_tokens(),
} }
return create_success_response(stats) return create_success_response(stats)
except Exception as e: except Exception as e:
logger.error(f"❌ Error getting verification stats: {e}") logger.error(f"❌ Error getting verification stats: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("STATS_ERROR", str(e)))
status_code=500,
content=create_error_response("STATS_ERROR", str(e))
)
@router.post("/cleanup-verifications") @router.post("/cleanup-verifications")
async def cleanup_verification_tokens( async def cleanup_verification_tokens(
current_user = Depends(get_current_admin), current_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database)
): ):
"""Manually trigger cleanup of expired verification tokens (admin only)""" """Manually trigger cleanup of expired verification tokens (admin only)"""
try: try:
@ -617,24 +550,24 @@ async def cleanup_verification_tokens(
logger.info(f"🧹 Manual cleanup completed by admin {current_user.id}: {cleaned_count} tokens cleaned") logger.info(f"🧹 Manual cleanup completed by admin {current_user.id}: {cleaned_count} tokens cleaned")
return create_success_response({ return create_success_response(
{
"message": f"Cleanup completed. Removed {cleaned_count} expired verification tokens.", "message": f"Cleanup completed. Removed {cleaned_count} expired verification tokens.",
"cleaned_count": cleaned_count "cleaned_count": cleaned_count,
}) }
)
except Exception as e: except Exception as e:
logger.error(f"❌ Error in manual cleanup: {e}") logger.error(f"❌ Error in manual cleanup: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
status_code=500,
content=create_error_response("CLEANUP_ERROR", str(e))
)
@router.get("/pending-verifications") @router.get("/pending-verifications")
async def get_pending_verifications( async def get_pending_verifications(
current_user=Depends(get_current_admin), current_user=Depends(get_current_admin),
page: int = Query(1, ge=1), page: int = Query(1, ge=1),
limit: int = Query(20, ge=1, le=100), limit: int = Query(20, ge=1, le=100),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Get list of pending email verifications (admin only)""" """Get list of pending email verifications (admin only)"""
try: try:
@ -656,14 +589,16 @@ async def get_pending_verifications(
if not verification_info.get("verified", False): if not verification_info.get("verified", False):
expires_at = datetime.fromisoformat(verification_info.get("expires_at", "")) expires_at = datetime.fromisoformat(verification_info.get("expires_at", ""))
pending_verifications.append({ pending_verifications.append(
{
"email": verification_info.get("email"), "email": verification_info.get("email"),
"user_type": verification_info.get("user_type"), "user_type": verification_info.get("user_type"),
"created_at": verification_info.get("created_at"), "created_at": verification_info.get("created_at"),
"expires_at": verification_info.get("expires_at"), "expires_at": verification_info.get("expires_at"),
"is_expired": current_time > expires_at, "is_expired": current_time > expires_at,
"resend_count": verification_info.get("resend_count", 0) "resend_count": verification_info.get("resend_count", 0),
}) }
)
if cursor == 0: if cursor == 0:
break break
@ -677,35 +612,27 @@ async def get_pending_verifications(
end = start + limit end = start + limit
paginated_verifications = pending_verifications[start:end] paginated_verifications = pending_verifications[start:end]
paginated_response = create_paginated_response( paginated_response = create_paginated_response(paginated_verifications, page, limit, total)
paginated_verifications,
page, limit, total
)
return create_success_response(paginated_response) return create_success_response(paginated_response)
except Exception as e: except Exception as e:
logger.error(f"❌ Error getting pending verifications: {e}") logger.error(f"❌ Error getting pending verifications: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e)))
status_code=500,
content=create_error_response("FETCH_ERROR", str(e))
)
@router.get("/rate-limits/info") @router.get("/rate-limits/info")
async def get_user_rate_limit_status( async def get_user_rate_limit_status(
current_user=Depends(get_current_user_or_guest), current_user=Depends(get_current_user_or_guest),
rate_limiter: RateLimiter = Depends(get_rate_limiter), rate_limiter: RateLimiter = Depends(get_rate_limiter),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Get rate limit status for a user (admin only)""" """Get rate limit status for a user (admin only)"""
try: try:
# Get user to determine type # Get user to determine type
user_data = await database.get_user_by_id(current_user.id) user_data = await database.get_user_by_id(current_user.id)
if not user_data: if not user_data:
return JSONResponse( return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
status_code=404,
content=create_error_response("USER_NOT_FOUND", "User not found")
)
user_type = user_data.get("type", "unknown") user_type = user_data.get("type", "unknown")
is_admin = False is_admin = False
@ -725,27 +652,22 @@ async def get_user_rate_limit_status(
except Exception as e: except Exception as e:
logger.error(f"❌ Get rate limit status error: {e}") logger.error(f"❌ Get rate limit status error: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("STATUS_ERROR", str(e)))
status_code=500,
content=create_error_response("STATUS_ERROR", str(e))
)
@router.get("/rate-limits/{user_id}") @router.get("/rate-limits/{user_id}")
async def get_anyone_rate_limit_status( async def get_anyone_rate_limit_status(
user_id: str = Path(...), user_id: str = Path(...),
admin_user=Depends(get_current_admin), admin_user=Depends(get_current_admin),
rate_limiter: RateLimiter = Depends(get_rate_limiter), rate_limiter: RateLimiter = Depends(get_rate_limiter),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Get rate limit status for a user (admin only)""" """Get rate limit status for a user (admin only)"""
try: try:
# Get user to determine type # Get user to determine type
user_data = await database.get_user_by_id(user_id) user_data = await database.get_user_by_id(user_id)
if not user_data: if not user_data:
return JSONResponse( return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
status_code=404,
content=create_error_response("USER_NOT_FOUND", "User not found")
)
user_type = user_data.get("type", "unknown") user_type = user_data.get("type", "unknown")
is_admin = False is_admin = False
@ -765,47 +687,36 @@ async def get_anyone_rate_limit_status(
except Exception as e: except Exception as e:
logger.error(f"❌ Get rate limit status error: {e}") logger.error(f"❌ Get rate limit status error: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("STATUS_ERROR", str(e)))
status_code=500,
content=create_error_response("STATUS_ERROR", str(e))
)
@router.post("/rate-limits/{user_id}/reset") @router.post("/rate-limits/{user_id}/reset")
async def reset_user_rate_limits( async def reset_user_rate_limits(
user_id: str = Path(...), user_id: str = Path(...),
admin_user=Depends(get_current_admin), admin_user=Depends(get_current_admin),
rate_limiter: RateLimiter = Depends(get_rate_limiter), rate_limiter: RateLimiter = Depends(get_rate_limiter),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Reset rate limits for a user (admin only)""" """Reset rate limits for a user (admin only)"""
try: try:
# Get user to determine type # Get user to determine type
user_data = await database.get_user_by_id(user_id) user_data = await database.get_user_by_id(user_id)
if not user_data: if not user_data:
return JSONResponse( return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
status_code=404,
content=create_error_response("USER_NOT_FOUND", "User not found")
)
user_type = user_data.get("type", "unknown") user_type = user_data.get("type", "unknown")
success = await rate_limiter.reset_user_rate_limits(user_id, user_type) success = await rate_limiter.reset_user_rate_limits(user_id, user_type)
if success: if success:
logger.info(f"🔄 Rate limits reset for {user_type} {user_id} by admin {admin_user.id}") logger.info(f"🔄 Rate limits reset for {user_type} {user_id} by admin {admin_user.id}")
return create_success_response({ return create_success_response(
"message": f"Rate limits reset for {user_type} {user_id}", {"message": f"Rate limits reset for {user_type} {user_id}", "resetBy": admin_user.id}
"resetBy": admin_user.id )
})
else: else:
return JSONResponse( return JSONResponse(
status_code=500, status_code=500, content=create_error_response("RESET_FAILED", "Failed to reset rate limits")
content=create_error_response("RESET_FAILED", "Failed to reset rate limits")
) )
except Exception as e: except Exception as e:
logger.error(f"❌ Reset rate limits error: {e}") logger.error(f"❌ Reset rate limits error: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("RESET_ERROR", str(e)))
status_code=500,
content=create_error_response("RESET_ERROR", str(e))
)

View File

@ -20,18 +20,25 @@ from device_manager import DeviceManager
from email_service import VerificationEmailRateLimiter, email_service from email_service import VerificationEmailRateLimiter, email_service
from logger import logger from logger import logger
from models import ( from models import (
LoginRequest, CreateCandidateRequest, Candidate, LoginRequest,
Employer, Guest, AuthResponse, MFARequest, CreateCandidateRequest,
MFAData, MFAVerifyRequest, ResendVerificationRequest, Candidate,
MFARequestResponse, MFARequestResponse Employer,
) Guest,
from utils.dependencies import ( AuthResponse,
get_current_admin, get_database, get_current_user, create_access_token MFARequest,
MFAData,
MFAVerifyRequest,
ResendVerificationRequest,
MFARequestResponse,
MFARequestResponse,
) )
from utils.dependencies import get_current_admin, get_database, get_current_user, create_access_token
from utils.responses import create_success_response, create_error_response from utils.responses import create_success_response, create_error_response
from utils.rate_limiter import get_rate_limiter from utils.rate_limiter import get_rate_limiter
from utils.auth_utils import ( from utils.auth_utils import (
AuthenticationManager, SecurityConfig, AuthenticationManager,
SecurityConfig,
validate_password_strength, validate_password_strength,
) )
@ -43,28 +50,31 @@ if JWT_SECRET_KEY == "":
raise ValueError("JWT_SECRET_KEY environment variable is not set") raise ValueError("JWT_SECRET_KEY environment variable is not set")
ALGORITHM = "HS256" ALGORITHM = "HS256"
# ============================ # ============================
# Password Reset Endpoints # Password Reset Endpoints
# ============================ # ============================
class PasswordResetRequest(BaseModel): class PasswordResetRequest(BaseModel):
email: EmailStr email: EmailStr
class PasswordResetConfirm(BaseModel): class PasswordResetConfirm(BaseModel):
token: str token: str
new_password: str new_password: str
@field_validator('new_password') @field_validator("new_password")
def validate_password_strength(cls, v): def validate_password_strength(cls, v):
is_valid, issues = validate_password_strength(v) is_valid, issues = validate_password_strength(v)
if not is_valid: if not is_valid:
raise ValueError('; '.join(issues)) raise ValueError("; ".join(issues))
return v return v
@router.post("/guest") @router.post("/guest")
async def create_guest_session_enhanced( async def create_guest_session_enhanced(
request: Request, request: Request,
database: RedisDatabase = Depends(get_database), database: RedisDatabase = Depends(get_database),
rate_limiter: RateLimiter = Depends(get_rate_limiter) rate_limiter: RateLimiter = Depends(get_rate_limiter),
): ):
"""Create a guest session with enhanced validation and persistence""" """Create a guest session with enhanced validation and persistence"""
try: try:
@ -73,21 +83,15 @@ async def create_guest_session_enhanced(
# Check rate limits for guest session creation # Check rate limits for guest session creation
rate_result = await rate_limiter.check_rate_limit( rate_result = await rate_limiter.check_rate_limit(
user_id=ip_address, user_id=ip_address, user_type="guest_creation", is_admin=False, endpoint="/guest"
user_type="guest_creation",
is_admin=False,
endpoint="/guest"
) )
if not rate_result.allowed: if not rate_result.allowed:
logger.warning(f"🚫 Guest creation rate limit exceeded for IP {ip_address}") logger.warning(f"🚫 Guest creation rate limit exceeded for IP {ip_address}")
return JSONResponse( return JSONResponse(
status_code=429, status_code=429,
content=create_error_response( content=create_error_response("RATE_LIMITED", rate_result.reason or "Too many guest sessions created"),
"RATE_LIMITED", headers={"Retry-After": str(rate_result.retry_after_seconds or 300)},
rate_result.reason or "Too many guest sessions created"
),
headers={"Retry-After": str(rate_result.retry_after_seconds or 300)}
) )
# Generate unique guest identifier with timestamp for uniqueness # Generate unique guest identifier with timestamp for uniqueness
@ -139,7 +143,7 @@ async def create_guest_session_enhanced(
"email": guest_data["email"], "email": guest_data["email"],
"username": guest_username, "username": guest_username,
"session_id": session_id, "session_id": session_id,
"created_at": current_time.isoformat() "created_at": current_time.isoformat(),
} }
await database.set_user(guest_data["email"], user_auth_data) await database.set_user(guest_data["email"], user_auth_data)
@ -149,11 +153,11 @@ async def create_guest_session_enhanced(
# Create authentication tokens with longer expiry for guests # Create authentication tokens with longer expiry for guests
access_token = create_access_token( access_token = create_access_token(
data={"sub": guest_id, "type": "guest"}, data={"sub": guest_id, "type": "guest"},
expires_delta=timedelta(hours=48) # Longer expiry for guests expires_delta=timedelta(hours=48), # Longer expiry for guests
) )
refresh_token = create_access_token( refresh_token = create_access_token(
data={"sub": guest_id, "type": "refresh_guest"}, data={"sub": guest_id, "type": "refresh_guest"},
expires_delta=timedelta(days=14) # 2 weeks refresh for guests expires_delta=timedelta(days=14), # 2 weeks refresh for guests
) )
# Verify guest was stored correctly # Verify guest was stored correctly
@ -161,8 +165,7 @@ async def create_guest_session_enhanced(
if not verification: if not verification:
logger.error(f"❌ Failed to verify guest storage: {guest_id}") logger.error(f"❌ Failed to verify guest storage: {guest_id}")
return JSONResponse( return JSONResponse(
status_code=500, status_code=500, content=create_error_response("STORAGE_ERROR", "Failed to create guest session")
content=create_error_response("STORAGE_ERROR", "Failed to create guest session")
) )
# Create guest object for response # Create guest object for response
@ -178,7 +181,7 @@ async def create_guest_session_enhanced(
"user": guest.model_dump(by_alias=True), "user": guest.model_dump(by_alias=True),
"expiresAt": int((current_time + timedelta(hours=48)).timestamp()), "expiresAt": int((current_time + timedelta(hours=48)).timestamp()),
"userType": "guest", "userType": "guest",
"isGuest": True "isGuest": True,
} }
return create_success_response(auth_response) return create_success_response(auth_response)
@ -186,25 +189,25 @@ async def create_guest_session_enhanced(
except Exception as e: except Exception as e:
logger.error(f"❌ Guest session creation error: {e}") logger.error(f"❌ Guest session creation error: {e}")
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return JSONResponse( return JSONResponse(
status_code=500, status_code=500, content=create_error_response("GUEST_CREATION_FAILED", "Failed to create guest session")
content=create_error_response("GUEST_CREATION_FAILED", "Failed to create guest session")
) )
@router.post("/guest/convert") @router.post("/guest/convert")
async def convert_guest_to_user( async def convert_guest_to_user(
registration_data: Dict[str, Any] = Body(...), registration_data: Dict[str, Any] = Body(...),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Convert a guest session to a permanent user account""" """Convert a guest session to a permanent user account"""
try: try:
# Verify current user is a guest # Verify current user is a guest
if current_user.user_type != "guest": if current_user.user_type != "guest":
return JSONResponse( return JSONResponse(
status_code=400, status_code=400, content=create_error_response("NOT_GUEST", "Only guest users can be converted")
content=create_error_response("NOT_GUEST", "Only guest users can be converted")
) )
guest: Guest = current_user guest: Guest = current_user
@ -215,25 +218,18 @@ async def convert_guest_to_user(
try: try:
candidate_request = CreateCandidateRequest.model_validate(registration_data) candidate_request = CreateCandidateRequest.model_validate(registration_data)
except ValidationError as e: except ValidationError as e:
return JSONResponse( return JSONResponse(status_code=400, content=create_error_response("VALIDATION_ERROR", str(e)))
status_code=400,
content=create_error_response("VALIDATION_ERROR", str(e))
)
# Check if email/username already exists # Check if email/username already exists
auth_manager = AuthenticationManager(database) auth_manager = AuthenticationManager(database)
user_exists, conflict_field = await auth_manager.check_user_exists( user_exists, conflict_field = await auth_manager.check_user_exists(
candidate_request.email, candidate_request.email, candidate_request.username
candidate_request.username
) )
if user_exists: if user_exists:
return JSONResponse( return JSONResponse(
status_code=409, status_code=409,
content=create_error_response( content=create_error_response("USER_EXISTS", f"A user with this {conflict_field} already exists"),
"USER_EXISTS",
f"A user with this {conflict_field} already exists"
)
) )
# Create candidate # Create candidate
@ -253,7 +249,7 @@ async def convert_guest_to_user(
"updated_at": current_time.isoformat(), "updated_at": current_time.isoformat(),
"status": "active", "status": "active",
"is_admin": False, "is_admin": False,
"converted_from_guest": guest.id "converted_from_guest": guest.id,
} }
candidate = Candidate.model_validate(candidate_data) candidate = Candidate.model_validate(candidate_data)
@ -269,7 +265,7 @@ async def convert_guest_to_user(
"id": candidate_id, "id": candidate_id,
"type": "candidate", "type": "candidate",
"email": candidate.email, "email": candidate.email,
"username": candidate.username "username": candidate.username,
} }
await database.set_user(candidate.email, user_auth_data) await database.set_user(candidate.email, user_auth_data)
@ -286,43 +282,45 @@ async def convert_guest_to_user(
access_token = create_access_token(data={"sub": candidate_id}) access_token = create_access_token(data={"sub": candidate_id})
refresh_token = create_access_token( refresh_token = create_access_token(
data={"sub": candidate_id, "type": "refresh"}, data={"sub": candidate_id, "type": "refresh"},
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS) expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS),
) )
auth_response = AuthResponse( auth_response = AuthResponse(
access_token=access_token, access_token=access_token,
refresh_token=refresh_token, refresh_token=refresh_token,
user=candidate, user=candidate,
expires_at=int((current_time + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()) expires_at=int((current_time + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()),
) )
logger.info(f"✅ Guest {guest.session_id} converted to candidate {candidate.username}") logger.info(f"✅ Guest {guest.session_id} converted to candidate {candidate.username}")
return create_success_response({ return create_success_response(
{
"message": "Guest account successfully converted to candidate", "message": "Guest account successfully converted to candidate",
"auth": auth_response.model_dump(by_alias=True), "auth": auth_response.model_dump(by_alias=True),
"conversionType": "candidate" "conversionType": "candidate",
}) }
)
else: else:
return JSONResponse( return JSONResponse(
status_code=400, status_code=400,
content=create_error_response("INVALID_TYPE", "Only candidate conversion is currently supported") content=create_error_response("INVALID_TYPE", "Only candidate conversion is currently supported"),
) )
except Exception as e: except Exception as e:
logger.error(f"❌ Guest conversion error: {e}") logger.error(f"❌ Guest conversion error: {e}")
return JSONResponse( return JSONResponse(
status_code=500, status_code=500, content=create_error_response("CONVERSION_FAILED", "Failed to convert guest account")
content=create_error_response("CONVERSION_FAILED", "Failed to convert guest account")
) )
@router.post("/logout") @router.post("/logout")
async def logout( async def logout(
access_token: str = Body(..., alias="accessToken"), access_token: str = Body(..., alias="accessToken"),
refresh_token: str = Body(..., alias="refreshToken"), refresh_token: str = Body(..., alias="refreshToken"),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Logout endpoint - revokes both access and refresh tokens""" """Logout endpoint - revokes both access and refresh tokens"""
logger.info(f"🔑 User {current_user.id} is logging out") logger.info(f"🔑 User {current_user.id} is logging out")
@ -336,21 +334,18 @@ async def logout(
if not user_id or token_type != "refresh": if not user_id or token_type != "refresh":
return JSONResponse( return JSONResponse(
status_code=401, status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
) )
except jwt.PyJWTError as e: except jwt.PyJWTError as e:
logger.warning(f"⚠️ Invalid refresh token during logout: {e}") logger.warning(f"⚠️ Invalid refresh token during logout: {e}")
return JSONResponse( return JSONResponse(
status_code=401, status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
) )
# Verify that the refresh token belongs to the current user # Verify that the refresh token belongs to the current user
if user_id != current_user.id: if user_id != current_user.id:
return JSONResponse( return JSONResponse(
status_code=403, status_code=403, content=create_error_response("FORBIDDEN", "Token does not belong to current user")
content=create_error_response("FORBIDDEN", "Token does not belong to current user")
) )
# Get Redis client # Get Redis client
@ -362,12 +357,14 @@ async def logout(
await redis.setex( await redis.setex(
f"blacklisted_token:{refresh_token}", f"blacklisted_token:{refresh_token}",
refresh_ttl, refresh_ttl,
json.dumps({ json.dumps(
{
"user_id": user_id, "user_id": user_id,
"token_type": "refresh", "token_type": "refresh",
"revoked_at": datetime.now(UTC).isoformat(), "revoked_at": datetime.now(UTC).isoformat(),
"reason": "user_logout" "reason": "user_logout",
}) }
),
) )
logger.info(f"🔒 Blacklisted refresh token for user {user_id}") logger.info(f"🔒 Blacklisted refresh token for user {user_id}")
@ -385,12 +382,14 @@ async def logout(
await redis.setex( await redis.setex(
f"blacklisted_token:{access_token}", f"blacklisted_token:{access_token}",
access_ttl, access_ttl,
json.dumps({ json.dumps(
{
"user_id": user_id, "user_id": user_id,
"token_type": "access", "token_type": "access",
"revoked_at": datetime.now(UTC).isoformat(), "revoked_at": datetime.now(UTC).isoformat(),
"reason": "user_logout" "reason": "user_logout",
}) }
),
) )
logger.info(f"🔒 Blacklisted access token for user {user_id}") logger.info(f"🔒 Blacklisted access token for user {user_id}")
else: else:
@ -409,26 +408,20 @@ async def logout(
# ) # )
logger.info(f"🔑 User {user_id} logged out successfully") logger.info(f"🔑 User {user_id} logged out successfully")
return create_success_response({ return create_success_response(
{
"message": "Logged out successfully", "message": "Logged out successfully",
"tokensRevoked": { "tokensRevoked": {"refreshToken": True, "accessToken": bool(access_token)},
"refreshToken": True,
"accessToken": bool(access_token)
} }
}) )
except Exception as e: except Exception as e:
logger.error(f"❌ Logout error: {e}") logger.error(f"❌ Logout error: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("LOGOUT_ERROR", str(e)))
status_code=500,
content=create_error_response("LOGOUT_ERROR", str(e))
)
@router.post("/logout-all") @router.post("/logout-all")
async def logout_all_devices( async def logout_all_devices(current_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)):
current_user = Depends(get_current_admin),
database: RedisDatabase = Depends(get_database)
):
"""Logout from all devices by revoking all tokens for the user""" """Logout from all devices by revoking all tokens for the user"""
try: try:
redis = redis_manager.get_client() redis = redis_manager.get_client()
@ -437,25 +430,20 @@ async def logout_all_devices(
await redis.setex( await redis.setex(
f"user_tokens_revoked:{current_user.id}", f"user_tokens_revoked:{current_user.id}",
int(timedelta(days=30).total_seconds()), # Max refresh token lifetime int(timedelta(days=30).total_seconds()), # Max refresh token lifetime
datetime.now(UTC).isoformat() datetime.now(UTC).isoformat(),
) )
logger.info(f"🔒 All tokens revoked for user {current_user.id}") logger.info(f"🔒 All tokens revoked for user {current_user.id}")
return create_success_response({ return create_success_response({"message": "Logged out from all devices successfully"})
"message": "Logged out from all devices successfully"
})
except Exception as e: except Exception as e:
logger.error(f"❌ Logout all devices error: {e}") logger.error(f"❌ Logout all devices error: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("LOGOUT_ALL_ERROR", str(e)))
status_code=500,
content=create_error_response("LOGOUT_ALL_ERROR", str(e))
)
@router.post("/refresh") @router.post("/refresh")
async def refresh_token_endpoint( async def refresh_token_endpoint(
refresh_token: str = Body(..., alias="refreshToken"), refresh_token: str = Body(..., alias="refreshToken"), database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database)
): ):
"""Refresh token endpoint""" """Refresh token endpoint"""
try: try:
@ -466,8 +454,7 @@ async def refresh_token_endpoint(
if not user_id or token_type != "refresh": if not user_id or token_type != "refresh":
return JSONResponse( return JSONResponse(
status_code=401, status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
) )
# Create new access token # Create new access token
@ -484,37 +471,29 @@ async def refresh_token_endpoint(
user = Employer.model_validate(employer_data) user = Employer.model_validate(employer_data)
if not user: if not user:
return JSONResponse( return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
status_code=404,
content=create_error_response("USER_NOT_FOUND", "User not found")
)
auth_response = AuthResponse( auth_response = AuthResponse(
access_token=access_token, access_token=access_token,
refresh_token=refresh_token, # Keep same refresh token refresh_token=refresh_token, # Keep same refresh token
user=user, user=user,
expires_at=int((datetime.now(UTC) + timedelta(hours=24)).timestamp()) expires_at=int((datetime.now(UTC) + timedelta(hours=24)).timestamp()),
) )
return create_success_response(auth_response.model_dump(by_alias=True)) return create_success_response(auth_response.model_dump(by_alias=True))
except jwt.PyJWTError: except jwt.PyJWTError:
return JSONResponse( return JSONResponse(status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token"))
status_code=401,
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
)
except Exception as e: except Exception as e:
logger.error(f"❌ Token refresh error: {e}") logger.error(f"❌ Token refresh error: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("REFRESH_ERROR", str(e)))
status_code=500,
content=create_error_response("REFRESH_ERROR", str(e))
)
@router.post("/resend-verification") @router.post("/resend-verification")
async def resend_verification_email( async def resend_verification_email(
request: ResendVerificationRequest, request: ResendVerificationRequest,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Resend verification email with comprehensive rate limiting and validation""" """Resend verification email with comprehensive rate limiting and validation"""
try: try:
@ -527,10 +506,7 @@ async def resend_verification_email(
can_send, reason = await rate_limiter.can_send_verification_email(email_lower) can_send, reason = await rate_limiter.can_send_verification_email(email_lower)
if not can_send: if not can_send:
logger.warning(f"⚠️ Verification email rate limit exceeded for {email_lower}: {reason}") logger.warning(f"⚠️ Verification email rate limit exceeded for {email_lower}: {reason}")
return JSONResponse( return JSONResponse(status_code=429, content=create_error_response("RATE_LIMITED", reason))
status_code=429,
content=create_error_response("RATE_LIMITED", reason)
)
# Clean up expired tokens first # Clean up expired tokens first
await database.cleanup_expired_verification_tokens() await database.cleanup_expired_verification_tokens()
@ -541,9 +517,11 @@ async def resend_verification_email(
# User exists and is verified - don't reveal this for security # User exists and is verified - don't reveal this for security
logger.info(f"🔍 Resend verification requested for already verified user: {email_lower}") logger.info(f"🔍 Resend verification requested for already verified user: {email_lower}")
await rate_limiter.record_email_sent(email_lower) # Record attempt to prevent abuse await rate_limiter.record_email_sent(email_lower) # Record attempt to prevent abuse
return create_success_response({ return create_success_response(
{
"message": "If your email is in our system and pending verification, a new verification email has been sent." "message": "If your email is in our system and pending verification, a new verification email has been sent."
}) }
)
# Look for pending verification token # Look for pending verification token
verification_data = await database.find_verification_token_by_email(email_lower) verification_data = await database.find_verification_token_by_email(email_lower)
@ -552,9 +530,11 @@ async def resend_verification_email(
# No pending verification found - don't reveal this for security # No pending verification found - don't reveal this for security
logger.info(f"🔍 Resend verification requested for non-existent pending verification: {email_lower}") logger.info(f"🔍 Resend verification requested for non-existent pending verification: {email_lower}")
await rate_limiter.record_email_sent(email_lower) # Record attempt to prevent abuse await rate_limiter.record_email_sent(email_lower) # Record attempt to prevent abuse
return create_success_response({ return create_success_response(
{
"message": "If your email is in our system and pending verification, a new verification email has been sent." "message": "If your email is in our system and pending verification, a new verification email has been sent."
}) }
)
# Check if verification token has expired # Check if verification token has expired
expires_at = datetime.fromisoformat(verification_data["expires_at"]) expires_at = datetime.fromisoformat(verification_data["expires_at"])
@ -568,8 +548,8 @@ async def resend_verification_email(
status_code=400, status_code=400,
content=create_error_response( content=create_error_response(
"TOKEN_EXPIRED", "TOKEN_EXPIRED",
"Your verification link has expired. Please register again to create a new account." "Your verification link has expired. Please register again to create a new account.",
) ),
) )
# Generate new verification token (invalidate old one) # Generate new verification token (invalidate old one)
@ -577,20 +557,19 @@ async def resend_verification_email(
new_token = secrets.token_urlsafe(32) new_token = secrets.token_urlsafe(32)
# Update verification data with new token and reset attempts # Update verification data with new token and reset attempts
verification_data.update({ verification_data.update(
{
"token": new_token, "token": new_token,
"expires_at": (current_time + timedelta(hours=24)).isoformat(), "expires_at": (current_time + timedelta(hours=24)).isoformat(),
"resent_at": current_time.isoformat(), "resent_at": current_time.isoformat(),
"resend_count": verification_data.get("resend_count", 0) + 1 "resend_count": verification_data.get("resend_count", 0) + 1,
}) }
)
# Store new token and remove old one # Store new token and remove old one
await database.redis.delete(f"email_verification:{old_token}") await database.redis.delete(f"email_verification:{old_token}")
await database.store_email_verification_token( await database.store_email_verification_token(
email_lower, email_lower, new_token, verification_data["user_type"], verification_data["user_data"]
new_token,
verification_data["user_type"],
verification_data["user_data"]
) )
# Get user name for email # Get user name for email
@ -610,69 +589,62 @@ async def resend_verification_email(
await rate_limiter.record_email_sent(email_lower) await rate_limiter.record_email_sent(email_lower)
# Send new verification email in background # Send new verification email in background
background_tasks.add_task( background_tasks.add_task(email_service.send_verification_email, email_lower, new_token, user_name, user_type)
email_service.send_verification_email,
email_lower,
new_token,
user_name,
user_type
)
# Log security event # Log security event
await database.log_security_event( await database.log_security_event(
verification_data["user_data"].get("candidate_data", {}).get("id") or verification_data["user_data"].get("candidate_data", {}).get("id")
verification_data["user_data"].get("employer_data", {}).get("id") or "unknown", or verification_data["user_data"].get("employer_data", {}).get("id")
or "unknown",
"verification_resend", "verification_resend",
{ {
"email": email_lower, "email": email_lower,
"user_type": user_type, "user_type": user_type,
"resend_count": verification_data.get("resend_count", 1), "resend_count": verification_data.get("resend_count", 1),
"old_token_invalidated": old_token[:8] + "...", # Log partial token for debugging "old_token_invalidated": old_token[:8] + "...", # Log partial token for debugging
"ip_address": "unknown" # You can extract this from request if needed "ip_address": "unknown", # You can extract this from request if needed
},
)
logger.info(
f"✅ Verification email resent to {email_lower} (attempt #{verification_data.get('resend_count', 1)})"
)
return create_success_response(
{
"message": "A new verification email has been sent to your email address. Please check your inbox and spam folder.",
"resendCount": verification_data.get("resend_count", 1),
} }
) )
logger.info(f"✅ Verification email resent to {email_lower} (attempt #{verification_data.get('resend_count', 1)})")
return create_success_response({
"message": "A new verification email has been sent to your email address. Please check your inbox and spam folder.",
"resendCount": verification_data.get("resend_count", 1)
})
except ValueError as ve: except ValueError as ve:
logger.warning(f"⚠️ Invalid resend verification request: {ve}") logger.warning(f"⚠️ Invalid resend verification request: {ve}")
return JSONResponse( return JSONResponse(status_code=400, content=create_error_response("VALIDATION_ERROR", str(ve)))
status_code=400,
content=create_error_response("VALIDATION_ERROR", str(ve))
)
except Exception as e: except Exception as e:
logger.error(f"❌ Resend verification email error: {e}") logger.error(f"❌ Resend verification email error: {e}")
return JSONResponse( return JSONResponse(
status_code=500, status_code=500,
content=create_error_response("RESEND_FAILED", "An error occurred while processing your request. Please try again later.") content=create_error_response(
"RESEND_FAILED", "An error occurred while processing your request. Please try again later."
),
) )
@router.post("/mfa/request") @router.post("/mfa/request")
async def request_mfa( async def request_mfa(
request: MFARequest, request: MFARequest,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
http_request: Request, http_request: Request,
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Request MFA for login from new device""" """Request MFA for login from new device"""
try: try:
# Verify credentials first # Verify credentials first
auth_manager = AuthenticationManager(database) auth_manager = AuthenticationManager(database)
is_valid, user_data, error_message = await auth_manager.verify_user_credentials( is_valid, user_data, error_message = await auth_manager.verify_user_credentials(request.email, request.password)
request.email,
request.password
)
if not is_valid or not user_data: if not is_valid or not user_data:
return JSONResponse( return JSONResponse(status_code=401, content=create_error_response("AUTH_FAILED", "Invalid credentials"))
status_code=401,
content=create_error_response("AUTH_FAILED", "Invalid credentials")
)
# Check if device is trusted # Check if device is trusted
device_manager = DeviceManager(database) device_manager = DeviceManager(database)
@ -684,10 +656,7 @@ async def request_mfa(
# Device is trusted, proceed with normal login # Device is trusted, proceed with normal login
await device_manager.update_device_last_used(user_data["id"], request.device_id) await device_manager.update_device_last_used(user_data["id"], request.device_id)
return create_success_response({ return create_success_response({"mfa_required": False, "message": "Device is trusted, proceed with login"})
"mfa_required": False,
"message": "Device is trusted, proceed with login"
})
# Generate MFA code # Generate MFA code
mfa_code = f"{secrets.randbelow(1000000):06d}" # 6-digit code mfa_code = f"{secrets.randbelow(1000000):06d}" # 6-digit code
@ -709,8 +678,7 @@ async def request_mfa(
if not email: if not email:
return JSONResponse( return JSONResponse(
status_code=400, status_code=400, content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA")
content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA")
) )
# Store MFA code # Store MFA code
@ -718,13 +686,7 @@ async def request_mfa(
logger.info(f"🔐 MFA code generated for {email} on device {request.device_id}") logger.info(f"🔐 MFA code generated for {email} on device {request.device_id}")
# Send MFA code via email # Send MFA code via email
background_tasks.add_task( background_tasks.add_task(email_service.send_mfa_email, email, mfa_code, request.device_name, user_name)
email_service.send_mfa_email,
email,
mfa_code,
request.device_name,
user_name
)
logger.info(f"🔐 MFA requested for {request.email} from new device {request.device_name}") logger.info(f"🔐 MFA requested for {request.email} from new device {request.device_name}")
@ -735,25 +697,22 @@ async def request_mfa(
device_id=request.device_id, device_id=request.device_id,
device_name=request.device_name, device_name=request.device_name,
) )
mfa_response = MFARequestResponse( mfa_response = MFARequestResponse(mfa_required=True, mfa_data=mfa_data)
mfa_required=True,
mfa_data=mfa_data
)
return create_success_response(mfa_response) return create_success_response(mfa_response)
except Exception as e: except Exception as e:
logger.error(f"❌ MFA request error: {e}") logger.error(f"❌ MFA request error: {e}")
return JSONResponse( return JSONResponse(
status_code=500, status_code=500, content=create_error_response("MFA_REQUEST_FAILED", "Failed to process MFA request")
content=create_error_response("MFA_REQUEST_FAILED", "Failed to process MFA request")
) )
@router.post("/login") @router.post("/login")
async def login( async def login(
request: LoginRequest, request: LoginRequest,
http_request: Request, http_request: Request,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""login with automatic MFA email sending for new devices""" """login with automatic MFA email sending for new devices"""
try: try:
@ -766,16 +725,12 @@ async def login(
device_id = device_info["device_id"] device_id = device_info["device_id"]
# Verify credentials first # Verify credentials first
is_valid, user_data, error_message = await auth_manager.verify_user_credentials( is_valid, user_data, error_message = await auth_manager.verify_user_credentials(request.login, request.password)
request.login,
request.password
)
if not is_valid or not user_data: if not is_valid or not user_data:
logger.warning(f"⚠️ Failed login attempt for: {request.login}") logger.warning(f"⚠️ Failed login attempt for: {request.login}")
return JSONResponse( return JSONResponse(
status_code=401, status_code=401, content=create_error_response("AUTH_FAILED", error_message or "Invalid credentials")
content=create_error_response("AUTH_FAILED", error_message or "Invalid credentials")
) )
# Check if device is trusted # Check if device is trusted
@ -804,8 +759,7 @@ async def login(
if not email: if not email:
return JSONResponse( return JSONResponse(
status_code=400, status_code=400, content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA")
content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA")
) )
# Store MFA code # Store MFA code
@ -817,12 +771,7 @@ async def login(
# Send MFA code via email in background # Send MFA code via email in background
background_tasks.add_task( background_tasks.add_task(
email_service.send_mfa_email, email_service.send_mfa_email, email, mfa_code, device_info["device_name"], user_name, ip_address
email,
mfa_code,
device_info["device_name"],
user_name,
ip_address
) )
# Log security event # Log security event
@ -834,8 +783,8 @@ async def login(
"device_name": device_info["device_name"], "device_name": device_info["device_name"],
"ip_address": ip_address, "ip_address": ip_address,
"user_agent": device_info.get("user_agent", ""), "user_agent": device_info.get("user_agent", ""),
"auto_sent": True "auto_sent": True,
} },
) )
logger.info(f"🔐 MFA code automatically sent to {request.login} for device {device_info['device_name']}") logger.info(f"🔐 MFA code automatically sent to {request.login} for device {device_info['device_name']}")
@ -847,8 +796,8 @@ async def login(
email=email, email=email,
device_id=device_id, device_id=device_id,
device_name=device_info["device_name"], device_name=device_info["device_name"],
code_sent=mfa_code code_sent=mfa_code,
) ),
) )
return create_success_response(mfa_response.model_dump(by_alias=True)) return create_success_response(mfa_response.model_dump(by_alias=True))
@ -860,7 +809,7 @@ async def login(
access_token = create_access_token(data={"sub": user_data["id"]}) access_token = create_access_token(data={"sub": user_data["id"]})
refresh_token = create_access_token( refresh_token = create_access_token(
data={"sub": user_data["id"], "type": "refresh"}, data={"sub": user_data["id"], "type": "refresh"},
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS) expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS),
) )
# Get user object # Get user object
@ -876,8 +825,7 @@ async def login(
if not user: if not user:
return JSONResponse( return JSONResponse(
status_code=404, status_code=404, content=create_error_response("USER_NOT_FOUND", "User profile not found")
content=create_error_response("USER_NOT_FOUND", "User profile not found")
) )
# Log successful login from trusted device # Log successful login from trusted device
@ -888,8 +836,8 @@ async def login(
"device_id": device_id, "device_id": device_id,
"device_name": device_info["device_name"], "device_name": device_info["device_name"],
"ip_address": http_request.client.host if http_request.client else "Unknown", "ip_address": http_request.client.host if http_request.client else "Unknown",
"trusted_device": True "trusted_device": True,
} },
) )
# Create response # Create response
@ -897,7 +845,9 @@ async def login(
access_token=access_token, access_token=access_token,
refresh_token=refresh_token, refresh_token=refresh_token,
user=user, user=user,
expires_at=int((datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()) expires_at=int(
(datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()
),
) )
logger.info(f"🔑 User {request.login} logged in successfully from trusted device") logger.info(f"🔑 User {request.login} logged in successfully from trusted device")
@ -908,17 +858,12 @@ async def login(
logger.error(backstory_traceback.format_exc()) logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Login error: {e}") logger.error(f"❌ Login error: {e}")
return JSONResponse( return JSONResponse(
status_code=500, status_code=500, content=create_error_response("LOGIN_ERROR", "An error occurred during login")
content=create_error_response("LOGIN_ERROR", "An error occurred during login")
) )
@router.post("/mfa/verify") @router.post("/mfa/verify")
async def verify_mfa( async def verify_mfa(request: MFAVerifyRequest, http_request: Request, database: RedisDatabase = Depends(get_database)):
request: MFAVerifyRequest,
http_request: Request,
database: RedisDatabase = Depends(get_database)
):
"""Verify MFA code and complete login with error handling""" """Verify MFA code and complete login with error handling"""
try: try:
# Get MFA data # Get MFA data
@ -928,13 +873,17 @@ async def verify_mfa(
logger.warning(f"⚠️ No MFA session found for {request.email} on device {request.device_id}") logger.warning(f"⚠️ No MFA session found for {request.email} on device {request.device_id}")
return JSONResponse( return JSONResponse(
status_code=404, status_code=404,
content=create_error_response("NO_MFA_SESSION", "No active MFA session found. Please try logging in again.") content=create_error_response(
"NO_MFA_SESSION", "No active MFA session found. Please try logging in again."
),
) )
if mfa_data.get("verified"): if mfa_data.get("verified"):
return JSONResponse( return JSONResponse(
status_code=400, status_code=400,
content=create_error_response("ALREADY_VERIFIED", "This MFA code has already been used. Please login again.") content=create_error_response(
"ALREADY_VERIFIED", "This MFA code has already been used. Please login again."
),
) )
# Check expiration # Check expiration
@ -944,7 +893,7 @@ async def verify_mfa(
await database.redis.delete(f"mfa_code:{request.email.lower()}:{request.device_id}") await database.redis.delete(f"mfa_code:{request.email.lower()}:{request.device_id}")
return JSONResponse( return JSONResponse(
status_code=400, status_code=400,
content=create_error_response("MFA_EXPIRED", "MFA code has expired. Please try logging in again.") content=create_error_response("MFA_EXPIRED", "MFA code has expired. Please try logging in again."),
) )
# Check attempts # Check attempts
@ -954,7 +903,9 @@ async def verify_mfa(
await database.redis.delete(f"mfa_code:{request.email.lower()}:{request.device_id}") await database.redis.delete(f"mfa_code:{request.email.lower()}:{request.device_id}")
return JSONResponse( return JSONResponse(
status_code=429, status_code=429,
content=create_error_response("TOO_MANY_ATTEMPTS", "Too many incorrect attempts. Please try logging in again.") content=create_error_response(
"TOO_MANY_ATTEMPTS", "Too many incorrect attempts. Please try logging in again."
),
) )
# Verify code # Verify code
@ -965,9 +916,8 @@ async def verify_mfa(
return JSONResponse( return JSONResponse(
status_code=400, status_code=400,
content=create_error_response( content=create_error_response(
"INVALID_CODE", "INVALID_CODE", f"Invalid MFA code. {remaining_attempts} attempts remaining."
f"Invalid MFA code. {remaining_attempts} attempts remaining." ),
)
) )
# Mark as verified # Mark as verified
@ -976,20 +926,13 @@ async def verify_mfa(
# Get user data # Get user data
user_data = await database.get_user(request.email) user_data = await database.get_user(request.email)
if not user_data: if not user_data:
return JSONResponse( return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
status_code=404,
content=create_error_response("USER_NOT_FOUND", "User not found")
)
# Add device to trusted devices if requested # Add device to trusted devices if requested
if request.remember_device: if request.remember_device:
device_manager = DeviceManager(database) device_manager = DeviceManager(database)
device_info = device_manager.parse_device_info(http_request) device_info = device_manager.parse_device_info(http_request)
await device_manager.add_trusted_device( await device_manager.add_trusted_device(user_data["id"], request.device_id, device_info)
user_data["id"],
request.device_id,
device_info
)
logger.info(f"🔒 Device {request.device_id} added to trusted devices for user {user_data['id']}") logger.info(f"🔒 Device {request.device_id} added to trusted devices for user {user_data['id']}")
# Update last login # Update last login
@ -1000,7 +943,7 @@ async def verify_mfa(
access_token = create_access_token(data={"sub": user_data["id"]}) access_token = create_access_token(data={"sub": user_data["id"]})
refresh_token = create_access_token( refresh_token = create_access_token(
data={"sub": user_data["id"], "type": "refresh"}, data={"sub": user_data["id"], "type": "refresh"},
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS) expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS),
) )
# Get user object # Get user object
@ -1016,8 +959,7 @@ async def verify_mfa(
if not user: if not user:
return JSONResponse( return JSONResponse(
status_code=404, status_code=404, content=create_error_response("USER_NOT_FOUND", "User profile not found")
content=create_error_response("USER_NOT_FOUND", "User profile not found")
) )
# Log successful MFA verification and login # Log successful MFA verification and login
@ -1028,8 +970,8 @@ async def verify_mfa(
"device_id": request.device_id, "device_id": request.device_id,
"ip_address": http_request.client.host if http_request.client else "Unknown", "ip_address": http_request.client.host if http_request.client else "Unknown",
"device_remembered": request.remember_device, "device_remembered": request.remember_device,
"attempts_used": current_attempts + 1 "attempts_used": current_attempts + 1,
} },
) )
await database.log_security_event( await database.log_security_event(
@ -1039,8 +981,8 @@ async def verify_mfa(
"device_id": request.device_id, "device_id": request.device_id,
"ip_address": http_request.client.host if http_request.client else "Unknown", "ip_address": http_request.client.host if http_request.client else "Unknown",
"mfa_verified": True, "mfa_verified": True,
"new_device": True "new_device": True,
} },
) )
# Clean up MFA session # Clean up MFA session
@ -1051,7 +993,9 @@ async def verify_mfa(
access_token=access_token, access_token=access_token,
refresh_token=refresh_token, refresh_token=refresh_token,
user=user, user=user,
expires_at=int((datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()) expires_at=int(
(datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()
),
) )
logger.info(f"✅ MFA verified and login completed for {request.email}") logger.info(f"✅ MFA verified and login completed for {request.email}")
@ -1062,15 +1006,12 @@ async def verify_mfa(
logger.error(backstory_traceback.format_exc()) logger.error(backstory_traceback.format_exc())
logger.error(f"❌ MFA verification error: {e}") logger.error(f"❌ MFA verification error: {e}")
return JSONResponse( return JSONResponse(
status_code=500, status_code=500, content=create_error_response("MFA_VERIFICATION_FAILED", "Failed to verify MFA")
content=create_error_response("MFA_VERIFICATION_FAILED", "Failed to verify MFA")
) )
@router.post("/password-reset/request") @router.post("/password-reset/request")
async def request_password_reset( async def request_password_reset(request: PasswordResetRequest, database: RedisDatabase = Depends(get_database)):
request: PasswordResetRequest,
database: RedisDatabase = Depends(get_database)
):
"""Request password reset""" """Request password reset"""
try: try:
# Check if user exists # Check if user exists
@ -1100,15 +1041,12 @@ async def request_password_reset(
except Exception as e: except Exception as e:
logger.error(f"❌ Password reset request error: {e}") logger.error(f"❌ Password reset request error: {e}")
return JSONResponse( return JSONResponse(
status_code=500, status_code=500, content=create_error_response("RESET_ERROR", "An error occurred processing the request")
content=create_error_response("RESET_ERROR", "An error occurred processing the request")
) )
@router.post("/password-reset/confirm") @router.post("/password-reset/confirm")
async def confirm_password_reset( async def confirm_password_reset(request: PasswordResetConfirm, database: RedisDatabase = Depends(get_database)):
request: PasswordResetConfirm,
database: RedisDatabase = Depends(get_database)
):
"""Confirm password reset with token""" """Confirm password reset with token"""
try: try:
# Find user by reset token # Find user by reset token
@ -1122,8 +1060,5 @@ async def confirm_password_reset(
except Exception as e: except Exception as e:
logger.error(f"❌ Password reset confirm error: {e}") logger.error(f"❌ Password reset confirm error: {e}")
return JSONResponse( return JSONResponse(
status_code=500, status_code=500, content=create_error_response("RESET_ERROR", "An error occurred resetting the password")
content=create_error_response("RESET_ERROR", "An error occurred resetting the password")
) )

File diff suppressed because it is too large Load Diff

View File

@ -4,96 +4,69 @@ Chat routes
import json import json
import uuid import uuid
from datetime import datetime, UTC from datetime import datetime, UTC
from typing import (Dict, Any) from typing import Dict, Any
from fastapi import ( from fastapi import APIRouter, Depends, Body, Depends, Query, Path, Body, APIRouter
APIRouter, Depends, Body, Depends, Query, Path,
Body, APIRouter
)
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from database.manager import RedisDatabase from database.manager import RedisDatabase
from logger import logger from logger import logger
from utils.dependencies import ( from utils.dependencies import get_database, get_current_user, get_current_user_or_guest
get_database, get_current_user, get_current_user_or_guest from utils.responses import create_success_response, create_error_response, create_paginated_response
) from utils.helpers import stream_agent_response
from utils.responses import (
create_success_response, create_error_response, create_paginated_response
)
from utils.helpers import (
stream_agent_response
)
import backstory_traceback import backstory_traceback
import entities.entity_manager as entities import entities.entity_manager as entities
from models import ( from models import Candidate, ChatMessageUser, Candidate, BaseUserWithType, ChatSession, ChatMessage
Candidate, ChatMessageUser, Candidate, BaseUserWithType, ChatSession, ChatMessage
)
# Create router for authentication endpoints # Create router for authentication endpoints
router = APIRouter(prefix="/chat", tags=["chat"]) router = APIRouter(prefix="/chat", tags=["chat"])
@router.post("/sessions/{session_id}/archive") @router.post("/sessions/{session_id}/archive")
async def archive_chat_session( async def archive_chat_session(
session_id: str = Path(...), session_id: str = Path(...), current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
): ):
"""Archive a chat session""" """Archive a chat session"""
try: try:
session_data = await database.get_chat_session(session_id) session_data = await database.get_chat_session(session_id)
if not session_data: if not session_data:
return JSONResponse( return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
# Check if user owns this session or is admin # Check if user owns this session or is admin
if session_data.get("userId") != current_user.id: if session_data.get("userId") != current_user.id:
return JSONResponse( return JSONResponse(
status_code=403, status_code=403, content=create_error_response("FORBIDDEN", "Cannot archive another user's session")
content=create_error_response("FORBIDDEN", "Cannot archive another user's session")
) )
await database.archive_chat_session(session_id) await database.archive_chat_session(session_id)
return create_success_response({ return create_success_response({"message": "Chat session archived successfully", "sessionId": session_id})
"message": "Chat session archived successfully",
"sessionId": session_id
})
except Exception as e: except Exception as e:
logger.error(f"❌ Archive chat session error: {e}") logger.error(f"❌ Archive chat session error: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("ARCHIVE_ERROR", str(e)))
status_code=500,
content=create_error_response("ARCHIVE_ERROR", str(e))
)
@router.get("/statistics") @router.get("/statistics")
async def get_chat_statistics( async def get_chat_statistics(current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)):
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
):
"""Get chat statistics (admin/analytics endpoint)""" """Get chat statistics (admin/analytics endpoint)"""
try: try:
stats = await database.get_chat_statistics() stats = await database.get_chat_statistics()
return create_success_response(stats) return create_success_response(stats)
except Exception as e: except Exception as e:
logger.error(f"❌ Get chat statistics error: {e}") logger.error(f"❌ Get chat statistics error: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("STATS_ERROR", str(e)))
status_code=500,
content=create_error_response("STATS_ERROR", str(e))
)
@router.post("/sessions") @router.post("/sessions")
async def create_chat_session( async def create_chat_session(
session_data: Dict[str, Any] = Body(...), session_data: Dict[str, Any] = Body(...),
current_user: BaseUserWithType = Depends(get_current_user_or_guest), current_user: BaseUserWithType = Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Create a new chat session with optional candidate username association""" """Create a new chat session with optional candidate username association"""
try: try:
@ -111,15 +84,14 @@ async def create_chat_session(
candidates_list = [Candidate.model_validate(data) for data in all_candidates_data.values()] candidates_list = [Candidate.model_validate(data) for data in all_candidates_data.values()]
# Find candidate by username (case-insensitive) # Find candidate by username (case-insensitive)
matching_candidates = [ matching_candidates = [c for c in candidates_list if c.username.lower() == username.lower()]
c for c in candidates_list
if c.username.lower() == username.lower()
]
if not matching_candidates: if not matching_candidates:
return JSONResponse( return JSONResponse(
status_code=404, status_code=404,
content=create_error_response("CANDIDATE_NOT_FOUND", f"Candidate with username '{username}' not found") content=create_error_response(
"CANDIDATE_NOT_FOUND", f"Candidate with username '{username}' not found"
),
) )
candidate_data = matching_candidates[0] candidate_data = matching_candidates[0]
@ -148,7 +120,7 @@ async def create_chat_session(
"username": candidate_data.username, "username": candidate_data.username,
"skills": [skill.name for skill in candidate_data.skills] if candidate_data.skills else [], "skills": [skill.name for skill in candidate_data.skills] if candidate_data.skills else [],
"experience": len(candidate_data.experience) if candidate_data.experience else 0, "experience": len(candidate_data.experience) if candidate_data.experience else 0,
"location": candidate_data.location.city if candidate_data.location else "Unknown" "location": candidate_data.location.city if candidate_data.location else "Unknown",
} }
context["additionalContext"] = additional_context context["additionalContext"] = additional_context
@ -162,8 +134,10 @@ async def create_chat_session(
chat_session = ChatSession.model_validate(session_data) chat_session = ChatSession.model_validate(session_data)
await database.set_chat_session(chat_session.id, chat_session.model_dump()) await database.set_chat_session(chat_session.id, chat_session.model_dump())
logger.info(f"✅ Chat session created: {chat_session.id} for user {current_user.id}" + logger.info(
(f" about candidate {candidate_data.full_name}" if candidate_data else "")) f"✅ Chat session created: {chat_session.id} for user {current_user.id}"
+ (f" about candidate {candidate_data.full_name}" if candidate_data else "")
)
return create_success_response(chat_session.model_dump(by_alias=True)) return create_success_response(chat_session.model_dump(by_alias=True))
@ -171,49 +145,58 @@ async def create_chat_session(
logger.error(backstory_traceback.format_exc()) logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Chat session creation error: {e}") logger.error(f"❌ Chat session creation error: {e}")
logger.info(json.dumps(session_data, indent=2)) logger.info(json.dumps(session_data, indent=2))
return JSONResponse( return JSONResponse(status_code=400, content=create_error_response("CREATION_FAILED", str(e)))
status_code=400,
content=create_error_response("CREATION_FAILED", str(e))
)
@router.post("/sessions/messages/stream") @router.post("/sessions/messages/stream")
async def post_chat_session_message_stream( async def post_chat_session_message_stream(
user_message: ChatMessageUser = Body(...), user_message: ChatMessageUser = Body(...),
current_user=Depends(get_current_user_or_guest), current_user=Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Post a message to a chat session and stream the response with persistence""" """Post a message to a chat session and stream the response with persistence"""
try: try:
chat_session_data = await database.get_chat_session(user_message.session_id) chat_session_data = await database.get_chat_session(user_message.session_id)
if not chat_session_data: if not chat_session_data:
logger.info("🔗 Chat session not found for session ID: " + user_message.session_id) logger.info("🔗 Chat session not found for session ID: " + user_message.session_id)
return JSONResponse( return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
chat_session = ChatSession.model_validate(chat_session_data) chat_session = ChatSession.model_validate(chat_session_data)
chat_type = chat_session.context.type chat_type = chat_session.context.type
candidate_info = chat_session.context.additional_context.get("candidateInfo", {}) if chat_session.context and chat_session.context.additional_context else None candidate_info = (
chat_session.context.additional_context.get("candidateInfo", {})
if chat_session.context and chat_session.context.additional_context
else None
)
# Get candidate info if this chat is about a specific candidate # Get candidate info if this chat is about a specific candidate
if candidate_info: if candidate_info:
logger.info(f"🔗 Chat session {user_message.session_id} about candidate {candidate_info['name']} accessed by user {current_user.id}") logger.info(
f"🔗 Chat session {user_message.session_id} about candidate {candidate_info['name']} accessed by user {current_user.id}"
)
else: else:
logger.info(f"🔗 Chat session {user_message.session_id} type {chat_type} accessed by user {current_user.id}") logger.info(
f"🔗 Chat session {user_message.session_id} type {chat_type} accessed by user {current_user.id}"
)
return JSONResponse( return JSONResponse(
status_code=400, status_code=400,
content=create_error_response("CANDIDATE_REQUIRED", "This chat session requires a candidate association") content=create_error_response(
"CANDIDATE_REQUIRED", "This chat session requires a candidate association"
),
) )
candidate_data = await database.get_candidate(candidate_info["id"]) if candidate_info else None candidate_data = await database.get_candidate(candidate_info["id"]) if candidate_info else None
candidate: Candidate | None = Candidate.model_validate(candidate_data) if candidate_data else None candidate: Candidate | None = Candidate.model_validate(candidate_data) if candidate_data else None
if not candidate: if not candidate:
logger.info(f"🔗 Candidate not found for chat session {user_message.session_id} with ID {candidate_info['id']}") logger.info(
f"🔗 Candidate not found for chat session {user_message.session_id} with ID {candidate_info['id']}"
)
return JSONResponse( return JSONResponse(
status_code=404, status_code=404,
content=create_error_response("CANDIDATE_NOT_FOUND", "Candidate not found for this chat session") content=create_error_response("CANDIDATE_NOT_FOUND", "Candidate not found for this chat session"),
)
logger.info(
f"🔗 User {current_user.id} posting message to chat session {user_message.session_id} with query length: {len(user_message.content)}"
) )
logger.info(f"🔗 User {current_user.id} posting message to chat session {user_message.session_id} with query length: {len(user_message.content)}")
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity: async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
# Entity automatically released when done # Entity automatically released when done
@ -222,7 +205,7 @@ async def post_chat_session_message_stream(
logger.info(f"🔗 No chat agent found for session {user_message.session_id} with type {chat_type}") logger.info(f"🔗 No chat agent found for session {user_message.session_id} with type {chat_type}")
return JSONResponse( return JSONResponse(
status_code=400, status_code=400,
content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type") content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type"),
) )
# Persist user message to database # Persist user message to database
@ -243,10 +226,8 @@ async def post_chat_session_message_stream(
except Exception: except Exception:
logger.error(backstory_traceback.format_exc()) logger.error(backstory_traceback.format_exc())
logger.error("❌ Chat message streaming error") logger.error("❌ Chat message streaming error")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("STREAMING_ERROR", ""))
status_code=500,
content=create_error_response("STREAMING_ERROR", "")
)
@router.get("/sessions/{session_id}/messages") @router.get("/sessions/{session_id}/messages")
async def get_chat_session_messages( async def get_chat_session_messages(
@ -254,16 +235,13 @@ async def get_chat_session_messages(
current_user=Depends(get_current_user_or_guest), current_user=Depends(get_current_user_or_guest),
page: int = Query(1, ge=1), page: int = Query(1, ge=1),
limit: int = Query(50, ge=1, le=100), # Increased default for chat messages limit: int = Query(50, ge=1, le=100), # Increased default for chat messages
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Get persisted chat messages for a session""" """Get persisted chat messages for a session"""
try: try:
chat_session_data = await database.get_chat_session(session_id) chat_session_data = await database.get_chat_session(session_id)
if not chat_session_data: if not chat_session_data:
return JSONResponse( return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
# Get messages from database # Get messages from database
chat_messages = await database.get_chat_messages(session_id) chat_messages = await database.get_chat_messages(session_id)
@ -288,43 +266,36 @@ async def get_chat_session_messages(
paginated_messages = messages_list[start:end] paginated_messages = messages_list[start:end]
paginated_response = create_paginated_response( paginated_response = create_paginated_response(
[m.model_dump(by_alias=True) for m in paginated_messages], [m.model_dump(by_alias=True) for m in paginated_messages], page, limit, total
page, limit, total
) )
return create_success_response(paginated_response) return create_success_response(paginated_response)
except Exception as e: except Exception as e:
logger.error(f"❌ Get chat messages error: {e}") logger.error(f"❌ Get chat messages error: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e)))
status_code=500,
content=create_error_response("FETCH_ERROR", str(e))
)
@router.patch("/sessions/{session_id}") @router.patch("/sessions/{session_id}")
async def update_chat_session( async def update_chat_session(
session_id: str = Path(...), session_id: str = Path(...),
updates: Dict[str, Any] = Body(...), updates: Dict[str, Any] = Body(...),
current_user=Depends(get_current_user_or_guest), current_user=Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Update a chat session's properties""" """Update a chat session's properties"""
try: try:
# Get the existing session # Get the existing session
session_data = await database.get_chat_session(session_id) session_data = await database.get_chat_session(session_id)
if not session_data: if not session_data:
return JSONResponse( return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
session = ChatSession.model_validate(session_data) session = ChatSession.model_validate(session_data)
# Check authorization - user can only update their own sessions # Check authorization - user can only update their own sessions
if session.user_id != current_user.id: if session.user_id != current_user.id:
return JSONResponse( return JSONResponse(
status_code=403, status_code=403, content=create_error_response("FORBIDDEN", "Cannot update another user's chat session")
content=create_error_response("FORBIDDEN", "Cannot update another user's chat session")
) )
# Validate and apply updates # Validate and apply updates
@ -333,8 +304,7 @@ async def update_chat_session(
if not filtered_updates: if not filtered_updates:
return JSONResponse( return JSONResponse(
status_code=400, status_code=400, content=create_error_response("INVALID_UPDATES", "No valid fields provided for update")
content=create_error_response("INVALID_UPDATES", "No valid fields provided for update")
) )
# Apply updates to session data # Apply updates to session data
@ -388,40 +358,31 @@ async def update_chat_session(
except ValueError as ve: except ValueError as ve:
logger.warning(f"⚠️ Validation error updating chat session: {ve}") logger.warning(f"⚠️ Validation error updating chat session: {ve}")
return JSONResponse( return JSONResponse(status_code=400, content=create_error_response("VALIDATION_ERROR", str(ve)))
status_code=400,
content=create_error_response("VALIDATION_ERROR", str(ve))
)
except Exception as e: except Exception as e:
logger.error(f"❌ Update chat session error: {e}") logger.error(f"❌ Update chat session error: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("UPDATE_ERROR", str(e)))
status_code=500,
content=create_error_response("UPDATE_ERROR", str(e))
)
@router.delete("/sessions/{session_id}") @router.delete("/sessions/{session_id}")
async def delete_chat_session( async def delete_chat_session(
session_id: str = Path(...), session_id: str = Path(...),
current_user=Depends(get_current_user_or_guest), current_user=Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Delete a chat session and all its messages""" """Delete a chat session and all its messages"""
try: try:
# Get the session to verify it exists and check ownership # Get the session to verify it exists and check ownership
session_data = await database.get_chat_session(session_id) session_data = await database.get_chat_session(session_id)
if not session_data: if not session_data:
return JSONResponse( return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
session = ChatSession.model_validate(session_data) session = ChatSession.model_validate(session_data)
# Check authorization - user can only delete their own sessions # Check authorization - user can only delete their own sessions
if session.user_id != current_user.id: if session.user_id != current_user.id:
return JSONResponse( return JSONResponse(
status_code=403, status_code=403, content=create_error_response("FORBIDDEN", "Cannot delete another user's chat session")
content=create_error_response("FORBIDDEN", "Cannot delete another user's chat session")
) )
# Delete all messages associated with this session # Delete all messages associated with this session
@ -440,42 +401,34 @@ async def delete_chat_session(
logger.info(f"🗑️ Chat session {session_id} deleted by user {current_user.id}") logger.info(f"🗑️ Chat session {session_id} deleted by user {current_user.id}")
return create_success_response({ return create_success_response(
"success": True, {"success": True, "message": "Chat session deleted successfully", "sessionId": session_id}
"message": "Chat session deleted successfully", )
"sessionId": session_id
})
except Exception as e: except Exception as e:
logger.error(f"❌ Delete chat session error: {e}") logger.error(f"❌ Delete chat session error: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("DELETE_ERROR", str(e)))
status_code=500,
content=create_error_response("DELETE_ERROR", str(e))
)
@router.patch("/sessions/{session_id}/reset") @router.patch("/sessions/{session_id}/reset")
async def reset_chat_session( async def reset_chat_session(
session_id: str = Path(...), session_id: str = Path(...),
current_user=Depends(get_current_user_or_guest), current_user=Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Delete a chat session and all its messages""" """Delete a chat session and all its messages"""
try: try:
# Get the session to verify it exists and check ownership # Get the session to verify it exists and check ownership
session_data = await database.get_chat_session(session_id) session_data = await database.get_chat_session(session_id)
if not session_data: if not session_data:
return JSONResponse( return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
session = ChatSession.model_validate(session_data) session = ChatSession.model_validate(session_data)
# Check authorization - user can only delete their own sessions # Check authorization - user can only delete their own sessions
if session.user_id != current_user.id: if session.user_id != current_user.id:
return JSONResponse( return JSONResponse(
status_code=403, status_code=403, content=create_error_response("FORBIDDEN", "Cannot reset another user's chat session")
content=create_error_response("FORBIDDEN", "Cannot reset another user's chat session")
) )
# Delete all messages associated with this session # Delete all messages associated with this session
@ -489,20 +442,12 @@ async def reset_chat_session(
logger.warning(f"⚠️ Error deleting messages for session {session_id}: {e}") logger.warning(f"⚠️ Error deleting messages for session {session_id}: {e}")
# Continue with session deletion even if message deletion fails # Continue with session deletion even if message deletion fails
logger.info(f"🗑️ Chat session {session_id} reset by user {current_user.id}") logger.info(f"🗑️ Chat session {session_id} reset by user {current_user.id}")
return create_success_response({ return create_success_response(
"success": True, {"success": True, "message": "Chat session reset successfully", "sessionId": session_id}
"message": "Chat session reset successfully", )
"sessionId": session_id
})
except Exception as e: except Exception as e:
logger.error(f"❌ Reset chat session error: {e}") logger.error(f"❌ Reset chat session error: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("RESET_ERROR", str(e)))
status_code=500,
content=create_error_response("RESET_ERROR", str(e))
)

View File

@ -9,19 +9,16 @@ from fastapi.responses import JSONResponse
from database.manager import RedisDatabase from database.manager import RedisDatabase
from logger import logger from logger import logger
from utils.dependencies import ( from utils.dependencies import get_current_admin, get_database
get_current_admin, get_database
)
from utils.responses import create_success_response, create_error_response from utils.responses import create_success_response, create_error_response
# Create router for authentication endpoints # Create router for authentication endpoints
router = APIRouter(prefix="/auth", tags=["authentication"]) router = APIRouter(prefix="/auth", tags=["authentication"])
@router.get("/guest/{guest_id}") @router.get("/guest/{guest_id}")
async def debug_guest_session( async def debug_guest_session(
guest_id: str = Path(...), guest_id: str = Path(...), admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
admin_user = Depends(get_current_admin),
database: RedisDatabase = Depends(get_database)
): ):
"""Debug guest session issues (admin only)""" """Debug guest session issues (admin only)"""
try: try:
@ -45,22 +42,19 @@ async def debug_guest_session(
"primary_storage": { "primary_storage": {
"exists": primary_exists, "exists": primary_exists,
"data": json.loads(primary_data) if primary_data else None, "data": json.loads(primary_data) if primary_data else None,
"ttl": primary_ttl "ttl": primary_ttl,
}, },
"backup_storage": { "backup_storage": {
"exists": backup_exists, "exists": backup_exists,
"data": json.loads(backup_data) if backup_data else None, "data": json.loads(backup_data) if backup_data else None,
"ttl": backup_ttl "ttl": backup_ttl,
}, },
"user_lookup": user_lookup, "user_lookup": user_lookup,
"timestamp": datetime.now(UTC).isoformat() "timestamp": datetime.now(UTC).isoformat(),
} }
return create_success_response(debug_info) return create_success_response(debug_info)
except Exception as e: except Exception as e:
logger.error(f"❌ Debug guest session error: {e}") logger.error(f"❌ Debug guest session error: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("DEBUG_ERROR", str(e)))
status_code=500,
content=create_error_response("DEBUG_ERROR", str(e))
)

View File

@ -10,12 +10,8 @@ from fastapi.responses import JSONResponse
from database.manager import RedisDatabase from database.manager import RedisDatabase
from logger import logger from logger import logger
from models import ( from models import CreateEmployerRequest
CreateEmployerRequest from utils.dependencies import get_database
)
from utils.dependencies import (
get_database
)
from utils.responses import create_success_response, create_error_response from utils.responses import create_success_response, create_error_response
from email_service import email_service from email_service import email_service
from utils.auth_utils import AuthenticationManager from utils.auth_utils import AuthenticationManager
@ -23,29 +19,22 @@ from utils.auth_utils import AuthenticationManager
# Create router for job endpoints # Create router for job endpoints
router = APIRouter(prefix="/employers", tags=["employers"]) router = APIRouter(prefix="/employers", tags=["employers"])
@router.post("") @router.post("")
async def create_employer_with_verification( async def create_employer_with_verification(
request: CreateEmployerRequest, request: CreateEmployerRequest, background_tasks: BackgroundTasks, database: RedisDatabase = Depends(get_database)
background_tasks: BackgroundTasks,
database: RedisDatabase = Depends(get_database)
): ):
"""Create a new employer with email verification""" """Create a new employer with email verification"""
try: try:
# Similar to candidate creation but for employer # Similar to candidate creation but for employer
auth_manager = AuthenticationManager(database) auth_manager = AuthenticationManager(database)
user_exists, conflict_field = await auth_manager.check_user_exists( user_exists, conflict_field = await auth_manager.check_user_exists(request.email, request.username)
request.email,
request.username
)
if user_exists and conflict_field: if user_exists and conflict_field:
return JSONResponse( return JSONResponse(
status_code=409, status_code=409,
content=create_error_response( content=create_error_response("USER_EXISTS", f"A user with this {conflict_field} already exists"),
"USER_EXISTS",
f"A user with this {conflict_field} already exists"
)
) )
employer_id = str(uuid.uuid4()) employer_id = str(uuid.uuid4())
@ -64,12 +53,8 @@ async def create_employer_with_verification(
"updatedAt": current_time.isoformat(), "updatedAt": current_time.isoformat(),
"status": "pending", # Not active until verified "status": "pending", # Not active until verified
"userType": "employer", "userType": "employer",
"location": { "location": {"city": "", "country": "", "remote": False},
"city": "", "socialLinks": [],
"country": "",
"remote": False
},
"socialLinks": []
} }
verification_token = secrets.token_urlsafe(32) verification_token = secrets.token_urlsafe(32)
@ -78,32 +63,25 @@ async def create_employer_with_verification(
request.email, request.email,
verification_token, verification_token,
"employer", "employer",
{ {"employer_data": employer_data, "password": request.password, "username": request.username},
"employer_data": employer_data,
"password": request.password,
"username": request.username
}
) )
background_tasks.add_task( background_tasks.add_task(
email_service.send_verification_email, email_service.send_verification_email, request.email, verification_token, request.company_name
request.email,
verification_token,
request.company_name
) )
logger.info(f"✅ Employer registration initiated for: {request.email}") logger.info(f"✅ Employer registration initiated for: {request.email}")
return create_success_response({ return create_success_response(
{
"message": "Registration successful! Please check your email to verify your account.", "message": "Registration successful! Please check your email to verify your account.",
"email": request.email, "email": request.email,
"verificationRequired": True "verificationRequired": True,
}) }
)
except Exception as e: except Exception as e:
logger.error(f"❌ Employer creation error: {e}") logger.error(f"❌ Employer creation error: {e}")
return JSONResponse( return JSONResponse(
status_code=500, status_code=500, content=create_error_response("CREATION_FAILED", "Failed to create employer account")
content=create_error_response("CREATION_FAILED", "Failed to create employer account")
) )

View File

@ -21,14 +21,24 @@ from utils.helpers import create_job_from_content, filter_and_paginate, get_docu
from database.manager import RedisDatabase from database.manager import RedisDatabase
from logger import logger from logger import logger
from models import ( from models import (
MOCK_UUID, ApiActivityType, ApiStatusType, ChatContextType, ChatMessage, ChatMessageError, ChatMessageStatus, DocumentType, Job, JobRequirementsMessage, Candidate, Employer MOCK_UUID,
) ApiActivityType,
from utils.dependencies import ( ApiStatusType,
get_current_admin, get_database, get_current_user ChatContextType,
ChatMessage,
ChatMessageError,
ChatMessageStatus,
DocumentType,
Job,
JobRequirementsMessage,
Candidate,
Employer,
) )
from utils.dependencies import get_current_admin, get_database, get_current_user
from utils.responses import create_paginated_response, create_success_response, create_error_response from utils.responses import create_paginated_response, create_success_response, create_error_response
import utils.llm_proxy as llm_manager import utils.llm_proxy as llm_manager
import entities.entity_manager as entities import entities.entity_manager as entities
# Create router for job endpoints # Create router for job endpoints
router = APIRouter(prefix="/jobs", tags=["jobs"]) router = APIRouter(prefix="/jobs", tags=["jobs"])
@ -38,14 +48,14 @@ async def reformat_as_markdown(database: RedisDatabase, candidate_entity: Candid
if not chat_agent: if not chat_agent:
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads session_id=MOCK_UUID, # No session ID for document uploads
content="No agent found for job requirements chat type" content="No agent found for job requirements chat type",
) )
yield error_message yield error_message
return return
status_message = ChatMessageStatus( status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads session_id=MOCK_UUID, # No session ID for document uploads
content="Reformatting job description as markdown...", content="Reformatting job description as markdown...",
activity=ApiActivityType.CONVERTING activity=ApiActivityType.CONVERTING,
) )
yield status_message yield status_message
@ -58,7 +68,7 @@ async def reformat_as_markdown(database: RedisDatabase, candidate_entity: Candid
system_prompt=""" system_prompt="""
You are a document editor. Take the provided job description and reformat as legible markdown. You are a document editor. Take the provided job description and reformat as legible markdown.
Return only the markdown content, no other text. Make sure all content is included. Return only the markdown content, no other text. Make sure all content is included.
""" """,
): ):
pass pass
@ -66,7 +76,7 @@ Return only the markdown content, no other text. Make sure all content is includ
logger.error("❌ Failed to reformat job description to markdown") logger.error("❌ Failed to reformat job description to markdown")
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to reformat job description" content="Failed to reformat job description",
) )
yield error_message yield error_message
return return
@ -84,7 +94,7 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
status_message = ChatMessageStatus( status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads session_id=MOCK_UUID, # No session ID for document uploads
content=f"Initiating connection with {current_user.first_name}'s AI agent...", content=f"Initiating connection with {current_user.first_name}'s AI agent...",
activity=ApiActivityType.INFO activity=ApiActivityType.INFO,
) )
yield status_message yield status_message
await asyncio.sleep(0) # Let the status message propagate await asyncio.sleep(0) # Let the status message propagate
@ -98,7 +108,7 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
if not message or not isinstance(message, ChatMessage): if not message or not isinstance(message, ChatMessage):
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to reformat job description" content="Failed to reformat job description",
) )
yield error_message yield error_message
return return
@ -108,23 +118,20 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
if not chat_agent: if not chat_agent:
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads session_id=MOCK_UUID, # No session ID for document uploads
content="No agent found for job requirements chat type" content="No agent found for job requirements chat type",
) )
yield error_message yield error_message
return return
status_message = ChatMessageStatus( status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads session_id=MOCK_UUID, # No session ID for document uploads
content="Analyzing document for company and requirement details...", content="Analyzing document for company and requirement details...",
activity=ApiActivityType.SEARCHING activity=ApiActivityType.SEARCHING,
) )
yield status_message yield status_message
message = None message = None
async for message in chat_agent.generate( async for message in chat_agent.generate(
llm=llm_manager.get_llm(), llm=llm_manager.get_llm(), model=defines.model, session_id=MOCK_UUID, prompt=markdown_message.content
model=defines.model,
session_id=MOCK_UUID,
prompt=markdown_message.content
): ):
if message.status != ApiStatusType.DONE: if message.status != ApiStatusType.DONE:
yield message yield message
@ -132,7 +139,7 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
if not message or not isinstance(message, JobRequirementsMessage): if not message or not isinstance(message, JobRequirementsMessage):
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads session_id=MOCK_UUID, # No session ID for document uploads
content="Job extraction did not convert successfully" content="Job extraction did not convert successfully",
) )
yield error_message yield error_message
return return
@ -143,12 +150,11 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
return return
@router.post("") @router.post("")
async def create_job( async def create_job(
job_data: Dict[str, Any] = Body(...), job_data: Dict[str, Any] = Body(...),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Create a new job""" """Create a new job"""
try: try:
@ -165,17 +171,14 @@ async def create_job(
except Exception as e: except Exception as e:
logger.error(f"❌ Job creation error: {e}") logger.error(f"❌ Job creation error: {e}")
return JSONResponse( return JSONResponse(status_code=400, content=create_error_response("CREATION_FAILED", str(e)))
status_code=400,
content=create_error_response("CREATION_FAILED", str(e))
)
@router.post("") @router.post("")
async def create_candidate_job( async def create_candidate_job(
job_data: Dict[str, Any] = Body(...), job_data: Dict[str, Any] = Body(...),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Create a new job""" """Create a new job"""
isinstance(current_user, Employer) isinstance(current_user, Employer)
@ -194,10 +197,7 @@ async def create_candidate_job(
except Exception as e: except Exception as e:
logger.error(f"❌ Job creation error: {e}") logger.error(f"❌ Job creation error: {e}")
return JSONResponse( return JSONResponse(status_code=400, content=create_error_response("CREATION_FAILED", str(e)))
status_code=400,
content=create_error_response("CREATION_FAILED", str(e))
)
@router.patch("/{job_id}") @router.patch("/{job_id}")
@ -205,17 +205,14 @@ async def update_job(
job_id: str = Path(...), job_id: str = Path(...),
updates: Dict[str, Any] = Body(...), updates: Dict[str, Any] = Body(...),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Update a candidate""" """Update a candidate"""
try: try:
job_data = await database.get_job(job_id) job_data = await database.get_job(job_id)
if not job_data: if not job_data:
logger.warning(f"⚠️ Job not found for update: {job_data}") logger.warning(f"⚠️ Job not found for update: {job_data}")
return JSONResponse( return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Job not found"))
status_code=404,
content=create_error_response("NOT_FOUND", "Job not found")
)
job = Job.model_validate(job_data) job = Job.model_validate(job_data)
@ -223,8 +220,7 @@ async def update_job(
if current_user.is_admin is False and job.owner_id != current_user.id: if current_user.is_admin is False and job.owner_id != current_user.id:
logger.warning(f"⚠️ Unauthorized update attempt by user {current_user.id} on job {job_id}") logger.warning(f"⚠️ Unauthorized update attempt by user {current_user.id} on job {job_id}")
return JSONResponse( return JSONResponse(
status_code=403, status_code=403, content=create_error_response("FORBIDDEN", "Cannot update another user's job")
content=create_error_response("FORBIDDEN", "Cannot update another user's job")
) )
# Apply updates # Apply updates
@ -239,25 +235,22 @@ async def update_job(
except Exception as e: except Exception as e:
logger.error(f"❌ Update job error: {e}") logger.error(f"❌ Update job error: {e}")
return JSONResponse( return JSONResponse(status_code=400, content=create_error_response("UPDATE_FAILED", str(e)))
status_code=400,
content=create_error_response("UPDATE_FAILED", str(e))
)
@router.post("/from-content") @router.post("/from-content")
async def create_job_from_description( async def create_job_from_description(
content: str = Body(...), content: str = Body(...), current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
): ):
"""Upload a document for the current candidate""" """Upload a document for the current candidate"""
async def content_stream_generator(content): async def content_stream_generator(content):
# Verify user is a candidate # Verify user is a candidate
if current_user.user_type != "candidate": if current_user.user_type != "candidate":
logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}") logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}")
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads session_id=MOCK_UUID, # No session ID for document uploads
content="Only candidates can upload documents" content="Only candidates can upload documents",
) )
yield error_message yield error_message
return return
@ -277,10 +270,11 @@ async def create_job_from_description(
return return
try: try:
async def to_json(method): async def to_json(method):
try: try:
async for message in method: async for message in method:
json_data = message.model_dump(mode='json', by_alias=True) json_data = message.model_dump(mode="json", by_alias=True)
json_str = json.dumps(json_data) json_str = json.dumps(json_data)
yield f"data: {json_str}\n\n".encode("utf-8") yield f"data: {json_str}\n\n".encode("utf-8")
except Exception as e: except Exception as e:
@ -304,18 +298,25 @@ async def create_job_from_description(
logger.error(backstory_traceback.format_exc()) logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Document upload error: {e}") logger.error(f"❌ Document upload error: {e}")
return StreamingResponse( return StreamingResponse(
iter([json.dumps(ChatMessageError( iter(
[
json.dumps(
ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to upload document" content="Failed to upload document",
).model_dump(by_alias=True)).encode("utf-8")]), ).model_dump(by_alias=True)
media_type="text/event-stream" ).encode("utf-8")
]
),
media_type="text/event-stream",
) )
@router.post("/upload") @router.post("/upload")
async def create_job_from_file( async def create_job_from_file(
file: UploadFile = File(...), file: UploadFile = File(...),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Upload a job document for the current candidate and create a Job""" """Upload a job document for the current candidate and create a Job"""
# Check file size (limit to 10MB) # Check file size (limit to 10MB)
@ -324,55 +325,70 @@ async def create_job_from_file(
if len(file_content) > max_size: if len(file_content) > max_size:
logger.info(f"⚠️ File too large: {file.filename} ({len(file_content)} bytes)") logger.info(f"⚠️ File too large: {file.filename} ({len(file_content)} bytes)")
return StreamingResponse( return StreamingResponse(
iter([json.dumps(ChatMessageError( iter(
[
json.dumps(
ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads session_id=MOCK_UUID, # No session ID for document uploads
content="File size exceeds 10MB limit" content="File size exceeds 10MB limit",
).model_dump(by_alias=True)).encode("utf-8")]), ).model_dump(by_alias=True)
media_type="text/event-stream" ).encode("utf-8")
]
),
media_type="text/event-stream",
) )
if len(file_content) == 0: if len(file_content) == 0:
logger.info(f"⚠️ File is empty: {file.filename}") logger.info(f"⚠️ File is empty: {file.filename}")
return StreamingResponse( return StreamingResponse(
iter([json.dumps(ChatMessageError( iter(
[
json.dumps(
ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads session_id=MOCK_UUID, # No session ID for document uploads
content="File is empty" content="File is empty",
).model_dump(by_alias=True)).encode("utf-8")]), ).model_dump(by_alias=True)
media_type="text/event-stream" ).encode("utf-8")
]
),
media_type="text/event-stream",
) )
"""Upload a document for the current candidate""" """Upload a document for the current candidate"""
async def upload_stream_generator(file_content): async def upload_stream_generator(file_content):
# Verify user is a candidate # Verify user is a candidate
if current_user.user_type != "candidate": if current_user.user_type != "candidate":
logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}") logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}")
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads session_id=MOCK_UUID, # No session ID for document uploads
content="Only candidates can upload documents" content="Only candidates can upload documents",
) )
yield error_message yield error_message
return return
file.filename = re.sub(r'^.*/', '', file.filename) if file.filename else '' # Sanitize filename file.filename = re.sub(r"^.*/", "", file.filename) if file.filename else "" # Sanitize filename
if not file.filename or file.filename.strip() == "": if not file.filename or file.filename.strip() == "":
logger.warning("⚠️ File upload attempt with missing filename") logger.warning("⚠️ File upload attempt with missing filename")
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads session_id=MOCK_UUID, # No session ID for document uploads
content="File must have a valid filename" content="File must have a valid filename",
) )
yield error_message yield error_message
return return
logger.info(f"📁 Received file upload: filename='{file.filename}', content_type='{file.content_type}', size='{len(file_content)} bytes'") logger.info(
f"📁 Received file upload: filename='{file.filename}', content_type='{file.content_type}', size='{len(file_content)} bytes'"
)
# Validate file type # Validate file type
allowed_types = ['.txt', '.md', '.docx', '.pdf', '.png', '.jpg', '.jpeg', '.gif'] allowed_types = [".txt", ".md", ".docx", ".pdf", ".png", ".jpg", ".jpeg", ".gif"]
file_extension = pathlib.Path(file.filename).suffix.lower() if file.filename else "" file_extension = pathlib.Path(file.filename).suffix.lower() if file.filename else ""
if file_extension not in allowed_types: if file_extension not in allowed_types:
logger.warning(f"⚠️ Invalid file type: {file_extension} for file {file.filename}") logger.warning(f"⚠️ Invalid file type: {file_extension} for file {file.filename}")
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads session_id=MOCK_UUID, # No session ID for document uploads
content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}" content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}",
) )
yield error_message yield error_message
return return
@ -383,7 +399,7 @@ async def create_job_from_file(
status_message = ChatMessageStatus( status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads session_id=MOCK_UUID, # No session ID for document uploads
content=f"Converting content from {document_type}...", content=f"Converting content from {document_type}...",
activity=ApiActivityType.CONVERTING activity=ApiActivityType.CONVERTING,
) )
yield status_message yield status_message
try: try:
@ -391,7 +407,7 @@ async def create_job_from_file(
stream = io.BytesIO(file_content) stream = io.BytesIO(file_content)
stream_info = StreamInfo( stream_info = StreamInfo(
extension=file_extension, # e.g., ".pdf" extension=file_extension, # e.g., ".pdf"
url=file.filename # optional, helps with logging and guessing url=file.filename, # optional, helps with logging and guessing
) )
result = md.convert_stream(stream, stream_info=stream_info, output_format="markdown") result = md.convert_stream(stream, stream_info=stream_info, output_format="markdown")
file_content = result.text_content file_content = result.text_content
@ -405,15 +421,18 @@ async def create_job_from_file(
logger.error(f"❌ Error converting {file.filename} to Markdown: {e}") logger.error(f"❌ Error converting {file.filename} to Markdown: {e}")
return return
async for message in create_job_from_content(database=database, current_user=current_user, content=file_content): async for message in create_job_from_content(
database=database, current_user=current_user, content=file_content
):
yield message yield message
return return
try: try:
async def to_json(method): async def to_json(method):
try: try:
async for message in method: async for message in method:
json_data = message.model_dump(mode='json', by_alias=True) json_data = message.model_dump(mode="json", by_alias=True)
json_str = json.dumps(json_data) json_str = json.dumps(json_data)
yield f"data: {json_str}\n\n".encode("utf-8") yield f"data: {json_str}\n\n".encode("utf-8")
except Exception as e: except Exception as e:
@ -437,26 +456,27 @@ async def create_job_from_file(
logger.error(backstory_traceback.format_exc()) logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Document upload error: {e}") logger.error(f"❌ Document upload error: {e}")
return StreamingResponse( return StreamingResponse(
iter([json.dumps(ChatMessageError( iter(
[
json.dumps(
ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to upload document" content="Failed to upload document",
).model_dump(mode='json', by_alias=True)).encode("utf-8")]), ).model_dump(mode="json", by_alias=True)
media_type="text/event-stream" ).encode("utf-8")
]
),
media_type="text/event-stream",
) )
@router.get("/{job_id}") @router.get("/{job_id}")
async def get_job( async def get_job(job_id: str = Path(...), database: RedisDatabase = Depends(get_database)):
job_id: str = Path(...),
database: RedisDatabase = Depends(get_database)
):
"""Get a job by ID""" """Get a job by ID"""
try: try:
job_data = await database.get_job(job_id) job_data = await database.get_job(job_id)
if not job_data: if not job_data:
return JSONResponse( return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Job not found"))
status_code=404,
content=create_error_response("NOT_FOUND", "Job not found")
)
# Increment view count # Increment view count
job_data["views"] = job_data.get("views", 0) + 1 job_data["views"] = job_data.get("views", 0) + 1
@ -467,10 +487,8 @@ async def get_job(
except Exception as e: except Exception as e:
logger.error(f"❌ Get job error: {e}") logger.error(f"❌ Get job error: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e)))
status_code=500,
content=create_error_response("FETCH_ERROR", str(e))
)
@router.get("") @router.get("")
async def get_jobs( async def get_jobs(
@ -479,7 +497,7 @@ async def get_jobs(
sortBy: Optional[str] = Query(None, alias="sortBy"), sortBy: Optional[str] = Query(None, alias="sortBy"),
sortOrder: str = Query("desc", pattern="^(asc|desc)$", alias="sortOrder"), sortOrder: str = Query("desc", pattern="^(asc|desc)$", alias="sortOrder"),
filters: Optional[str] = Query(None), filters: Optional[str] = Query(None),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Get paginated list of jobs""" """Get paginated list of jobs"""
try: try:
@ -493,23 +511,18 @@ async def get_jobs(
for job in all_jobs_data.values(): for job in all_jobs_data.values():
jobs_list.append(Job.model_validate(job)) jobs_list.append(Job.model_validate(job))
paginated_jobs, total = filter_and_paginate( paginated_jobs, total = filter_and_paginate(jobs_list, page, limit, sortBy, sortOrder, filter_dict)
jobs_list, page, limit, sortBy, sortOrder, filter_dict
)
paginated_response = create_paginated_response( paginated_response = create_paginated_response(
[j.model_dump(by_alias=True) for j in paginated_jobs], [j.model_dump(by_alias=True) for j in paginated_jobs], page, limit, total
page, limit, total
) )
return create_success_response(paginated_response) return create_success_response(paginated_response)
except Exception as e: except Exception as e:
logger.error(f"❌ Get jobs error: {e}") logger.error(f"❌ Get jobs error: {e}")
return JSONResponse( return JSONResponse(status_code=400, content=create_error_response("FETCH_FAILED", str(e)))
status_code=400,
content=create_error_response("FETCH_FAILED", str(e))
)
@router.get("/search") @router.get("/search")
async def search_jobs( async def search_jobs(
@ -517,7 +530,7 @@ async def search_jobs(
filters: Optional[str] = Query(None), filters: Optional[str] = Query(None),
page: int = Query(1, ge=1), page: int = Query(1, ge=1),
limit: int = Query(20, ge=1, le=100), limit: int = Query(20, ge=1, le=100),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Search jobs""" """Search jobs"""
try: try:
@ -532,69 +545,52 @@ async def search_jobs(
if query: if query:
query_lower = query.lower() query_lower = query.lower()
jobs_list = [ jobs_list = [
j for j in jobs_list j
if ((j.title and query_lower in j.title.lower()) or for j in jobs_list
(j.description and query_lower in j.description.lower()) or if (
any(query_lower in skill.lower() for skill in getattr(j, "skills", []) or [])) (j.title and query_lower in j.title.lower())
or (j.description and query_lower in j.description.lower())
or any(query_lower in skill.lower() for skill in getattr(j, "skills", []) or [])
)
] ]
paginated_jobs, total = filter_and_paginate( paginated_jobs, total = filter_and_paginate(jobs_list, page, limit, filters=filter_dict)
jobs_list, page, limit, filters=filter_dict
)
paginated_response = create_paginated_response( paginated_response = create_paginated_response(
[j.model_dump(by_alias=True) for j in paginated_jobs], [j.model_dump(by_alias=True) for j in paginated_jobs], page, limit, total
page, limit, total
) )
return create_success_response(paginated_response) return create_success_response(paginated_response)
except Exception as e: except Exception as e:
logger.error(f"❌ Search jobs error: {e}") logger.error(f"❌ Search jobs error: {e}")
return JSONResponse( return JSONResponse(status_code=400, content=create_error_response("SEARCH_FAILED", str(e)))
status_code=400,
content=create_error_response("SEARCH_FAILED", str(e))
)
@router.delete("/{job_id}") @router.delete("/{job_id}")
async def delete_job( async def delete_job(
job_id: str = Path(...), job_id: str = Path(...), admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
admin_user = Depends(get_current_admin),
database: RedisDatabase = Depends(get_database)
): ):
"""Delete a Job""" """Delete a Job"""
try: try:
# Check if admin user # Check if admin user
if not admin_user.is_admin: if not admin_user.is_admin:
logger.warning(f"⚠️ Unauthorized delete attempt by user {admin_user.id}") logger.warning(f"⚠️ Unauthorized delete attempt by user {admin_user.id}")
return JSONResponse( return JSONResponse(status_code=403, content=create_error_response("FORBIDDEN", "Only admins can delete"))
status_code=403,
content=create_error_response("FORBIDDEN", "Only admins can delete")
)
# Get candidate data # Get candidate data
job_data = await database.get_job(job_id) job_data = await database.get_job(job_id)
if not job_data: if not job_data:
logger.warning(f"⚠️ Candidate not found for deletion: {job_id}") logger.warning(f"⚠️ Candidate not found for deletion: {job_id}")
return JSONResponse( return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Job not found"))
status_code=404,
content=create_error_response("NOT_FOUND", "Job not found")
)
# Delete job from database # Delete job from database
await database.delete_job(job_id) await database.delete_job(job_id)
logger.info(f"🗑️ Job deleted: {job_id} by admin {admin_user.id}") logger.info(f"🗑️ Job deleted: {job_id} by admin {admin_user.id}")
return create_success_response({ return create_success_response({"message": "Job deleted successfully", "jobId": job_id})
"message": "Job deleted successfully",
"jobId": job_id
})
except Exception as e: except Exception as e:
logger.error(f"❌ Delete job error: {e}") logger.error(f"❌ Delete job error: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("DELETE_ERROR", "Failed to delete job"))
status_code=500,
content=create_error_response("DELETE_ERROR", "Failed to delete job")
)

View File

@ -6,6 +6,7 @@ from utils.llm_proxy import LLMProvider, get_llm
router = APIRouter(prefix="/providers", tags=["providers"]) router = APIRouter(prefix="/providers", tags=["providers"])
@router.get("/models") @router.get("/models")
async def list_models(provider: Optional[str] = None): async def list_models(provider: Optional[str] = None):
"""List available models for a provider""" """List available models for a provider"""
@ -17,29 +18,25 @@ async def list_models(provider: Optional[str] = None):
try: try:
provider_enum = LLMProvider(provider.lower()) provider_enum = LLMProvider(provider.lower())
except ValueError: except ValueError:
raise HTTPException( raise HTTPException(status_code=400, detail=f"Unsupported provider: {provider}")
status_code=400,
detail=f"Unsupported provider: {provider}"
)
models = await llm.list_models(provider_enum) models = await llm.list_models(provider_enum)
return { return {"provider": provider_enum.value if provider_enum else llm.default_provider.value, "models": models}
"provider": provider_enum.value if provider_enum else llm.default_provider.value,
"models": models
}
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.get("") @router.get("")
async def list_providers(): async def list_providers():
"""List all configured providers""" """List all configured providers"""
llm = get_llm() llm = get_llm()
return { return {
"providers": [provider.value for provider in llm._initialized_providers], "providers": [provider.value for provider in llm._initialized_providers],
"default": llm.default_provider.value "default": llm.default_provider.value,
} }
@router.post("/{provider}/set-default") @router.post("/{provider}/set-default")
async def set_default_provider(provider: str): async def set_default_provider(provider: str):
"""Set the default provider""" """Set the default provider"""
@ -51,6 +48,7 @@ async def set_default_provider(provider: str):
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
# Health check endpoint # Health check endpoint
@router.get("/health") @router.get("/health")
async def health_check(): async def health_check():
@ -59,5 +57,5 @@ async def health_check():
return { return {
"status": "healthy", "status": "healthy",
"providers_configured": len(llm._initialized_providers), "providers_configured": len(llm._initialized_providers),
"default_provider": llm.default_provider.value "default_provider": llm.default_provider.value,
} }

View File

@ -11,26 +11,24 @@ from fastapi.responses import StreamingResponse
import backstory_traceback as backstory_traceback import backstory_traceback as backstory_traceback
from database.manager import RedisDatabase from database.manager import RedisDatabase
from logger import logger from logger import logger
from models import ( from models import MOCK_UUID, ChatMessageError, Job, Candidate, Resume, ResumeMessage
MOCK_UUID, ChatMessageError, Job, Candidate, Resume, ResumeMessage from utils.dependencies import get_database, get_current_user
)
from utils.dependencies import (
get_database, get_current_user
)
from utils.responses import create_success_response from utils.responses import create_success_response
# Create router for authentication endpoints # Create router for authentication endpoints
router = APIRouter(prefix="/resumes", tags=["resumes"]) router = APIRouter(prefix="/resumes", tags=["resumes"])
@router.post("/{candidate_id}/{job_id}") @router.post("/{candidate_id}/{job_id}")
async def create_candidate_resume( async def create_candidate_resume(
candidate_id: str = Path(..., description="ID of the candidate"), candidate_id: str = Path(..., description="ID of the candidate"),
job_id: str = Path(..., description="ID of the job"), job_id: str = Path(..., description="ID of the job"),
resume_content: str = Body(...), resume_content: str = Body(...),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Create a new resume for a candidate/job combination""" """Create a new resume for a candidate/job combination"""
async def message_stream_generator(): async def message_stream_generator():
logger.info(f"🔍 Looking up candidate and job details for {candidate_id}/{job_id}") logger.info(f"🔍 Looking up candidate and job details for {candidate_id}/{job_id}")
@ -39,7 +37,7 @@ async def create_candidate_resume(
logger.error(f"❌ Candidate with ID '{candidate_id}' not found") logger.error(f"❌ Candidate with ID '{candidate_id}' not found")
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads session_id=MOCK_UUID, # No session ID for document uploads
content=f"Candidate with ID '{candidate_id}' not found" content=f"Candidate with ID '{candidate_id}' not found",
) )
yield error_message yield error_message
return return
@ -50,13 +48,15 @@ async def create_candidate_resume(
logger.error(f"❌ Job with ID '{job_id}' not found") logger.error(f"❌ Job with ID '{job_id}' not found")
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads session_id=MOCK_UUID, # No session ID for document uploads
content=f"Job with ID '{job_id}' not found" content=f"Job with ID '{job_id}' not found",
) )
yield error_message yield error_message
return return
job = Job.model_validate(job_data) job = Job.model_validate(job_data)
logger.info(f"📄 Saving resume for candidate {candidate.first_name} {candidate.last_name} for job '{job.title}'") logger.info(
f"📄 Saving resume for candidate {candidate.first_name} {candidate.last_name} for job '{job.title}'"
)
# Job and Candidate are valid. Save the resume # Job and Candidate are valid. Save the resume
resume = Resume( resume = Resume(
@ -66,16 +66,13 @@ async def create_candidate_resume(
) )
resume_message: ResumeMessage = ResumeMessage( resume_message: ResumeMessage = ResumeMessage(
session_id=MOCK_UUID, # No session ID for document uploads session_id=MOCK_UUID, # No session ID for document uploads
resume=resume resume=resume,
) )
# Save to database # Save to database
success = await database.set_resume(current_user.id, resume.model_dump()) success = await database.set_resume(current_user.id, resume.model_dump())
if not success: if not success:
error_message = ChatMessageError( error_message = ChatMessageError(session_id=MOCK_UUID, content="Failed to save resume to database")
session_id=MOCK_UUID,
content="Failed to save resume to database"
)
yield error_message yield error_message
return return
@ -84,10 +81,11 @@ async def create_candidate_resume(
return return
try: try:
async def to_json(method): async def to_json(method):
try: try:
async for message in method: async for message in method:
json_data = message.model_dump(mode='json', by_alias=True) json_data = message.model_dump(mode="json", by_alias=True)
json_str = json.dumps(json_data) json_str = json.dumps(json_data)
yield f"data: {json_str}\n\n".encode("utf-8") yield f"data: {json_str}\n\n".encode("utf-8")
except Exception as e: except Exception as e:
@ -111,18 +109,22 @@ async def create_candidate_resume(
logger.error(backstory_traceback.format_exc()) logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Resume creation error: {e}") logger.error(f"❌ Resume creation error: {e}")
return StreamingResponse( return StreamingResponse(
iter([json.dumps(ChatMessageError( iter(
[
json.dumps(
ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to create resume" content="Failed to create resume",
).model_dump(mode='json', by_alias=True))]), ).model_dump(mode="json", by_alias=True)
media_type="text/event-stream" )
]
),
media_type="text/event-stream",
) )
@router.get("") @router.get("")
async def get_user_resumes( async def get_user_resumes(current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)):
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
):
"""Get all resumes for the current user""" """Get all resumes for the current user"""
try: try:
resumes_data = await database.get_all_resumes_for_user(current_user.id) resumes_data = await database.get_all_resumes_for_user(current_user.id)
@ -135,19 +137,17 @@ async def get_user_resumes(
if candidate_data: if candidate_data:
resume.candidate = Candidate.model_validate(candidate_data) resume.candidate = Candidate.model_validate(candidate_data)
resumes.sort(key=lambda x: x.updated_at, reverse=True) # Sort by creation date resumes.sort(key=lambda x: x.updated_at, reverse=True) # Sort by creation date
return create_success_response({ return create_success_response({"resumes": resumes, "count": len(resumes)})
"resumes": resumes,
"count": len(resumes)
})
except Exception as e: except Exception as e:
logger.error(f"❌ Error retrieving resumes for user {current_user.id}: {e}") logger.error(f"❌ Error retrieving resumes for user {current_user.id}: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve resumes") raise HTTPException(status_code=500, detail="Failed to retrieve resumes")
@router.get("/{resume_id}") @router.get("/{resume_id}")
async def get_resume( async def get_resume(
resume_id: str = Path(..., description="ID of the resume"), resume_id: str = Path(..., description="ID of the resume"),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Get a specific resume by ID""" """Get a specific resume by ID"""
try: try:
@ -155,21 +155,19 @@ async def get_resume(
if not resume: if not resume:
raise HTTPException(status_code=404, detail="Resume not found") raise HTTPException(status_code=404, detail="Resume not found")
return { return {"success": True, "resume": resume}
"success": True,
"resume": resume
}
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
logger.error(f"❌ Error retrieving resume {resume_id} for user {current_user.id}: {e}") logger.error(f"❌ Error retrieving resume {resume_id} for user {current_user.id}: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve resume") raise HTTPException(status_code=500, detail="Failed to retrieve resume")
@router.delete("/{resume_id}") @router.delete("/{resume_id}")
async def delete_resume( async def delete_resume(
resume_id: str = Path(..., description="ID of the resume"), resume_id: str = Path(..., description="ID of the resume"),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Delete a specific resume""" """Delete a specific resume"""
try: try:
@ -177,102 +175,82 @@ async def delete_resume(
if not success: if not success:
raise HTTPException(status_code=404, detail="Resume not found") raise HTTPException(status_code=404, detail="Resume not found")
return { return {"success": True, "message": f"Resume {resume_id} deleted successfully"}
"success": True,
"message": f"Resume {resume_id} deleted successfully"
}
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
logger.error(f"❌ Error deleting resume {resume_id} for user {current_user.id}: {e}") logger.error(f"❌ Error deleting resume {resume_id} for user {current_user.id}: {e}")
raise HTTPException(status_code=500, detail="Failed to delete resume") raise HTTPException(status_code=500, detail="Failed to delete resume")
@router.get("/candidate/{candidate_id}") @router.get("/candidate/{candidate_id}")
async def get_resumes_by_candidate( async def get_resumes_by_candidate(
candidate_id: str = Path(..., description="ID of the candidate"), candidate_id: str = Path(..., description="ID of the candidate"),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Get all resumes for a specific candidate""" """Get all resumes for a specific candidate"""
try: try:
resumes = await database.get_resumes_by_candidate(current_user.id, candidate_id) resumes = await database.get_resumes_by_candidate(current_user.id, candidate_id)
return { return {"success": True, "candidate_id": candidate_id, "resumes": resumes, "count": len(resumes)}
"success": True,
"candidate_id": candidate_id,
"resumes": resumes,
"count": len(resumes)
}
except Exception as e: except Exception as e:
logger.error(f"❌ Error retrieving resumes for candidate {candidate_id}: {e}") logger.error(f"❌ Error retrieving resumes for candidate {candidate_id}: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve candidate resumes") raise HTTPException(status_code=500, detail="Failed to retrieve candidate resumes")
@router.get("/job/{job_id}") @router.get("/job/{job_id}")
async def get_resumes_by_job( async def get_resumes_by_job(
job_id: str = Path(..., description="ID of the job"), job_id: str = Path(..., description="ID of the job"),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Get all resumes for a specific job""" """Get all resumes for a specific job"""
try: try:
resumes = await database.get_resumes_by_job(current_user.id, job_id) resumes = await database.get_resumes_by_job(current_user.id, job_id)
return { return {"success": True, "job_id": job_id, "resumes": resumes, "count": len(resumes)}
"success": True,
"job_id": job_id,
"resumes": resumes,
"count": len(resumes)
}
except Exception as e: except Exception as e:
logger.error(f"❌ Error retrieving resumes for job {job_id}: {e}") logger.error(f"❌ Error retrieving resumes for job {job_id}: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve job resumes") raise HTTPException(status_code=500, detail="Failed to retrieve job resumes")
@router.get("/search") @router.get("/search")
async def search_resumes( async def search_resumes(
q: str = Query(..., description="Search query"), q: str = Query(..., description="Search query"),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Search resumes by content""" """Search resumes by content"""
try: try:
resumes = await database.search_resumes_for_user(current_user.id, q) resumes = await database.search_resumes_for_user(current_user.id, q)
return { return {"success": True, "query": q, "resumes": resumes, "count": len(resumes)}
"success": True,
"query": q,
"resumes": resumes,
"count": len(resumes)
}
except Exception as e: except Exception as e:
logger.error(f"❌ Error searching resumes for user {current_user.id}: {e}") logger.error(f"❌ Error searching resumes for user {current_user.id}: {e}")
raise HTTPException(status_code=500, detail="Failed to search resumes") raise HTTPException(status_code=500, detail="Failed to search resumes")
@router.get("/stats") @router.get("/stats")
async def get_resume_statistics( async def get_resume_statistics(
current_user = Depends(get_current_user), current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database)
): ):
"""Get resume statistics for the current user""" """Get resume statistics for the current user"""
try: try:
stats = await database.get_resume_statistics(current_user.id) stats = await database.get_resume_statistics(current_user.id)
return { return {"success": True, "statistics": stats}
"success": True,
"statistics": stats
}
except Exception as e: except Exception as e:
logger.error(f"❌ Error retrieving resume statistics for user {current_user.id}: {e}") logger.error(f"❌ Error retrieving resume statistics for user {current_user.id}: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve resume statistics") raise HTTPException(status_code=500, detail="Failed to retrieve resume statistics")
@router.put("/{resume_id}") @router.put("/{resume_id}")
async def update_resume( async def update_resume(
resume_id: str = Path(..., description="ID of the resume"), resume_id: str = Path(..., description="ID of the resume"),
resume: str = Body(..., description="Updated resume content"), resume: str = Body(..., description="Updated resume content"),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database) database: RedisDatabase = Depends(get_database),
): ):
"""Update the content of a specific resume""" """Update the content of a specific resume"""
try: try:
updates = { updates = {"resume": resume, "updated_at": datetime.now(UTC).isoformat()}
"resume": resume,
"updated_at": datetime.now(UTC).isoformat()
}
updated_resume_data = await database.update_resume(current_user.id, resume_id, updates) updated_resume_data = await database.update_resume(current_user.id, resume_id, updates)
if not updated_resume_data: if not updated_resume_data:
@ -280,11 +258,9 @@ async def update_resume(
raise HTTPException(status_code=404, detail="Resume not found") raise HTTPException(status_code=404, detail="Resume not found")
updated_resume = Resume.model_validate(updated_resume_data) if updated_resume_data else None updated_resume = Resume.model_validate(updated_resume_data) if updated_resume_data else None
return create_success_response({ return create_success_response(
"success": True, {"success": True, "message": f"Resume {resume_id} updated successfully", "resume": updated_resume}
"message": f"Resume {resume_id} updated successfully", )
"resume": updated_resume
})
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:

View File

@ -10,6 +10,7 @@ from utils.responses import create_success_response
# Create router for authentication endpoints # Create router for authentication endpoints
router = APIRouter(prefix="/system", tags=["system"]) router = APIRouter(prefix="/system", tags=["system"])
async def get_redis() -> Redis: async def get_redis() -> Redis:
"""Dependency to get Redis client""" """Dependency to get Redis client"""
return redis_manager.get_client() return redis_manager.get_client()
@ -19,9 +20,11 @@ async def get_redis() -> Redis:
async def get_system_info(request: Request): async def get_system_info(request: Request):
"""Get system information""" """Get system information"""
from system_info import system_info # Import system_info function from system_info module from system_info import system_info # Import system_info function from system_info module
system = system_info() system = system_info()
return create_success_response(system.model_dump(mode='json')) return create_success_response(system.model_dump(mode="json"))
@router.get("/redis/stats") @router.get("/redis/stats")
async def redis_stats(redis: Redis = Depends(get_redis)): async def redis_stats(redis: Redis = Depends(get_redis)):
@ -33,7 +36,7 @@ async def redis_stats(redis: Redis = Depends(get_redis)):
"total_commands_processed": info.get("total_commands_processed"), "total_commands_processed": info.get("total_commands_processed"),
"keyspace_hits": info.get("keyspace_hits"), "keyspace_hits": info.get("keyspace_hits"),
"keyspace_misses": info.get("keyspace_misses"), "keyspace_misses": info.get("keyspace_misses"),
"uptime_in_seconds": info.get("uptime_in_seconds") "uptime_in_seconds": info.get("uptime_in_seconds"),
} }
except Exception as e: except Exception as e:
raise HTTPException(status_code=503, detail=f"Redis stats unavailable: {e}") raise HTTPException(status_code=503, detail=f"Redis stats unavailable: {e}")

View File

@ -7,23 +7,17 @@ from fastapi.responses import JSONResponse
from database.manager import RedisDatabase from database.manager import RedisDatabase
from logger import logger from logger import logger
from models import ( from models import BaseUserWithType
BaseUserWithType from utils.dependencies import get_database
)
from utils.dependencies import (
get_database
)
from utils.responses import create_success_response, create_error_response from utils.responses import create_success_response, create_error_response
# Create router for job endpoints # Create router for job endpoints
router = APIRouter(prefix="/users", tags=["users"]) router = APIRouter(prefix="/users", tags=["users"])
# reference can be candidateId, username, or email # reference can be candidateId, username, or email
@router.get("/users/{reference}") @router.get("/users/{reference}")
async def get_user( async def get_user(reference: str = Path(...), database: RedisDatabase = Depends(get_database)):
reference: str = Path(...),
database: RedisDatabase = Depends(get_database)
):
"""Get a candidate by username""" """Get a candidate by username"""
try: try:
# Normalize reference to lowercase for case-insensitive search # Normalize reference to lowercase for case-insensitive search
@ -32,16 +26,15 @@ async def get_user(
all_candidate_data = await database.get_all_candidates() all_candidate_data = await database.get_all_candidates()
if not all_candidate_data: if not all_candidate_data:
logger.warning("⚠️ No users found in database") logger.warning("⚠️ No users found in database")
return JSONResponse( return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "No users found"))
status_code=404,
content=create_error_response("NOT_FOUND", "No users found")
)
user_data = None user_data = None
for user in all_candidate_data.values(): for user in all_candidate_data.values():
if (user.get("id", "").lower() == query_lower or if (
user.get("username", "").lower() == query_lower or user.get("id", "").lower() == query_lower
user.get("email", "").lower() == query_lower): or user.get("username", "").lower() == query_lower
or user.get("email", "").lower() == query_lower
):
user_data = user user_data = user
break break
@ -49,23 +42,19 @@ async def get_user(
all_guest_data = await database.get_all_guests() all_guest_data = await database.get_all_guests()
if not all_guest_data: if not all_guest_data:
logger.warning("⚠️ No guests found in database") logger.warning("⚠️ No guests found in database")
return JSONResponse( return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "No users found"))
status_code=404,
content=create_error_response("NOT_FOUND", "No users found")
)
for user in all_guest_data.values(): for user in all_guest_data.values():
if (user.get("id", "").lower() == query_lower or if (
user.get("username", "").lower() == query_lower or user.get("id", "").lower() == query_lower
user.get("email", "").lower() == query_lower): or user.get("username", "").lower() == query_lower
or user.get("email", "").lower() == query_lower
):
user_data = user user_data = user
break break
if not user_data: if not user_data:
logger.warning(f"⚠️ User nor Guest found for reference: {reference}") logger.warning(f"⚠️ User nor Guest found for reference: {reference}")
return JSONResponse( return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "User not found"))
status_code=404,
content=create_error_response("NOT_FOUND", "User not found")
)
user = BaseUserWithType.model_validate(user_data) user = BaseUserWithType.model_validate(user_data)
@ -73,8 +62,4 @@ async def get_user(
except Exception as e: except Exception as e:
logger.error(f"❌ Get user error: {e}") logger.error(f"❌ Get user error: {e}")
return JSONResponse( return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e)))
status_code=500,
content=create_error_response("FETCH_ERROR", str(e))
)

View File

@ -4,6 +4,7 @@ import subprocess
import math import math
from models import SystemInfo from models import SystemInfo
def get_installed_ram(): def get_installed_ram():
try: try:
with open("/proc/meminfo", "r") as f: with open("/proc/meminfo", "r") as f:
@ -19,9 +20,7 @@ def get_graphics_cards():
gpus = [] gpus = []
try: try:
# Run the ze-monitor utility # Run the ze-monitor utility
result = subprocess.run( result = subprocess.run(["ze-monitor"], capture_output=True, text=True, check=True)
["ze-monitor"], capture_output=True, text=True, check=True
)
# Clean up the output (remove leading/trailing whitespace and newlines) # Clean up the output (remove leading/trailing whitespace and newlines)
output = result.stdout.strip() output = result.stdout.strip()
@ -71,6 +70,7 @@ def get_cpu_info():
except Exception as e: except Exception as e:
return f"Error retrieving CPU info: {e}" return f"Error retrieving CPU info: {e}"
def system_info() -> SystemInfo: def system_info() -> SystemInfo:
""" """
Collects system information including RAM, GPU, CPU, LLM model, embedding model, and context length. Collects system information including RAM, GPU, CPU, LLM model, embedding model, and context length.

View File

@ -122,9 +122,7 @@ def get_forecast(grid_endpoint):
# Process the forecast data into a simpler format # Process the forecast data into a simpler format
forecast = { forecast = {
"location": data["properties"] "location": data["properties"].get("relativeLocation", {}).get("properties", {}),
.get("relativeLocation", {})
.get("properties", {}),
"updated": data["properties"].get("updated", ""), "updated": data["properties"].get("updated", ""),
"periods": [], "periods": [],
} }
@ -189,9 +187,7 @@ def TickerValue(ticker_symbols):
if ticker_symbol == "": if ticker_symbol == "":
continue continue
url = ( url = f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}"
f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}"
)
response = requests.get(url) response = requests.get(url)
data = response.json() data = response.json()
@ -244,9 +240,7 @@ def yfTickerValue(ticker_symbols):
logging.error(f"Error fetching data for {ticker_symbol}: {e}") logging.error(f"Error fetching data for {ticker_symbol}: {e}")
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
results.append( results.append({"error": f"Error fetching data for {ticker_symbol}: {str(e)}"})
{"error": f"Error fetching data for {ticker_symbol}: {str(e)}"}
)
return results[0] if len(results) == 1 else results return results[0] if len(results) == 1 else results
@ -278,9 +272,11 @@ def DateTime(timezone="America/Los_Angeles"):
except Exception as e: except Exception as e:
return {"error": f"Invalid timezone {timezone}: {str(e)}"} return {"error": f"Invalid timezone {timezone}: {str(e)}"}
async def GenerateImage(llm, model: str, prompt: str): async def GenerateImage(llm, model: str, prompt: str):
return {"image_id": "image-a830a83-bd831"} return {"image_id": "image-a830a83-bd831"}
async def AnalyzeSite(llm, model: str, url: str, question: str): async def AnalyzeSite(llm, model: str, url: str, question: str):
""" """
Fetches content from a URL, extracts the text, and uses Ollama to summarize it. Fetches content from a URL, extracts the text, and uses Ollama to summarize it.
@ -347,7 +343,6 @@ async def AnalyzeSite(llm, model: str, url: str, question: str):
return f"Error processing the website content: {str(e)}" return f"Error processing the website content: {str(e)}"
# %% # %%
class Function(BaseModel): class Function(BaseModel):
name: str name: str
@ -355,10 +350,12 @@ class Function(BaseModel):
parameters: Dict[str, Any] parameters: Dict[str, Any]
returns: Optional[Dict[str, Any]] = {} returns: Optional[Dict[str, Any]] = {}
class Tool(BaseModel): class Tool(BaseModel):
type: str type: str
function: Function function: Function
tools: List[Tool] = [ tools: List[Tool] = [
# Tool.model_validate({ # Tool.model_validate({
# "type": "function", # "type": "function",
@ -366,24 +363,20 @@ tools : List[Tool] = [
# "name": "GenerateImage", # "name": "GenerateImage",
# "description": """\ # "description": """\
# CRITICAL INSTRUCTIONS FOR IMAGE GENERATION: # CRITICAL INSTRUCTIONS FOR IMAGE GENERATION:
# 1. Call this tool when users request images, drawings, or visual content # 1. Call this tool when users request images, drawings, or visual content
# 2. This tool returns an image_id (e.g., "img_abc123") # 2. This tool returns an image_id (e.g., "img_abc123")
# 3. MANDATORY: You must respond with EXACTLY this format: <GenerateImage id={image_id}/> # 3. MANDATORY: You must respond with EXACTLY this format: <GenerateImage id={image_id}/>
# 4. FORBIDDEN: DO NOT use markdown image syntax ![](url) # 4. FORBIDDEN: DO NOT use markdown image syntax ![](url)
# 5. FORBIDDEN: DO NOT create fake URLs or file paths # 5. FORBIDDEN: DO NOT create fake URLs or file paths
# 6. FORBIDDEN: DO NOT use any other image embedding format # 6. FORBIDDEN: DO NOT use any other image embedding format
# CORRECT EXAMPLE: # CORRECT EXAMPLE:
# User: "Draw a cat" # User: "Draw a cat"
# Tool returns: {"image_id": "img_xyz789"} # Tool returns: {"image_id": "img_xyz789"}
# Your response: "Here's your cat image: <GenerateImage id=img_xyz789/>" # Your response: "Here's your cat image: <GenerateImage id=img_xyz789/>"
# WRONG EXAMPLES (DO NOT DO THIS): # WRONG EXAMPLES (DO NOT DO THIS):
# - ![](https://example.com/...) # - ![](https://example.com/...)
# - ![Cat image](any_url) # - ![Cat image](any_url)
# - <img src="..."> # - <img src="...">
# The <GenerateImage id={image_id}/> format is the ONLY way to display images in this system. # The <GenerateImage id={image_id}/> format is the ONLY way to display images in this system.
# """, # """,
# "parameters": { # "parameters": {
@ -407,7 +400,8 @@ tools : List[Tool] = [
# } # }
# } # }
# }), # }),
Tool.model_validate({ Tool.model_validate(
{
"type": "function", "type": "function",
"function": { "function": {
"name": "TickerValue", "name": "TickerValue",
@ -424,8 +418,10 @@ tools : List[Tool] = [
"additionalProperties": False, "additionalProperties": False,
}, },
}, },
}), }
Tool.model_validate({ ),
Tool.model_validate(
{
"type": "function", "type": "function",
"function": { "function": {
"name": "AnalyzeSite", "name": "AnalyzeSite",
@ -463,8 +459,10 @@ tools : List[Tool] = [
}, },
}, },
}, },
}), }
Tool.model_validate({ ),
Tool.model_validate(
{
"type": "function", "type": "function",
"function": { "function": {
"name": "DateTime", "name": "DateTime",
@ -480,8 +478,10 @@ tools : List[Tool] = [
"required": [], "required": [],
}, },
}, },
}), }
Tool.model_validate({ ),
Tool.model_validate(
{
"type": "function", "type": "function",
"function": { "function": {
"name": "WeatherForecast", "name": "WeatherForecast",
@ -504,22 +504,27 @@ tools : List[Tool] = [
"additionalProperties": False, "additionalProperties": False,
}, },
}, },
}), }
),
] ]
class ToolEntry(BaseModel): class ToolEntry(BaseModel):
enabled: bool = True enabled: bool = True
tool: Tool tool: Tool
def llm_tools(tools: List[ToolEntry]) -> List[Dict[str, Any]]: def llm_tools(tools: List[ToolEntry]) -> List[Dict[str, Any]]:
return [entry.tool.model_dump(mode='json') for entry in tools if entry.enabled is True] return [entry.tool.model_dump(mode="json") for entry in tools if entry.enabled is True]
def all_tools() -> List[ToolEntry]: def all_tools() -> List[ToolEntry]:
return [ToolEntry(tool=tool) for tool in tools] return [ToolEntry(tool=tool) for tool in tools]
def enabled_tools(tools: List[ToolEntry]) -> List[ToolEntry]: def enabled_tools(tools: List[ToolEntry]) -> List[ToolEntry]:
return [ToolEntry(tool=entry.tool) for entry in tools if entry.enabled is True] return [ToolEntry(tool=entry.tool) for entry in tools if entry.enabled is True]
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite", "GenerateImage"] tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite", "GenerateImage"]
__all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"] __all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"]

View File

@ -14,6 +14,7 @@ from pydantic import BaseModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PasswordSecurity: class PasswordSecurity:
"""Handles password hashing and verification using bcrypt""" """Handles password hashing and verification using bcrypt"""
@ -32,9 +33,9 @@ class PasswordSecurity:
salt = bcrypt.gensalt() salt = bcrypt.gensalt()
# Hash the password # Hash the password
password_hash = bcrypt.hashpw(password.encode('utf-8'), salt) password_hash = bcrypt.hashpw(password.encode("utf-8"), salt)
return password_hash.decode('utf-8'), salt.decode('utf-8') return password_hash.decode("utf-8"), salt.decode("utf-8")
@staticmethod @staticmethod
def verify_password(password: str, password_hash: str) -> bool: def verify_password(password: str, password_hash: str) -> bool:
@ -49,10 +50,7 @@ class PasswordSecurity:
True if password matches, False otherwise True if password matches, False otherwise
""" """
try: try:
return bcrypt.checkpw( return bcrypt.checkpw(password.encode("utf-8"), password_hash.encode("utf-8"))
password.encode('utf-8'),
password_hash.encode('utf-8')
)
except Exception as e: except Exception as e:
logger.error(f"Password verification error: {e}") logger.error(f"Password verification error: {e}")
return False return False
@ -62,8 +60,10 @@ class PasswordSecurity:
"""Generate a cryptographically secure random token""" """Generate a cryptographically secure random token"""
return secrets.token_urlsafe(length) return secrets.token_urlsafe(length)
class AuthenticationRecord(BaseModel): class AuthenticationRecord(BaseModel):
"""Authentication record for storing user credentials""" """Authentication record for storing user credentials"""
user_id: str user_id: str
password_hash: str password_hash: str
salt: str salt: str
@ -78,18 +78,19 @@ class AuthenticationRecord(BaseModel):
locked_until: Optional[datetime] = None locked_until: Optional[datetime] = None
class Config: class Config:
json_encoders = { json_encoders = {datetime: lambda v: v.isoformat() if v else None}
datetime: lambda v: v.isoformat() if v else None
}
class SecurityConfig: class SecurityConfig:
"""Security configuration constants""" """Security configuration constants"""
MAX_LOGIN_ATTEMPTS = 5 MAX_LOGIN_ATTEMPTS = 5
ACCOUNT_LOCKOUT_DURATION_MINUTES = 15 ACCOUNT_LOCKOUT_DURATION_MINUTES = 15
PASSWORD_MIN_LENGTH = 8 PASSWORD_MIN_LENGTH = 8
TOKEN_EXPIRY_HOURS = 24 TOKEN_EXPIRY_HOURS = 24
REFRESH_TOKEN_EXPIRY_DAYS = 30 REFRESH_TOKEN_EXPIRY_DAYS = 30
class AuthenticationManager: class AuthenticationManager:
"""Manages authentication operations with security features""" """Manages authentication operations with security features"""
@ -120,7 +121,7 @@ class AuthenticationManager:
password_hash=password_hash, password_hash=password_hash,
salt=salt, salt=salt,
last_password_change=datetime.now(timezone.utc), last_password_change=datetime.now(timezone.utc),
login_attempts=0 login_attempts=0,
) )
# Store in database # Store in database
@ -129,7 +130,9 @@ class AuthenticationManager:
logger.info(f"🔐 Created authentication record for user {user_id}") logger.info(f"🔐 Created authentication record for user {user_id}")
return auth_record return auth_record
async def verify_user_credentials(self, login: str, password: str) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]: async def verify_user_credentials(
self, login: str, password: str
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
""" """
Verify user credentials with security checks Verify user credentials with security checks
@ -164,7 +167,11 @@ class AuthenticationManager:
seconds = int(total_seconds % 60) seconds = int(total_seconds % 60)
time_until_unlock_str = f"{minutes}m {seconds}s" time_until_unlock_str = f"{minutes}m {seconds}s"
logger.warning(f"🔒 Account is locked for user {login} for another {time_until_unlock_str}.") logger.warning(f"🔒 Account is locked for user {login} for another {time_until_unlock_str}.")
return False, None, f"Account is temporarily locked due to too many failed attempts. Retry after {time_until_unlock_str}" return (
False,
None,
f"Account is temporarily locked due to too many failed attempts. Retry after {time_until_unlock_str}",
)
# Verify password # Verify password
if not self.password_security.verify_password(password, auth_data.password_hash): if not self.password_security.verify_password(password, auth_data.password_hash):
@ -176,7 +183,9 @@ class AuthenticationManager:
auth_data.locked_until = datetime.now(timezone.utc) + timedelta( auth_data.locked_until = datetime.now(timezone.utc) + timedelta(
minutes=SecurityConfig.ACCOUNT_LOCKOUT_DURATION_MINUTES minutes=SecurityConfig.ACCOUNT_LOCKOUT_DURATION_MINUTES
) )
logger.warning(f"🔒 Account locked for user {login} after {auth_data.login_attempts} failed attempts") logger.warning(
f"🔒 Account locked for user {login} after {auth_data.login_attempts} failed attempts"
)
# Update authentication record # Update authentication record
await self.database.set_authentication(user_data["id"], auth_data.model_dump()) await self.database.set_authentication(user_data["id"], auth_data.model_dump())
@ -238,6 +247,7 @@ class AuthenticationManager:
except Exception as e: except Exception as e:
logger.error(f"❌ Error updating last login for user {user_id}: {e}") logger.error(f"❌ Error updating last login for user {user_id}: {e}")
# Utility functions for common operations # Utility functions for common operations
def validate_password_strength(password: str) -> Tuple[bool, list]: def validate_password_strength(password: str) -> Tuple[bool, list]:
""" """
@ -270,6 +280,7 @@ def validate_password_strength(password: str) -> Tuple[bool, list]:
return len(issues) == 0, issues return len(issues) == 0, issues
def sanitize_login_input(login: str) -> str: def sanitize_login_input(login: str) -> str:
"""Sanitize login input (email or username)""" """Sanitize login input (email or username)"""
return login.strip().lower() if login else "" return login.strip().lower() if login else ""

View File

@ -33,11 +33,13 @@ background_task_manager: Optional[BackgroundTaskManager] = None
# Global database manager reference # Global database manager reference
db_manager = None db_manager = None
def set_db_manager(manager: DatabaseManager): def set_db_manager(manager: DatabaseManager):
"""Set the global database manager reference""" """Set the global database manager reference"""
global db_manager global db_manager
db_manager = manager db_manager = manager
def get_database() -> RedisDatabase: def get_database() -> RedisDatabase:
""" """
Safe database dependency that checks for availability Safe database dependency that checks for availability
@ -47,26 +49,18 @@ def get_database() -> RedisDatabase:
if db_manager is None: if db_manager is None:
logger.error("Database manager not initialized") logger.error("Database manager not initialized")
raise HTTPException( raise HTTPException(status_code=503, detail="Database not available - service starting up")
status_code=503,
detail="Database not available - service starting up"
)
if db_manager.is_shutting_down: if db_manager.is_shutting_down:
logger.warning("Database is shutting down") logger.warning("Database is shutting down")
raise HTTPException( raise HTTPException(status_code=503, detail="Service is shutting down")
status_code=503,
detail="Service is shutting down"
)
try: try:
return db_manager.get_database() return db_manager.get_database()
except RuntimeError as e: except RuntimeError as e:
logger.error(f"Database not available: {e}") logger.error(f"Database not available: {e}")
raise HTTPException( raise HTTPException(status_code=503, detail="Database connection not available")
status_code=503,
detail="Database connection not available"
)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy() to_encode = data.copy()
@ -78,6 +72,7 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=ALGORITHM) encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt return encoded_jwt
async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials = Depends(security)): async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""Enhanced token verification with guest session recovery""" """Enhanced token verification with guest session recovery"""
try: try:
@ -125,9 +120,9 @@ async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials
logger.error(f"❌ Token verification error: {e}") logger.error(f"❌ Token verification error: {e}")
raise HTTPException(status_code=401, detail="Token verification failed") raise HTTPException(status_code=401, detail="Token verification failed")
async def get_current_user( async def get_current_user(
user_id: str = Depends(verify_token_with_blacklist), user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database)
) -> BaseUserWithType: ) -> BaseUserWithType:
"""Get current user from database""" """Get current user from database"""
try: try:
@ -135,6 +130,7 @@ async def get_current_user(
candidate_data = await database.get_candidate(user_id) candidate_data = await database.get_candidate(user_id)
if candidate_data: if candidate_data:
from helpers.model_cast import cast_to_base_user_with_type from helpers.model_cast import cast_to_base_user_with_type
if candidate_data.get("is_AI"): if candidate_data.get("is_AI"):
return cast_to_base_user_with_type(CandidateAI.model_validate(candidate_data)) return cast_to_base_user_with_type(CandidateAI.model_validate(candidate_data))
else: else:
@ -152,16 +148,20 @@ async def get_current_user(
logger.error(f"❌ Error getting current user: {e}") logger.error(f"❌ Error getting current user: {e}")
raise HTTPException(status_code=404, detail="User not found") raise HTTPException(status_code=404, detail="User not found")
async def get_current_user_or_guest( async def get_current_user_or_guest(
user_id: str = Depends(verify_token_with_blacklist), user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database)
) -> BaseUserWithType: ) -> BaseUserWithType:
"""Get current user (including guests) from database""" """Get current user (including guests) from database"""
try: try:
# Check candidates first # Check candidates first
candidate_data = await database.get_candidate(user_id) candidate_data = await database.get_candidate(user_id)
if candidate_data: if candidate_data:
return Candidate.model_validate(candidate_data) if not candidate_data.get("is_AI") else CandidateAI.model_validate(candidate_data) return (
Candidate.model_validate(candidate_data)
if not candidate_data.get("is_AI")
else CandidateAI.model_validate(candidate_data)
)
# Check employers # Check employers
employer_data = await database.get_employer(user_id) employer_data = await database.get_employer(user_id)
@ -180,9 +180,9 @@ async def get_current_user_or_guest(
logger.error(f"❌ Error getting current user: {e}") logger.error(f"❌ Error getting current user: {e}")
raise HTTPException(status_code=404, detail="User not found") raise HTTPException(status_code=404, detail="User not found")
async def get_current_admin( async def get_current_admin(
user_id: str = Depends(verify_token_with_blacklist), user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database)
) -> BaseUserWithType: ) -> BaseUserWithType:
user = await get_current_user(user_id=user_id, database=database) user = await get_current_user(user_id=user_id, database=database)
if isinstance(user, Candidate) and user.is_admin: if isinstance(user, Candidate) and user.is_admin:
@ -193,6 +193,7 @@ async def get_current_admin(
logger.warning(f"⚠️ User {user_id} is not an admin") logger.warning(f"⚠️ User {user_id} is not an admin")
raise HTTPException(status_code=403, detail="Admin access required") raise HTTPException(status_code=403, detail="Admin access required")
prometheus_collector = CollectorRegistry() prometheus_collector = CollectorRegistry()
# Keep the Instrumentator instance alive # Keep the Instrumentator instance alive
@ -201,5 +202,5 @@ instrumentator = Instrumentator(
should_ignore_untemplated=True, should_ignore_untemplated=True,
should_group_untemplated=True, should_group_untemplated=True,
excluded_handlers=[f"{defines.api_prefix}/metrics"], excluded_handlers=[f"{defines.api_prefix}/metrics"],
registry=prometheus_collector registry=prometheus_collector,
) )

View File

@ -1,5 +1,6 @@
from typing import List, Dict from typing import List, Dict
from models import (Job) from models import Job
def get_requirements_list(job: Job) -> List[Dict[str, str]]: def get_requirements_list(job: Job) -> List[Dict[str, str]]:
requirements: List[Dict[str, str]] = [] requirements: List[Dict[str, str]] = []
@ -7,56 +8,56 @@ def get_requirements_list(job: Job) -> List[Dict[str, str]]:
if job.requirements: if job.requirements:
if job.requirements.technical_skills: if job.requirements.technical_skills:
if job.requirements.technical_skills.required: if job.requirements.technical_skills.required:
requirements.extend([ requirements.extend(
[
{"requirement": req, "domain": "Technical Skills (required)"} {"requirement": req, "domain": "Technical Skills (required)"}
for req in job.requirements.technical_skills.required for req in job.requirements.technical_skills.required
]) ]
)
if job.requirements.technical_skills.preferred: if job.requirements.technical_skills.preferred:
requirements.extend([ requirements.extend(
[
{"requirement": req, "domain": "Technical Skills (preferred)"} {"requirement": req, "domain": "Technical Skills (preferred)"}
for req in job.requirements.technical_skills.preferred for req in job.requirements.technical_skills.preferred
]) ]
)
if job.requirements.experience_requirements: if job.requirements.experience_requirements:
if job.requirements.experience_requirements.required: if job.requirements.experience_requirements.required:
requirements.extend([ requirements.extend(
[
{"requirement": req, "domain": "Experience (required)"} {"requirement": req, "domain": "Experience (required)"}
for req in job.requirements.experience_requirements.required for req in job.requirements.experience_requirements.required
]) ]
)
if job.requirements.experience_requirements.preferred: if job.requirements.experience_requirements.preferred:
requirements.extend([ requirements.extend(
[
{"requirement": req, "domain": "Experience (preferred)"} {"requirement": req, "domain": "Experience (preferred)"}
for req in job.requirements.experience_requirements.preferred for req in job.requirements.experience_requirements.preferred
]) ]
)
if job.requirements.soft_skills: if job.requirements.soft_skills:
requirements.extend([ requirements.extend([{"requirement": req, "domain": "Soft Skills"} for req in job.requirements.soft_skills])
{"requirement": req, "domain": "Soft Skills"}
for req in job.requirements.soft_skills
])
if job.requirements.experience: if job.requirements.experience:
requirements.extend([ requirements.extend([{"requirement": req, "domain": "Experience"} for req in job.requirements.experience])
{"requirement": req, "domain": "Experience"}
for req in job.requirements.experience
])
if job.requirements.education: if job.requirements.education:
requirements.extend([ requirements.extend([{"requirement": req, "domain": "Education"} for req in job.requirements.education])
{"requirement": req, "domain": "Education"}
for req in job.requirements.education
])
if job.requirements.certifications: if job.requirements.certifications:
requirements.extend([ requirements.extend(
{"requirement": req, "domain": "Certifications"} [{"requirement": req, "domain": "Certifications"} for req in job.requirements.certifications]
for req in job.requirements.certifications )
])
if job.requirements.preferred_attributes: if job.requirements.preferred_attributes:
requirements.extend([ requirements.extend(
[
{"requirement": req, "domain": "Preferred Attributes"} {"requirement": req, "domain": "Preferred Attributes"}
for req in job.requirements.preferred_attributes for req in job.requirements.preferred_attributes
]) ]
)
return requirements return requirements

View File

@ -13,15 +13,13 @@ from fastapi.responses import StreamingResponse
import defines import defines
from logger import logger from logger import logger
from models import DocumentType from models import DocumentType
from models import ( from models import Job, ChatMessage, DocumentType, ApiStatusType
Job,
ChatMessage, DocumentType, ApiStatusType
)
from typing import List, Dict from typing import List, Dict
from models import (Job) from models import Job
import utils.llm_proxy as llm_manager import utils.llm_proxy as llm_manager
async def get_last_item(generator): async def get_last_item(generator):
"""Get the last item from an async generator""" """Get the last item from an async generator"""
last_item = None last_item = None
@ -36,7 +34,7 @@ def filter_and_paginate(
limit: int = 20, limit: int = 20,
sort_by: Optional[str] = None, sort_by: Optional[str] = None,
sort_order: str = "desc", sort_order: str = "desc",
filters: Optional[Dict] = None filters: Optional[Dict] = None,
) -> Tuple[List[Any], int]: ) -> Tuple[List[Any], int]:
"""Filter, sort, and paginate items""" """Filter, sort, and paginate items"""
filtered_items = items.copy() filtered_items = items.copy()
@ -47,8 +45,7 @@ def filter_and_paginate(
if isinstance(filtered_items[0], dict) and key in filtered_items[0]: if isinstance(filtered_items[0], dict) and key in filtered_items[0]:
filtered_items = [item for item in filtered_items if item.get(key) == value] filtered_items = [item for item in filtered_items if item.get(key) == value]
elif hasattr(filtered_items[0], key) if filtered_items else False: elif hasattr(filtered_items[0], key) if filtered_items else False:
filtered_items = [item for item in filtered_items filtered_items = [item for item in filtered_items if getattr(item, key, None) == value]
if getattr(item, key, None) == value]
# Sort items # Sort items
if sort_by and filtered_items: if sort_by and filtered_items:
@ -72,6 +69,7 @@ def filter_and_paginate(
async def stream_agent_response(chat_agent, user_message, chat_session_data=None, database=None) -> StreamingResponse: async def stream_agent_response(chat_agent, user_message, chat_session_data=None, database=None) -> StreamingResponse:
"""Stream agent response with proper formatting""" """Stream agent response with proper formatting"""
async def message_stream_generator(): async def message_stream_generator():
"""Generator to stream messages with persistence""" """Generator to stream messages with persistence"""
final_message = None final_message = None
@ -97,12 +95,15 @@ async def stream_agent_response(chat_agent, user_message, chat_session_data=None
# metadata and other unnecessary fields for streaming # metadata and other unnecessary fields for streaming
if generated_message.status != ApiStatusType.DONE: if generated_message.status != ApiStatusType.DONE:
from models import ChatMessageStreaming, ChatMessageStatus from models import ChatMessageStreaming, ChatMessageStatus
if not isinstance(generated_message, ChatMessageStreaming) and not isinstance(generated_message, ChatMessageStatus):
if not isinstance(generated_message, ChatMessageStreaming) and not isinstance(
generated_message, ChatMessageStatus
):
raise TypeError( raise TypeError(
f"Expected ChatMessageStreaming or ChatMessageStatus, got {type(generated_message)}" f"Expected ChatMessageStreaming or ChatMessageStatus, got {type(generated_message)}"
) )
json_data = generated_message.model_dump(mode='json', by_alias=True) json_data = generated_message.model_dump(mode="json", by_alias=True)
json_str = json.dumps(json_data) json_str = json.dumps(json_data)
yield f"data: {json_str}\n\n" yield f"data: {json_str}\n\n"
@ -147,16 +148,16 @@ def get_document_type_from_filename(filename: str) -> DocumentType:
extension = pathlib.Path(filename).suffix.lower() extension = pathlib.Path(filename).suffix.lower()
type_mapping = { type_mapping = {
'.pdf': DocumentType.PDF, ".pdf": DocumentType.PDF,
'.docx': DocumentType.DOCX, ".docx": DocumentType.DOCX,
'.doc': DocumentType.DOCX, ".doc": DocumentType.DOCX,
'.txt': DocumentType.TXT, ".txt": DocumentType.TXT,
'.md': DocumentType.MARKDOWN, ".md": DocumentType.MARKDOWN,
'.markdown': DocumentType.MARKDOWN, ".markdown": DocumentType.MARKDOWN,
'.png': DocumentType.IMAGE, ".png": DocumentType.IMAGE,
'.jpg': DocumentType.IMAGE, ".jpg": DocumentType.IMAGE,
'.jpeg': DocumentType.IMAGE, ".jpeg": DocumentType.IMAGE,
'.gif': DocumentType.IMAGE, ".gif": DocumentType.IMAGE,
} }
return type_mapping.get(extension, DocumentType.TXT) return type_mapping.get(extension, DocumentType.TXT)
@ -176,17 +177,12 @@ async def reformat_as_markdown(database, candidate_entity, content: str):
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS) chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
if not chat_agent: if not chat_agent:
error_message = ChatMessageError( error_message = ChatMessageError(session_id=MOCK_UUID, content="No agent found for job requirements chat type")
session_id=MOCK_UUID,
content="No agent found for job requirements chat type"
)
yield error_message yield error_message
return return
status_message = ChatMessageStatus( status_message = ChatMessageStatus(
session_id=MOCK_UUID, session_id=MOCK_UUID, content="Reformatting job description as markdown...", activity=ApiActivityType.CONVERTING
content="Reformatting job description as markdown...",
activity=ApiActivityType.CONVERTING
) )
yield status_message yield status_message
@ -199,16 +195,13 @@ async def reformat_as_markdown(database, candidate_entity, content: str):
system_prompt=""" system_prompt="""
You are a document editor. Take the provided job description and reformat as legible markdown. You are a document editor. Take the provided job description and reformat as legible markdown.
Return only the markdown content, no other text. Make sure all content is included. Return only the markdown content, no other text. Make sure all content is included.
""" """,
): ):
pass pass
if not message or not isinstance(message, ChatMessage): if not message or not isinstance(message, ChatMessage):
logger.error("❌ Failed to reformat job description to markdown") logger.error("❌ Failed to reformat job description to markdown")
error_message = ChatMessageError( error_message = ChatMessageError(session_id=MOCK_UUID, content="Failed to reformat job description")
session_id=MOCK_UUID,
content="Failed to reformat job description"
)
yield error_message yield error_message
return return
@ -226,14 +219,19 @@ Return only the markdown content, no other text. Make sure all content is includ
async def create_job_from_content(database, current_user, content: str): async def create_job_from_content(database, current_user, content: str):
"""Create a job from content using AI analysis""" """Create a job from content using AI analysis"""
from models import ( from models import (
MOCK_UUID, ApiStatusType, ChatMessageError, ChatMessageStatus, MOCK_UUID,
ApiActivityType, ChatContextType, JobRequirementsMessage ApiStatusType,
ChatMessageError,
ChatMessageStatus,
ApiActivityType,
ChatContextType,
JobRequirementsMessage,
) )
status_message = ChatMessageStatus( status_message = ChatMessageStatus(
session_id=MOCK_UUID, session_id=MOCK_UUID,
content=f"Initiating connection with {current_user.first_name}'s AI agent...", content=f"Initiating connection with {current_user.first_name}'s AI agent...",
activity=ApiActivityType.INFO activity=ApiActivityType.INFO,
) )
yield status_message yield status_message
await asyncio.sleep(0) # Let the status message propagate await asyncio.sleep(0) # Let the status message propagate
@ -248,10 +246,7 @@ async def create_job_from_content(database, current_user, content: str):
yield message yield message
if not message or not isinstance(message, ChatMessage): if not message or not isinstance(message, ChatMessage):
error_message = ChatMessageError( error_message = ChatMessageError(session_id=MOCK_UUID, content="Failed to reformat job description")
session_id=MOCK_UUID,
content="Failed to reformat job description"
)
yield error_message yield error_message
return return
@ -260,8 +255,7 @@ async def create_job_from_content(database, current_user, content: str):
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS) chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
if not chat_agent: if not chat_agent:
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=MOCK_UUID, session_id=MOCK_UUID, content="No agent found for job requirements chat type"
content="No agent found for job requirements chat type"
) )
yield error_message yield error_message
return return
@ -269,24 +263,20 @@ async def create_job_from_content(database, current_user, content: str):
status_message = ChatMessageStatus( status_message = ChatMessageStatus(
session_id=MOCK_UUID, session_id=MOCK_UUID,
content="Analyzing document for company and requirement details...", content="Analyzing document for company and requirement details...",
activity=ApiActivityType.SEARCHING activity=ApiActivityType.SEARCHING,
) )
yield status_message yield status_message
message = None message = None
async for message in chat_agent.generate( async for message in chat_agent.generate(
llm=llm_manager.get_llm(), llm=llm_manager.get_llm(), model=defines.model, session_id=MOCK_UUID, prompt=markdown_message.content
model=defines.model,
session_id=MOCK_UUID,
prompt=markdown_message.content
): ):
if message.status != ApiStatusType.DONE: if message.status != ApiStatusType.DONE:
yield message yield message
if not message or not isinstance(message, JobRequirementsMessage): if not message or not isinstance(message, JobRequirementsMessage):
error_message = ChatMessageError( error_message = ChatMessageError(
session_id=MOCK_UUID, session_id=MOCK_UUID, content="Job extraction did not convert successfully"
content="Job extraction did not convert successfully"
) )
yield error_message yield error_message
return return
@ -296,62 +286,63 @@ async def create_job_from_content(database, current_user, content: str):
yield job_requirements yield job_requirements
return return
def get_requirements_list(job: Job) -> List[Dict[str, str]]: def get_requirements_list(job: Job) -> List[Dict[str, str]]:
requirements: List[Dict[str, str]] = [] requirements: List[Dict[str, str]] = []
if job.requirements: if job.requirements:
if job.requirements.technical_skills: if job.requirements.technical_skills:
if job.requirements.technical_skills.required: if job.requirements.technical_skills.required:
requirements.extend([ requirements.extend(
[
{"requirement": req, "domain": "Technical Skills (required)"} {"requirement": req, "domain": "Technical Skills (required)"}
for req in job.requirements.technical_skills.required for req in job.requirements.technical_skills.required
]) ]
)
if job.requirements.technical_skills.preferred: if job.requirements.technical_skills.preferred:
requirements.extend([ requirements.extend(
[
{"requirement": req, "domain": "Technical Skills (preferred)"} {"requirement": req, "domain": "Technical Skills (preferred)"}
for req in job.requirements.technical_skills.preferred for req in job.requirements.technical_skills.preferred
]) ]
)
if job.requirements.experience_requirements: if job.requirements.experience_requirements:
if job.requirements.experience_requirements.required: if job.requirements.experience_requirements.required:
requirements.extend([ requirements.extend(
[
{"requirement": req, "domain": "Experience (required)"} {"requirement": req, "domain": "Experience (required)"}
for req in job.requirements.experience_requirements.required for req in job.requirements.experience_requirements.required
]) ]
)
if job.requirements.experience_requirements.preferred: if job.requirements.experience_requirements.preferred:
requirements.extend([ requirements.extend(
[
{"requirement": req, "domain": "Experience (preferred)"} {"requirement": req, "domain": "Experience (preferred)"}
for req in job.requirements.experience_requirements.preferred for req in job.requirements.experience_requirements.preferred
]) ]
)
if job.requirements.soft_skills: if job.requirements.soft_skills:
requirements.extend([ requirements.extend([{"requirement": req, "domain": "Soft Skills"} for req in job.requirements.soft_skills])
{"requirement": req, "domain": "Soft Skills"}
for req in job.requirements.soft_skills
])
if job.requirements.experience: if job.requirements.experience:
requirements.extend([ requirements.extend([{"requirement": req, "domain": "Experience"} for req in job.requirements.experience])
{"requirement": req, "domain": "Experience"}
for req in job.requirements.experience
])
if job.requirements.education: if job.requirements.education:
requirements.extend([ requirements.extend([{"requirement": req, "domain": "Education"} for req in job.requirements.education])
{"requirement": req, "domain": "Education"}
for req in job.requirements.education
])
if job.requirements.certifications: if job.requirements.certifications:
requirements.extend([ requirements.extend(
{"requirement": req, "domain": "Certifications"} [{"requirement": req, "domain": "Certifications"} for req in job.requirements.certifications]
for req in job.requirements.certifications )
])
if job.requirements.preferred_attributes: if job.requirements.preferred_attributes:
requirements.extend([ requirements.extend(
[
{"requirement": req, "domain": "Preferred Attributes"} {"requirement": req, "domain": "Preferred Attributes"}
for req in job.requirements.preferred_attributes for req in job.requirements.preferred_attributes
]) ]
)
return requirements return requirements

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,7 @@
from prometheus_client import Counter, Histogram # type: ignore from prometheus_client import Counter, Histogram # type: ignore
from threading import Lock from threading import Lock
def singleton(cls): def singleton(cls):
instance = None instance = None
lock = Lock() lock = Lock()

View File

@ -12,50 +12,62 @@ from logger import logger
from .dependencies import get_current_user_or_guest, get_database from .dependencies import get_current_user_or_guest, get_database
async def get_rate_limiter(database: RedisDatabase = Depends(get_database)) -> RateLimiter: async def get_rate_limiter(database: RedisDatabase = Depends(get_database)) -> RateLimiter:
"""Dependency to get rate limiter instance""" """Dependency to get rate limiter instance"""
return RateLimiter(database) return RateLimiter(database)
class RateLimitConfig(BaseModel): class RateLimitConfig(BaseModel):
"""Rate limit configuration""" """Rate limit configuration"""
requests_per_minute: int requests_per_minute: int
requests_per_hour: int requests_per_hour: int
requests_per_day: int requests_per_day: int
burst_limit: int # Maximum requests in a short burst burst_limit: int # Maximum requests in a short burst
burst_window_seconds: int = 60 # Window for burst detection burst_window_seconds: int = 60 # Window for burst detection
class GuestRateLimitConfig(RateLimitConfig): class GuestRateLimitConfig(RateLimitConfig):
"""Rate limits for guest users - more restrictive""" """Rate limits for guest users - more restrictive"""
requests_per_minute: int = 10 requests_per_minute: int = 10
requests_per_hour: int = 100 requests_per_hour: int = 100
requests_per_day: int = 500 requests_per_day: int = 500
burst_limit: int = 15 burst_limit: int = 15
burst_window_seconds: int = 60 burst_window_seconds: int = 60
class AuthenticatedUserRateLimitConfig(RateLimitConfig): class AuthenticatedUserRateLimitConfig(RateLimitConfig):
"""Rate limits for authenticated users - more generous""" """Rate limits for authenticated users - more generous"""
requests_per_minute: int = 60 requests_per_minute: int = 60
requests_per_hour: int = 1000 requests_per_hour: int = 1000
requests_per_day: int = 10000 requests_per_day: int = 10000
burst_limit: int = 100 burst_limit: int = 100
burst_window_seconds: int = 60 burst_window_seconds: int = 60
class PremiumUserRateLimitConfig(RateLimitConfig): class PremiumUserRateLimitConfig(RateLimitConfig):
"""Rate limits for premium/admin users - most generous""" """Rate limits for premium/admin users - most generous"""
requests_per_minute: int = 120 requests_per_minute: int = 120
requests_per_hour: int = 5000 requests_per_hour: int = 5000
requests_per_day: int = 50000 requests_per_day: int = 50000
burst_limit: int = 200 burst_limit: int = 200
burst_window_seconds: int = 60 burst_window_seconds: int = 60
class RateLimitResult(BaseModel): class RateLimitResult(BaseModel):
"""Result of rate limit check""" """Result of rate limit check"""
allowed: bool allowed: bool
reason: Optional[str] = None reason: Optional[str] = None
retry_after_seconds: Optional[int] = None retry_after_seconds: Optional[int] = None
remaining_requests: Dict[str, int] = {} remaining_requests: Dict[str, int] = {}
reset_times: Dict[str, datetime] = {} reset_times: Dict[str, datetime] = {}
class RateLimiter: class RateLimiter:
"""Rate limiter using Redis for distributed rate limiting""" """Rate limiter using Redis for distributed rate limiting"""
@ -78,11 +90,7 @@ class RateLimiter:
return self.user_config return self.user_config
async def check_rate_limit( async def check_rate_limit(
self, self, user_id: str, user_type: str, is_admin: bool = False, endpoint: Optional[str] = None
user_id: str,
user_type: str,
is_admin: bool = False,
endpoint: Optional[str] = None
) -> RateLimitResult: ) -> RateLimitResult:
""" """
Check if user has exceeded rate limits Check if user has exceeded rate limits
@ -105,7 +113,7 @@ class RateLimiter:
"minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}", "minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}",
"hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}", "hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}",
"day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}", "day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}",
"burst": f"{base_key}:burst" "burst": f"{base_key}:burst",
} }
# Add endpoint-specific limiting if provided # Add endpoint-specific limiting if provided
@ -125,7 +133,7 @@ class RateLimiter:
"minute": int(results[0] or 0), "minute": int(results[0] or 0),
"hour": int(results[1] or 0), "hour": int(results[1] or 0),
"day": int(results[2] or 0), "day": int(results[2] or 0),
"burst": int(results[3] or 0) "burst": int(results[3] or 0),
} }
# Check limits # Check limits
@ -133,7 +141,7 @@ class RateLimiter:
"minute": config.requests_per_minute, "minute": config.requests_per_minute,
"hour": config.requests_per_hour, "hour": config.requests_per_hour,
"day": config.requests_per_day, "day": config.requests_per_day,
"burst": config.burst_limit "burst": config.burst_limit,
} }
# Check each limit # Check each limit
@ -146,18 +154,22 @@ class RateLimiter:
elif window == "hour": elif window == "hour":
retry_after = 3600 - (current_time.minute * 60 + current_time.second) retry_after = 3600 - (current_time.minute * 60 + current_time.second)
elif window == "day": elif window == "day":
retry_after = 86400 - (current_time.hour * 3600 + current_time.minute * 60 + current_time.second) retry_after = 86400 - (
current_time.hour * 3600 + current_time.minute * 60 + current_time.second
)
else: # burst else: # burst
retry_after = config.burst_window_seconds retry_after = config.burst_window_seconds
logger.warning(f"🚫 Rate limit exceeded for {user_type} {user_id}: {current_count}/{limit} {window}") logger.warning(
f"🚫 Rate limit exceeded for {user_type} {user_id}: {current_count}/{limit} {window}"
)
return RateLimitResult( return RateLimitResult(
allowed=False, allowed=False,
reason=f"Rate limit exceeded: {current_count}/{limit} requests per {window}", reason=f"Rate limit exceeded: {current_count}/{limit} requests per {window}",
retry_after_seconds=retry_after, retry_after_seconds=retry_after,
remaining_requests={k: max(0, limits[k] - v) for k, v in current_counts.items()}, remaining_requests={k: max(0, limits[k] - v) for k, v in current_counts.items()},
reset_times=self._calculate_reset_times(current_time) reset_times=self._calculate_reset_times(current_time),
) )
# If we get here, request is allowed - increment counters # If we get here, request is allowed - increment counters
@ -182,17 +194,12 @@ class RateLimiter:
await pipe.execute() await pipe.execute()
# Calculate remaining requests # Calculate remaining requests
remaining = { remaining = {k: max(0, limits[k] - (current_counts[k] + 1)) for k in current_counts.keys()}
k: max(0, limits[k] - (current_counts[k] + 1))
for k in current_counts.keys()
}
logger.debug(f"✅ Rate limit check passed for {user_type} {user_id}") logger.debug(f"✅ Rate limit check passed for {user_type} {user_id}")
return RateLimitResult( return RateLimitResult(
allowed=True, allowed=True, remaining_requests=remaining, reset_times=self._calculate_reset_times(current_time)
remaining_requests=remaining,
reset_times=self._calculate_reset_times(current_time)
) )
except Exception as e: except Exception as e:
@ -206,18 +213,9 @@ class RateLimiter:
next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1) next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
next_day = current_time.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1) next_day = current_time.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
return { return {"minute": next_minute, "hour": next_hour, "day": next_day}
"minute": next_minute,
"hour": next_hour,
"day": next_day
}
async def get_user_rate_limit_status( async def get_user_rate_limit_status(self, user_id: str, user_type: str, is_admin: bool = False) -> Dict[str, Any]:
self,
user_id: str,
user_type: str,
is_admin: bool = False
) -> Dict[str, Any]:
"""Get current rate limit status for a user""" """Get current rate limit status for a user"""
config = self.get_config_for_user(user_type, is_admin) config = self.get_config_for_user(user_type, is_admin)
current_time = datetime.now(UTC) current_time = datetime.now(UTC)
@ -227,7 +225,7 @@ class RateLimiter:
"minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}", "minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}",
"hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}", "hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}",
"day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}", "day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}",
"burst": f"{base_key}:burst" "burst": f"{base_key}:burst",
} }
try: try:
@ -240,14 +238,14 @@ class RateLimiter:
"minute": int(results[0] or 0), "minute": int(results[0] or 0),
"hour": int(results[1] or 0), "hour": int(results[1] or 0),
"day": int(results[2] or 0), "day": int(results[2] or 0),
"burst": int(results[3] or 0) "burst": int(results[3] or 0),
} }
limits = { limits = {
"minute": config.requests_per_minute, "minute": config.requests_per_minute,
"hour": config.requests_per_hour, "hour": config.requests_per_hour,
"day": config.requests_per_day, "day": config.requests_per_day,
"burst": config.burst_limit "burst": config.burst_limit,
} }
return { return {
@ -258,7 +256,7 @@ class RateLimiter:
"limits": limits, "limits": limits,
"remaining": {k: max(0, limits[k] - current_counts[k]) for k in limits.keys()}, "remaining": {k: max(0, limits[k] - current_counts[k]) for k in limits.keys()},
"reset_times": self._calculate_reset_times(current_time), "reset_times": self._calculate_reset_times(current_time),
"config": config.model_dump() "config": config.model_dump(),
} }
except Exception as e: except Exception as e:
@ -295,11 +293,9 @@ class RateLimiter:
# Rate Limited Decorator # Rate Limited Decorator
# ============================ # ============================
def rate_limited( def rate_limited(
guest_per_minute: int = 10, guest_per_minute: int = 10, user_per_minute: int = 60, admin_per_minute: int = 120, endpoint_specific: bool = True
user_per_minute: int = 60,
admin_per_minute: int = 120,
endpoint_specific: bool = True
): ):
""" """
Decorator to easily apply rate limiting to endpoints Decorator to easily apply rate limiting to endpoints
@ -320,11 +316,13 @@ def rate_limited(
): ):
return {"message": "Rate limited endpoint"} return {"message": "Rate limited endpoint"}
""" """
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
@wraps(func) @wraps(func)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
# Extract dependencies from function signature # Extract dependencies from function signature
import inspect import inspect
inspect.signature(func) inspect.signature(func)
# Get request, current_user, and rate_limiter from kwargs or args # Get request, current_user, and rate_limiter from kwargs or args
@ -336,7 +334,7 @@ def rate_limited(
for param_name, param_value in kwargs.items(): for param_name, param_value in kwargs.items():
if isinstance(param_value, Request): if isinstance(param_value, Request):
request = param_value request = param_value
elif hasattr(param_value, 'user_type'): # User-like object elif hasattr(param_value, "user_type"): # User-like object
current_user = param_value current_user = param_value
elif isinstance(param_value, RateLimiter): elif isinstance(param_value, RateLimiter):
rate_limiter = param_value rate_limiter = param_value
@ -350,30 +348,33 @@ def rate_limited(
# Apply rate limiting if we have the required components # Apply rate limiting if we have the required components
if request and current_user and rate_limiter: if request and current_user and rate_limiter:
await apply_custom_rate_limiting( await apply_custom_rate_limiting(
request, current_user, rate_limiter, request, current_user, rate_limiter, guest_per_minute, user_per_minute, admin_per_minute
guest_per_minute, user_per_minute, admin_per_minute
) )
# Call the original function # Call the original function
return await func(*args, **kwargs) return await func(*args, **kwargs)
return wrapper return wrapper
return decorator return decorator
async def apply_custom_rate_limiting( async def apply_custom_rate_limiting(
request: Request, request: Request,
current_user, current_user,
rate_limiter: RateLimiter, rate_limiter: RateLimiter,
guest_per_minute: int, guest_per_minute: int,
user_per_minute: int, user_per_minute: int,
admin_per_minute: int admin_per_minute: int,
): ):
"""Apply custom rate limiting with specified limits""" """Apply custom rate limiting with specified limits"""
try: try:
# Determine user info # Determine user info
user_id = current_user.id user_id = current_user.id
user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type) user_type = (
is_admin = getattr(current_user, 'is_admin', False) current_user.user_type.value if hasattr(current_user.user_type, "value") else str(current_user.user_type)
)
is_admin = getattr(current_user, "is_admin", False)
# Determine appropriate limit # Determine appropriate limit
if is_admin: if is_admin:
@ -385,13 +386,17 @@ async def apply_custom_rate_limiting(
# Create custom rate limit key # Create custom rate limit key
current_time = datetime.now(UTC) current_time = datetime.now(UTC)
custom_key = f"custom_rate_limit:{request.url.path}:{user_type}:{user_id}:minute:{current_time.strftime('%Y%m%d%H%M')}" custom_key = (
f"custom_rate_limit:{request.url.path}:{user_type}:{user_id}:minute:{current_time.strftime('%Y%m%d%H%M')}"
)
# Check current usage # Check current usage
current_count = int(await rate_limiter.redis.get(custom_key) or 0) current_count = int(await rate_limiter.redis.get(custom_key) or 0)
if current_count >= requests_per_minute: if current_count >= requests_per_minute:
logger.warning(f"🚫 Custom rate limit exceeded for {user_type} {user_id}: {current_count}/{requests_per_minute}") logger.warning(
f"🚫 Custom rate limit exceeded for {user_type} {user_id}: {current_count}/{requests_per_minute}"
)
raise HTTPException( raise HTTPException(
status_code=429, status_code=429,
detail={ detail={
@ -399,9 +404,9 @@ async def apply_custom_rate_limiting(
"message": f"Custom rate limit exceeded: {current_count}/{requests_per_minute} requests per minute", "message": f"Custom rate limit exceeded: {current_count}/{requests_per_minute} requests per minute",
"retryAfter": 60 - current_time.second, "retryAfter": 60 - current_time.second,
"userType": user_type, "userType": user_type,
"endpoint": request.url.path "endpoint": request.url.path,
}, },
headers={"Retry-After": str(60 - current_time.second)} headers={"Retry-After": str(60 - current_time.second)},
) )
# Increment counter # Increment counter
@ -410,7 +415,9 @@ async def apply_custom_rate_limiting(
pipe.expire(custom_key, 120) # 2 minutes TTL pipe.expire(custom_key, 120) # 2 minutes TTL
await pipe.execute() await pipe.execute()
logger.debug(f"✅ Custom rate limit check passed for {user_type} {user_id}: {current_count + 1}/{requests_per_minute}") logger.debug(
f"✅ Custom rate limit check passed for {user_type} {user_id}: {current_count + 1}/{requests_per_minute}"
)
except HTTPException: except HTTPException:
raise raise
@ -418,15 +425,13 @@ async def apply_custom_rate_limiting(
logger.error(f"❌ Custom rate limiting error: {e}") logger.error(f"❌ Custom rate limiting error: {e}")
# Fail open # Fail open
# ============================ # ============================
# Alternative: FastAPI Dependency-Based Rate Limiting # Alternative: FastAPI Dependency-Based Rate Limiting
# ============================ # ============================
def create_rate_limit_dependency(
guest_per_minute: int = 10, def create_rate_limit_dependency(guest_per_minute: int = 10, user_per_minute: int = 60, admin_per_minute: int = 120):
user_per_minute: int = 60,
admin_per_minute: int = 120
):
""" """
Create a FastAPI dependency for rate limiting Create a FastAPI dependency for rate limiting
@ -441,23 +446,25 @@ def create_rate_limit_dependency(
): ):
return {"message": "Rate limited endpoint"} return {"message": "Rate limited endpoint"}
""" """
async def rate_limit_dependency( async def rate_limit_dependency(
request: Request, request: Request,
current_user=Depends(get_current_user_or_guest), current_user=Depends(get_current_user_or_guest),
rate_limiter: RateLimiter = Depends(get_rate_limiter) rate_limiter: RateLimiter = Depends(get_rate_limiter),
): ):
await apply_custom_rate_limiting( await apply_custom_rate_limiting(
request, current_user, rate_limiter, request, current_user, rate_limiter, guest_per_minute, user_per_minute, admin_per_minute
guest_per_minute, user_per_minute, admin_per_minute
) )
return True return True
return rate_limit_dependency return rate_limit_dependency
# ============================ # ============================
# Rate Limiting Utilities # Rate Limiting Utilities
# ============================ # ============================
class EndpointRateLimiter: class EndpointRateLimiter:
"""Utility class for endpoint-specific rate limiting""" """Utility class for endpoint-specific rate limiting"""
@ -477,9 +484,11 @@ class EndpointRateLimiter:
return True # No custom limits set return True # No custom limits set
limits = self.custom_limits[endpoint] limits = self.custom_limits[endpoint]
user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type) user_type = (
current_user.user_type.value if hasattr(current_user.user_type, "value") else str(current_user.user_type)
)
if getattr(current_user, 'is_admin', False): if getattr(current_user, "is_admin", False):
user_type = "admin" user_type = "admin"
limit = limits.get(user_type, limits.get("default", 60)) limit = limits.get(user_type, limits.get("default", 60))
@ -491,8 +500,7 @@ class EndpointRateLimiter:
if current_count >= limit: if current_count >= limit:
raise HTTPException( raise HTTPException(
status_code=429, status_code=429, detail=f"Endpoint rate limit exceeded: {current_count}/{limit} for {endpoint}"
detail=f"Endpoint rate limit exceeded: {current_count}/{limit} for {endpoint}"
) )
# Increment counter # Increment counter
@ -501,9 +509,11 @@ class EndpointRateLimiter:
return True return True
# Global endpoint rate limiter instance # Global endpoint rate limiter instance
endpoint_rate_limiter = None endpoint_rate_limiter = None
def get_endpoint_rate_limiter(rate_limiter: RateLimiter = Depends(get_rate_limiter)) -> EndpointRateLimiter: def get_endpoint_rate_limiter(rate_limiter: RateLimiter = Depends(get_rate_limiter)) -> EndpointRateLimiter:
"""Get endpoint rate limiter instance""" """Get endpoint rate limiter instance"""
global endpoint_rate_limiter global endpoint_rate_limiter
@ -511,15 +521,14 @@ def get_endpoint_rate_limiter(rate_limiter: RateLimiter = Depends(get_rate_limit
endpoint_rate_limiter = EndpointRateLimiter(rate_limiter) endpoint_rate_limiter = EndpointRateLimiter(rate_limiter)
# Configure endpoint-specific limits # Configure endpoint-specific limits
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/chat/sessions/*/messages/stream", { endpoint_rate_limiter.set_endpoint_limits(
"guest": 5, "candidate": 30, "employer": 30, "admin": 100 "/api/1.0/chat/sessions/*/messages/stream", {"guest": 5, "candidate": 30, "employer": 30, "admin": 100}
}) )
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/candidates/documents/upload", { endpoint_rate_limiter.set_endpoint_limits(
"guest": 2, "candidate": 10, "employer": 10, "admin": 50 "/api/1.0/candidates/documents/upload", {"guest": 2, "candidate": 10, "employer": 10, "admin": 50}
}) )
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/jobs", { endpoint_rate_limiter.set_endpoint_limits(
"guest": 1, "candidate": 5, "employer": 20, "admin": 50 "/api/1.0/jobs", {"guest": 1, "candidate": 5, "employer": 20, "admin": 50}
}) )
return endpoint_rate_limiter return endpoint_rate_limiter

View File

@ -3,37 +3,17 @@ Response utility functions for consistent API responses
""" """
from typing import Any, Optional, Dict, List from typing import Any, Optional, Dict, List
def create_success_response(data: Any, meta: Optional[Dict] = None) -> Dict: def create_success_response(data: Any, meta: Optional[Dict] = None) -> Dict:
return { return {"success": True, "data": data, "meta": meta}
"success": True,
"data": data,
"meta": meta
}
def create_error_response(code: str, message: str, details: Any = None) -> Dict: def create_error_response(code: str, message: str, details: Any = None) -> Dict:
return { return {"success": False, "error": {"code": code, "message": message, "details": details}}
"success": False,
"error": {
"code": code,
"message": message,
"details": details
}
}
def create_paginated_response(
data: List[Any], def create_paginated_response(data: List[Any], page: int, limit: int, total: int) -> Dict:
page: int,
limit: int,
total: int
) -> Dict:
total_pages = (total + limit - 1) // limit total_pages = (total + limit - 1) // limit
has_more = page < total_pages has_more = page < total_pages
return { return {"data": data, "total": total, "page": page, "limit": limit, "totalPages": total_pages, "hasMore": has_more}
"data": data,
"total": total,
"page": page,
"limit": limit,
"totalPages": total_pages,
"hasMore": has_more
}