diff --git a/src/backend/agents/__init__.py b/src/backend/agents/__init__.py index 03280b1..6f06b06 100644 --- a/src/backend/agents/__init__.py +++ b/src/backend/agents/__init__.py @@ -17,11 +17,9 @@ from logger import logger AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses # Maps class_name to (module_name, class_name) -class_registry: Dict[str, Tuple[str, str]] = ( - {} -) +class_registry: Dict[str, Tuple[str, str]] = {} -__all__ = ['get_or_create_agent'] +__all__ = ["get_or_create_agent"] package_dir = pathlib.Path(__file__).parent package_name = __name__ @@ -38,16 +36,11 @@ for path in package_dir.glob("*.py"): # Find all Agent subclasses in the module for name, obj in inspect.getmembers(module, inspect.isclass): - if ( - issubclass(obj, AnyAgent) - and obj is not AnyAgent - and obj is not Agent - and name not in class_registry - ): + if issubclass(obj, AnyAgent) and obj is not AnyAgent and obj is not Agent and name not in class_registry: class_registry[name] = (full_module_name, name) globals()[name] = obj logger.info(f"Adding agent: {name}") - __all__.append(name) # type: ignore + __all__.append(name) # type: ignore except ImportError as e: logger.error(traceback.format_exc()) logger.error(f"Error importing {full_module_name}: {e}") diff --git a/src/backend/agents/base.py b/src/backend/agents/base.py index 03ec451..9e8f8c5 100644 --- a/src/backend/agents/base.py +++ b/src/backend/agents/base.py @@ -32,19 +32,45 @@ from pathlib import Path from rag import start_file_watcher, ChromaDBFileWatcher import defines from logger import logger -from models import (Tunables, ChatMessageUser, ChatMessage, RagEntry, ChatMessageMetaData, ApiStatusType, Candidate, ChatContextType) +from models import ( + Tunables, + ChatMessageUser, + ChatMessage, + RagEntry, + ChatMessageMetaData, + ApiStatusType, + Candidate, + ChatContextType, +) import utils.llm_proxy as llm_manager from database.manager import RedisDatabase from models import ChromaDBGetResponse from utils.metrics import Metrics -from models import ( ApiActivityType, ApiMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiStatusType, ChatMessageMetaData, Candidate) +from models import ( + ApiActivityType, + ApiMessage, + ChatMessageError, + ChatMessageRagSearch, + ChatMessageStatus, + ChatMessageStreaming, + LLMMessage, + ChatMessage, + ChatOptions, + ChatMessageUser, + Tunables, + ApiStatusType, + ChatMessageMetaData, + Candidate, +) from logger import logger import defines from .registry import agent_registry -from models import ( ChromaDBGetResponse ) +from models import ChromaDBGetResponse + + class CandidateEntity(Candidate): model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc @@ -54,18 +80,15 @@ class CandidateEntity(Candidate): async def cleanup(self): """Cleanup resources associated with this entity""" - + # Internal instance members CandidateEntity__agents: List[Agent] = [] CandidateEntity__observer: Optional[Any] = Field(default=None, exclude=True) CandidateEntity__file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True) - CandidateEntity__prometheus_collector: Optional[CollectorRegistry] = Field( - default=None, exclude=True - ) + CandidateEntity__prometheus_collector: Optional[CollectorRegistry] = Field(default=None, exclude=True) CandidateEntity__metrics: Optional[Metrics] = Field( - default=None, - description="Metrics collector for this agent, used to track performance and usage." + default=None, description="Metrics collector for this agent, used to track performance and usage." ) def __init__(self, candidate=None): @@ -78,17 +101,17 @@ class CandidateEntity(Candidate): @classmethod def exists(cls, username: str): # Validate username format (only allow safe characters) - if not re.match(r'^[a-zA-Z0-9_-]+$', username): + if not re.match(r"^[a-zA-Z0-9_-]+$", username): return False # Invalid username characters - + # Check for minimum and maximum length if not (3 <= len(username) <= 32): return False # Invalid username length - + # Use Path for safe path handling and normalization user_dir = Path(defines.user_dir) / username user_info_path = user_dir / defines.user_info_file - + # Ensure the final path is actually within the intended parent directory # to help prevent directory traversal attacks try: @@ -96,33 +119,29 @@ class CandidateEntity(Candidate): return False # Path traversal attempt detected except (ValueError, RuntimeError): # Potential exceptions from resolve() return False - + # Check if file exists return user_info_path.is_file() def get_or_create_agent(self, agent_type: ChatContextType) -> Agent: """ Get or create an agent of the specified type for this candidate. - + Args: agent_type: The type of agent to create (default is 'candidate_chat'). **kwargs: Additional fields required by the specific agent subclass. - + Returns: The created agent instance. """ - + # Only instantiate one agent of each type per user for agent in self.CandidateEntity__agents: if agent.agent_type == agent_type: return agent - return get_or_create_agent( - agent_type=agent_type, - user=self, - prometheus_collector=self.prometheus_collector - ) - + return get_or_create_agent(agent_type=agent_type, user=self, prometheus_collector=self.prometheus_collector) + # Wrapper properties that map into file_watcher @property def umap_collection(self) -> ChromaDBGetResponse: @@ -132,6 +151,7 @@ class CandidateEntity(Candidate): # Fields managed by initialize() CandidateEntity__initialized: bool = Field(default=False, exclude=True) + @property def metrics(self) -> Metrics: if not self.CandidateEntity__metrics: @@ -154,21 +174,16 @@ class CandidateEntity(Candidate): def observer(self) -> Any: if not self.CandidateEntity__observer: raise ValueError("initialize() has not been called.") - return self.CandidateEntity__observer + return self.CandidateEntity__observer def collect_metrics(self, agent: Agent, response): if not self.metrics: logger.warning("No metrics collector set for this agent.") return - self.metrics.tokens_prompt.labels(agent=agent.agent_type).inc( - response.usage.prompt_eval_count - ) + self.metrics.tokens_prompt.labels(agent=agent.agent_type).inc(response.usage.prompt_eval_count) self.metrics.tokens_eval.labels(agent=agent.agent_type).inc(response.usage.eval_count) - async def initialize( - self, - prometheus_collector: CollectorRegistry, - database: RedisDatabase): + async def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase): if self.CandidateEntity__initialized: # Initialization can only be attempted once; if there are multiple attempts, it means # a subsystem is failing or there is a logic bug in the code. @@ -183,13 +198,13 @@ class CandidateEntity(Candidate): if not prometheus_collector: raise ValueError("prometheus_collector can not be None") - + self.CandidateEntity__prometheus_collector = prometheus_collector self.CandidateEntity__metrics = Metrics(prometheus_collector=self.prometheus_collector) - + user_dir = os.path.join(defines.user_dir, self.username) - vector_db_dir=os.path.join(user_dir, defines.persist_directory) - rag_content_dir=os.path.join(user_dir, defines.rag_content_dir) + vector_db_dir = os.path.join(user_dir, defines.persist_directory) + rag_content_dir = os.path.join(user_dir, defines.rag_content_dir) os.makedirs(vector_db_dir, exist_ok=True) os.makedirs(rag_content_dir, exist_ok=True) @@ -205,17 +220,21 @@ class CandidateEntity(Candidate): ) has_username_rag = any(item.name == self.username for item in self.rags) if not has_username_rag: - self.rags.append(RagEntry( - name=self.username, - description=f"Expert data about {self.full_name}.", - )) + self.rags.append( + RagEntry( + name=self.username, + description=f"Expert data about {self.full_name}.", + ) + ) self.rag_content_size = self.file_watcher.collection.count() + class Agent(BaseModel, ABC): """ Base class for all agent types. This class defines the common attributes and methods for all agent types. """ + class Config: arbitrary_types_allowed = True # Allow arbitrary types like RedisDatabase @@ -237,7 +256,7 @@ class Agent(BaseModel, ABC): conversation: List[ChatMessageUser] = Field( default_factory=list, - description="Conversation history for this agent, used to maintain context across messages." + description="Conversation history for this agent, used to maintain context across messages.", ) @property @@ -254,9 +273,7 @@ class Agent(BaseModel, ABC): last_item = item return last_item - def set_optimal_context_size( - self, llm: Any, model: str, prompt: str, ctx_buffer=2048 - ) -> int: + def set_optimal_context_size(self, llm: Any, model: str, prompt: str, ctx_buffer=2048) -> int: # Most models average 1.3-1.5 tokens per word word_count = len(prompt.split()) tokens = int(word_count * 1.4) @@ -265,9 +282,7 @@ class Agent(BaseModel, ABC): total_ctx = tokens + ctx_buffer if total_ctx > self.context_size: - logger.info( - f"Increasing context size from {self.context_size} to {total_ctx}" - ) + logger.info(f"Increasing context size from {self.context_size} to {total_ctx}") # Grow the context size if necessary self.context_size = max(self.context_size, total_ctx) @@ -468,28 +483,28 @@ class Agent(BaseModel, ABC): """ if not rag_message.content: return "" - + context = [] for chroma_results in rag_message.content: for index, metadata in enumerate(chroma_results.metadatas): - content = "\n".join([ - line.strip() - for line in chroma_results.documents[index].split("\n") - if line - ]).strip() - context.append(f""" + content = "\n".join( + [line.strip() for line in chroma_results.documents[index].split("\n") if line] + ).strip() + context.append( + f""" Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")} Document reference: {chroma_results.ids[index]} Content: {content} -""") +""" + ) return "\n".join(context) async def generate_rag_results( self, session_id: str, prompt: str, - top_k: int=defines.default_rag_top_k, - threshold: float=defines.default_rag_threshold, + top_k: int = defines.default_rag_top_k, + threshold: float = defines.default_rag_threshold, ) -> AsyncGenerator[ApiMessage, None]: """ Generate RAG results for the given query. @@ -501,14 +516,11 @@ Content: {content} A list of dictionaries containing the RAG results. """ if not self.user: - error_message = ChatMessageError( - session_id=session_id, - content="No user set for RAG generation." - ) + error_message = ChatMessageError(session_id=session_id, content="No user set for RAG generation.") yield error_message return - - results : List[ChromaDBGetResponse] = [] + + results: List[ChromaDBGetResponse] = [] user: CandidateEntity = self.user for rag in user.rags: if not rag.enabled: @@ -517,22 +529,20 @@ Content: {content} status_message = ChatMessageStatus( session_id=session_id, activity=ApiActivityType.SEARCHING, - content = f"Searching RAG context {rag.name}..." + content=f"Searching RAG context {rag.name}...", ) yield status_message try: - chroma_results = await user.file_watcher.find_similar( - query=prompt, top_k=top_k, threshold=threshold - ) + chroma_results = await user.file_watcher.find_similar(query=prompt, top_k=top_k, threshold=threshold) if not chroma_results: continue - query_embedding = np.array(chroma_results["query_embedding"]).flatten() # type: ignore + query_embedding = np.array(chroma_results["query_embedding"]).flatten() # type: ignore - umap_2d = user.file_watcher.umap_model_2d.transform([query_embedding])[0] # type: ignore - umap_3d = user.file_watcher.umap_model_3d.transform([query_embedding])[0] # type: ignore + umap_2d = user.file_watcher.umap_model_2d.transform([query_embedding])[0] # type: ignore + umap_3d = user.file_watcher.umap_model_3d.transform([query_embedding])[0] # type: ignore - rag_metadata = ChromaDBGetResponse( + rag_metadata = ChromaDBGetResponse( name=rag.name, query=prompt, query_embedding=query_embedding.tolist(), @@ -548,7 +558,7 @@ Content: {content} continue_message = ChatMessageStatus( session_id=session_id, activity=ApiActivityType.SEARCHING, - content=f"Error searching RAG context {rag.name}: {str(e)}" + content=f"Error searching RAG context {rag.name}: {str(e)}", ) yield continue_message @@ -559,25 +569,23 @@ Content: {content} ) yield final_message return - - async def llm_one_shot( - self, - llm: Any, model: str, - session_id: str, prompt: str, system_prompt: str, - tunables: Optional[Tunables] = None, - temperature=0.7) -> AsyncGenerator[ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessage, None]: + async def llm_one_shot( + self, + llm: Any, + model: str, + session_id: str, + prompt: str, + system_prompt: str, + tunables: Optional[Tunables] = None, + temperature=0.7, + ) -> AsyncGenerator[ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessage, None]: if not self.user: - error_message = ChatMessageError( - session_id=session_id, - content="No user set for chat generation." - ) + error_message = ChatMessageError(session_id=session_id, content="No user set for chat generation.") yield error_message return - - self.set_optimal_context_size( - llm=llm, model=model, prompt=prompt+system_prompt - ) + + self.set_optimal_context_size(llm=llm, model=model, prompt=prompt + system_prompt) options = ChatOptions( seed=8911, @@ -591,9 +599,7 @@ Content: {content} ] status_message = ChatMessageStatus( - session_id=session_id, - activity=ApiActivityType.GENERATING, - content="Generating response..." + session_id=session_id, activity=ApiActivityType.GENERATING, content="Generating response..." ) yield status_message @@ -609,10 +615,7 @@ Content: {content} stream=True, ): if not response: - error_message = ChatMessageError( - session_id=session_id, - content="No response from LLM." - ) + error_message = ChatMessageError(session_id=session_id, content="No response from LLM.") yield error_message return @@ -627,49 +630,37 @@ Content: {content} yield streaming_message if not response: - error_message = ChatMessageError( - session_id=session_id, - content="No response from LLM." - ) + error_message = ChatMessageError(session_id=session_id, content="No response from LLM.") yield error_message return - + self.user.collect_metrics(agent=self, response=response) - self.context_tokens = ( - response.usage.prompt_eval_count + response.usage.eval_count - ) + self.context_tokens = response.usage.prompt_eval_count + response.usage.eval_count chat_message = ChatMessage( session_id=session_id, tunables=tunables, status=ApiStatusType.DONE, content=content, - metadata = ChatMessageMetaData( + metadata=ChatMessageMetaData( options=options, eval_count=response.usage.eval_count, eval_duration=response.usage.eval_duration, prompt_eval_count=response.usage.prompt_eval_count, prompt_eval_duration=response.usage.prompt_eval_duration, - - ) + ), ) yield chat_message return async def generate( - self, llm: Any, model: str, - session_id: str, prompt: str, - tunables: Optional[Tunables] = None, - temperature=0.7 + self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7 ) -> AsyncGenerator[ApiMessage, None]: if not self.user: - error_message = ChatMessageError( - session_id=session_id, - content="No user set for chat generation." - ) + error_message = ChatMessageError(session_id=session_id, content="No user set for chat generation.") yield error_message return - + user_message = ChatMessageUser( session_id=session_id, content=prompt, @@ -678,7 +669,7 @@ Content: {content} self.user.metrics.generate_count.labels(agent=self.agent_type).inc() with self.user.metrics.generate_duration.labels(agent=self.agent_type).time(): context = None - rag_message : ChatMessageRagSearch | None = None + rag_message: ChatMessageRagSearch | None = None if self.user: message = None async for message in self.generate_rag_results(session_id=session_id, prompt=prompt): @@ -690,37 +681,32 @@ Content: {content} yield message if not isinstance(message, ChatMessageRagSearch): - raise ValueError( - f"Expected ChatMessageRagSearch, got {type(rag_message)}" - ) - + raise ValueError(f"Expected ChatMessageRagSearch, got {type(rag_message)}") + rag_message = message context = self.get_rag_context(rag_message) # Create a pruned down message list based purely on the prompt and responses, # discarding the full preamble generated by prepare_message - messages: List[LLMMessage] = [ - LLMMessage(role="system", content=self.system_prompt) - ] + messages: List[LLMMessage] = [LLMMessage(role="system", content=self.system_prompt)] # Add the conversation history to the messages - messages.extend([ - LLMMessage(role="user" if isinstance(m, ChatMessageUser) else "assistant", content=m.content) - for m in self.conversation - ]) + messages.extend( + [ + LLMMessage(role="user" if isinstance(m, ChatMessageUser) else "assistant", content=m.content) + for m in self.conversation + ] + ) # Add the RAG context to the messages if available if context: messages.append( LLMMessage( role="user", - content=f"<|context|>\nThe following is context information about {self.user.full_name}:\n{context}\n\n\nPrompt to respond to:\n{prompt}\n" + content=f"<|context|>\nThe following is context information about {self.user.full_name}:\n{context}\n\n\nPrompt to respond to:\n{prompt}\n", ) ) else: # Only the actual user query is provided with the full context message - messages.append( - LLMMessage(role="user", content=prompt) - ) - + messages.append(LLMMessage(role="user", content=prompt)) # use_tools = message.tunables.enable_tools and len(self.context.tools) > 0 # message.metadata.tools = { @@ -824,16 +810,12 @@ Content: {content} # not use_tools status_message = ChatMessageStatus( - session_id=session_id, - activity=ApiActivityType.GENERATING, - content="Generating response..." + session_id=session_id, activity=ApiActivityType.GENERATING, content="Generating response..." ) yield status_message # Set the response for streaming - self.set_optimal_context_size( - llm, model, prompt=prompt - ) + self.set_optimal_context_size(llm, model, prompt=prompt) options = ChatOptions( seed=8911, @@ -853,10 +835,7 @@ Content: {content} stream=True, ): if not response: - error_message = ChatMessageError( - session_id=session_id, - content="No response from LLM." - ) + error_message = ChatMessageError(session_id=session_id, content="No response from LLM.") yield error_message return @@ -870,17 +849,12 @@ Content: {content} yield streaming_message if not response: - error_message = ChatMessageError( - session_id=session_id, - content="No response from LLM." - ) + error_message = ChatMessageError(session_id=session_id, content="No response from LLM.") yield error_message return - + self.user.collect_metrics(agent=self, response=response) - self.context_tokens = ( - response.usage.prompt_eval_count + response.usage.eval_count - ) + self.context_tokens = response.usage.prompt_eval_count + response.usage.eval_count end_time = time.perf_counter() chat_message = ChatMessage( @@ -888,7 +862,7 @@ Content: {content} tunables=tunables, status=ApiStatusType.DONE, content=content, - metadata = ChatMessageMetaData( + metadata=ChatMessageMetaData( options=options, eval_count=response.usage.eval_count, eval_duration=response.usage.eval_duration, @@ -899,10 +873,9 @@ Content: {content} "llm_streamed": end_time - start_time, "llm_with_tools": 0, # Placeholder for tool processing time }, - ) + ), ) - # Add the user and chat messages to the conversation self.conversation.append(user_message) self.conversation.append(chat_message) @@ -996,12 +969,13 @@ Content: {content} raise ValueError("No Markdown found in the response") + _agents: List[Agent] = [] + def get_or_create_agent( - agent_type: str, - prometheus_collector: CollectorRegistry, - user: Optional[CandidateEntity]=None) -> Agent: + agent_type: str, prometheus_collector: CollectorRegistry, user: Optional[CandidateEntity] = None +) -> Agent: """ Get or create and append a new agent of the specified type, ensuring only one agent per type exists. @@ -1025,14 +999,16 @@ def get_or_create_agent( for agent_cls in Agent.__subclasses__(): if agent_cls.model_fields["agent_type"].default == agent_type: # Create the agent instance with provided kwargs - agent = agent_cls(agent_type=agent_type, # type: ignore[call-arg] - user=user) + agent = agent_cls( + agent_type=agent_type, # type: ignore[call-arg] + user=user, + ) _agents.append(agent) return agent raise ValueError(f"No agent class found for agent_type: {agent_type}") + # Register the base agent agent_registry.register(Agent._agent_type, Agent) CandidateEntity.model_rebuild() - diff --git a/src/backend/agents/candidate_chat.py b/src/backend/agents/candidate_chat.py index 9d41307..27ee876 100644 --- a/src/backend/agents/candidate_chat.py +++ b/src/backend/agents/candidate_chat.py @@ -5,7 +5,7 @@ from .base import Agent, agent_registry from logger import logger from .registry import agent_registry -from models import ( ApiMessage, Tunables, ApiStatusType) +from models import ApiMessage, Tunables, ApiStatusType system_message = """ @@ -21,21 +21,19 @@ Always <|context|> when possible. Be concise, and never make up information. If Before answering, ensure you have spelled the candidate's name correctly. """ + class CandidateChat(Agent): """ CandidateChat Agent """ - agent_type: Literal["candidate_chat"] = "candidate_chat" # type: ignore + agent_type: Literal["candidate_chat"] = "candidate_chat" # type: ignore _agent_type: ClassVar[str] = agent_type # Add this for registration system_prompt: str = system_message async def generate( - self, llm: Any, model: str, - session_id: str, prompt: str, - tunables: Optional[Tunables] = None, - temperature=0.7 + self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7 ) -> AsyncGenerator[ApiMessage, None]: user = self.user if not user: @@ -54,12 +52,14 @@ Use that spelling instead of any spelling you may find in the <|context|>. {system_message} """ - async for message in super().generate(llm=llm, model=model, session_id=session_id, prompt=prompt, temperature=temperature, tunables=tunables): + async for message in super().generate( + llm=llm, model=model, session_id=session_id, prompt=prompt, temperature=temperature, tunables=tunables + ): if message.status == ApiStatusType.ERROR: yield message return yield message + # Register the base agent agent_registry.register(CandidateChat._agent_type, CandidateChat) - diff --git a/src/backend/agents/general.py b/src/backend/agents/general.py index ddb50a1..6670635 100644 --- a/src/backend/agents/general.py +++ b/src/backend/agents/general.py @@ -49,11 +49,13 @@ class Chat(Agent): """ Chat Agent """ - agent_type: Literal["general"] = "general" # type: ignore + + agent_type: Literal["general"] = "general" # type: ignore _agent_type: ClassVar[str] = agent_type # Add this for registration system_prompt: str = system_message + # async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]: # logger.info(f"{self.agent_type} - {inspect.stack()[0].function}") # if not self.context: diff --git a/src/backend/agents/generate_image.py b/src/backend/agents/generate_image.py index ac34f9f..d509bdc 100644 --- a/src/backend/agents/generate_image.py +++ b/src/backend/agents/generate_image.py @@ -4,8 +4,8 @@ from typing import ( ClassVar, Any, AsyncGenerator, - Optional -# override + Optional, + # override ) # NOTE: You must import Optional for late binding to work import random import time @@ -13,7 +13,15 @@ import time import os from .base import Agent, agent_registry -from models import ApiActivityType, ChatMessage, ChatMessageError, ChatMessageStatus, ChatMessageStreaming, ApiStatusType, Tunables +from models import ( + ApiActivityType, + ChatMessage, + ChatMessageError, + ChatMessageStatus, + ChatMessageStreaming, + ApiStatusType, + Tunables, +) from logger import logger import defines import backstory_traceback as traceback @@ -23,18 +31,16 @@ from image_generator.profile_image import generate_image, ImageRequest seed = int(time.time()) random.seed(seed) + class ImageGenerator(Agent): - agent_type: Literal["generate_image"] = "generate_image" # type: ignore + agent_type: Literal["generate_image"] = "generate_image" # type: ignore _agent_type: ClassVar[str] = agent_type # Add this for registration agent_persist: bool = False - system_prompt: str = "" # No system prompt is used + system_prompt: str = "" # No system prompt is used async def generate( - self, llm: Any, model: str, - session_id: str, prompt: str, - tunables: Optional[Tunables] = None, - temperature=0.7 + self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7 ) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]: if not self.user: logger.error("User is not set for ImageGenerator agent.") @@ -54,11 +60,17 @@ class ImageGenerator(Agent): yield status_message logger.info(f"Image generation: {file_path} <- {prompt}") - request = ImageRequest(filepath=file_path, session_id=session_id, prompt=prompt, iterations=4, height=256, width=256, guidance_scale=7.5) + request = ImageRequest( + filepath=file_path, + session_id=session_id, + prompt=prompt, + iterations=4, + height=256, + width=256, + guidance_scale=7.5, + ) generated_message = None - async for generated_message in generate_image( - request=request - ): + async for generated_message in generate_image(request=request): if generated_message.status == ApiStatusType.ERROR: yield generated_message return @@ -68,8 +80,7 @@ class ImageGenerator(Agent): if generated_message is None: error_message = ChatMessageError( - session_id=session_id, - content="Image generation failed to produce a valid response." + session_id=session_id, content="Image generation failed to produce a valid response." ) logger.error(f"⚠️ {error_message.content}") yield error_message @@ -78,26 +89,24 @@ class ImageGenerator(Agent): logger.info("Image generation done...") user.profile_image = "profile.png" - + # Image generated generated_image = ChatMessage( session_id=session_id, status=ApiStatusType.DONE, - content = f"{defines.api_prefix}/profile/{user.username}", - metadata=generated_message.metadata + content=f"{defines.api_prefix}/profile/{user.username}", + metadata=generated_message.metadata, ) yield generated_image return - + except Exception as e: - error_message = ChatMessageError( - session_id=session_id, - content=f"Error generating image: {str(e)}" - ) + error_message = ChatMessageError(session_id=session_id, content=f"Error generating image: {str(e)}") logger.error(traceback.format_exc()) logger.error(f"⚠️ {error_message.content}") yield error_message return + # Register the base agent agent_registry.register(ImageGenerator._agent_type, ImageGenerator) diff --git a/src/backend/agents/generate_persona.py b/src/backend/agents/generate_persona.py index 042c393..58a5e88 100644 --- a/src/backend/agents/generate_persona.py +++ b/src/backend/agents/generate_persona.py @@ -8,8 +8,8 @@ from typing import ( Tuple, AsyncGenerator, List, - Optional -# override + Optional, + # override ) # NOTE: You must import Optional for late binding to work import random import re @@ -20,7 +20,16 @@ import os import random from .base import Agent, agent_registry -from models import ApiActivityType, ChatMessage, ChatMessageError, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ApiStatusType, Tunables +from models import ( + ApiActivityType, + ChatMessage, + ChatMessageError, + ApiMessageType, + ChatMessageStatus, + ChatMessageStreaming, + ApiStatusType, + Tunables, +) from logger import logger import defines import backstory_traceback as traceback @@ -43,6 +52,7 @@ emptyUser = { "questions": [], } + def generate_persona_system_prompt(persona: Dict[str, Any]) -> str: return f"""\ You are a casting director for a movie. Your job is to provide information on ficticious personas for use in a screen play. @@ -84,6 +94,7 @@ DO NOT infer, imply, abbreviate, or state the ethnicity, gender, or age in the u You are providing those only for use later by the system when casting individuals for the role. """ + generate_resume_system_prompt = """ You are a creative writing casting director. As part of the casting, you are building backstories about individuals. The first part of that is to create an in-depth resume for the person. You will be provided with the following information: @@ -115,10 +126,12 @@ import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + class EthnicNameGenerator: def __init__(self): try: - from names_dataset import NameDataset # type: ignore + from names_dataset import NameDataset # type: ignore + self.nd = NameDataset() except ImportError: logger.error("NameDataset not available. Please install: pip install names-dataset") @@ -126,56 +139,54 @@ class EthnicNameGenerator: except Exception as e: logger.error(f"Error initializing NameDataset: {e}") self.nd = None - + # US Census 2020 approximate ethnic distribution self.ethnic_weights = { - 'White': 0.576, - 'Hispanic': 0.186, - 'Black': 0.134, - 'Asian': 0.062, - 'Native American': 0.013, - 'Pacific Islander': 0.003, - 'Mixed/Other': 0.026 + "White": 0.576, + "Hispanic": 0.186, + "Black": 0.134, + "Asian": 0.062, + "Native American": 0.013, + "Pacific Islander": 0.003, + "Mixed/Other": 0.026, } - + # Map ethnicities to countries (using alpha-2 codes that NameDataset uses) self.ethnic_country_mapping = { - 'White': ['US', 'GB', 'DE', 'IE', 'IT', 'PL', 'FR', 'CA', 'AU'], - 'Hispanic': ['MX', 'ES', 'CO', 'PE', 'AR', 'CU', 'VE', 'CL'], - 'Black': ['US'], # African American names - 'Asian': ['CN', 'IN', 'PH', 'VN', 'KR', 'JP', 'TH', 'MY'], - 'Native American': ['US'], - 'Pacific Islander': ['US'], - 'Mixed/Other': ['US'] + "White": ["US", "GB", "DE", "IE", "IT", "PL", "FR", "CA", "AU"], + "Hispanic": ["MX", "ES", "CO", "PE", "AR", "CU", "VE", "CL"], + "Black": ["US"], # African American names + "Asian": ["CN", "IN", "PH", "VN", "KR", "JP", "TH", "MY"], + "Native American": ["US"], + "Pacific Islander": ["US"], + "Mixed/Other": ["US"], } - + def get_weighted_ethnicity(self) -> str: """Select ethnicity based on US demographic weights""" ethnicities = list(self.ethnic_weights.keys()) weights = list(self.ethnic_weights.values()) return random.choices(ethnicities, weights=weights)[0] - - def get_names_by_criteria(self, countries: List[str], gender: Optional[str] = None, - n: int = 50, use_first_names: bool = True) -> List[str]: + + def get_names_by_criteria( + self, countries: List[str], gender: Optional[str] = None, n: int = 50, use_first_names: bool = True + ) -> List[str]: """Get names matching criteria using NameDataset's get_top_names method""" if not self.nd: return [] - + all_names = [] for country_code in countries: try: # Get top names for this country top_names = self.nd.get_top_names( - n=n, - use_first_names=use_first_names, - country_alpha2=country_code, - gender=gender + n=n, use_first_names=use_first_names, country_alpha2=country_code, gender=gender ) - + if country_code in top_names: if use_first_names and gender: # For first names with gender specified - gender_key = 'M' if gender.upper() in ['M', 'MALE'] else 'F' + gender_key = "M" if gender.upper() in ["M", "MALE"] else "F" if gender_key in top_names[country_code]: all_names.extend(top_names[country_code][gender_key]) elif use_first_names: @@ -185,102 +196,98 @@ class EthnicNameGenerator: else: # For last names all_names.extend(top_names[country_code]) - + except Exception as e: logger.debug(f"Error getting names for {country_code}: {e}") continue - + return list(set(all_names)) # Remove duplicates - - def get_name_by_ethnicity(self, ethnicity: str, gender: str = 'random') -> Tuple[str, str, str, str]: + + def get_name_by_ethnicity(self, ethnicity: str, gender: str = "random") -> Tuple[str, str, str, str]: """Generate a name based on ethnicity using the correct NameDataset API""" - if gender == 'random': - gender = random.choice(['Male', 'Female']) - - countries = self.ethnic_country_mapping.get(ethnicity, ['US']) - + if gender == "random": + gender = random.choice(["Male", "Female"]) + + countries = self.ethnic_country_mapping.get(ethnicity, ["US"]) + # Get first names - first_names = self.get_names_by_criteria( - countries=countries, - gender=gender, - use_first_names=True - ) - + first_names = self.get_names_by_criteria(countries=countries, gender=gender, use_first_names=True) + # Get last names - last_names = self.get_names_by_criteria( - countries=countries, - use_first_names=False - ) - + last_names = self.get_names_by_criteria(countries=countries, use_first_names=False) + # Select names or use fallbacks if first_names: first_name = random.choice(first_names) else: first_name = self._get_fallback_first_name(gender, ethnicity) logger.info(f"Using fallback first name for {ethnicity} {gender}") - + if last_names: last_name = random.choice(last_names) else: last_name = self._get_fallback_last_name(ethnicity) logger.info(f"Using fallback last name for {ethnicity}") - + return first_name, last_name, ethnicity, gender - + def _get_fallback_first_name(self, gender: str, ethnicity: str) -> str: """Provide culturally appropriate fallback first names""" fallback_names = { - 'White': { - 'Male': ['James', 'Robert', 'John', 'Michael', 'William', 'David', 'Richard', 'Joseph'], - 'Female': ['Mary', 'Patricia', 'Jennifer', 'Linda', 'Elizabeth', 'Barbara', 'Susan', 'Jessica'] + "White": { + "Male": ["James", "Robert", "John", "Michael", "William", "David", "Richard", "Joseph"], + "Female": ["Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan", "Jessica"], }, - 'Hispanic': { - 'Male': ['José', 'Luis', 'Miguel', 'Juan', 'Francisco', 'Alejandro', 'Antonio', 'Carlos'], - 'Female': ['María', 'Guadalupe', 'Juana', 'Margarita', 'Francisca', 'Teresa', 'Rosa', 'Ana'] + "Hispanic": { + "Male": ["José", "Luis", "Miguel", "Juan", "Francisco", "Alejandro", "Antonio", "Carlos"], + "Female": ["María", "Guadalupe", "Juana", "Margarita", "Francisca", "Teresa", "Rosa", "Ana"], }, - 'Black': { - 'Male': ['James', 'Robert', 'John', 'Michael', 'William', 'David', 'Richard', 'Charles'], - 'Female': ['Mary', 'Patricia', 'Linda', 'Elizabeth', 'Barbara', 'Susan', 'Jessica', 'Sarah'] + "Black": { + "Male": ["James", "Robert", "John", "Michael", "William", "David", "Richard", "Charles"], + "Female": ["Mary", "Patricia", "Linda", "Elizabeth", "Barbara", "Susan", "Jessica", "Sarah"], + }, + "Asian": { + "Male": ["Wei", "Ming", "Chen", "Li", "Kumar", "Raj", "Hiroshi", "Takeshi"], + "Female": ["Mei", "Lin", "Ling", "Priya", "Yuki", "Soo", "Hana", "Anh"], }, - 'Asian': { - 'Male': ['Wei', 'Ming', 'Chen', 'Li', 'Kumar', 'Raj', 'Hiroshi', 'Takeshi'], - 'Female': ['Mei', 'Lin', 'Ling', 'Priya', 'Yuki', 'Soo', 'Hana', 'Anh'] - } } - - ethnicity_names = fallback_names.get(ethnicity, fallback_names['White']) - return random.choice(ethnicity_names.get(gender, ethnicity_names['Male'])) - + + ethnicity_names = fallback_names.get(ethnicity, fallback_names["White"]) + return random.choice(ethnicity_names.get(gender, ethnicity_names["Male"])) + def _get_fallback_last_name(self, ethnicity: str) -> str: """Provide culturally appropriate fallback last names""" fallback_surnames = { - 'White': ['Smith', 'Johnson', 'Williams', 'Brown', 'Jones', 'Miller', 'Wilson', 'Moore'], - 'Hispanic': ['García', 'Rodríguez', 'Martínez', 'López', 'González', 'Pérez', 'Sánchez', 'Ramírez'], - 'Black': ['Johnson', 'Williams', 'Brown', 'Jones', 'Davis', 'Miller', 'Wilson', 'Moore'], - 'Asian': ['Li', 'Wang', 'Zhang', 'Liu', 'Chen', 'Yang', 'Huang', 'Zhao'] + "White": ["Smith", "Johnson", "Williams", "Brown", "Jones", "Miller", "Wilson", "Moore"], + "Hispanic": ["García", "Rodríguez", "Martínez", "López", "González", "Pérez", "Sánchez", "Ramírez"], + "Black": ["Johnson", "Williams", "Brown", "Jones", "Davis", "Miller", "Wilson", "Moore"], + "Asian": ["Li", "Wang", "Zhang", "Liu", "Chen", "Yang", "Huang", "Zhao"], } - - return random.choice(fallback_surnames.get(ethnicity, fallback_surnames['White'])) - - def generate_random_name(self, gender: str = 'random') -> Tuple[str, str, str, str]: + + return random.choice(fallback_surnames.get(ethnicity, fallback_surnames["White"])) + + def generate_random_name(self, gender: str = "random") -> Tuple[str, str, str, str]: """Generate a random name with ethnicity based on US demographics""" ethnicity = self.get_weighted_ethnicity() return self.get_name_by_ethnicity(ethnicity, gender) - - def generate_multiple_names(self, count: int = 10, gender: str = 'random') -> List[Dict]: + + def generate_multiple_names(self, count: int = 10, gender: str = "random") -> List[Dict]: """Generate multiple random names""" names = [] for _ in range(count): first, last, ethnicity, actual_gender = self.generate_random_name(gender) - names.append({ - 'full_name': f"{first} {last}", - 'first_name': first, - 'last_name': last, - 'ethnicity': ethnicity, - 'gender': actual_gender - }) + names.append( + { + "full_name": f"{first} {last}", + "first_name": first, + "last_name": last, + "ethnicity": ethnicity, + "gender": actual_gender, + } + ) return names - + + class GeneratePersona(Agent): agent_type: Literal["generate_persona"] = "generate_persona" # type: ignore _agent_type: ClassVar[str] = agent_type # Add this for registration @@ -305,30 +312,27 @@ class GeneratePersona(Agent): self.full_name = f"{self.first_name} {self.last_name}" async def generate( - self, llm: Any, model: str, - session_id: str, prompt: str, - tunables: Optional[Tunables] = None, - temperature=0.7 + self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7 ) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]: self.randomize() original_prompt = prompt.strip() persona = { -"age": self.age, -"gender": self.gender, -"ethnicity": self.ethnicity, -"full_name": self.full_name, -"first_name": self.first_name, -"last_name": self.last_name, -} - + "age": self.age, + "gender": self.gender, + "ethnicity": self.ethnicity, + "full_name": self.full_name, + "first_name": self.first_name, + "last_name": self.last_name, + } + prompt = f"""\ ```json {json.dumps(persona)} ``` """ - + if original_prompt: prompt += f""" Incorporate the following into the job description: {original_prompt} @@ -340,7 +344,8 @@ Incorporate the following into the job description: {original_prompt} logger.info("🤖 Generating persona...") generating_message = None async for generating_message in self.llm_one_shot( - llm=llm, model=model, + llm=llm, + model=model, session_id=session_id, prompt=prompt, system_prompt=generate_persona_system_prompt(persona=persona), @@ -351,11 +356,10 @@ Incorporate the following into the job description: {original_prompt} yield generating_message if generating_message.status != ApiStatusType.DONE: yield generating_message - + if not generating_message: error_message = ChatMessageError( - session_id=session_id, - content="Persona generation failed to generate a response." + session_id=session_id, content="Persona generation failed to generate a response." ) yield error_message return @@ -373,7 +377,7 @@ Incorporate the following into the job description: {original_prompt} self.username = persona.get("username", None) if not self.username: raise ValueError("LLM did not generate a username") - self.username = re.sub(r'\s+', '.', self.username) + self.username = re.sub(r"\s+", ".", self.username) user_dir = os.path.join(defines.user_dir, persona["username"]) while os.path.exists(user_dir): match = re.match(r"^(.*?)(\d*)$", persona["username"]) @@ -396,19 +400,14 @@ Incorporate the following into the job description: {original_prompt} location_parts = persona["location"].split(",") if len(location_parts) == 3: city, state, country = [part.strip() for part in location_parts] - persona["location"] = { - "city": city, - "state": state, - "country": country - } + persona["location"] = {"city": city, "state": state, "country": country} else: logger.error(f"Invalid location format: {persona['location']}") persona["location"] = None persona["is_ai"] = True except Exception as e: error_message = ChatMessageError( - session_id=session_id, - content=f"Error parsing LLM response: {str(e)}\n\n{json_str}" + session_id=session_id, content=f"Error parsing LLM response: {str(e)}\n\n{json_str}" ) logger.error(f"❌ Error parsing LLM response: {error_message.content}") logger.error(traceback.format_exc()) @@ -420,10 +419,7 @@ Incorporate the following into the job description: {original_prompt} # Persona generated persona_message = ChatMessage( - session_id=session_id, - status=ApiStatusType.DONE, - type=ApiMessageType.JSON, - content = json.dumps(persona) + session_id=session_id, status=ApiStatusType.DONE, type=ApiMessageType.JSON, content=json.dumps(persona) ) yield persona_message @@ -432,8 +428,8 @@ Incorporate the following into the job description: {original_prompt} # status_message = ChatMessageStatus( session_id=session_id, - activity = ApiActivityType.THINKING, - content = f"Generating resume for {persona['full_name']}..." + activity=ApiActivityType.THINKING, + content=f"Generating resume for {persona['full_name']}...", ) logger.info(f"🤖 {status_message.content}") yield status_message @@ -456,7 +452,8 @@ Incorporate the following into the job description: {original_prompt} Make sure at least one of the candidate's job descriptions take into account the following: {original_prompt}.""" async for generating_message in self.llm_one_shot( - llm=llm, model=model, + llm=llm, + model=model, session_id=session_id, prompt=content, system_prompt=generate_resume_system_prompt, @@ -467,11 +464,10 @@ Make sure at least one of the candidate's job descriptions take into account the raise Exception(generating_message.content) if generating_message.status != ApiStatusType.DONE: yield generating_message - + if not generating_message: error_message = ChatMessageError( - session_id=session_id, - content="Resume generation failed to generate a response." + session_id=session_id, content="Resume generation failed to generate a response." ) logger.error(f"❌ {error_message.content}") yield error_message @@ -479,10 +475,7 @@ Make sure at least one of the candidate's job descriptions take into account the resume = self.extract_markdown_from_text(generating_message.content) resume_message = ChatMessage( - session_id=session_id, - status=ApiStatusType.DONE, - type=ApiMessageType.TEXT, - content=resume + session_id=session_id, status=ApiStatusType.DONE, type=ApiMessageType.TEXT, content=resume ) yield resume_message return @@ -502,5 +495,6 @@ Make sure at least one of the candidate's job descriptions take into account the raise ValueError("No JSON found in the response") + # Register the base agent agent_registry.register(GeneratePersona._agent_type, GeneratePersona) diff --git a/src/backend/agents/generate_resume.py b/src/backend/agents/generate_resume.py index 14a571d..ab95884 100644 --- a/src/backend/agents/generate_resume.py +++ b/src/backend/agents/generate_resume.py @@ -4,23 +4,31 @@ from typing import ( ClassVar, Any, AsyncGenerator, - List -# override + List, + # override ) # NOTE: You must import Optional for late binding to work import json from logger import logger from .base import Agent, agent_registry -from models import (ApiActivityType, ApiMessage, ApiStatusType, ChatMessage, ChatMessageError, ChatMessageResume, ChatMessageStatus, SkillAssessment, SkillStrength) +from models import ( + ApiActivityType, + ApiMessage, + ApiStatusType, + ChatMessage, + ChatMessageError, + ChatMessageResume, + ChatMessageStatus, + SkillAssessment, + SkillStrength, +) + class GenerateResume(Agent): - agent_type: Literal["generate_resume"] = "generate_resume" # type: ignore + agent_type: Literal["generate_resume"] = "generate_resume" # type: ignore _agent_type: ClassVar[str] = agent_type # Add this for registration - def generate_resume_prompt( - self, - skills: List[SkillAssessment] - ): + def generate_resume_prompt(self, skills: List[SkillAssessment]): """ Generate a professional resume based on skill assessment results @@ -35,13 +43,13 @@ class GenerateResume(Agent): """ if not self.user: raise ValueError("User must be bound to agent") - + # Extract and organize skill assessment data skills_by_strength = { - SkillStrength.STRONG: [], - SkillStrength.MODERATE: [], - SkillStrength.WEAK: [], - SkillStrength.NONE: [] + SkillStrength.STRONG: [], + SkillStrength.MODERATE: [], + SkillStrength.WEAK: [], + SkillStrength.NONE: [], } experience_evidence = {} @@ -63,11 +71,7 @@ class GenerateResume(Agent): experience_evidence[source] = [] experience_evidence[source].append( - { - "skill": skill, - "quote": evidence.quote, - "context": evidence.context - } + {"skill": skill, "quote": evidence.quote, "context": evidence.context} ) # Build the system prompt @@ -103,7 +107,7 @@ Phone: {self.user.phone or 'N/A'} ### Weaker Skills (mentioned or implied): {", ".join(skills_by_strength[SkillStrength.WEAK])} """ - + system_prompt += """\ ## EXPERIENCE EVIDENCE: @@ -167,16 +171,16 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme ) -> AsyncGenerator[ApiMessage, None]: # Stage 1A: Analyze job requirements status_message = ChatMessageStatus( - session_id=session_id, - content = "Analyzing job requirements", - activity=ApiActivityType.THINKING + session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING ) yield status_message system_prompt, prompt = self.generate_resume_prompt(skills=skills) generated_message = None - async for generated_message in self.llm_one_shot(llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt): + async for generated_message in self.llm_one_shot( + llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt + ): if generated_message.status == ApiStatusType.ERROR: yield generated_message return @@ -185,22 +189,20 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme if not generated_message: error_message = ChatMessageError( - session_id=session_id, - content="Job requirements analysis failed to generate a response." + session_id=session_id, content="Job requirements analysis failed to generate a response." ) logger.error(f"⚠️ {error_message.content}") yield error_message return - + if not isinstance(generated_message, ChatMessage): error_message = ChatMessageError( - session_id=session_id, - content="Job requirements analysis did not return a valid message." + session_id=session_id, content="Job requirements analysis did not return a valid message." ) logger.error(f"⚠️ {error_message.content}") yield error_message return - + resume_message = ChatMessageResume( session_id=session_id, status=ApiStatusType.DONE, @@ -213,6 +215,7 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme yield resume_message logger.info("✅ Resume generation completed successfully.") return - + + # Register the base agent agent_registry.register(GenerateResume._agent_type, GenerateResume) diff --git a/src/backend/agents/job_requirements.py b/src/backend/agents/job_requirements.py index b3e7326..2c62a24 100644 --- a/src/backend/agents/job_requirements.py +++ b/src/backend/agents/job_requirements.py @@ -5,17 +5,30 @@ from typing import ( ClassVar, Any, AsyncGenerator, - Optional -# override + Optional, + # override ) # NOTE: You must import Optional for late binding to work import inspect import json from .base import Agent, agent_registry -from models import ApiActivityType, ApiMessage, ChatMessage, ChatMessageError, ChatMessageStatus, ChatMessageStreaming, ApiStatusType, Job, JobRequirements, JobRequirementsMessage, Tunables +from models import ( + ApiActivityType, + ApiMessage, + ChatMessage, + ChatMessageError, + ChatMessageStatus, + ChatMessageStreaming, + ApiStatusType, + Job, + JobRequirements, + JobRequirementsMessage, + Tunables, +) from logger import logger import backstory_traceback as traceback + class JobRequirementsAgent(Agent): agent_type: Literal["job_requirements"] = "job_requirements" # type: ignore _agent_type: ClassVar[str] = agent_type # Add this for registration @@ -91,14 +104,14 @@ Avoid vague categorizations and be precise about whether skills are explicitly r """Analyze job requirements from job description.""" system_prompt, prompt = self.create_job_analysis_prompt(prompt) status_message = ChatMessageStatus( - session_id=session_id, - content="Analyzing job requirements", - activity=ApiActivityType.THINKING + session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING ) yield status_message logger.info(f"🔍 {status_message.content}") generated_message = None - async for generated_message in self.llm_one_shot(llm, model, session_id=session_id, prompt=prompt, system_prompt=system_prompt): + async for generated_message in self.llm_one_shot( + llm, model, session_id=session_id, prompt=prompt, system_prompt=system_prompt + ): if generated_message.status == ApiStatusType.ERROR: yield generated_message return @@ -107,12 +120,12 @@ Avoid vague categorizations and be precise about whether skills are explicitly r if not generated_message: error_message = ChatMessageError( - session_id=session_id, - content="Job requirements analysis failed to generate a response.") + session_id=session_id, content="Job requirements analysis failed to generate a response." + ) logger.error(f"⚠️ {error_message.content}") yield error_message return - + yield generated_message return @@ -129,39 +142,34 @@ Avoid vague categorizations and be precise about whether skills are explicitly r display = { "technical_skills": { "required": reqs.technical_skills.required, - "preferred": reqs.technical_skills.preferred + "preferred": reqs.technical_skills.preferred, }, "experience_requirements": { "required": reqs.experience_requirements.required, - "preferred": reqs.experience_requirements.preferred + "preferred": reqs.experience_requirements.preferred, }, "soft_skills": reqs.soft_skills, "experience": reqs.experience, "education": reqs.education, "certifications": reqs.certifications, "preferred_attributes": reqs.preferred_attributes, - "company_values": reqs.company_values + "company_values": reqs.company_values, } - + return display async def generate( self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7 ) -> AsyncGenerator[ApiMessage, None]: if not self.user: - error_message = ChatMessageError( - session_id=session_id, - content="User is not set for this agent." - ) + error_message = ChatMessageError(session_id=session_id, content="User is not set for this agent.") logger.error(f"⚠️ {error_message.content}") yield error_message return - + # Stage 1A: Analyze job requirements status_message = ChatMessageStatus( - session_id=session_id, - content = "Analyzing job requirements", - activity=ApiActivityType.THINKING + session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING ) yield status_message @@ -175,13 +183,12 @@ Avoid vague categorizations and be precise about whether skills are explicitly r if not generated_message: error_message = ChatMessageError( - session_id=session_id, - content="Job requirements analysis failed to generate a response." + session_id=session_id, content="Job requirements analysis failed to generate a response." ) logger.error(f"⚠️ {error_message.content}") yield error_message return - + requirements = None job_requirements_data = "" company = "" @@ -211,7 +218,9 @@ Avoid vague categorizations and be precise about whether skills are explicitly r return except Exception as e: status_message.status = ApiStatusType.ERROR - status_message.content = f"Unexpected error processing job requirements: {str(e)}\n\n{job_requirements_data}" + status_message.content = ( + f"Unexpected error processing job requirements: {str(e)}\n\n{job_requirements_data}" + ) logger.error(traceback.format_exc()) logger.error(f"⚠️ {status_message.content}") yield status_message @@ -238,5 +247,6 @@ Avoid vague categorizations and be precise about whether skills are explicitly r logger.info("✅ Job requirements analysis completed successfully.") return + # Register the base agent agent_registry.register(JobRequirementsAgent._agent_type, JobRequirementsAgent) diff --git a/src/backend/agents/rag_search.py b/src/backend/agents/rag_search.py index 8fe27eb..838e4e0 100644 --- a/src/backend/agents/rag_search.py +++ b/src/backend/agents/rag_search.py @@ -5,20 +5,19 @@ from .base import Agent, agent_registry from logger import logger from .registry import agent_registry -from models import ( ApiMessage, ApiStatusType, ChatMessageError, ChatMessageRagSearch, ApiStatusType, Tunables ) +from models import ApiMessage, ApiStatusType, ChatMessageError, ChatMessageRagSearch, ApiStatusType, Tunables + class Chat(Agent): """ Chat Agent """ - agent_type: Literal["rag_search"] = "rag_search" # type: ignore + + agent_type: Literal["rag_search"] = "rag_search" # type: ignore _agent_type: ClassVar[str] = agent_type # Add this for registration async def generate( - self, llm: Any, model: str, - session_id: str, prompt: str, - tunables: Optional[Tunables] = None, - temperature=0.7 + self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7 ) -> AsyncGenerator[ApiMessage, None]: """ Generate a response based on the user message and the provided LLM. @@ -44,14 +43,14 @@ class Chat(Agent): if not isinstance(rag_message, ChatMessageRagSearch): logger.error(f"Expected ChatMessageRagSearch, got {type(rag_message)}") error_message = ChatMessageError( - session_id=session_id, - content="RAG search did not return a valid response." + session_id=session_id, content="RAG search did not return a valid response." ) yield error_message return - + rag_message.status = ApiStatusType.DONE yield rag_message + # Register the base agent agent_registry.register(Chat._agent_type, Chat) diff --git a/src/backend/agents/registry.py b/src/backend/agents/registry.py index 226aa86..d5d87e8 100644 --- a/src/backend/agents/registry.py +++ b/src/backend/agents/registry.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import List, Dict, Optional, Type + # We'll use a registry pattern rather than hardcoded strings class AgentRegistry: """Registry for agent types and classes""" diff --git a/src/backend/agents/skill_match.py b/src/backend/agents/skill_match.py index ccb8e1b..f0d3250 100644 --- a/src/backend/agents/skill_match.py +++ b/src/backend/agents/skill_match.py @@ -4,19 +4,29 @@ from typing import ( ClassVar, Any, AsyncGenerator, - Optional -# override + Optional, + # override ) # NOTE: You must import Optional for late binding to work import json from .base import Agent, agent_registry -from models import (ApiMessage, ChatMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageSkillAssessment, ApiStatusType, EvidenceDetail, -SkillAssessment, Tunables) +from models import ( + ApiMessage, + ChatMessage, + ChatMessageError, + ChatMessageRagSearch, + ChatMessageSkillAssessment, + ApiStatusType, + EvidenceDetail, + SkillAssessment, + Tunables, +) from logger import logger import backstory_traceback as traceback + class SkillMatchAgent(Agent): - agent_type: Literal["skill_match"] = "skill_match" # type: ignore + agent_type: Literal["skill_match"] = "skill_match" # type: ignore _agent_type: ClassVar[str] = agent_type # Add this for registration def generate_skill_assessment_prompt(self, skill, rag_context): @@ -96,30 +106,24 @@ JSON RESPONSE:""" return system_prompt, prompt async def generate( - self, llm: Any, model: str, - session_id: str, prompt: str, - tunables: Optional[Tunables] = None, - temperature=0.7 + self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7 ) -> AsyncGenerator[ApiMessage, None]: if not self.user: error_message = ChatMessageError( session_id=session_id, - content="Agent not attached to user. Attach the agent to a user before generating responses." + content="Agent not attached to user. Attach the agent to a user before generating responses.", ) logger.error(f"⚠️ {error_message.content}") yield error_message return - + skill = prompt.strip() if not skill: - error_message = ChatMessageError( - session_id=session_id, - content="Skill cannot be empty." - ) + error_message = ChatMessageError(session_id=session_id, content="Skill cannot be empty.") logger.error(f"⚠️ {error_message.content}") yield error_message return - + generated_message = None async for generated_message in self.generate_rag_results(session_id=session_id, prompt=skill): if generated_message.status == ApiStatusType.ERROR: @@ -130,8 +134,7 @@ JSON RESPONSE:""" if generated_message is None: error_message = ChatMessageError( - session_id=session_id, - content="RAG search did not return a valid response." + session_id=session_id, content="RAG search did not return a valid response." ) logger.error(f"⚠️ {error_message.content}") yield error_message @@ -140,19 +143,20 @@ JSON RESPONSE:""" if not isinstance(generated_message, ChatMessageRagSearch): logger.error(f"Expected ChatMessageRagSearch, got {type(generated_message)}") error_message = ChatMessageError( - session_id=session_id, - content="RAG search did not return a valid response." + session_id=session_id, content="RAG search did not return a valid response." ) yield error_message return - rag_message : ChatMessageRagSearch = generated_message + rag_message: ChatMessageRagSearch = generated_message rag_context = self.get_rag_context(rag_message) logger.info(f"🔍 RAG content retrieved {len(rag_context)} bytes of context") system_prompt, prompt = self.generate_skill_assessment_prompt(skill=skill, rag_context=rag_context) generated_message = None - async for generated_message in self.llm_one_shot(llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt, temperature=0.7): + async for generated_message in self.llm_one_shot( + llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt, temperature=0.7 + ): if generated_message.status == ApiStatusType.ERROR: logger.error(f"⚠️ {generated_message.content}") yield generated_message @@ -162,8 +166,7 @@ JSON RESPONSE:""" if generated_message is None: error_message = ChatMessageError( - session_id=session_id, - content="Skill assessment failed to generate a response." + session_id=session_id, content="Skill assessment failed to generate a response." ) logger.error(f"⚠️ {error_message.content}") yield error_message @@ -171,8 +174,7 @@ JSON RESPONSE:""" if not isinstance(generated_message, ChatMessage): error_message = ChatMessageError( - session_id=session_id, - content="Skill assessment did not return a valid message." + session_id=session_id, content="Skill assessment did not return a valid message." ) logger.error(f"⚠️ {error_message.content}") yield error_message @@ -195,20 +197,21 @@ JSON RESPONSE:""" EvidenceDetail( source=evidence.get("source", ""), quote=evidence.get("quote", ""), - context=evidence.get("context", "") - ) for evidence in skill_assessment_data.get("evidence_details", []) - ] + context=evidence.get("context", ""), + ) + for evidence in skill_assessment_data.get("evidence_details", []) + ], ) except Exception as e: error_message = ChatMessageError( session_id=session_id, - content=f"Failed to parse Skill assessment JSON: {str(e)}\n\n{generated_message.content}\n\nJSON:\n{json_str}\n\n" + content=f"Failed to parse Skill assessment JSON: {str(e)}\n\n{generated_message.content}\n\nJSON:\n{json_str}\n\n", ) logger.error(traceback.format_exc()) logger.error(f"⚠️ {error_message.content}") yield error_message return - + # if skill_assessment.evidence_strength == "none": # logger.info("⚠️ No evidence found for skill assessment, returning NONE.") # with open("src/tmp.txt", "w") as f: @@ -232,5 +235,6 @@ JSON RESPONSE:""" logger.info("✅ Skill assessment completed successfully.") return + # Register the base agent agent_registry.register(SkillMatchAgent._agent_type, SkillMatchAgent) diff --git a/src/backend/background_tasks.py b/src/backend/background_tasks.py index c2e298e..03fbb39 100644 --- a/src/backend/background_tasks.py +++ b/src/backend/background_tasks.py @@ -9,99 +9,103 @@ from typing import Optional, List, Dict, Any, Callable from logger import logger from database.manager import DatabaseManager + class BackgroundTaskManager: """Manages background tasks for the application using asyncio instead of threading""" - + def __init__(self, database_manager: DatabaseManager): self.database_manager = database_manager self.running = False self.tasks: List[asyncio.Task] = [] self.main_loop: Optional[asyncio.AbstractEventLoop] = None - + async def cleanup_inactive_guests(self, inactive_hours: int = 24): """Clean up inactive guest sessions""" try: if self.database_manager.is_shutting_down: logger.info("Skipping guest cleanup - application shutting down") return 0 - + database = self.database_manager.get_database() cleaned_count = await database.cleanup_inactive_guests(inactive_hours) - + if cleaned_count > 0: logger.info(f"🧹 Background cleanup: removed {cleaned_count} inactive guest sessions") - + return cleaned_count except Exception as e: logger.error(f"❌ Error in guest cleanup: {e}") return 0 - + async def cleanup_expired_verification_tokens(self): """Clean up expired email verification tokens""" try: if self.database_manager.is_shutting_down: logger.info("Skipping token cleanup - application shutting down") return 0 - + database = self.database_manager.get_database() cleaned_count = await database.cleanup_expired_verification_tokens() - + if cleaned_count > 0: logger.info(f"🧹 Background cleanup: removed {cleaned_count} expired verification tokens") - + return cleaned_count except Exception as e: logger.error(f"❌ Error in verification token cleanup: {e}") return 0 - + async def update_guest_statistics(self): """Update guest usage statistics""" try: if self.database_manager.is_shutting_down: logger.info("Skipping stats update - application shutting down") return {} - + database = self.database_manager.get_database() stats = await database.get_guest_statistics() - + # Log interesting statistics - if stats.get('total_guests', 0) > 0: - logger.info(f"📊 Guest stats: {stats['total_guests']} total, " - f"{stats['active_last_hour']} active in last hour, " - f"{stats['converted_guests']} converted") - + if stats.get("total_guests", 0) > 0: + logger.info( + f"📊 Guest stats: {stats['total_guests']} total, " + f"{stats['active_last_hour']} active in last hour, " + f"{stats['converted_guests']} converted" + ) + return stats except Exception as e: logger.error(f"❌ Error updating guest statistics: {e}") return {} - + async def cleanup_old_rate_limit_data(self, days_old: int = 7): """Clean up old rate limiting data""" try: if self.database_manager.is_shutting_down: logger.info("Skipping rate limit cleanup - application shutting down") return 0 - + # Get Redis client safely (using the event loop safe method) from database.manager import redis_manager + redis = await redis_manager.get_client() - + # Clean up rate limit keys older than specified days cutoff_time = datetime.now(UTC) - timedelta(days=days_old) pattern = "rate_limit:*" - + cursor = 0 deleted_count = 0 - + while True: cursor, keys = await redis.scan(cursor, match=pattern, count=100) - + for key in keys: # Check if key is old enough to delete try: ttl = await redis.ttl(key) if ttl == -1: # No expiration set, check creation time - creation_time = await redis.hget(key, "created_at") # type: ignore + creation_time = await redis.hget(key, "created_at") # type: ignore if creation_time: creation_time = datetime.fromisoformat(creation_time).replace(tzinfo=UTC) if creation_time < cutoff_time: @@ -110,41 +114,41 @@ class BackgroundTaskManager: deleted_count += 1 except Exception: continue - + if cursor == 0: break - + if deleted_count > 0: logger.info(f"🧹 Cleaned up {deleted_count} old rate limit keys") - + return deleted_count except Exception as e: logger.error(f"❌ Error cleaning up rate limit data: {e}") return 0 - + async def cleanup_orphaned_data(self): """Clean up orphaned database records""" try: if self.database_manager.is_shutting_down: return 0 - + database = self.database_manager.get_database() - + # Clean up orphaned job requirements orphaned_count = await database.cleanup_orphaned_job_requirements() - + if orphaned_count > 0: logger.info(f"🧹 Cleaned up {orphaned_count} orphaned job requirements") - + return orphaned_count except Exception as e: logger.error(f"❌ Error cleaning up orphaned data: {e}") return 0 - + async def _run_periodic_task(self, name: str, task_func: Callable, interval_seconds: int, *args, **kwargs): """Run a periodic task safely in the same event loop""" logger.info(f"🔄 Starting periodic task: {name} (every {interval_seconds}s)") - + while self.running: try: # Verify we're still in the correct event loop @@ -152,34 +156,34 @@ class BackgroundTaskManager: if current_loop != self.main_loop: logger.error(f"Task {name} detected event loop change! Stopping.") break - + # Run the task await task_func(*args, **kwargs) - + except asyncio.CancelledError: logger.info(f"Periodic task {name} was cancelled") break except Exception as e: logger.error(f"❌ Error in periodic task {name}: {e}") # Continue running despite errors - + # Sleep with cancellation support try: await asyncio.sleep(interval_seconds) except asyncio.CancelledError: logger.info(f"Periodic task {name} cancelled during sleep") break - + async def start(self): """Start all background tasks in the current event loop""" if self.running: logger.warning("⚠️ Background task manager already running") return - + # Store the current event loop self.main_loop = asyncio.get_running_loop() self.running = True - + # Define periodic tasks with their intervals (in seconds) periodic_tasks = [ # (name, function, interval_seconds, *args) @@ -189,68 +193,62 @@ class BackgroundTaskManager: ("rate_limit_cleanup", self.cleanup_old_rate_limit_data, 24 * 3600, 7), # Daily, cleanup 7 days old ("orphaned_cleanup", self.cleanup_orphaned_data, 6 * 3600), # Every 6 hours ] - + # Create asyncio tasks for each periodic task for name, func, interval, *args in periodic_tasks: - task = asyncio.create_task( - self._run_periodic_task(name, func, interval, *args), - name=f"background_{name}" - ) + task = asyncio.create_task(self._run_periodic_task(name, func, interval, *args), name=f"background_{name}") self.tasks.append(task) logger.info(f"📅 Scheduled background task: {name}") - + # Run initial cleanup tasks immediately (but don't wait for them) asyncio.create_task(self._run_initial_cleanup(), name="initial_cleanup") - + logger.info("🚀 Background task manager started with asyncio tasks") - + async def _run_initial_cleanup(self): """Run some cleanup tasks immediately on startup""" try: logger.info("🧹 Running initial cleanup tasks...") - + # Clean up expired tokens immediately await asyncio.sleep(5) # Give the app time to fully start await self.cleanup_expired_verification_tokens() - + # Clean up very old inactive guests (7 days old) await self.cleanup_inactive_guests(inactive_hours=7 * 24) - + # Update statistics await self.update_guest_statistics() - + logger.info("✅ Initial cleanup tasks completed") - + except Exception as e: logger.error(f"❌ Error in initial cleanup: {e}") - + async def stop(self): """Stop all background tasks gracefully""" logger.info("🛑 Stopping background task manager...") - + self.running = False - + # Cancel all running tasks for task in self.tasks: if not task.done(): task.cancel() - + # Wait for all tasks to complete with timeout if self.tasks: try: - await asyncio.wait_for( - asyncio.gather(*self.tasks, return_exceptions=True), - timeout=30.0 - ) + await asyncio.wait_for(asyncio.gather(*self.tasks, return_exceptions=True), timeout=30.0) logger.info("✅ All background tasks stopped gracefully") except asyncio.TimeoutError: logger.warning("⚠️ Some background tasks did not stop within timeout") - + self.tasks.clear() self.main_loop = None - + logger.info("🛑 Background task manager stopped") - + async def get_task_status(self) -> Dict[str, Any]: """Get status of all background tasks""" status = { @@ -258,23 +256,23 @@ class BackgroundTaskManager: "main_loop_id": id(self.main_loop) if self.main_loop else None, "current_loop_id": None, "task_count": len(self.tasks), - "tasks": [] + "tasks": [], } - + try: current_loop = asyncio.get_running_loop() status["current_loop_id"] = id(current_loop) status["loop_matches"] = (id(current_loop) == id(self.main_loop)) if self.main_loop else False except RuntimeError: status["current_loop_id"] = "no_running_loop" - + for task in self.tasks: task_info = { "name": task.get_name(), "done": task.done(), "cancelled": task.cancelled(), } - + if task.done() and not task.cancelled(): try: task.result() # This will raise an exception if the task failed @@ -286,11 +284,11 @@ class BackgroundTaskManager: task_info["status"] = "cancelled" else: task_info["status"] = "running" - + status["tasks"].append(task_info) - + return status - + async def force_run_task(self, task_name: str) -> Any: """Manually trigger a specific background task""" task_map = { @@ -300,10 +298,10 @@ class BackgroundTaskManager: "rate_limit_cleanup": self.cleanup_old_rate_limit_data, "orphaned_cleanup": self.cleanup_orphaned_data, } - + if task_name not in task_map: raise ValueError(f"Unknown task: {task_name}. Available: {list(task_map.keys())}") - + logger.info(f"🔧 Manually running task: {task_name}") result = await task_map[task_name]() logger.info(f"✅ Manual task {task_name} completed") @@ -317,11 +315,12 @@ async def setup_background_tasks(database_manager: DatabaseManager) -> Backgroun await task_manager.start() return task_manager + # For integration with your existing app startup async def initialize_with_background_tasks(database_manager: DatabaseManager): """Initialize database and background tasks together""" # Start background tasks background_tasks = await setup_background_tasks(database_manager) - + # Return both for your app to manage - return database_manager, background_tasks \ No newline at end of file + return database_manager, background_tasks diff --git a/src/backend/backstory_traceback.py b/src/backend/backstory_traceback.py index bc6f3ab..425d20a 100644 --- a/src/backend/backstory_traceback.py +++ b/src/backend/backstory_traceback.py @@ -3,21 +3,22 @@ import os import sys import defines + def filter_traceback(tb, app_path=None, module_name=None): """ Filter traceback to include only frames from the specified application path or module. - + Args: tb: Traceback object (e.g., from sys.exc_info()[2]) app_path: Directory path of your application (e.g., '/path/to/your/app') module_name: Name of the module to include (e.g., 'myapp') - + Returns: Formatted traceback string with filtered frames. """ # Extract stack frames stack = traceback.extract_tb(tb) - + # Filter frames based on app_path or module_name filtered_stack = [] for frame in stack: @@ -27,25 +28,26 @@ def filter_traceback(tb, app_path=None, module_name=None): filtered_stack.append(frame) elif module_name and frame.filename.startswith(module_name): filtered_stack.append(frame) - + # Format the filtered stack trace formatted_stack = traceback.format_list(filtered_stack) - + # Get exception info to include the exception type and message exc_type, exc_value, _ = sys.exc_info() formatted_exc = traceback.format_exception_only(exc_type, exc_value) - + # Combine the filtered stack trace with the exception message - return ''.join(formatted_stack + formatted_exc) + return "".join(formatted_stack + formatted_exc) + def format_exc(app_path=defines.app_path, module_name=None): """ Custom version of traceback.format_exc() that filters stack frames. - + Args: app_path: Directory path of your application module_name: Name of the module to include - + Returns: Formatted traceback string with only relevant frames. """ diff --git a/src/backend/database/__init__.py b/src/backend/database/__init__.py index b320fb4..832d370 100644 --- a/src/backend/database/__init__.py +++ b/src/backend/database/__init__.py @@ -1,4 +1,4 @@ from .core import RedisDatabase from .manager import DatabaseManager, redis_manager -__all__ = ['RedisDatabase', 'DatabaseManager', 'redis_manager'] \ No newline at end of file +__all__ = ["RedisDatabase", "DatabaseManager", "redis_manager"] diff --git a/src/backend/database/constants.py b/src/backend/database/constants.py index e9dfd3d..d2eae49 100644 --- a/src/backend/database/constants.py +++ b/src/backend/database/constants.py @@ -1,15 +1,15 @@ KEY_PREFIXES = { - 'viewers': 'viewer:', - 'candidates': 'candidate:', - 'employers': 'employer:', - 'jobs': 'job:', - 'job_applications': 'job_application:', - 'chat_sessions': 'chat_session:', - 'chat_messages': 'chat_messages:', - 'ai_parameters': 'ai_parameters:', - 'users': 'user:', - 'candidate_documents': 'candidate_documents:', - 'job_requirements': 'job_requirements:', - 'resumes': 'resume:', - 'user_resumes': 'user_resumes:', + "viewers": "viewer:", + "candidates": "candidate:", + "employers": "employer:", + "jobs": "job:", + "job_applications": "job_application:", + "chat_sessions": "chat_session:", + "chat_messages": "chat_messages:", + "ai_parameters": "ai_parameters:", + "users": "user:", + "candidate_documents": "candidate_documents:", + "job_requirements": "job_requirements:", + "resumes": "resume:", + "user_resumes": "user_resumes:", } diff --git a/src/backend/database/core.py b/src/backend/database/core.py index 4122153..71485d2 100644 --- a/src/backend/database/core.py +++ b/src/backend/database/core.py @@ -8,14 +8,15 @@ from .mixins.auth import AuthMixin from .mixins.analytics import AnalyticsMixin from .mixins.job import JobMixin from .mixins.skill import SkillMixin -from .mixins.ai import AIMixin +from .mixins.ai import AIMixin -# RedisDatabase is the main class that combines all mixins for a + +# RedisDatabase is the main class that combines all mixins for a # comprehensive Redis database interface. class RedisDatabase( AIMixin, BaseMixin, - ResumeMixin, + ResumeMixin, DocumentMixin, UserMixin, ChatMixin, @@ -27,7 +28,7 @@ class RedisDatabase( """ Main Redis database class combining all mixins """ - + def __init__(self, redis: Redis): self.redis = redis - super().__init__() \ No newline at end of file + super().__init__() diff --git a/src/backend/database/manager.py b/src/backend/database/manager.py index 2e61c70..e2a36a8 100644 --- a/src/backend/database/manager.py +++ b/src/backend/database/manager.py @@ -1,4 +1,4 @@ -from redis.asyncio import (Redis, ConnectionPool) +from redis.asyncio import Redis, ConnectionPool from typing import Optional, Optional import json import logging @@ -9,7 +9,8 @@ from .core import RedisDatabase logger = logging.getLogger(__name__) -# _RedisManager is a singleton class that manages the Redis connection and + +# _RedisManager is a singleton class that manages the Redis connection and # provides methods for connecting, disconnecting, and performing health checks. # # It uses connection pooling for better performance and resource management. @@ -18,14 +19,14 @@ class _RedisManager: self.redis: Optional[Redis] = None self.redis_url = os.getenv("REDIS_URL", "redis://redis:6379") self.redis_db = int(os.getenv("REDIS_DB", "0")) - + # Append database to URL if not already present if not self.redis_url.endswith(f"/{self.redis_db}"): self.redis_url = f"{self.redis_url}/{self.redis_db}" - + self._connection_pool: Optional[ConnectionPool] = None self._is_connected = False - + async def connect(self): """Initialize Redis connection with connection pooling""" if self._is_connected and self.redis: @@ -42,13 +43,11 @@ class _RedisManager: retry_on_timeout=True, socket_keepalive=True, socket_keepalive_options={}, - health_check_interval=30 + health_check_interval=30, ) - - self.redis = Redis( - connection_pool=self._connection_pool - ) - + + self.redis = Redis(connection_pool=self._connection_pool) + if not self.redis: raise RuntimeError("Redis client not initialized") @@ -56,67 +55,67 @@ class _RedisManager: await self.redis.ping() self._is_connected = True logger.info("Successfully connected to Redis") - + # Log Redis info info = await self.redis.info() logger.info(f"Redis version: {info.get('redis_version', 'unknown')}") - + except Exception as e: logger.error(f"Failed to connect to Redis: {e}") self._is_connected = False self.redis = None self._connection_pool = None raise - + async def disconnect(self): """Close Redis connection gracefully""" if not self._is_connected: logger.info("Redis already disconnected") return - + try: if self.redis: # Wait for any pending operations to complete await asyncio.sleep(0.1) - + # Close the client await self.redis.aclose() logger.info("Redis client closed") - + if self._connection_pool: # Close the connection pool await self._connection_pool.aclose() logger.info("Redis connection pool closed") - + self._is_connected = False self.redis = None self._connection_pool = None - + logger.info("Successfully disconnected from Redis") - + except Exception as e: logger.error(f"Error during Redis disconnect: {e}") # Force cleanup even if there's an error self._is_connected = False self.redis = None self._connection_pool = None - + def get_client(self) -> Redis: """Get Redis client instance""" if not self._is_connected or not self.redis: raise RuntimeError("Redis client not initialized or disconnected") return self.redis - + @property def is_connected(self) -> bool: """Check if Redis is connected""" return self._is_connected and self.redis is not None - + async def health_check(self) -> dict: """Perform health check on Redis connection""" if not self.is_connected: return {"status": "disconnected", "error": "Redis not connected"} - + if not self.redis: raise RuntimeError("Redis client not initialized") @@ -124,25 +123,25 @@ class _RedisManager: # Test basic operations await self.redis.ping() info = await self.redis.info() - + return { "status": "healthy", "redis_version": info.get("redis_version", "unknown"), "uptime_seconds": info.get("uptime_in_seconds", 0), "connected_clients": info.get("connected_clients", 0), "used_memory_human": info.get("used_memory_human", "unknown"), - "total_commands_processed": info.get("total_commands_processed", 0) + "total_commands_processed": info.get("total_commands_processed", 0), } except Exception as e: logger.error(f"Redis health check failed: {e}") return {"status": "error", "error": str(e)} - + async def force_save(self, background: bool = True) -> bool: """Force Redis to save data to disk""" if not self.is_connected: logger.warning("Cannot save: Redis not connected") return False - + try: if not self.redis: raise RuntimeError("Redis client not initialized") @@ -159,12 +158,12 @@ class _RedisManager: except Exception as e: logger.error(f"Redis save failed: {e}") return False - + async def get_info(self) -> Optional[dict]: """Get Redis server information""" if not self.is_connected: return None - + try: if not self.redis: raise RuntimeError("Redis client not initialized") @@ -173,49 +172,51 @@ class _RedisManager: logger.error(f"Failed to get Redis info: {e}") return None + # Global Redis manager instance redis_manager = _RedisManager() + # DatabaseManager is an enhanced database manager that provides graceful shutdown capabilities # It manages the Redis connection, tracks active requests, and allows for data backup before shutdown. class DatabaseManager: """Enhanced database manager with graceful shutdown capabilities""" - + def __init__(self): self.db: Optional[RedisDatabase] = None self._shutdown_initiated = False self._active_requests = 0 self._shutdown_timeout = int(os.getenv("SHUTDOWN_TIMEOUT", "30")) # seconds self._backup_on_shutdown = os.getenv("BACKUP_ON_SHUTDOWN", "false").lower() == "true" - + async def initialize(self): """Initialize database connection""" try: # Connect to Redis await redis_manager.connect() logger.info("Redis connection established") - + # Create database instance self.db = RedisDatabase(redis_manager.get_client()) - + # Test connection and log stats if not redis_manager.redis: raise RuntimeError("Redis client not initialized") await redis_manager.redis.ping() stats = await self.db.get_stats() logger.info(f"Database initialized successfully. Stats: {stats}") - + return self.db - + except Exception as e: logger.error(f"Failed to initialize database: {e}") raise - + async def backup_data(self) -> Optional[str]: """Create a backup of critical data before shutdown""" if not self.db: return None - + try: backup_data = { "timestamp": datetime.now(UTC).isoformat(), @@ -223,41 +224,41 @@ class DatabaseManager: "users": await self.db.get_all_users(), # Add other critical data as needed } - + backup_filename = f"backup_{datetime.now(UTC).strftime('%Y%m%d_%H%M%S')}.json" - + # Save to local file (you might want to save to cloud storage instead) - with open(backup_filename, 'w') as f: + with open(backup_filename, "w") as f: json.dump(backup_data, f, indent=2, default=str) - + logger.info(f"Backup created: {backup_filename}") return backup_filename - + except Exception as e: logger.error(f"Backup failed: {e}") return None - + async def graceful_shutdown(self): """Perform graceful shutdown with optional backup""" self._shutdown_initiated = True logger.info("Initiating graceful shutdown...") - + # Wait for active requests to complete (with timeout) wait_time = 0 while self._active_requests > 0 and wait_time < self._shutdown_timeout: logger.info(f"Waiting for {self._active_requests} active requests to complete...") await asyncio.sleep(1) wait_time += 1 - + if self._active_requests > 0: logger.warning(f"Shutdown timeout reached. {self._active_requests} requests may be interrupted.") - + # Create backup if configured if self._backup_on_shutdown: backup_file = await self.backup_data() if backup_file: logger.info(f"Pre-shutdown backup completed: {backup_file}") - + # Force Redis to save data to disk try: if redis_manager.redis: @@ -265,10 +266,10 @@ class DatabaseManager: try: await redis_manager.redis.bgsave() logger.info("Background save initiated") - + # Wait a bit for background save to start await asyncio.sleep(0.5) - + except Exception as e: logger.warning(f"Background save failed, trying synchronous save: {e}") try: @@ -277,32 +278,32 @@ class DatabaseManager: logger.info("Synchronous save completed") except Exception as e2: logger.warning(f"Synchronous save also failed (Redis persistence may be disabled): {e2}") - + except Exception as e: logger.error(f"Error during Redis save: {e}") - + # Close Redis connection try: await redis_manager.disconnect() logger.info("Redis connection closed successfully") except Exception as e: logger.error(f"Error closing Redis connection: {e}") - + logger.info("Graceful shutdown completed") - + def increment_requests(self): """Track active requests""" self._active_requests += 1 - + def decrement_requests(self): """Track completed requests""" self._active_requests = max(0, self._active_requests - 1) - + @property def is_shutting_down(self) -> bool: """Check if shutdown is in progress""" return self._shutdown_initiated - + def get_database(self) -> RedisDatabase: """Get database instance""" if self.db is None: @@ -310,5 +311,3 @@ class DatabaseManager: if self._shutdown_initiated: raise RuntimeError("Application is shutting down") return self.db - - \ No newline at end of file diff --git a/src/backend/database/mixins/ai.py b/src/backend/database/mixins/ai.py index 1511342..26abc75 100644 --- a/src/backend/database/mixins/ai.py +++ b/src/backend/database/mixins/ai.py @@ -8,41 +8,42 @@ from ..constants import KEY_PREFIXES logger = logging.getLogger(__name__) + class AIMixin(DatabaseProtocol): """Mixin for AI operations""" + async def get_ai_parameters(self, param_id: str) -> Optional[Dict]: """Get AI parameters by ID""" key = f"{KEY_PREFIXES['ai_parameters']}{param_id}" data = await self.redis.get(key) return self._deserialize(data) if data else None - + async def set_ai_parameters(self, param_id: str, param_data: Dict): """Set AI parameters data""" key = f"{KEY_PREFIXES['ai_parameters']}{param_id}" await self.redis.set(key, self._serialize(param_data)) - + async def get_all_ai_parameters(self) -> Dict[str, Any]: """Get all AI parameters""" pattern = f"{KEY_PREFIXES['ai_parameters']}*" keys = await self.redis.keys(pattern) - + if not keys: return {} - + pipe = self.redis.pipeline() for key in keys: pipe.get(key) values = await pipe.execute() - + result = {} for key, value in zip(keys, values): - param_id = key.replace(KEY_PREFIXES['ai_parameters'], '') + param_id = key.replace(KEY_PREFIXES["ai_parameters"], "") result[param_id] = self._deserialize(value) - + return result - + async def delete_ai_parameters(self, param_id: str): """Delete AI parameters""" key = f"{KEY_PREFIXES['ai_parameters']}{param_id}" await self.redis.delete(key) - diff --git a/src/backend/database/mixins/analytics.py b/src/backend/database/mixins/analytics.py index 99d37f2..2d21a87 100644 --- a/src/backend/database/mixins/analytics.py +++ b/src/backend/database/mixins/analytics.py @@ -7,6 +7,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) + class AnalyticsMixin: """Mixin for analytics-related database operations""" - diff --git a/src/backend/database/mixins/auth.py b/src/backend/database/mixins/auth.py index 7655238..120df2f 100644 --- a/src/backend/database/mixins/auth.py +++ b/src/backend/database/mixins/auth.py @@ -8,35 +8,37 @@ from .protocols import DatabaseProtocol logger = logging.getLogger(__name__) + class AuthMixin(DatabaseProtocol): """Mixin for auth-related database operations""" - + async def find_verification_token_by_email(self, email: str) -> Optional[Dict[str, Any]]: """Find pending verification token by email address""" try: pattern = "email_verification:*" cursor = 0 email_lower = email.lower() - + while True: cursor, keys = await self.redis.scan(cursor, match=pattern, count=100) - + for key in keys: token_data = await self.redis.get(key) if token_data: verification_info = json.loads(token_data) - if (verification_info.get("email", "").lower() == email_lower and - not verification_info.get("verified", False)): + if verification_info.get("email", "").lower() == email_lower and not verification_info.get( + "verified", False + ): # Extract token from key token = key.replace("email_verification:", "") verification_info["token"] = token return verification_info - + if cursor == 0: break - + return None - + except Exception as e: logger.error(f"❌ Error finding verification token by email {email}: {e}") return None @@ -47,22 +49,22 @@ class AuthMixin(DatabaseProtocol): pattern = "email_verification:*" cursor = 0 count = 0 - + while True: cursor, keys = await self.redis.scan(cursor, match=pattern, count=100) - + for key in keys: token_data = await self.redis.get(key) if token_data: verification_info = json.loads(token_data) if not verification_info.get("verified", False): count += 1 - + if cursor == 0: break - + return count - + except Exception as e: logger.error(f"❌ Error counting pending verifications: {e}") return 0 @@ -74,29 +76,29 @@ class AuthMixin(DatabaseProtocol): cursor = 0 cleaned_count = 0 current_time = datetime.now(timezone.utc) - + while True: cursor, keys = await self.redis.scan(cursor, match=pattern, count=100) - + for key in keys: token_data = await self.redis.get(key) if token_data: verification_info = json.loads(token_data) expires_at = datetime.fromisoformat(verification_info.get("expires_at", "")) - + if current_time > expires_at: await self.redis.delete(key) cleaned_count += 1 logger.info(f"🧹 Cleaned expired verification token for {verification_info.get('email')}") - + if cursor == 0: break - + if cleaned_count > 0: logger.info(f"🧹 Cleaned up {cleaned_count} expired verification tokens") - + return cleaned_count - + except Exception as e: logger.error(f"❌ Error cleaning up expired verification tokens: {e}") return 0 @@ -106,22 +108,19 @@ class AuthMixin(DatabaseProtocol): try: key = f"verification_attempts:{email.lower()}" data = await self.redis.get(key) - + if not data: return 0 - + attempts_data = json.loads(data) current_time = datetime.now(timezone.utc) window_start = current_time - timedelta(hours=24) - + # Filter out old attempts - recent_attempts = [ - attempt for attempt in attempts_data - if datetime.fromisoformat(attempt) > window_start - ] - + recent_attempts = [attempt for attempt in attempts_data if datetime.fromisoformat(attempt) > window_start] + return len(recent_attempts) - + except Exception as e: logger.error(f"❌ Error getting verification attempts count for {email}: {e}") return 0 @@ -131,34 +130,31 @@ class AuthMixin(DatabaseProtocol): try: key = f"verification_attempts:{email.lower()}" current_time = datetime.now(timezone.utc) - + # Get existing attempts data = await self.redis.get(key) attempts_data = json.loads(data) if data else [] - + # Add current attempt attempts_data.append(current_time.isoformat()) - + # Keep only last 24 hours of attempts window_start = current_time - timedelta(hours=24) - recent_attempts = [ - attempt for attempt in attempts_data - if datetime.fromisoformat(attempt) > window_start - ] - + recent_attempts = [attempt for attempt in attempts_data if datetime.fromisoformat(attempt) > window_start] + # Store with 24 hour expiration await self.redis.setex( key, 24 * 60 * 60, # 24 hours - json.dumps(recent_attempts) + json.dumps(recent_attempts), ) - + return True - + except Exception as e: logger.error(f"❌ Error recording verification attempt for {email}: {e}") return False - + async def store_email_verification_token(self, email: str, token: str, user_type: str, user_data: dict) -> bool: """Store email verification token with user data""" try: @@ -169,16 +165,16 @@ class AuthMixin(DatabaseProtocol): "user_data": user_data, "expires_at": (datetime.now(timezone.utc) + timedelta(hours=24)).isoformat(), "created_at": datetime.now(timezone.utc).isoformat(), - "verified": False + "verified": False, } - + # Store with 24 hour expiration await self.redis.setex( - key, + key, 24 * 60 * 60, # 24 hours in seconds - json.dumps(verification_data, default=str) + json.dumps(verification_data, default=str), ) - + logger.info(f"📧 Stored email verification token for {email}") return True except Exception as e: @@ -208,7 +204,7 @@ class AuthMixin(DatabaseProtocol): await self.redis.setex( key, 24 * 60 * 60, # Keep for remaining TTL - json.dumps(token_data, default=str) + json.dumps(token_data, default=str), ) return True return False @@ -219,7 +215,7 @@ class AuthMixin(DatabaseProtocol): async def store_mfa_code(self, email: str, code: str, device_id: str) -> bool: """Store MFA code for verification""" try: - logger.info("🔐 Storing MFA code for email: %s", email ) + logger.info("🔐 Storing MFA code for email: %s", email) key = f"mfa_code:{email.lower()}:{device_id}" mfa_data = { "code": code, @@ -228,16 +224,16 @@ class AuthMixin(DatabaseProtocol): "expires_at": (datetime.now(timezone.utc) + timedelta(minutes=10)).isoformat(), "created_at": datetime.now(timezone.utc).isoformat(), "attempts": 0, - "verified": False + "verified": False, } - + # Store with 10 minute expiration await self.redis.setex( key, 10 * 60, # 10 minutes in seconds - json.dumps(mfa_data, default=str) + json.dumps(mfa_data, default=str), ) - + logger.info(f"🔐 Stored MFA code for {email}") return True except Exception as e: @@ -266,7 +262,7 @@ class AuthMixin(DatabaseProtocol): await self.redis.setex( key, 10 * 60, # Keep original TTL - json.dumps(mfa_data, default=str) + json.dumps(mfa_data, default=str), ) return mfa_data["attempts"] return 0 @@ -285,7 +281,7 @@ class AuthMixin(DatabaseProtocol): await self.redis.setex( key, 10 * 60, # Keep for remaining TTL - json.dumps(mfa_data, default=str) + json.dumps(mfa_data, default=str), ) return True return False @@ -303,7 +299,7 @@ class AuthMixin(DatabaseProtocol): except Exception as e: logger.error(f"❌ Error storing authentication record for {user_id}: {e}") return False - + async def get_authentication(self, user_id: str) -> Optional[Dict[str, Any]]: """Retrieve authentication record for a user""" try: @@ -315,7 +311,7 @@ class AuthMixin(DatabaseProtocol): except Exception as e: logger.error(f"❌ Error retrieving authentication record for {user_id}: {e}") return None - + async def delete_authentication(self, user_id: str) -> bool: """Delete authentication record for a user""" try: @@ -326,8 +322,10 @@ class AuthMixin(DatabaseProtocol): except Exception as e: logger.error(f"❌ Error deleting authentication record for {user_id}: {e}") return False - - async def store_refresh_token(self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]) -> bool: + + async def store_refresh_token( + self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str] + ) -> bool: """Store refresh token for a user""" try: key = f"refresh_token:{token}" @@ -337,9 +335,9 @@ class AuthMixin(DatabaseProtocol): "device": device_info.get("device", "unknown"), "ip_address": device_info.get("ip_address", "unknown"), "is_revoked": False, - "created_at": datetime.now(timezone.utc).isoformat() + "created_at": datetime.now(timezone.utc).isoformat(), } - + # Store with expiration ttl_seconds = int((expires_at - datetime.now(timezone.utc)).total_seconds()) if ttl_seconds > 0: @@ -352,7 +350,7 @@ class AuthMixin(DatabaseProtocol): except Exception as e: logger.error(f"❌ Error storing refresh token for {user_id}: {e}") return False - + async def get_refresh_token(self, token: str) -> Optional[Dict[str, Any]]: """Retrieve refresh token data""" try: @@ -364,7 +362,7 @@ class AuthMixin(DatabaseProtocol): except Exception as e: logger.error(f"❌ Error retrieving refresh token: {e}") return None - + async def revoke_refresh_token(self, token: str) -> bool: """Revoke a refresh token""" try: @@ -380,7 +378,7 @@ class AuthMixin(DatabaseProtocol): except Exception as e: logger.error(f"❌ Error revoking refresh token: {e}") return False - + async def revoke_all_user_tokens(self, user_id: str) -> bool: """Revoke all refresh tokens for a user""" try: @@ -388,10 +386,10 @@ class AuthMixin(DatabaseProtocol): pattern = "refresh_token:*" cursor = 0 revoked_count = 0 - + while True: cursor, keys = await self.redis.scan(cursor, match=pattern, count=100) - + for key in keys: token_data = await self.redis.get(key) if token_data: @@ -401,16 +399,16 @@ class AuthMixin(DatabaseProtocol): token_info["revoked_at"] = datetime.now(timezone.utc).isoformat() await self.redis.set(key, json.dumps(token_info, default=str)) revoked_count += 1 - + if cursor == 0: break - + logger.info(f"🔐 Revoked {revoked_count} refresh tokens for user {user_id}") return True except Exception as e: logger.error(f"❌ Error revoking all tokens for user {user_id}: {e}") return False - + # Password Reset Token Methods async def store_password_reset_token(self, email: str, token: str, expires_at: datetime) -> bool: """Store password reset token""" @@ -420,9 +418,9 @@ class AuthMixin(DatabaseProtocol): "email": email.lower(), "expires_at": expires_at.isoformat(), "used": False, - "created_at": datetime.now(timezone.utc).isoformat() + "created_at": datetime.now(timezone.utc).isoformat(), } - + # Store with expiration ttl_seconds = int((expires_at - datetime.now(timezone.utc)).total_seconds()) if ttl_seconds > 0: @@ -435,7 +433,7 @@ class AuthMixin(DatabaseProtocol): except Exception as e: logger.error(f"❌ Error storing password reset token for {email}: {e}") return False - + async def get_password_reset_token(self, token: str) -> Optional[Dict[str, Any]]: """Retrieve password reset token data""" try: @@ -447,7 +445,7 @@ class AuthMixin(DatabaseProtocol): except Exception as e: logger.error(f"❌ Error retrieving password reset token: {e}") return None - + async def mark_password_reset_token_used(self, token: str) -> bool: """Mark password reset token as used""" try: @@ -463,7 +461,7 @@ class AuthMixin(DatabaseProtocol): except Exception as e: logger.error(f"❌ Error marking password reset token as used: {e}") return False - + # User Activity and Security Logging async def log_security_event(self, user_id: str, event_type: str, details: Dict[str, Any]) -> bool: """Log security events for audit purposes""" @@ -473,40 +471,39 @@ class AuthMixin(DatabaseProtocol): "timestamp": datetime.now(timezone.utc).isoformat(), "user_id": user_id, "event_type": event_type, - "details": details + "details": details, } - + # Add to list (latest events first) - await self.redis.lpush(key, json.dumps(event_data, default=str))# type: ignore - + await self.redis.lpush(key, json.dumps(event_data, default=str)) # type: ignore + # Keep only last 100 events per day - await self.redis.ltrim(key, 0, 99)# type: ignore - + await self.redis.ltrim(key, 0, 99) # type: ignore + # Set expiration for 30 days await self.redis.expire(key, 30 * 24 * 60 * 60) - + logger.info(f"🔒 Logged security event {event_type} for user {user_id}") return True except Exception as e: logger.error(f"❌ Error logging security event for {user_id}: {e}") return False - + async def get_user_security_log(self, user_id: str, days: int = 7) -> List[Dict[str, Any]]: """Retrieve security log for a user""" try: events = [] for i in range(days): - date = (datetime.now(timezone.utc) - timedelta(days=i)).strftime('%Y-%m-%d') + date = (datetime.now(timezone.utc) - timedelta(days=i)).strftime("%Y-%m-%d") key = f"security_log:{user_id}:{date}" - - daily_events = await self.redis.lrange(key, 0, -1)# type: ignore + + daily_events = await self.redis.lrange(key, 0, -1) # type: ignore for event_json in daily_events: events.append(json.loads(event_json)) - + # Sort by timestamp (most recent first) events.sort(key=lambda x: x["timestamp"], reverse=True) return events except Exception as e: logger.error(f"❌ Error retrieving security log for {user_id}: {e}") - return [] - + return [] diff --git a/src/backend/database/mixins/base.py b/src/backend/database/mixins/base.py index dbde059..cd5bf04 100644 --- a/src/backend/database/mixins/base.py +++ b/src/backend/database/mixins/base.py @@ -5,20 +5,22 @@ from typing import Any, Dict, TYPE_CHECKING from .protocols import DatabaseProtocol from ..constants import KEY_PREFIXES + logger = logging.getLogger(__name__) if TYPE_CHECKING: pass + class BaseMixin(DatabaseProtocol): """Base mixin with core Redis operations and utilities""" - + def _serialize(self, data: Any) -> str: """Serialize data to JSON string for Redis storage""" if data is None: return "" return json.dumps(data, default=str) - + def _deserialize(self, data: str) -> Any: """Deserialize JSON string from Redis""" if not data: @@ -45,6 +47,3 @@ class BaseMixin(DatabaseProtocol): keys = await self.redis.keys(pattern) if keys: await self.redis.delete(*keys) - - - \ No newline at end of file diff --git a/src/backend/database/mixins/chat.py b/src/backend/database/mixins/chat.py index 08a0bee..58d2a4f 100644 --- a/src/backend/database/mixins/chat.py +++ b/src/backend/database/mixins/chat.py @@ -8,76 +8,77 @@ from ..constants import KEY_PREFIXES logger = logging.getLogger(__name__) + class ChatMixin(DatabaseProtocol): """Mixin for chat-related database operations""" - + # Chat Sessions operations async def get_candidate_chat_summary(self, candidate_id: str) -> Dict[str, Any]: """Get a summary of chat activity for a specific candidate""" sessions = await self.get_chat_sessions_by_candidate(candidate_id) - + if not sessions: return { "candidate_id": candidate_id, "total_sessions": 0, "total_messages": 0, "first_chat": None, - "last_chat": None + "last_chat": None, } - + total_messages = 0 for session in sessions: session_id = session.get("id") if session_id: message_count = await self.get_chat_message_count(session_id) total_messages += message_count - + # Sort sessions by creation date sessions_by_date = sorted(sessions, key=lambda x: x.get("createdAt", "")) - + return { "candidate_id": candidate_id, "total_sessions": len(sessions), "total_messages": total_messages, "first_chat": sessions_by_date[0].get("createdAt") if sessions_by_date else None, "last_chat": sessions_by_date[-1].get("lastActivity") if sessions_by_date else None, - "recent_sessions": sessions[:5] # Last 5 sessions + "recent_sessions": sessions[:5], # Last 5 sessions } - + # Chat Sessions operations async def get_chat_session(self, session_id: str) -> Optional[Dict]: """Get chat session by ID""" key = f"{KEY_PREFIXES['chat_sessions']}{session_id}" data = await self.redis.get(key) return self._deserialize(data) if data else None - + async def set_chat_session(self, session_id: str, session_data: Dict): """Set chat session data""" key = f"{KEY_PREFIXES['chat_sessions']}{session_id}" await self.redis.set(key, self._serialize(session_data)) - + async def get_all_chat_sessions(self) -> Dict[str, Any]: """Get all chat sessions""" pattern = f"{KEY_PREFIXES['chat_sessions']}*" keys = await self.redis.keys(pattern) - + if not keys: return {} - + pipe = self.redis.pipeline() for key in keys: pipe.get(key) values = await pipe.execute() - + result = {} for key, value in zip(keys, values): - session_id = key.replace(KEY_PREFIXES['chat_sessions'], '') + session_id = key.replace(KEY_PREFIXES["chat_sessions"], "") result[session_id] = self._deserialize(value) - + return result async def delete_chat_session(self, session_id: str) -> bool: - '''Delete a chat session from Redis''' + """Delete a chat session from Redis""" try: result = await self.redis.delete(f"chat_session:{session_id}") return result > 0 @@ -86,11 +87,11 @@ class ChatMixin(DatabaseProtocol): raise async def delete_chat_message(self, session_id: str, message_id: str) -> bool: - '''Delete a specific chat message from Redis''' + """Delete a specific chat message from Redis""" try: # Remove from the session's message list key = f"{KEY_PREFIXES['chat_messages']}{session_id}" - await self.redis.lrem(key, 0, message_id)# type: ignore + await self.redis.lrem(key, 0, message_id) # type: ignore # Delete the message data itself result = await self.redis.delete(f"chat_message:{message_id}") return result > 0 @@ -102,42 +103,42 @@ class ChatMixin(DatabaseProtocol): async def get_chat_messages(self, session_id: str) -> List[Dict]: """Get chat messages for a session""" key = f"{KEY_PREFIXES['chat_messages']}{session_id}" - messages = await self.redis.lrange(key, 0, -1)# type: ignore - return [self._deserialize(msg) for msg in messages if msg] # type: ignore - + messages = await self.redis.lrange(key, 0, -1) # type: ignore + return [self._deserialize(msg) for msg in messages if msg] # type: ignore + async def add_chat_message(self, session_id: str, message_data: Dict): """Add a chat message to a session""" key = f"{KEY_PREFIXES['chat_messages']}{session_id}" - await self.redis.rpush(key, self._serialize(message_data))# type: ignore - + await self.redis.rpush(key, self._serialize(message_data)) # type: ignore + async def set_chat_messages(self, session_id: str, messages: List[Dict]): """Set all chat messages for a session (replaces existing)""" key = f"{KEY_PREFIXES['chat_messages']}{session_id}" - + # Clear existing messages await self.redis.delete(key) - + # Add new messages if messages: serialized_messages = [self._serialize(msg) for msg in messages] - await self.redis.rpush(key, *serialized_messages)# type: ignore - + await self.redis.rpush(key, *serialized_messages) # type: ignore + async def get_all_chat_messages(self) -> Dict[str, List[Dict]]: """Get all chat messages grouped by session""" pattern = f"{KEY_PREFIXES['chat_messages']}*" keys = await self.redis.keys(pattern) - + if not keys: return {} - + result = {} for key in keys: - session_id = key.replace(KEY_PREFIXES['chat_messages'], '') - messages = await self.redis.lrange(key, 0, -1)# type: ignore + session_id = key.replace(KEY_PREFIXES["chat_messages"], "") + messages = await self.redis.lrange(key, 0, -1) # type: ignore result[session_id] = [self._deserialize(msg) for msg in messages if msg] - + return result - + async def delete_chat_messages(self, session_id: str): """Delete all chat messages for a session""" key = f"{KEY_PREFIXES['chat_messages']}{session_id}" @@ -148,61 +149,60 @@ class ChatMixin(DatabaseProtocol): """Get all chat sessions for a specific user""" all_sessions = await self.get_all_chat_sessions() user_sessions = [] - + for session_data in all_sessions.values(): if session_data.get("userId") == user_id or session_data.get("guestId") == user_id: user_sessions.append(session_data) - + # Sort by last activity (most recent first) user_sessions.sort(key=lambda x: x.get("lastActivity", ""), reverse=True) return user_sessions - + async def get_chat_sessions_by_candidate(self, candidate_id: str) -> List[Dict]: """Get all chat sessions related to a specific candidate""" all_sessions = await self.get_all_chat_sessions() candidate_sessions = [] - + for session_data in all_sessions.values(): context = session_data.get("context", {}) - if (context.get("relatedEntityType") == "candidate" and - context.get("relatedEntityId") == candidate_id): + if context.get("relatedEntityType") == "candidate" and context.get("relatedEntityId") == candidate_id: candidate_sessions.append(session_data) - + # Sort by last activity (most recent first) candidate_sessions.sort(key=lambda x: x.get("lastActivity", ""), reverse=True) return candidate_sessions - + async def update_chat_session_activity(self, session_id: str): """Update the last activity timestamp for a chat session""" session_data = await self.get_chat_session(session_id) if session_data: session_data["lastActivity"] = datetime.now(UTC).isoformat() await self.set_chat_session(session_id, session_data) - + async def get_recent_chat_messages(self, session_id: str, limit: int = 10) -> List[Dict]: """Get the most recent chat messages for a session""" messages = await self.get_chat_messages(session_id) # Return the last 'limit' messages return messages[-limit:] if len(messages) > limit else messages - + async def get_chat_message_count(self, session_id: str) -> int: """Get the total number of messages in a chat session""" key = f"{KEY_PREFIXES['chat_messages']}{session_id}" - return await self.redis.llen(key)# type: ignore - + return await self.redis.llen(key) # type: ignore + async def search_chat_messages(self, session_id: str, query: str) -> List[Dict]: """Search for messages containing specific text in a session""" messages = await self.get_chat_messages(session_id) query_lower = query.lower() - + matching_messages = [] for msg in messages: content = msg.get("content", "").lower() if query_lower in content: matching_messages.append(msg) - + return matching_messages - + # Chat Session Management async def archive_chat_session(self, session_id: str): """Archive a chat session""" @@ -211,38 +211,37 @@ class ChatMixin(DatabaseProtocol): session_data["isArchived"] = True session_data["updatedAt"] = datetime.now(UTC).isoformat() await self.set_chat_session(session_id, session_data) - + async def delete_chat_session_completely(self, session_id: str): """Delete a chat session and all its messages""" # Delete the session await self.delete_chat_session(session_id) # Delete all messages await self.delete_chat_messages(session_id) - + async def cleanup_old_chat_sessions(self, days_old: int = 90): """Archive or delete chat sessions older than specified days""" cutoff_date = datetime.now(UTC) - timedelta(days=days_old) cutoff_iso = cutoff_date.isoformat() - + all_sessions = await self.get_all_chat_sessions() archived_count = 0 - + for session_id, session_data in all_sessions.items(): last_activity = session_data.get("lastActivity", session_data.get("createdAt", "")) - + if last_activity < cutoff_iso and not session_data.get("isArchived", False): await self.archive_chat_session(session_id) archived_count += 1 - + return archived_count - - + # Analytics and Reporting async def get_chat_statistics(self) -> Dict[str, Any]: """Get comprehensive chat statistics""" all_sessions = await self.get_all_chat_sessions() all_messages = await self.get_all_chat_messages() - + stats = { "total_sessions": len(all_sessions), "total_messages": sum(len(messages) for messages in all_messages.values()), @@ -250,34 +249,34 @@ class ChatMixin(DatabaseProtocol): "archived_sessions": 0, "sessions_by_type": {}, "sessions_with_candidates": 0, - "average_messages_per_session": 0 + "average_messages_per_session": 0, } - + # Analyze sessions for session_data in all_sessions.values(): if session_data.get("isArchived", False): stats["archived_sessions"] += 1 else: stats["active_sessions"] += 1 - + # Count by type context_type = session_data.get("context", {}).get("type", "unknown") stats["sessions_by_type"][context_type] = stats["sessions_by_type"].get(context_type, 0) + 1 - + # Count sessions with candidate association if session_data.get("context", {}).get("relatedEntityType") == "candidate": stats["sessions_with_candidates"] += 1 - + # Calculate averages if stats["total_sessions"] > 0: stats["average_messages_per_session"] = stats["total_messages"] / stats["total_sessions"] - + return stats async def bulk_update_chat_sessions(self, session_updates: Dict[str, Dict]): """Bulk update multiple chat sessions""" pipe = self.redis.pipeline() - + for session_id, updates in session_updates.items(): session_data = await self.get_chat_session(session_id) if session_data: @@ -285,5 +284,5 @@ class ChatMixin(DatabaseProtocol): session_data["updatedAt"] = datetime.now(UTC).isoformat() key = f"{KEY_PREFIXES['chat_sessions']}{session_id}" pipe.set(key, self._serialize(session_data)) - + await pipe.execute() diff --git a/src/backend/database/mixins/document.py b/src/backend/database/mixins/document.py index d9fd83a..345db7d 100644 --- a/src/backend/database/mixins/document.py +++ b/src/backend/database/mixins/document.py @@ -7,6 +7,7 @@ from ..constants import KEY_PREFIXES logger = logging.getLogger(__name__) + class DocumentMixin(DatabaseProtocol): """Mixin for document-related database operations""" @@ -31,17 +32,17 @@ class DocumentMixin(DatabaseProtocol): try: # Get all document IDs for this candidate key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}" - document_ids = await self.redis.lrange(key, 0, -1)# type: ignore - + document_ids = await self.redis.lrange(key, 0, -1) # type: ignore + if not document_ids: logger.info(f"No documents found for candidate {candidate_id}") return 0 - + deleted_count = 0 - + # Use pipeline for efficient batch operations pipe = self.redis.pipeline() - + # Delete each document's metadata for doc_id in document_ids: pipe.delete(f"document:{doc_id}") @@ -50,13 +51,13 @@ class DocumentMixin(DatabaseProtocol): # Delete the candidate's document list pipe.delete(key) - + # Execute all operations await pipe.execute() - + logger.info(f"Successfully deleted {deleted_count} documents for candidate {candidate_id}") return deleted_count - + except Exception as e: logger.error(f"Error deleting all documents for candidate {candidate_id}: {e}") raise @@ -64,17 +65,17 @@ class DocumentMixin(DatabaseProtocol): async def get_candidate_documents(self, candidate_id: str) -> List[Dict]: """Get all documents for a specific candidate""" key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}" - document_ids = await self.redis.lrange(key, 0, -1) # type: ignore - + document_ids = await self.redis.lrange(key, 0, -1) # type: ignore + if not document_ids: return [] - + # Get all document metadata pipe = self.redis.pipeline() for doc_id in document_ids: pipe.get(f"document:{doc_id}") values = await pipe.execute() - + documents = [] for doc_id, value in zip(document_ids, values): if value: @@ -83,20 +84,20 @@ class DocumentMixin(DatabaseProtocol): documents.append(doc_data) else: # Clean up orphaned document ID - await self.redis.lrem(key, 0, doc_id)# type: ignore + await self.redis.lrem(key, 0, doc_id) # type: ignore logger.warning(f"Removed orphaned document ID {doc_id} for candidate {candidate_id}") - + return documents async def add_document_to_candidate(self, candidate_id: str, document_id: str): """Add a document ID to a candidate's document list""" key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}" - await self.redis.rpush(key, document_id)# type: ignore + await self.redis.rpush(key, document_id) # type: ignore async def remove_document_from_candidate(self, candidate_id: str, document_id: str): """Remove a document ID from a candidate's document list""" key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}" - await self.redis.lrem(key, 0, document_id)# type: ignore + await self.redis.lrem(key, 0, document_id) # type: ignore async def update_document(self, document_id: str, updates: Dict) -> Dict[Any, Any] | None: """Update document metadata""" @@ -115,29 +116,28 @@ class DocumentMixin(DatabaseProtocol): async def bulk_update_document_rag_status(self, candidate_id: str, document_ids: List[str], include_in_rag: bool): """Bulk update RAG status for multiple documents""" pipe = self.redis.pipeline() - + for doc_id in document_ids: doc_data = await self.get_document(doc_id) if doc_data and doc_data.get("candidate_id") == candidate_id: doc_data["include_in_rag"] = include_in_rag doc_data["updatedAt"] = datetime.now(UTC).isoformat() pipe.set(f"document:{doc_id}", self._serialize(doc_data)) - + await pipe.execute() async def get_document_count_for_candidate(self, candidate_id: str) -> int: """Get total number of documents for a candidate""" key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}" - return await self.redis.llen(key)# type: ignore + return await self.redis.llen(key) # type: ignore async def search_candidate_documents(self, candidate_id: str, query: str) -> List[Dict]: """Search documents by filename for a candidate""" all_documents = await self.get_candidate_documents(candidate_id) query_lower = query.lower() - - return [ - doc for doc in all_documents - if (query_lower in doc.get("filename", "").lower() or - query_lower in doc.get("originalName", "").lower()) - ] + return [ + doc + for doc in all_documents + if (query_lower in doc.get("filename", "").lower() or query_lower in doc.get("originalName", "").lower()) + ] diff --git a/src/backend/database/mixins/job.py b/src/backend/database/mixins/job.py index a5a9161..c908c62 100644 --- a/src/backend/database/mixins/job.py +++ b/src/backend/database/mixins/job.py @@ -8,39 +8,41 @@ from ..constants import KEY_PREFIXES logger = logging.getLogger(__name__) + class JobMixin(DatabaseProtocol): """Mixin for job-related database operations""" + async def get_job(self, job_id: str) -> Optional[Dict]: """Get job by ID""" key = f"{KEY_PREFIXES['jobs']}{job_id}" data = await self.redis.get(key) return self._deserialize(data) if data else None - + async def set_job(self, job_id: str, job_data: Dict): """Set job data""" key = f"{KEY_PREFIXES['jobs']}{job_id}" await self.redis.set(key, self._serialize(job_data)) - + async def get_all_jobs(self) -> Dict[str, Any]: """Get all jobs""" pattern = f"{KEY_PREFIXES['jobs']}*" keys = await self.redis.keys(pattern) - + if not keys: return {} - + pipe = self.redis.pipeline() for key in keys: pipe.get(key) values = await pipe.execute() - + result = {} for key, value in zip(keys, values): - job_id = key.replace(KEY_PREFIXES['jobs'], '') + job_id = key.replace(KEY_PREFIXES["jobs"], "") result[job_id] = self._deserialize(value) - + return result - + async def delete_job(self, job_id: str): """Delete job""" key = f"{KEY_PREFIXES['jobs']}{job_id}" @@ -51,10 +53,10 @@ class JobMixin(DatabaseProtocol): try: # Get all documents for the candidate candidate_documents = await self.get_candidate_documents(candidate_id) - + if not candidate_documents: return [] - + # Get job requirements for each document job_requirements = [] for doc in candidate_documents: @@ -66,7 +68,7 @@ class JobMixin(DatabaseProtocol): requirements["document_filename"] = doc.get("filename") requirements["document_original_name"] = doc.get("originalName") job_requirements.append(requirements) - + return job_requirements except Exception as e: logger.error(f"❌ Error retrieving job requirements for candidate {candidate_id}: {e}") @@ -82,15 +84,15 @@ class JobMixin(DatabaseProtocol): try: deleted_count = 0 pipe = self.redis.pipeline() - + for doc_id in document_ids: key = f"{KEY_PREFIXES['job_requirements']}{doc_id}" pipe.delete(key) deleted_count += 1 - + results = await pipe.execute() actual_deleted = sum(1 for result in results if result > 0) - + logger.info(f"📋 Bulk deleted job requirements for {actual_deleted}/{len(document_ids)} documents") return actual_deleted except Exception as e: @@ -103,49 +105,49 @@ class JobMixin(DatabaseProtocol): key = f"{KEY_PREFIXES['job_applications']}{application_id}" data = await self.redis.get(key) return self._deserialize(data) if data else None - + async def set_job_application(self, application_id: str, application_data: Dict): """Set job application data""" key = f"{KEY_PREFIXES['job_applications']}{application_id}" await self.redis.set(key, self._serialize(application_data)) - + async def get_all_job_applications(self) -> Dict[str, Any]: """Get all job applications""" pattern = f"{KEY_PREFIXES['job_applications']}*" keys = await self.redis.keys(pattern) - + if not keys: return {} - + pipe = self.redis.pipeline() for key in keys: pipe.get(key) values = await pipe.execute() - + result = {} for key, value in zip(keys, values): - app_id = key.replace(KEY_PREFIXES['job_applications'], '') + app_id = key.replace(KEY_PREFIXES["job_applications"], "") result[app_id] = self._deserialize(value) - + return result - + async def delete_job_application(self, application_id: str): """Delete job application""" key = f"{KEY_PREFIXES['job_applications']}{application_id}" await self.redis.delete(key) - + async def cleanup_orphaned_job_requirements(self) -> int: """Clean up job requirements for documents that no longer exist""" try: # Get all job requirements all_requirements = await self.get_all_job_requirements() - + if not all_requirements: return 0 - + orphaned_count = 0 pipe = self.redis.pipeline() - + for document_id in all_requirements.keys(): # Check if the document still exists document_exists = await self.get_document(document_id) @@ -155,16 +157,16 @@ class JobMixin(DatabaseProtocol): pipe.delete(key) orphaned_count += 1 logger.info(f"📋 Queued orphaned job requirements for deletion: {document_id}") - + if orphaned_count > 0: await pipe.execute() logger.info(f"🧹 Cleaned up {orphaned_count} orphaned job requirements") - + return orphaned_count except Exception as e: logger.error(f"❌ Error cleaning up orphaned job requirements: {e}") - return 0 - + return 0 + async def get_job_requirements(self, document_id: str) -> Optional[Dict]: """Get cached job requirements analysis for a document""" try: @@ -184,19 +186,19 @@ class JobMixin(DatabaseProtocol): """Save job requirements analysis results for a document""" try: key = f"{KEY_PREFIXES['job_requirements']}{document_id}" - + # Add metadata to the requirements requirements_with_meta = { **requirements, "cached_at": datetime.now(UTC).isoformat(), - "document_id": document_id + "document_id": document_id, } - + await self.redis.set(key, self._serialize(requirements_with_meta)) - + # Optional: Set expiration (e.g., 30 days) to prevent indefinite storage # await self.redis.expire(key, 30 * 24 * 60 * 60) # 30 days - + logger.info(f"📋 Saved job requirements for document {document_id}") return True except Exception as e: @@ -221,21 +223,21 @@ class JobMixin(DatabaseProtocol): try: pattern = f"{KEY_PREFIXES['job_requirements']}*" keys = await self.redis.keys(pattern) - + if not keys: return {} - + pipe = self.redis.pipeline() for key in keys: pipe.get(key) values = await pipe.execute() - + result = {} for key, value in zip(keys, values): - document_id = key.replace(KEY_PREFIXES['job_requirements'], '') + document_id = key.replace(KEY_PREFIXES["job_requirements"], "") if value: result[document_id] = self._deserialize(value) - + return result except Exception as e: logger.error(f"❌ Error retrieving all job requirements: {e}") @@ -246,35 +248,30 @@ class JobMixin(DatabaseProtocol): try: pattern = f"{KEY_PREFIXES['job_requirements']}*" keys = await self.redis.keys(pattern) - - stats = { - "total_cached_requirements": len(keys), - "cache_dates": {}, - "documents_with_requirements": [] - } - + + stats = {"total_cached_requirements": len(keys), "cache_dates": {}, "documents_with_requirements": []} + if keys: # Get cache dates for analysis pipe = self.redis.pipeline() for key in keys: pipe.get(key) values = await pipe.execute() - + for key, value in zip(keys, values): if value: requirements_data = self._deserialize(value) if requirements_data: - document_id = key.replace(KEY_PREFIXES['job_requirements'], '') + document_id = key.replace(KEY_PREFIXES["job_requirements"], "") stats["documents_with_requirements"].append(document_id) - + # Track cache dates cached_at = requirements_data.get("cached_at") if cached_at: cache_date = cached_at[:10] # Extract date part stats["cache_dates"][cache_date] = stats["cache_dates"].get(cache_date, 0) + 1 - + return stats except Exception as e: logger.error(f"❌ Error getting job requirements stats: {e}") return {"total_cached_requirements": 0, "cache_dates": {}, "documents_with_requirements": []} - diff --git a/src/backend/database/mixins/protocols.py b/src/backend/database/mixins/protocols.py index 9b98468..eb8cabb 100644 --- a/src/backend/database/mixins/protocols.py +++ b/src/backend/database/mixins/protocols.py @@ -7,153 +7,412 @@ if TYPE_CHECKING: from models import SkillAssessment + class DatabaseProtocol(Protocol): # Base mixin redis: Redis - def _serialize(self, data) -> str: ... - def _deserialize(self, data: str): ... + + def _serialize(self, data) -> str: + ... + + def _deserialize(self, data: str): + ... # Chat mixin - async def add_chat_message(self, session_id: str, message_data: Dict): ... - async def archive_chat_session(self, session_id: str): ... - async def bulk_update_chat_sessions(self, session_updates: Dict[str, Dict]): ... - async def delete_chat_message(self, session_id: str, message_id: str) -> bool: ... - async def delete_chat_messages(self, session_id: str): ... - async def delete_chat_session_completely(self, session_id: str): ... - async def delete_chat_session(self, session_id: str) -> bool: ... - + async def add_chat_message(self, session_id: str, message_data: Dict): + ... + + async def archive_chat_session(self, session_id: str): + ... + + async def bulk_update_chat_sessions(self, session_updates: Dict[str, Dict]): + ... + + async def delete_chat_message(self, session_id: str, message_id: str) -> bool: + ... + + async def delete_chat_messages(self, session_id: str): + ... + + async def delete_chat_session_completely(self, session_id: str): + ... + + async def delete_chat_session(self, session_id: str) -> bool: + ... + # Document mixin - async def add_document_to_candidate(self, candidate_id: str, document_id: str): ... - async def bulk_update_document_rag_status(self, candidate_id: str, document_ids: List[str], include_in_rag: bool): ... - + async def add_document_to_candidate(self, candidate_id: str, document_id: str): + ... + + async def bulk_update_document_rag_status(self, candidate_id: str, document_ids: List[str], include_in_rag: bool): + ... + # Job mixin - async def bulk_delete_job_requirements(self, document_ids: List[str]) -> int: ... - async def cache_skill_match(self, cache_key: str, assessment: SkillAssessment) -> None: ... + async def bulk_delete_job_requirements(self, document_ids: List[str]) -> int: + ... + + async def cache_skill_match(self, cache_key: str, assessment: SkillAssessment) -> None: + ... # User mixin - async def delete_candidate_batch(self, candidate_ids: List[str]) -> Dict[str, Dict[str, int]]: ... - async def delete_candidate(self, candidate_id: str) -> Dict[str, int]: ... - async def delete_employer(self, employer_id: str): ... - async def delete_guest(self, guest_id: str) -> bool: ... - async def delete_user(self, email: str): ... - async def find_candidate_by_username(self, username: str) -> Optional[Dict]: ... - async def get_all_users(self) -> Dict[str, Any]: ... - async def get_all_viewers(self) -> Dict[str, Any]: ... - async def get_candidate_chat_summary(self, candidate_id: str) -> Dict[str, Any]: ... - async def get_candidate_documents(self, candidate_id: str) -> List[Dict]: ... - async def get_candidate(self, candidate_id: str) -> Optional[Dict]: ... - async def get_employer(self, employer_id: str) -> Optional[Dict]: ... - async def get_guest_by_session_id(self, session_id: str) -> Optional[Dict[str, Any]]: ... - async def get_guest(self, guest_id: str) -> Optional[Dict[str, Any]]: ... - async def get_guest_statistics(self) -> Dict[str, Any]: ... - async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: ... - async def get_user_by_username(self, username: str) -> Optional[Dict]: ... - async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]: ... - async def get_user_security_log(self, user_id: str, days: int = 7) -> List[Dict[str, Any]]: ... - async def get_user(self, login: str) -> Optional[Dict[str, Any]]: ... - async def invalidate_candidate_skill_cache(self, candidate_id: str) -> int: ... - async def invalidate_user_skill_cache(self, user_id: str) -> int: ... - async def set_candidate(self, candidate_id: str, candidate_data: Dict): ... - async def set_employer(self, employer_id: str, employer_data: Dict): ... - async def set_guest(self, guest_id: str, guest_data: Dict[str, Any]) -> None: ... - async def set_user_by_id(self, user_id: str, user_data: Dict[str, Any]) -> bool: ... - async def set_user(self, login: str, user_data: Dict[str, Any]) -> bool: ... - async def update_user_rag_timestamp(self, user_id: str) -> bool: ... - async def user_exists_by_email(self, email: str) -> bool: ... - async def user_exists_by_username(self, username: str) -> bool: ... + async def delete_candidate_batch(self, candidate_ids: List[str]) -> Dict[str, Dict[str, int]]: + ... - # Auth mixin - async def cleanup_expired_verification_tokens(self) -> int: ... - async def cleanup_inactive_guests(self, inactive_hours: int = 24) -> int: ... - async def cleanup_old_chat_sessions(self, days_old: int = 90) -> int: ... - async def cleanup_orphaned_job_requirements(self) -> int: ... - async def clear_all_data(self: "DatabaseProtocol"): ... - async def clear_all_skill_match_cache(self) -> int: ... + async def delete_candidate(self, candidate_id: str) -> Dict[str, int]: + ... + + async def delete_employer(self, employer_id: str): + ... + + async def delete_guest(self, guest_id: str) -> bool: + ... + + async def delete_user(self, email: str): + ... + + async def find_candidate_by_username(self, username: str) -> Optional[Dict]: + ... + + async def get_all_users(self) -> Dict[str, Any]: + ... + + async def get_all_viewers(self) -> Dict[str, Any]: + ... + + async def get_candidate_chat_summary(self, candidate_id: str) -> Dict[str, Any]: + ... + + async def get_candidate_documents(self, candidate_id: str) -> List[Dict]: + ... + + async def get_candidate(self, candidate_id: str) -> Optional[Dict]: + ... + + async def get_employer(self, employer_id: str) -> Optional[Dict]: + ... + + async def get_guest_by_session_id(self, session_id: str) -> Optional[Dict[str, Any]]: + ... + + async def get_guest(self, guest_id: str) -> Optional[Dict[str, Any]]: + ... + + async def get_guest_statistics(self) -> Dict[str, Any]: + ... + + async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: + ... + + async def get_user_by_username(self, username: str) -> Optional[Dict]: + ... + + async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]: + ... + + async def get_user_security_log(self, user_id: str, days: int = 7) -> List[Dict[str, Any]]: + ... + + async def get_user(self, login: str) -> Optional[Dict[str, Any]]: + ... + + async def invalidate_candidate_skill_cache(self, candidate_id: str) -> int: + ... + + async def invalidate_user_skill_cache(self, user_id: str) -> int: + ... + + async def set_candidate(self, candidate_id: str, candidate_data: Dict): + ... + + async def set_employer(self, employer_id: str, employer_data: Dict): + ... + + async def set_guest(self, guest_id: str, guest_data: Dict[str, Any]) -> None: + ... + + async def set_user_by_id(self, user_id: str, user_data: Dict[str, Any]) -> bool: + ... + + async def set_user(self, login: str, user_data: Dict[str, Any]) -> bool: + ... + + async def update_user_rag_timestamp(self, user_id: str) -> bool: + ... + + async def user_exists_by_email(self, email: str) -> bool: + ... + + async def user_exists_by_username(self, username: str) -> bool: + ... + + # Auth mixin + async def cleanup_expired_verification_tokens(self) -> int: + ... + + async def cleanup_inactive_guests(self, inactive_hours: int = 24) -> int: + ... + + async def cleanup_old_chat_sessions(self, days_old: int = 90) -> int: + ... + + async def cleanup_orphaned_job_requirements(self) -> int: + ... + + async def clear_all_data(self: "DatabaseProtocol"): + ... + + async def clear_all_skill_match_cache(self) -> int: + ... # Resume mixin - async def delete_ai_parameters(self, param_id: str): ... - async def delete_all_candidate_documents(self, candidate_id: str) -> int: ... - async def delete_all_resumes_for_user(self, user_id: str) -> int: ... - async def delete_authentication(self, user_id: str) -> bool: ... - async def delete_document(self, document_id: str): ... + async def delete_ai_parameters(self, param_id: str): + ... - async def delete_job_application(self, application_id: str): ... - async def delete_job_requirements(self, document_id: str) -> bool: ... - async def delete_job(self, job_id: str): ... - async def delete_resume(self, user_id: str, resume_id: str) -> bool: ... - async def delete_viewer(self, viewer_id: str): ... - async def find_verification_token_by_email(self, email: str) -> Optional[Dict[str, Any]]: ... - async def get_ai_parameters(self, param_id: str) -> Optional[Dict]: ... - async def get_all_ai_parameters(self) -> Dict[str, Any]: ... - async def get_all_candidates(self) -> Dict[str, Any]: ... - async def get_all_chat_messages(self) -> Dict[str, List[Dict]]: ... - async def get_all_chat_sessions(self) -> Dict[str, Any]: ... - async def get_all_employers(self) -> Dict[str, Any]: ... - async def get_all_guests(self) -> Dict[str, Dict[str, Any]]: ... - async def get_all_job_applications(self) -> Dict[str, Any]: ... - async def get_all_job_requirements(self) -> Dict[str, Any]: ... - async def get_all_jobs(self) -> Dict[str, Any]: ... - async def get_all_resumes_for_user(self, user_id: str) -> List[Dict]: ... - async def get_all_resumes(self) -> Dict[str, List[Dict]]: ... - async def get_authentication(self, user_id: str) -> Optional[Dict[str, Any]]: ... - async def get_cached_skill_match(self, cache_key: str) -> Optional[SkillAssessment]: ... - async def get_chat_message_count(self, session_id: str) -> int: ... - async def get_chat_messages(self, session_id: str) -> List[Dict]: ... - async def get_chat_sessions_by_candidate(self, candidate_id: str) -> List[Dict]: ... - async def get_chat_sessions_by_user(self, user_id: str) -> List[Dict]: ... - async def get_chat_session(self, session_id: str) -> Optional[Dict]: ... - async def get_chat_statistics(self) -> Dict[str, Any]: ... - async def get_document_count_for_candidate(self, candidate_id: str) -> int: ... - async def get_documents_by_rag_status(self, candidate_id: str, include_in_rag: bool = True) -> List[Dict]: ... - async def get_document(self, document_id: str) -> Optional[Dict]: ... - async def get_email_verification_token(self, token: str) -> Optional[Dict[str, Any]]: ... - async def get_job_application(self, application_id: str) -> Optional[Dict]: ... - async def get_job_requirements_by_candidate(self, candidate_id: str) -> List[Dict]: ... - async def get_job_requirements(self, document_id: str) -> Optional[Dict]: ... - async def get_job_requirements_stats(self) -> Dict[str, Any]: ... - async def get_job(self, job_id: str) -> Optional[Dict]: ... - async def get_mfa_code(self, email: str, device_id: str) -> Optional[Dict[str, Any]]: ... - async def get_multiple_candidates_by_usernames(self, usernames: List[str]) -> Dict[str, Dict]: ... - async def get_password_reset_token(self, token: str) -> Optional[Dict[str, Any]]: ... - async def get_pending_verifications_count(self) -> int: ... - async def get_recent_chat_messages(self, session_id: str, limit: int = 10) -> List[Dict]: ... - async def get_refresh_token(self, token: str) -> Optional[Dict[str, Any]]: ... - async def get_resumes_by_candidate(self, user_id: str, candidate_id: str) -> List[Dict]: ... - async def get_resumes_by_job(self, user_id: str, job_id: str) -> List[Dict]: ... - async def get_resume(self, user_id: str, resume_id: str) -> Optional[Dict]: ... - async def get_resume_statistics(self, user_id: str) -> Dict[str, Any]: ... - async def get_stats(self) -> Dict[str, int]: ... - async def get_verification_attempts_count(self, email: str) -> int: ... - async def get_viewer(self, viewer_id: str) -> Optional[Dict]: ... - async def increment_mfa_attempts(self, email: str, device_id: str) -> int: ... - async def invalidate_job_requirements_cache(self, document_id: str) -> bool: ... - async def log_security_event(self, user_id: str, event_type: str, details: Dict[str, Any]) -> bool: ... - async def mark_email_verified(self, token: str) -> bool: ... - async def mark_mfa_verified(self, email: str, device_id: str) -> bool: ... - async def mark_password_reset_token_used(self, token: str) -> bool: ... - async def record_verification_attempt(self, email: str) -> bool: ... - async def remove_document_from_candidate(self, candidate_id: str, document_id: str): ... - async def revoke_all_user_tokens(self, user_id: str) -> bool: ... - async def revoke_refresh_token(self, token: str) -> bool: ... - async def save_job_requirements(self, document_id: str, requirements: Dict) -> bool: ... - async def search_candidate_documents(self, candidate_id: str, query: str) -> List[Dict]: ... - async def search_chat_messages(self, session_id: str, query: str) -> List[Dict]: ... - async def search_resumes_for_user(self, user_id: str, query: str) -> List[Dict]: ... - async def set_ai_parameters(self, param_id: str, param_data: Dict): ... - async def set_authentication(self, user_id: str, auth_data: Dict[str, Any]) -> bool: ... - async def set_chat_messages(self, session_id: str, messages: List[Dict]): ... - async def set_chat_session(self, session_id: str, session_data: Dict): ... - async def set_document(self, document_id: str, document_data: Dict): ... - async def set_job_application(self, application_id: str, application_data: Dict): ... - async def set_job(self, job_id: str, job_data: Dict): ... - async def set_resume(self, user_id: str, resume_data: Dict) -> bool: ... - async def set_viewer(self, viewer_id: str, viewer_data: Dict): ... - async def store_email_verification_token(self, email: str, token: str, user_type: str, user_data: dict) -> bool: ... - async def store_mfa_code(self, email: str, code: str, device_id: str) -> bool: ... - async def store_password_reset_token(self, email: str, token: str, expires_at: datetime) -> bool: ... - async def store_refresh_token(self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]) -> bool: ... - async def update_chat_session_activity(self, session_id: str): ... - async def update_document(self, document_id: str, updates: Dict)-> Dict[Any, Any] | None: ... - async def update_resume(self, user_id: str, resume_id: str, updates: Dict) -> Optional[Dict]: ... + async def delete_all_candidate_documents(self, candidate_id: str) -> int: + ... + async def delete_all_resumes_for_user(self, user_id: str) -> int: + ... + + async def delete_authentication(self, user_id: str) -> bool: + ... + + async def delete_document(self, document_id: str): + ... + + async def delete_job_application(self, application_id: str): + ... + + async def delete_job_requirements(self, document_id: str) -> bool: + ... + + async def delete_job(self, job_id: str): + ... + + async def delete_resume(self, user_id: str, resume_id: str) -> bool: + ... + + async def delete_viewer(self, viewer_id: str): + ... + + async def find_verification_token_by_email(self, email: str) -> Optional[Dict[str, Any]]: + ... + + async def get_ai_parameters(self, param_id: str) -> Optional[Dict]: + ... + + async def get_all_ai_parameters(self) -> Dict[str, Any]: + ... + + async def get_all_candidates(self) -> Dict[str, Any]: + ... + + async def get_all_chat_messages(self) -> Dict[str, List[Dict]]: + ... + + async def get_all_chat_sessions(self) -> Dict[str, Any]: + ... + + async def get_all_employers(self) -> Dict[str, Any]: + ... + + async def get_all_guests(self) -> Dict[str, Dict[str, Any]]: + ... + + async def get_all_job_applications(self) -> Dict[str, Any]: + ... + + async def get_all_job_requirements(self) -> Dict[str, Any]: + ... + + async def get_all_jobs(self) -> Dict[str, Any]: + ... + + async def get_all_resumes_for_user(self, user_id: str) -> List[Dict]: + ... + + async def get_all_resumes(self) -> Dict[str, List[Dict]]: + ... + + async def get_authentication(self, user_id: str) -> Optional[Dict[str, Any]]: + ... + + async def get_cached_skill_match(self, cache_key: str) -> Optional[SkillAssessment]: + ... + + async def get_chat_message_count(self, session_id: str) -> int: + ... + + async def get_chat_messages(self, session_id: str) -> List[Dict]: + ... + + async def get_chat_sessions_by_candidate(self, candidate_id: str) -> List[Dict]: + ... + + async def get_chat_sessions_by_user(self, user_id: str) -> List[Dict]: + ... + + async def get_chat_session(self, session_id: str) -> Optional[Dict]: + ... + + async def get_chat_statistics(self) -> Dict[str, Any]: + ... + + async def get_document_count_for_candidate(self, candidate_id: str) -> int: + ... + + async def get_documents_by_rag_status(self, candidate_id: str, include_in_rag: bool = True) -> List[Dict]: + ... + + async def get_document(self, document_id: str) -> Optional[Dict]: + ... + + async def get_email_verification_token(self, token: str) -> Optional[Dict[str, Any]]: + ... + + async def get_job_application(self, application_id: str) -> Optional[Dict]: + ... + + async def get_job_requirements_by_candidate(self, candidate_id: str) -> List[Dict]: + ... + + async def get_job_requirements(self, document_id: str) -> Optional[Dict]: + ... + + async def get_job_requirements_stats(self) -> Dict[str, Any]: + ... + + async def get_job(self, job_id: str) -> Optional[Dict]: + ... + + async def get_mfa_code(self, email: str, device_id: str) -> Optional[Dict[str, Any]]: + ... + + async def get_multiple_candidates_by_usernames(self, usernames: List[str]) -> Dict[str, Dict]: + ... + + async def get_password_reset_token(self, token: str) -> Optional[Dict[str, Any]]: + ... + + async def get_pending_verifications_count(self) -> int: + ... + + async def get_recent_chat_messages(self, session_id: str, limit: int = 10) -> List[Dict]: + ... + + async def get_refresh_token(self, token: str) -> Optional[Dict[str, Any]]: + ... + + async def get_resumes_by_candidate(self, user_id: str, candidate_id: str) -> List[Dict]: + ... + + async def get_resumes_by_job(self, user_id: str, job_id: str) -> List[Dict]: + ... + + async def get_resume(self, user_id: str, resume_id: str) -> Optional[Dict]: + ... + + async def get_resume_statistics(self, user_id: str) -> Dict[str, Any]: + ... + + async def get_stats(self) -> Dict[str, int]: + ... + + async def get_verification_attempts_count(self, email: str) -> int: + ... + + async def get_viewer(self, viewer_id: str) -> Optional[Dict]: + ... + + async def increment_mfa_attempts(self, email: str, device_id: str) -> int: + ... + + async def invalidate_job_requirements_cache(self, document_id: str) -> bool: + ... + + async def log_security_event(self, user_id: str, event_type: str, details: Dict[str, Any]) -> bool: + ... + + async def mark_email_verified(self, token: str) -> bool: + ... + + async def mark_mfa_verified(self, email: str, device_id: str) -> bool: + ... + + async def mark_password_reset_token_used(self, token: str) -> bool: + ... + + async def record_verification_attempt(self, email: str) -> bool: + ... + + async def remove_document_from_candidate(self, candidate_id: str, document_id: str): + ... + + async def revoke_all_user_tokens(self, user_id: str) -> bool: + ... + + async def revoke_refresh_token(self, token: str) -> bool: + ... + + async def save_job_requirements(self, document_id: str, requirements: Dict) -> bool: + ... + + async def search_candidate_documents(self, candidate_id: str, query: str) -> List[Dict]: + ... + + async def search_chat_messages(self, session_id: str, query: str) -> List[Dict]: + ... + + async def search_resumes_for_user(self, user_id: str, query: str) -> List[Dict]: + ... + + async def set_ai_parameters(self, param_id: str, param_data: Dict): + ... + + async def set_authentication(self, user_id: str, auth_data: Dict[str, Any]) -> bool: + ... + + async def set_chat_messages(self, session_id: str, messages: List[Dict]): + ... + + async def set_chat_session(self, session_id: str, session_data: Dict): + ... + + async def set_document(self, document_id: str, document_data: Dict): + ... + + async def set_job_application(self, application_id: str, application_data: Dict): + ... + + async def set_job(self, job_id: str, job_data: Dict): + ... + + async def set_resume(self, user_id: str, resume_data: Dict) -> bool: + ... + + async def set_viewer(self, viewer_id: str, viewer_data: Dict): + ... + + async def store_email_verification_token(self, email: str, token: str, user_type: str, user_data: dict) -> bool: + ... + + async def store_mfa_code(self, email: str, code: str, device_id: str) -> bool: + ... + + async def store_password_reset_token(self, email: str, token: str, expires_at: datetime) -> bool: + ... + + async def store_refresh_token( + self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str] + ) -> bool: + ... + + async def update_chat_session_activity(self, session_id: str): + ... + + async def update_document(self, document_id: str, updates: Dict) -> Dict[Any, Any] | None: + ... + + async def update_resume(self, user_id: str, resume_id: str, updates: Dict) -> Optional[Dict]: + ... diff --git a/src/backend/database/mixins/resume.py b/src/backend/database/mixins/resume.py index cb0f6c7..9e4672f 100644 --- a/src/backend/database/mixins/resume.py +++ b/src/backend/database/mixins/resume.py @@ -7,6 +7,7 @@ from ..constants import KEY_PREFIXES logger = logging.getLogger(__name__) + class ResumeMixin(DatabaseProtocol): """Mixin for resume-related database operations""" @@ -14,19 +15,19 @@ class ResumeMixin(DatabaseProtocol): """Save a resume for a user""" try: # Generate resume_id if not present - if 'id' not in resume_data: + if "id" not in resume_data: raise ValueError("Resume data must include an 'id' field") - - resume_id = resume_data['id'] - + + resume_id = resume_data["id"] + # Store the resume data key = f"{KEY_PREFIXES['resumes']}{user_id}:{resume_id}" await self.redis.set(key, self._serialize(resume_data)) - + # Add resume_id to user's resume list user_resumes_key = f"{KEY_PREFIXES['user_resumes']}{user_id}" - await self.redis.rpush(user_resumes_key, resume_id) # type: ignore - + await self.redis.rpush(user_resumes_key, resume_id) # type: ignore + logger.info(f"📄 Saved resume {resume_id} for user {user_id}") return True except Exception as e: @@ -53,19 +54,19 @@ class ResumeMixin(DatabaseProtocol): try: # Get all resume IDs for this user user_resumes_key = f"{KEY_PREFIXES['user_resumes']}{user_id}" - resume_ids = await self.redis.lrange(user_resumes_key, 0, -1)# type: ignore - + resume_ids = await self.redis.lrange(user_resumes_key, 0, -1) # type: ignore + if not resume_ids: logger.info(f"📄 No resumes found for user {user_id}") return [] - + # Get all resume data resumes = [] pipe = self.redis.pipeline() for resume_id in resume_ids: pipe.get(f"{KEY_PREFIXES['resumes']}{user_id}:{resume_id}") values = await pipe.execute() - + for resume_id, value in zip(resume_ids, values): if value: resume_data = self._deserialize(value) @@ -73,12 +74,12 @@ class ResumeMixin(DatabaseProtocol): resumes.append(resume_data) else: # Clean up orphaned resume ID - await self.redis.lrem(user_resumes_key, 0, resume_id)# type: ignore + await self.redis.lrem(user_resumes_key, 0, resume_id) # type: ignore logger.warning(f"Removed orphaned resume ID {resume_id} for user {user_id}") - + # Sort by created_at timestamp (most recent first) resumes.sort(key=lambda x: x.get("created_at", ""), reverse=True) - + logger.info(f"📄 Retrieved {len(resumes)} resumes for user {user_id}") return resumes except Exception as e: @@ -91,11 +92,11 @@ class ResumeMixin(DatabaseProtocol): # Delete the resume data key = f"{KEY_PREFIXES['resumes']}{user_id}:{resume_id}" result = await self.redis.delete(key) - + # Remove from user's resume list user_resumes_key = f"{KEY_PREFIXES['user_resumes']}{user_id}" - await self.redis.lrem(user_resumes_key, 0, resume_id)# type: ignore - + await self.redis.lrem(user_resumes_key, 0, resume_id) # type: ignore + if result > 0: logger.info(f"🗑️ Deleted resume {resume_id} for user {user_id}") return True @@ -111,31 +112,31 @@ class ResumeMixin(DatabaseProtocol): try: # Get all resume IDs for this user user_resumes_key = f"{KEY_PREFIXES['user_resumes']}{user_id}" - resume_ids = await self.redis.lrange(user_resumes_key, 0, -1)# type: ignore - + resume_ids = await self.redis.lrange(user_resumes_key, 0, -1) # type: ignore + if not resume_ids: logger.info(f"📄 No resumes found for user {user_id}") return 0 - + deleted_count = 0 - + # Use pipeline for efficient batch operations pipe = self.redis.pipeline() - + # Delete each resume for resume_id in resume_ids: pipe.delete(f"{KEY_PREFIXES['resumes']}{user_id}:{resume_id}") deleted_count += 1 - + # Delete the user's resume list pipe.delete(user_resumes_key) - + # Execute all operations await pipe.execute() - + logger.info(f"🗑️ Successfully deleted {deleted_count} resumes for user {user_id}") return deleted_count - + except Exception as e: logger.error(f"❌ Error deleting all resumes for user {user_id}: {e}") raise @@ -145,21 +146,21 @@ class ResumeMixin(DatabaseProtocol): try: pattern = f"{KEY_PREFIXES['resumes']}*" keys = await self.redis.keys(pattern) - + if not keys: return {} - + # Group by user_id user_resumes = {} pipe = self.redis.pipeline() for key in keys: pipe.get(key) values = await pipe.execute() - + for key, value in zip(keys, values): if value: # Extract user_id from key format: resume:{user_id}:{resume_id} - key_parts = key.replace(KEY_PREFIXES['resumes'], '').split(':', 1) + key_parts = key.replace(KEY_PREFIXES["resumes"], "").split(":", 1) if len(key_parts) >= 1: user_id = key_parts[0] resume_data = self._deserialize(value) @@ -167,11 +168,11 @@ class ResumeMixin(DatabaseProtocol): if user_id not in user_resumes: user_resumes[user_id] = [] user_resumes[user_id].append(resume_data) - + # Sort each user's resumes by created_at for user_id in user_resumes: user_resumes[user_id].sort(key=lambda x: x.get("created_at", ""), reverse=True) - + return user_resumes except Exception as e: logger.error(f"❌ Error retrieving all resumes: {e}") @@ -182,20 +183,22 @@ class ResumeMixin(DatabaseProtocol): try: all_resumes = await self.get_all_resumes_for_user(user_id) query_lower = query.lower() - + matching_resumes = [] for resume in all_resumes: # Search in resume content, job_id, candidate_id, etc. - searchable_text = " ".join([ - resume.get("resume", ""), - resume.get("job_id", ""), - resume.get("candidate_id", ""), - str(resume.get("created_at", "")) - ]).lower() - + searchable_text = " ".join( + [ + resume.get("resume", ""), + resume.get("job_id", ""), + resume.get("candidate_id", ""), + str(resume.get("created_at", "")), + ] + ).lower() + if query_lower in searchable_text: matching_resumes.append(resume) - + logger.info(f"📄 Found {len(matching_resumes)} matching resumes for user {user_id}") return matching_resumes except Exception as e: @@ -206,11 +209,8 @@ class ResumeMixin(DatabaseProtocol): """Get all resumes for a specific candidate created by a user""" try: all_resumes = await self.get_all_resumes_for_user(user_id) - candidate_resumes = [ - resume for resume in all_resumes - if resume.get("candidate_id") == candidate_id - ] - + candidate_resumes = [resume for resume in all_resumes if resume.get("candidate_id") == candidate_id] + logger.info(f"📄 Found {len(candidate_resumes)} resumes for candidate {candidate_id} by user {user_id}") return candidate_resumes except Exception as e: @@ -221,11 +221,8 @@ class ResumeMixin(DatabaseProtocol): """Get all resumes for a specific job created by a user""" try: all_resumes = await self.get_all_resumes_for_user(user_id) - job_resumes = [ - resume for resume in all_resumes - if resume.get("job_id") == job_id - ] - + job_resumes = [resume for resume in all_resumes if resume.get("job_id") == job_id] + logger.info(f"📄 Found {len(job_resumes)} resumes for job {job_id} by user {user_id}") return job_resumes except Exception as e: @@ -236,24 +233,24 @@ class ResumeMixin(DatabaseProtocol): """Get resume statistics for a user""" try: all_resumes = await self.get_all_resumes_for_user(user_id) - + stats = { "total_resumes": len(all_resumes), "resumes_by_candidate": {}, "resumes_by_job": {}, "creation_timeline": {}, - "recent_resumes": [] + "recent_resumes": [], } - + for resume in all_resumes: # Count by candidate candidate_id = resume.get("candidate_id", "unknown") stats["resumes_by_candidate"][candidate_id] = stats["resumes_by_candidate"].get(candidate_id, 0) + 1 - + # Count by job job_id = resume.get("job_id", "unknown") stats["resumes_by_job"][job_id] = stats["resumes_by_job"].get(job_id, 0) + 1 - + # Timeline by date created_at = resume.get("created_at") if created_at: @@ -262,14 +259,20 @@ class ResumeMixin(DatabaseProtocol): stats["creation_timeline"][date_key] = stats["creation_timeline"].get(date_key, 0) + 1 except (IndexError, TypeError): pass - + # Get recent resumes (last 5) stats["recent_resumes"] = all_resumes[:5] - + return stats except Exception as e: logger.error(f"❌ Error getting resume statistics for user {user_id}: {e}") - return {"total_resumes": 0, "resumes_by_candidate": {}, "resumes_by_job": {}, "creation_timeline": {}, "recent_resumes": []} + return { + "total_resumes": 0, + "resumes_by_candidate": {}, + "resumes_by_job": {}, + "creation_timeline": {}, + "recent_resumes": [], + } async def update_resume(self, user_id: str, resume_id: str, updates: Dict) -> Optional[Dict]: """Update specific fields of a resume""" @@ -278,10 +281,10 @@ class ResumeMixin(DatabaseProtocol): if resume_data: resume_data.update(updates) resume_data["updated_at"] = datetime.now(UTC).isoformat() - + key = f"{KEY_PREFIXES['resumes']}{user_id}:{resume_id}" await self.redis.set(key, self._serialize(resume_data)) - + logger.info(f"📄 Updated resume {resume_id} for user {user_id}") return resume_data return None diff --git a/src/backend/database/mixins/skill.py b/src/backend/database/mixins/skill.py index 84b105b..91add8e 100644 --- a/src/backend/database/mixins/skill.py +++ b/src/backend/database/mixins/skill.py @@ -8,6 +8,7 @@ from .protocols import DatabaseProtocol logger = logging.getLogger(__name__) + class SkillMixin(DatabaseProtocol): """Mixin for Skill-related database operations""" @@ -23,7 +24,7 @@ class SkillMixin(DatabaseProtocol): except Exception as e: logger.error(f"❌ Error getting cached skill match: {e}") return None - + async def invalidate_candidate_skill_cache(self, candidate_id: str) -> int: """Invalidate all cached skill matches for a specific candidate""" try: @@ -35,7 +36,7 @@ class SkillMixin(DatabaseProtocol): except Exception as e: logger.error(f"Error invalidating candidate skill cache: {e}") return 0 - + async def clear_all_skill_match_cache(self) -> int: """Clear all skill match cache (useful after major system updates)""" try: @@ -47,7 +48,7 @@ class SkillMixin(DatabaseProtocol): except Exception as e: logger.error(f"Error clearing skill match cache: {e}") return 0 - + async def invalidate_user_skill_cache(self, user_id: str) -> int: """Invalidate all cached skill matches when a user's RAG data is updated""" try: @@ -55,12 +56,12 @@ class SkillMixin(DatabaseProtocol): # You might need to adjust the pattern based on how you associate candidates with users pattern = f"skill_match:{user_id}:*" keys = await self.redis.keys(pattern) - + # Filter keys that belong to candidates owned by this user # This would require additional logic to determine candidate ownership # For now, you might want to clear all cache when any user's RAG data updates # or implement a more sophisticated mapping - + if keys: return await self.redis.delete(*keys) return 0 @@ -73,8 +74,10 @@ class SkillMixin(DatabaseProtocol): try: # Cache for 1 hour by default await self.redis.set( - cache_key, - json.dumps(assessment.model_dump(mode='json', by_alias=True), default=str) # Serialize with datetime handling + cache_key, + json.dumps( + assessment.model_dump(mode="json", by_alias=True), default=str + ), # Serialize with datetime handling ) logger.info(f"💾 Skill match cached: {cache_key}") except Exception as e: diff --git a/src/backend/database/mixins/user.py b/src/backend/database/mixins/user.py index 0c8fbbc..7acf8bc 100644 --- a/src/backend/database/mixins/user.py +++ b/src/backend/database/mixins/user.py @@ -10,6 +10,7 @@ from ..constants import KEY_PREFIXES logger = logging.getLogger(__name__) + class UserMixin(DatabaseProtocol): """Mixin for user operations""" @@ -21,27 +22,27 @@ class UserMixin(DatabaseProtocol): try: # Ensure last_activity is always set guest_data["last_activity"] = datetime.now(UTC).isoformat() - + # Store in Redis with both hash and individual key for redundancy - await self.redis.hset("guests", guest_id, json.dumps(guest_data))# type: ignore - + await self.redis.hset("guests", guest_id, json.dumps(guest_data)) # type: ignore + # Also store with a longer TTL as backup await self.redis.setex( - f"guest_backup:{guest_id}", + f"guest_backup:{guest_id}", 86400 * 7, # 7 days TTL - json.dumps(guest_data) + json.dumps(guest_data), ) - + logger.info(f"💾 Guest stored with backup: {guest_id}") except Exception as e: logger.error(f"❌ Error storing guest {guest_id}: {e}") raise - + async def get_guest(self, guest_id: str) -> Optional[Dict[str, Any]]: """Get guest data with fallback to backup""" try: # Try primary storage first - data = await self.redis.hget("guests", guest_id)# type: ignore + data = await self.redis.hget("guests", guest_id) # type: ignore if data: guest_data = json.loads(data) # Update last activity when accessed @@ -49,24 +50,24 @@ class UserMixin(DatabaseProtocol): await self.set_guest(guest_id, guest_data) logger.info(f"🔍 Guest found in primary storage: {guest_id}") return guest_data - + # Fallback to backup storage backup_data = await self.redis.get(f"guest_backup:{guest_id}") if backup_data: guest_data = json.loads(backup_data) guest_data["last_activity"] = datetime.now(UTC).isoformat() - + # Restore to primary storage await self.set_guest(guest_id, guest_data) logger.info(f"🔄 Guest restored from backup: {guest_id}") return guest_data - + logger.warning(f"⚠️ Guest not found: {guest_id}") return None except Exception as e: logger.error(f"❌ Error getting guest {guest_id}: {e}") return None - + async def get_guest_by_session_id(self, session_id: str) -> Optional[Dict[str, Any]]: """Get guest data by session ID""" try: @@ -78,23 +79,20 @@ class UserMixin(DatabaseProtocol): except Exception as e: logger.error(f"❌ Error getting guest by session ID {session_id}: {e}") return None - + async def get_all_guests(self) -> Dict[str, Dict[str, Any]]: """Get all guests""" try: - data = await self.redis.hgetall("guests")# type: ignore - return { - guest_id: json.loads(guest_json) - for guest_id, guest_json in data.items() - } + data = await self.redis.hgetall("guests") # type: ignore + return {guest_id: json.loads(guest_json) for guest_id, guest_json in data.items()} except Exception as e: logger.error(f"❌ Error getting all guests: {e}") return {} - + async def delete_guest(self, guest_id: str) -> bool: """Delete a guest""" try: - result = await self.redis.hdel("guests", guest_id)# type: ignore + result = await self.redis.hdel("guests", guest_id) # type: ignore if result: logger.info(f"🗑️ Guest deleted: {guest_id}") return True @@ -102,35 +100,35 @@ class UserMixin(DatabaseProtocol): except Exception as e: logger.error(f"❌ Error deleting guest {guest_id}: {e}") return False - + async def cleanup_inactive_guests(self, inactive_hours: int = 24) -> int: """Clean up inactive guest sessions with safety checks""" try: all_guests = await self.get_all_guests() current_time = datetime.now(UTC) cutoff_time = current_time - timedelta(hours=inactive_hours) - + deleted_count = 0 preserved_count = 0 - + for guest_id, guest_data in all_guests.items(): try: last_activity_str = guest_data.get("last_activity") created_at_str = guest_data.get("created_at") - + # Skip cleanup if guest is very new (less than 1 hour old) if created_at_str: - created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00')) + created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00")) if current_time - created_at < timedelta(hours=1): preserved_count += 1 logger.info(f"🛡️ Preserving new guest: {guest_id}") continue - + # Check last activity should_delete = False if last_activity_str: try: - last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00')) + last_activity = datetime.fromisoformat(last_activity_str.replace("Z", "+00:00")) if last_activity < cutoff_time: should_delete = True except ValueError: @@ -138,81 +136,81 @@ class UserMixin(DatabaseProtocol): if not created_at_str: should_delete = True else: - # No last activity, but don't delete if guest is new + # No last activity, but don't delete if guest is new if not created_at_str: should_delete = True - + if should_delete: await self.delete_guest(guest_id) deleted_count += 1 else: preserved_count += 1 - + except Exception as e: logger.error(f"❌ Error processing guest {guest_id} for cleanup: {e}") preserved_count += 1 # Preserve on error - + if deleted_count > 0: logger.info(f"🧹 Guest cleanup: removed {deleted_count}, preserved {preserved_count}") - + return deleted_count except Exception as e: logger.error(f"❌ Error in guest cleanup: {e}") return 0 - + async def get_guest_statistics(self) -> Dict[str, Any]: """Get guest usage statistics""" try: all_guests = await self.get_all_guests() current_time = datetime.now(UTC) - + stats = { "total_guests": len(all_guests), "active_last_hour": 0, "active_last_day": 0, "converted_guests": 0, "by_ip": {}, - "creation_timeline": {} + "creation_timeline": {}, } - + hour_ago = current_time - timedelta(hours=1) day_ago = current_time - timedelta(days=1) - + for guest_data in all_guests.values(): # Check activity last_activity_str = guest_data.get("last_activity") if last_activity_str: try: - last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00')) + last_activity = datetime.fromisoformat(last_activity_str.replace("Z", "+00:00")) if last_activity > hour_ago: stats["active_last_hour"] += 1 if last_activity > day_ago: stats["active_last_day"] += 1 except ValueError: pass - + # Check conversions if guest_data.get("converted_to_user_id"): stats["converted_guests"] += 1 - + # IP tracking ip = guest_data.get("ip_address", "unknown") stats["by_ip"][ip] = stats["by_ip"].get(ip, 0) + 1 - + # Creation timeline created_at_str = guest_data.get("created_at") if created_at_str: try: - created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00')) - date_key = created_at.strftime('%Y-%m-%d') + created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00")) + date_key = created_at.strftime("%Y-%m-%d") stats["creation_timeline"][date_key] = stats["creation_timeline"].get(date_key, 0) + 1 except ValueError: pass - + return stats except Exception as e: logger.error(f"❌ Error getting guest statistics: {e}") - return {} + return {} # ================ # Users @@ -222,7 +220,7 @@ class UserMixin(DatabaseProtocol): username_key = f"{KEY_PREFIXES['users']}{username.lower()}" data = await self.redis.get(username_key) return self._deserialize(data) if data else None - + async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]: """Get the last time user's RAG data was updated (returns timezone-aware UTC)""" try: @@ -241,19 +239,19 @@ class UserMixin(DatabaseProtocol): return None except Exception as e: logger.error(f"❌ Error getting user RAG update time: {e}") - return None - + return None + async def update_user_rag_timestamp(self, user_id: str) -> bool: """Set the user's RAG data update time (stores as UTC ISO format)""" try: update_time = datetime.now(timezone.utc) - + # Ensure we're storing UTC timezone-aware format if update_time.tzinfo is None: update_time = update_time.replace(tzinfo=timezone.utc) else: update_time = update_time.astimezone(timezone.utc) - + rag_update_key = f"user:{user_id}:rag_last_update" # Store as ISO format with timezone info timestamp_str = update_time.isoformat() # This includes timezone @@ -274,18 +272,18 @@ class UserMixin(DatabaseProtocol): except Exception as e: logger.error(f"❌ Error storing user by ID {user_id}: {e}") return False - + async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: """Get user lookup data by user ID""" try: - data = await self.redis.hget("user_lookup_by_id", user_id)# type: ignore + data = await self.redis.hget("user_lookup_by_id", user_id) # type: ignore if data: return json.loads(data) return None except Exception as e: logger.error(f"❌ Error getting user by ID {user_id}: {e}") - return None - + return None + async def user_exists_by_email(self, email: str) -> bool: """Check if a user exists with the given email""" try: @@ -295,7 +293,7 @@ class UserMixin(DatabaseProtocol): except Exception as e: logger.error(f"❌ Error checking user existence by email {email}: {e}") return False - + async def user_exists_by_username(self, username: str) -> bool: """Check if a user exists with the given username""" try: @@ -305,31 +303,31 @@ class UserMixin(DatabaseProtocol): except Exception as e: logger.error(f"❌ Error checking user existence by username {username}: {e}") return False - + async def get_all_users(self) -> Dict[str, Any]: """Get all users""" pattern = f"{KEY_PREFIXES['users']}*" keys = await self.redis.keys(pattern) - + if not keys: return {} - + pipe = self.redis.pipeline() for key in keys: pipe.get(key) values = await pipe.execute() - + result = {} for key, value in zip(keys, values): - email = key.replace(KEY_PREFIXES['users'], '') + email = key.replace(KEY_PREFIXES["users"], "") logger.info(f"🔍 Found user key: {key}, type: {type(value)}") if type(value) == str: result[email] = value else: result[email] = self._deserialize(value) - + return result - + async def delete_user(self, email: str): """Delete user""" key = f"{KEY_PREFIXES['users']}{email}" @@ -340,7 +338,7 @@ class UserMixin(DatabaseProtocol): try: login = login.strip().lower() key = f"users:{login}" - + data = await self.redis.get(key) if data: user_data = json.loads(data) @@ -356,7 +354,7 @@ class UserMixin(DatabaseProtocol): try: login = login.strip().lower() key = f"users:{login}" - + await self.redis.set(key, json.dumps(user_data, default=str)) logger.info(f"👤 Stored user data for {login}") return True @@ -364,7 +362,6 @@ class UserMixin(DatabaseProtocol): logger.error(f"❌ Error storing user {login}: {e}") return False - # ================ # Employers # ================ @@ -373,37 +370,36 @@ class UserMixin(DatabaseProtocol): key = f"{KEY_PREFIXES['employers']}{employer_id}" data = await self.redis.get(key) return self._deserialize(data) if data else None - + async def set_employer(self, employer_id: str, employer_data: Dict): """Set employer data""" key = f"{KEY_PREFIXES['employers']}{employer_id}" await self.redis.set(key, self._serialize(employer_data)) - + async def get_all_employers(self) -> Dict[str, Any]: """Get all employers""" pattern = f"{KEY_PREFIXES['employers']}*" keys = await self.redis.keys(pattern) - + if not keys: return {} - + pipe = self.redis.pipeline() for key in keys: pipe.get(key) values = await pipe.execute() - + result = {} for key, value in zip(keys, values): - employer_id = key.replace(KEY_PREFIXES['employers'], '') + employer_id = key.replace(KEY_PREFIXES["employers"], "") result[employer_id] = self._deserialize(value) - + return result - + async def delete_employer(self, employer_id: str): """Delete employer""" key = f"{KEY_PREFIXES['employers']}{employer_id}" await self.redis.delete(key) - # ================ # Candidates @@ -413,33 +409,33 @@ class UserMixin(DatabaseProtocol): key = f"{KEY_PREFIXES['candidates']}{candidate_id}" data = await self.redis.get(key) return self._deserialize(data) if data else None - + async def set_candidate(self, candidate_id: str, candidate_data: Dict): """Set candidate data""" key = f"{KEY_PREFIXES['candidates']}{candidate_id}" await self.redis.set(key, self._serialize(candidate_data)) - + async def get_all_candidates(self) -> Dict[str, Any]: """Get all candidates""" pattern = f"{KEY_PREFIXES['candidates']}*" keys = await self.redis.keys(pattern) - + if not keys: return {} - + # Use pipeline for efficiency pipe = self.redis.pipeline() for key in keys: pipe.get(key) values = await pipe.execute() - + result = {} for key, value in zip(keys, values): - candidate_id = key.replace(KEY_PREFIXES['candidates'], '') + candidate_id = key.replace(KEY_PREFIXES["candidates"], "") result[candidate_id] = self._deserialize(value) - + return result - + async def delete_candidate(self: Self, candidate_id: str) -> Dict[str, int]: """ Delete candidate and all related records in a cascading manner @@ -456,20 +452,20 @@ class UserMixin(DatabaseProtocol): "security_logs": 0, "ai_parameters": 0, "candidate_record": 0, - "resumes": 0 + "resumes": 0, } - + logger.info(f"🗑️ Starting cascading delete for candidate {candidate_id}") - + # 1. Get candidate data first to retrieve associated information candidate_data = await self.get_candidate(candidate_id) if not candidate_data: logger.warning(f"⚠️ Candidate {candidate_id} not found") return deletion_stats - + candidate_email = candidate_data.get("email", "").lower() candidate_username = candidate_data.get("username", "").lower() - + # 2. Delete all candidate documents and their metadata try: documents_deleted = await self.delete_all_candidate_documents(candidate_id) @@ -477,64 +473,68 @@ class UserMixin(DatabaseProtocol): logger.info(f"🗑️ Deleted {documents_deleted} documents for candidate {candidate_id}") except Exception as e: logger.error(f"❌ Error deleting candidate documents: {e}") - + # 3. Delete all chat sessions related to this candidate try: candidate_sessions = await self.get_chat_sessions_by_candidate(candidate_id) messages_deleted = 0 - + for session in candidate_sessions: session_id = session.get("id") if session_id: # Count messages before deletion message_count = await self.get_chat_message_count(session_id) messages_deleted += message_count - + # Delete chat session and its messages await self.delete_chat_session_completely(session_id) - + deletion_stats["chat_sessions"] = len(candidate_sessions) deletion_stats["chat_messages"] = messages_deleted - logger.info(f"🗑️ Deleted {len(candidate_sessions)} chat sessions and {messages_deleted} messages for candidate {candidate_id}") + logger.info( + f"🗑️ Deleted {len(candidate_sessions)} chat sessions and {messages_deleted} messages for candidate {candidate_id}" + ) except Exception as e: logger.error(f"❌ Error deleting chat sessions: {e}") - + # 4. Delete job applications from this candidate try: all_applications = await self.get_all_job_applications() candidate_applications = [] - + for app_id, app_data in all_applications.items(): if app_data.get("candidateId") == candidate_id: candidate_applications.append(app_id) - + # Delete each application for app_id in candidate_applications: await self.delete_job_application(app_id) - + deletion_stats["job_applications"] = len(candidate_applications) logger.info(f"🗑️ Deleted {len(candidate_applications)} job applications for candidate {candidate_id}") except Exception as e: logger.error(f"❌ Error deleting job applications: {e}") - + # 5. Delete user records (by email and username if they exist) try: user_records_deleted = 0 - + # Delete by email if candidate_email and await self.user_exists_by_email(candidate_email): await self.delete_user(candidate_email) user_records_deleted += 1 logger.info(f"🗑️ Deleted user record by email: {candidate_email}") - + # Delete by username (if different from email) - if (candidate_username and - candidate_username != candidate_email and - await self.user_exists_by_username(candidate_username)): + if ( + candidate_username + and candidate_username != candidate_email + and await self.user_exists_by_username(candidate_username) + ): await self.delete_user(candidate_username) user_records_deleted += 1 logger.info(f"🗑️ Deleted user record by username: {candidate_username}") - + # Delete user by ID if exists user_by_id = await self.get_user_by_id(candidate_id) if user_by_id: @@ -542,12 +542,12 @@ class UserMixin(DatabaseProtocol): await self.redis.delete(key) user_records_deleted += 1 logger.info(f"🗑️ Deleted user record by ID: {candidate_id}") - + deletion_stats["user_records"] = user_records_deleted logger.info(f"🗑️ Deleted {user_records_deleted} user records for candidate {candidate_id}") except Exception as e: logger.error(f"❌ Error deleting user records: {e}") - + # 6. Delete authentication records try: auth_deleted = await self.delete_authentication(candidate_id) @@ -556,57 +556,56 @@ class UserMixin(DatabaseProtocol): logger.info(f"🗑️ Deleted authentication record for candidate {candidate_id}") except Exception as e: logger.error(f"❌ Error deleting authentication records: {e}") - + # 7. Revoke all refresh tokens for this user try: await self.revoke_all_user_tokens(candidate_id) logger.info(f"🗑️ Revoked all refresh tokens for candidate {candidate_id}") except Exception as e: logger.error(f"❌ Error revoking refresh tokens: {e}") - + # 8. Delete security logs for this user try: security_logs_deleted = 0 # Security logs are stored by date, so we need to scan for them pattern = f"security_log:{candidate_id}:*" cursor = 0 - + while True: cursor, keys = await self.redis.scan(cursor, match=pattern, count=100) - + if keys: await self.redis.delete(*keys) security_logs_deleted += len(keys) - + if cursor == 0: break - + deletion_stats["security_logs"] = security_logs_deleted if security_logs_deleted > 0: logger.info(f"🗑️ Deleted {security_logs_deleted} security log entries for candidate {candidate_id}") except Exception as e: logger.error(f"❌ Error deleting security logs: {e}") - + # 9. Delete AI parameters that might be specific to this candidate try: all_ai_params = await self.get_all_ai_parameters() candidate_ai_params = [] - + for param_id, param_data in all_ai_params.items(): - if (param_data.get("candidateId") == candidate_id or - param_data.get("userId") == candidate_id): + if param_data.get("candidateId") == candidate_id or param_data.get("userId") == candidate_id: candidate_ai_params.append(param_id) - + # Delete each AI parameter set for param_id in candidate_ai_params: await self.delete_ai_parameters(param_id) - + deletion_stats["ai_parameters"] = len(candidate_ai_params) if len(candidate_ai_params) > 0: logger.info(f"🗑️ Deleted {len(candidate_ai_params)} AI parameter sets for candidate {candidate_id}") except Exception as e: logger.error(f"❌ Error deleting AI parameters: {e}") - + # 10. Delete email verification tokens if any exist try: if candidate_email: @@ -614,10 +613,10 @@ class UserMixin(DatabaseProtocol): pattern = "email_verification:*" cursor = 0 tokens_deleted = 0 - + while True: cursor, keys = await self.redis.scan(cursor, match=pattern, count=100) - + for key in keys: token_data = await self.redis.get(key) if token_data: @@ -625,25 +624,27 @@ class UserMixin(DatabaseProtocol): if verification_info.get("email", "").lower() == candidate_email: await self.redis.delete(key) tokens_deleted += 1 - + if cursor == 0: break - + if tokens_deleted > 0: - logger.info(f"🗑️ Deleted {tokens_deleted} email verification tokens for candidate {candidate_id}") + logger.info( + f"🗑️ Deleted {tokens_deleted} email verification tokens for candidate {candidate_id}" + ) except Exception as e: logger.error(f"❌ Error deleting email verification tokens: {e}") - + # 11. Delete password reset tokens if any exist try: if candidate_email: pattern = "password_reset:*" cursor = 0 tokens_deleted = 0 - + while True: cursor, keys = await self.redis.scan(cursor, match=pattern, count=100) - + for key in keys: token_data = await self.redis.get(key) if token_data: @@ -651,37 +652,37 @@ class UserMixin(DatabaseProtocol): if reset_info.get("email", "").lower() == candidate_email: await self.redis.delete(key) tokens_deleted += 1 - + if cursor == 0: break - + if tokens_deleted > 0: logger.info(f"🗑️ Deleted {tokens_deleted} password reset tokens for candidate {candidate_id}") except Exception as e: logger.error(f"❌ Error deleting password reset tokens: {e}") - + # 12. Delete MFA codes if any exist try: if candidate_email: pattern = f"mfa_code:{candidate_email}:*" cursor = 0 mfa_codes_deleted = 0 - + while True: cursor, keys = await self.redis.scan(cursor, match=pattern, count=100) - + if keys: await self.redis.delete(*keys) mfa_codes_deleted += len(keys) - + if cursor == 0: break - + if mfa_codes_deleted > 0: logger.info(f"🗑️ Deleted {mfa_codes_deleted} MFA codes for candidate {candidate_id}") except Exception as e: logger.error(f"❌ Error deleting MFA codes: {e}") - + # 13. Finally, delete the candidate record itself try: key = f"{KEY_PREFIXES['candidates']}{candidate_id}" @@ -690,24 +691,24 @@ class UserMixin(DatabaseProtocol): logger.info(f"🗑️ Deleted candidate record for {candidate_id}") except Exception as e: logger.error(f"❌ Error deleting candidate record: {e}") - + # 14. Delete resumes associated with this candidate across all users try: all_resumes = await self.get_all_resumes() candidate_resumes_deleted = 0 - + for user_id, user_resumes in all_resumes.items(): resumes_to_delete = [] for resume in user_resumes: if resume.get("candidate_id") == candidate_id: resumes_to_delete.append(resume.get("resume_id")) - + # Delete each resume for this candidate for resume_id in resumes_to_delete: if resume_id: await self.delete_resume(user_id, resume_id) candidate_resumes_deleted += 1 - + deletion_stats["resumes"] = candidate_resumes_deleted if candidate_resumes_deleted > 0: logger.info(f"🗑️ Deleted {candidate_resumes_deleted} resumes for candidate {candidate_id}") @@ -717,14 +718,16 @@ class UserMixin(DatabaseProtocol): # 15. Log the deletion as a security event (if we have admin/system user context) try: total_items_deleted = sum(deletion_stats.values()) - logger.info(f"✅ Completed cascading delete for candidate {candidate_id}. " - f"Total items deleted: {total_items_deleted}") + logger.info( + f"✅ Completed cascading delete for candidate {candidate_id}. " + f"Total items deleted: {total_items_deleted}" + ) logger.info(f"📊 Deletion breakdown: {deletion_stats}") except Exception as e: logger.error(f"❌ Error logging deletion summary: {e}") - + return deletion_stats - + except Exception as e: logger.error(f"❌ Critical error during candidate deletion {candidate_id}: {e}") raise @@ -748,25 +751,25 @@ class UserMixin(DatabaseProtocol): "candidate_record": 0, "resumes": 0, } - + logger.info(f"🗑️ Starting batch deletion for {len(candidate_ids)} candidates") - + for candidate_id in candidate_ids: try: deletion_stats = await self.delete_candidate(candidate_id) batch_results[candidate_id] = deletion_stats - + # Add to totals for key, value in deletion_stats.items(): total_stats[key] += value - + except Exception as e: logger.error(f"❌ Failed to delete candidate {candidate_id}: {e}") batch_results[candidate_id] = {"error": str(e)} - + logger.info(f"✅ Completed batch deletion. Total items deleted: {sum(total_stats.values())}") logger.info(f"📊 Batch totals: {total_stats}") - + return { "individual_results": batch_results, "totals": total_stats, @@ -774,70 +777,70 @@ class UserMixin(DatabaseProtocol): "total_candidates_processed": len(candidate_ids), "successful_deletions": len([r for r in batch_results.values() if "error" not in r]), "failed_deletions": len([r for r in batch_results.values() if "error" in r]), - "total_items_deleted": sum(total_stats.values()) - } + "total_items_deleted": sum(total_stats.values()), + }, } - + except Exception as e: logger.error(f"❌ Critical error during batch candidate deletion: {e}") - raise + raise async def find_candidate_by_username(self, username: str) -> Optional[Dict]: """Find candidate by username""" all_candidates = await self.get_all_candidates() username_lower = username.lower() - + for candidate_data in all_candidates.values(): if candidate_data.get("username", "").lower() == username_lower: return candidate_data - + return None async def get_multiple_candidates_by_usernames(self, usernames: List[str]) -> Dict[str, Dict]: """Get multiple candidates by their usernames efficiently""" all_candidates = await self.get_all_candidates() username_set = {username.lower() for username in usernames} - + result = {} for candidate_data in all_candidates.values(): candidate_username = candidate_data.get("username", "").lower() if candidate_username in username_set: result[candidate_username] = candidate_data - + return result async def get_candidate_chat_summary(self, candidate_id: str) -> Dict[str, Any]: """Get a summary of chat activity for a specific candidate""" sessions = await self.get_chat_sessions_by_candidate(candidate_id) - + if not sessions: return { "candidate_id": candidate_id, "total_sessions": 0, "total_messages": 0, "first_chat": None, - "last_chat": None + "last_chat": None, } - + total_messages = 0 for session in sessions: session_id = session.get("id") if session_id: message_count = await self.get_chat_message_count(session_id) total_messages += message_count - + # Sort sessions by creation date sessions_by_date = sorted(sessions, key=lambda x: x.get("createdAt", "")) - + return { "candidate_id": candidate_id, "total_sessions": len(sessions), "total_messages": total_messages, "first_chat": sessions_by_date[0].get("createdAt") if sessions_by_date else None, "last_chat": sessions_by_date[-1].get("lastActivity") if sessions_by_date else None, - "recent_sessions": sessions[:5] # Last 5 sessions + "recent_sessions": sessions[:5], # Last 5 sessions } - + # ================ # Viewers # ================ @@ -846,35 +849,34 @@ class UserMixin(DatabaseProtocol): key = f"{KEY_PREFIXES['viewers']}{viewer_id}" data = await self.redis.get(key) return self._deserialize(data) if data else None - + async def set_viewer(self, viewer_id: str, viewer_data: Dict): """Set viewer data""" key = f"{KEY_PREFIXES['viewers']}{viewer_id}" await self.redis.set(key, self._serialize(viewer_data)) - + async def get_all_viewers(self) -> Dict[str, Any]: """Get all viewers""" pattern = f"{KEY_PREFIXES['viewers']}*" keys = await self.redis.keys(pattern) - + if not keys: return {} - + # Use pipeline for efficiency pipe = self.redis.pipeline() for key in keys: pipe.get(key) values = await pipe.execute() - + result = {} for key, value in zip(keys, values): - viewer_id = key.replace(KEY_PREFIXES['viewers'], '') + viewer_id = key.replace(KEY_PREFIXES["viewers"], "") result[viewer_id] = self._deserialize(value) - + return result - + async def delete_viewer(self, viewer_id: str): """Delete viewer""" key = f"{KEY_PREFIXES['viewers']}{viewer_id}" await self.redis.delete(key) - diff --git a/src/backend/defines.py b/src/backend/defines.py index 5c517b9..45566dc 100644 --- a/src/backend/defines.py +++ b/src/backend/defines.py @@ -3,20 +3,20 @@ import os ollama_api_url = "http://ollama:11434" # Default Ollama local endpoint user_dir = "/opt/backstory/users" -user_info_file = "info.json" # Relative to "{user_dir}/{user}" +user_info_file = "info.json" # Relative to "{user_dir}/{user}" default_username = "jketreno" -rag_content_dir = "rag-content" # Relative to "{user_dir}/{user}" +rag_content_dir = "rag-content" # Relative to "{user_dir}/{user}" # Path to candidate full resume -resume_doc_dir = f"{rag_content_dir}/resume" # Relative to "{user_dir}/{user} +resume_doc_dir = f"{rag_content_dir}/resume" # Relative to "{user_dir}/{user} resume_doc = "resume.md" -persist_directory = "db" # Relative to "{user_dir}/{user}" +persist_directory = "db" # Relative to "{user_dir}/{user}" -# Model name License Notes -# model = "deepseek-r1:7b" # MIT Tool calls don"t work +# Model name License Notes +# model = "deepseek-r1:7b" # MIT Tool calls don"t work # model = "gemma3:4b" # Gemma Requires newer ollama https://ai.google.dev/gemma/terms # model = "llama3.2" # Llama Good results; qwen seems slightly better https://huggingface.co/meta-llama/Llama-3.2-1B/blob/main/LICENSE.txt # model = "mistral:7b" # Apache 2.0 Tool calls don"t work -model = "qwen2.5:7b" # Apache 2.0 Good results +model = "qwen2.5:7b" # Apache 2.0 Good results # model = "qwen3:8b" # Apache 2.0 Requires newer ollama model = os.getenv("MODEL_NAME", model) @@ -40,7 +40,7 @@ logging_level = os.getenv("LOGGING_LEVEL", "INFO").upper() # RAG and Vector DB settings ## Where to read RAG content -chunk_buffer = 5 # Number of lines before and after chunk beyond the portion used in embedding (to return to callers) +chunk_buffer = 5 # Number of lines before and after chunk beyond the portion used in embedding (to return to callers) # Maximum number of entries for ChromaDB to find default_rag_top_k = 50 @@ -60,7 +60,7 @@ cert_path = "/opt/backstory/keys/cert.pem" host = os.getenv("BACKSTORY_HOST", "0.0.0.0") port = int(os.getenv("BACKSTORY_PORT", "8911")) api_prefix = "/api/1.0" -debug=os.getenv("BACKSTORY_DEBUG", "false").lower() in ("true", "1", "yes") +debug = os.getenv("BACKSTORY_DEBUG", "false").lower() in ("true", "1", "yes") # Used for filtering tracebacks -app_path="/opt/backstory/src/backend" \ No newline at end of file +app_path = "/opt/backstory/src/backend" diff --git a/src/backend/device_manager.py b/src/backend/device_manager.py index cfad1c6..0dc849b 100644 --- a/src/backend/device_manager.py +++ b/src/backend/device_manager.py @@ -1,31 +1,32 @@ -from fastapi import Request +from fastapi import Request from database.manager import RedisDatabase import hashlib from logger import logger from datetime import datetime, timezone -from user_agents import parse +from user_agents import parse import json + class DeviceManager: def __init__(self, database: RedisDatabase): self.database = database - + def generate_device_fingerprint(self, request: Request) -> str: """Generate device fingerprint from request""" user_agent = request.headers.get("user-agent", "") accept_language = request.headers.get("accept-language", "") - + # Create fingerprint fingerprint_data = f"{user_agent}|{accept_language}" fingerprint = hashlib.sha256(fingerprint_data.encode()).hexdigest()[:16] - + return fingerprint - + def parse_device_info(self, request: Request) -> dict: """Parse device information from request""" user_agent_string = request.headers.get("user-agent", "") user_agent = parse(user_agent_string) - + return { "device_id": self.generate_device_fingerprint(request), "device_name": f"{user_agent.browser.family} on {user_agent.os.family}", @@ -34,9 +35,9 @@ class DeviceManager: "os": user_agent.os.family, "os_version": user_agent.os.version_string, "ip_address": request.client.host if request.client else "unknown", - "user_agent": user_agent_string + "user_agent": user_agent_string, } - + async def is_trusted_device(self, user_id: str, device_id: str) -> bool: """Check if device is trusted for user""" try: @@ -46,7 +47,7 @@ class DeviceManager: except Exception as e: logger.error(f"Error checking trusted device: {e}") return False - + async def add_trusted_device(self, user_id: str, device_id: str, device_info: dict): """Add device to trusted devices""" try: @@ -54,20 +55,20 @@ class DeviceManager: device_data = { **device_info, "added_at": datetime.now(timezone.utc).isoformat(), - "last_used": datetime.now(timezone.utc).isoformat() + "last_used": datetime.now(timezone.utc).isoformat(), } - + # Store for 90 days await self.database.redis.setex( - key, + key, 90 * 24 * 60 * 60, # 90 days in seconds - json.dumps(device_data, default=str) + json.dumps(device_data, default=str), ) - + logger.info(f"🔒 Added trusted device {device_id} for user {user_id}") except Exception as e: logger.error(f"Error adding trusted device: {e}") - + async def update_device_last_used(self, user_id: str, device_id: str): """Update last used timestamp for device""" try: @@ -79,7 +80,7 @@ class DeviceManager: await self.database.redis.setex( key, 90 * 24 * 60 * 60, # Reset 90 day expiry - json.dumps(device_info, default=str) + json.dumps(device_info, default=str), ) except Exception as e: logger.error(f"Error updating device last used: {e}") diff --git a/src/backend/email_service.py b/src/backend/email_service.py index 29b82c1..d00f7fd 100644 --- a/src/backend/email_service.py +++ b/src/backend/email_service.py @@ -1,8 +1,8 @@ import os from typing import Tuple from logger import logger -from email.mime.text import MIMEText -from email.mime.multipart import MIMEMultipart +from email.mime.text import MIMEText +from email.mime.multipart import MIMEMultipart import smtplib import asyncio from email_templates import EMAIL_TEMPLATES @@ -10,6 +10,7 @@ from datetime import datetime, timezone, timedelta import json from database.manager import RedisDatabase + class EmailService: def __init__(self): # Configure these in your .env file @@ -26,64 +27,48 @@ class EmailService: def _get_template(self, template_name: str) -> dict: """Get email template by name""" return EMAIL_TEMPLATES.get(template_name, {}) - + def _format_template(self, template: str, **kwargs) -> str: """Format template with provided variables""" return template.format( - app_name=self.app_name, - from_name=self.from_name, - frontend_url=self.frontend_url, - **kwargs + app_name=self.app_name, from_name=self.from_name, frontend_url=self.frontend_url, **kwargs ) - + async def send_verification_email( - self, - to_email: str, - verification_token: str, - user_name: str, - user_type: str = "user" + self, to_email: str, verification_token: str, user_name: str, user_type: str = "user" ): """Send email verification email using template""" try: template = self._get_template("verification") verification_link = f"{self.frontend_url}/login/verify-email?token={verification_token}" - - subject = self._format_template( - template["subject"], - user_name=user_name, - to_email=to_email - ) - + + subject = self._format_template(template["subject"], user_name=user_name, to_email=to_email) + html_content = self._format_template( template["html"], user_name=user_name, user_type=user_type, to_email=to_email, - verification_link=verification_link + verification_link=verification_link, ) - + await self._send_email(to_email, subject, html_content) logger.info(f"📧 Verification email sent to {to_email}") - + except Exception as e: logger.error(f"❌ Failed to send verification email to {to_email}: {e}") raise - + async def send_mfa_email( - self, - to_email: str, - mfa_code: str, - device_name: str, - user_name: str, - ip_address: str = "Unknown" + self, to_email: str, mfa_code: str, device_name: str, user_name: str, ip_address: str = "Unknown" ): """Send MFA code email using template""" try: template = self._get_template("mfa") login_time = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC") - + subject = self._format_template(template["subject"]) - + html_content = self._format_template( template["html"], user_name=user_name, @@ -91,64 +76,56 @@ class EmailService: ip_address=ip_address, login_time=login_time, mfa_code=mfa_code, - to_email=to_email + to_email=to_email, ) - + await self._send_email(to_email, subject, html_content) logger.info(f"📧 MFA code sent to {to_email} for device {device_name}") - + except Exception as e: logger.error(f"❌ Failed to send MFA email to {to_email}: {e}") raise - - async def send_password_reset_email( - self, - to_email: str, - reset_token: str, - user_name: str - ): + + async def send_password_reset_email(self, to_email: str, reset_token: str, user_name: str): """Send password reset email using template""" try: template = self._get_template("password_reset") reset_link = f"{self.frontend_url}/login/reset-password?token={reset_token}" - + subject = self._format_template(template["subject"]) - + html_content = self._format_template( - template["html"], - user_name=user_name, - reset_link=reset_link, - to_email=to_email + template["html"], user_name=user_name, reset_link=reset_link, to_email=to_email ) - + await self._send_email(to_email, subject, html_content) logger.info(f"📧 Password reset email sent to {to_email}") - + except Exception as e: logger.error(f"❌ Failed to send password reset email to {to_email}: {e}") raise - + async def _send_email(self, to_email: str, subject: str, html_content: str): """Send email using SMTP with improved error handling""" try: if not self.email_user: raise ValueError("Email user is not configured") # Create message - msg = MIMEMultipart('alternative') - msg['From'] = f"{self.from_name} <{self.email_user}>" - msg['To'] = to_email - msg['Subject'] = subject - msg['Reply-To'] = self.email_user - + msg = MIMEMultipart("alternative") + msg["From"] = f"{self.from_name} <{self.email_user}>" + msg["To"] = to_email + msg["Subject"] = subject + msg["Reply-To"] = self.email_user + # Add HTML content - html_part = MIMEText(html_content, 'html', 'utf-8') + html_part = MIMEText(html_content, "html", "utf-8") msg.attach(html_part) - + # Send email with connection pooling and retry logic max_retries = 3 if not self.smtp_server or self.smtp_port == 0 or not self.email_user or not self.email_password: raise ValueError("SMTP configuration is not set in the environment variables") - + for attempt in range(max_retries): try: with smtplib.SMTP(self.smtp_server, self.smtp_port) as server: @@ -157,71 +134,61 @@ class EmailService: text = msg.as_string() server.sendmail(self.email_user, to_email, text) break # Success, exit retry loop - except smtplib.SMTPException as e: if attempt == max_retries - 1: # Last attempt raise logger.warning(f"⚠️ SMTP attempt {attempt + 1} failed, retrying: {e}") await asyncio.sleep(1) # Wait before retry - + logger.debug(f"📧 Email sent successfully to {to_email}") - + except Exception as e: logger.error(f"❌ SMTP error sending to {to_email}: {e}") raise + class EmailRateLimiter: def __init__(self, database: RedisDatabase): self.database = database - + async def can_send_email(self, email: str, email_type: str, limit: int = 5, window_minutes: int = 60) -> bool: """Check if email can be sent based on rate limiting""" try: key = f"email_rate_limit:{email_type}:{email.lower()}" current_time = datetime.now(timezone.utc) window_start = current_time - timedelta(minutes=window_minutes) - + # Get current count count_data = await self.database.redis.get(key) if not count_data: # First email, allow it await self._record_email_sent(key, current_time, window_minutes) return True - + email_records = json.loads(count_data) - + # Filter out old records - recent_records = [ - record for record in email_records - if datetime.fromisoformat(record) > window_start - ] - + recent_records = [record for record in email_records if datetime.fromisoformat(record) > window_start] + if len(recent_records) >= limit: logger.warning(f"⚠️ Email rate limit exceeded for {email} ({email_type})") return False - + # Add current email to records recent_records.append(current_time.isoformat()) - await self.database.redis.setex( - key, - window_minutes * 60, - json.dumps(recent_records) - ) - + await self.database.redis.setex(key, window_minutes * 60, json.dumps(recent_records)) + return True - + except Exception as e: logger.error(f"❌ Error checking email rate limit: {e}") # On error, allow the email to be safe return True - + async def _record_email_sent(self, key: str, timestamp: datetime, ttl_minutes: int): """Record that an email was sent""" - await self.database.redis.setex( - key, - ttl_minutes * 60, - json.dumps([timestamp.isoformat()]) - ) + await self.database.redis.setex(key, ttl_minutes * 60, json.dumps([timestamp.isoformat()])) + class VerificationEmailRateLimiter: def __init__(self, database: RedisDatabase): @@ -229,7 +196,7 @@ class VerificationEmailRateLimiter: self.max_attempts_per_hour = 3 # Maximum 3 emails per hour self.max_attempts_per_day = 10 # Maximum 10 emails per day self.cooldown_minutes = 5 # 5 minute cooldown between emails - + async def can_send_verification_email(self, email: str) -> Tuple[bool, str]: """ Check if verification email can be sent based on rate limiting @@ -238,49 +205,49 @@ class VerificationEmailRateLimiter: try: email_lower = email.lower() current_time = datetime.now(timezone.utc) - + # Check daily limit daily_count = await self.database.get_verification_attempts_count(email) if daily_count >= self.max_attempts_per_day: - return False, f"Daily limit reached. You can request up to {self.max_attempts_per_day} verification emails per day." - + return ( + False, + f"Daily limit reached. You can request up to {self.max_attempts_per_day} verification emails per day.", + ) + # Check hourly limit hour_ago = current_time - timedelta(hours=1) hourly_key = f"verification_attempts:{email_lower}" data = await self.database.redis.get(hourly_key) - + if data: attempts_data = json.loads(data) - recent_attempts = [ - attempt for attempt in attempts_data - if datetime.fromisoformat(attempt) > hour_ago - ] - + recent_attempts = [attempt for attempt in attempts_data if datetime.fromisoformat(attempt) > hour_ago] + if len(recent_attempts) >= self.max_attempts_per_hour: - return False, f"Hourly limit reached. You can request up to {self.max_attempts_per_hour} verification emails per hour." - + return ( + False, + f"Hourly limit reached. You can request up to {self.max_attempts_per_hour} verification emails per hour.", + ) + # Check cooldown period if recent_attempts: last_attempt = max(datetime.fromisoformat(attempt) for attempt in recent_attempts) time_since_last = current_time - last_attempt - + if time_since_last.total_seconds() < self.cooldown_minutes * 60: remaining_minutes = self.cooldown_minutes - int(time_since_last.total_seconds() / 60) return False, f"Please wait {remaining_minutes} more minute(s) before requesting another email." - + return True, "OK" - + except Exception as e: logger.error(f"❌ Error checking verification email rate limit: {e}") # On error, be conservative and deny return False, "Rate limit check failed. Please try again later." - + async def record_email_sent(self, email: str): """Record that a verification email was sent""" await self.database.record_verification_attempt(email) - email_service = EmailService() - - \ No newline at end of file diff --git a/src/backend/email_templates.py b/src/backend/email_templates.py index 51c81c0..96c343f 100644 --- a/src/backend/email_templates.py +++ b/src/backend/email_templates.py @@ -129,9 +129,8 @@ EMAIL_TEMPLATES = { - """ + """, }, - "mfa": { "subject": "Security Code for Backstory", "html": """ @@ -274,9 +273,8 @@ EMAIL_TEMPLATES = { - """ + """, }, - "password_reset": { "subject": "Reset your Backstory password", "html": """ @@ -386,6 +384,6 @@ EMAIL_TEMPLATES = { - """ - } -} \ No newline at end of file + """, + }, +} diff --git a/src/backend/entities/entity_manager.py b/src/backend/entities/entity_manager.py index df65be2..d7e9b73 100644 --- a/src/backend/entities/entity_manager.py +++ b/src/backend/entities/entity_manager.py @@ -3,16 +3,17 @@ import weakref from datetime import datetime, timedelta from typing import Dict, Optional from contextlib import asynccontextmanager -from pydantic import BaseModel # type: ignore +from pydantic import BaseModel # type: ignore from models import Candidate from agents.base import CandidateEntity from database.manager import RedisDatabase -from prometheus_client import CollectorRegistry # type: ignore +from prometheus_client import CollectorRegistry # type: ignore + class EntityManager(BaseModel): """Manages lifecycle of CandidateEntity instances""" - + def __init__(self, default_ttl_minutes: int = 30): self._entities: Dict[str, CandidateEntity] = {} self._weak_refs: Dict[str, weakref.ReferenceType] = {} @@ -25,7 +26,7 @@ class EntityManager(BaseModel): """Start background cleanup task""" if self._cleanup_task is None: self._cleanup_task = asyncio.create_task(self._periodic_cleanup()) - + async def stop_cleanup_task(self): """Stop background cleanup task""" if self._cleanup_task: @@ -36,49 +37,44 @@ class EntityManager(BaseModel): pass self._cleanup_task = None - def initialize( - self, - prometheus_collector: CollectorRegistry, - database: RedisDatabase): + def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase): """Initialize the EntityManager with Prometheus collector""" self._prometheus_collector = prometheus_collector self._database = database async def get_entity(self, candidate: "Candidate") -> CandidateEntity: """Get or create CandidateEntity with proper reference tracking""" - + # Check if entity exists and is still valid if id in self._entities: entity = self._entities[candidate.id] entity.last_accessed = datetime.now() entity.reference_count += 1 return entity - + if not self._prometheus_collector or not self._database: raise ValueError("EntityManager has not been initialized with required components.") - + entity = CandidateEntity(candidate=candidate) - await entity.initialize( - prometheus_collector=self._prometheus_collector, - database=self._database) - + await entity.initialize(prometheus_collector=self._prometheus_collector, database=self._database) + # Store with reference tracking self._entities[candidate.id] = entity self._weak_refs[candidate.id] = weakref.ref(entity, self._on_entity_deleted(candidate.id)) - + entity.reference_count = 1 entity.last_accessed = datetime.now() - + return entity - + async def remove_entity(self, candidate_id: str) -> bool: """ Immediately remove and cleanup a candidate entity from active persistence. This should be called when a candidate is being deleted from the system. - + Args: candidate_id: The ID of the candidate entity to remove - + Returns: bool: True if entity was found and removed, False if not found """ @@ -88,36 +84,38 @@ class EntityManager(BaseModel): if not entity: print(f"Entity {candidate_id} not found in active persistence") return False - + # Remove from tracking dictionaries self._entities.pop(candidate_id, None) self._weak_refs.pop(candidate_id, None) - + # Cleanup the entity await entity.cleanup() - + print(f"Successfully removed entity {candidate_id} from active persistence") return True - + except Exception as e: print(f"Error removing entity {candidate_id}: {e}") return False - + def _on_entity_deleted(self, user_id: str): """Callback when entity is garbage collected""" + def cleanup_callback(weak_ref): self._entities.pop(user_id, None) self._weak_refs.pop(user_id, None) print(f"Entity {user_id} garbage collected") + return cleanup_callback - + async def release_entity(self, user_id: str): """Explicitly release reference to entity""" if user_id in self._entities: entity = self._entities[user_id] entity.reference_count = max(0, entity.reference_count - 1) entity.last_accessed = datetime.now() - + async def _periodic_cleanup(self): """Background task to clean up expired entities""" while True: @@ -128,20 +126,19 @@ class EntityManager(BaseModel): break except Exception as e: print(f"Error in cleanup task: {e}") - + async def _cleanup_expired_entities(self): """Remove entities that have expired based on TTL and reference count""" current_time = datetime.now() expired_entities = [] - + for user_id, entity in list(self._entities.items()): time_since_access = current_time - entity.last_accessed - + # Remove if TTL exceeded and no active references - if (time_since_access > timedelta(minutes=self._ttl_minutes) - and entity.reference_count == 0): + if time_since_access > timedelta(minutes=self._ttl_minutes) and entity.reference_count == 0: expired_entities.append(user_id) - + for user_id in expired_entities: entity = self._entities.pop(user_id, None) self._weak_refs.pop(user_id, None) @@ -153,6 +150,7 @@ class EntityManager(BaseModel): # Global entity manager instance entity_manager = EntityManager(default_ttl_minutes=30) + @asynccontextmanager async def get_candidate_entity(candidate: Candidate): """Context manager for safe entity access with automatic reference management""" @@ -164,4 +162,5 @@ async def get_candidate_entity(candidate: Candidate): finally: await entity_manager.release_entity(candidate.id) + EntityManager.model_rebuild() diff --git a/src/backend/focused_test.py b/src/backend/focused_test.py index ad697ea..a6ab048 100644 --- a/src/backend/focused_test.py +++ b/src/backend/focused_test.py @@ -6,20 +6,17 @@ without getting caught up in serialization format complexities import sys from datetime import datetime -from models import ( - UserStatus, UserType, SkillLevel, EmploymentType, - Candidate, Employer, Location, Skill -) +from models import UserStatus, UserType, SkillLevel, EmploymentType, Candidate, Employer, Location, Skill def test_model_creation(): """Test that we can create models successfully""" print("🧪 Testing model creation...") - + # Create supporting objects location = Location(city="Austin", country="USA") skill = Skill(name="Python", category="Programming", level=SkillLevel.ADVANCED) - + # Create candidate candidate = Candidate( email="test@example.com", @@ -37,9 +34,9 @@ def test_model_creation(): preferred_job_types=[EmploymentType.FULL_TIME], location=location, languages=[], - certifications=[] + certifications=[], ) - + # Create employer employer = Employer( user_type=UserType.EMPLOYER, @@ -54,72 +51,75 @@ def test_model_creation(): industry="Technology", company_size="50-200", company_description="A test company", - location=location + location=location, ) - + print(f"✅ Candidate: {candidate.first_name} {candidate.last_name}") print(f"✅ Employer: {employer.company_name}") print(f"✅ User types: {candidate.user_type}, {employer.user_type}") - + return candidate, employer + def test_json_api_format(): """Test JSON serialization in API format (the most important use case)""" print("\n📡 Testing JSON API format...") - + candidate, employer = test_model_creation() - + # Serialize to JSON (API format) candidate_json = candidate.model_dump_json(by_alias=True) employer_json = employer.model_dump_json(by_alias=True) - + print(f"✅ Candidate JSON: {len(candidate_json)} chars") print(f"✅ Employer JSON: {len(employer_json)} chars") - + # Deserialize from JSON candidate_back = Candidate.model_validate_json(candidate_json) employer_back = Employer.model_validate_json(employer_json) - + # Verify data integrity assert candidate_back.email == candidate.email assert candidate_back.first_name == candidate.first_name assert employer_back.company_name == employer.company_name - + print("✅ JSON round-trip successful") print("✅ Data integrity verified") - + return True + def test_api_dict_format(): """Test dictionary format with aliases (for API requests/responses)""" print("\n📊 Testing API dictionary format...") - + candidate, employer = test_model_creation() - + # Create API format dictionaries candidate_dict = candidate.model_dump(by_alias=True) employer_dict = employer.model_dump(by_alias=True) - + # Verify camelCase aliases are used assert "firstName" in candidate_dict assert "lastName" in candidate_dict assert "createdAt" in candidate_dict assert "companyName" in employer_dict - + print("✅ API format dictionaries created") print("✅ CamelCase aliases verified") - + # Test deserializing from API format candidate_back = Candidate.model_validate(candidate_dict) employer_back = Employer.model_validate(employer_dict) - + assert candidate_back.email == candidate.email assert employer_back.company_name == employer.company_name - + print("✅ API format round-trip successful") - + return True + def test_validation_constraints(): """Test that validation constraints work""" print("\n🔒 Testing validation constraints...") @@ -135,45 +135,47 @@ def test_validation_constraints(): status=UserStatus.ACTIVE, first_name="Jane", last_name="Doe", - full_name="Jane Doe" + full_name="Jane Doe", ) print("❌ Validation should have failed but didn't") return False except ValueError as e: - print(f"✅ Validation error caught: {e}") + print(f"✅ Validation error caught: {e}") return True + def test_enum_values(): """Test that enum values work correctly""" print("\n📋 Testing enum values...") - + # Test that enum values are properly handled candidate, employer = test_model_creation() - + # Check enum values in serialization candidate_dict = candidate.model_dump(by_alias=True) - + assert candidate_dict["status"] == "active" assert candidate_dict["userType"] == "candidate" assert employer.user_type == UserType.EMPLOYER - + print("✅ Enum values correctly serialized") print(f"✅ User types: candidate={candidate.user_type}, employer={employer.user_type}") - + return True + def main(): """Run all focused tests""" print("🎯 Focused Pydantic Model Tests") print("=" * 40) - + try: test_model_creation() test_json_api_format() - test_api_dict_format() + test_api_dict_format() test_validation_constraints() test_enum_values() - + print("\n🎉 All focused tests passed!") print("=" * 40) print("✅ Models work correctly") @@ -181,16 +183,18 @@ def main(): print("✅ Validation constraints work") print("✅ Enum values work") print("✅ Ready for type generation!") - + return True - + except Exception as e: print(f"\n❌ Test failed: {type(e).__name__}: {e}") import traceback + traceback.print_exc() print(f"\n❌ {traceback.format_exc()}") return False + if __name__ == "__main__": success = main() - sys.exit(0 if success else 1) \ No newline at end of file + sys.exit(0 if success else 1) diff --git a/src/backend/helpers/check_serializable.py b/src/backend/helpers/check_serializable.py index 5f6ff86..7791acf 100644 --- a/src/backend/helpers/check_serializable.py +++ b/src/backend/helpers/check_serializable.py @@ -1,7 +1,8 @@ -from pydantic import BaseModel +from pydantic import BaseModel import json from typing import Any, List, Set + def check_serializable(obj: Any, path: str = "", errors: List[str] = [], visited: Set[int] = set()) -> List[str]: """ Recursively check all fields in an object for non-JSON-serializable types, avoiding infinite recursion. @@ -57,4 +58,4 @@ def check_serializable(obj: Any, path: str = "", errors: List[str] = [], visited # Remove the current object from visited to allow processing in other branches visited.discard(obj_id) - return errors \ No newline at end of file + return errors diff --git a/src/backend/helpers/model_cast.py b/src/backend/helpers/model_cast.py index 9a31226..402b9e2 100644 --- a/src/backend/helpers/model_cast.py +++ b/src/backend/helpers/model_cast.py @@ -1,5 +1,5 @@ from typing import Type, TypeVar -from pydantic import BaseModel +from pydantic import BaseModel import copy from models import Candidate, CandidateAI, Employer, Guest, BaseUserWithType @@ -10,16 +10,19 @@ assert issubclass(CandidateAI, BaseUserWithType), "CandidateAI must inherit from assert issubclass(Employer, BaseUserWithType), "Employer must inherit from BaseUserWithType" assert issubclass(Guest, BaseUserWithType), "Guest must inherit from BaseUserWithType" -T = TypeVar('T', bound=BaseModel) +T = TypeVar("T", bound=BaseModel) + def cast_to_model(model_cls: Type[T], source: BaseModel) -> T: data = {field: getattr(source, field) for field in model_cls.__fields__} return model_cls(**data) + def cast_to_model_safe(model_cls: Type[T], source: BaseModel) -> T: data = {field: copy.deepcopy(getattr(source, field)) for field in model_cls.__fields__} return model_cls(**data) + def cast_to_base_user_with_type(user) -> BaseUserWithType: """ Casts a Candidate, CandidateAI, Employer, or Guest to BaseUserWithType. diff --git a/src/backend/image_generator/image_model_cache.py b/src/backend/image_generator/image_model_cache.py index d13905c..ff71bfa 100644 --- a/src/backend/image_generator/image_model_cache.py +++ b/src/backend/image_generator/image_model_cache.py @@ -4,10 +4,11 @@ import re import time from typing import Any -import torch -from diffusers import StableDiffusionPipeline, FluxPipeline +import torch +from diffusers import StableDiffusionPipeline, FluxPipeline -class ImageModelCache: # Stay loaded for 3 hours + +class ImageModelCache: # Stay loaded for 3 hours def __init__(self, timeout_seconds: float = 3 * 60 * 60): self._pipe = None self._model_name = None @@ -36,11 +37,11 @@ class ImageModelCache: # Stay loaded for 3 hours cached_model_type = self._get_model_type(self._model_name) if self._model_name else None if ( - self._pipe is not None and - self._model_name == model and - self._device == device and - current_model_type == cached_model_type and - current_time - self._last_access_time < self._timeout_seconds + self._pipe is not None + and self._model_name == model + and self._device == device + and current_model_type == cached_model_type + and current_time - self._last_access_time < self._timeout_seconds ): self._last_access_time = current_time return self._pipe @@ -52,8 +53,10 @@ class ImageModelCache: # Stay loaded for 3 hours model, torch_dtype=torch.float16 if device == "cuda" else torch.float32, ) + def dummy_safety_checker(images, clip_input): return images, [False] * len(images) + pipe.safety_checker = dummy_safety_checker else: pipe = FluxPipeline.from_pretrained( @@ -61,7 +64,7 @@ class ImageModelCache: # Stay loaded for 3 hours torch_dtype=torch.float16 if device == "cuda" else torch.float32, ) try: - pipe.load_lora_weights('enhanceaiteam/Flux-uncensored', weight_name='lora.safetensors') + pipe.load_lora_weights("enhanceaiteam/Flux-uncensored", weight_name="lora.safetensors") except Exception as e: raise Exception(f"Failed to load LoRA weights: {str(e)}") @@ -89,10 +92,7 @@ class ImageModelCache: # Stay loaded for 3 hours async def cleanup_if_expired(self): async with self._lock: - if ( - self._pipe is not None and - time.time() - self._last_access_time >= self._timeout_seconds - ): + if self._pipe is not None and time.time() - self._last_access_time >= self._timeout_seconds: await self._unload_model() async def _periodic_cleanup(self): diff --git a/src/backend/image_generator/profile_image.py b/src/backend/image_generator/profile_image.py index a813848..9c0bf3a 100644 --- a/src/backend/image_generator/profile_image.py +++ b/src/backend/image_generator/profile_image.py @@ -1,5 +1,5 @@ from __future__ import annotations -from pydantic import BaseModel +from pydantic import BaseModel from typing import Any, AsyncGenerator import traceback import asyncio @@ -29,6 +29,7 @@ TIME_ESTIMATES = { } } + class ImageRequest(BaseModel): session_id: str filepath: str @@ -39,34 +40,40 @@ class ImageRequest(BaseModel): width: int = 256 guidance_scale: float = 7.5 + # Global model cache instance model_cache = ImageModelCache() + def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str): """Background worker for Flux image generation""" try: # Flux: Run generation in the background and yield progress updates - status_queue.put(ChatMessageStatus( - session_id=params.session_id, - content="Initializing image generation.", - activity=ApiActivityType.GENERATING_IMAGE, - )) + status_queue.put( + ChatMessageStatus( + session_id=params.session_id, + content="Initializing image generation.", + activity=ApiActivityType.GENERATING_IMAGE, + ) + ) # Start the generation task start_gen_time = time.time() - + # Simulate your pipe call with progress updates def status_callback(pipeline, step, timestep, callback_kwargs): # Send progress updates - progress = int((step+1) / params.iterations * 100) - - status_queue.put(ChatMessageStatus( - session_id=params.session_id, - content=f"Processing step {step+1}/{params.iterations} ({progress}%)", - activity=ApiActivityType.GENERATING_IMAGE, - )) + progress = int((step + 1) / params.iterations * 100) + + status_queue.put( + ChatMessageStatus( + session_id=params.session_id, + content=f"Processing step {step+1}/{params.iterations} ({progress}%)", + activity=ApiActivityType.GENERATING_IMAGE, + ) + ) return callback_kwargs - + # Replace this block with your actual Flux pipe call: image = pipe( params.prompt, @@ -76,30 +83,36 @@ def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task width=params.width, callback_on_step_end=status_callback, ).images[0] - + gen_time = time.time() - start_gen_time per_step_time = gen_time / params.iterations if params.iterations > 0 else gen_time - + logger.info(f"Saving to {params.filepath}") image.save(params.filepath) # Final completion status - status_queue.put(ChatMessage( - session_id=params.session_id, - status=ApiStatusType.DONE, - content=f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.", - )) - + status_queue.put( + ChatMessage( + session_id=params.session_id, + status=ApiStatusType.DONE, + content=f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.", + ) + ) + except Exception as e: logger.error(traceback.format_exc()) logger.error(e) - status_queue.put(ChatMessageError( - session_id=params.session_id, - content=f"Error during image generation: {str(e)}", - )) + status_queue.put( + ChatMessageError( + session_id=params.session_id, + content=f"Error during image generation: {str(e)}", + ) + ) -async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError, None]: +async def async_generate_image( + pipe: Any, params: ImageRequest +) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError, None]: """ Single async function that handles background Flux generation with status streaming """ @@ -109,40 +122,36 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato try: # Start background worker thread - worker_thread = Thread( - target=flux_worker, - args=(pipe, params, status_queue, task_id), - daemon=True - ) + worker_thread = Thread(target=flux_worker, args=(pipe, params, status_queue, task_id), daemon=True) worker_thread.start() - + # Initial status status_message = ChatMessageStatus( session_id=params.session_id, content=f"Starting image generation with task ID {task_id}", - activity=ApiActivityType.THINKING + activity=ApiActivityType.THINKING, ) yield status_message - + # Stream status updates completed = False last_heartbeat = time.time() - + while not completed and worker_thread.is_alive(): try: # Try to get status update (non-blocking) status_update = status_queue.get_nowait() - + # Send status update yield status_update - + # Check if completed if status_update.status == ApiStatusType.DONE: logger.info(f"Image generation completed for task {task_id}") completed = True - + last_heartbeat = time.time() - + except queue.Empty: # No new status, send heartbeat if needed current_time = time.time() @@ -154,10 +163,10 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato ) yield heartbeat last_heartbeat = current_time - + # Brief sleep to prevent busy waiting await asyncio.sleep(0.1) - + # Handle thread completion or timeout if not completed: if worker_thread.is_alive(): @@ -175,20 +184,18 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato content=f"Generation completed for task {task_id}.", ) yield final_status - + except Exception as e: - error_status = ChatMessageError( - session_id=params.session_id, - content=f'Server error: {str(e)}' - ) + error_status = ChatMessageError(session_id=params.session_id, content=f"Server error: {str(e)}") logger.error(error_status) yield error_status - + finally: # Cleanup: ensure thread completion - if worker_thread and 'worker_thread' in locals() and worker_thread.is_alive(): + if worker_thread and "worker_thread" in locals() and worker_thread.is_alive(): worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup + def status(session_id: str, status: str) -> ChatMessageStatus: """Update chat message status and return it.""" chat_message = ChatMessageStatus( @@ -198,6 +205,7 @@ def status(session_id: str, status: str) -> ChatMessageStatus: ) return chat_message + async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, None]: """Generate an image with specified dimensions and yield status updates with time estimates.""" session_id = request.session_id @@ -205,10 +213,7 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N try: # Validate prompt if not prompt: - error_message = ChatMessageError( - session_id=session_id, - content="Prompt cannot be empty." - ) + error_message = ChatMessageError(session_id=session_id, content="Prompt cannot be empty.") logger.error(error_message.content) yield error_message return @@ -216,8 +221,7 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N # Validate dimensions if request.height <= 0 or request.width <= 0: error_message = ChatMessageError( - session_id=session_id, - content="Height and width must be positive integers." + session_id=session_id, content="Height and width must be positive integers." ) logger.error(error_message.content) yield error_message @@ -234,13 +238,16 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N estimates = TIME_ESTIMATES[model_type][device] resolution_scale = (request.height * request.width) / (512 * 512) estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale - + # Initialize or get cached pipeline start_time = time.time() yield status(session_id, "Loading generative image model...") pipe = await model_cache.get_pipeline(request.model, device) load_time = time.time() - start_time - yield status(session_id, f"Model loaded in {load_time:.1f} seconds.",) + yield status( + session_id, + f"Model loaded in {load_time:.1f} seconds.", + ) progress = None async for progress in async_generate_image(pipe, request): @@ -252,13 +259,12 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N if not progress: error_message = ChatMessageError( - session_id=session_id, - content="Image generation failed to produce a valid response." + session_id=session_id, content="Image generation failed to produce a valid response." ) logger.error(f"⚠️ {error_message.content}") yield error_message return - + # Final result total_time = time.time() - start_time chat_message = ChatMessage( @@ -269,11 +275,8 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N yield chat_message except Exception as e: - error_message = ChatMessageError( - session_id=session_id, - content=f"Error during image generation: {str(e)}" - ) + error_message = ChatMessageError(session_id=session_id, content=f"Error during image generation: {str(e)}") logger.error(traceback.format_exc()) logger.error(error_message.content) yield error_message - return \ No newline at end of file + return diff --git a/src/backend/json_extractor.py b/src/backend/json_extractor.py index 7694559..fba8701 100644 --- a/src/backend/json_extractor.py +++ b/src/backend/json_extractor.py @@ -2,13 +2,14 @@ import json import re from typing import List, Union + def extract_json_blocks(text: str, allow_multiple: bool = False) -> List[dict]: """ Extract JSON blocks from text, even if surrounded by markdown or noisy text. If allow_multiple is True, returns all JSON blocks; otherwise, only the first. """ found = [] - + # First try to extract from code blocks (most reliable) code_block_pattern = r"```(?:json)?\s*([\s\S]+?)\s*```" for match in re.finditer(code_block_pattern, text): @@ -20,24 +21,25 @@ def extract_json_blocks(text: str, allow_multiple: bool = False) -> List[dict]: return [parsed] except json.JSONDecodeError: continue - + # If no valid code blocks found, look for standalone JSON objects/arrays if not found: standalone_json = _extract_standalone_json(text, allow_multiple) found.extend(standalone_json) - + if not found: raise ValueError("No valid JSON block found in the text") - + return found + def _extract_standalone_json(text: str, allow_multiple: bool = False) -> List[Union[dict, list]]: """Extract standalone JSON objects or arrays from text using proper brace counting.""" found = [] i = 0 - + while i < len(text): - if text[i] in '{[': + if text[i] in "{[": # Found potential JSON start json_str = _extract_complete_json_at_position(text, i) if json_str: @@ -52,31 +54,32 @@ def _extract_standalone_json(text: str, allow_multiple: bool = False) -> List[Un except json.JSONDecodeError: pass i += 1 - + return found + def _extract_complete_json_at_position(text: str, start_pos: int) -> str: """ Extract a complete JSON object or array starting at the given position. Uses proper brace/bracket counting and string escape handling. """ - if start_pos >= len(text) or text[start_pos] not in '{[': + if start_pos >= len(text) or text[start_pos] not in "{[": return "" - + start_char = text[start_pos] - end_char = '}' if start_char == '{' else ']' - + end_char = "}" if start_char == "{" else "]" + count = 1 i = start_pos + 1 in_string = False escape_next = False - + while i < len(text) and count > 0: char = text[i] - + if escape_next: escape_next = False - elif char == '\\' and in_string: + elif char == "\\" and in_string: escape_next = True elif char == '"' and not escape_next: in_string = not in_string @@ -85,13 +88,14 @@ def _extract_complete_json_at_position(text: str, start_pos: int) -> str: count += 1 elif char == end_char: count -= 1 - + i += 1 - + if count == 0: return text[start_pos:i] return "" + def extract_json_from_text(text: str) -> str: """Extract JSON string from text that may contain other content.""" return json.dumps(extract_json_blocks(text, allow_multiple=False)[0]) diff --git a/src/backend/logger.py b/src/backend/logger.py index 4d8e928..8d8acee 100644 --- a/src/backend/logger.py +++ b/src/backend/logger.py @@ -2,11 +2,11 @@ import os import warnings import logging import defines + + def _setup_logging(level=defines.logging_level) -> logging.Logger: os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR" - warnings.filterwarnings( - "ignore", message="Overriding a previously registered kernel" - ) + warnings.filterwarnings("ignore", message="Overriding a previously registered kernel") warnings.filterwarnings("ignore", message="Warning only once for all operators") warnings.filterwarnings("ignore", message=".*Couldn't find ffmpeg or avconv.*") warnings.filterwarnings("ignore", message="'force_all_finite' was renamed to") @@ -19,8 +19,7 @@ def _setup_logging(level=defines.logging_level) -> logging.Logger: # Create a custom formatter formatter = logging.Formatter( - fmt="%(levelname)s - %(filename)s:%(lineno)d - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S" + fmt="%(levelname)s - %(filename)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) # Create a handler (e.g., StreamHandler for console output) @@ -50,5 +49,6 @@ def _setup_logging(level=defines.logging_level) -> logging.Logger: logger = logging.getLogger(__name__) return logger + logger = _setup_logging(level=defines.logging_level) -logger.debug(f"Logging initialized with level: {defines.logging_level}") \ No newline at end of file +logger.debug(f"Logging initialized with level: {defines.logging_level}") diff --git a/src/backend/main.py b/src/backend/main.py index bdb8b19..c13f380 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -39,7 +39,7 @@ from utils.dependencies import get_database, set_db_manager from routes import ( admin, auth, - candidates, + candidates, chat, employers, jobs, @@ -58,6 +58,7 @@ background_task_manager = None prev_int = signal.getsignal(signal.SIGINT) prev_term = signal.getsignal(signal.SIGTERM) + def signal_handler(signum, frame): logger.info(f"⚠️ Received signal {signum!r}, shutting down…") # now call the old handler (it might raise KeyboardInterrupt or exit) @@ -66,14 +67,16 @@ def signal_handler(signum, frame): elif signum == signal.SIGTERM and callable(prev_term): prev_term(signum, frame) + # Global background task manager background_task_manager: Optional[BackgroundTaskManager] = None + @asynccontextmanager async def lifespan(app: FastAPI): # Startup global db_manager, background_task_manager - + logger.info("🚀 Starting Backstory API with enhanced background tasks") logger.info(f"📝 API Documentation available at: http://{defines.host}:{defines.port}{defines.api_prefix}/docs") logger.info("🔗 API endpoints prefixed with: /api/1.0") @@ -81,10 +84,10 @@ async def lifespan(app: FastAPI): # Initialize database db_manager = DatabaseManager() await db_manager.initialize() - + # Set the database manager in dependencies set_db_manager(db_manager) - + entity_manager.initialize(prometheus_collector=prometheus_collector, database=db_manager.get_database()) # Initialize background task manager @@ -96,26 +99,27 @@ async def lifespan(app: FastAPI): signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) - + logger.info("🚀 Application startup completed with background tasks") - + yield # Application is running - + except Exception as e: logger.error(f"❌ Failed to start application: {e}") raise - + finally: # Shutdown logger.info("Application shutdown requested") - + # Stop background tasks first if background_task_manager: await background_task_manager.stop() - + if db_manager: await db_manager.graceful_shutdown() + app = FastAPI( lifespan=lifespan, title="Backstory API", @@ -129,12 +133,10 @@ app = FastAPI( ssl_enabled = os.getenv("SSL_ENABLED", "true").lower() == "true" if ssl_enabled: - allow_origins = ["https://battle-linux.ketrenos.com:3000", - "https://backstory-beta.ketrenos.com"] + allow_origins = ["https://battle-linux.ketrenos.com:3000", "https://backstory-beta.ketrenos.com"] else: - allow_origins = ["http://battle-linux.ketrenos.com:3000", - "http://backstory-beta.ketrenos.com"] - + allow_origins = ["http://battle-linux.ketrenos.com:3000", "http://backstory-beta.ketrenos.com"] + # Add CORS middleware app.add_middleware( CORSMiddleware, @@ -144,12 +146,14 @@ app.add_middleware( allow_headers=["*"], ) + # ============================ # Debug data type failures # ============================ @app.exception_handler(RequestValidationError) async def validation_exception_handler(request: Request, exc: RequestValidationError): import traceback + logger.error(traceback.format_exc()) logger.error(backstory_traceback.format_exc()) logger.error(f"❌ Validation error {request.method} {request.url.path}: {str(exc)}") @@ -158,6 +162,7 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE content=json.dumps({"detail": str(exc)}), ) + # ============================ # Create API router with prefix # ============================ @@ -181,44 +186,43 @@ api_router.include_router(users.router) # Health Check and Info Endpoints # ============================ + @app.get("/health") async def health_check( - database = Depends(get_database), + database=Depends(get_database), ): """Health check endpoint""" try: if not redis_manager.redis: raise RuntimeError("Redis client not initialized") - + # Test Redis connection await redis_manager.redis.ping() - + # Get database stats stats = await database.get_stats() - + # Redis info redis_info = await redis_manager.redis.info() - + return { "status": "healthy", "timestamp": datetime.utcnow().isoformat(), - "database": { - "status": "connected", - "stats": stats - }, + "database": {"status": "connected", "stats": stats}, "redis": { "version": redis_info.get("redis_version", "unknown"), "uptime": redis_info.get("uptime_in_seconds", 0), - "memory_used": redis_info.get("used_memory_human", "unknown") - } + "memory_used": redis_info.get("used_memory_human", "unknown"), + }, } - + except RuntimeError as e: return {"status": "shutting_down", "message": str(e)} except Exception as e: logger.error(f"❌ Health check failed: {e}") return {"status": "error", "message": str(e)} + # ============================ # Include Router in App # ============================ @@ -231,11 +235,14 @@ app.include_router(api_router) # ============================ logger.info(f"Debug mode is {'enabled' if defines.debug else 'disabled'}") + @app.middleware("http") async def log_requests(request: Request, call_next): try: if defines.debug and not re.match(rf"{defines.api_prefix}/metrics", request.url.path): - logger.info(f"📝 Request {request.method}: {request.url.path}, Remote: {request.client.host if request.client else ''}") + logger.info( + f"📝 Request {request.method}: {request.url.path}, Remote: {request.client.host if request.client else ''}" + ) response = await call_next(request) if defines.debug and not re.match(rf"{defines.api_prefix}/metrics", request.url.path): if response.status_code < 200 or response.status_code >= 300: @@ -243,11 +250,13 @@ async def log_requests(request: Request, call_next): return response except Exception as e: import traceback + logger.error(traceback.format_exc()) logger.error(backstory_traceback.format_exc()) logger.error(f"❌ Error processing request: {str(e)}, Path: {request.url.path}, Method: {request.method}") return JSONResponse(status_code=400, content={"detail": "Invalid HTTP request"}) + # ============================ # Request tracking middleware # ============================ @@ -258,7 +267,7 @@ async def track_requests(request, call_next): """Middleware to track active requests during shutdown""" if db_manager.is_shutting_down: return JSONResponse(status_code=503, content={"error": "Application is shutting down"}) - + db_manager.increment_requests() try: response = await call_next(request) @@ -266,6 +275,7 @@ async def track_requests(request, call_next): finally: db_manager.decrement_requests() + # ============================ # FastAPI Metrics # ============================ @@ -277,7 +287,7 @@ instrumentator = Instrumentator( should_ignore_untemplated=True, should_group_untemplated=True, excluded_handlers=[f"{defines.api_prefix}/metrics"], - registry=prometheus_collector + registry=prometheus_collector, ) # Instrument the FastAPI app @@ -291,15 +301,17 @@ instrumentator.expose(app, endpoint=f"{defines.api_prefix}/metrics") # Static File Serving # ============================ + @app.get("/{path:path}") async def serve_static(path: str, request: Request): full_path = os.path.join(defines.static_content, path) - + if os.path.exists(full_path) and os.path.isfile(full_path): return FileResponse(full_path) return FileResponse(os.path.join(defines.static_content, "index.html")) - + + # Root endpoint when no static files @app.get("/", include_in_schema=False) async def root(): @@ -309,21 +321,23 @@ async def root(): "version": "1.0.0", "api_prefix": defines.api_prefix, "documentation": f"{defines.api_prefix}/docs", - "health": f"{defines.api_prefix}/health" + "health": f"{defines.api_prefix}/health", } + async def periodic_verification_cleanup(): """Background task to periodically clean up expired verification tokens""" try: database = get_database() cleaned_count = await database.cleanup_expired_verification_tokens() - + if cleaned_count > 0: logger.info(f"🧹 Periodic cleanup: removed {cleaned_count} expired verification tokens") - + except Exception as e: logger.error(f"❌ Error in periodic verification cleanup: {e}") + if __name__ == "__main__": host = defines.host port = defines.port @@ -341,4 +355,4 @@ if __name__ == "__main__": ) else: logger.info(f"Starting web server at http://{host}:{port}") - uvicorn.run(app="main:app", host=host, port=port, log_config=None) \ No newline at end of file + uvicorn.run(app="main:app", host=host, port=port, log_config=None) diff --git a/src/backend/models.py b/src/backend/models.py index 9403a2b..433e952 100644 --- a/src/backend/models.py +++ b/src/backend/models.py @@ -4,39 +4,44 @@ from datetime import datetime, UTC from enum import Enum import uuid from utils.auth_utils import ( - validate_password_strength, + validate_password_strength, sanitize_login_input, ) import defines # Generic type variable -T = TypeVar('T') +T = TypeVar("T") # ============================ # Enums # ============================ + class UserType(str, Enum): CANDIDATE = "candidate" EMPLOYER = "employer" GUEST = "guest" + class UserGender(str, Enum): FEMALE = "female" MALE = "male" + class UserStatus(str, Enum): ACTIVE = "active" INACTIVE = "inactive" PENDING = "pending" BANNED = "banned" + class SkillLevel(str, Enum): BEGINNER = "beginner" INTERMEDIATE = "intermediate" ADVANCED = "advanced" EXPERT = "expert" + class EmploymentType(str, Enum): FULL_TIME = "full-time" PART_TIME = "part-time" @@ -44,6 +49,7 @@ class EmploymentType(str, Enum): INTERNSHIP = "internship" FREELANCE = "freelance" + class InterviewType(str, Enum): PHONE = "phone" VIDEO = "video" @@ -51,6 +57,7 @@ class InterviewType(str, Enum): TECHNICAL = "technical" BEHAVIORAL = "behavioral" + class ApplicationStatus(str, Enum): APPLIED = "applied" REVIEWING = "reviewing" @@ -60,12 +67,14 @@ class ApplicationStatus(str, Enum): ACCEPTED = "accepted" WITHDRAWN = "withdrawn" + class InterviewRecommendation(str, Enum): STRONG_HIRE = "strong_hire" HIRE = "hire" NO_HIRE = "no_hire" STRONG_NO_HIRE = "strong_no_hire" + class ChatSenderType(str, Enum): USER = "user" ASSISTANT = "assistant" @@ -75,24 +84,32 @@ class ChatSenderType(str, Enum): WARNING = "warning" ERROR = "error" + class SkillStatus(str, Enum): PENDING = "pending" COMPLETE = "complete" WAITING = "waiting" ERROR = "error" + class SkillStrength(str, Enum): STRONG = "strong" MODERATE = "moderate" WEAK = "weak" NONE = "none" + class EvidenceDetail(BaseModel): - source: str = Field(..., alias=str("source"), description="The source of the evidence (e.g., resume section, position, project)") - quote: str = Field(..., alias=str("quote"), description="Exact text from the resume or other source showing evidence") + source: str = Field( + ..., alias=str("source"), description="The source of the evidence (e.g., resume section, position, project)" + ) + quote: str = Field( + ..., alias=str("quote"), description="Exact text from the resume or other source showing evidence" + ) context: str = Field(..., alias=str("context"), description="Brief explanation of how this demonstrates the skill") model_config = ConfigDict(populate_by_name=True) + class ChromaDBGetResponse(BaseModel): # Chroma fields ids: List[str] = [] @@ -109,31 +126,53 @@ class ChromaDBGetResponse(BaseModel): umap_embedding_2d: Optional[List[float]] = Field(default=None, alias=str("umapEmbedding2D")) umap_embedding_3d: Optional[List[float]] = Field(default=None, alias=str("umapEmbedding3D")) + class SkillAssessment(BaseModel): - candidate_id: str = Field(..., alias=str('candidateId')) + candidate_id: str = Field(..., alias=str("candidateId")) skill: str = Field(..., alias=str("skill"), description="The skill being assessed") - skill_modified: Optional[str] = Field(default="", alias=str("skillModified"), description="The skill rephrased by LLM during skill match") - evidence_found: bool = Field(..., alias=str("evidenceFound"), description="Whether evidence was found for the skill") - evidence_strength: SkillStrength = Field(..., alias=str("evidenceStrength"), description="Strength of evidence found for the skill") - assessment: str = Field(..., alias=str("assessment"), description="Short (one to two sentence) assessment of the candidate's proficiency with the skill") - description: str = Field(..., alias=str("description"), description="Short (two to three sentence) description of what the skill is, independent of whether the candidate has that skill or not") - evidence_details: List[EvidenceDetail] = Field(default_factory=list, alias=str("evidenceDetails"), description="List of evidence details supporting the skill assessment") + skill_modified: Optional[str] = Field( + default="", alias=str("skillModified"), description="The skill rephrased by LLM during skill match" + ) + evidence_found: bool = Field( + ..., alias=str("evidenceFound"), description="Whether evidence was found for the skill" + ) + evidence_strength: SkillStrength = Field( + ..., alias=str("evidenceStrength"), description="Strength of evidence found for the skill" + ) + assessment: str = Field( + ..., + alias=str("assessment"), + description="Short (one to two sentence) assessment of the candidate's proficiency with the skill", + ) + description: str = Field( + ..., + alias=str("description"), + description="Short (two to three sentence) description of what the skill is, independent of whether the candidate has that skill or not", + ) + evidence_details: List[EvidenceDetail] = Field( + default_factory=list, + alias=str("evidenceDetails"), + description="List of evidence details supporting the skill assessment", + ) created_at: datetime = Field(default_factory=lambda: datetime.now(UTC), alias=str("createdAt")) updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC), alias=str("updatedAt")) rag_results: List[ChromaDBGetResponse] = Field(default_factory=list, alias=str("ragResults")) model_config = ConfigDict(populate_by_name=True) - + + class ApiMessageType(str, Enum): BINARY = "binary" TEXT = "text" JSON = "json" + class ApiStatusType(str, Enum): STREAMING = "streaming" STATUS = "status" DONE = "done" ERROR = "error" + class ChatContextType(str, Enum): JOB_SEARCH = "job_search" JOB_REQUIREMENTS = "job_requirements" @@ -148,17 +187,20 @@ class ChatContextType(str, Enum): RAG_SEARCH = "rag_search" SKILL_MATCH = "skill_match" + class AIModelType(str, Enum): QWEN2_5 = "qwen2.5" FLUX_SCHNELL = "flux-schnell" + class MFAMethod(str, Enum): APP = "app" SMS = "sms" EMAIL = "email" + class VectorStoreType(str, Enum): - CHROMA = "chroma", + CHROMA = ("chroma",) # FAISS = "faiss", # PINECONE = "pinecone" # QDRANT = "qdrant" @@ -166,6 +208,7 @@ class VectorStoreType(str, Enum): # MILVUS = "milvus" # WEAVIATE = "weaviate" + class DataSourceType(str, Enum): DOCUMENT = "document" WEBSITE = "website" @@ -173,6 +216,7 @@ class DataSourceType(str, Enum): DATABASE = "database" INTERNAL = "internal" + class ProcessingStepType(str, Enum): EXTRACT = "extract" TRANSFORM = "transform" @@ -181,12 +225,14 @@ class ProcessingStepType(str, Enum): FILTER = "filter" SUMMARIZE = "summarize" + class SearchType(str, Enum): SIMILARITY = "similarity" MMR = "mmr" HYBRID = "hybrid" KEYWORD = "keyword" + class ActivityType(str, Enum): LOGIN = "login" SEARCH = "search" @@ -196,33 +242,39 @@ class ActivityType(str, Enum): UPDATE_PROFILE = "update_profile" CHAT = "chat" + class ThemePreference(str, Enum): LIGHT = "light" DARK = "dark" SYSTEM = "system" + class NotificationType(str, Enum): EMAIL = "email" PUSH = "push" IN_APP = "in_app" + class FontSize(str, Enum): SMALL = "small" MEDIUM = "medium" LARGE = "large" + class SalaryPeriod(str, Enum): HOUR = "hour" DAY = "day" MONTH = "month" YEAR = "year" + class LanguageProficiency(str, Enum): BASIC = "basic" CONVERSATIONAL = "conversational" FLUENT = "fluent" NATIVE = "native" + class SocialPlatform(str, Enum): LINKEDIN = "linkedin" TWITTER = "twitter" @@ -232,12 +284,14 @@ class SocialPlatform(str, Enum): WEBSITE = "website" OTHER = "other" + class ColorBlindMode(str, Enum): PROTANOPIA = "protanopia" DEUTERANOPIA = "deuteranopia" TRITANOPIA = "tritanopia" NONE = "none" + class SortOrder(str, Enum): ASC = "asc" DESC = "desc" @@ -246,24 +300,27 @@ class SortOrder(str, Enum): class LoginRequest(BaseModel): login: str # Can be email or username password: str - - @field_validator('login') + + @field_validator("login") def sanitize_login(cls, v): return sanitize_login_input(v) - - @field_validator('password') + + @field_validator("password") def validate_password_not_empty(cls, v): if not v or not v.strip(): - raise ValueError('Password cannot be empty') + raise ValueError("Password cannot be empty") return v + # ============================ # MFA Models # ============================ + class EmailVerificationRequest(BaseModel): token: str + class MFARequest(BaseModel): username: str password: str @@ -272,6 +329,7 @@ class MFARequest(BaseModel): email: str = Field(..., alias=str("email")) model_config = ConfigDict(populate_by_name=True) + class MFAVerifyRequest(BaseModel): email: EmailStr code: str @@ -279,6 +337,7 @@ class MFAVerifyRequest(BaseModel): remember_device: bool = Field(False, alias=str("rememberDevice")) model_config = ConfigDict(populate_by_name=True) + class MFAData(BaseModel): message: str device_id: str = Field(..., alias=str("deviceId")) @@ -287,27 +346,33 @@ class MFAData(BaseModel): email: str model_config = ConfigDict(populate_by_name=True) + class MFARequestResponse(BaseModel): mfa_required: bool = Field(..., alias=str("mfaRequired")) mfa_data: Optional[MFAData] = Field(default=None, alias=str("mfaData")) model_config = ConfigDict(populate_by_name=True) + class ResendVerificationRequest(BaseModel): email: EmailStr + # ============================ # Supporting Models # ============================ + class Tunables(BaseModel): enable_rag: bool = Field(default=True, alias=str("enableRAG")) enable_tools: bool = Field(default=True, alias=str("enableTools")) enable_context: bool = Field(default=True, alias=str("enableContext")) + class CandidateQuestion(BaseModel): question: str tunables: Optional[Tunables] = None + class Location(BaseModel): city: str state: Optional[str] = None @@ -319,6 +384,7 @@ class Location(BaseModel): hybrid_options: Optional[List[str]] = Field(default=None, alias=str("hybridOptions")) address: Optional[str] = None + class Skill(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) name: str @@ -326,6 +392,7 @@ class Skill(BaseModel): level: SkillLevel years_of_experience: Optional[int] = Field(default=None, alias=str("yearsOfExperience")) + class WorkExperience(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) company_name: str = Field(..., alias=str("companyName")) @@ -338,6 +405,7 @@ class WorkExperience(BaseModel): location: Location achievements: Optional[List[str]] = None + class Education(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) institution: str @@ -350,10 +418,12 @@ class Education(BaseModel): achievements: Optional[List[str]] = None location: Optional[Location] = None + class Language(BaseModel): language: str proficiency: LanguageProficiency + class Certification(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) name: str @@ -363,15 +433,18 @@ class Certification(BaseModel): credential_id: Optional[str] = Field(default=None, alias=str("credentialId")) credential_url: Optional[HttpUrl] = Field(default=None, alias=str("credentialUrl")) + class SocialLink(BaseModel): platform: SocialPlatform url: HttpUrl + class DesiredSalary(BaseModel): amount: float currency: str period: SalaryPeriod + class SalaryRange(BaseModel): min: float max: float @@ -379,12 +452,14 @@ class SalaryRange(BaseModel): period: SalaryPeriod is_visible: bool = Field(..., alias=str("isVisible")) + class PointOfContact(BaseModel): name: str position: str email: EmailStr phone: Optional[str] = None + class RefreshToken(BaseModel): token: str expires_at: datetime = Field(..., alias=str("expiresAt")) @@ -393,6 +468,7 @@ class RefreshToken(BaseModel): is_revoked: bool = Field(..., alias=str("isRevoked")) revoked_reason: Optional[str] = Field(default=None, alias=str("revokedReason")) + class Attachment(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) file_name: str = Field(..., alias=str("fileName")) @@ -404,35 +480,42 @@ class Attachment(BaseModel): processing_result: Optional[Any] = Field(default=None, alias=str("processingResult")) thumbnail_url: Optional[str] = Field(default=None, alias=str("thumbnailUrl")) + class MessageReaction(BaseModel): user_id: str = Field(..., alias=str("userId")) reaction: str timestamp: datetime + class EditHistory(BaseModel): content: str edited_at: datetime = Field(..., alias=str("editedAt")) edited_by: str = Field(..., alias=str("editedBy")) + class CustomQuestion(BaseModel): question: str answer: str + class CandidateContact(BaseModel): email: EmailStr phone: Optional[str] = None + class ApplicationDecision(BaseModel): status: Literal["accepted", "rejected"] reason: Optional[str] = None date: datetime by: str + class NotificationPreference(BaseModel): type: NotificationType events: List[str] is_enabled: bool = Field(..., alias=str("isEnabled")) + class AccessibilitySettings(BaseModel): font_size: FontSize = Field(..., alias=str("fontSize")) high_contrast: bool = Field(..., alias=str("highContrast")) @@ -440,6 +523,7 @@ class AccessibilitySettings(BaseModel): screen_reader: bool = Field(..., alias=str("screenReader")) color_blind_mode: Optional[ColorBlindMode] = Field(default=None, alias=str("colorBlindMode")) + class ProcessingStep(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) type: ProcessingStepType @@ -447,6 +531,7 @@ class ProcessingStep(BaseModel): order: int depends_on: Optional[List[str]] = Field(default=None, alias=str("dependsOn")) + class RetrievalParameters(BaseModel): search_type: SearchType = Field(..., alias=str("searchType")) top_k: int = Field(..., alias=str("topK")) @@ -456,15 +541,18 @@ class RetrievalParameters(BaseModel): filter_options: Optional[Dict[str, Any]] = Field(default=None, alias=str("filterOptions")) context_window: int = Field(..., alias=str("contextWindow")) + class ErrorDetail(BaseModel): code: str message: str details: Optional[Any] = None + # ============================ # Main Models # ============================ + # Generic base user with user_type for API responses class BaseUserWithType(BaseModel): user_type: UserType = Field(..., alias=str("userType")) @@ -472,6 +560,7 @@ class BaseUserWithType(BaseModel): last_activity: datetime = Field(default_factory=lambda: datetime.now(UTC), alias=str("lastActivity")) model_config = ConfigDict(populate_by_name=True, use_enum_values=True) + # Base user model without user_type field class BaseUser(BaseUserWithType): email: EmailStr @@ -486,7 +575,7 @@ class BaseUser(BaseUserWithType): profile_image: Optional[str] = Field(default=None, alias=str("profileImage")) status: UserStatus is_admin: bool = Field(default=False, alias=str("isAdmin")) - + model_config = ConfigDict(populate_by_name=True, use_enum_values=True) @@ -495,6 +584,7 @@ class RagEntry(BaseModel): description: str = "" enabled: bool = True + class RagContentMetadata(BaseModel): source_file: str = Field(..., alias=str("sourceFile")) line_begin: int = Field(..., alias=str("lineBegin")) @@ -505,28 +595,34 @@ class RagContentMetadata(BaseModel): metadata: Dict[str, Any] = Field(default_factory=dict) model_config = ConfigDict(populate_by_name=True) + class RagContentResponse(BaseModel): id: str content: str metadata: RagContentMetadata + class DocumentType(str, Enum): PDF = "pdf" DOCX = "docx" TXT = "txt" MARKDOWN = "markdown" IMAGE = "image" - + + class DocumentOptions(BaseModel): include_in_rag: bool = Field(default=True, alias=str("includeInRag")) is_job_document: Optional[bool] = Field(default=False, alias=str("isJobDocument")) overwrite: Optional[bool] = Field(default=False, alias=str("overwrite")) model_config = ConfigDict(populate_by_name=True) + class RAGDocumentRequest(BaseModel): """Request model for RAG document content""" + id: str + class Document(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) owner_id: str = Field(..., alias=str("ownerId")) @@ -539,6 +635,7 @@ class Document(BaseModel): rag_chunks: Optional[int] = Field(default=0, alias=str("ragChunks")) model_config = ConfigDict(populate_by_name=True) + class DocumentContentResponse(BaseModel): document_id: str = Field(..., alias=str("documentId")) filename: str @@ -547,36 +644,40 @@ class DocumentContentResponse(BaseModel): size: int model_config = ConfigDict(populate_by_name=True) + class DocumentListResponse(BaseModel): documents: List[Document] total: int model_config = ConfigDict(populate_by_name=True) + class DocumentUpdateRequest(BaseModel): filename: Optional[str] = None options: Optional[DocumentOptions] = None model_config = ConfigDict(populate_by_name=True) + class Candidate(BaseUser): user_type: UserType = Field(UserType.CANDIDATE, alias=str("userType")) username: str description: Optional[str] = None resume: Optional[str] = None skills: Optional[List[Skill]] = None - experience: Optional[List[WorkExperience]] = None - questions: Optional[List[CandidateQuestion]] = None - education: Optional[List[Education]] = None + experience: Optional[List[WorkExperience]] = None + questions: Optional[List[CandidateQuestion]] = None + education: Optional[List[Education]] = None preferred_job_types: Optional[List[EmploymentType]] = Field(default=None, alias=str("preferredJobTypes")) desired_salary: Optional[DesiredSalary] = Field(default=None, alias=str("desiredSalary")) availability_date: Optional[datetime] = Field(default=None, alias=str("availabilityDate")) summary: Optional[str] = None - languages: Optional[List[Language]] = None + languages: Optional[List[Language]] = None certifications: Optional[List[Certification]] = None job_applications: Optional[List["JobApplication"]] = Field(default=None, alias=str("jobApplications")) rags: List[RagEntry] = Field(default_factory=list) - rag_content_size : int = 0 + rag_content_size: int = 0 is_public: bool = Field(default=True, alias=str("isPublic")) + class CandidateAI(Candidate): user_type: UserType = Field(UserType.CANDIDATE, alias=str("userType")) is_AI: bool = Field(True, alias=str("isAI")) @@ -584,6 +685,7 @@ class CandidateAI(Candidate): gender: Optional[UserGender] = None ethnicity: Optional[str] = None + class Employer(BaseUser): user_type: UserType = Field(UserType.EMPLOYER, alias=str("userType")) company_name: str = Field(..., alias=str("companyName")) @@ -597,6 +699,7 @@ class Employer(BaseUser): social_links: Optional[List[SocialLink]] = Field(default=None, alias=str("socialLinks")) poc: Optional[PointOfContact] = None + class Guest(BaseUser): user_type: UserType = Field(UserType.GUEST, alias=str("userType")) session_id: str = Field(..., alias=str("sessionId")) @@ -609,6 +712,7 @@ class Guest(BaseUser): is_public: bool = Field(default=False, alias=str("isPublic")) model_config = ConfigDict(populate_by_name=True, use_enum_values=True) + class Authentication(BaseModel): user_id: str = Field(..., alias=str("userId")) password_hash: str = Field(..., alias=str("passwordHash")) @@ -624,30 +728,35 @@ class Authentication(BaseModel): locked_until: Optional[datetime] = Field(default=None, alias=str("lockedUntil")) model_config = ConfigDict(populate_by_name=True) + class AuthResponse(BaseModel): access_token: str = Field(..., alias=str("accessToken")) refresh_token: str = Field(..., alias=str("refreshToken")) user: Union[Candidate, Employer, Guest] # Add Guest support expires_at: int = Field(..., alias=str("expiresAt")) user_type: Optional[str] = Field(default=UserType.GUEST, alias=str("userType")) # Explicit user type - is_guest: Optional[bool] = Field(default=True, alias=str("isGuest")) # Guest indicator + is_guest: Optional[bool] = Field(default=True, alias=str("isGuest")) # Guest indicator model_config = ConfigDict(populate_by_name=True) + class GuestCleanupRequest(BaseModel): """Request to cleanup inactive guests""" + inactive_hours: int = Field(24, alias=str("inactiveHours")) model_config = ConfigDict(populate_by_name=True) + class Requirements(BaseModel): required: List[str] = Field(default_factory=list) preferred: List[str] = Field(default_factory=list) - @model_validator(mode='before') + @model_validator(mode="before") def validate_requirements(cls, values): if not isinstance(values, dict): raise ValueError("Requirements must be a dictionary with 'required' and 'preferred' keys.") return values - + + class JobRequirements(BaseModel): technical_skills: Requirements = Field(..., alias=str("technicalSkills")) experience_requirements: Requirements = Field(..., alias=str("experienceRequirements")) @@ -659,6 +768,7 @@ class JobRequirements(BaseModel): company_values: Optional[List[str]] = Field(default=None, alias=str("companyValues")) model_config = ConfigDict(populate_by_name=True) + class JobDetails(BaseModel): location: Location salary_range: Optional[SalaryRange] = Field(default=None, alias=str("salaryRange")) @@ -675,6 +785,7 @@ class JobDetails(BaseModel): views: int = 0 application_count: int = Field(0, alias=str("applicationCount")) + class Job(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) owner_id: str = Field(..., alias=str("ownerId")) @@ -690,6 +801,7 @@ class Job(BaseModel): details: Optional[JobDetails] = None model_config = ConfigDict(populate_by_name=True) + class InterviewFeedback(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) interview_id: str = Field(..., alias=str("interviewId")) @@ -707,6 +819,7 @@ class InterviewFeedback(BaseModel): skill_assessments: Optional[List[SkillAssessment]] = Field(default=None, alias=str("skillAssessments")) model_config = ConfigDict(populate_by_name=True) + class InterviewSchedule(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) application_id: str = Field(..., alias=str("applicationId")) @@ -721,6 +834,7 @@ class InterviewSchedule(BaseModel): meeting_link: Optional[HttpUrl] = Field(default=None, alias=str("meetingLink")) model_config = ConfigDict(populate_by_name=True) + class JobApplication(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) job_id: str = Field(..., alias=str("jobId")) @@ -737,8 +851,10 @@ class JobApplication(BaseModel): decision: Optional[ApplicationDecision] = None model_config = ConfigDict(populate_by_name=True) + class GuestSessionResponse(BaseModel): """Response for guest session creation""" + access_token: str = Field(..., alias=str("accessToken")) refresh_token: str = Field(..., alias=str("refreshToken")) user: Guest @@ -746,23 +862,29 @@ class GuestSessionResponse(BaseModel): user_type: Literal["guest"] = Field("guest", alias=str("userType")) is_guest: bool = Field(True, alias=str("isGuest")) model_config = ConfigDict(populate_by_name=True) - + + class ChatContext(BaseModel): type: ChatContextType related_entity_id: Optional[str] = Field(default=None, alias=str("relatedEntityId")) - related_entity_type: Optional[Literal["job", "candidate", "employer"]] = Field(default=None, alias=str("relatedEntityType")) + related_entity_type: Optional[Literal["job", "candidate", "employer"]] = Field( + default=None, alias=str("relatedEntityType") + ) additional_context: Optional[Dict[str, Any]] = Field({}, alias=str("additionalContext")) model_config = ConfigDict(populate_by_name=True) + class ChatOptions(BaseModel): seed: Optional[int] = 8911 num_ctx: Optional[int] = Field(default=None, alias=str("numCtx")) # Number of context tokens - temperature: Optional[float] = Field(default=0.7) # Higher temperature to encourage tool usage + temperature: Optional[float] = Field(default=0.7) # Higher temperature to encourage tool usage model_config = ConfigDict(populate_by_name=True) + # Add rate limiting configuration models class RateLimitConfig(BaseModel): """Rate limit configuration""" + requests_per_minute: int = Field(..., alias=str("requestsPerMinute")) requests_per_hour: int = Field(..., alias=str("requestsPerHour")) requests_per_day: int = Field(..., alias=str("requestsPerDay")) @@ -770,8 +892,10 @@ class RateLimitConfig(BaseModel): burst_window_seconds: int = Field(60, alias=str("burstWindowSeconds")) model_config = ConfigDict(populate_by_name=True) + class RateLimitResult(BaseModel): """Result of rate limit check""" + allowed: bool reason: Optional[str] = None retry_after_seconds: Optional[int] = Field(default=None, alias=str("retryAfterSeconds")) @@ -779,8 +903,10 @@ class RateLimitResult(BaseModel): reset_times: Dict[str, datetime] = Field(default_factory=dict, alias=str("resetTimes")) model_config = ConfigDict(populate_by_name=True) + class RateLimitStatus(BaseModel): """Rate limit status for a user""" + user_id: str = Field(..., alias=str("userId")) user_type: str = Field(..., alias=str("userType")) is_admin: bool = Field(..., alias=str("isAdmin")) @@ -790,11 +916,12 @@ class RateLimitStatus(BaseModel): reset_times: Dict[str, datetime] = Field(..., alias=str("resetTimes")) config: RateLimitConfig model_config = ConfigDict(populate_by_name=True) - + # Add guest conversion request models class GuestConversionRequest(BaseModel): """Request to convert guest to permanent user""" + account_type: Literal["candidate", "employer"] = Field(..., alias=str("accountType")) email: EmailStr username: str @@ -802,7 +929,7 @@ class GuestConversionRequest(BaseModel): first_name: str = Field(..., alias=str("firstName")) last_name: str = Field(..., alias=str("lastName")) phone: Optional[str] = None - + # Employer-specific fields (optional) company_name: Optional[str] = Field(default=None, alias=str("companyName")) industry: Optional[str] = None @@ -810,24 +937,26 @@ class GuestConversionRequest(BaseModel): company_description: Optional[str] = Field(default=None, alias=str("companyDescription")) website_url: Optional[HttpUrl] = Field(default=None, alias=str("websiteUrl")) model_config = ConfigDict(populate_by_name=True) - - @field_validator('username') + + @field_validator("username") def validate_username(cls, v): if not v or len(v.strip()) < 3: - raise ValueError('Username must be at least 3 characters long') + raise ValueError("Username must be at least 3 characters long") return v.strip().lower() - - @field_validator('password') + + @field_validator("password") def validate_password_strength(cls, v): # Import here to avoid circular imports is_valid, issues = validate_password_strength(v) if not is_valid: - raise ValueError('; '.join(issues)) + raise ValueError("; ".join(issues)) return v + # Add guest statistics response model class GuestStatistics(BaseModel): """Guest usage statistics""" + total_guests: int = Field(..., alias=str("totalGuests")) active_last_hour: int = Field(..., alias=str("activeLastHour")) active_last_day: int = Field(..., alias=str("activeLastDay")) @@ -836,7 +965,9 @@ class GuestStatistics(BaseModel): creation_timeline: Dict[str, int] = Field(..., alias=str("creationTimeline")) model_config = ConfigDict(populate_by_name=True) -from utils.llm_proxy import (LLMMessage) + +from utils.llm_proxy import LLMMessage + class ApiMessage(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) @@ -847,23 +978,27 @@ class ApiMessage(BaseModel): timestamp: datetime = Field(default_factory=lambda: datetime.now(UTC), alias=str("timestamp")) model_config = ConfigDict(populate_by_name=True) + MOCK_UUID = str(uuid.uuid4()) + class ChatMessageStreaming(ApiMessage): status: ApiStatusType = ApiStatusType.STREAMING type: ApiMessageType = ApiMessageType.TEXT content: str + class ApiActivityType(str, Enum): - SYSTEM = "system" # Used solely on frontend - INFO = "info" # Used solely on frontend - SEARCHING = "searching" # Used when generating RAG information - THINKING = "thinking" # Used when determing if AI will use tools - GENERATING = "generating" # Used when AI is generating a response - CONVERTING = "converting" # Used when AI is generating a response - GENERATING_IMAGE = "generating_image" # Used when AI is generating an image - TOOLING = "tooling" # Used when AI is using tools - HEARTBEAT = "heartbeat" # Used for periodic updates + SYSTEM = "system" # Used solely on frontend + INFO = "info" # Used solely on frontend + SEARCHING = "searching" # Used when generating RAG information + THINKING = "thinking" # Used when determing if AI will use tools + GENERATING = "generating" # Used when AI is generating a response + CONVERTING = "converting" # Used when AI is generating a response + GENERATING_IMAGE = "generating_image" # Used when AI is generating an image + TOOLING = "tooling" # Used when AI is using tools + HEARTBEAT = "heartbeat" # Used for periodic updates + class ChatMessageStatus(ApiMessage): sender_id: Optional[str] = Field(default=MOCK_UUID, alias=str("senderId")) @@ -872,21 +1007,25 @@ class ChatMessageStatus(ApiMessage): activity: ApiActivityType content: Any + class ChatMessageError(ApiMessage): sender_id: Optional[str] = Field(default=MOCK_UUID, alias=str("senderId")) status: ApiStatusType = ApiStatusType.ERROR type: ApiMessageType = ApiMessageType.TEXT content: str + class ChatMessageRagSearch(ApiMessage): type: ApiMessageType = Field(default=ApiMessageType.JSON) dimensions: int = 2 | 3 content: List[ChromaDBGetResponse] = [] + class JobRequirementsMessage(ApiMessage): type: ApiMessageType = ApiMessageType.JSON job: Job = Field(..., alias=str("job")) + class DocumentMessage(ApiMessage): type: ApiMessageType = ApiMessageType.JSON sender_id: Optional[str] = Field(default=MOCK_UUID, alias=str("senderId")) @@ -895,6 +1034,7 @@ class DocumentMessage(ApiMessage): converted: bool = Field(False, alias=str("converted")) model_config = ConfigDict(populate_by_name=True) + class ChatMessageMetaData(BaseModel): model: AIModelType = AIModelType.QWEN2_5 temperature: float = 0.7 @@ -914,6 +1054,7 @@ class ChatMessageMetaData(BaseModel): timers: Dict[str, float] = Field(default_factory=dict) model_config = ConfigDict(populate_by_name=True) + class ChatMessageUser(ApiMessage): type: ApiMessageType = ApiMessageType.TEXT status: ApiStatusType = ApiStatusType.DONE @@ -921,19 +1062,22 @@ class ChatMessageUser(ApiMessage): content: str = "" tunables: Optional[Tunables] = None + class ChatMessage(ChatMessageUser): role: ChatSenderType = ChatSenderType.ASSISTANT metadata: ChatMessageMetaData = Field(default=ChatMessageMetaData()) - #attachments: Optional[List[Attachment]] = None - #reactions: Optional[List[MessageReaction]] = None - #is_edited: bool = Field(False, alias=str("isEdited")) - #edit_history: Optional[List[EditHistory]] = Field(default=None, alias=str("editHistory")) + # attachments: Optional[List[Attachment]] = None + # reactions: Optional[List[MessageReaction]] = None + # is_edited: bool = Field(False, alias=str("isEdited")) + # edit_history: Optional[List[EditHistory]] = Field(default=None, alias=str("editHistory")) + class ChatMessageSkillAssessment(ChatMessageUser): role: ChatSenderType = ChatSenderType.ASSISTANT metadata: ChatMessageMetaData = Field(default=ChatMessageMetaData()) skill_assessment: SkillAssessment = Field(..., alias=str("skillAssessment")) + class ChatMessageResume(ChatMessageUser): role: ChatSenderType = ChatSenderType.ASSISTANT metadata: ChatMessageMetaData = Field(default=ChatMessageMetaData()) @@ -942,6 +1086,7 @@ class ChatMessageResume(ChatMessageUser): prompt: Optional[str] = Field(default=None, alias=str("prompt")) model_config = ConfigDict(populate_by_name=True) + class Resume(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) job_id: str = Field(..., alias=str("jobId")) @@ -953,16 +1098,19 @@ class Resume(BaseModel): candidate: Optional[Candidate] = None model_config = ConfigDict(populate_by_name=True) + class ResumeMessage(ChatMessageUser): role: ChatSenderType = ChatSenderType.ASSISTANT resume: Resume model_config = ConfigDict(populate_by_name=True) + class GPUInfo(BaseModel): name: str memory: int discrete: bool + class SystemInfo(BaseModel): installed_RAM: str = Field(..., alias=str("installedRAM")) graphics_cards: List[GPUInfo] = Field(..., alias=str("graphicsCards")) @@ -972,6 +1120,7 @@ class SystemInfo(BaseModel): max_context_length: int = Field(default=defines.max_context, alias=str("maxContextLength")) model_config = ConfigDict(populate_by_name=True) + class ChatSession(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) user_id: Optional[str] = Field(default=None, alias=str("userId")) @@ -991,6 +1140,7 @@ class ChatSession(BaseModel): raise ValueError("Either user_id or guest_id must be provided") return self + class DataSourceConfiguration(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) rag_config_id: str = Field(..., alias=str("ragConfigId")) @@ -1005,6 +1155,7 @@ class DataSourceConfiguration(BaseModel): metadata: Optional[Dict[str, Any]] = None model_config = ConfigDict(populate_by_name=True) + class RAGConfiguration(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) user_id: str = Field(..., alias=str("userId")) @@ -1020,6 +1171,7 @@ class RAGConfiguration(BaseModel): is_active: bool = Field(..., alias=str("isActive")) model_config = ConfigDict(populate_by_name=True) + class UserActivity(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) user_id: Optional[str] = Field(default=None, alias=str("userId")) @@ -1038,6 +1190,7 @@ class UserActivity(BaseModel): raise ValueError("Either user_id or guest_id must be provided") return self + class Analytics(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) entity_type: Literal["job", "candidate", "chat", "system", "employer"] = Field(..., alias=str("entityType")) @@ -1049,6 +1202,7 @@ class Analytics(BaseModel): segment: Optional[str] = None model_config = ConfigDict(populate_by_name=True) + class UserPreference(BaseModel): user_id: str = Field(..., alias=str("userId")) theme: ThemePreference @@ -1060,6 +1214,7 @@ class UserPreference(BaseModel): email_frequency: Literal["immediate", "daily", "weekly", "never"] = Field(..., alias=str("emailFrequency")) model_config = ConfigDict(populate_by_name=True) + # ============================ # API Request/Response Models # ============================ @@ -1072,20 +1227,21 @@ class CreateCandidateRequest(BaseModel): # Add other required candidate fields as needed phone: Optional[str] = None model_config = ConfigDict(populate_by_name=True) - - @field_validator('username') + + @field_validator("username") def validate_username(cls, v): if not v or len(v.strip()) < 3: - raise ValueError('Username must be at least 3 characters long') + raise ValueError("Username must be at least 3 characters long") return v.strip().lower() - - @field_validator('password') + + @field_validator("password") def validate_password_strength(cls, v): is_valid, issues = validate_password_strength(v) if not is_valid: - raise ValueError('; '.join(issues)) + raise ValueError("; ".join(issues)) return v + # Create Employer Endpoint (similar pattern) class CreateEmployerRequest(BaseModel): email: EmailStr @@ -1093,32 +1249,34 @@ class CreateEmployerRequest(BaseModel): password: str company_name: str = Field(..., alias=str("companyName")) industry: str - company_size: str = Field(..., alias=str("companySize")) + company_size: str = Field(..., alias=str("companySize")) company_description: str = Field(..., alias=str("companyDescription")) # Add other required employer fields website_url: Optional[str] = Field(default=None, alias=str("websiteUrl")) phone: Optional[str] = None model_config = ConfigDict(populate_by_name=True) - - @field_validator('username') + + @field_validator("username") def validate_username(cls, v): if not v or len(v.strip()) < 3: - raise ValueError('Username must be at least 3 characters long') + raise ValueError("Username must be at least 3 characters long") return v.strip().lower() - - @field_validator('password') + + @field_validator("password") def validate_password_strength(cls, v): is_valid, issues = validate_password_strength(v) if not is_valid: - raise ValueError('; '.join(issues)) + raise ValueError("; ".join(issues)) return v + class ChatQuery(BaseModel): prompt: str tunables: Optional[Tunables] = None agent_options: Optional[Dict[str, Any]] = Field(default=None, alias=str("agentOptions")) model_config = ConfigDict(populate_by_name=True) + class PaginatedRequest(BaseModel): page: Annotated[int, Field(ge=1)] = 1 limit: Annotated[int, Field(ge=1, le=100)] = 20 @@ -1127,6 +1285,7 @@ class PaginatedRequest(BaseModel): filters: Optional[Dict[str, Any]] = None model_config = ConfigDict(populate_by_name=True) + class SearchQuery(BaseModel): query: str filters: Optional[Dict[str, Any]] = None @@ -1136,6 +1295,7 @@ class SearchQuery(BaseModel): sort_order: Optional[SortOrder] = Field(default=None, alias=str("sortOrder")) model_config = ConfigDict(populate_by_name=True) + class PaginatedResponse(BaseModel): data: List[Any] # Will be typed specifically when used total: int @@ -1145,12 +1305,14 @@ class PaginatedResponse(BaseModel): has_more: bool = Field(..., alias=str("hasMore")) model_config = ConfigDict(populate_by_name=True) + class ApiResponse(BaseModel): success: bool data: Optional[Any] = None # Will be typed specifically when used error: Optional[ErrorDetail] = None meta: Optional[Dict[str, Any]] = None + # Specific typed response models for common use cases class CandidateResponse(BaseModel): success: bool @@ -1158,30 +1320,35 @@ class CandidateResponse(BaseModel): error: Optional[ErrorDetail] = None meta: Optional[Dict[str, Any]] = None + class EmployerResponse(BaseModel): success: bool data: Optional[Employer] = None error: Optional[ErrorDetail] = None meta: Optional[Dict[str, Any]] = None + class JobResponse(BaseModel): success: bool data: Optional[Job] = None error: Optional[ErrorDetail] = None meta: Optional[Dict[str, Any]] = None + class CandidateListResponse(BaseModel): success: bool data: Optional[List[Candidate]] = None error: Optional[ErrorDetail] = None meta: Optional[Dict[str, Any]] = None + class JobListResponse(BaseModel): success: bool data: Optional[List["Job"]] = None error: Optional[ErrorDetail] = None meta: Optional[Dict[str, Any]] = None + User = Union[Candidate, CandidateAI, Employer, Guest] # Forward references resolution @@ -1189,4 +1356,4 @@ Candidate.update_forward_refs() Employer.update_forward_refs() ChatSession.update_forward_refs() JobApplication.update_forward_refs() -Job.update_forward_refs() \ No newline at end of file +Job.update_forward_refs() diff --git a/src/backend/pyproject.toml b/src/backend/pyproject.toml new file mode 100644 index 0000000..afc37d4 --- /dev/null +++ b/src/backend/pyproject.toml @@ -0,0 +1,23 @@ +[tool.ruff] +# Set line length (default is 88, Black-compatible) +line-length = 120 +exclude = [ + ".git", + "__pycache__", + "migrations", + "venv", + "dist", + "build", +] +target-version = "py312" + +[tool.ruff.format] +# Use Ruff's formatter (Black-compatible by default) +quote-style = "double" # Enforce double quotes +indent-style = "space" # Use spaces for indentation +skip-magic-trailing-comma = false # Add trailing commas where applicable + +[tool.ruff.per-file-ignores] +# Ignore specific rules for certain files (e.g., tests or scripts) +"tests/**" = ["D100", "S101"] # Ignore missing docstrings and assert warnings in tests +"scripts/**" = ["E402"] # Ignore import order in scripts \ No newline at end of file diff --git a/src/backend/rag/__init__.py b/src/backend/rag/__init__.py index d499be5..d292b54 100644 --- a/src/backend/rag/__init__.py +++ b/src/backend/rag/__init__.py @@ -1,7 +1,3 @@ from .rag import ChromaDBFileWatcher, start_file_watcher, RagEntry -__all__ = [ - "ChromaDBFileWatcher", - "start_file_watcher", - "RagEntry" -] - \ No newline at end of file + +__all__ = ["ChromaDBFileWatcher", "start_file_watcher", "RagEntry"] diff --git a/src/backend/rag/markdown_chunker.py b/src/backend/rag/markdown_chunker.py index b5c2e3d..6ab19fb 100644 --- a/src/backend/rag/markdown_chunker.py +++ b/src/backend/rag/markdown_chunker.py @@ -7,10 +7,12 @@ import logging import defines + class Chunk(TypedDict): text: str metadata: Dict[str, Any] + def clear_chunk(chunk: Chunk): chunk["text"] = "" chunk["metadata"] = { @@ -22,6 +24,7 @@ def clear_chunk(chunk: Chunk): } return chunk + class MarkdownChunker: def __init__(self): # Initialize the Markdown parser @@ -76,10 +79,7 @@ class MarkdownChunker: # Initialize a chunk structure chunk: Chunk = { "text": "", - "metadata": { - "source_file": file_path, - "lines": total_lines - }, + "metadata": {"source_file": file_path, "lines": total_lines}, } clear_chunk(chunk) @@ -89,9 +89,7 @@ class MarkdownChunker: return chunks def _sanitize_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]: - return { - k: ("" if v is None else v) for k, v in metadata.items() if v is not None - } + return {k: ("" if v is None else v) for k, v in metadata.items() if v is not None} def _extract_text_from_children(self, node: SyntaxTreeNode) -> str: lines = [] @@ -114,7 +112,7 @@ class MarkdownChunker: chunks: List[Chunk], chunk: Chunk, level: int, - buffer: int = defines.chunk_buffer + buffer: int = defines.chunk_buffer, ) -> int: is_list = False # Handle heading nodes @@ -141,7 +139,7 @@ class MarkdownChunker: if node.nester_tokens: opening, closing = node.nester_tokens if opening and opening.map: - ( begin, end ) = opening.map + (begin, end) = opening.map metadata = chunk["metadata"] metadata["chunk_begin"] = max(0, begin - buffer) metadata["chunk_end"] = min(metadata["lines"], end + buffer) @@ -186,7 +184,7 @@ class MarkdownChunker: if node.nester_tokens: opening, closing = node.nester_tokens if opening and opening.map: - ( begin, end ) = opening.map + (begin, end) = opening.map metadata = chunk["metadata"] metadata["chunk_begin"] = max(0, begin - buffer) metadata["chunk_end"] = min(metadata["lines"], end + buffer) @@ -198,9 +196,7 @@ class MarkdownChunker: # Recursively process children if not is_list: for child in node.children: - level = self._process_node( - child, current_headings, chunks, chunk, level=level - ) + level = self._process_node(child, current_headings, chunks, chunk, level=level) # After root-level recursion, finalize any remaining chunk if node.type == "document": @@ -211,7 +207,7 @@ class MarkdownChunker: if node.nester_tokens: opening, closing = node.nester_tokens if opening and opening.map: - ( begin, end ) = opening.map + (begin, end) = opening.map metadata = chunk["metadata"] metadata["chunk_begin"] = max(0, begin - buffer) metadata["chunk_end"] = min(metadata["lines"], end + buffer) diff --git a/src/backend/rag/rag.py b/src/backend/rag/rag.py index 05fe91d..59f3319 100644 --- a/src/backend/rag/rag.py +++ b/src/backend/rag/rag.py @@ -11,10 +11,10 @@ import json import numpy as np # type: ignore import traceback -import chromadb # type: ignore +import chromadb # type: ignore from watchdog.observers import Observer # type: ignore from watchdog.events import FileSystemEventHandler # type: ignore -import umap # type: ignore +import umap # type: ignore from markitdown import MarkItDown # type: ignore from chromadb.api.models.Collection import Collection # type: ignore @@ -33,11 +33,13 @@ __all__ = ["ChromaDBFileWatcher", "start_file_watcher"] DEFAULT_CHUNK_SIZE = 750 DEFAULT_CHUNK_OVERLAP = 100 + class RagEntry(BaseModel): name: str description: str = "" enabled: bool = True - + + class ChromaDBFileWatcher(FileSystemEventHandler): def __init__( self, @@ -72,9 +74,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler): # self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') # Path for storing file hash state - self.hash_state_path = os.path.join( - self.persist_directory, f"{collection_name}_hash_state.json" - ) + self.hash_state_path = os.path.join(self.persist_directory, f"{collection_name}_hash_state.json") # Flag to track if this is a new collection self.is_new_collection = False @@ -158,9 +158,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler): process_all: If True, process all files regardless of hash status """ # Check for new or modified files - file_paths = glob.glob( - os.path.join(self.watch_directory, "**/*"), recursive=True - ) + file_paths = glob.glob(os.path.join(self.watch_directory, "**/*"), recursive=True) files_checked = 0 files_processed = 0 files_to_process = [] @@ -180,20 +178,12 @@ class ChromaDBFileWatcher(FileSystemEventHandler): continue # If file is new, changed, or we're processing all files - if ( - process_all - or file_path not in self.file_hashes - or self.file_hashes[file_path] != current_hash - ): + if process_all or file_path not in self.file_hashes or self.file_hashes[file_path] != current_hash: self.file_hashes[file_path] = current_hash files_to_process.append(file_path) - logging.info( - f"File {'found' if process_all else 'changed'}: {file_path}" - ) + logging.info(f"File {'found' if process_all else 'changed'}: {file_path}") - logging.info( - f"Found {len(files_to_process)} files to process after scanning {files_checked} files" - ) + logging.info(f"Found {len(files_to_process)} files to process after scanning {files_checked} files") # Check for deleted files deleted_files = [] @@ -201,9 +191,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler): if not os.path.exists(file_path): deleted_files.append(file_path) # Schedule removal - asyncio.run_coroutine_threadsafe( - self.remove_file_from_collection(file_path), self.loop - ) + asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop) # Don't block on result, just let it run logging.info(f"File deleted: {file_path}") @@ -253,10 +241,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler): if not current_hash: # File might have been deleted or is inaccessible return - if ( - file_path in self.file_hashes - and self.file_hashes[file_path] == current_hash - ): + if file_path in self.file_hashes and self.file_hashes[file_path] == current_hash: # File hasn't actually changed in content logging.info(f"Hash has not changed for {file_path}") return @@ -289,9 +274,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler): if results and "ids" in results and results["ids"]: self.collection.delete(ids=results["ids"]) await self.database.update_user_rag_timestamp(self.user_id) - logging.info( - f"Removed {len(results['ids'])} chunks for deleted file: {file_path}" - ) + logging.info(f"Removed {len(results['ids'])} chunks for deleted file: {file_path}") # Remove from hash dictionary if file_path in self.file_hashes: @@ -304,17 +287,15 @@ class ChromaDBFileWatcher(FileSystemEventHandler): def _update_umaps(self): # Update the UMAP embeddings - self._umap_collection = ChromaDBGetResponse.model_validate(self._collection.get( - include=["embeddings", "documents", "metadatas"] - )) + self._umap_collection = ChromaDBGetResponse.model_validate( + self._collection.get(include=["embeddings", "documents", "metadatas"]) + ) if not self._umap_collection or not len(self._umap_collection.embeddings): logging.warning("⚠️ No embeddings found in the collection.") return # During initialization - logging.info( - f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors" - ) + logging.info(f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors") vectors = np.array(self._umap_collection.embeddings) self._umap_model_2d = umap.UMAP( n_components=2, @@ -323,14 +304,12 @@ class ChromaDBFileWatcher(FileSystemEventHandler): n_neighbors=30, min_dist=0.1, ) - self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors) # type: ignore + self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors) # type: ignore # logging.info( # f"2D UMAP model n_components: {self._umap_model_2d.n_components}" # ) # Should be 2 - logging.info( - f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors" - ) + logging.info(f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors") self._umap_model_3d = umap.UMAP( n_components=3, random_state=8911, @@ -338,7 +317,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler): n_neighbors=30, min_dist=0.01, ) - self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors)# type: ignore + self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors) # type: ignore # logging.info( # f"3D UMAP model n_components: {self._umap_model_3d.n_components}" # ) # Should be 3 @@ -350,7 +329,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler): os.makedirs(self.persist_directory) # Initialize ChromaDB client - chroma_client = chromadb.PersistentClient( + chroma_client = chromadb.PersistentClient( path=self.persist_directory, settings=chromadb.Settings(anonymized_telemetry=False), # type: ignore ) @@ -373,13 +352,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler): self.is_new_collection = True logging.info(f"Recreating collection: {self.collection_name}") - return chroma_client.get_or_create_collection( - name=self.collection_name, metadata={"hnsw:space": "cosine"} - ) + return chroma_client.get_or_create_collection(name=self.collection_name, metadata={"hnsw:space": "cosine"}) async def get_embedding(self, text: str) -> np.ndarray: """Generate and normalize an embedding for the given text.""" - + # Get embedding try: response = await self.llm.embeddings(model=defines.embedding_model, input_texts=text) @@ -419,9 +396,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler): # Generate a more unique ID based on content and metadata path_hash = "" if "path" in metadata: - path_hash = hashlib.md5(metadata["source_file"].encode()).hexdigest()[ - :8 - ] + path_hash = hashlib.md5(metadata["source_file"].encode()).hexdigest()[:8] content_hash = hashlib.md5(text.encode()).hexdigest()[:8] chunk_id = f"{path_hash}_{i}_{content_hash}" @@ -438,7 +413,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler): logging.error(traceback.format_exc()) logging.error(chunk) - def prepare_metadata(self, meta: Dict[str, Any], buffer=defines.chunk_buffer)-> str | None: + def prepare_metadata(self, meta: Dict[str, Any], buffer=defines.chunk_buffer) -> str | None: source_file = meta.get("source_file") try: source_file = meta["source_file"] @@ -541,9 +516,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler): return file_path = event.src_path - asyncio.run_coroutine_threadsafe( - self.remove_file_from_collection(file_path), self.loop - ) + asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop) logging.info(f"File deleted: {file_path}") def on_moved(self, event): @@ -571,11 +544,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler): try: # Remove existing entries for this file existing_results = self.collection.get(where={"path": file_path}) - if ( - existing_results - and "ids" in existing_results - and existing_results["ids"] - ): + if existing_results and "ids" in existing_results and existing_results["ids"]: self.collection.delete(ids=existing_results["ids"]) await self.database.update_user_rag_timestamp(self.user_id) @@ -584,15 +553,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler): p = Path(file_path) p_as_md = p.with_suffix(".md") if p_as_md.exists(): - logging.info( - f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}" - ) + logging.info(f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}") # If file_path.md doesn't exist or file_path is newer than file_path.md, # fire off markitdown - if (not p_as_md.exists()) or ( - p.stat().st_mtime > p_as_md.stat().st_mtime - ): + if (not p_as_md.exists()) or (p.stat().st_mtime > p_as_md.stat().st_mtime): self._markitdown(file_path, p_as_md) return @@ -626,9 +591,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler): # Process all files regardless of hash state num_processed = await self.scan_directory(process_all=True) - logging.info( - f"Vectorstore initialized with {self.collection.count()} documents" - ) + logging.info(f"Vectorstore initialized with {self.collection.count()} documents") self._update_umaps() @@ -676,7 +639,7 @@ def start_file_watcher( persist_directory=persist_directory, collection_name=collection_name, recreate=recreate, - database=database + database=database, ) # Process all files if: diff --git a/src/backend/routes/__init__.py b/src/backend/routes/__init__.py index cc5bf08..6ae262e 100644 --- a/src/backend/routes/__init__.py +++ b/src/backend/routes/__init__.py @@ -13,14 +13,4 @@ from . import employers from . import admin from . import system -__all__ = [ - "auth", - "candidates", - "resumes", - "jobs", - "chat", - "users", - "employers", - "admin", - "system" -] \ No newline at end of file +__all__ = ["auth", "candidates", "resumes", "jobs", "chat", "users", "employers", "admin", "system"] diff --git a/src/backend/routes/admin.py b/src/backend/routes/admin.py index e26b169..faf317c 100644 --- a/src/backend/routes/admin.py +++ b/src/backend/routes/admin.py @@ -5,8 +5,18 @@ import json from datetime import datetime, timezone, UTC from fastapi import ( - APIRouter, HTTPException, Depends, Body, Request, HTTPException, - Depends, Query, Path, Body, APIRouter, Request + APIRouter, + HTTPException, + Depends, + Body, + Request, + HTTPException, + Depends, + Query, + Path, + Body, + APIRouter, + Request, ) from fastapi.responses import JSONResponse @@ -14,402 +24,346 @@ from fastapi.responses import JSONResponse from utils.rate_limiter import RateLimiter, get_rate_limiter from database.manager import RedisDatabase from logger import logger -from utils.dependencies import ( - get_current_admin, get_current_user_or_guest, get_database, background_task_manager -) +from utils.dependencies import get_current_admin, get_current_user_or_guest, get_database, background_task_manager from utils.responses import ( - create_paginated_response, create_success_response, create_error_response, + create_paginated_response, + create_success_response, + create_error_response, ) # Create router for authentication endpoints router = APIRouter(prefix="/admin", tags=["admin"]) + @router.post("/tasks/cleanup-guests") async def manual_guest_cleanup( inactive_hours: int = Body(24, embed=True), - current_user = Depends(get_current_admin), - admin_user = Depends(get_current_admin) + current_user=Depends(get_current_admin), + admin_user=Depends(get_current_admin), ): """Manually trigger guest cleanup (admin only)""" try: if not background_task_manager: return JSONResponse( status_code=500, - content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available") + content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available"), ) - + cleaned_count = await background_task_manager.cleanup_inactive_guests(inactive_hours) - + logger.info(f"🧹 Manual guest cleanup triggered by admin {admin_user.id}: {cleaned_count} guests cleaned") - - return create_success_response({ - "message": f"Guest cleanup completed. Removed {cleaned_count} inactive sessions.", - "cleaned_count": cleaned_count, - "triggered_by": admin_user.id - }) - - except Exception as e: - logger.error(f"❌ Manual guest cleanup error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("CLEANUP_ERROR", str(e)) + + return create_success_response( + { + "message": f"Guest cleanup completed. Removed {cleaned_count} inactive sessions.", + "cleaned_count": cleaned_count, + "triggered_by": admin_user.id, + } ) + except Exception as e: + logger.error(f"❌ Manual guest cleanup error: {e}") + return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e))) + + @router.post("/tasks/cleanup-tokens") -async def manual_token_cleanup( - admin_user = Depends(get_current_admin) -): +async def manual_token_cleanup(admin_user=Depends(get_current_admin)): """Manually trigger verification token cleanup (admin only)""" try: global background_task_manager - + if not background_task_manager: return JSONResponse( status_code=500, - content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available") + content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available"), ) - + cleaned_count = await background_task_manager.cleanup_expired_verification_tokens() - + logger.info(f"🧹 Manual token cleanup triggered by admin {admin_user.id}: {cleaned_count} tokens cleaned") - - return create_success_response({ - "message": f"Token cleanup completed. Removed {cleaned_count} expired tokens.", - "cleaned_count": cleaned_count, - "triggered_by": admin_user.id - }) - - except Exception as e: - logger.error(f"❌ Manual token cleanup error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("CLEANUP_ERROR", str(e)) + + return create_success_response( + { + "message": f"Token cleanup completed. Removed {cleaned_count} expired tokens.", + "cleaned_count": cleaned_count, + "triggered_by": admin_user.id, + } ) + except Exception as e: + logger.error(f"❌ Manual token cleanup error: {e}") + return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e))) + + @router.post("/tasks/cleanup-rate-limits") -async def manual_rate_limit_cleanup( - days_old: int = Body(7, embed=True), - admin_user = Depends(get_current_admin) -): +async def manual_rate_limit_cleanup(days_old: int = Body(7, embed=True), admin_user=Depends(get_current_admin)): """Manually trigger rate limit data cleanup (admin only)""" try: global background_task_manager - + if not background_task_manager: return JSONResponse( status_code=500, - content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available") + content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available"), ) - + cleaned_count = await background_task_manager.cleanup_old_rate_limit_data(days_old) - + logger.info(f"🧹 Manual rate limit cleanup triggered by admin {admin_user.id}: {cleaned_count} keys cleaned") - - return create_success_response({ - "message": f"Rate limit cleanup completed. Removed {cleaned_count} old keys.", - "cleaned_count": cleaned_count, - "triggered_by": admin_user.id - }) - + + return create_success_response( + { + "message": f"Rate limit cleanup completed. Removed {cleaned_count} old keys.", + "cleaned_count": cleaned_count, + "triggered_by": admin_user.id, + } + ) + except Exception as e: logger.error(f"❌ Manual rate limit cleanup error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("CLEANUP_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e))) + # ======================================== # System Health and Maintenance Endpoints # ======================================== + @router.get("/system/health") -async def get_system_health( - request: Request, - admin_user = Depends(get_current_admin) -): +async def get_system_health(request: Request, admin_user=Depends(get_current_admin)): """Get comprehensive system health status (admin only)""" try: # Database health - database_manager = getattr(request.app.state, 'database_manager', None) + database_manager = getattr(request.app.state, "database_manager", None) db_health = {"status": "unavailable", "healthy": False} - + if database_manager: try: database_manager.get_database() from database.manager import redis_manager + redis_health = await redis_manager.health_check() db_health = { "status": redis_health.get("status", "unknown"), "healthy": redis_health.get("status") == "healthy", - "details": redis_health + "details": redis_health, } except Exception as e: - db_health = { - "status": "error", - "healthy": False, - "error": str(e) - } - + db_health = {"status": "error", "healthy": False, "error": str(e)} + # Background task health - background_task_manager = getattr(request.app.state, 'background_task_manager', None) + background_task_manager = getattr(request.app.state, "background_task_manager", None) task_health = {"status": "unavailable", "healthy": False} - + if background_task_manager: try: task_status = await background_task_manager.get_task_status() running_tasks = len([t for t in task_status["tasks"] if t["status"] == "running"]) failed_tasks = len([t for t in task_status["tasks"] if t["status"] == "failed"]) - + task_health = { "status": "healthy" if task_status["running"] and failed_tasks == 0 else "degraded", "healthy": task_status["running"] and failed_tasks == 0, "running_tasks": running_tasks, "failed_tasks": failed_tasks, - "total_tasks": task_status["task_count"] + "total_tasks": task_status["task_count"], } except Exception as e: - task_health = { - "status": "error", - "healthy": False, - "error": str(e) - } - + task_health = {"status": "error", "healthy": False, "error": str(e)} + # Overall health overall_healthy = db_health["healthy"] and task_health["healthy"] - - return create_success_response({ - "timestamp": datetime.now(UTC).isoformat(), - "overall_healthy": overall_healthy, - "components": { - "database": db_health, - "background_tasks": task_health + + return create_success_response( + { + "timestamp": datetime.now(UTC).isoformat(), + "overall_healthy": overall_healthy, + "components": {"database": db_health, "background_tasks": task_health}, } - }) - + ) + except Exception as e: logger.error(f"❌ Error getting system health: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("HEALTH_CHECK_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("HEALTH_CHECK_ERROR", str(e))) + @router.post("/maintenance/cleanup") async def run_maintenance_cleanup( - request: Request, - admin_user = Depends(get_current_admin), - database: RedisDatabase = Depends(get_database) + request: Request, admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database) ): """Run comprehensive maintenance cleanup (admin only)""" try: cleanup_results = {} - + # Run various cleanup operations cleanup_operations = [ ("inactive_guests", lambda: database.cleanup_inactive_guests(72)), # 3 days ("expired_tokens", lambda: database.cleanup_expired_verification_tokens()), ("orphaned_job_requirements", lambda: database.cleanup_orphaned_job_requirements()), ] - + for operation_name, operation_func in cleanup_operations: try: result = await operation_func() cleanup_results[operation_name] = { "success": True, "cleaned_count": result, - "message": f"Cleaned {result} items" + "message": f"Cleaned {result} items", } except Exception as e: - cleanup_results[operation_name] = { - "success": False, - "error": str(e), - "message": f"Failed: {str(e)}" - } - + cleanup_results[operation_name] = {"success": False, "error": str(e), "message": f"Failed: {str(e)}"} + # Calculate totals total_cleaned = sum( - result.get("cleaned_count", 0) - for result in cleanup_results.values() - if result.get("success", False) + result.get("cleaned_count", 0) for result in cleanup_results.values() if result.get("success", False) ) - - successful_operations = len([ - r for r in cleanup_results.values() - if r.get("success", False) - ]) - - return create_success_response({ - "message": f"Maintenance cleanup completed. {total_cleaned} items cleaned across {successful_operations} operations.", - "total_cleaned": total_cleaned, - "successful_operations": successful_operations, - "details": cleanup_results - }) - + + successful_operations = len([r for r in cleanup_results.values() if r.get("success", False)]) + + return create_success_response( + { + "message": f"Maintenance cleanup completed. {total_cleaned} items cleaned across {successful_operations} operations.", + "total_cleaned": total_cleaned, + "successful_operations": successful_operations, + "details": cleanup_results, + } + ) + except Exception as e: logger.error(f"❌ Error in maintenance cleanup: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("CLEANUP_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e))) + # ======================================== # Background Task Statistics # ======================================== + @router.get("/tasks/stats") async def get_task_statistics( - request: Request, - admin_user = Depends(get_current_admin), - database: RedisDatabase = Depends(get_database) + request: Request, admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database) ): """Get background task execution statistics (admin only)""" try: # Get guest statistics guest_stats = await database.get_guest_statistics() - + # Get background task manager status - background_task_manager = getattr(request.app.state, 'background_task_manager', None) + background_task_manager = getattr(request.app.state, "background_task_manager", None) task_manager_stats = {} - + if background_task_manager: task_status = await background_task_manager.get_task_status() task_manager_stats = { "running": task_status["running"], "task_count": task_status["task_count"], - "task_breakdown": {} + "task_breakdown": {}, } - + # Count tasks by status for task in task_status["tasks"]: status = task["status"] task_manager_stats["task_breakdown"][status] = task_manager_stats["task_breakdown"].get(status, 0) + 1 - - return create_success_response({ - "guest_statistics": guest_stats, - "task_manager": task_manager_stats, - "timestamp": datetime.now(UTC).isoformat() - }) - + + return create_success_response( + { + "guest_statistics": guest_stats, + "task_manager": task_manager_stats, + "timestamp": datetime.now(UTC).isoformat(), + } + ) + except Exception as e: logger.error(f"❌ Error getting task statistics: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("STATS_ERROR", str(e)) - ) - + return JSONResponse(status_code=500, content=create_error_response("STATS_ERROR", str(e))) + + # ======================================== # Background Task Status Endpoints # ======================================== + @router.get("/tasks/status") -async def get_background_task_status( - request: Request, - admin_user = Depends(get_current_admin) -): +async def get_background_task_status(request: Request, admin_user=Depends(get_current_admin)): """Get background task manager status (admin only)""" try: # Get background task manager from app state - background_task_manager = getattr(request.app.state, 'background_task_manager', None) - + background_task_manager = getattr(request.app.state, "background_task_manager", None) + if not background_task_manager: - return create_success_response({ - "running": False, - "message": "Background task manager not initialized", - "tasks": [], - "task_count": 0 - }) - + return create_success_response( + {"running": False, "message": "Background task manager not initialized", "tasks": [], "task_count": 0} + ) + # Get comprehensive task status using the new method task_status = await background_task_manager.get_task_status() - + # Add additional system info system_info = { "uptime_seconds": None, # Could calculate from start time if stored - "last_cleanup": None, # Could track last cleanup time + "last_cleanup": None, # Could track last cleanup time } - + # Format the response - return create_success_response({ - "running": task_status["running"], - "task_count": task_status["task_count"], - "loop_status": { - "main_loop_id": task_status["main_loop_id"], - "current_loop_id": task_status["current_loop_id"], - "loop_matches": task_status.get("loop_matches", False) - }, - "tasks": task_status["tasks"], - "system_info": system_info - }) - - except Exception as e: - logger.error(f"❌ Get task status error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("STATUS_ERROR", str(e)) + return create_success_response( + { + "running": task_status["running"], + "task_count": task_status["task_count"], + "loop_status": { + "main_loop_id": task_status["main_loop_id"], + "current_loop_id": task_status["current_loop_id"], + "loop_matches": task_status.get("loop_matches", False), + }, + "tasks": task_status["tasks"], + "system_info": system_info, + } ) + except Exception as e: + logger.error(f"❌ Get task status error: {e}") + return JSONResponse(status_code=500, content=create_error_response("STATUS_ERROR", str(e))) + + @router.post("/tasks/run/{task_name}") -async def run_background_task( - task_name: str, - request: Request, - admin_user = Depends(get_current_admin) -): +async def run_background_task(task_name: str, request: Request, admin_user=Depends(get_current_admin)): """Manually trigger a specific background task (admin only)""" try: - background_task_manager = getattr(request.app.state, 'background_task_manager', None) - + background_task_manager = getattr(request.app.state, "background_task_manager", None) + if not background_task_manager: return JSONResponse( status_code=503, - content=create_error_response( - "MANAGER_UNAVAILABLE", - "Background task manager not initialized" - ) + content=create_error_response("MANAGER_UNAVAILABLE", "Background task manager not initialized"), ) - + # List of available tasks - available_tasks = [ - "guest_cleanup", - "token_cleanup", - "guest_stats", - "rate_limit_cleanup", - "orphaned_cleanup" - ] - + available_tasks = ["guest_cleanup", "token_cleanup", "guest_stats", "rate_limit_cleanup", "orphaned_cleanup"] + if task_name not in available_tasks: return JSONResponse( status_code=400, content=create_error_response( - "INVALID_TASK", - f"Unknown task: {task_name}. Available: {available_tasks}" - ) + "INVALID_TASK", f"Unknown task: {task_name}. Available: {available_tasks}" + ), ) - + # Run the task result = await background_task_manager.force_run_task(task_name) - - return create_success_response({ - "task_name": task_name, - "result": result, - "message": f"Task {task_name} completed successfully" - }) - - except ValueError as e: - return JSONResponse( - status_code=400, - content=create_error_response("INVALID_TASK", str(e)) - ) - except Exception as e: - logger.error(f"❌ Error running task {task_name}: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("TASK_EXECUTION_ERROR", str(e)) + + return create_success_response( + {"task_name": task_name, "result": result, "message": f"Task {task_name} completed successfully"} ) + except ValueError as e: + return JSONResponse(status_code=400, content=create_error_response("INVALID_TASK", str(e))) + except Exception as e: + logger.error(f"❌ Error running task {task_name}: {e}") + return JSONResponse(status_code=500, content=create_error_response("TASK_EXECUTION_ERROR", str(e))) + + @router.get("/tasks/list") -async def list_available_tasks( - admin_user = Depends(get_current_admin) -): +async def list_available_tasks(admin_user=Depends(get_current_admin)): """List all available background tasks (admin only)""" try: tasks = [ @@ -417,163 +371,146 @@ async def list_available_tasks( "name": "guest_cleanup", "description": "Clean up inactive guest sessions", "interval": "6 hours", - "parameters": ["inactive_hours (default: 48)"] + "parameters": ["inactive_hours (default: 48)"], }, { - "name": "token_cleanup", + "name": "token_cleanup", "description": "Clean up expired email verification tokens", "interval": "12 hours", - "parameters": [] + "parameters": [], }, { "name": "guest_stats", - "description": "Update guest usage statistics", + "description": "Update guest usage statistics", "interval": "1 hour", - "parameters": [] + "parameters": [], }, { "name": "rate_limit_cleanup", "description": "Clean up old rate limiting data", - "interval": "24 hours", - "parameters": ["days_old (default: 7)"] + "interval": "24 hours", + "parameters": ["days_old (default: 7)"], }, { "name": "orphaned_cleanup", "description": "Clean up orphaned database records", "interval": "6 hours", - "parameters": [] - } + "parameters": [], + }, ] - - return create_success_response({ - "total_tasks": len(tasks), - "tasks": tasks - }) - + + return create_success_response({"total_tasks": len(tasks), "tasks": tasks}) + except Exception as e: logger.error(f"❌ Error listing tasks: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("LIST_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("LIST_ERROR", str(e))) + @router.post("/tasks/restart") -async def restart_background_tasks( - request: Request, - admin_user = Depends(get_current_admin) -): +async def restart_background_tasks(request: Request, admin_user=Depends(get_current_admin)): """Restart the background task manager (admin only)""" try: - database_manager = getattr(request.app.state, 'database_manager', None) - background_task_manager = getattr(request.app.state, 'background_task_manager', None) - + database_manager = getattr(request.app.state, "database_manager", None) + background_task_manager = getattr(request.app.state, "background_task_manager", None) + if not database_manager: return JSONResponse( - status_code=503, - content=create_error_response( - "DATABASE_UNAVAILABLE", - "Database manager not available" - ) + status_code=503, content=create_error_response("DATABASE_UNAVAILABLE", "Database manager not available") ) - + # Stop existing background tasks if background_task_manager: await background_task_manager.stop() logger.info("🛑 Stopped existing background task manager") - + # Create and start new background task manager from background_tasks import BackgroundTaskManager + new_background_task_manager = BackgroundTaskManager(database_manager) await new_background_task_manager.start() - + # Update app state request.app.state.background_task_manager = new_background_task_manager - + # Get status of new manager status = await new_background_task_manager.get_task_status() - - return create_success_response({ - "message": "Background task manager restarted successfully", - "new_status": status - }) - + + return create_success_response( + {"message": "Background task manager restarted successfully", "new_status": status} + ) + except Exception as e: logger.error(f"❌ Error restarting background tasks: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("RESTART_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("RESTART_ERROR", str(e))) + # ============================ # Task Monitoring and Metrics # ============================ + class TaskMetrics: """Collect metrics for background tasks""" - + def __init__(self): self.task_runs = {} self.task_durations = {} self.task_errors = {} - + def record_task_run(self, task_name: str, duration: float, success: bool = True): """Record a task execution""" if task_name not in self.task_runs: self.task_runs[task_name] = 0 self.task_durations[task_name] = [] self.task_errors[task_name] = 0 - + self.task_runs[task_name] += 1 self.task_durations[task_name].append(duration) - + if not success: self.task_errors[task_name] += 1 - + # Keep only last 100 durations to prevent memory growth if len(self.task_durations[task_name]) > 100: self.task_durations[task_name] = self.task_durations[task_name][-100:] - + def get_metrics(self) -> dict: """Get task metrics summary""" metrics = {} - + for task_name in self.task_runs: durations = self.task_durations[task_name] avg_duration = sum(durations) / len(durations) if durations else 0 - + metrics[task_name] = { "total_runs": self.task_runs[task_name], "total_errors": self.task_errors[task_name], - "success_rate": (self.task_runs[task_name] - self.task_errors[task_name]) / self.task_runs[task_name] if self.task_runs[task_name] > 0 else 0, + "success_rate": (self.task_runs[task_name] - self.task_errors[task_name]) / self.task_runs[task_name] + if self.task_runs[task_name] > 0 + else 0, "average_duration": avg_duration, - "last_runs": durations[-10:] if durations else [] + "last_runs": durations[-10:] if durations else [], } - + return metrics + # Global task metrics task_metrics = TaskMetrics() + @router.get("/tasks/metrics") -async def get_task_metrics( - admin_user = Depends(get_current_admin) -): +async def get_task_metrics(admin_user=Depends(get_current_admin)): """Get background task metrics (admin only)""" try: global task_metrics metrics = task_metrics.get_metrics() - - return create_success_response({ - "metrics": metrics, - "timestamp": datetime.now(UTC).isoformat() - }) - + + return create_success_response({"metrics": metrics, "timestamp": datetime.now(UTC).isoformat()}) + except Exception as e: logger.error(f"❌ Get task metrics error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("METRICS_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("METRICS_ERROR", str(e))) # ============================ @@ -581,135 +518,125 @@ async def get_task_metrics( # ============================ # @router.get("/verification-stats") async def get_verification_statistics( - current_user = Depends(get_current_admin), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database) ): """Get verification statistics (admin only)""" try: if not current_user.is_admin: raise HTTPException(status_code=403, detail="Admin access required") - + stats = { "pending_verifications": await database.get_pending_verifications_count(), - "expired_tokens_cleaned": await database.cleanup_expired_verification_tokens() + "expired_tokens_cleaned": await database.cleanup_expired_verification_tokens(), } - + return create_success_response(stats) - + except Exception as e: logger.error(f"❌ Error getting verification stats: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("STATS_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("STATS_ERROR", str(e))) + @router.post("/cleanup-verifications") async def cleanup_verification_tokens( - current_user = Depends(get_current_admin), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database) ): """Manually trigger cleanup of expired verification tokens (admin only)""" try: if not current_user.is_admin: raise HTTPException(status_code=403, detail="Admin access required") - + cleaned_count = await database.cleanup_expired_verification_tokens() - + logger.info(f"🧹 Manual cleanup completed by admin {current_user.id}: {cleaned_count} tokens cleaned") - - return create_success_response({ - "message": f"Cleanup completed. Removed {cleaned_count} expired verification tokens.", - "cleaned_count": cleaned_count - }) - + + return create_success_response( + { + "message": f"Cleanup completed. Removed {cleaned_count} expired verification tokens.", + "cleaned_count": cleaned_count, + } + ) + except Exception as e: logger.error(f"❌ Error in manual cleanup: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("CLEANUP_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e))) + @router.get("/pending-verifications") async def get_pending_verifications( - current_user = Depends(get_current_admin), + current_user=Depends(get_current_admin), page: int = Query(1, ge=1), limit: int = Query(20, ge=1, le=100), - database: RedisDatabase = Depends(get_database) + database: RedisDatabase = Depends(get_database), ): """Get list of pending email verifications (admin only)""" try: if not current_user.is_admin: raise HTTPException(status_code=403, detail="Admin access required") - + pattern = "email_verification:*" cursor = 0 pending_verifications = [] current_time = datetime.now(timezone.utc) - + while True: cursor, keys = await database.redis.scan(cursor, match=pattern, count=100) - + for key in keys: token_data = await database.redis.get(key) if token_data: verification_info = json.loads(token_data) if not verification_info.get("verified", False): expires_at = datetime.fromisoformat(verification_info.get("expires_at", "")) - - pending_verifications.append({ - "email": verification_info.get("email"), - "user_type": verification_info.get("user_type"), - "created_at": verification_info.get("created_at"), - "expires_at": verification_info.get("expires_at"), - "is_expired": current_time > expires_at, - "resend_count": verification_info.get("resend_count", 0) - }) - + + pending_verifications.append( + { + "email": verification_info.get("email"), + "user_type": verification_info.get("user_type"), + "created_at": verification_info.get("created_at"), + "expires_at": verification_info.get("expires_at"), + "is_expired": current_time > expires_at, + "resend_count": verification_info.get("resend_count", 0), + } + ) + if cursor == 0: break - + # Sort by creation date (newest first) pending_verifications.sort(key=lambda x: x["created_at"], reverse=True) - + # Apply pagination total = len(pending_verifications) start = (page - 1) * limit end = start + limit paginated_verifications = pending_verifications[start:end] - - paginated_response = create_paginated_response( - paginated_verifications, - page, limit, total - ) - + + paginated_response = create_paginated_response(paginated_verifications, page, limit, total) + return create_success_response(paginated_response) - + except Exception as e: logger.error(f"❌ Error getting pending verifications: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("FETCH_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e))) + @router.get("/rate-limits/info") async def get_user_rate_limit_status( - current_user = Depends(get_current_user_or_guest), + current_user=Depends(get_current_user_or_guest), rate_limiter: RateLimiter = Depends(get_rate_limiter), - database: RedisDatabase = Depends(get_database) + database: RedisDatabase = Depends(get_database), ): """Get rate limit status for a user (admin only)""" try: # Get user to determine type user_data = await database.get_user_by_id(current_user.id) if not user_data: - return JSONResponse( - status_code=404, - content=create_error_response("USER_NOT_FOUND", "User not found") - ) - + return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found")) + user_type = user_data.get("type", "unknown") is_admin = False - + if user_type == "candidate": candidate_data = await database.get_candidate(current_user.id) if candidate_data: @@ -718,38 +645,33 @@ async def get_user_rate_limit_status( employer_data = await database.get_employer(current_user.id) if employer_data: is_admin = employer_data.get("is_admin", False) - + status = await rate_limiter.get_user_rate_limit_status(current_user.id, user_type, is_admin) - + return create_success_response(status) - + except Exception as e: logger.error(f"❌ Get rate limit status error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("STATUS_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("STATUS_ERROR", str(e))) + @router.get("/rate-limits/{user_id}") async def get_anyone_rate_limit_status( user_id: str = Path(...), - admin_user = Depends(get_current_admin), + admin_user=Depends(get_current_admin), rate_limiter: RateLimiter = Depends(get_rate_limiter), - database: RedisDatabase = Depends(get_database) + database: RedisDatabase = Depends(get_database), ): """Get rate limit status for a user (admin only)""" try: # Get user to determine type user_data = await database.get_user_by_id(user_id) if not user_data: - return JSONResponse( - status_code=404, - content=create_error_response("USER_NOT_FOUND", "User not found") - ) - + return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found")) + user_type = user_data.get("type", "unknown") is_admin = False - + if user_type == "candidate": candidate_data = await database.get_candidate(user_id) if candidate_data: @@ -758,54 +680,43 @@ async def get_anyone_rate_limit_status( employer_data = await database.get_employer(user_id) if employer_data: is_admin = employer_data.get("is_admin", False) - + status = await rate_limiter.get_user_rate_limit_status(user_id, user_type, is_admin) - + return create_success_response(status) - + except Exception as e: logger.error(f"❌ Get rate limit status error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("STATUS_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("STATUS_ERROR", str(e))) + @router.post("/rate-limits/{user_id}/reset") async def reset_user_rate_limits( user_id: str = Path(...), - admin_user = Depends(get_current_admin), + admin_user=Depends(get_current_admin), rate_limiter: RateLimiter = Depends(get_rate_limiter), - database: RedisDatabase = Depends(get_database) + database: RedisDatabase = Depends(get_database), ): """Reset rate limits for a user (admin only)""" try: # Get user to determine type user_data = await database.get_user_by_id(user_id) if not user_data: - return JSONResponse( - status_code=404, - content=create_error_response("USER_NOT_FOUND", "User not found") - ) - + return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found")) + user_type = user_data.get("type", "unknown") success = await rate_limiter.reset_user_rate_limits(user_id, user_type) - + if success: logger.info(f"🔄 Rate limits reset for {user_type} {user_id} by admin {admin_user.id}") - return create_success_response({ - "message": f"Rate limits reset for {user_type} {user_id}", - "resetBy": admin_user.id - }) + return create_success_response( + {"message": f"Rate limits reset for {user_type} {user_id}", "resetBy": admin_user.id} + ) else: return JSONResponse( - status_code=500, - content=create_error_response("RESET_FAILED", "Failed to reset rate limits") + status_code=500, content=create_error_response("RESET_FAILED", "Failed to reset rate limits") ) - + except Exception as e: logger.error(f"❌ Reset rate limits error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("RESET_ERROR", str(e)) - ) - + return JSONResponse(status_code=500, content=create_error_response("RESET_ERROR", str(e))) diff --git a/src/backend/routes/auth.py b/src/backend/routes/auth.py index 6efd667..3a20a9b 100644 --- a/src/backend/routes/auth.py +++ b/src/backend/routes/auth.py @@ -20,19 +20,26 @@ from device_manager import DeviceManager from email_service import VerificationEmailRateLimiter, email_service from logger import logger from models import ( - LoginRequest, CreateCandidateRequest, Candidate, - Employer, Guest, AuthResponse, MFARequest, - MFAData, MFAVerifyRequest, ResendVerificationRequest, - MFARequestResponse, MFARequestResponse -) -from utils.dependencies import ( - get_current_admin, get_database, get_current_user, create_access_token + LoginRequest, + CreateCandidateRequest, + Candidate, + Employer, + Guest, + AuthResponse, + MFARequest, + MFAData, + MFAVerifyRequest, + ResendVerificationRequest, + MFARequestResponse, + MFARequestResponse, ) +from utils.dependencies import get_current_admin, get_database, get_current_user, create_access_token from utils.responses import create_success_response, create_error_response from utils.rate_limiter import get_rate_limiter from utils.auth_utils import ( - AuthenticationManager, SecurityConfig, - validate_password_strength, + AuthenticationManager, + SecurityConfig, + validate_password_strength, ) # Create router for authentication endpoints @@ -43,59 +50,56 @@ if JWT_SECRET_KEY == "": raise ValueError("JWT_SECRET_KEY environment variable is not set") ALGORITHM = "HS256" + # ============================ # Password Reset Endpoints # ============================ class PasswordResetRequest(BaseModel): email: EmailStr + class PasswordResetConfirm(BaseModel): token: str new_password: str - - @field_validator('new_password') + + @field_validator("new_password") def validate_password_strength(cls, v): is_valid, issues = validate_password_strength(v) if not is_valid: - raise ValueError('; '.join(issues)) + raise ValueError("; ".join(issues)) return v + @router.post("/guest") async def create_guest_session_enhanced( request: Request, database: RedisDatabase = Depends(get_database), - rate_limiter: RateLimiter = Depends(get_rate_limiter) + rate_limiter: RateLimiter = Depends(get_rate_limiter), ): """Create a guest session with enhanced validation and persistence""" try: # Apply rate limiting for guest creation ip_address = request.client.host if request.client else "unknown" - + # Check rate limits for guest session creation rate_result = await rate_limiter.check_rate_limit( - user_id=ip_address, - user_type="guest_creation", - is_admin=False, - endpoint="/guest" + user_id=ip_address, user_type="guest_creation", is_admin=False, endpoint="/guest" ) - + if not rate_result.allowed: logger.warning(f"🚫 Guest creation rate limit exceeded for IP {ip_address}") return JSONResponse( status_code=429, - content=create_error_response( - "RATE_LIMITED", - rate_result.reason or "Too many guest sessions created" - ), - headers={"Retry-After": str(rate_result.retry_after_seconds or 300)} + content=create_error_response("RATE_LIMITED", rate_result.reason or "Too many guest sessions created"), + headers={"Retry-After": str(rate_result.retry_after_seconds or 300)}, ) - + # Generate unique guest identifier with timestamp for uniqueness current_time = datetime.now(UTC) guest_id = str(uuid.uuid4()) session_id = f"guest_{int(current_time.timestamp())}_{secrets.token_hex(8)}" guest_username = f"guest-{session_id[-12:]}" - + # Verify username is unique (unlikely but possible collision) while True: existing_user = await database.get_user(guest_username) @@ -105,7 +109,7 @@ async def create_guest_session_enhanced( guest_username = f"guest-{session_id[-16:]}" else: break - + # Create guest user data with comprehensive info guest_data = { "id": guest_id, @@ -113,7 +117,7 @@ async def create_guest_session_enhanced( "username": guest_username, "email": f"{guest_username}@guest.backstory.ketrenos.com", "first_name": "Guest", - "last_name": "User", + "last_name": "User", "full_name": "Guest User", "user_type": "guest", "created_at": current_time.isoformat(), @@ -126,12 +130,12 @@ async def create_guest_session_enhanced( "user_agent": request.headers.get("user-agent", "Unknown"), "converted_to_user_id": None, "browser_session": True, # Mark as browser session - "persistent": True, # Mark as persistent + "persistent": True, # Mark as persistent } - + # Store guest with enhanced persistence await database.set_guest(guest_id, guest_data) - + # Create user lookup records user_auth_data = { "id": guest_id, @@ -139,38 +143,37 @@ async def create_guest_session_enhanced( "email": guest_data["email"], "username": guest_username, "session_id": session_id, - "created_at": current_time.isoformat() + "created_at": current_time.isoformat(), } - + await database.set_user(guest_data["email"], user_auth_data) await database.set_user(guest_username, user_auth_data) await database.set_user_by_id(guest_id, user_auth_data) - + # Create authentication tokens with longer expiry for guests access_token = create_access_token( data={"sub": guest_id, "type": "guest"}, - expires_delta=timedelta(hours=48) # Longer expiry for guests + expires_delta=timedelta(hours=48), # Longer expiry for guests ) refresh_token = create_access_token( data={"sub": guest_id, "type": "refresh_guest"}, - expires_delta=timedelta(days=14) # 2 weeks refresh for guests + expires_delta=timedelta(days=14), # 2 weeks refresh for guests ) - + # Verify guest was stored correctly verification = await database.get_guest(guest_id) if not verification: logger.error(f"❌ Failed to verify guest storage: {guest_id}") return JSONResponse( - status_code=500, - content=create_error_response("STORAGE_ERROR", "Failed to create guest session") + status_code=500, content=create_error_response("STORAGE_ERROR", "Failed to create guest session") ) - + # Create guest object for response guest = Guest.model_validate(guest_data) - + # Log successful creation logger.info(f"👤 Guest session created and verified: {guest_username} (ID: {guest_id}) from IP: {ip_address}") - + # Create auth response auth_response = { "accessToken": access_token, @@ -178,68 +181,61 @@ async def create_guest_session_enhanced( "user": guest.model_dump(by_alias=True), "expiresAt": int((current_time + timedelta(hours=48)).timestamp()), "userType": "guest", - "isGuest": True + "isGuest": True, } - + return create_success_response(auth_response) - + except Exception as e: logger.error(f"❌ Guest session creation error: {e}") import traceback + logger.error(traceback.format_exc()) return JSONResponse( - status_code=500, - content=create_error_response("GUEST_CREATION_FAILED", "Failed to create guest session") + status_code=500, content=create_error_response("GUEST_CREATION_FAILED", "Failed to create guest session") ) - + + @router.post("/guest/convert") async def convert_guest_to_user( registration_data: Dict[str, Any] = Body(...), - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user), + database: RedisDatabase = Depends(get_database), ): """Convert a guest session to a permanent user account""" try: # Verify current user is a guest if current_user.user_type != "guest": return JSONResponse( - status_code=400, - content=create_error_response("NOT_GUEST", "Only guest users can be converted") + status_code=400, content=create_error_response("NOT_GUEST", "Only guest users can be converted") ) - + guest: Guest = current_user account_type = registration_data.get("accountType", "candidate") - + if account_type == "candidate": # Validate candidate registration data try: candidate_request = CreateCandidateRequest.model_validate(registration_data) except ValidationError as e: - return JSONResponse( - status_code=400, - content=create_error_response("VALIDATION_ERROR", str(e)) - ) - + return JSONResponse(status_code=400, content=create_error_response("VALIDATION_ERROR", str(e))) + # Check if email/username already exists auth_manager = AuthenticationManager(database) user_exists, conflict_field = await auth_manager.check_user_exists( - candidate_request.email, - candidate_request.username + candidate_request.email, candidate_request.username ) - + if user_exists: return JSONResponse( status_code=409, - content=create_error_response( - "USER_EXISTS", - f"A user with this {conflict_field} already exists" - ) + content=create_error_response("USER_EXISTS", f"A user with this {conflict_field} already exists"), ) - + # Create candidate candidate_id = str(uuid.uuid4()) current_time = datetime.now(timezone.utc) - + candidate_data = { "id": candidate_id, "user_type": "candidate", @@ -253,76 +249,78 @@ async def convert_guest_to_user( "updated_at": current_time.isoformat(), "status": "active", "is_admin": False, - "converted_from_guest": guest.id + "converted_from_guest": guest.id, } - + candidate = Candidate.model_validate(candidate_data) - + # Create authentication await auth_manager.create_user_authentication(candidate_id, candidate_request.password) - + # Store candidate await database.set_candidate(candidate_id, candidate.model_dump()) - + # Update user lookup records user_auth_data = { "id": candidate_id, "type": "candidate", "email": candidate.email, - "username": candidate.username + "username": candidate.username, } - + await database.set_user(candidate.email, user_auth_data) await database.set_user(candidate.username, user_auth_data) await database.set_user_by_id(candidate_id, user_auth_data) - + # Mark guest as converted guest_data = guest.model_dump() guest_data["converted_to_user_id"] = candidate_id guest_data["updated_at"] = current_time.isoformat() await database.set_guest(guest.id, guest_data) - + # Create new tokens for the candidate access_token = create_access_token(data={"sub": candidate_id}) refresh_token = create_access_token( data={"sub": candidate_id, "type": "refresh"}, - expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS) + expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS), ) - + auth_response = AuthResponse( access_token=access_token, refresh_token=refresh_token, user=candidate, - expires_at=int((current_time + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()) + expires_at=int((current_time + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()), ) - + logger.info(f"✅ Guest {guest.session_id} converted to candidate {candidate.username}") - - return create_success_response({ - "message": "Guest account successfully converted to candidate", - "auth": auth_response.model_dump(by_alias=True), - "conversionType": "candidate" - }) - + + return create_success_response( + { + "message": "Guest account successfully converted to candidate", + "auth": auth_response.model_dump(by_alias=True), + "conversionType": "candidate", + } + ) + else: return JSONResponse( status_code=400, - content=create_error_response("INVALID_TYPE", "Only candidate conversion is currently supported") + content=create_error_response("INVALID_TYPE", "Only candidate conversion is currently supported"), ) - + except Exception as e: logger.error(f"❌ Guest conversion error: {e}") return JSONResponse( - status_code=500, - content=create_error_response("CONVERSION_FAILED", "Failed to convert guest account") + status_code=500, content=create_error_response("CONVERSION_FAILED", "Failed to convert guest account") ) - + + @router.post("/logout") async def logout( access_token: str = Body(..., alias="accessToken"), refresh_token: str = Body(..., alias="refreshToken"), - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user), + database: RedisDatabase = Depends(get_database), ): """Logout endpoint - revokes both access and refresh tokens""" logger.info(f"🔑 User {current_user.id} is logging out") @@ -333,51 +331,50 @@ async def logout( user_id = refresh_payload.get("sub") token_type = refresh_payload.get("type") refresh_exp = refresh_payload.get("exp") - + if not user_id or token_type != "refresh": return JSONResponse( - status_code=401, - content=create_error_response("INVALID_TOKEN", "Invalid refresh token") + status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token") ) except jwt.PyJWTError as e: logger.warning(f"⚠️ Invalid refresh token during logout: {e}") return JSONResponse( - status_code=401, - content=create_error_response("INVALID_TOKEN", "Invalid refresh token") + status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token") ) - + # Verify that the refresh token belongs to the current user if user_id != current_user.id: return JSONResponse( - status_code=403, - content=create_error_response("FORBIDDEN", "Token does not belong to current user") + status_code=403, content=create_error_response("FORBIDDEN", "Token does not belong to current user") ) - + # Get Redis client redis = redis_manager.get_client() - + # Revoke refresh token (blacklist it until its natural expiration) refresh_ttl = max(0, refresh_exp - int(datetime.now(UTC).timestamp())) if refresh_ttl > 0: await redis.setex( - f"blacklisted_token:{refresh_token}", - refresh_ttl, - json.dumps({ - "user_id": user_id, - "token_type": "refresh", - "revoked_at": datetime.now(UTC).isoformat(), - "reason": "user_logout" - }) + f"blacklisted_token:{refresh_token}", + refresh_ttl, + json.dumps( + { + "user_id": user_id, + "token_type": "refresh", + "revoked_at": datetime.now(UTC).isoformat(), + "reason": "user_logout", + } + ), ) logger.info(f"🔒 Blacklisted refresh token for user {user_id}") - + # If access token is provided, revoke it too if access_token: try: access_payload = jwt.decode(access_token, JWT_SECRET_KEY, algorithms=[ALGORITHM]) access_user_id = access_payload.get("sub") access_exp = access_payload.get("exp") - + # Verify access token belongs to same user if access_user_id == user_id: access_ttl = max(0, access_exp - int(datetime.now(UTC).timestamp())) @@ -385,12 +382,14 @@ async def logout( await redis.setex( f"blacklisted_token:{access_token}", access_ttl, - json.dumps({ - "user_id": user_id, - "token_type": "access", - "revoked_at": datetime.now(UTC).isoformat(), - "reason": "user_logout" - }) + json.dumps( + { + "user_id": user_id, + "token_type": "access", + "revoked_at": datetime.now(UTC).isoformat(), + "reason": "user_logout", + } + ), ) logger.info(f"🔒 Blacklisted access token for user {user_id}") else: @@ -398,64 +397,53 @@ async def logout( except jwt.PyJWTError as e: logger.warning(f"⚠️ Invalid access token during logout (non-critical): {e}") # Don't fail logout if access token is invalid - + # Optional: Revoke all tokens for this user (for "logout from all devices") # Uncomment the following lines if you want to implement this feature: - # + # # await redis.setex( # f"user_tokens_revoked:{user_id}", # timedelta(days=30).total_seconds(), # Max refresh token lifetime # datetime.now(UTC).isoformat() # ) - + logger.info(f"🔑 User {user_id} logged out successfully") - return create_success_response({ - "message": "Logged out successfully", - "tokensRevoked": { - "refreshToken": True, - "accessToken": bool(access_token) + return create_success_response( + { + "message": "Logged out successfully", + "tokensRevoked": {"refreshToken": True, "accessToken": bool(access_token)}, } - }) - - except Exception as e: - logger.error(f"❌ Logout error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("LOGOUT_ERROR", str(e)) ) + except Exception as e: + logger.error(f"❌ Logout error: {e}") + return JSONResponse(status_code=500, content=create_error_response("LOGOUT_ERROR", str(e))) + + @router.post("/logout-all") -async def logout_all_devices( - current_user = Depends(get_current_admin), - database: RedisDatabase = Depends(get_database) -): +async def logout_all_devices(current_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)): """Logout from all devices by revoking all tokens for the user""" try: redis = redis_manager.get_client() - + # Set a timestamp that invalidates all tokens issued before this moment await redis.setex( f"user_tokens_revoked:{current_user.id}", int(timedelta(days=30).total_seconds()), # Max refresh token lifetime - datetime.now(UTC).isoformat() + datetime.now(UTC).isoformat(), ) - + logger.info(f"🔒 All tokens revoked for user {current_user.id}") - return create_success_response({ - "message": "Logged out from all devices successfully" - }) - + return create_success_response({"message": "Logged out from all devices successfully"}) + except Exception as e: logger.error(f"❌ Logout all devices error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("LOGOUT_ALL_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("LOGOUT_ALL_ERROR", str(e))) -@router.post("/refresh") + +@router.post("/refresh") async def refresh_token_endpoint( - refresh_token: str = Body(..., alias="refreshToken"), - database: RedisDatabase = Depends(get_database) + refresh_token: str = Body(..., alias="refreshToken"), database: RedisDatabase = Depends(get_database) ): """Refresh token endpoint""" try: @@ -463,16 +451,15 @@ async def refresh_token_endpoint( payload = jwt.decode(refresh_token, JWT_SECRET_KEY, algorithms=[ALGORITHM]) user_id = payload.get("sub") token_type = payload.get("type") - + if not user_id or token_type != "refresh": return JSONResponse( - status_code=401, - content=create_error_response("INVALID_TOKEN", "Invalid refresh token") + status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token") ) - + # Create new access token access_token = create_access_token(data={"sub": user_id}) - + # Get user user = None candidate_data = await database.get_candidate(user_id) @@ -482,84 +469,77 @@ async def refresh_token_endpoint( employer_data = await database.get_employer(user_id) if employer_data: user = Employer.model_validate(employer_data) - + if not user: - return JSONResponse( - status_code=404, - content=create_error_response("USER_NOT_FOUND", "User not found") - ) - + return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found")) + auth_response = AuthResponse( access_token=access_token, refresh_token=refresh_token, # Keep same refresh token user=user, - expires_at=int((datetime.now(UTC) + timedelta(hours=24)).timestamp()) + expires_at=int((datetime.now(UTC) + timedelta(hours=24)).timestamp()), ) - + return create_success_response(auth_response.model_dump(by_alias=True)) - + except jwt.PyJWTError: - return JSONResponse( - status_code=401, - content=create_error_response("INVALID_TOKEN", "Invalid refresh token") - ) + return JSONResponse(status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token")) except Exception as e: logger.error(f"❌ Token refresh error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("REFRESH_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("REFRESH_ERROR", str(e))) + @router.post("/resend-verification") async def resend_verification_email( request: ResendVerificationRequest, background_tasks: BackgroundTasks, - database: RedisDatabase = Depends(get_database) + database: RedisDatabase = Depends(get_database), ): """Resend verification email with comprehensive rate limiting and validation""" try: email_lower = request.email.lower().strip() - + # Initialize rate limiter rate_limiter = VerificationEmailRateLimiter(database) - + # Check rate limiting can_send, reason = await rate_limiter.can_send_verification_email(email_lower) if not can_send: logger.warning(f"⚠️ Verification email rate limit exceeded for {email_lower}: {reason}") - return JSONResponse( - status_code=429, - content=create_error_response("RATE_LIMITED", reason) - ) - + return JSONResponse(status_code=429, content=create_error_response("RATE_LIMITED", reason)) + # Clean up expired tokens first await database.cleanup_expired_verification_tokens() - + # Check if user already exists and is verified user_data = await database.get_user(email_lower) if user_data: # User exists and is verified - don't reveal this for security logger.info(f"🔍 Resend verification requested for already verified user: {email_lower}") await rate_limiter.record_email_sent(email_lower) # Record attempt to prevent abuse - return create_success_response({ - "message": "If your email is in our system and pending verification, a new verification email has been sent." - }) - + return create_success_response( + { + "message": "If your email is in our system and pending verification, a new verification email has been sent." + } + ) + # Look for pending verification token verification_data = await database.find_verification_token_by_email(email_lower) - + if not verification_data: # No pending verification found - don't reveal this for security logger.info(f"🔍 Resend verification requested for non-existent pending verification: {email_lower}") await rate_limiter.record_email_sent(email_lower) # Record attempt to prevent abuse - return create_success_response({ - "message": "If your email is in our system and pending verification, a new verification email has been sent." - }) - + return create_success_response( + { + "message": "If your email is in our system and pending verification, a new verification email has been sent." + } + ) + # Check if verification token has expired expires_at = datetime.fromisoformat(verification_data["expires_at"]) current_time = datetime.now(timezone.utc) - + if current_time > expires_at: # Token expired - clean it up and inform user await database.redis.delete(f"email_verification:{verification_data['token']}") @@ -567,36 +547,35 @@ async def resend_verification_email( return JSONResponse( status_code=400, content=create_error_response( - "TOKEN_EXPIRED", - "Your verification link has expired. Please register again to create a new account." - ) + "TOKEN_EXPIRED", + "Your verification link has expired. Please register again to create a new account.", + ), ) - + # Generate new verification token (invalidate old one) old_token = verification_data["token"] new_token = secrets.token_urlsafe(32) - + # Update verification data with new token and reset attempts - verification_data.update({ - "token": new_token, - "expires_at": (current_time + timedelta(hours=24)).isoformat(), - "resent_at": current_time.isoformat(), - "resend_count": verification_data.get("resend_count", 0) + 1 - }) - + verification_data.update( + { + "token": new_token, + "expires_at": (current_time + timedelta(hours=24)).isoformat(), + "resent_at": current_time.isoformat(), + "resend_count": verification_data.get("resend_count", 0) + 1, + } + ) + # Store new token and remove old one await database.redis.delete(f"email_verification:{old_token}") await database.store_email_verification_token( - email_lower, - new_token, - verification_data["user_type"], - verification_data["user_data"] + email_lower, new_token, verification_data["user_type"], verification_data["user_data"] ) - + # Get user name for email user_data_container = verification_data["user_data"] user_type = verification_data["user_type"] - + if user_type == "candidate": candidate_data = user_data_container["candidate_data"] user_name = candidate_data.get("fullName", "User") @@ -605,93 +584,83 @@ async def resend_verification_email( user_name = employer_data.get("companyName", "User") else: user_name = "User" - + # Record email attempt await rate_limiter.record_email_sent(email_lower) - + # Send new verification email in background - background_tasks.add_task( - email_service.send_verification_email, - email_lower, - new_token, - user_name, - user_type - ) - + background_tasks.add_task(email_service.send_verification_email, email_lower, new_token, user_name, user_type) + # Log security event await database.log_security_event( - verification_data["user_data"].get("candidate_data", {}).get("id") or - verification_data["user_data"].get("employer_data", {}).get("id") or "unknown", + verification_data["user_data"].get("candidate_data", {}).get("id") + or verification_data["user_data"].get("employer_data", {}).get("id") + or "unknown", "verification_resend", { "email": email_lower, "user_type": user_type, "resend_count": verification_data.get("resend_count", 1), "old_token_invalidated": old_token[:8] + "...", # Log partial token for debugging - "ip_address": "unknown" # You can extract this from request if needed + "ip_address": "unknown", # You can extract this from request if needed + }, + ) + + logger.info( + f"✅ Verification email resent to {email_lower} (attempt #{verification_data.get('resend_count', 1)})" + ) + + return create_success_response( + { + "message": "A new verification email has been sent to your email address. Please check your inbox and spam folder.", + "resendCount": verification_data.get("resend_count", 1), } ) - - logger.info(f"✅ Verification email resent to {email_lower} (attempt #{verification_data.get('resend_count', 1)})") - - return create_success_response({ - "message": "A new verification email has been sent to your email address. Please check your inbox and spam folder.", - "resendCount": verification_data.get("resend_count", 1) - }) - + except ValueError as ve: logger.warning(f"⚠️ Invalid resend verification request: {ve}") - return JSONResponse( - status_code=400, - content=create_error_response("VALIDATION_ERROR", str(ve)) - ) + return JSONResponse(status_code=400, content=create_error_response("VALIDATION_ERROR", str(ve))) except Exception as e: logger.error(f"❌ Resend verification email error: {e}") return JSONResponse( status_code=500, - content=create_error_response("RESEND_FAILED", "An error occurred while processing your request. Please try again later.") + content=create_error_response( + "RESEND_FAILED", "An error occurred while processing your request. Please try again later." + ), ) - + + @router.post("/mfa/request") async def request_mfa( request: MFARequest, background_tasks: BackgroundTasks, http_request: Request, - database: RedisDatabase = Depends(get_database) + database: RedisDatabase = Depends(get_database), ): """Request MFA for login from new device""" try: # Verify credentials first auth_manager = AuthenticationManager(database) - is_valid, user_data, error_message = await auth_manager.verify_user_credentials( - request.email, - request.password - ) - + is_valid, user_data, error_message = await auth_manager.verify_user_credentials(request.email, request.password) + if not is_valid or not user_data: - return JSONResponse( - status_code=401, - content=create_error_response("AUTH_FAILED", "Invalid credentials") - ) - + return JSONResponse(status_code=401, content=create_error_response("AUTH_FAILED", "Invalid credentials")) + # Check if device is trusted device_manager = DeviceManager(database) device_manager.parse_device_info(http_request) - + is_trusted = await device_manager.is_trusted_device(user_data["id"], request.device_id) - + if is_trusted: # Device is trusted, proceed with normal login await device_manager.update_device_last_used(user_data["id"], request.device_id) - - return create_success_response({ - "mfa_required": False, - "message": "Device is trusted, proceed with login" - }) - + + return create_success_response({"mfa_required": False, "message": "Device is trusted, proceed with login"}) + # Generate MFA code mfa_code = f"{secrets.randbelow(1000000):06d}" # 6-digit code - + # Store MFA code # Get user name for email user_name = "User" @@ -709,25 +678,18 @@ async def request_mfa( if not email: return JSONResponse( - status_code=400, - content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA") + status_code=400, content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA") ) - + # Store MFA code await database.store_mfa_code(email, mfa_code, request.device_id) logger.info(f"🔐 MFA code generated for {email} on device {request.device_id}") # Send MFA code via email - background_tasks.add_task( - email_service.send_mfa_email, - email, - mfa_code, - request.device_name, - user_name - ) - + background_tasks.add_task(email_service.send_mfa_email, email, mfa_code, request.device_name, user_name) + logger.info(f"🔐 MFA requested for {request.email} from new device {request.device_name}") - + mfa_data = MFAData( message="New device detected. We've sent a security code to your email address.", code_sent=mfa_code, @@ -735,59 +697,52 @@ async def request_mfa( device_id=request.device_id, device_name=request.device_name, ) - mfa_response = MFARequestResponse( - mfa_required=True, - mfa_data=mfa_data - ) + mfa_response = MFARequestResponse(mfa_required=True, mfa_data=mfa_data) return create_success_response(mfa_response) - + except Exception as e: logger.error(f"❌ MFA request error: {e}") return JSONResponse( - status_code=500, - content=create_error_response("MFA_REQUEST_FAILED", "Failed to process MFA request") + status_code=500, content=create_error_response("MFA_REQUEST_FAILED", "Failed to process MFA request") ) + @router.post("/login") async def login( request: LoginRequest, http_request: Request, background_tasks: BackgroundTasks, - database: RedisDatabase = Depends(get_database) + database: RedisDatabase = Depends(get_database), ): """login with automatic MFA email sending for new devices""" try: # Initialize managers auth_manager = AuthenticationManager(database) device_manager = DeviceManager(database) - + # Parse device information device_info = device_manager.parse_device_info(http_request) device_id = device_info["device_id"] - + # Verify credentials first - is_valid, user_data, error_message = await auth_manager.verify_user_credentials( - request.login, - request.password - ) - + is_valid, user_data, error_message = await auth_manager.verify_user_credentials(request.login, request.password) + if not is_valid or not user_data: logger.warning(f"⚠️ Failed login attempt for: {request.login}") return JSONResponse( - status_code=401, - content=create_error_response("AUTH_FAILED", error_message or "Invalid credentials") + status_code=401, content=create_error_response("AUTH_FAILED", error_message or "Invalid credentials") ) - + # Check if device is trusted is_trusted = await device_manager.is_trusted_device(user_data["id"], device_id) - + if not is_trusted: # New device detected - automatically send MFA email logger.info(f"🔐 New device detected for {request.login}, sending MFA email") - + # Generate MFA code mfa_code = f"{secrets.randbelow(1000000):06d}" # 6-digit code - + # Get user name and details for email user_name = "User" email = None @@ -801,30 +756,24 @@ async def login( if employer_data: user_name = employer_data.get("company_name", "User") email = employer_data.get("email", None) - + if not email: return JSONResponse( - status_code=400, - content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA") + status_code=400, content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA") ) - + # Store MFA code await database.store_mfa_code(email, mfa_code, device_id) - + # Ensure email is lowercase # Get IP address for security info ip_address = http_request.client.host if http_request.client else "Unknown" - + # Send MFA code via email in background background_tasks.add_task( - email_service.send_mfa_email, - email, - mfa_code, - device_info["device_name"], - user_name, - ip_address + email_service.send_mfa_email, email, mfa_code, device_info["device_name"], user_name, ip_address ) - + # Log security event await database.log_security_event( user_data["id"], @@ -834,12 +783,12 @@ async def login( "device_name": device_info["device_name"], "ip_address": ip_address, "user_agent": device_info.get("user_agent", ""), - "auto_sent": True - } + "auto_sent": True, + }, ) - + logger.info(f"🔐 MFA code automatically sent to {request.login} for device {device_info['device_name']}") - + mfa_response = MFARequestResponse( mfa_required=True, mfa_data=MFAData( @@ -847,22 +796,22 @@ async def login( email=email, device_id=device_id, device_name=device_info["device_name"], - code_sent=mfa_code - ) + code_sent=mfa_code, + ), ) return create_success_response(mfa_response.model_dump(by_alias=True)) - + # Trusted device - proceed with normal login await device_manager.update_device_last_used(user_data["id"], device_id) await auth_manager.update_last_login(user_data["id"]) - + # Create tokens access_token = create_access_token(data={"sub": user_data["id"]}) refresh_token = create_access_token( data={"sub": user_data["id"], "type": "refresh"}, - expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS) + expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS), ) - + # Get user object user = None if user_data["type"] == "candidate": @@ -873,13 +822,12 @@ async def login( employer_data = await database.get_employer(user_data["id"]) if employer_data: user = Employer.model_validate(employer_data) - + if not user: return JSONResponse( - status_code=404, - content=create_error_response("USER_NOT_FOUND", "User profile not found") + status_code=404, content=create_error_response("USER_NOT_FOUND", "User profile not found") ) - + # Log successful login from trusted device await database.log_security_event( user_data["id"], @@ -888,55 +836,56 @@ async def login( "device_id": device_id, "device_name": device_info["device_name"], "ip_address": http_request.client.host if http_request.client else "Unknown", - "trusted_device": True - } + "trusted_device": True, + }, ) - + # Create response auth_response = AuthResponse( access_token=access_token, refresh_token=refresh_token, user=user, - expires_at=int((datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()) + expires_at=int( + (datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp() + ), ) - + logger.info(f"🔑 User {request.login} logged in successfully from trusted device") - + return create_success_response(auth_response.model_dump(by_alias=True)) - + except Exception as e: logger.error(backstory_traceback.format_exc()) logger.error(f"❌ Login error: {e}") return JSONResponse( - status_code=500, - content=create_error_response("LOGIN_ERROR", "An error occurred during login") + status_code=500, content=create_error_response("LOGIN_ERROR", "An error occurred during login") ) @router.post("/mfa/verify") -async def verify_mfa( - request: MFAVerifyRequest, - http_request: Request, - database: RedisDatabase = Depends(get_database) -): +async def verify_mfa(request: MFAVerifyRequest, http_request: Request, database: RedisDatabase = Depends(get_database)): """Verify MFA code and complete login with error handling""" try: # Get MFA data mfa_data = await database.get_mfa_code(request.email, request.device_id) - + if not mfa_data: logger.warning(f"⚠️ No MFA session found for {request.email} on device {request.device_id}") return JSONResponse( status_code=404, - content=create_error_response("NO_MFA_SESSION", "No active MFA session found. Please try logging in again.") + content=create_error_response( + "NO_MFA_SESSION", "No active MFA session found. Please try logging in again." + ), ) - + if mfa_data.get("verified"): return JSONResponse( status_code=400, - content=create_error_response("ALREADY_VERIFIED", "This MFA code has already been used. Please login again.") + content=create_error_response( + "ALREADY_VERIFIED", "This MFA code has already been used. Please login again." + ), ) - + # Check expiration expires_at = datetime.fromisoformat(mfa_data["expires_at"]) if datetime.now(timezone.utc) > expires_at: @@ -944,9 +893,9 @@ async def verify_mfa( await database.redis.delete(f"mfa_code:{request.email.lower()}:{request.device_id}") return JSONResponse( status_code=400, - content=create_error_response("MFA_EXPIRED", "MFA code has expired. Please try logging in again.") + content=create_error_response("MFA_EXPIRED", "MFA code has expired. Please try logging in again."), ) - + # Check attempts current_attempts = mfa_data.get("attempts", 0) if current_attempts >= 5: @@ -954,55 +903,49 @@ async def verify_mfa( await database.redis.delete(f"mfa_code:{request.email.lower()}:{request.device_id}") return JSONResponse( status_code=429, - content=create_error_response("TOO_MANY_ATTEMPTS", "Too many incorrect attempts. Please try logging in again.") + content=create_error_response( + "TOO_MANY_ATTEMPTS", "Too many incorrect attempts. Please try logging in again." + ), ) - + # Verify code if mfa_data["code"] != request.code: await database.increment_mfa_attempts(request.email, request.device_id) remaining_attempts = 5 - (current_attempts + 1) - + return JSONResponse( status_code=400, content=create_error_response( - "INVALID_CODE", - f"Invalid MFA code. {remaining_attempts} attempts remaining." - ) + "INVALID_CODE", f"Invalid MFA code. {remaining_attempts} attempts remaining." + ), ) - + # Mark as verified await database.mark_mfa_verified(request.email, request.device_id) - + # Get user data user_data = await database.get_user(request.email) if not user_data: - return JSONResponse( - status_code=404, - content=create_error_response("USER_NOT_FOUND", "User not found") - ) - + return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found")) + # Add device to trusted devices if requested if request.remember_device: device_manager = DeviceManager(database) device_info = device_manager.parse_device_info(http_request) - await device_manager.add_trusted_device( - user_data["id"], - request.device_id, - device_info - ) + await device_manager.add_trusted_device(user_data["id"], request.device_id, device_info) logger.info(f"🔒 Device {request.device_id} added to trusted devices for user {user_data['id']}") - + # Update last login auth_manager = AuthenticationManager(database) await auth_manager.update_last_login(user_data["id"]) - + # Create tokens access_token = create_access_token(data={"sub": user_data["id"]}) refresh_token = create_access_token( data={"sub": user_data["id"], "type": "refresh"}, - expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS) + expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS), ) - + # Get user object user = None if user_data["type"] == "candidate": @@ -1013,13 +956,12 @@ async def verify_mfa( employer_data = await database.get_employer(user_data["id"]) if employer_data: user = Employer.model_validate(employer_data) - + if not user: return JSONResponse( - status_code=404, - content=create_error_response("USER_NOT_FOUND", "User profile not found") + status_code=404, content=create_error_response("USER_NOT_FOUND", "User profile not found") ) - + # Log successful MFA verification and login await database.log_security_event( user_data["id"], @@ -1028,10 +970,10 @@ async def verify_mfa( "device_id": request.device_id, "ip_address": http_request.client.host if http_request.client else "Unknown", "device_remembered": request.remember_device, - "attempts_used": current_attempts + 1 - } + "attempts_used": current_attempts + 1, + }, ) - + await database.log_security_event( user_data["id"], "login", @@ -1039,38 +981,37 @@ async def verify_mfa( "device_id": request.device_id, "ip_address": http_request.client.host if http_request.client else "Unknown", "mfa_verified": True, - "new_device": True - } + "new_device": True, + }, ) - + # Clean up MFA session await database.redis.delete(f"mfa_code:{request.email.lower()}:{request.device_id}") - + # Create response auth_response = AuthResponse( access_token=access_token, refresh_token=refresh_token, user=user, - expires_at=int((datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()) + expires_at=int( + (datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp() + ), ) - + logger.info(f"✅ MFA verified and login completed for {request.email}") - + return create_success_response(auth_response.model_dump(by_alias=True)) - + except Exception as e: logger.error(backstory_traceback.format_exc()) logger.error(f"❌ MFA verification error: {e}") return JSONResponse( - status_code=500, - content=create_error_response("MFA_VERIFICATION_FAILED", "Failed to verify MFA") + status_code=500, content=create_error_response("MFA_VERIFICATION_FAILED", "Failed to verify MFA") ) + @router.post("/password-reset/request") -async def request_password_reset( - request: PasswordResetRequest, - database: RedisDatabase = Depends(get_database) -): +async def request_password_reset(request: PasswordResetRequest, database: RedisDatabase = Depends(get_database)): """Request password reset""" try: # Check if user exists @@ -1078,52 +1019,46 @@ async def request_password_reset( if not user_data: # Don't reveal whether email exists or not return create_success_response({"message": "If the email exists, a reset link will be sent"}) - + auth_manager = AuthenticationManager(database) - + # Generate reset token reset_token = auth_manager.password_security.generate_secure_token() reset_expiry = datetime.now(timezone.utc) + timedelta(hours=1) # 1 hour expiry - + # Update authentication record auth_record = await database.get_authentication(user_data["id"]) if auth_record: auth_record["resetPasswordToken"] = reset_token auth_record["resetPasswordExpiry"] = reset_expiry.isoformat() await database.set_authentication(user_data["id"], auth_record) - + # TODO: Send email with reset token logger.info(f"🔐 Password reset requested for: {request.email}") - + return create_success_response({"message": "If the email exists, a reset link will be sent"}) - + except Exception as e: logger.error(f"❌ Password reset request error: {e}") return JSONResponse( - status_code=500, - content=create_error_response("RESET_ERROR", "An error occurred processing the request") + status_code=500, content=create_error_response("RESET_ERROR", "An error occurred processing the request") ) + @router.post("/password-reset/confirm") -async def confirm_password_reset( - request: PasswordResetConfirm, - database: RedisDatabase = Depends(get_database) -): +async def confirm_password_reset(request: PasswordResetConfirm, database: RedisDatabase = Depends(get_database)): """Confirm password reset with token""" try: # Find user by reset token # This would require a way to lookup by token - you might need to modify your database structure - + # For now, this is a placeholder - you'd need to implement token lookup # in your Redis database structure - + return create_success_response({"message": "Password reset successfully"}) - + except Exception as e: logger.error(f"❌ Password reset confirm error: {e}") return JSONResponse( - status_code=500, - content=create_error_response("RESET_ERROR", "An error occurred resetting the password") + status_code=500, content=create_error_response("RESET_ERROR", "An error occurred resetting the password") ) - - diff --git a/src/backend/routes/candidates.py b/src/backend/routes/candidates.py index 4ec9d83..4104668 100644 --- a/src/backend/routes/candidates.py +++ b/src/backend/routes/candidates.py @@ -19,19 +19,54 @@ import backstory_traceback as backstory_traceback from agents.generate_resume import GenerateResume import agents.base as agents from utils.rate_limiter import rate_limited -from utils.helpers import filter_and_paginate, get_document_type_from_filename, get_skill_cache_key, get_requirements_list +from utils.helpers import ( + filter_and_paginate, + get_document_type_from_filename, + get_skill_cache_key, + get_requirements_list, +) from database.manager import RedisDatabase from email_service import email_service from logger import logger from models import ( - MOCK_UUID, ApiActivityType, ApiMessageType, ApiStatusType, CandidateAI, ChatContextType, ChatMessageError, ChatMessageRagSearch, ChatMessageResume, ChatMessageSkillAssessment, ChatMessageStatus, ChatMessageStreaming, ChatMessageUser, ChatSession, Document, DocumentContentResponse, DocumentListResponse, DocumentMessage, DocumentOptions, DocumentType, DocumentUpdateRequest, Job, JobRequirements, CreateCandidateRequest, Candidate, RAGDocumentRequest, - RagContentMetadata, RagContentResponse, SkillAssessment, UserType + MOCK_UUID, + ApiActivityType, + ApiMessageType, + ApiStatusType, + CandidateAI, + ChatContextType, + ChatMessageError, + ChatMessageRagSearch, + ChatMessageResume, + ChatMessageSkillAssessment, + ChatMessageStatus, + ChatMessageStreaming, + ChatMessageUser, + ChatSession, + Document, + DocumentContentResponse, + DocumentListResponse, + DocumentMessage, + DocumentOptions, + DocumentType, + DocumentUpdateRequest, + Job, + JobRequirements, + CreateCandidateRequest, + Candidate, + RAGDocumentRequest, + RagContentMetadata, + RagContentResponse, + SkillAssessment, + UserType, ) from utils.dependencies import ( - get_current_admin, get_database, get_current_user, get_current_user_or_guest, - prometheus_collector + get_current_admin, + get_database, + get_current_user, + get_current_user_or_guest, + prometheus_collector, ) -from utils.rate_limiter import rate_limited from utils.responses import create_paginated_response, create_success_response, create_error_response import utils.llm_proxy as llm_manager import entities.entity_manager as entities @@ -41,6 +76,7 @@ from utils.auth_utils import AuthenticationManager # Create router for authentication endpoints router = APIRouter(prefix="/candidates", tags=["candidates"]) + # ============================ # Candidate Endpoints # ============================ @@ -49,25 +85,27 @@ async def create_candidate_ai( background_tasks: BackgroundTasks, user_message: ChatMessageUser = Body(...), admin: Candidate = Depends(get_current_admin), - database: RedisDatabase = Depends(get_database) + database: RedisDatabase = Depends(get_database), ): """Create a new candidate using AI-generated data""" try: - - generate_agent = agents.get_or_create_agent( - agent_type=ChatContextType.GENERATE_PERSONA, - prometheus_collector=prometheus_collector) + async with entities.get_candidate_entity(candidate=admin) as admin_entity: + generate_agent = agents.get_or_create_agent( + agent_type=ChatContextType.GENERATE_PERSONA, + prometheus_collector=prometheus_collector, + user=admin_entity, + ) if not generate_agent: logger.warning("⚠️ Unable to create AI generation agent.") return JSONResponse( status_code=400, - content=create_error_response("AGENT_NOT_FOUND", "Unable to create AI generation agent") + content=create_error_response("AGENT_NOT_FOUND", "Unable to create AI generation agent"), ) - + persona_message = None resume_message = None - state = 0 # 0 -- create persona, 1 -- create resume + state = 0 # 0 -- create persona, 1 -- create resume async for generated_message in generate_agent.generate( llm=llm_manager.get_llm(), model=defines.model, @@ -75,15 +113,14 @@ async def create_candidate_ai( prompt=user_message.content, ): if isinstance(generated_message, ChatMessageError): - error_message : ChatMessageError = generated_message + error_message: ChatMessageError = generated_message logger.error(f"❌ AI generation error: {error_message.content}") return JSONResponse( - status_code=500, - content=create_error_response("AI_GENERATION_ERROR", error_message.content) + status_code=500, content=create_error_response("AI_GENERATION_ERROR", error_message.content) ) if isinstance(generated_message, ChatMessageRagSearch): raise ValueError("AI generation returned a RAG search message instead of a persona") - + if generated_message.status == ApiStatusType.DONE and state == 0: persona_message = generated_message state = 1 # Switch to resume generation @@ -94,34 +131,40 @@ async def create_candidate_ai( logger.error("❌ AI generation failed: No message generated") return JSONResponse( status_code=500, - content=create_error_response("AI_GENERATION_FAILED", "Failed to generate AI candidate data") + content=create_error_response("AI_GENERATION_FAILED", "Failed to generate AI candidate data"), ) - + if not isinstance(persona_message, ChatMessageUser): logger.error(f"❌ AI generation returned unexpected message type: {type(persona_message)}") return JSONResponse( status_code=500, - content=create_error_response("AI_GENERATION_ERROR", "AI generation did not return a valid user message") + content=create_error_response( + "AI_GENERATION_ERROR", "AI generation did not return a valid user message" + ), ) - + if not isinstance(resume_message, ChatMessageUser): logger.error(f"❌ AI generation returned unexpected resume message type: {type(resume_message)}") return JSONResponse( status_code=500, - content=create_error_response("AI_GENERATION_ERROR", "AI generation did not return a valid resume message") - ) - + content=create_error_response( + "AI_GENERATION_ERROR", "AI generation did not return a valid resume message" + ), + ) + try: current_time = datetime.now(timezone.utc) candidate_data = json.loads(persona_message.content) - candidate_data.update({ - "user_type": "candidate", - "created_at": current_time.isoformat(), - "updated_at": current_time.isoformat(), - "status": "active", # Directly active for AI-generated candidates - "is_admin": False, # Default to non-admin - "is_AI": True, # Mark as AI-generated - }) + candidate_data.update( + { + "user_type": "candidate", + "created_at": current_time.isoformat(), + "updated_at": current_time.isoformat(), + "status": "active", # Directly active for AI-generated candidates + "is_admin": False, # Default to non-admin + "is_AI": True, # Mark as AI-generated + } + ) candidate = CandidateAI.model_validate(candidate_data) except ValidationError as e: logger.error("❌ AI candidate data validation failed") @@ -132,7 +175,7 @@ async def create_candidate_ai( print(f"Field: {error['loc'][0]}, Error: {error['msg']}") return JSONResponse( status_code=400, - content=create_error_response("AI_VALIDATION_FAILED", "AI-generated data validation failed") + content=create_error_response("AI_VALIDATION_FAILED", "AI-generated data validation failed"), ) except Exception: # Log the error and return a validation error response @@ -141,21 +184,21 @@ async def create_candidate_ai( logger.error(json.dumps(persona_message.content, indent=2)) return JSONResponse( status_code=400, - content=create_error_response("AI_VALIDATION_FAILED", "AI-generated data validation failed") + content=create_error_response("AI_VALIDATION_FAILED", "AI-generated data validation failed"), ) - + logger.info(f"🤖 AI-generated candidate {candidate.username} created with email {candidate.email}") candidate_data = candidate.model_dump(by_alias=False, exclude_unset=False) # Store in database await database.set_candidate(candidate.id, candidate_data) - + user_auth_data = { "id": candidate.id, "type": "candidate", "email": candidate.email, - "username": candidate.username + "username": candidate.username, } - + await database.set_user(candidate.email, user_auth_data) await database.set_user(candidate.username, user_auth_data) await database.set_user_by_id(candidate.id, user_auth_data) @@ -164,7 +207,7 @@ async def create_candidate_ai( if resume_message: document_id = str(uuid.uuid4()) document_type = DocumentType.MARKDOWN - document_content = resume_message.content.encode('utf-8') + document_content = resume_message.content.encode("utf-8") document_filename = "resume.md" document_data = Document( @@ -174,8 +217,8 @@ async def create_candidate_ai( type=document_type, size=len(document_content), upload_date=datetime.now(UTC), - owner_id=candidate.id - ) + owner_id=candidate.id, + ) file_path = os.path.join(defines.user_dir, candidate.username, "rag-content", document_filename) # Ensure the directory exists rag_content_dir = pathlib.Path(defines.user_dir) / candidate.username / "rag-content" @@ -183,64 +226,62 @@ async def create_candidate_ai( try: with open(file_path, "wb") as f: f.write(document_content) - + logger.info(f"📁 File saved to disk: {file_path}") - + except Exception as e: logger.error(f"❌ Failed to save file to disk: {e}") return JSONResponse( - status_code=500, - content=create_error_response("FILE_SAVE_ERROR", "Failed to resume file to disk") + status_code=500, content=create_error_response("FILE_SAVE_ERROR", "Failed to resume file to disk") ) # Store document metadata in database await database.set_document(document_id, document_data.model_dump()) await database.add_document_to_candidate(candidate.id, document_id) - logger.info(f"📄 Document metadata saved for candidate {candidate.id}: {document_id}") + logger.info(f"📄 Document metadata saved for candidate {candidate.id}: {document_id}") - logger.info(f"✅ AI-generated candidate created: {candidate_data['email']}, resume is {len(document_content) if document_content else 0} bytes") + logger.info( + f"✅ AI-generated candidate created: {candidate_data['email']}, resume is {len(document_content) if document_content else 0} bytes" + ) + + return create_success_response( + { + "message": "AI-generated candidate created successfully", + "candidate": candidate_data, + "resume": document_content, + } + ) - return create_success_response({ - "message": "AI-generated candidate created successfully", - "candidate": candidate_data, - "resume": document_content, - }) - except Exception as e: logger.error(backstory_traceback.format_exc()) logger.error(f"❌ AI Candidate creation error: {e}") return JSONResponse( status_code=500, - content=create_error_response("AI_CREATION_FAILED", "Failed to create AI-generated candidate") + content=create_error_response("AI_CREATION_FAILED", "Failed to create AI-generated candidate"), ) - + + @router.post("") async def create_candidate_with_verification( - request: CreateCandidateRequest, - background_tasks: BackgroundTasks, - database: RedisDatabase = Depends(get_database) + request: CreateCandidateRequest, background_tasks: BackgroundTasks, database: RedisDatabase = Depends(get_database) ): """Create a new candidate with email verification""" try: # Initialize authentication manager auth_manager = AuthenticationManager(database) - + # Check if user already exists - user_exists, conflict_field = await auth_manager.check_user_exists( - request.email, - request.username - ) - + user_exists, conflict_field = await auth_manager.check_user_exists(request.email, request.username) + if user_exists and conflict_field: - logger.warning(f"⚠️ Attempted to create user with existing {conflict_field}: {getattr(request, conflict_field)}") + logger.warning( + f"⚠️ Attempted to create user with existing {conflict_field}: {getattr(request, conflict_field)}" + ) return JSONResponse( status_code=409, - content=create_error_response( - "USER_EXISTS", - f"A user with this {conflict_field} already exists" - ) + content=create_error_response("USER_EXISTS", f"A user with this {conflict_field} already exists"), ) - + # Generate candidate data (but don't activate yet) candidate_id = str(uuid.uuid4()) current_time = datetime.now(timezone.utc) @@ -248,7 +289,7 @@ async def create_candidate_with_verification( is_admin = False if len(all_candidates) == 0: is_admin = True - + candidate_data = { "id": candidate_id, "userType": "candidate", @@ -263,10 +304,10 @@ async def create_candidate_with_verification( "status": "pending", # Not active until email verified "isAdmin": is_admin, } - + # Generate verification token verification_token = secrets.token_urlsafe(32) - + # Store verification token with user data await database.store_email_verification_token( request.email, @@ -275,100 +316,123 @@ async def create_candidate_with_verification( { "candidate_data": candidate_data, "password": request.password, # Store temporarily for verification - "username": request.username - } + "username": request.username, + }, ) - + # Send verification email in background background_tasks.add_task( email_service.send_verification_email, request.email, verification_token, - f"{request.first_name} {request.last_name}" + f"{request.first_name} {request.last_name}", ) - + logger.info(f"✅ Candidate registration initiated for: {request.email}") - - return create_success_response({ - "message": f"Registration successful! Please check your email to verify your account. {'As the first user on this sytem, you have admin priveledges.' if is_admin else ''}", - "email": request.email, - "verificationRequired": True - }) - + + return create_success_response( + { + "message": f"Registration successful! Please check your email to verify your account. {'As the first user on this sytem, you have admin priveledges.' if is_admin else ''}", + "email": request.email, + "verificationRequired": True, + } + ) + except Exception as e: logger.error(f"❌ Candidate creation error: {e}") return JSONResponse( - status_code=500, - content=create_error_response("CREATION_FAILED", "Failed to create candidate account") + status_code=500, content=create_error_response("CREATION_FAILED", "Failed to create candidate account") ) + @router.post("/documents/upload") async def upload_candidate_document( file: UploadFile = File(...), options_data: str = Form(..., alias="options"), - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user), + database: RedisDatabase = Depends(get_database), ): try: # Parse the JSON string and create DocumentOptions object options_dict = json.loads(options_data) - options : DocumentOptions = DocumentOptions.model_validate(options_dict) + options: DocumentOptions = DocumentOptions.model_validate(options_dict) except (json.JSONDecodeError, ValidationError): return StreamingResponse( - iter([json.dumps(ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads - content="Invalid options format. Please provide valid JSON." - ).model_dump(mode='json', by_alias=True))]), - media_type="text/event-stream" + iter( + [ + json.dumps( + ChatMessageError( + session_id=MOCK_UUID, # No session ID for document uploads + content="Invalid options format. Please provide valid JSON.", + ).model_dump(mode="json", by_alias=True) + ) + ] + ), + media_type="text/event-stream", ) - + # Check file size (limit to 10MB) max_size = 10 * 1024 * 1024 # 10MB file_content = await file.read() if len(file_content) > max_size: logger.info(f"⚠️ File too large: {file.filename} ({len(file_content)} bytes)") return StreamingResponse( - iter([json.dumps(ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads - content="File size exceeds 10MB limit" - ).model_dump(mode='json', by_alias=True))]), - media_type="text/event-stream" + iter( + [ + json.dumps( + ChatMessageError( + session_id=MOCK_UUID, # No session ID for document uploads + content="File size exceeds 10MB limit", + ).model_dump(mode="json", by_alias=True) + ) + ] + ), + media_type="text/event-stream", ) if len(file_content) == 0: logger.info(f"⚠️ File is empty: {file.filename}") return StreamingResponse( - iter([json.dumps(ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads - content="File is empty" - ).model_dump(mode='json', by_alias=True))]), - media_type="text/event-stream" + iter( + [ + json.dumps( + ChatMessageError( + session_id=MOCK_UUID, # No session ID for document uploads + content="File is empty", + ).model_dump(mode="json", by_alias=True) + ) + ] + ), + media_type="text/event-stream", ) """Upload a document for the current candidate""" + async def upload_stream_generator(file_content): # Verify user is a candidate if current_user.user_type != "candidate": logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}") error_message = ChatMessageError( session_id=MOCK_UUID, # No session ID for document uploads - content="Only candidates can upload documents" + content="Only candidates can upload documents", ) yield error_message return - + candidate: Candidate = current_user - file.filename = re.sub(r'^.*/', '', file.filename) if file.filename else '' # Sanitize filename + file.filename = re.sub(r"^.*/", "", file.filename) if file.filename else "" # Sanitize filename if not file.filename or file.filename.strip() == "": logger.warning("⚠️ File upload attempt with missing filename") error_message = ChatMessageError( session_id=MOCK_UUID, # No session ID for document uploads - content="File must have a valid filename" + content="File must have a valid filename", ) yield error_message return - - logger.info(f"📁 Received file upload: filename='{file.filename}', content_type='{file.content_type}', size='{len(file_content)} bytes'") - + + logger.info( + f"📁 Received file upload: filename='{file.filename}', content_type='{file.content_type}', size='{len(file_content)} bytes'" + ) + directory = "rag-content" if options.include_in_rag else "files" directory = "jobs" if options.is_job_document else directory @@ -382,7 +446,7 @@ async def upload_candidate_document( logger.warning(f"⚠️ File already exists: {file_path}") error_message = ChatMessageError( session_id=MOCK_UUID, # No session ID for document uploads - content=f"File with this name already exists in the '{directory}' directory" + content=f"File with this name already exists in the '{directory}' directory", ) yield error_message return @@ -391,27 +455,27 @@ async def upload_candidate_document( status_message = ChatMessageStatus( session_id=MOCK_UUID, # No session ID for document uploads content=f"Overwriting existing file: {file.filename}", - activity=ApiActivityType.INFO + activity=ApiActivityType.INFO, ) yield status_message # Validate file type - allowed_types = ['.txt', '.md', '.docx', '.pdf', '.png', '.jpg', '.jpeg', '.gif'] + allowed_types = [".txt", ".md", ".docx", ".pdf", ".png", ".jpg", ".jpeg", ".gif"] file_extension = pathlib.Path(file.filename).suffix.lower() if file.filename else "" - + if file_extension not in allowed_types: logger.warning(f"⚠️ Invalid file type: {file_extension} for file {file.filename}") error_message = ChatMessageError( session_id=MOCK_UUID, # No session ID for document uploads - content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}" + content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}", ) yield error_message return - + # Create document metadata document_id = str(uuid.uuid4()) document_type = get_document_type_from_filename(file.filename or "unknown.txt") - + document_data = Document( id=document_id, filename=file.filename or f"document_{document_id}", @@ -420,19 +484,19 @@ async def upload_candidate_document( size=len(file_content), upload_date=datetime.now(UTC), options=options, - owner_id=candidate.id + owner_id=candidate.id, ) - + # Save file to disk directory = os.path.join(defines.user_dir, candidate.username, directory) file_path = os.path.join(directory, file.filename) - + try: with open(file_path, "wb") as f: f.write(file_content) - + logger.info(f"📁 File saved to disk: {file_path}") - + except Exception as e: logger.error(f"❌ Failed to save file to disk: {e}") error_message = ChatMessageError( @@ -448,17 +512,16 @@ async def upload_candidate_document( p_as_md = p.with_suffix(".md") # If file_path.md doesn't exist or file_path is newer than file_path.md, # fire off markitdown - if (not p_as_md.exists()) or ( - p.stat().st_mtime > p_as_md.stat().st_mtime - ): + if (not p_as_md.exists()) or (p.stat().st_mtime > p_as_md.stat().st_mtime): status_message = ChatMessageStatus( session_id=MOCK_UUID, # No session ID for document uploads content=f"Converting content from {document_type}...", - activity=ApiActivityType.CONVERTING + activity=ApiActivityType.CONVERTING, ) yield status_message try: - from markitdown import MarkItDown # type: ignore + from markitdown import MarkItDown # type: ignore + md = MarkItDown(enable_plugins=False) # Set to True to enable plugins result = md.convert(file_path, output_format="markdown") p_as_md.write_text(result.text_content) @@ -488,11 +551,13 @@ async def upload_candidate_document( content=file_content, ) yield chat_message + try: + async def to_json(method): try: async for message in method: - json_data = message.model_dump(mode='json', by_alias=True) + json_data = message.model_dump(mode="json", by_alias=True) json_str = json.dumps(json_data) yield f"data: {json_str}\n\n".encode("utf-8") except Exception as e: @@ -516,18 +581,25 @@ async def upload_candidate_document( logger.error(backstory_traceback.format_exc()) logger.error(f"❌ Document upload error: {e}") return StreamingResponse( - iter([json.dumps(ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads - content="Failed to upload document" - ).model_dump(mode='json', by_alias=True))]), - media_type="text/event-stream" + iter( + [ + json.dumps( + ChatMessageError( + session_id=MOCK_UUID, # No session ID for document uploads + content="Failed to upload document", + ).model_dump(mode="json", by_alias=True) + ) + ] + ), + media_type="text/event-stream", ) + @router.post("/profile/upload") async def upload_candidate_profile( file: UploadFile = File(...), - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user), + database: RedisDatabase = Depends(get_database), ): """Upload a document for the current candidate""" try: @@ -535,79 +607,71 @@ async def upload_candidate_profile( if current_user.user_type != "candidate": logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}") return JSONResponse( - status_code=403, - content=create_error_response("FORBIDDEN", "Only candidates can upload their profile") + status_code=403, content=create_error_response("FORBIDDEN", "Only candidates can upload their profile") ) - + candidate: Candidate = current_user # Validate file type - allowed_types = ['.png', '.jpg', '.jpeg', '.gif'] + allowed_types = [".png", ".jpg", ".jpeg", ".gif"] file_extension = pathlib.Path(file.filename).suffix.lower() if file.filename else "" - + if file_extension not in allowed_types: logger.warning(f"⚠️ Invalid file type: {file_extension} for file {file.filename}") return JSONResponse( status_code=400, content=create_error_response( - "INVALID_FILE_TYPE", - f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}" - ) + "INVALID_FILE_TYPE", + f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}", + ), ) - + # Check file size (limit to 5MB) max_size = 5 * 1024 * 1024 # 2MB file_content = await file.read() if len(file_content) > max_size: logger.info(f"⚠️ File too large: {file.filename} ({len(file_content)} bytes)") return JSONResponse( - status_code=400, - content=create_error_response("FILE_TOO_LARGE", "File size exceeds 10MB limit") + status_code=400, content=create_error_response("FILE_TOO_LARGE", "File size exceeds 10MB limit") ) - + # Save file to disk as "profile." _, extension = os.path.splitext(file.filename or "") file_path = os.path.join(defines.user_dir, candidate.username) os.makedirs(file_path, exist_ok=True) file_path = os.path.join(file_path, f"profile{extension}") - + try: with open(file_path, "wb") as f: f.write(file_content) - + logger.info(f"📁 File saved to disk: {file_path}") - + except Exception as e: logger.error(f"❌ Failed to save file to disk: {e}") return JSONResponse( - status_code=500, - content=create_error_response("FILE_SAVE_ERROR", "Failed to save file to disk") + status_code=500, content=create_error_response("FILE_SAVE_ERROR", "Failed to save file to disk") ) - updates = { - "updated_at": datetime.now(UTC).isoformat(), - "profile_image": f"profile{extension}" - } + updates = {"updated_at": datetime.now(UTC).isoformat(), "profile_image": f"profile{extension}"} candidate_dict = candidate.model_dump() candidate_dict.update(updates) updated_candidate = Candidate.model_validate(candidate_dict) await database.set_candidate(candidate.id, updated_candidate.model_dump()) logger.info(f"📄 Profile image uploaded: {updated_candidate.profile_image} for candidate {candidate.id}") - + return create_success_response(True) - + except Exception as e: logger.error(backstory_traceback.format_exc()) logger.error(f"❌ Document upload error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("UPLOAD_ERROR", "Failed to upload document") - ) + return JSONResponse(status_code=500, content=create_error_response("UPLOAD_ERROR", "Failed to upload document")) + @router.get("/profile/{username}") async def get_candidate_profile_image( username: str = Path(..., description="Username of the candidate"), # current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + database: RedisDatabase = Depends(get_database), ): """Get profile image of a candidate by username""" try: @@ -619,49 +683,39 @@ async def get_candidate_profile_image( # Filter by search query candidates_list = [ - c for c in candidates_list - if (query_lower == c.email.lower() or - query_lower == c.username.lower()) + c for c in candidates_list if (query_lower == c.email.lower() or query_lower == c.username.lower()) ] if not len(candidates_list): - return JSONResponse( - status_code=404, - content=create_error_response("NOT_FOUND", "Candidate not found") - ) - + return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Candidate not found")) + candidate = Candidate.model_validate(candidates_list[0]) if not candidate.profile_image: logger.warning(f"⚠️ Candidate {candidate.username} has no profile image set") - return JSONResponse( - status_code=404, - content=create_error_response("NOT_FOUND", "Profile image not found") - ) + return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Profile image not found")) file_path = os.path.join(defines.user_dir, candidate.username, candidate.profile_image) file_path = pathlib.Path(file_path) if not file_path.exists(): logger.error(f"❌ Profile image file not found on disk: {file_path}") return JSONResponse( - status_code=404, - content=create_error_response("FILE_NOT_FOUND", "Profile image file not found on disk") + status_code=404, content=create_error_response("FILE_NOT_FOUND", "Profile image file not found on disk") ) return FileResponse( file_path, media_type=f"image/{file_path.suffix[1:]}", # Get extension without dot - filename=candidate.profile_image + filename=candidate.profile_image, ) except Exception as e: logger.error(backstory_traceback.format_exc()) logger.error(f"❌ Get candidate profile image failed: {str(e)}") return JSONResponse( - status_code=500, - content=create_error_response("FETCH_ERROR", "Failed to retrieve profile image") + status_code=500, content=create_error_response("FETCH_ERROR", "Failed to retrieve profile image") ) + @router.get("/documents") async def get_candidate_documents( - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database) ): """Get all documents for the current candidate""" try: @@ -669,141 +723,132 @@ async def get_candidate_documents( if current_user.user_type != "candidate": logger.warning(f"⚠️ Unauthorized access attempt by user type: {current_user.user_type}") return JSONResponse( - status_code=403, - content=create_error_response("FORBIDDEN", "Only candidates can access documents") + status_code=403, content=create_error_response("FORBIDDEN", "Only candidates can access documents") ) - + candidate: Candidate = current_user - + # Get documents from database documents_data = await database.get_candidate_documents(candidate.id) documents = [Document.model_validate(doc_data) for doc_data in documents_data] - + # Sort by upload date (newest first) documents.sort(key=lambda x: x.upload_date, reverse=True) - - response_data = DocumentListResponse( - documents=documents, - total=len(documents) - ) - + + response_data = DocumentListResponse(documents=documents, total=len(documents)) + return create_success_response(response_data.model_dump(by_alias=True)) - + except Exception as e: logger.error(backstory_traceback.format_exc()) logger.error(f"❌ Get candidate documents error: {e}") return JSONResponse( - status_code=500, - content=create_error_response("FETCH_ERROR", "Failed to retrieve documents") + status_code=500, content=create_error_response("FETCH_ERROR", "Failed to retrieve documents") ) + @router.get("/documents/{document_id}/content") async def get_document_content( document_id: str = Path(...), - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user), + database: RedisDatabase = Depends(get_database), ): """Get document content by ID""" try: # Verify user is a candidate if current_user.user_type != "candidate": return JSONResponse( - status_code=403, - content=create_error_response("FORBIDDEN", "Only candidates can access documents") + status_code=403, content=create_error_response("FORBIDDEN", "Only candidates can access documents") ) - + candidate: Candidate = current_user - + # Get document metadata document_data = await database.get_document(document_id) if not document_data: - return JSONResponse( - status_code=404, - content=create_error_response("NOT_FOUND", "Document not found") - ) - + return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Document not found")) + document = Document.model_validate(document_data) - + # Verify document belongs to current candidate if document.owner_id != candidate.id: return JSONResponse( status_code=403, - content=create_error_response("FORBIDDEN", "Cannot access another candidate's document") + content=create_error_response("FORBIDDEN", "Cannot access another candidate's document"), ) - - file_path = os.path.join(defines.user_dir, candidate.username, "rag-content" if document.options.include_in_rag else "files", document.original_name) + + file_path = os.path.join( + defines.user_dir, + candidate.username, + "rag-content" if document.options.include_in_rag else "files", + document.original_name, + ) file_path = pathlib.Path(file_path) if document.type not in [DocumentType.TXT, DocumentType.MARKDOWN]: - file_path = file_path.with_suffix('.md') + file_path = file_path.with_suffix(".md") if not file_path.exists(): logger.error(f"❌ Document file not found on disk: {file_path}") return JSONResponse( - status_code=404, - content=create_error_response("FILE_NOT_FOUND", "Document file not found on disk") + status_code=404, content=create_error_response("FILE_NOT_FOUND", "Document file not found on disk") ) - + try: with open(file_path, "r", encoding="utf-8") as f: content = f.read() - + response = DocumentContentResponse( document_id=document_id, filename=document.filename, type=document.type, content=content, - size=document.size + size=document.size, ) return create_success_response(response.model_dump(by_alias=True)) - + except Exception as e: logger.error(f"❌ Failed to read document file: {e}") return JSONResponse( - status_code=500, - content=create_error_response("READ_ERROR", "Failed to read document content") + status_code=500, content=create_error_response("READ_ERROR", "Failed to read document content") ) - + except Exception as e: logger.error(backstory_traceback.format_exc()) logger.error(f"❌ Get document content error: {e}") return JSONResponse( - status_code=500, - content=create_error_response("FETCH_ERROR", "Failed to retrieve document content") + status_code=500, content=create_error_response("FETCH_ERROR", "Failed to retrieve document content") ) + @router.patch("/documents/{document_id}") async def update_document( document_id: str = Path(...), updates: DocumentUpdateRequest = Body(...), - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user), + database: RedisDatabase = Depends(get_database), ): """Update document metadata (filename, RAG status)""" try: # Verify user is a candidate if current_user.user_type != "candidate": return JSONResponse( - status_code=403, - content=create_error_response("FORBIDDEN", "Only candidates can update documents") + status_code=403, content=create_error_response("FORBIDDEN", "Only candidates can update documents") ) - + candidate: Candidate = current_user - + # Get document metadata document_data = await database.get_document(document_id) if not document_data: - return JSONResponse( - status_code=404, - content=create_error_response("NOT_FOUND", "Document not found") - ) - + return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Document not found")) + document = Document.model_validate(document_data) - + # Verify document belongs to current candidate if document.owner_id != candidate.id: return JSONResponse( status_code=403, - content=create_error_response("FORBIDDEN", "Cannot update another candidate's document") + content=create_error_response("FORBIDDEN", "Cannot update another candidate's document"), ) update_options = updates.options if updates.options else DocumentOptions() if document.options.include_in_rag != update_options.include_in_rag: @@ -814,7 +859,7 @@ async def update_document( os.makedirs(file_dir, exist_ok=True) rag_path = os.path.join(rag_dir, document.original_name) file_path = os.path.join(file_dir, document.original_name) - + if update_options.include_in_rag: src = pathlib.Path(file_path) dst = pathlib.Path(rag_path) @@ -845,42 +890,38 @@ async def update_document( update_dict["filename"] = updates.filename.strip() if update_options.include_in_rag is not None: update_dict["include_in_rag"] = update_options.include_in_rag - + if not update_dict: return JSONResponse( - status_code=400, - content=create_error_response("NO_UPDATES", "No valid updates provided") + status_code=400, content=create_error_response("NO_UPDATES", "No valid updates provided") ) - + # Add timestamp update_dict["updatedAt"] = datetime.now(UTC).isoformat() - + # Update in database updated_data = await database.update_document(document_id, update_dict) if not updated_data: return JSONResponse( - status_code=500, - content=create_error_response("UPDATE_FAILED", "Failed to update document") + status_code=500, content=create_error_response("UPDATE_FAILED", "Failed to update document") ) - + updated_document = Document.model_validate(updated_data) - + logger.info(f"📄 Document updated: {document_id} for candidate {candidate.username}") - + return create_success_response(updated_document.model_dump(by_alias=True)) - + except Exception as e: logger.error(f"❌ Update document error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("UPDATE_ERROR", "Failed to update document") - ) + return JSONResponse(status_code=500, content=create_error_response("UPDATE_ERROR", "Failed to update document")) + @router.delete("/documents/{document_id}") async def delete_document( document_id: str = Path(...), - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user), + database: RedisDatabase = Depends(get_database), ): """Delete a document and its file""" try: @@ -888,35 +929,36 @@ async def delete_document( if current_user.user_type != "candidate": logger.warning(f"⚠️ Unauthorized delete attempt by user type: {current_user.user_type}") return JSONResponse( - status_code=403, - content=create_error_response("FORBIDDEN", "Only candidates can delete documents") + status_code=403, content=create_error_response("FORBIDDEN", "Only candidates can delete documents") ) - + candidate: Candidate = current_user - + # Get document metadata document_data = await database.get_document(document_id) if not document_data: logger.warning(f"⚠️ Document not found for deletion: {document_id}") - return JSONResponse( - status_code=404, - content=create_error_response("NOT_FOUND", "Document not found") - ) - + return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Document not found")) + document = Document.model_validate(document_data) - + # Verify document belongs to current candidate if document.owner_id != candidate.id: logger.warning(f"⚠️ Unauthorized delete attempt on document {document_id} by candidate {candidate.username}") return JSONResponse( status_code=403, - content=create_error_response("FORBIDDEN", "Cannot delete another candidate's document") + content=create_error_response("FORBIDDEN", "Cannot delete another candidate's document"), ) - + # Delete file from disk - file_path = os.path.join(defines.user_dir, candidate.username, "rag-content" if document.options.include_in_rag else "files", document.original_name) + file_path = os.path.join( + defines.user_dir, + candidate.username, + "rag-content" if document.options.include_in_rag else "files", + document.original_name, + ) file_path = pathlib.Path(file_path) - + try: if file_path.exists(): file_path.unlink() @@ -934,85 +976,72 @@ async def delete_document( except Exception as e: logger.error(f"❌ Failed to delete file from disk: {e}") # Continue with metadata deletion even if file deletion fails - + # Remove from database await database.remove_document_from_candidate(candidate.id, document_id) await database.delete_document(document_id) - + logger.info(f"🗑️ Document deleted: {document_id} for candidate {candidate.username}") - - return create_success_response({ - "message": "Document deleted successfully", - "documentId": document_id - }) - + + return create_success_response({"message": "Document deleted successfully", "documentId": document_id}) + except Exception as e: logger.error(backstory_traceback.format_exc()) logger.error(f"❌ Delete document error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("DELETE_ERROR", "Failed to delete document") - ) + return JSONResponse(status_code=500, content=create_error_response("DELETE_ERROR", "Failed to delete document")) + @router.get("/documents/search") async def search_candidate_documents( query: str = Query(..., min_length=1), - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user), + database: RedisDatabase = Depends(get_database), ): """Search candidate documents by filename""" try: # Verify user is a candidate if current_user.user_type != "candidate": return JSONResponse( - status_code=403, - content=create_error_response("FORBIDDEN", "Only candidates can search documents") + status_code=403, content=create_error_response("FORBIDDEN", "Only candidates can search documents") ) - + candidate: Candidate = current_user - + # Search documents documents_data = await database.search_candidate_documents(candidate.id, query) documents = [Document.model_validate(doc_data) for doc_data in documents_data] - + # Sort by upload date (newest first) documents.sort(key=lambda x: x.upload_date, reverse=True) - - response_data = DocumentListResponse( - documents=documents, - total=len(documents) - ) - + + response_data = DocumentListResponse(documents=documents, total=len(documents)) + return create_success_response(response_data.model_dump(by_alias=True)) - + except Exception as e: logger.error(f"❌ Search documents error: {e}") return JSONResponse( - status_code=500, - content=create_error_response("SEARCH_ERROR", "Failed to search documents") + status_code=500, content=create_error_response("SEARCH_ERROR", "Failed to search documents") ) + @router.post("/rag-content") async def post_candidate_vector_content( - rag_document: RAGDocumentRequest = Body(...), - current_user = Depends(get_current_user) + rag_document: RAGDocumentRequest = Body(...), current_user=Depends(get_current_user) ): try: if current_user.user_type != "candidate": logger.warning(f"⚠️ Unauthorized access attempt by user type: {current_user.user_type}") return JSONResponse( - status_code=403, - content=create_error_response("FORBIDDEN", "Only candidates can access this endpoint") + status_code=403, content=create_error_response("FORBIDDEN", "Only candidates can access this endpoint") ) - candidate : Candidate = current_user + candidate: Candidate = current_user async with entities.get_candidate_entity(candidate=candidate) as candidate_entity: collection = candidate_entity.umap_collection if not collection: logger.warning(f"⚠️ No UMAP collection found for candidate {candidate.username}") - return JSONResponse( - {"error": "No UMAP collection found"}, status_code=404 - ) + return JSONResponse({"error": "No UMAP collection found"}, status_code=404) for index, id in enumerate(collection.ids): if id == rag_document.id: @@ -1027,39 +1056,29 @@ async def post_candidate_vector_content( return JSONResponse(f"No content found for document id {rag_document.id}.", 404) return create_success_response(rag_response.model_dump(by_alias=True)) - logger.warning(f"⚠️ Document id {rag_document.id} not found in UMAP collection for candidate {candidate.username}") + logger.warning( + f"⚠️ Document id {rag_document.id} not found in UMAP collection for candidate {candidate.username}" + ) return JSONResponse(f"Document id {rag_document.id} not found.", 404) except Exception as e: logger.error(backstory_traceback.format_exc()) logger.error(f"❌ Post candidate content error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("FETCH_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e))) + @router.post("/rag-vectors") -async def post_candidate_vectors( - dimensions: int = Body(...), - current_user = Depends(get_current_user) -): +async def post_candidate_vectors(dimensions: int = Body(...), current_user=Depends(get_current_user)): try: if current_user.user_type != "candidate": return JSONResponse( - status_code=403, - content=create_error_response("FORBIDDEN", "Only candidates can access this endpoint") + status_code=403, content=create_error_response("FORBIDDEN", "Only candidates can access this endpoint") ) - candidate : Candidate = current_user + candidate: Candidate = current_user async with entities.get_candidate_entity(candidate=candidate) as candidate_entity: collection = candidate_entity.umap_collection if not collection: - results = { - "ids": [], - "metadatas": [], - "documents": [], - "embeddings": [], - "size": 0 - } + results = {"ids": [], "metadatas": [], "documents": [], "embeddings": [], "size": 0} return create_success_response(results) if dimensions == 2: umap_embedding = candidate_entity.file_watcher.umap_embedding_2d @@ -1067,38 +1086,30 @@ async def post_candidate_vectors( umap_embedding = candidate_entity.file_watcher.umap_embedding_3d if len(umap_embedding) == 0: - results = { - "ids": [], - "metadatas": [], - "documents": [], - "embeddings": [], - "size": 0 - } + results = {"ids": [], "metadatas": [], "documents": [], "embeddings": [], "size": 0} return create_success_response(results) - + result = { "ids": collection.ids, "metadatas": collection.metadatas, "documents": collection.documents, "embeddings": umap_embedding.tolist(), - "size": candidate_entity.file_watcher.collection.count() + "size": candidate_entity.file_watcher.collection.count(), } return create_success_response(result) - + except Exception as e: logger.error(backstory_traceback.format_exc()) logger.error(f"❌ Post candidate vectors error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("FETCH_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e))) + @router.delete("/{candidate_id}") async def delete_candidate( candidate_id: str = Path(...), - admin_user = Depends(get_current_admin), - database: RedisDatabase = Depends(get_database) + admin_user=Depends(get_current_admin), + database: RedisDatabase = Depends(get_database), ): """Delete a candidate""" try: @@ -1106,27 +1117,23 @@ async def delete_candidate( if not admin_user.is_admin: logger.warning(f"⚠️ Unauthorized delete attempt by user {admin_user.id}") return JSONResponse( - status_code=403, - content=create_error_response("FORBIDDEN", "Only admins can delete candidates") + status_code=403, content=create_error_response("FORBIDDEN", "Only admins can delete candidates") ) - + # Get candidate data candidate_data = await database.get_candidate(candidate_id) if not candidate_data: logger.warning(f"⚠️ Candidate not found for deletion: {candidate_id}") - return JSONResponse( - status_code=404, - content=create_error_response("NOT_FOUND", "Candidate not found") - ) + return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Candidate not found")) await entities.entity_manager.remove_entity(candidate_id) # Delete candidate from database await database.delete_candidate(candidate_id) - + # Optionally delete files and documents associated with the candidate await database.delete_all_candidate_documents(candidate_id) - + file_path = os.path.join(defines.user_dir, candidate_data["username"]) if os.path.exists(file_path): try: @@ -1136,63 +1143,56 @@ async def delete_candidate( logger.error(f"❌ Failed to delete candidate files directory: {e}") logger.info(f"🗑️ Candidate deleted: {candidate_id} by admin {admin_user.id}") - - return create_success_response({ - "message": "Candidate deleted successfully", - "candidateId": candidate_id - }) - + + return create_success_response({"message": "Candidate deleted successfully", "candidateId": candidate_id}) + except Exception as e: logger.error(f"❌ Delete candidate error: {e}") return JSONResponse( - status_code=500, - content=create_error_response("DELETE_ERROR", "Failed to delete candidate") + status_code=500, content=create_error_response("DELETE_ERROR", "Failed to delete candidate") ) + @router.patch("/{candidate_id}") async def update_candidate( candidate_id: str = Path(...), updates: Dict[str, Any] = Body(...), - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user), + database: RedisDatabase = Depends(get_database), ): """Update a candidate""" try: candidate_data = await database.get_candidate(candidate_id) if not candidate_data: logger.warning(f"⚠️ Candidate not found for update: {candidate_id}") - return JSONResponse( - status_code=404, - content=create_error_response("NOT_FOUND", "Candidate not found") - ) - - is_AI = candidate_data.get("is_AI", False) + return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Candidate not found")) + + is_AI = candidate_data.get("is_AI", False) candidate = CandidateAI.model_validate(candidate_data) if is_AI else Candidate.model_validate(candidate_data) - + # Check authorization (user can only update their own profile) if current_user.is_admin is False and candidate.id != current_user.id: logger.warning(f"⚠️ Unauthorized update attempt by user {current_user.id} on candidate {candidate_id}") return JSONResponse( - status_code=403, - content=create_error_response("FORBIDDEN", "Cannot update another user's profile") + status_code=403, content=create_error_response("FORBIDDEN", "Cannot update another user's profile") ) - + # Apply updates updates["updatedAt"] = datetime.now(UTC).isoformat() logger.info(f"🔄 Updating candidate {candidate_id} with data: {updates}") candidate_dict = candidate.model_dump() candidate_dict.update(updates) - updated_candidate = CandidateAI.model_validate(candidate_dict) if is_AI else Candidate.model_validate(candidate_dict) + updated_candidate = ( + CandidateAI.model_validate(candidate_dict) if is_AI else Candidate.model_validate(candidate_dict) + ) await database.set_candidate(candidate_id, updated_candidate.model_dump()) - + return create_success_response(updated_candidate.model_dump(by_alias=True)) - + except Exception as e: logger.error(f"❌ Update candidate error: {e}") - return JSONResponse( - status_code=400, - content=create_error_response("UPDATE_FAILED", str(e)) - ) + return JSONResponse(status_code=400, content=create_error_response("UPDATE_FAILED", str(e))) + @router.get("") async def get_candidates( @@ -1201,8 +1201,8 @@ async def get_candidates( sortBy: Optional[str] = Query(None, alias="sortBy"), sortOrder: str = Query("desc", pattern="^(asc|desc)$", alias="sortOrder"), filters: Optional[str] = Query(None), - current_user = Depends(get_current_user_or_guest), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user_or_guest), + database: RedisDatabase = Depends(get_database), ): """Get paginated list of candidates""" try: @@ -1210,29 +1210,31 @@ async def get_candidates( filter_dict = None if filters: filter_dict = json.loads(filters) - + # Get all candidates from Redis all_candidates_data = await database.get_all_candidates() - candidates_list = [Candidate.model_validate(data) if not data.get("is_AI") else CandidateAI.model_validate(data) for data in all_candidates_data.values()] - candidates_list = [c for c in candidates_list if c.is_public or (current_user.userType != UserType.GUEST and c.id == current_user.id)] + candidates_list = [ + Candidate.model_validate(data) if not data.get("is_AI") else CandidateAI.model_validate(data) + for data in all_candidates_data.values() + ] + candidates_list = [ + c + for c in candidates_list + if c.is_public or (current_user.userType != UserType.GUEST and c.id == current_user.id) + ] + + paginated_candidates, total = filter_and_paginate(candidates_list, page, limit, sortBy, sortOrder, filter_dict) - paginated_candidates, total = filter_and_paginate( - candidates_list, page, limit, sortBy, sortOrder, filter_dict - ) - paginated_response = create_paginated_response( - [c.model_dump(by_alias=True) for c in paginated_candidates], - page, limit, total + [c.model_dump(by_alias=True) for c in paginated_candidates], page, limit, total ) return create_success_response(paginated_response) - + except Exception as e: logger.error(f"❌ Get candidates error: {e}") - return JSONResponse( - status_code=400, - content=create_error_response("FETCH_FAILED", str(e)) - ) + return JSONResponse(status_code=400, content=create_error_response("FETCH_FAILED", str(e))) + @router.get("/search") async def search_candidates( @@ -1240,7 +1242,7 @@ async def search_candidates( filters: Optional[str] = Query(None), page: int = Query(1, ge=1), limit: int = Query(20, ge=1, le=100), - database: RedisDatabase = Depends(get_database) + database: RedisDatabase = Depends(get_database), ): """Search candidates""" try: @@ -1248,69 +1250,65 @@ async def search_candidates( filter_dict = {} if filters: filter_dict = json.loads(filters) - + # Get all candidates from Redis all_candidates_data = await database.get_all_candidates() candidates_list = [Candidate.model_validate(data) for data in all_candidates_data.values()] - + # Filter by search query if query: query_lower = query.lower() candidates_list = [ - c for c in candidates_list - if (query_lower in c.first_name.lower() or - query_lower in c.last_name.lower() or - query_lower in c.email.lower() or - query_lower in c.username.lower() or - any(query_lower in skill.name.lower() for skill in c.skills or [])) + c + for c in candidates_list + if ( + query_lower in c.first_name.lower() + or query_lower in c.last_name.lower() + or query_lower in c.email.lower() + or query_lower in c.username.lower() + or any(query_lower in skill.name.lower() for skill in c.skills or []) + ) ] - - paginated_candidates, total = filter_and_paginate( - candidates_list, page, limit, filters=filter_dict - ) - + + paginated_candidates, total = filter_and_paginate(candidates_list, page, limit, filters=filter_dict) + paginated_response = create_paginated_response( - [c.model_dump(by_alias=True) for c in paginated_candidates], - page, limit, total - ) - - return create_success_response(paginated_response) - - except Exception as e: - logger.error(f"❌ Search candidates error: {e}") - return JSONResponse( - status_code=400, - content=create_error_response("SEARCH_FAILED", str(e)) + [c.model_dump(by_alias=True) for c in paginated_candidates], page, limit, total ) + return create_success_response(paginated_response) + + except Exception as e: + logger.error(f"❌ Search candidates error: {e}") + return JSONResponse(status_code=400, content=create_error_response("SEARCH_FAILED", str(e))) + + @router.post("/rag-search") -async def post_candidate_rag_search( - query: str = Body(...), - current_user = Depends(get_current_user) -): +async def post_candidate_rag_search(query: str = Body(...), current_user=Depends(get_current_user)): """Get chat activity summary for a candidate""" try: if current_user.user_type != "candidate": logger.warning(f"⚠️ Unauthorized RAG search attempt by user {current_user.id}") return JSONResponse( - status_code=403, - content=create_error_response("FORBIDDEN", "Only candidates can access this endpoint") + status_code=403, content=create_error_response("FORBIDDEN", "Only candidates can access this endpoint") ) - - candidate : Candidate = current_user + + candidate: Candidate = current_user chat_type = ChatContextType.RAG_SEARCH - # Get RAG search data + # Get RAG search data async with entities.get_candidate_entity(candidate=candidate) as candidate_entity: # Entity automatically released when done chat_agent = candidate_entity.get_or_create_agent(agent_type=chat_type) if not chat_agent: return JSONResponse( status_code=400, - content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type") + content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type"), ) - - user_message = ChatMessageUser(sender_id=candidate.id, session_id=MOCK_UUID, content=query, timestamp=datetime.now(UTC)) - rag_message : Any = None + + user_message = ChatMessageUser( + sender_id=candidate.id, session_id=MOCK_UUID, content=query, timestamp=datetime.now(UTC) + ) + rag_message: Any = None async for generated_message in chat_agent.generate( llm=llm_manager.get_llm(), model=defines.model, @@ -1318,28 +1316,23 @@ async def post_candidate_rag_search( prompt=user_message.content, ): rag_message = generated_message - + if not rag_message: return JSONResponse( status_code=500, - content=create_error_response("NO_RESPONSE", "No response generated for the RAG search") + content=create_error_response("NO_RESPONSE", "No response generated for the RAG search"), ) - final_message : ChatMessageRagSearch = rag_message + final_message: ChatMessageRagSearch = rag_message return create_success_response(final_message.content[0].model_dump(by_alias=True)) - + except Exception as e: logger.error(f"❌ Get candidate chat summary error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("SUMMARY_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("SUMMARY_ERROR", str(e))) + # reference can be candidateId, username, or email @router.get("/{reference}") -async def get_candidate( - reference: str = Path(...), - database: RedisDatabase = Depends(get_database) -): +async def get_candidate(reference: str = Path(...), database: RedisDatabase = Depends(get_database)): """Get a candidate by username""" try: # Normalize reference to lowercase for case-insensitive search @@ -1348,42 +1341,38 @@ async def get_candidate( all_candidates_data = await database.get_all_candidates() if not all_candidates_data: logger.warning("⚠️ No candidates found in database") - return JSONResponse( - status_code=404, - content=create_error_response("NOT_FOUND", "No candidates found") - ) - + return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "No candidates found")) + candidate_data = None for candidate in all_candidates_data.values(): - if (candidate.get("id", "").lower() == query_lower or - candidate.get("username", "").lower() == query_lower or - candidate.get("email", "").lower() == query_lower): + if ( + candidate.get("id", "").lower() == query_lower + or candidate.get("username", "").lower() == query_lower + or candidate.get("email", "").lower() == query_lower + ): candidate_data = candidate break if not candidate_data: logger.warning(f"⚠️ Candidate not found for reference: {reference}") - return JSONResponse( - status_code=404, - content=create_error_response("NOT_FOUND", "Candidate not found") - ) + return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Candidate not found")) - candidate = Candidate.model_validate(candidate_data) if not candidate_data.get("is_AI") else CandidateAI.model_validate(candidate_data) + candidate = ( + Candidate.model_validate(candidate_data) + if not candidate_data.get("is_AI") + else CandidateAI.model_validate(candidate_data) + ) return create_success_response(candidate.model_dump(by_alias=True)) - + except Exception as e: logger.error(f"❌ Get candidate error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("FETCH_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e))) + @router.get("/{username}/chat-summary") async def get_candidate_chat_summary( - username: str = Path(...), - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + username: str = Path(...), current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database) ): """Get chat activity summary for a candidate""" try: @@ -1392,53 +1381,53 @@ async def get_candidate_chat_summary( if not candidate_data: return JSONResponse( status_code=404, - content=create_error_response("CANDIDATE_NOT_FOUND", f"Candidate with username '{username}' not found") + content=create_error_response("CANDIDATE_NOT_FOUND", f"Candidate with username '{username}' not found"), ) - + summary = await database.get_candidate_chat_summary(candidate_data["id"]) summary["candidate"] = { "username": candidate_data.get("username"), "fullName": candidate_data.get("fullName"), - "email": candidate_data.get("email") + "email": candidate_data.get("email"), } - + return create_success_response(summary) - + except Exception as e: logger.error(f"❌ Get candidate chat summary error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("SUMMARY_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("SUMMARY_ERROR", str(e))) + @router.post("/{candidate_id}/skill-match") async def get_candidate_skill_match( candidate_id: str = Path(...), skill: str = Body(...), - current_user = Depends(get_current_user_or_guest), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user_or_guest), + database: RedisDatabase = Depends(get_database), ) -> StreamingResponse: - """Get skill match for a candidate against a skill with caching""" + async def message_stream_generator(): candidate_data = await database.get_candidate(candidate_id) if not candidate_data: error_message = ChatMessageError( session_id=MOCK_UUID, # No session ID for document uploads - content=f"Candidate with ID '{candidate_id}' not found" + content=f"Candidate with ID '{candidate_id}' not found", ) yield error_message return - + candidate = Candidate.model_validate(candidate_data) - + cache_key = get_skill_cache_key(candidate.id, skill) - + # Get cached assessment if it exists - assessment : SkillAssessment | None = await database.get_cached_skill_match(cache_key) + assessment: SkillAssessment | None = await database.get_cached_skill_match(cache_key) if assessment and assessment.skill.lower() != skill.lower(): - logger.warning(f"❌ Cached skill match for {candidate.username} does not match requested skill: {assessment.skill} != {skill} ({cache_key}). Regenerating...") + logger.warning( + f"❌ Cached skill match for {candidate.username} does not match requested skill: {assessment.skill} != {skill} ({cache_key}). Regenerating..." + ) assessment = None # Determine if we need to regenerate the assessment @@ -1453,39 +1442,41 @@ async def get_candidate_skill_match( logger.info(f"🔄 Out-of-date cached entry for {candidate.username} skill {assessment.skill}") assessment = None else: - logger.info(f"✅ Using cached skill match for {candidate.username} skill {assessment.skill}: {cache_key}") + logger.info( + f"✅ Using cached skill match for {candidate.username} skill {assessment.skill}: {cache_key}" + ) else: logger.info(f"💾 No cached skill match data: {cache_key}, {candidate.id}, {skill}") - + if assessment: # Return cached assessment skill_message = ChatMessageSkillAssessment( session_id=MOCK_UUID, # No session ID for document uploads content=f"Cached skill match found for {candidate.username}", - skill_assessment=assessment + skill_assessment=assessment, ) yield skill_message return - + logger.info(f"🔍 Generating skill match for candidate {candidate.username} for skill: {skill}") - + async with entities.get_candidate_entity(candidate=candidate) as candidate_entity: agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.SKILL_MATCH) if not agent: error_message = ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads - content="No skill match agent found for this candidate" - ) + session_id=MOCK_UUID, # No session ID for document uploads + content="No skill match agent found for this candidate", + ) yield error_message return - + # Generate new skill match final_message = None async for generated_message in agent.generate( llm=llm_manager.get_llm(), - model=defines.model, - session_id=MOCK_UUID, - prompt=skill, + model=defines.model, + session_id=MOCK_UUID, + prompt=skill, ): if generated_message.status == ApiStatusType.ERROR: if isinstance(generated_message, ChatMessageError): @@ -1494,63 +1485,66 @@ async def get_candidate_skill_match( content = "An error occurred during AI generation" error_message = ChatMessageError( session_id=MOCK_UUID, # No session ID for document uploads - content=f"AI generation error: {content}" + content=f"AI generation error: {content}", ) logger.error(f"❌ {error_message.content}") yield error_message return - + # If the message is not done, convert it to a ChatMessageBase to remove # metadata and other unnecessary fields for streaming if generated_message.status != ApiStatusType.DONE: - if not isinstance(generated_message, ChatMessageStreaming) and not isinstance(generated_message, ChatMessageStatus): + if not isinstance(generated_message, ChatMessageStreaming) and not isinstance( + generated_message, ChatMessageStatus + ): raise TypeError( f"Expected ChatMessageStreaming or ChatMessageStatus, got {type(generated_message)}" ) - yield generated_message# Convert to ChatMessageBase for streaming + yield generated_message # Convert to ChatMessageBase for streaming # Store reference to the complete AI message if generated_message.status == ApiStatusType.DONE: final_message = generated_message - break - + break + if final_message is None: error_message = ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads - content="No match found for the given skill" - ) + session_id=MOCK_UUID, # No session ID for document uploads + content="No match found for the given skill", + ) yield error_message return - + if not isinstance(final_message, ChatMessageSkillAssessment): error_message = ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads - content="Skill match response is not valid" - ) + session_id=MOCK_UUID, # No session ID for document uploads + content="Skill match response is not valid", + ) yield error_message return - - skill_match : ChatMessageSkillAssessment = final_message + + skill_match: ChatMessageSkillAssessment = final_message assessment = skill_match.skill_assessment if not assessment: error_message = ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads - content="Skill assessment could not be generated" - ) + session_id=MOCK_UUID, # No session ID for document uploads + content="Skill assessment could not be generated", + ) yield error_message return - + await database.cache_skill_match(cache_key, assessment) logger.info(f"💾 Cached new skill match for candidate {candidate.id} as {cache_key}") logger.info(f"✅ Skill match: {assessment.evidence_strength} {skill}") yield skill_match return - + try: + async def to_json(method): try: async for message in method: - json_data = message.model_dump(mode='json', by_alias=True) + json_data = message.model_dump(mode="json", by_alias=True) json_str = json.dumps(json_data) yield f"data: {json_str}\n\n".encode("utf-8") except Exception as e: @@ -1574,19 +1568,26 @@ async def get_candidate_skill_match( logger.error(backstory_traceback.format_exc()) logger.error(f"❌ Document upload error: {e}") return StreamingResponse( - iter([json.dumps(ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads - content="Failed to generate skill assessment" - ).model_dump(mode='json', by_alias=True))]), - media_type="text/event-stream" - ) + iter( + [ + json.dumps( + ChatMessageError( + session_id=MOCK_UUID, # No session ID for document uploads + content="Failed to generate skill assessment", + ).model_dump(mode="json", by_alias=True) + ) + ] + ), + media_type="text/event-stream", + ) + @router.post("/job-score") async def get_candidate_job_score( job_requirements: JobRequirements = Body(...), skills: List[SkillAssessment] = Body(...), - current_user = Depends(get_current_user_or_guest), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user_or_guest), + database: RedisDatabase = Depends(get_database), ): # Initialize counters required_skills_total = 0 @@ -1656,69 +1657,67 @@ async def get_candidate_job_score( preferred_skills_matched += 1 # If no skills were found, return empty statistics if required_skills_total == 0 and preferred_skills_total == 0: - return create_success_response({ - "required_skills": { - "total": 0, - "matched": 0, - "percentage": 0.0, - }, - "preferred_skills": { - "total": 0, - "matched": 0, - "percentage": 0.0, - }, - "overall_match": { - "total": 0, - "matched": 0, - "percentage": 0.0, - }, - }) + return create_success_response( + { + "required_skills": { + "total": 0, + "matched": 0, + "percentage": 0.0, + }, + "preferred_skills": { + "total": 0, + "matched": 0, + "percentage": 0.0, + }, + "overall_match": { + "total": 0, + "matched": 0, + "percentage": 0.0, + }, + } + ) # Calculate percentages - required_match_percent = ( - (required_skills_matched / required_skills_total * 100) - if required_skills_total > 0 - else 0 - ) + required_match_percent = (required_skills_matched / required_skills_total * 100) if required_skills_total > 0 else 0 preferred_match_percent = ( - (preferred_skills_matched / preferred_skills_total * 100) - if preferred_skills_total > 0 - else 0 + (preferred_skills_matched / preferred_skills_total * 100) if preferred_skills_total > 0 else 0 ) overall_total = required_skills_total + preferred_skills_total overall_matched = required_skills_matched + preferred_skills_matched - overall_match_percent = ( - (overall_matched / overall_total * 100) if overall_total > 0 else 0 + overall_match_percent = (overall_matched / overall_total * 100) if overall_total > 0 else 0 + + return create_success_response( + { + "required_skills": { + "total": required_skills_total, + "matched": required_skills_matched, + "percentage": round(required_match_percent, 1), + }, + "preferred_skills": { + "total": preferred_skills_total, + "matched": preferred_skills_matched, + "percentage": round(preferred_match_percent, 1), + }, + "overall_match": { + "total": overall_total, + "matched": overall_matched, + "percentage": round(overall_match_percent, 1), + }, + } ) - return create_success_response({ - "required_skills": { - "total": required_skills_total, - "matched": required_skills_matched, - "percentage": round(required_match_percent, 1), - }, - "preferred_skills": { - "total": preferred_skills_total, - "matched": preferred_skills_matched, - "percentage": round(preferred_match_percent, 1), - }, - "overall_match": { - "total": overall_total, - "matched": overall_matched, - "percentage": round(overall_match_percent, 1), - }, - }) @router.post("/{candidate_id}/{job_id}/generate-resume") async def generate_resume( candidate_id: str = Path(...), job_id: str = Path(...), - current_user = Depends(get_current_user_or_guest), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user_or_guest), + database: RedisDatabase = Depends(get_database), ) -> StreamingResponse: skills: List[SkillAssessment] = [] - + """Get skill match for a candidate against a requirement with caching""" + async def message_stream_generator(): logger.info(f"🔍 Looking up candidate and job details for {candidate_id}/{job_id}") @@ -1727,10 +1726,10 @@ async def generate_resume( logger.error(f"❌ Candidate with ID '{candidate_id}' not found") error_message = ChatMessageError( session_id=MOCK_UUID, # No session ID for document uploads - content=f"Candidate with ID '{candidate_id}' not found" + content=f"Candidate with ID '{candidate_id}' not found", ) yield error_message - return + return candidate = Candidate.model_validate(candidate_data) job_data = await database.get_job(job_id) @@ -1738,7 +1737,7 @@ async def generate_resume( logger.error(f"❌ Job with ID '{job_id}' not found") error_message = ChatMessageError( session_id=MOCK_UUID, # No session ID for document uploads - content=f"Job with ID '{job_id}' not found" + content=f"Job with ID '{job_id}' not found", ) yield error_message return @@ -1747,21 +1746,25 @@ async def generate_resume( uninitalized = False requirements = get_requirements_list(job) - logger.info(f"🔍 Checking skill match for candidate {candidate.username} against job {job.id}'s {len(requirements)} requirements.") + logger.info( + f"🔍 Checking skill match for candidate {candidate.username} against job {job.id}'s {len(requirements)} requirements." + ) for req in requirements: - skill = req.get('requirement', None) + skill = req.get("requirement", None) if not skill: logger.warning(f"⚠️ No 'requirement' found in entry: {req}") continue cache_key = get_skill_cache_key(candidate.id, skill) - assessment : SkillAssessment | None = await database.get_cached_skill_match(cache_key) + assessment: SkillAssessment | None = await database.get_cached_skill_match(cache_key) if not assessment: logger.info(f"💾 No cached skill match data: {cache_key}, {candidate.id}, {skill}") uninitalized = True break if assessment and assessment.skill.lower() != skill.lower(): - logger.warning(f"❌ Cached skill match for {candidate.username} does not match requested skill: {assessment.skill} != {skill} ({cache_key}).") + logger.warning( + f"❌ Cached skill match for {candidate.username} does not match requested skill: {assessment.skill} != {skill} ({cache_key})." + ) uninitalized = True break @@ -1772,29 +1775,31 @@ async def generate_resume( logger.error("❌ Uninitialized skill match data, cannot generate resume") error_message = ChatMessageError( session_id=MOCK_UUID, # No session ID for document uploads - content="Uninitialized skill match data, cannot generate resume" + content="Uninitialized skill match data, cannot generate resume", ) yield error_message return - - logger.info(f"🔍 Generating resume for candidate {candidate.username}, job {job.id}, with {len(skills)} skills.") - + + logger.info( + f"🔍 Generating resume for candidate {candidate.username}, job {job.id}, with {len(skills)} skills." + ) + async with entities.get_candidate_entity(candidate=candidate) as candidate_entity: agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.GENERATE_RESUME) if not agent: error_message = ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads - content="No skill match agent found for this candidate" - ) + session_id=MOCK_UUID, # No session ID for document uploads + content="No skill match agent found for this candidate", + ) yield error_message return - + final_message = None if not isinstance(agent, GenerateResume): error_message = ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads - content="Agent is not a GenerateResume instance" - ) + session_id=MOCK_UUID, # No session ID for document uploads + content="Agent is not a GenerateResume instance", + ) yield error_message return async for generated_message in agent.generate_resume( @@ -1811,46 +1816,49 @@ async def generate_resume( logger.error(f"❌ AI generation error: {content}") yield f"data: {json.dumps({'status': 'error'})}\n\n" return - + # If the message is not done, convert it to a ChatMessageBase to remove # metadata and other unnecessary fields for streaming if generated_message.status != ApiStatusType.DONE: - if not isinstance(generated_message, ChatMessageStreaming) and not isinstance(generated_message, ChatMessageStatus): + if not isinstance(generated_message, ChatMessageStreaming) and not isinstance( + generated_message, ChatMessageStatus + ): raise TypeError( f"Expected ChatMessageStreaming or ChatMessageStatus, got {type(generated_message)}" ) - yield generated_message# Convert to ChatMessageBase for streaming + yield generated_message # Convert to ChatMessageBase for streaming # Store reference to the complete AI message if generated_message.status == ApiStatusType.DONE: final_message = generated_message - break - + break + if final_message is None: error_message = ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads - content="No skill match found for the given requirement" - ) + session_id=MOCK_UUID, # No session ID for document uploads + content="No skill match found for the given requirement", + ) yield error_message return - + if not isinstance(final_message, ChatMessageResume): error_message = ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads - content="Skill match response is not valid" - ) + session_id=MOCK_UUID, # No session ID for document uploads + content="Skill match response is not valid", + ) yield error_message return - - resume : ChatMessageResume = final_message + + resume: ChatMessageResume = final_message yield resume return - + try: + async def to_json(method): try: async for message in method: - json_data = message.model_dump(mode='json', by_alias=True) + json_data = message.model_dump(mode="json", by_alias=True) json_str = json.dumps(json_data) yield f"data: {json_str}\n\n".encode("utf-8") except Exception as e: @@ -1874,21 +1882,28 @@ async def generate_resume( logger.error(backstory_traceback.format_exc()) logger.error(f"❌ Document upload error: {e}") return StreamingResponse( - iter([json.dumps(ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads - content="Failed to generate skill assessment" - ).model_dump(mode='json', by_alias=True))]), - media_type="text/event-stream" - ) + iter( + [ + json.dumps( + ChatMessageError( + session_id=MOCK_UUID, # No session ID for document uploads + content="Failed to generate skill assessment", + ).model_dump(mode="json", by_alias=True) + ) + ] + ), + media_type="text/event-stream", + ) + @rate_limited(guest_per_minute=5, user_per_minute=30, admin_per_minute=100) @router.get("/{username}/chat-sessions") async def get_candidate_chat_sessions( username: str = Path(...), - current_user = Depends(get_current_user_or_guest), + current_user=Depends(get_current_user_or_guest), page: int = Query(1, ge=1), limit: int = Query(20, ge=1, le=100), - database: RedisDatabase = Depends(get_database) + database: RedisDatabase = Depends(get_database), ): """Get all chat sessions related to a specific candidate""" try: @@ -1896,70 +1911,65 @@ async def get_candidate_chat_sessions( # Find candidate by username all_candidates_data = await database.get_all_candidates() candidates_list = [Candidate.model_validate(data) for data in all_candidates_data.values()] - - matching_candidates = [ - c for c in candidates_list - if c.username.lower() == username.lower() - ] - + + matching_candidates = [c for c in candidates_list if c.username.lower() == username.lower()] + if not matching_candidates: return JSONResponse( status_code=404, - content=create_error_response("CANDIDATE_NOT_FOUND", f"Candidate with username '{username}' not found") + content=create_error_response("CANDIDATE_NOT_FOUND", f"Candidate with username '{username}' not found"), ) - + candidate = matching_candidates[0] - + # Get all chat sessions all_sessions_data = await database.get_all_chat_sessions() sessions_list = [] - + for index, session_data in enumerate(all_sessions_data.values()): try: session = ChatSession.model_validate(session_data) if session.user_id != current_user.id: # User can only access their own sessions - logger.info(f"🔗 Skipping session {session.id} - not owned by user {current_user.id} (created by {session.user_id})") + logger.info( + f"🔗 Skipping session {session.id} - not owned by user {current_user.id} (created by {session.user_id})" + ) continue # Check if this session is related to the candidate context = session.context - if (context and - context.related_entity_type == "candidate" and - context.related_entity_id == candidate.id): + if context and context.related_entity_type == "candidate" and context.related_entity_id == candidate.id: sessions_list.append(session) except Exception as e: logger.error(backstory_traceback.format_exc()) logger.error(f"❌ Failed to validate session ({index}): {e}") logger.error(f"❌ Session data: {session_data}") continue - + # Sort by last activity (most recent first) sessions_list.sort(key=lambda x: x.last_activity, reverse=True) - + # Apply pagination total = len(sessions_list) start = (page - 1) * limit end = start + limit paginated_sessions = sessions_list[start:end] - + paginated_response = create_paginated_response( - [s.model_dump(by_alias=True) for s in paginated_sessions], - page, limit, total + [s.model_dump(by_alias=True) for s in paginated_sessions], page, limit, total ) - - return create_success_response({ - "candidate": { - "id": candidate.id, - "username": candidate.username, - "fullName": candidate.full_name, - "email": candidate.email - }, - "sessions": paginated_response - }) - + + return create_success_response( + { + "candidate": { + "id": candidate.id, + "username": candidate.username, + "fullName": candidate.full_name, + "email": candidate.email, + }, + "sessions": paginated_response, + } + ) + except Exception as e: logger.error(f"❌ Get candidate chat sessions error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("FETCH_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e))) diff --git a/src/backend/routes/chat.py b/src/backend/routes/chat.py index b152833..0b8d6cc 100644 --- a/src/backend/routes/chat.py +++ b/src/backend/routes/chat.py @@ -4,96 +4,69 @@ Chat routes import json import uuid from datetime import datetime, UTC -from typing import (Dict, Any) +from typing import Dict, Any -from fastapi import ( - APIRouter, Depends, Body, Depends, Query, Path, - Body, APIRouter -) +from fastapi import APIRouter, Depends, Body, Depends, Query, Path, Body, APIRouter from fastapi.responses import JSONResponse from database.manager import RedisDatabase from logger import logger -from utils.dependencies import ( - get_database, get_current_user, get_current_user_or_guest -) -from utils.responses import ( - create_success_response, create_error_response, create_paginated_response -) -from utils.helpers import ( - stream_agent_response -) +from utils.dependencies import get_database, get_current_user, get_current_user_or_guest +from utils.responses import create_success_response, create_error_response, create_paginated_response +from utils.helpers import stream_agent_response import backstory_traceback import entities.entity_manager as entities -from models import ( - Candidate, ChatMessageUser, Candidate, BaseUserWithType, ChatSession, ChatMessage -) +from models import Candidate, ChatMessageUser, Candidate, BaseUserWithType, ChatSession, ChatMessage # Create router for authentication endpoints router = APIRouter(prefix="/chat", tags=["chat"]) + @router.post("/sessions/{session_id}/archive") async def archive_chat_session( - session_id: str = Path(...), - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + session_id: str = Path(...), current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database) ): """Archive a chat session""" try: session_data = await database.get_chat_session(session_id) if not session_data: - return JSONResponse( - status_code=404, - content=create_error_response("NOT_FOUND", "Chat session not found") - ) - + return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found")) + # Check if user owns this session or is admin if session_data.get("userId") != current_user.id: return JSONResponse( - status_code=403, - content=create_error_response("FORBIDDEN", "Cannot archive another user's session") + status_code=403, content=create_error_response("FORBIDDEN", "Cannot archive another user's session") ) - + await database.archive_chat_session(session_id) - - return create_success_response({ - "message": "Chat session archived successfully", - "sessionId": session_id - }) - + + return create_success_response({"message": "Chat session archived successfully", "sessionId": session_id}) + except Exception as e: logger.error(f"❌ Archive chat session error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("ARCHIVE_ERROR", str(e)) - ) - + return JSONResponse(status_code=500, content=create_error_response("ARCHIVE_ERROR", str(e))) + + @router.get("/statistics") -async def get_chat_statistics( - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) -): +async def get_chat_statistics(current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)): """Get chat statistics (admin/analytics endpoint)""" try: stats = await database.get_chat_statistics() return create_success_response(stats) except Exception as e: logger.error(f"❌ Get chat statistics error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("STATS_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("STATS_ERROR", str(e))) @router.post("/sessions") async def create_chat_session( session_data: Dict[str, Any] = Body(...), current_user: BaseUserWithType = Depends(get_current_user_or_guest), - database: RedisDatabase = Depends(get_database) + database: RedisDatabase = Depends(get_database), ): """Create a new chat session with optional candidate username association""" try: @@ -101,31 +74,30 @@ async def create_chat_session( username = session_data.get("username") candidate_id = None candidate_data = None - + # If username is provided, look up the candidate if username: logger.info(f"🔍 Looking up candidate with username: {username}") - + # Get all candidates and find by username all_candidates_data = await database.get_all_candidates() candidates_list = [Candidate.model_validate(data) for data in all_candidates_data.values()] - + # Find candidate by username (case-insensitive) - matching_candidates = [ - c for c in candidates_list - if c.username.lower() == username.lower() - ] - + matching_candidates = [c for c in candidates_list if c.username.lower() == username.lower()] + if not matching_candidates: return JSONResponse( status_code=404, - content=create_error_response("CANDIDATE_NOT_FOUND", f"Candidate with username '{username}' not found") + content=create_error_response( + "CANDIDATE_NOT_FOUND", f"Candidate with username '{username}' not found" + ), ) - + candidate_data = matching_candidates[0] candidate_id = candidate_data.id logger.info(f"✅ Found candidate: {candidate_data.full_name} (ID: {candidate_id})") - + # Add required fields session_id = str(uuid.uuid4()) session_data["id"] = session_id @@ -138,7 +110,7 @@ async def create_chat_session( if candidate_id and candidate_data: context["relatedEntityId"] = candidate_id context["relatedEntityType"] = "candidate" - + # Add candidate info to additional context for AI reference additional_context = context.get("additionalContext", {}) additional_context["candidateInfo"] = { @@ -148,72 +120,83 @@ async def create_chat_session( "username": candidate_data.username, "skills": [skill.name for skill in candidate_data.skills] if candidate_data.skills else [], "experience": len(candidate_data.experience) if candidate_data.experience else 0, - "location": candidate_data.location.city if candidate_data.location else "Unknown" + "location": candidate_data.location.city if candidate_data.location else "Unknown", } context["additionalContext"] = additional_context - + # Set a descriptive title if not provided if not session_data.get("title"): session_data["title"] = f"Chat about {candidate_data.full_name}" - + session_data["context"] = context - + # Create chat session chat_session = ChatSession.model_validate(session_data) await database.set_chat_session(chat_session.id, chat_session.model_dump()) - - logger.info(f"✅ Chat session created: {chat_session.id} for user {current_user.id}" + - (f" about candidate {candidate_data.full_name}" if candidate_data else "")) - + + logger.info( + f"✅ Chat session created: {chat_session.id} for user {current_user.id}" + + (f" about candidate {candidate_data.full_name}" if candidate_data else "") + ) + return create_success_response(chat_session.model_dump(by_alias=True)) - + except Exception as e: logger.error(backstory_traceback.format_exc()) logger.error(f"❌ Chat session creation error: {e}") logger.info(json.dumps(session_data, indent=2)) - return JSONResponse( - status_code=400, - content=create_error_response("CREATION_FAILED", str(e)) - ) + return JSONResponse(status_code=400, content=create_error_response("CREATION_FAILED", str(e))) + @router.post("/sessions/messages/stream") async def post_chat_session_message_stream( user_message: ChatMessageUser = Body(...), - current_user = Depends(get_current_user_or_guest), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user_or_guest), + database: RedisDatabase = Depends(get_database), ): """Post a message to a chat session and stream the response with persistence""" try: chat_session_data = await database.get_chat_session(user_message.session_id) if not chat_session_data: logger.info("🔗 Chat session not found for session ID: " + user_message.session_id) - return JSONResponse( - status_code=404, - content=create_error_response("NOT_FOUND", "Chat session not found") - ) + return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found")) chat_session = ChatSession.model_validate(chat_session_data) chat_type = chat_session.context.type - candidate_info = chat_session.context.additional_context.get("candidateInfo", {}) if chat_session.context and chat_session.context.additional_context else None + candidate_info = ( + chat_session.context.additional_context.get("candidateInfo", {}) + if chat_session.context and chat_session.context.additional_context + else None + ) # Get candidate info if this chat is about a specific candidate if candidate_info: - logger.info(f"🔗 Chat session {user_message.session_id} about candidate {candidate_info['name']} accessed by user {current_user.id}") + logger.info( + f"🔗 Chat session {user_message.session_id} about candidate {candidate_info['name']} accessed by user {current_user.id}" + ) else: - logger.info(f"🔗 Chat session {user_message.session_id} type {chat_type} accessed by user {current_user.id}") + logger.info( + f"🔗 Chat session {user_message.session_id} type {chat_type} accessed by user {current_user.id}" + ) return JSONResponse( status_code=400, - content=create_error_response("CANDIDATE_REQUIRED", "This chat session requires a candidate association") + content=create_error_response( + "CANDIDATE_REQUIRED", "This chat session requires a candidate association" + ), ) candidate_data = await database.get_candidate(candidate_info["id"]) if candidate_info else None - candidate : Candidate | None = Candidate.model_validate(candidate_data) if candidate_data else None + candidate: Candidate | None = Candidate.model_validate(candidate_data) if candidate_data else None if not candidate: - logger.info(f"🔗 Candidate not found for chat session {user_message.session_id} with ID {candidate_info['id']}") + logger.info( + f"🔗 Candidate not found for chat session {user_message.session_id} with ID {candidate_info['id']}" + ) return JSONResponse( status_code=404, - content=create_error_response("CANDIDATE_NOT_FOUND", "Candidate not found for this chat session") + content=create_error_response("CANDIDATE_NOT_FOUND", "Candidate not found for this chat session"), ) - logger.info(f"🔗 User {current_user.id} posting message to chat session {user_message.session_id} with query length: {len(user_message.content)}") + logger.info( + f"🔗 User {current_user.id} posting message to chat session {user_message.session_id} with query length: {len(user_message.content)}" + ) async with entities.get_candidate_entity(candidate=candidate) as candidate_entity: # Entity automatically released when done @@ -222,13 +205,13 @@ async def post_chat_session_message_stream( logger.info(f"🔗 No chat agent found for session {user_message.session_id} with type {chat_type}") return JSONResponse( status_code=400, - content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type") + content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type"), ) # Persist user message to database await database.add_chat_message(user_message.session_id, user_message.model_dump()) logger.info(f"💬 User message saved to database for session {user_message.session_id}") - + # Update session last activity chat_session_data["lastActivity"] = datetime.now(UTC).isoformat() await database.set_chat_session(user_message.session_id, chat_session_data) @@ -239,35 +222,30 @@ async def post_chat_session_message_stream( database=database, chat_session_data=chat_session_data, ) - + except Exception: logger.error(backstory_traceback.format_exc()) logger.error("❌ Chat message streaming error") - return JSONResponse( - status_code=500, - content=create_error_response("STREAMING_ERROR", "") - ) + return JSONResponse(status_code=500, content=create_error_response("STREAMING_ERROR", "")) + @router.get("/sessions/{session_id}/messages") async def get_chat_session_messages( session_id: str = Path(...), - current_user = Depends(get_current_user_or_guest), + current_user=Depends(get_current_user_or_guest), page: int = Query(1, ge=1), limit: int = Query(50, ge=1, le=100), # Increased default for chat messages - database: RedisDatabase = Depends(get_database) + database: RedisDatabase = Depends(get_database), ): """Get persisted chat messages for a session""" try: chat_session_data = await database.get_chat_session(session_id) if not chat_session_data: - return JSONResponse( - status_code=404, - content=create_error_response("NOT_FOUND", "Chat session not found") - ) + return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found")) # Get messages from database chat_messages = await database.get_chat_messages(session_id) - + # Convert to ChatMessage objects and sort by timestamp messages_list = [] for msg_data in chat_messages: @@ -277,69 +255,61 @@ async def get_chat_session_messages( except Exception as e: logger.warning(f"⚠️ Failed to validate message: {e}") continue - + # Sort by timestamp (oldest first for chat history) messages_list.sort(key=lambda x: x.timestamp) - + # Apply pagination total = len(messages_list) start = (page - 1) * limit end = start + limit paginated_messages = messages_list[start:end] - + paginated_response = create_paginated_response( - [m.model_dump(by_alias=True) for m in paginated_messages], - page, limit, total + [m.model_dump(by_alias=True) for m in paginated_messages], page, limit, total ) return create_success_response(paginated_response) - + except Exception as e: logger.error(f"❌ Get chat messages error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("FETCH_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e))) + @router.patch("/sessions/{session_id}") async def update_chat_session( session_id: str = Path(...), updates: Dict[str, Any] = Body(...), - current_user = Depends(get_current_user_or_guest), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user_or_guest), + database: RedisDatabase = Depends(get_database), ): """Update a chat session's properties""" try: # Get the existing session session_data = await database.get_chat_session(session_id) if not session_data: - return JSONResponse( - status_code=404, - content=create_error_response("NOT_FOUND", "Chat session not found") - ) - + return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found")) + session = ChatSession.model_validate(session_data) - + # Check authorization - user can only update their own sessions if session.user_id != current_user.id: return JSONResponse( - status_code=403, - content=create_error_response("FORBIDDEN", "Cannot update another user's chat session") + status_code=403, content=create_error_response("FORBIDDEN", "Cannot update another user's chat session") ) - + # Validate and apply updates allowed_fields = {"title", "context", "isArchived", "systemPrompt"} filtered_updates = {k: v for k, v in updates.items() if k in allowed_fields} - + if not filtered_updates: return JSONResponse( - status_code=400, - content=create_error_response("INVALID_UPDATES", "No valid fields provided for update") + status_code=400, content=create_error_response("INVALID_UPDATES", "No valid fields provided for update") ) - + # Apply updates to session data session_dict = session.model_dump() - + # Handle special field mappings (camelCase to snake_case) if "isArchived" in filtered_updates: session_dict["is_archived"] = filtered_updates["isArchived"] @@ -351,7 +321,7 @@ async def update_chat_session( # Merge context updates with existing context existing_context = session_dict.get("context", {}) context_updates = filtered_updates["context"] - + # Update specific context fields while preserving others for context_key, context_value in context_updates.items(): if context_key == "additionalContext": @@ -368,141 +338,116 @@ async def update_chat_session( snake_key = "related_entity_type" elif context_key == "aiParameters": snake_key = "ai_parameters" - + existing_context[snake_key] = context_value - + session_dict["context"] = existing_context - + # Update last activity timestamp session_dict["last_activity"] = datetime.now(UTC).isoformat() - + # Validate the updated session updated_session = ChatSession.model_validate(session_dict) - + # Save to database await database.set_chat_session(session_id, updated_session.model_dump()) - + logger.info(f"✅ Chat session {session_id} updated by user {current_user.id}") - + return create_success_response(updated_session.model_dump(by_alias=True)) - + except ValueError as ve: logger.warning(f"⚠️ Validation error updating chat session: {ve}") - return JSONResponse( - status_code=400, - content=create_error_response("VALIDATION_ERROR", str(ve)) - ) + return JSONResponse(status_code=400, content=create_error_response("VALIDATION_ERROR", str(ve))) except Exception as e: logger.error(f"❌ Update chat session error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("UPDATE_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("UPDATE_ERROR", str(e))) + @router.delete("/sessions/{session_id}") async def delete_chat_session( session_id: str = Path(...), - current_user = Depends(get_current_user_or_guest), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user_or_guest), + database: RedisDatabase = Depends(get_database), ): """Delete a chat session and all its messages""" try: # Get the session to verify it exists and check ownership session_data = await database.get_chat_session(session_id) if not session_data: - return JSONResponse( - status_code=404, - content=create_error_response("NOT_FOUND", "Chat session not found") - ) - + return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found")) + session = ChatSession.model_validate(session_data) - + # Check authorization - user can only delete their own sessions if session.user_id != current_user.id: return JSONResponse( - status_code=403, - content=create_error_response("FORBIDDEN", "Cannot delete another user's chat session") + status_code=403, content=create_error_response("FORBIDDEN", "Cannot delete another user's chat session") ) - + # Delete all messages associated with this session try: await database.delete_chat_messages(session_id) chat_messages = await database.get_chat_messages(session_id) - message_count = len(chat_messages) + message_count = len(chat_messages) logger.info(f"🗑️ Deleted {message_count} messages from session {session_id}") - + except Exception as e: logger.warning(f"⚠️ Error deleting messages for session {session_id}: {e}") # Continue with session deletion even if message deletion fails - + # Delete the session itself await database.delete_chat_session(session_id) - + logger.info(f"🗑️ Chat session {session_id} deleted by user {current_user.id}") - - return create_success_response({ - "success": True, - "message": "Chat session deleted successfully", - "sessionId": session_id - }) - + + return create_success_response( + {"success": True, "message": "Chat session deleted successfully", "sessionId": session_id} + ) + except Exception as e: logger.error(f"❌ Delete chat session error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("DELETE_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("DELETE_ERROR", str(e))) + @router.patch("/sessions/{session_id}/reset") async def reset_chat_session( session_id: str = Path(...), - current_user = Depends(get_current_user_or_guest), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user_or_guest), + database: RedisDatabase = Depends(get_database), ): """Delete a chat session and all its messages""" try: # Get the session to verify it exists and check ownership session_data = await database.get_chat_session(session_id) if not session_data: - return JSONResponse( - status_code=404, - content=create_error_response("NOT_FOUND", "Chat session not found") - ) - + return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found")) + session = ChatSession.model_validate(session_data) - + # Check authorization - user can only delete their own sessions if session.user_id != current_user.id: return JSONResponse( - status_code=403, - content=create_error_response("FORBIDDEN", "Cannot reset another user's chat session") + status_code=403, content=create_error_response("FORBIDDEN", "Cannot reset another user's chat session") ) - + # Delete all messages associated with this session try: await database.delete_chat_messages(session_id) chat_messages = await database.get_chat_messages(session_id) - message_count = len(chat_messages) + message_count = len(chat_messages) logger.info(f"🗑️ Deleted {message_count} messages from session {session_id}") - + except Exception as e: logger.warning(f"⚠️ Error deleting messages for session {session_id}: {e}") # Continue with session deletion even if message deletion fails - - + logger.info(f"🗑️ Chat session {session_id} reset by user {current_user.id}") - - return create_success_response({ - "success": True, - "message": "Chat session reset successfully", - "sessionId": session_id - }) - - except Exception as e: - logger.error(f"❌ Reset chat session error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("RESET_ERROR", str(e)) + + return create_success_response( + {"success": True, "message": "Chat session reset successfully", "sessionId": session_id} ) - + except Exception as e: + logger.error(f"❌ Reset chat session error: {e}") + return JSONResponse(status_code=500, content=create_error_response("RESET_ERROR", str(e))) diff --git a/src/backend/routes/debug.py b/src/backend/routes/debug.py index 3ae8cfd..1aae8ec 100644 --- a/src/backend/routes/debug.py +++ b/src/backend/routes/debug.py @@ -9,58 +9,52 @@ from fastapi.responses import JSONResponse from database.manager import RedisDatabase from logger import logger -from utils.dependencies import ( - get_current_admin, get_database -) +from utils.dependencies import get_current_admin, get_database from utils.responses import create_success_response, create_error_response # Create router for authentication endpoints router = APIRouter(prefix="/auth", tags=["authentication"]) + @router.get("/guest/{guest_id}") async def debug_guest_session( - guest_id: str = Path(...), - admin_user = Depends(get_current_admin), - database: RedisDatabase = Depends(get_database) + guest_id: str = Path(...), admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database) ): """Debug guest session issues (admin only)""" try: # Check primary storage - primary_data = await database.redis.hget("guests", guest_id) # type: ignore + primary_data = await database.redis.hget("guests", guest_id) # type: ignore primary_exists = primary_data is not None - - # Check backup storage + + # Check backup storage backup_data = await database.redis.get(f"guest_backup:{guest_id}") backup_exists = backup_data is not None - + # Check user lookup user_lookup = await database.get_user_by_id(guest_id) - + # Get TTL info primary_ttl = await database.redis.ttl("guests") backup_ttl = await database.redis.ttl(f"guest_backup:{guest_id}") - + debug_info = { "guest_id": guest_id, "primary_storage": { "exists": primary_exists, "data": json.loads(primary_data) if primary_data else None, - "ttl": primary_ttl + "ttl": primary_ttl, }, "backup_storage": { - "exists": backup_exists, + "exists": backup_exists, "data": json.loads(backup_data) if backup_data else None, - "ttl": backup_ttl + "ttl": backup_ttl, }, "user_lookup": user_lookup, - "timestamp": datetime.now(UTC).isoformat() + "timestamp": datetime.now(UTC).isoformat(), } - + return create_success_response(debug_info) - + except Exception as e: logger.error(f"❌ Debug guest session error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("DEBUG_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("DEBUG_ERROR", str(e))) diff --git a/src/backend/routes/employers.py b/src/backend/routes/employers.py index c5529ae..56ecdd8 100644 --- a/src/backend/routes/employers.py +++ b/src/backend/routes/employers.py @@ -10,12 +10,8 @@ from fastapi.responses import JSONResponse from database.manager import RedisDatabase from logger import logger -from models import ( - CreateEmployerRequest -) -from utils.dependencies import ( - get_database -) +from models import CreateEmployerRequest +from utils.dependencies import get_database from utils.responses import create_success_response, create_error_response from email_service import email_service from utils.auth_utils import AuthenticationManager @@ -23,34 +19,27 @@ from utils.auth_utils import AuthenticationManager # Create router for job endpoints router = APIRouter(prefix="/employers", tags=["employers"]) + @router.post("") async def create_employer_with_verification( - request: CreateEmployerRequest, - background_tasks: BackgroundTasks, - database: RedisDatabase = Depends(get_database) + request: CreateEmployerRequest, background_tasks: BackgroundTasks, database: RedisDatabase = Depends(get_database) ): """Create a new employer with email verification""" try: # Similar to candidate creation but for employer auth_manager = AuthenticationManager(database) - - user_exists, conflict_field = await auth_manager.check_user_exists( - request.email, - request.username - ) - + + user_exists, conflict_field = await auth_manager.check_user_exists(request.email, request.username) + if user_exists and conflict_field: return JSONResponse( status_code=409, - content=create_error_response( - "USER_EXISTS", - f"A user with this {conflict_field} already exists" - ) + content=create_error_response("USER_EXISTS", f"A user with this {conflict_field} already exists"), ) - + employer_id = str(uuid.uuid4()) current_time = datetime.now(timezone.utc) - + employer_data = { "id": employer_id, "email": request.email, @@ -64,46 +53,35 @@ async def create_employer_with_verification( "updatedAt": current_time.isoformat(), "status": "pending", # Not active until verified "userType": "employer", - "location": { - "city": "", - "country": "", - "remote": False - }, - "socialLinks": [] + "location": {"city": "", "country": "", "remote": False}, + "socialLinks": [], } - + verification_token = secrets.token_urlsafe(32) - + await database.store_email_verification_token( request.email, verification_token, "employer", + {"employer_data": employer_data, "password": request.password, "username": request.username}, + ) + + background_tasks.add_task( + email_service.send_verification_email, request.email, verification_token, request.company_name + ) + + logger.info(f"✅ Employer registration initiated for: {request.email}") + + return create_success_response( { - "employer_data": employer_data, - "password": request.password, - "username": request.username + "message": "Registration successful! Please check your email to verify your account.", + "email": request.email, + "verificationRequired": True, } ) - - background_tasks.add_task( - email_service.send_verification_email, - request.email, - verification_token, - request.company_name - ) - - logger.info(f"✅ Employer registration initiated for: {request.email}") - - return create_success_response({ - "message": "Registration successful! Please check your email to verify your account.", - "email": request.email, - "verificationRequired": True - }) - + except Exception as e: logger.error(f"❌ Employer creation error: {e}") return JSONResponse( - status_code=500, - content=create_error_response("CREATION_FAILED", "Failed to create employer account") + status_code=500, content=create_error_response("CREATION_FAILED", "Failed to create employer account") ) - diff --git a/src/backend/routes/jobs.py b/src/backend/routes/jobs.py index 4e0f9b6..bef9715 100644 --- a/src/backend/routes/jobs.py +++ b/src/backend/routes/jobs.py @@ -21,14 +21,24 @@ from utils.helpers import create_job_from_content, filter_and_paginate, get_docu from database.manager import RedisDatabase from logger import logger from models import ( - MOCK_UUID, ApiActivityType, ApiStatusType, ChatContextType, ChatMessage, ChatMessageError, ChatMessageStatus, DocumentType, Job, JobRequirementsMessage, Candidate, Employer -) -from utils.dependencies import ( - get_current_admin, get_database, get_current_user + MOCK_UUID, + ApiActivityType, + ApiStatusType, + ChatContextType, + ChatMessage, + ChatMessageError, + ChatMessageStatus, + DocumentType, + Job, + JobRequirementsMessage, + Candidate, + Employer, ) +from utils.dependencies import get_current_admin, get_database, get_current_user from utils.responses import create_paginated_response, create_success_response, create_error_response import utils.llm_proxy as llm_manager import entities.entity_manager as entities + # Create router for job endpoints router = APIRouter(prefix="/jobs", tags=["jobs"]) @@ -38,14 +48,14 @@ async def reformat_as_markdown(database: RedisDatabase, candidate_entity: Candid if not chat_agent: error_message = ChatMessageError( session_id=MOCK_UUID, # No session ID for document uploads - content="No agent found for job requirements chat type" + content="No agent found for job requirements chat type", ) yield error_message return status_message = ChatMessageStatus( session_id=MOCK_UUID, # No session ID for document uploads content="Reformatting job description as markdown...", - activity=ApiActivityType.CONVERTING + activity=ApiActivityType.CONVERTING, ) yield status_message @@ -58,7 +68,7 @@ async def reformat_as_markdown(database: RedisDatabase, candidate_entity: Candid system_prompt=""" You are a document editor. Take the provided job description and reformat as legible markdown. Return only the markdown content, no other text. Make sure all content is included. -""" +""", ): pass @@ -66,11 +76,11 @@ Return only the markdown content, no other text. Make sure all content is includ logger.error("❌ Failed to reformat job description to markdown") error_message = ChatMessageError( session_id=MOCK_UUID, # No session ID for document uploads - content="Failed to reformat job description" + content="Failed to reformat job description", ) yield error_message return - chat_message : ChatMessage = message + chat_message: ChatMessage = message try: chat_message.content = chat_agent.extract_markdown_from_text(chat_message.content) except Exception: @@ -84,11 +94,11 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida status_message = ChatMessageStatus( session_id=MOCK_UUID, # No session ID for document uploads content=f"Initiating connection with {current_user.first_name}'s AI agent...", - activity=ApiActivityType.INFO + activity=ApiActivityType.INFO, ) yield status_message - await asyncio.sleep(0) # Let the status message propagate - + await asyncio.sleep(0) # Let the status message propagate + async with entities.get_candidate_entity(candidate=current_user) as candidate_entity: message = None async for message in reformat_as_markdown(database, candidate_entity, content): @@ -98,7 +108,7 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida if not message or not isinstance(message, ChatMessage): error_message = ChatMessageError( session_id=MOCK_UUID, # No session ID for document uploads - content="Failed to reformat job description" + content="Failed to reformat job description", ) yield error_message return @@ -108,23 +118,20 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida if not chat_agent: error_message = ChatMessageError( session_id=MOCK_UUID, # No session ID for document uploads - content="No agent found for job requirements chat type" + content="No agent found for job requirements chat type", ) yield error_message return status_message = ChatMessageStatus( session_id=MOCK_UUID, # No session ID for document uploads content="Analyzing document for company and requirement details...", - activity=ApiActivityType.SEARCHING + activity=ApiActivityType.SEARCHING, ) yield status_message message = None async for message in chat_agent.generate( - llm=llm_manager.get_llm(), - model=defines.model, - session_id=MOCK_UUID, - prompt=markdown_message.content + llm=llm_manager.get_llm(), model=defines.model, session_id=MOCK_UUID, prompt=markdown_message.content ): if message.status != ApiStatusType.DONE: yield message @@ -132,23 +139,22 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida if not message or not isinstance(message, JobRequirementsMessage): error_message = ChatMessageError( session_id=MOCK_UUID, # No session ID for document uploads - content="Job extraction did not convert successfully" + content="Job extraction did not convert successfully", ) yield error_message return - - job_requirements : JobRequirementsMessage = message + + job_requirements: JobRequirementsMessage = message logger.info(f"✅ Successfully generated job requirements for job {job_requirements.id}") yield job_requirements return - @router.post("") async def create_job( job_data: Dict[str, Any] = Body(...), - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user), + database: RedisDatabase = Depends(get_database), ): """Create a new job""" try: @@ -158,24 +164,21 @@ async def create_job( job.id = str(uuid.uuid4()) job.owner_id = current_user.id job.owner = current_user - + await database.set_job(job.id, job.model_dump()) - + return create_success_response(job.model_dump(by_alias=True)) - + except Exception as e: logger.error(f"❌ Job creation error: {e}") - return JSONResponse( - status_code=400, - content=create_error_response("CREATION_FAILED", str(e)) - ) + return JSONResponse(status_code=400, content=create_error_response("CREATION_FAILED", str(e))) @router.post("") async def create_candidate_job( job_data: Dict[str, Any] = Body(...), - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user), + database: RedisDatabase = Depends(get_database), ): """Create a new job""" isinstance(current_user, Employer) @@ -187,46 +190,39 @@ async def create_candidate_job( job.id = str(uuid.uuid4()) job.owner_id = current_user.id job.owner = current_user - + await database.set_job(job.id, job.model_dump()) - + return create_success_response(job.model_dump(by_alias=True)) - + except Exception as e: logger.error(f"❌ Job creation error: {e}") - return JSONResponse( - status_code=400, - content=create_error_response("CREATION_FAILED", str(e)) - ) + return JSONResponse(status_code=400, content=create_error_response("CREATION_FAILED", str(e))) @router.patch("/{job_id}") async def update_job( job_id: str = Path(...), updates: Dict[str, Any] = Body(...), - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user), + database: RedisDatabase = Depends(get_database), ): """Update a candidate""" try: job_data = await database.get_job(job_id) if not job_data: logger.warning(f"⚠️ Job not found for update: {job_data}") - return JSONResponse( - status_code=404, - content=create_error_response("NOT_FOUND", "Job not found") - ) - + return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Job not found")) + job = Job.model_validate(job_data) - + # Check authorization (user can only update their own profile) if current_user.is_admin is False and job.owner_id != current_user.id: logger.warning(f"⚠️ Unauthorized update attempt by user {current_user.id} on job {job_id}") return JSONResponse( - status_code=403, - content=create_error_response("FORBIDDEN", "Cannot update another user's job") + status_code=403, content=create_error_response("FORBIDDEN", "Cannot update another user's job") ) - + # Apply updates updates["updatedAt"] = datetime.now(UTC).isoformat() logger.info(f"🔄 Updating job {job_id} with data: {updates}") @@ -234,36 +230,33 @@ async def update_job( job_dict.update(updates) updated_job = Job.model_validate(job_dict) await database.set_job(job_id, updated_job.model_dump()) - + return create_success_response(updated_job.model_dump(by_alias=True)) - + except Exception as e: logger.error(f"❌ Update job error: {e}") - return JSONResponse( - status_code=400, - content=create_error_response("UPDATE_FAILED", str(e)) - ) + return JSONResponse(status_code=400, content=create_error_response("UPDATE_FAILED", str(e))) + @router.post("/from-content") async def create_job_from_description( - content: str = Body(...), - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + content: str = Body(...), current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database) ): """Upload a document for the current candidate""" + async def content_stream_generator(content): # Verify user is a candidate if current_user.user_type != "candidate": logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}") error_message = ChatMessageError( session_id=MOCK_UUID, # No session ID for document uploads - content="Only candidates can upload documents" + content="Only candidates can upload documents", ) yield error_message return - + logger.info(f"📁 Received file content: size='{len(content)} bytes'") - + last_yield_was_streaming = False async for message in create_job_from_content(database=database, current_user=current_user, content=content): if message.status != ApiStatusType.STREAMING: @@ -277,10 +270,11 @@ async def create_job_from_description( return try: + async def to_json(method): try: async for message in method: - json_data = message.model_dump(mode='json', by_alias=True) + json_data = message.model_dump(mode="json", by_alias=True) json_str = json.dumps(json_data) yield f"data: {json_str}\n\n".encode("utf-8") except Exception as e: @@ -304,18 +298,25 @@ async def create_job_from_description( logger.error(backstory_traceback.format_exc()) logger.error(f"❌ Document upload error: {e}") return StreamingResponse( - iter([json.dumps(ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads - content="Failed to upload document" - ).model_dump(by_alias=True)).encode("utf-8")]), - media_type="text/event-stream" + iter( + [ + json.dumps( + ChatMessageError( + session_id=MOCK_UUID, # No session ID for document uploads + content="Failed to upload document", + ).model_dump(by_alias=True) + ).encode("utf-8") + ] + ), + media_type="text/event-stream", ) + @router.post("/upload") async def create_job_from_file( file: UploadFile = File(...), - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user), + database: RedisDatabase = Depends(get_database), ): """Upload a job document for the current candidate and create a Job""" # Check file size (limit to 10MB) @@ -324,66 +325,81 @@ async def create_job_from_file( if len(file_content) > max_size: logger.info(f"⚠️ File too large: {file.filename} ({len(file_content)} bytes)") return StreamingResponse( - iter([json.dumps(ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads - content="File size exceeds 10MB limit" - ).model_dump(by_alias=True)).encode("utf-8")]), - media_type="text/event-stream" + iter( + [ + json.dumps( + ChatMessageError( + session_id=MOCK_UUID, # No session ID for document uploads + content="File size exceeds 10MB limit", + ).model_dump(by_alias=True) + ).encode("utf-8") + ] + ), + media_type="text/event-stream", ) if len(file_content) == 0: logger.info(f"⚠️ File is empty: {file.filename}") return StreamingResponse( - iter([json.dumps(ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads - content="File is empty" - ).model_dump(by_alias=True)).encode("utf-8")]), - media_type="text/event-stream" + iter( + [ + json.dumps( + ChatMessageError( + session_id=MOCK_UUID, # No session ID for document uploads + content="File is empty", + ).model_dump(by_alias=True) + ).encode("utf-8") + ] + ), + media_type="text/event-stream", ) """Upload a document for the current candidate""" + async def upload_stream_generator(file_content): # Verify user is a candidate if current_user.user_type != "candidate": logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}") error_message = ChatMessageError( session_id=MOCK_UUID, # No session ID for document uploads - content="Only candidates can upload documents" + content="Only candidates can upload documents", ) yield error_message return - - file.filename = re.sub(r'^.*/', '', file.filename) if file.filename else '' # Sanitize filename + + file.filename = re.sub(r"^.*/", "", file.filename) if file.filename else "" # Sanitize filename if not file.filename or file.filename.strip() == "": logger.warning("⚠️ File upload attempt with missing filename") error_message = ChatMessageError( session_id=MOCK_UUID, # No session ID for document uploads - content="File must have a valid filename" + content="File must have a valid filename", ) yield error_message return - - logger.info(f"📁 Received file upload: filename='{file.filename}', content_type='{file.content_type}', size='{len(file_content)} bytes'") - + + logger.info( + f"📁 Received file upload: filename='{file.filename}', content_type='{file.content_type}', size='{len(file_content)} bytes'" + ) + # Validate file type - allowed_types = ['.txt', '.md', '.docx', '.pdf', '.png', '.jpg', '.jpeg', '.gif'] + allowed_types = [".txt", ".md", ".docx", ".pdf", ".png", ".jpg", ".jpeg", ".gif"] file_extension = pathlib.Path(file.filename).suffix.lower() if file.filename else "" - + if file_extension not in allowed_types: logger.warning(f"⚠️ Invalid file type: {file_extension} for file {file.filename}") error_message = ChatMessageError( session_id=MOCK_UUID, # No session ID for document uploads - content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}" + content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}", ) yield error_message return - + document_type = get_document_type_from_filename(file.filename or "unknown.txt") - + if document_type != DocumentType.MARKDOWN and document_type != DocumentType.TXT: status_message = ChatMessageStatus( session_id=MOCK_UUID, # No session ID for document uploads content=f"Converting content from {document_type}...", - activity=ApiActivityType.CONVERTING + activity=ApiActivityType.CONVERTING, ) yield status_message try: @@ -391,8 +407,8 @@ async def create_job_from_file( stream = io.BytesIO(file_content) stream_info = StreamInfo( extension=file_extension, # e.g., ".pdf" - url=file.filename # optional, helps with logging and guessing - ) + url=file.filename, # optional, helps with logging and guessing + ) result = md.convert_stream(stream, stream_info=stream_info, output_format="markdown") file_content = result.text_content logger.info(f"✅ Converted {file.filename} to Markdown format") @@ -404,16 +420,19 @@ async def create_job_from_file( yield error_message logger.error(f"❌ Error converting {file.filename} to Markdown: {e}") return - - async for message in create_job_from_content(database=database, current_user=current_user, content=file_content): + + async for message in create_job_from_content( + database=database, current_user=current_user, content=file_content + ): yield message return try: + async def to_json(method): try: async for message in method: - json_data = message.model_dump(mode='json', by_alias=True) + json_data = message.model_dump(mode="json", by_alias=True) json_str = json.dumps(json_data) yield f"data: {json_str}\n\n".encode("utf-8") except Exception as e: @@ -437,40 +456,39 @@ async def create_job_from_file( logger.error(backstory_traceback.format_exc()) logger.error(f"❌ Document upload error: {e}") return StreamingResponse( - iter([json.dumps(ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads - content="Failed to upload document" - ).model_dump(mode='json', by_alias=True)).encode("utf-8")]), - media_type="text/event-stream" + iter( + [ + json.dumps( + ChatMessageError( + session_id=MOCK_UUID, # No session ID for document uploads + content="Failed to upload document", + ).model_dump(mode="json", by_alias=True) + ).encode("utf-8") + ] + ), + media_type="text/event-stream", ) + @router.get("/{job_id}") -async def get_job( - job_id: str = Path(...), - database: RedisDatabase = Depends(get_database) -): +async def get_job(job_id: str = Path(...), database: RedisDatabase = Depends(get_database)): """Get a job by ID""" try: job_data = await database.get_job(job_id) if not job_data: - return JSONResponse( - status_code=404, - content=create_error_response("NOT_FOUND", "Job not found") - ) - + return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Job not found")) + # Increment view count job_data["views"] = job_data.get("views", 0) + 1 await database.set_job(job_id, job_data) - + job = Job.model_validate(job_data) return create_success_response(job.model_dump(by_alias=True)) - + except Exception as e: logger.error(f"❌ Get job error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("FETCH_ERROR", str(e)) - ) + return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e))) + @router.get("") async def get_jobs( @@ -479,37 +497,32 @@ async def get_jobs( sortBy: Optional[str] = Query(None, alias="sortBy"), sortOrder: str = Query("desc", pattern="^(asc|desc)$", alias="sortOrder"), filters: Optional[str] = Query(None), - database: RedisDatabase = Depends(get_database) + database: RedisDatabase = Depends(get_database), ): """Get paginated list of jobs""" try: filter_dict = None if filters: filter_dict = json.loads(filters) - + # Get all jobs from Redis all_jobs_data = await database.get_all_jobs() jobs_list = [] for job in all_jobs_data.values(): jobs_list.append(Job.model_validate(job)) - - paginated_jobs, total = filter_and_paginate( - jobs_list, page, limit, sortBy, sortOrder, filter_dict - ) - + + paginated_jobs, total = filter_and_paginate(jobs_list, page, limit, sortBy, sortOrder, filter_dict) + paginated_response = create_paginated_response( - [j.model_dump(by_alias=True) for j in paginated_jobs], - page, limit, total + [j.model_dump(by_alias=True) for j in paginated_jobs], page, limit, total ) - + return create_success_response(paginated_response) - + except Exception as e: logger.error(f"❌ Get jobs error: {e}") - return JSONResponse( - status_code=400, - content=create_error_response("FETCH_FAILED", str(e)) - ) + return JSONResponse(status_code=400, content=create_error_response("FETCH_FAILED", str(e))) + @router.get("/search") async def search_jobs( @@ -517,84 +530,67 @@ async def search_jobs( filters: Optional[str] = Query(None), page: int = Query(1, ge=1), limit: int = Query(20, ge=1, le=100), - database: RedisDatabase = Depends(get_database) + database: RedisDatabase = Depends(get_database), ): """Search jobs""" try: filter_dict = {} if filters: filter_dict = json.loads(filters) - + # Get all jobs from Redis all_jobs_data = await database.get_all_jobs() jobs_list = [Job.model_validate(data) for data in all_jobs_data.values() if data.get("is_active", True)] - + if query: query_lower = query.lower() jobs_list = [ - j for j in jobs_list - if ((j.title and query_lower in j.title.lower()) or - (j.description and query_lower in j.description.lower()) or - any(query_lower in skill.lower() for skill in getattr(j, "skills", []) or [])) + j + for j in jobs_list + if ( + (j.title and query_lower in j.title.lower()) + or (j.description and query_lower in j.description.lower()) + or any(query_lower in skill.lower() for skill in getattr(j, "skills", []) or []) + ) ] - - paginated_jobs, total = filter_and_paginate( - jobs_list, page, limit, filters=filter_dict - ) - + + paginated_jobs, total = filter_and_paginate(jobs_list, page, limit, filters=filter_dict) + paginated_response = create_paginated_response( - [j.model_dump(by_alias=True) for j in paginated_jobs], - page, limit, total + [j.model_dump(by_alias=True) for j in paginated_jobs], page, limit, total ) - + return create_success_response(paginated_response) - + except Exception as e: logger.error(f"❌ Search jobs error: {e}") - return JSONResponse( - status_code=400, - content=create_error_response("SEARCH_FAILED", str(e)) - ) + return JSONResponse(status_code=400, content=create_error_response("SEARCH_FAILED", str(e))) @router.delete("/{job_id}") async def delete_job( - job_id: str = Path(...), - admin_user = Depends(get_current_admin), - database: RedisDatabase = Depends(get_database) + job_id: str = Path(...), admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database) ): """Delete a Job""" try: # Check if admin user if not admin_user.is_admin: logger.warning(f"⚠️ Unauthorized delete attempt by user {admin_user.id}") - return JSONResponse( - status_code=403, - content=create_error_response("FORBIDDEN", "Only admins can delete") - ) - + return JSONResponse(status_code=403, content=create_error_response("FORBIDDEN", "Only admins can delete")) + # Get candidate data job_data = await database.get_job(job_id) if not job_data: logger.warning(f"⚠️ Candidate not found for deletion: {job_id}") - return JSONResponse( - status_code=404, - content=create_error_response("NOT_FOUND", "Job not found") - ) + return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Job not found")) # Delete job from database await database.delete_job(job_id) - + logger.info(f"🗑️ Job deleted: {job_id} by admin {admin_user.id}") - - return create_success_response({ - "message": "Job deleted successfully", - "jobId": job_id - }) - + + return create_success_response({"message": "Job deleted successfully", "jobId": job_id}) + except Exception as e: logger.error(f"❌ Delete job error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("DELETE_ERROR", "Failed to delete job") - ) + return JSONResponse(status_code=500, content=create_error_response("DELETE_ERROR", "Failed to delete job")) diff --git a/src/backend/routes/providers.py b/src/backend/routes/providers.py index c222a2f..ed4d3c2 100644 --- a/src/backend/routes/providers.py +++ b/src/backend/routes/providers.py @@ -6,40 +6,37 @@ from utils.llm_proxy import LLMProvider, get_llm router = APIRouter(prefix="/providers", tags=["providers"]) + @router.get("/models") async def list_models(provider: Optional[str] = None): """List available models for a provider""" try: llm = get_llm() - + provider_enum = None if provider: try: provider_enum = LLMProvider(provider.lower()) except ValueError: - raise HTTPException( - status_code=400, - detail=f"Unsupported provider: {provider}" - ) - + raise HTTPException(status_code=400, detail=f"Unsupported provider: {provider}") + models = await llm.list_models(provider_enum) - return { - "provider": provider_enum.value if provider_enum else llm.default_provider.value, - "models": models - } - + return {"provider": provider_enum.value if provider_enum else llm.default_provider.value, "models": models} + except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + @router.get("") async def list_providers(): """List all configured providers""" llm = get_llm() return { "providers": [provider.value for provider in llm._initialized_providers], - "default": llm.default_provider.value + "default": llm.default_provider.value, } + @router.post("/{provider}/set-default") async def set_default_provider(provider: str): """Set the default provider""" @@ -51,6 +48,7 @@ async def set_default_provider(provider: str): except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) + # Health check endpoint @router.get("/health") async def health_check(): @@ -59,5 +57,5 @@ async def health_check(): return { "status": "healthy", "providers_configured": len(llm._initialized_providers), - "default_provider": llm.default_provider.value + "default_provider": llm.default_provider.value, } diff --git a/src/backend/routes/resumes.py b/src/backend/routes/resumes.py index 9b0220c..1bb628a 100644 --- a/src/backend/routes/resumes.py +++ b/src/backend/routes/resumes.py @@ -11,26 +11,24 @@ from fastapi.responses import StreamingResponse import backstory_traceback as backstory_traceback from database.manager import RedisDatabase from logger import logger -from models import ( - MOCK_UUID, ChatMessageError, Job, Candidate, Resume, ResumeMessage -) -from utils.dependencies import ( - get_database, get_current_user -) +from models import MOCK_UUID, ChatMessageError, Job, Candidate, Resume, ResumeMessage +from utils.dependencies import get_database, get_current_user from utils.responses import create_success_response - + # Create router for authentication endpoints router = APIRouter(prefix="/resumes", tags=["resumes"]) + @router.post("/{candidate_id}/{job_id}") async def create_candidate_resume( - candidate_id: str = Path(..., description="ID of the candidate"), + candidate_id: str = Path(..., description="ID of the candidate"), job_id: str = Path(..., description="ID of the job"), resume_content: str = Body(...), - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user), + database: RedisDatabase = Depends(get_database), ): """Create a new resume for a candidate/job combination""" + async def message_stream_generator(): logger.info(f"🔍 Looking up candidate and job details for {candidate_id}/{job_id}") @@ -39,10 +37,10 @@ async def create_candidate_resume( logger.error(f"❌ Candidate with ID '{candidate_id}' not found") error_message = ChatMessageError( session_id=MOCK_UUID, # No session ID for document uploads - content=f"Candidate with ID '{candidate_id}' not found" + content=f"Candidate with ID '{candidate_id}' not found", ) yield error_message - return + return candidate = Candidate.model_validate(candidate_data) job_data = await database.get_job(job_id) @@ -50,14 +48,16 @@ async def create_candidate_resume( logger.error(f"❌ Job with ID '{job_id}' not found") error_message = ChatMessageError( session_id=MOCK_UUID, # No session ID for document uploads - content=f"Job with ID '{job_id}' not found" + content=f"Job with ID '{job_id}' not found", ) yield error_message return job = Job.model_validate(job_data) - logger.info(f"📄 Saving resume for candidate {candidate.first_name} {candidate.last_name} for job '{job.title}'") - + logger.info( + f"📄 Saving resume for candidate {candidate.first_name} {candidate.last_name} for job '{job.title}'" + ) + # Job and Candidate are valid. Save the resume resume = Resume( job_id=job_id, @@ -66,28 +66,26 @@ async def create_candidate_resume( ) resume_message: ResumeMessage = ResumeMessage( session_id=MOCK_UUID, # No session ID for document uploads - resume=resume + resume=resume, ) - + # Save to database success = await database.set_resume(current_user.id, resume.model_dump()) if not success: - error_message = ChatMessageError( - session_id=MOCK_UUID, - content="Failed to save resume to database" - ) + error_message = ChatMessageError(session_id=MOCK_UUID, content="Failed to save resume to database") yield error_message return - + logger.info(f"✅ Successfully saved resume {resume_message.resume.id} for user {current_user.id}") yield resume_message return - + try: + async def to_json(method): try: async for message in method: - json_data = message.model_dump(mode='json', by_alias=True) + json_data = message.model_dump(mode="json", by_alias=True) json_str = json.dumps(json_data) yield f"data: {json_str}\n\n".encode("utf-8") except Exception as e: @@ -111,22 +109,26 @@ async def create_candidate_resume( logger.error(backstory_traceback.format_exc()) logger.error(f"❌ Resume creation error: {e}") return StreamingResponse( - iter([json.dumps(ChatMessageError( - session_id=MOCK_UUID, # No session ID for document uploads - content="Failed to create resume" - ).model_dump(mode='json', by_alias=True))]), - media_type="text/event-stream" + iter( + [ + json.dumps( + ChatMessageError( + session_id=MOCK_UUID, # No session ID for document uploads + content="Failed to create resume", + ).model_dump(mode="json", by_alias=True) + ) + ] + ), + media_type="text/event-stream", ) + @router.get("") -async def get_user_resumes( - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) -): +async def get_user_resumes(current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)): """Get all resumes for the current user""" try: resumes_data = await database.get_all_resumes_for_user(current_user.id) - resumes : List[Resume] = [Resume.model_validate(data) for data in resumes_data] + resumes: List[Resume] = [Resume.model_validate(data) for data in resumes_data] for resume in resumes: job_data = await database.get_job(resume.job_id) if job_data: @@ -135,156 +137,130 @@ async def get_user_resumes( if candidate_data: resume.candidate = Candidate.model_validate(candidate_data) resumes.sort(key=lambda x: x.updated_at, reverse=True) # Sort by creation date - return create_success_response({ - "resumes": resumes, - "count": len(resumes) - }) + return create_success_response({"resumes": resumes, "count": len(resumes)}) except Exception as e: logger.error(f"❌ Error retrieving resumes for user {current_user.id}: {e}") raise HTTPException(status_code=500, detail="Failed to retrieve resumes") + @router.get("/{resume_id}") async def get_resume( resume_id: str = Path(..., description="ID of the resume"), - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user), + database: RedisDatabase = Depends(get_database), ): """Get a specific resume by ID""" try: resume = await database.get_resume(current_user.id, resume_id) if not resume: raise HTTPException(status_code=404, detail="Resume not found") - - return { - "success": True, - "resume": resume - } + + return {"success": True, "resume": resume} except HTTPException: raise except Exception as e: logger.error(f"❌ Error retrieving resume {resume_id} for user {current_user.id}: {e}") raise HTTPException(status_code=500, detail="Failed to retrieve resume") + @router.delete("/{resume_id}") async def delete_resume( resume_id: str = Path(..., description="ID of the resume"), - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user), + database: RedisDatabase = Depends(get_database), ): """Delete a specific resume""" try: success = await database.delete_resume(current_user.id, resume_id) if not success: raise HTTPException(status_code=404, detail="Resume not found") - - return { - "success": True, - "message": f"Resume {resume_id} deleted successfully" - } + + return {"success": True, "message": f"Resume {resume_id} deleted successfully"} except HTTPException: raise except Exception as e: logger.error(f"❌ Error deleting resume {resume_id} for user {current_user.id}: {e}") raise HTTPException(status_code=500, detail="Failed to delete resume") + @router.get("/candidate/{candidate_id}") async def get_resumes_by_candidate( candidate_id: str = Path(..., description="ID of the candidate"), - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user), + database: RedisDatabase = Depends(get_database), ): """Get all resumes for a specific candidate""" try: resumes = await database.get_resumes_by_candidate(current_user.id, candidate_id) - return { - "success": True, - "candidate_id": candidate_id, - "resumes": resumes, - "count": len(resumes) - } + return {"success": True, "candidate_id": candidate_id, "resumes": resumes, "count": len(resumes)} except Exception as e: logger.error(f"❌ Error retrieving resumes for candidate {candidate_id}: {e}") raise HTTPException(status_code=500, detail="Failed to retrieve candidate resumes") + @router.get("/job/{job_id}") async def get_resumes_by_job( job_id: str = Path(..., description="ID of the job"), - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user), + database: RedisDatabase = Depends(get_database), ): """Get all resumes for a specific job""" try: resumes = await database.get_resumes_by_job(current_user.id, job_id) - return { - "success": True, - "job_id": job_id, - "resumes": resumes, - "count": len(resumes) - } + return {"success": True, "job_id": job_id, "resumes": resumes, "count": len(resumes)} except Exception as e: logger.error(f"❌ Error retrieving resumes for job {job_id}: {e}") raise HTTPException(status_code=500, detail="Failed to retrieve job resumes") + @router.get("/search") async def search_resumes( q: str = Query(..., description="Search query"), - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user), + database: RedisDatabase = Depends(get_database), ): """Search resumes by content""" try: resumes = await database.search_resumes_for_user(current_user.id, q) - return { - "success": True, - "query": q, - "resumes": resumes, - "count": len(resumes) - } + return {"success": True, "query": q, "resumes": resumes, "count": len(resumes)} except Exception as e: logger.error(f"❌ Error searching resumes for user {current_user.id}: {e}") raise HTTPException(status_code=500, detail="Failed to search resumes") + @router.get("/stats") async def get_resume_statistics( - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database) ): """Get resume statistics for the current user""" try: stats = await database.get_resume_statistics(current_user.id) - return { - "success": True, - "statistics": stats - } + return {"success": True, "statistics": stats} except Exception as e: logger.error(f"❌ Error retrieving resume statistics for user {current_user.id}: {e}") raise HTTPException(status_code=500, detail="Failed to retrieve resume statistics") + @router.put("/{resume_id}") async def update_resume( resume_id: str = Path(..., description="ID of the resume"), resume: str = Body(..., description="Updated resume content"), - current_user = Depends(get_current_user), - database: RedisDatabase = Depends(get_database) + current_user=Depends(get_current_user), + database: RedisDatabase = Depends(get_database), ): """Update the content of a specific resume""" try: - updates = { - "resume": resume, - "updated_at": datetime.now(UTC).isoformat() - } - + updates = {"resume": resume, "updated_at": datetime.now(UTC).isoformat()} + updated_resume_data = await database.update_resume(current_user.id, resume_id, updates) if not updated_resume_data: logger.warning(f"⚠️ Resume {resume_id} not found for user {current_user.id}") raise HTTPException(status_code=404, detail="Resume not found") updated_resume = Resume.model_validate(updated_resume_data) if updated_resume_data else None - - return create_success_response({ - "success": True, - "message": f"Resume {resume_id} updated successfully", - "resume": updated_resume - }) + + return create_success_response( + {"success": True, "message": f"Resume {resume_id} updated successfully", "resume": updated_resume} + ) except HTTPException: raise except Exception as e: diff --git a/src/backend/routes/system.py b/src/backend/routes/system.py index b13eb0e..c1a3fc8 100644 --- a/src/backend/routes/system.py +++ b/src/backend/routes/system.py @@ -10,6 +10,7 @@ from utils.responses import create_success_response # Create router for authentication endpoints router = APIRouter(prefix="/system", tags=["system"]) + async def get_redis() -> Redis: """Dependency to get Redis client""" return redis_manager.get_client() @@ -19,9 +20,11 @@ async def get_redis() -> Redis: async def get_system_info(request: Request): """Get system information""" from system_info import system_info # Import system_info function from system_info module + system = system_info() - return create_success_response(system.model_dump(mode='json')) + return create_success_response(system.model_dump(mode="json")) + @router.get("/redis/stats") async def redis_stats(redis: Redis = Depends(get_redis)): @@ -33,7 +36,7 @@ async def redis_stats(redis: Redis = Depends(get_redis)): "total_commands_processed": info.get("total_commands_processed"), "keyspace_hits": info.get("keyspace_hits"), "keyspace_misses": info.get("keyspace_misses"), - "uptime_in_seconds": info.get("uptime_in_seconds") + "uptime_in_seconds": info.get("uptime_in_seconds"), } except Exception as e: raise HTTPException(status_code=503, detail=f"Redis stats unavailable: {e}") diff --git a/src/backend/routes/users.py b/src/backend/routes/users.py index 6832d4c..6698016 100644 --- a/src/backend/routes/users.py +++ b/src/backend/routes/users.py @@ -7,23 +7,17 @@ from fastapi.responses import JSONResponse from database.manager import RedisDatabase from logger import logger -from models import ( - BaseUserWithType -) -from utils.dependencies import ( - get_database -) +from models import BaseUserWithType +from utils.dependencies import get_database from utils.responses import create_success_response, create_error_response # Create router for job endpoints router = APIRouter(prefix="/users", tags=["users"]) + # reference can be candidateId, username, or email @router.get("/users/{reference}") -async def get_user( - reference: str = Path(...), - database: RedisDatabase = Depends(get_database) -): +async def get_user(reference: str = Path(...), database: RedisDatabase = Depends(get_database)): """Get a candidate by username""" try: # Normalize reference to lowercase for case-insensitive search @@ -32,16 +26,15 @@ async def get_user( all_candidate_data = await database.get_all_candidates() if not all_candidate_data: logger.warning("⚠️ No users found in database") - return JSONResponse( - status_code=404, - content=create_error_response("NOT_FOUND", "No users found") - ) - + return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "No users found")) + user_data = None for user in all_candidate_data.values(): - if (user.get("id", "").lower() == query_lower or - user.get("username", "").lower() == query_lower or - user.get("email", "").lower() == query_lower): + if ( + user.get("id", "").lower() == query_lower + or user.get("username", "").lower() == query_lower + or user.get("email", "").lower() == query_lower + ): user_data = user break @@ -49,32 +42,24 @@ async def get_user( all_guest_data = await database.get_all_guests() if not all_guest_data: logger.warning("⚠️ No guests found in database") - return JSONResponse( - status_code=404, - content=create_error_response("NOT_FOUND", "No users found") - ) + return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "No users found")) for user in all_guest_data.values(): - if (user.get("id", "").lower() == query_lower or - user.get("username", "").lower() == query_lower or - user.get("email", "").lower() == query_lower): + if ( + user.get("id", "").lower() == query_lower + or user.get("username", "").lower() == query_lower + or user.get("email", "").lower() == query_lower + ): user_data = user break - if not user_data: + if not user_data: logger.warning(f"⚠️ User nor Guest found for reference: {reference}") - return JSONResponse( - status_code=404, - content=create_error_response("NOT_FOUND", "User not found") - ) + return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "User not found")) user = BaseUserWithType.model_validate(user_data) return create_success_response(user.model_dump(by_alias=True)) - + except Exception as e: logger.error(f"❌ Get user error: {e}") - return JSONResponse( - status_code=500, - content=create_error_response("FETCH_ERROR", str(e)) - ) - + return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e))) diff --git a/src/backend/system_info.py b/src/backend/system_info.py index 46e5dc3..3803255 100644 --- a/src/backend/system_info.py +++ b/src/backend/system_info.py @@ -4,6 +4,7 @@ import subprocess import math from models import SystemInfo + def get_installed_ram(): try: with open("/proc/meminfo", "r") as f: @@ -19,9 +20,7 @@ def get_graphics_cards(): gpus = [] try: # Run the ze-monitor utility - result = subprocess.run( - ["ze-monitor"], capture_output=True, text=True, check=True - ) + result = subprocess.run(["ze-monitor"], capture_output=True, text=True, check=True) # Clean up the output (remove leading/trailing whitespace and newlines) output = result.stdout.strip() @@ -71,6 +70,7 @@ def get_cpu_info(): except Exception as e: return f"Error retrieving CPU info: {e}" + def system_info() -> SystemInfo: """ Collects system information including RAM, GPU, CPU, LLM model, embedding model, and context length. @@ -85,4 +85,4 @@ def system_info() -> SystemInfo: embedding_model=defines.embedding_model, max_context_length=defines.max_context, ) - return system \ No newline at end of file + return system diff --git a/src/backend/tools/basetools.py b/src/backend/tools/basetools.py index 8727110..d67d600 100644 --- a/src/backend/tools/basetools.py +++ b/src/backend/tools/basetools.py @@ -1,17 +1,17 @@ import os -from pydantic import BaseModel +from pydantic import BaseModel from typing import List, Optional, Any, Dict from datetime import datetime from typing import ( Any, ) -from bs4 import BeautifulSoup +from bs4 import BeautifulSoup -from geopy.geocoders import Nominatim -import pytz +from geopy.geocoders import Nominatim +import pytz import requests -import yfinance as yf +import yfinance as yf import logging @@ -122,9 +122,7 @@ def get_forecast(grid_endpoint): # Process the forecast data into a simpler format forecast = { - "location": data["properties"] - .get("relativeLocation", {}) - .get("properties", {}), + "location": data["properties"].get("relativeLocation", {}).get("properties", {}), "updated": data["properties"].get("updated", ""), "periods": [], } @@ -189,9 +187,7 @@ def TickerValue(ticker_symbols): if ticker_symbol == "": continue - url = ( - f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}" - ) + url = f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}" response = requests.get(url) data = response.json() @@ -244,9 +240,7 @@ def yfTickerValue(ticker_symbols): logging.error(f"Error fetching data for {ticker_symbol}: {e}") logging.error(traceback.format_exc()) - results.append( - {"error": f"Error fetching data for {ticker_symbol}: {str(e)}"} - ) + results.append({"error": f"Error fetching data for {ticker_symbol}: {str(e)}"}) return results[0] if len(results) == 1 else results @@ -278,8 +272,10 @@ def DateTime(timezone="America/Los_Angeles"): except Exception as e: return {"error": f"Invalid timezone {timezone}: {str(e)}"} + async def GenerateImage(llm, model: str, prompt: str): - return { "image_id": "image-a830a83-bd831" } + return {"image_id": "image-a830a83-bd831"} + async def AnalyzeSite(llm, model: str, url: str, question: str): """ @@ -347,7 +343,6 @@ async def AnalyzeSite(llm, model: str, url: str, question: str): return f"Error processing the website content: {str(e)}" - # %% class Function(BaseModel): name: str @@ -355,171 +350,181 @@ class Function(BaseModel): parameters: Dict[str, Any] returns: Optional[Dict[str, Any]] = {} + class Tool(BaseModel): type: str function: Function -tools : List[Tool] = [ -# Tool.model_validate({ -# "type": "function", -# "function": { -# "name": "GenerateImage", -# "description": """\ -# CRITICAL INSTRUCTIONS FOR IMAGE GENERATION: -# 1. Call this tool when users request images, drawings, or visual content -# 2. This tool returns an image_id (e.g., "img_abc123") -# 3. MANDATORY: You must respond with EXACTLY this format: -# 4. FORBIDDEN: DO NOT use markdown image syntax ![](url) -# 5. FORBIDDEN: DO NOT create fake URLs or file paths -# 6. FORBIDDEN: DO NOT use any other image embedding format - -# CORRECT EXAMPLE: -# User: "Draw a cat" -# Tool returns: {"image_id": "img_xyz789"} -# Your response: "Here's your cat image: " - -# WRONG EXAMPLES (DO NOT DO THIS): -# - ![](https://example.com/...) -# - ![Cat image](any_url) -# - - -# The format is the ONLY way to display images in this system. -# """, -# "parameters": { -# "type": "object", -# "properties": { -# "prompt": { -# "type": "string", -# "description": "Detailed image description including style, colors, subject, composition" -# } -# }, -# "required": ["prompt"] -# }, -# "returns": { -# "type": "object", -# "properties": { -# "image_id": { -# "type": "string", -# "description": "Unique identifier for the generated image. Use this EXACTLY in " -# } -# } -# } -# } -# }), - Tool.model_validate({ - "type": "function", - "function": { - "name": "TickerValue", - "description": "Get the current stock price of one or more ticker symbols. Returns an array of objects with 'symbol' and 'price' fields. Call this whenever you need to know the latest value of stock ticker symbols, for example when a user asks 'How much is Intel trading at?' or 'What are the prices of AAPL and MSFT?'", - "parameters": { - "type": "object", - "properties": { - "ticker": { - "type": "string", - "description": "The company stock ticker symbol. For multiple tickers, provide a comma-separated list (e.g., 'AAPL,MSFT,GOOGL').", +tools: List[Tool] = [ + # Tool.model_validate({ + # "type": "function", + # "function": { + # "name": "GenerateImage", + # "description": """\ + # CRITICAL INSTRUCTIONS FOR IMAGE GENERATION: + # 1. Call this tool when users request images, drawings, or visual content + # 2. This tool returns an image_id (e.g., "img_abc123") + # 3. MANDATORY: You must respond with EXACTLY this format: + # 4. FORBIDDEN: DO NOT use markdown image syntax ![](url) + # 5. FORBIDDEN: DO NOT create fake URLs or file paths + # 6. FORBIDDEN: DO NOT use any other image embedding format + # CORRECT EXAMPLE: + # User: "Draw a cat" + # Tool returns: {"image_id": "img_xyz789"} + # Your response: "Here's your cat image: " + # WRONG EXAMPLES (DO NOT DO THIS): + # - ![](https://example.com/...) + # - ![Cat image](any_url) + # - + # The format is the ONLY way to display images in this system. + # """, + # "parameters": { + # "type": "object", + # "properties": { + # "prompt": { + # "type": "string", + # "description": "Detailed image description including style, colors, subject, composition" + # } + # }, + # "required": ["prompt"] + # }, + # "returns": { + # "type": "object", + # "properties": { + # "image_id": { + # "type": "string", + # "description": "Unique identifier for the generated image. Use this EXACTLY in " + # } + # } + # } + # } + # }), + Tool.model_validate( + { + "type": "function", + "function": { + "name": "TickerValue", + "description": "Get the current stock price of one or more ticker symbols. Returns an array of objects with 'symbol' and 'price' fields. Call this whenever you need to know the latest value of stock ticker symbols, for example when a user asks 'How much is Intel trading at?' or 'What are the prices of AAPL and MSFT?'", + "parameters": { + "type": "object", + "properties": { + "ticker": { + "type": "string", + "description": "The company stock ticker symbol. For multiple tickers, provide a comma-separated list (e.g., 'AAPL,MSFT,GOOGL').", + }, }, + "required": ["ticker"], + "additionalProperties": False, }, - "required": ["ticker"], - "additionalProperties": False, }, - }, - }), - Tool.model_validate({ - "type": "function", - "function": { - "name": "AnalyzeSite", - "description": "Downloads the requested site and asks a second LLM agent to answer the question based on the site content. For example if the user says 'What are the top headlines on cnn.com?' you would use AnalyzeSite to get the answer. Only use this if the user asks about a specific site or company.", - "parameters": { - "type": "object", - "properties": { - "url": { - "type": "string", - "description": "The website URL to download and process", - }, - "question": { - "type": "string", - "description": "The question to ask the second LLM about the content", + } + ), + Tool.model_validate( + { + "type": "function", + "function": { + "name": "AnalyzeSite", + "description": "Downloads the requested site and asks a second LLM agent to answer the question based on the site content. For example if the user says 'What are the top headlines on cnn.com?' you would use AnalyzeSite to get the answer. Only use this if the user asks about a specific site or company.", + "parameters": { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "The website URL to download and process", + }, + "question": { + "type": "string", + "description": "The question to ask the second LLM about the content", + }, }, + "required": ["url", "question"], + "additionalProperties": False, }, - "required": ["url", "question"], - "additionalProperties": False, - }, - "returns": { - "type": "object", - "properties": { - "source": { - "type": "string", - "description": "Identifier for the source LLM", - }, - "content": { - "type": "string", - "description": "The complete response from the second LLM", - }, - "metadata": { - "type": "object", - "description": "Additional information about the response", + "returns": { + "type": "object", + "properties": { + "source": { + "type": "string", + "description": "Identifier for the source LLM", + }, + "content": { + "type": "string", + "description": "The complete response from the second LLM", + }, + "metadata": { + "type": "object", + "description": "Additional information about the response", + }, }, }, }, - }, - }), - Tool.model_validate({ - "type": "function", - "function": { - "name": "DateTime", - "description": "Get the current date and time in a specified timezone. For example if a user asks 'What time is it in Poland?' you would pass the Warsaw timezone to DateTime.", - "parameters": { - "type": "object", - "properties": { - "timezone": { - "type": "string", - "description": "Timezone name (e.g., 'UTC', 'America/New_York', 'Europe/London', 'America/Los_Angeles'). Default is 'America/Los_Angeles'.", - } - }, - "required": [], - }, - }, - }), - Tool.model_validate({ - "type": "function", - "function": { - "name": "WeatherForecast", - "description": "Get the full weather forecast as structured data for a given CITY and STATE location in the United States. For example, if the user asks 'What is the weather in Portland?' or 'What is the forecast for tomorrow?' use the provided data to answer the question.", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "City to find the weather forecast (e.g., 'Portland', 'Seattle').", - "minLength": 2, - }, - "state": { - "type": "string", - "description": "State to find the weather forecast (e.g., 'OR', 'WA').", - "minLength": 2, + } + ), + Tool.model_validate( + { + "type": "function", + "function": { + "name": "DateTime", + "description": "Get the current date and time in a specified timezone. For example if a user asks 'What time is it in Poland?' you would pass the Warsaw timezone to DateTime.", + "parameters": { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "Timezone name (e.g., 'UTC', 'America/New_York', 'Europe/London', 'America/Los_Angeles'). Default is 'America/Los_Angeles'.", + } }, + "required": [], }, - "required": ["city", "state"], - "additionalProperties": False, }, - }, - }), + } + ), + Tool.model_validate( + { + "type": "function", + "function": { + "name": "WeatherForecast", + "description": "Get the full weather forecast as structured data for a given CITY and STATE location in the United States. For example, if the user asks 'What is the weather in Portland?' or 'What is the forecast for tomorrow?' use the provided data to answer the question.", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City to find the weather forecast (e.g., 'Portland', 'Seattle').", + "minLength": 2, + }, + "state": { + "type": "string", + "description": "State to find the weather forecast (e.g., 'OR', 'WA').", + "minLength": 2, + }, + }, + "required": ["city", "state"], + "additionalProperties": False, + }, + }, + } + ), ] + class ToolEntry(BaseModel): enabled: bool = True tool: Tool + def llm_tools(tools: List[ToolEntry]) -> List[Dict[str, Any]]: - return [entry.tool.model_dump(mode='json') for entry in tools if entry.enabled is True] + return [entry.tool.model_dump(mode="json") for entry in tools if entry.enabled is True] + def all_tools() -> List[ToolEntry]: return [ToolEntry(tool=tool) for tool in tools] + def enabled_tools(tools: List[ToolEntry]) -> List[ToolEntry]: return [ToolEntry(tool=entry.tool) for entry in tools if entry.enabled is True] + tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite", "GenerateImage"] __all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"] - diff --git a/src/backend/utils/__init__.py b/src/backend/utils/__init__.py index 7fde0bd..e71a7c2 100644 --- a/src/backend/utils/__init__.py +++ b/src/backend/utils/__init__.py @@ -1,3 +1,3 @@ """ Utils package - Utility functions and dependencies -""" \ No newline at end of file +""" diff --git a/src/backend/utils/auth_utils.py b/src/backend/utils/auth_utils.py index d2aa785..0a93d84 100644 --- a/src/backend/utils/auth_utils.py +++ b/src/backend/utils/auth_utils.py @@ -5,65 +5,65 @@ Provides password hashing, verification, and security features """ from __future__ import annotations import traceback -import bcrypt +import bcrypt import secrets import logging from datetime import datetime, timezone, timedelta from typing import Dict, Any, Optional, Tuple -from pydantic import BaseModel +from pydantic import BaseModel logger = logging.getLogger(__name__) + class PasswordSecurity: """Handles password hashing and verification using bcrypt""" - + @staticmethod def hash_password(password: str) -> Tuple[str, str]: """ Hash a password with a random salt using bcrypt - + Args: password: Plain text password - + Returns: Tuple of (password_hash, salt) both as strings """ # Generate a random salt salt = bcrypt.gensalt() - + # Hash the password - password_hash = bcrypt.hashpw(password.encode('utf-8'), salt) - - return password_hash.decode('utf-8'), salt.decode('utf-8') - + password_hash = bcrypt.hashpw(password.encode("utf-8"), salt) + + return password_hash.decode("utf-8"), salt.decode("utf-8") + @staticmethod def verify_password(password: str, password_hash: str) -> bool: """ Verify a password against its hash - + Args: password: Plain text password to verify password_hash: Stored password hash - + Returns: True if password matches, False otherwise """ try: - return bcrypt.checkpw( - password.encode('utf-8'), - password_hash.encode('utf-8') - ) + return bcrypt.checkpw(password.encode("utf-8"), password_hash.encode("utf-8")) except Exception as e: logger.error(f"Password verification error: {e}") return False - + @staticmethod def generate_secure_token(length: int = 32) -> str: """Generate a cryptographically secure random token""" return secrets.token_urlsafe(length) + class AuthenticationRecord(BaseModel): """Authentication record for storing user credentials""" + user_id: str password_hash: str salt: str @@ -76,67 +76,70 @@ class AuthenticationRecord(BaseModel): mfa_secret: Optional[str] = None login_attempts: int = 0 locked_until: Optional[datetime] = None - + class Config: - json_encoders = { - datetime: lambda v: v.isoformat() if v else None - } + json_encoders = {datetime: lambda v: v.isoformat() if v else None} + class SecurityConfig: """Security configuration constants""" + MAX_LOGIN_ATTEMPTS = 5 ACCOUNT_LOCKOUT_DURATION_MINUTES = 15 PASSWORD_MIN_LENGTH = 8 TOKEN_EXPIRY_HOURS = 24 REFRESH_TOKEN_EXPIRY_DAYS = 30 + class AuthenticationManager: """Manages authentication operations with security features""" - + def __init__(self, database): self.database = database self.password_security = PasswordSecurity() - + async def create_user_authentication(self, user_id: str, password: str) -> AuthenticationRecord: """ Create authentication record for a new user - + Args: user_id: Unique user identifier password: Plain text password - + Returns: AuthenticationRecord object """ if len(password) < SecurityConfig.PASSWORD_MIN_LENGTH: raise ValueError(f"Password must be at least {SecurityConfig.PASSWORD_MIN_LENGTH} characters long") - + # Hash the password password_hash, salt = self.password_security.hash_password(password) - + # Create authentication record auth_record = AuthenticationRecord( user_id=user_id, password_hash=password_hash, salt=salt, last_password_change=datetime.now(timezone.utc), - login_attempts=0 + login_attempts=0, ) - + # Store in database await self.database.set_authentication(user_id, auth_record.model_dump()) - + logger.info(f"🔐 Created authentication record for user {user_id}") return auth_record - - async def verify_user_credentials(self, login: str, password: str) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]: + + async def verify_user_credentials( + self, login: str, password: str + ) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]: """ Verify user credentials with security checks - + Args: login: Email or username password: Plain text password - + Returns: Tuple of (is_valid, user_data, error_message) """ @@ -146,15 +149,15 @@ class AuthenticationManager: if not user_data: logger.warning(f"⚠️ Login attempt with non-existent user: {login}") return False, None, "Invalid credentials" - + # Get authentication record auth_record = await self.database.get_authentication(user_data["id"]) if not auth_record: logger.error(f"❌ No authentication record found for user {user_data['id']}") return False, None, "Authentication record not found" - + auth_data = AuthenticationRecord.model_validate(auth_record) - + # Check if account is locked if auth_data.locked_until and auth_data.locked_until > datetime.now(timezone.utc): time_until_unlock = auth_data.locked_until - datetime.now(timezone.utc) @@ -164,48 +167,54 @@ class AuthenticationManager: seconds = int(total_seconds % 60) time_until_unlock_str = f"{minutes}m {seconds}s" logger.warning(f"🔒 Account is locked for user {login} for another {time_until_unlock_str}.") - return False, None, f"Account is temporarily locked due to too many failed attempts. Retry after {time_until_unlock_str}" - + return ( + False, + None, + f"Account is temporarily locked due to too many failed attempts. Retry after {time_until_unlock_str}", + ) + # Verify password if not self.password_security.verify_password(password, auth_data.password_hash): # Increment failed attempts auth_data.login_attempts += 1 - + # Lock account if too many attempts if auth_data.login_attempts >= SecurityConfig.MAX_LOGIN_ATTEMPTS: auth_data.locked_until = datetime.now(timezone.utc) + timedelta( minutes=SecurityConfig.ACCOUNT_LOCKOUT_DURATION_MINUTES ) - logger.warning(f"🔒 Account locked for user {login} after {auth_data.login_attempts} failed attempts") - + logger.warning( + f"🔒 Account locked for user {login} after {auth_data.login_attempts} failed attempts" + ) + # Update authentication record await self.database.set_authentication(user_data["id"], auth_data.model_dump()) - + logger.warning(f"⚠️ Invalid password for user {login} (attempt {auth_data.login_attempts})") return False, None, "Invalid credentials" - + # Reset failed attempts on successful login if auth_data.login_attempts > 0: auth_data.login_attempts = 0 auth_data.locked_until = None await self.database.set_authentication(user_data["id"], auth_data.model_dump()) - + logger.info(f"✅ Successful authentication for user {login}") return True, user_data, None - + except Exception as e: logger.error(traceback.format_exc()) logger.error(f"❌ Authentication error for user {login}: {e}") return False, None, "Authentication failed" - + async def check_user_exists(self, email: str, username: str | None = None) -> Tuple[bool, Optional[str]]: """ Check if a user already exists with the given email or username - + Args: email: Email address to check username: Username to check (optional) - + Returns: Tuple of (exists, conflict_field) """ @@ -214,20 +223,20 @@ class AuthenticationManager: existing_user = await self.database.get_user(email) if existing_user: return True, "email" - + # Check username if provided if username: existing_user = await self.database.get_user(username) if existing_user: return True, "username" - + return False, None - + except Exception as e: logger.error(f"❌ Error checking user existence: {e}") # In case of error, assume user doesn't exist to avoid blocking creation return False, None - + async def update_last_login(self, user_id: str): """Update user's last login timestamp""" try: @@ -238,38 +247,40 @@ class AuthenticationManager: except Exception as e: logger.error(f"❌ Error updating last login for user {user_id}: {e}") + # Utility functions for common operations def validate_password_strength(password: str) -> Tuple[bool, list]: """ Validate password strength - + Args: password: Password to validate - + Returns: Tuple of (is_valid, list_of_issues) """ issues = [] - + if len(password) < SecurityConfig.PASSWORD_MIN_LENGTH: issues.append(f"Password must be at least {SecurityConfig.PASSWORD_MIN_LENGTH} characters long") - + if not any(c.isupper() for c in password): issues.append("Password must contain at least one uppercase letter") - + if not any(c.islower() for c in password): issues.append("Password must contain at least one lowercase letter") - + if not any(c.isdigit() for c in password): issues.append("Password must contain at least one digit") - + # Check for special characters special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?" if not any(c in special_chars for c in password): issues.append("Password must contain at least one special character") - + return len(issues) == 0, issues + def sanitize_login_input(login: str) -> str: """Sanitize login input (email or username)""" - return login.strip().lower() if login else "" \ No newline at end of file + return login.strip().lower() if login else "" diff --git a/src/backend/utils/dependencies.py b/src/backend/utils/dependencies.py index a72de8b..c702eca 100644 --- a/src/backend/utils/dependencies.py +++ b/src/backend/utils/dependencies.py @@ -19,7 +19,7 @@ from models import BaseUserWithType, Candidate, CandidateAI, Employer, Guest from logger import logger from background_tasks import BackgroundTaskManager -#from . rate_limiter import RateLimiter +# from . rate_limiter import RateLimiter # Security security = HTTPBearer() @@ -33,11 +33,13 @@ background_task_manager: Optional[BackgroundTaskManager] = None # Global database manager reference db_manager = None + def set_db_manager(manager: DatabaseManager): """Set the global database manager reference""" global db_manager db_manager = manager + def get_database() -> RedisDatabase: """ Safe database dependency that checks for availability @@ -47,26 +49,18 @@ def get_database() -> RedisDatabase: if db_manager is None: logger.error("Database manager not initialized") - raise HTTPException( - status_code=503, - detail="Database not available - service starting up" - ) - + raise HTTPException(status_code=503, detail="Database not available - service starting up") + if db_manager.is_shutting_down: logger.warning("Database is shutting down") - raise HTTPException( - status_code=503, - detail="Service is shutting down" - ) - + raise HTTPException(status_code=503, detail="Service is shutting down") + try: return db_manager.get_database() except RuntimeError as e: logger.error(f"Database not available: {e}") - raise HTTPException( - status_code=503, - detail="Database connection not available" - ) + raise HTTPException(status_code=503, detail="Database connection not available") + def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): to_encode = data.copy() @@ -78,6 +72,7 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt + async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials = Depends(security)): """Enhanced token verification with guest session recovery""" try: @@ -87,35 +82,35 @@ async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials payload = jwt.decode(credentials.credentials, JWT_SECRET_KEY, algorithms=[ALGORITHM]) user_id: str = payload.get("sub") token_type: str = payload.get("type", "access") - + if user_id is None: raise HTTPException(status_code=401, detail="Invalid authentication credentials") - + # Check if token is blacklisted redis = redis_manager.get_client() blacklist_key = f"blacklisted_token:{credentials.credentials}" - + is_blacklisted = await redis.exists(blacklist_key) if is_blacklisted: logger.warning(f"🚫 Attempt to use blacklisted token for user {user_id}") raise HTTPException(status_code=401, detail="Token has been revoked") - + # For guest tokens, verify guest still exists and update activity if token_type == "guest" or payload.get("type") == "guest": database = db_manager.get_database() guest_data = await database.get_guest(user_id) - + if not guest_data: logger.warning(f"🚫 Guest session not found for token: {user_id}") raise HTTPException(status_code=401, detail="Guest session expired") - + # Update guest activity guest_data["last_activity"] = datetime.now(UTC).isoformat() await database.set_guest(user_id, guest_data) logger.debug(f"🔄 Guest activity updated: {user_id}") - + return user_id - + except jwt.PyJWTError as e: logger.warning(f"⚠️ JWT decode error: {e}") raise HTTPException(status_code=401, detail="Invalid authentication credentials") @@ -125,9 +120,9 @@ async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials logger.error(f"❌ Token verification error: {e}") raise HTTPException(status_code=401, detail="Token verification failed") + async def get_current_user( - user_id: str = Depends(verify_token_with_blacklist), - database: RedisDatabase = Depends(get_database) + user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database) ) -> BaseUserWithType: """Get current user from database""" try: @@ -135,54 +130,59 @@ async def get_current_user( candidate_data = await database.get_candidate(user_id) if candidate_data: from helpers.model_cast import cast_to_base_user_with_type + if candidate_data.get("is_AI"): return cast_to_base_user_with_type(CandidateAI.model_validate(candidate_data)) else: return cast_to_base_user_with_type(Candidate.model_validate(candidate_data)) - + # Check employers employer = await database.get_employer(user_id) if employer: return Employer.model_validate(employer) - + logger.warning(f"⚠️ User {user_id} not found in database") raise HTTPException(status_code=404, detail="User not found") - + except Exception as e: logger.error(f"❌ Error getting current user: {e}") raise HTTPException(status_code=404, detail="User not found") + async def get_current_user_or_guest( - user_id: str = Depends(verify_token_with_blacklist), - database: RedisDatabase = Depends(get_database) + user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database) ) -> BaseUserWithType: """Get current user (including guests) from database""" try: # Check candidates first candidate_data = await database.get_candidate(user_id) if candidate_data: - return Candidate.model_validate(candidate_data) if not candidate_data.get("is_AI") else CandidateAI.model_validate(candidate_data) - + return ( + Candidate.model_validate(candidate_data) + if not candidate_data.get("is_AI") + else CandidateAI.model_validate(candidate_data) + ) + # Check employers employer_data = await database.get_employer(user_id) if employer_data: return Employer.model_validate(employer_data) - + # Check guests guest_data = await database.get_guest(user_id) if guest_data: return Guest.model_validate(guest_data) - + logger.warning(f"⚠️ User {user_id} not found in database") raise HTTPException(status_code=404, detail="User not found") - + except Exception as e: logger.error(f"❌ Error getting current user: {e}") raise HTTPException(status_code=404, detail="User not found") + async def get_current_admin( - user_id: str = Depends(verify_token_with_blacklist), - database: RedisDatabase = Depends(get_database) + user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database) ) -> BaseUserWithType: user = await get_current_user(user_id=user_id, database=database) if isinstance(user, Candidate) and user.is_admin: @@ -193,6 +193,7 @@ async def get_current_admin( logger.warning(f"⚠️ User {user_id} is not an admin") raise HTTPException(status_code=403, detail="Admin access required") + prometheus_collector = CollectorRegistry() # Keep the Instrumentator instance alive @@ -201,5 +202,5 @@ instrumentator = Instrumentator( should_ignore_untemplated=True, should_group_untemplated=True, excluded_handlers=[f"{defines.api_prefix}/metrics"], - registry=prometheus_collector + registry=prometheus_collector, ) diff --git a/src/backend/utils/get_requirements_list.py b/src/backend/utils/get_requirements_list.py index e858fe3..ded65ad 100644 --- a/src/backend/utils/get_requirements_list.py +++ b/src/backend/utils/get_requirements_list.py @@ -1,5 +1,6 @@ from typing import List, Dict -from models import (Job) +from models import Job + def get_requirements_list(job: Job) -> List[Dict[str, str]]: requirements: List[Dict[str, str]] = [] @@ -7,56 +8,56 @@ def get_requirements_list(job: Job) -> List[Dict[str, str]]: if job.requirements: if job.requirements.technical_skills: if job.requirements.technical_skills.required: - requirements.extend([ - {"requirement": req, "domain": "Technical Skills (required)"} - for req in job.requirements.technical_skills.required - ]) + requirements.extend( + [ + {"requirement": req, "domain": "Technical Skills (required)"} + for req in job.requirements.technical_skills.required + ] + ) if job.requirements.technical_skills.preferred: - requirements.extend([ - {"requirement": req, "domain": "Technical Skills (preferred)"} - for req in job.requirements.technical_skills.preferred - ]) - + requirements.extend( + [ + {"requirement": req, "domain": "Technical Skills (preferred)"} + for req in job.requirements.technical_skills.preferred + ] + ) + if job.requirements.experience_requirements: if job.requirements.experience_requirements.required: - requirements.extend([ - {"requirement": req, "domain": "Experience (required)"} - for req in job.requirements.experience_requirements.required - ]) + requirements.extend( + [ + {"requirement": req, "domain": "Experience (required)"} + for req in job.requirements.experience_requirements.required + ] + ) if job.requirements.experience_requirements.preferred: - requirements.extend([ - {"requirement": req, "domain": "Experience (preferred)"} - for req in job.requirements.experience_requirements.preferred - ]) - - if job.requirements.soft_skills: - requirements.extend([ - {"requirement": req, "domain": "Soft Skills"} - for req in job.requirements.soft_skills - ]) - - if job.requirements.experience: - requirements.extend([ - {"requirement": req, "domain": "Experience"} - for req in job.requirements.experience - ]) - - if job.requirements.education: - requirements.extend([ - {"requirement": req, "domain": "Education"} - for req in job.requirements.education - ]) - - if job.requirements.certifications: - requirements.extend([ - {"requirement": req, "domain": "Certifications"} - for req in job.requirements.certifications - ]) - - if job.requirements.preferred_attributes: - requirements.extend([ - {"requirement": req, "domain": "Preferred Attributes"} - for req in job.requirements.preferred_attributes - ]) + requirements.extend( + [ + {"requirement": req, "domain": "Experience (preferred)"} + for req in job.requirements.experience_requirements.preferred + ] + ) - return requirements \ No newline at end of file + if job.requirements.soft_skills: + requirements.extend([{"requirement": req, "domain": "Soft Skills"} for req in job.requirements.soft_skills]) + + if job.requirements.experience: + requirements.extend([{"requirement": req, "domain": "Experience"} for req in job.requirements.experience]) + + if job.requirements.education: + requirements.extend([{"requirement": req, "domain": "Education"} for req in job.requirements.education]) + + if job.requirements.certifications: + requirements.extend( + [{"requirement": req, "domain": "Certifications"} for req in job.requirements.certifications] + ) + + if job.requirements.preferred_attributes: + requirements.extend( + [ + {"requirement": req, "domain": "Preferred Attributes"} + for req in job.requirements.preferred_attributes + ] + ) + + return requirements diff --git a/src/backend/utils/helpers.py b/src/backend/utils/helpers.py index 028873e..a1073a5 100644 --- a/src/backend/utils/helpers.py +++ b/src/backend/utils/helpers.py @@ -13,15 +13,13 @@ from fastapi.responses import StreamingResponse import defines from logger import logger from models import DocumentType -from models import ( - Job, - ChatMessage, DocumentType, ApiStatusType -) +from models import Job, ChatMessage, DocumentType, ApiStatusType from typing import List, Dict -from models import (Job) +from models import Job import utils.llm_proxy as llm_manager + async def get_last_item(generator): """Get the last item from an async generator""" last_item = None @@ -36,20 +34,19 @@ def filter_and_paginate( limit: int = 20, sort_by: Optional[str] = None, sort_order: str = "desc", - filters: Optional[Dict] = None + filters: Optional[Dict] = None, ) -> Tuple[List[Any], int]: """Filter, sort, and paginate items""" filtered_items = items.copy() - + # Apply filters (simplified filtering logic) if filters: for key, value in filters.items(): if isinstance(filtered_items[0], dict) and key in filtered_items[0]: filtered_items = [item for item in filtered_items if item.get(key) == value] elif hasattr(filtered_items[0], key) if filtered_items else False: - filtered_items = [item for item in filtered_items - if getattr(item, key, None) == value] - + filtered_items = [item for item in filtered_items if getattr(item, key, None) == value] + # Sort items if sort_by and filtered_items: reverse = sort_order.lower() == "desc" @@ -60,24 +57,25 @@ def filter_and_paginate( filtered_items.sort(key=lambda x: getattr(x, sort_by, ""), reverse=reverse) except (AttributeError, TypeError): pass # Skip sorting if attribute doesn't exist or isn't comparable - + # Paginate total = len(filtered_items) start = (page - 1) * limit end = start + limit paginated_items = filtered_items[start:end] - + return paginated_items, total async def stream_agent_response(chat_agent, user_message, chat_session_data=None, database=None) -> StreamingResponse: """Stream agent response with proper formatting""" + async def message_stream_generator(): """Generator to stream messages with persistence""" final_message = None - + import utils.llm_proxy as llm_manager - + async for generated_message in chat_agent.generate( llm=llm_manager.get_llm(), model=defines.model, @@ -88,36 +86,39 @@ async def stream_agent_response(chat_agent, user_message, chat_session_data=None logger.error(f"❌ AI generation error: {generated_message.content}") yield f"data: {json.dumps({'status': 'error'})}\n\n" return - + # Store reference to the complete AI message if generated_message.status == ApiStatusType.DONE: final_message = generated_message - + # If the message is not done, convert it to a ChatMessageBase to remove # metadata and other unnecessary fields for streaming if generated_message.status != ApiStatusType.DONE: from models import ChatMessageStreaming, ChatMessageStatus - if not isinstance(generated_message, ChatMessageStreaming) and not isinstance(generated_message, ChatMessageStatus): + + if not isinstance(generated_message, ChatMessageStreaming) and not isinstance( + generated_message, ChatMessageStatus + ): raise TypeError( f"Expected ChatMessageStreaming or ChatMessageStatus, got {type(generated_message)}" ) - json_data = generated_message.model_dump(mode='json', by_alias=True) + json_data = generated_message.model_dump(mode="json", by_alias=True) json_str = json.dumps(json_data) - + yield f"data: {json_str}\n\n" - + # After streaming is complete, persist the final AI message to database if final_message and final_message.status == ApiStatusType.DONE: try: if database and chat_session_data: await database.add_chat_message(final_message.session_id, final_message.model_dump()) logger.info(f"🤖 Message saved to database for session {final_message.session_id}") - + # Update session last activity again chat_session_data["lastActivity"] = datetime.now(UTC).isoformat() await database.set_chat_session(final_message.session_id, chat_session_data) - + except Exception as e: logger.error(f"❌ Failed to save message to database: {e}") @@ -145,20 +146,20 @@ def get_candidate_files_dir(username: str) -> pathlib.Path: def get_document_type_from_filename(filename: str) -> DocumentType: """Determine document type from filename extension""" extension = pathlib.Path(filename).suffix.lower() - + type_mapping = { - '.pdf': DocumentType.PDF, - '.docx': DocumentType.DOCX, - '.doc': DocumentType.DOCX, - '.txt': DocumentType.TXT, - '.md': DocumentType.MARKDOWN, - '.markdown': DocumentType.MARKDOWN, - '.png': DocumentType.IMAGE, - '.jpg': DocumentType.IMAGE, - '.jpeg': DocumentType.IMAGE, - '.gif': DocumentType.IMAGE, + ".pdf": DocumentType.PDF, + ".docx": DocumentType.DOCX, + ".doc": DocumentType.DOCX, + ".txt": DocumentType.TXT, + ".md": DocumentType.MARKDOWN, + ".markdown": DocumentType.MARKDOWN, + ".png": DocumentType.IMAGE, + ".jpg": DocumentType.IMAGE, + ".jpeg": DocumentType.IMAGE, + ".gif": DocumentType.IMAGE, } - + return type_mapping.get(extension, DocumentType.TXT) @@ -173,20 +174,15 @@ async def reformat_as_markdown(database, candidate_entity, content: str): """Reformat content as markdown using AI agent""" from models import ChatContextType, MOCK_UUID, ChatMessageError, ChatMessageStatus, ApiActivityType, ChatMessage import utils.llm_proxy as llm_manager - + chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS) if not chat_agent: - error_message = ChatMessageError( - session_id=MOCK_UUID, - content="No agent found for job requirements chat type" - ) + error_message = ChatMessageError(session_id=MOCK_UUID, content="No agent found for job requirements chat type") yield error_message return - + status_message = ChatMessageStatus( - session_id=MOCK_UUID, - content="Reformatting job description as markdown...", - activity=ApiActivityType.CONVERTING + session_id=MOCK_UUID, content="Reformatting job description as markdown...", activity=ApiActivityType.CONVERTING ) yield status_message @@ -199,25 +195,22 @@ async def reformat_as_markdown(database, candidate_entity, content: str): system_prompt=""" You are a document editor. Take the provided job description and reformat as legible markdown. Return only the markdown content, no other text. Make sure all content is included. -""" +""", ): pass if not message or not isinstance(message, ChatMessage): logger.error("❌ Failed to reformat job description to markdown") - error_message = ChatMessageError( - session_id=MOCK_UUID, - content="Failed to reformat job description" - ) + error_message = ChatMessageError(session_id=MOCK_UUID, content="Failed to reformat job description") yield error_message return - + chat_message: ChatMessage = message try: chat_message.content = chat_agent.extract_markdown_from_text(chat_message.content) except Exception: pass - + logger.info("✅ Successfully converted content to markdown") yield chat_message return @@ -226,14 +219,19 @@ Return only the markdown content, no other text. Make sure all content is includ async def create_job_from_content(database, current_user, content: str): """Create a job from content using AI analysis""" from models import ( - MOCK_UUID, ApiStatusType, ChatMessageError, ChatMessageStatus, - ApiActivityType, ChatContextType, JobRequirementsMessage + MOCK_UUID, + ApiStatusType, + ChatMessageError, + ChatMessageStatus, + ApiActivityType, + ChatContextType, + JobRequirementsMessage, ) - + status_message = ChatMessageStatus( session_id=MOCK_UUID, content=f"Initiating connection with {current_user.first_name}'s AI agent...", - activity=ApiActivityType.INFO + activity=ApiActivityType.INFO, ) yield status_message await asyncio.sleep(0) # Let the status message propagate @@ -246,112 +244,105 @@ async def create_job_from_content(database, current_user, content: str): # Only yield one final DONE message if message.status != ApiStatusType.DONE: yield message - + if not message or not isinstance(message, ChatMessage): - error_message = ChatMessageError( - session_id=MOCK_UUID, - content="Failed to reformat job description" - ) + error_message = ChatMessageError(session_id=MOCK_UUID, content="Failed to reformat job description") yield error_message return - + markdown_message = message chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS) if not chat_agent: error_message = ChatMessageError( - session_id=MOCK_UUID, - content="No agent found for job requirements chat type" + session_id=MOCK_UUID, content="No agent found for job requirements chat type" ) yield error_message return - + status_message = ChatMessageStatus( session_id=MOCK_UUID, content="Analyzing document for company and requirement details...", - activity=ApiActivityType.SEARCHING + activity=ApiActivityType.SEARCHING, ) yield status_message message = None async for message in chat_agent.generate( - llm=llm_manager.get_llm(), - model=defines.model, - session_id=MOCK_UUID, - prompt=markdown_message.content + llm=llm_manager.get_llm(), model=defines.model, session_id=MOCK_UUID, prompt=markdown_message.content ): if message.status != ApiStatusType.DONE: yield message if not message or not isinstance(message, JobRequirementsMessage): error_message = ChatMessageError( - session_id=MOCK_UUID, - content="Job extraction did not convert successfully" + session_id=MOCK_UUID, content="Job extraction did not convert successfully" ) yield error_message return - + job_requirements: JobRequirementsMessage = message logger.info(f"✅ Successfully generated job requirements for job {job_requirements.id}") yield job_requirements return - + + def get_requirements_list(job: Job) -> List[Dict[str, str]]: requirements: List[Dict[str, str]] = [] if job.requirements: if job.requirements.technical_skills: if job.requirements.technical_skills.required: - requirements.extend([ - {"requirement": req, "domain": "Technical Skills (required)"} - for req in job.requirements.technical_skills.required - ]) + requirements.extend( + [ + {"requirement": req, "domain": "Technical Skills (required)"} + for req in job.requirements.technical_skills.required + ] + ) if job.requirements.technical_skills.preferred: - requirements.extend([ - {"requirement": req, "domain": "Technical Skills (preferred)"} - for req in job.requirements.technical_skills.preferred - ]) - + requirements.extend( + [ + {"requirement": req, "domain": "Technical Skills (preferred)"} + for req in job.requirements.technical_skills.preferred + ] + ) + if job.requirements.experience_requirements: if job.requirements.experience_requirements.required: - requirements.extend([ - {"requirement": req, "domain": "Experience (required)"} - for req in job.requirements.experience_requirements.required - ]) + requirements.extend( + [ + {"requirement": req, "domain": "Experience (required)"} + for req in job.requirements.experience_requirements.required + ] + ) if job.requirements.experience_requirements.preferred: - requirements.extend([ - {"requirement": req, "domain": "Experience (preferred)"} - for req in job.requirements.experience_requirements.preferred - ]) - - if job.requirements.soft_skills: - requirements.extend([ - {"requirement": req, "domain": "Soft Skills"} - for req in job.requirements.soft_skills - ]) - - if job.requirements.experience: - requirements.extend([ - {"requirement": req, "domain": "Experience"} - for req in job.requirements.experience - ]) - - if job.requirements.education: - requirements.extend([ - {"requirement": req, "domain": "Education"} - for req in job.requirements.education - ]) - - if job.requirements.certifications: - requirements.extend([ - {"requirement": req, "domain": "Certifications"} - for req in job.requirements.certifications - ]) - - if job.requirements.preferred_attributes: - requirements.extend([ - {"requirement": req, "domain": "Preferred Attributes"} - for req in job.requirements.preferred_attributes - ]) + requirements.extend( + [ + {"requirement": req, "domain": "Experience (preferred)"} + for req in job.requirements.experience_requirements.preferred + ] + ) - return requirements \ No newline at end of file + if job.requirements.soft_skills: + requirements.extend([{"requirement": req, "domain": "Soft Skills"} for req in job.requirements.soft_skills]) + + if job.requirements.experience: + requirements.extend([{"requirement": req, "domain": "Experience"} for req in job.requirements.experience]) + + if job.requirements.education: + requirements.extend([{"requirement": req, "domain": "Education"} for req in job.requirements.education]) + + if job.requirements.certifications: + requirements.extend( + [{"requirement": req, "domain": "Certifications"} for req in job.requirements.certifications] + ) + + if job.requirements.preferred_attributes: + requirements.extend( + [ + {"requirement": req, "domain": "Preferred Attributes"} + for req in job.requirements.preferred_attributes + ] + ) + + return requirements diff --git a/src/backend/utils/llm_proxy.py b/src/backend/utils/llm_proxy.py index 0e35896..4684ded 100644 --- a/src/backend/utils/llm_proxy.py +++ b/src/backend/utils/llm_proxy.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from typing import Dict, List, Any, AsyncGenerator, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field from enum import Enum import asyncio from dataclasses import dataclass @@ -8,61 +8,66 @@ import os import defines from logger import logger + # Standard message format for all providers @dataclass class LLMMessage: role: str # "user", "assistant", "system" content: str - + def __getitem__(self, key: str): """Allow dictionary-style access for backward compatibility""" - if key == 'role': + if key == "role": return self.role - elif key == 'content': + elif key == "content": return self.content else: raise KeyError(f"'{key}' not found in LLMMessage") - + def __setitem__(self, key: str, value: str): """Allow dictionary-style assignment""" - if key == 'role': + if key == "role": self.role = value - elif key == 'content': + elif key == "content": self.content = value else: raise KeyError(f"'{key}' not found in LLMMessage") - + @classmethod - def from_dict(cls, data: Dict[str, str]) -> 'LLMMessage': + def from_dict(cls, data: Dict[str, str]) -> "LLMMessage": """Create LLMMessage from dictionary""" - return cls(role=data['role'], content=data['content']) - + return cls(role=data["role"], content=data["content"]) + def to_dict(self) -> Dict[str, str]: """Convert LLMMessage to dictionary""" - return {'role': self.role, 'content': self.content} - + return {"role": self.role, "content": self.content} + + # Enhanced usage statistics model class UsageStats(BaseModel): """Comprehensive usage statistics across all providers""" + # Token counts (standardized across providers) prompt_tokens: Optional[int] = Field(default=None, description="Number of tokens in the prompt") completion_tokens: Optional[int] = Field(default=None, description="Number of tokens in the completion") total_tokens: Optional[int] = Field(default=None, description="Total number of tokens used") - + # Ollama-specific detailed stats prompt_eval_count: Optional[int] = Field(default=None, description="Number of tokens evaluated in prompt") prompt_eval_duration: Optional[int] = Field(default=None, description="Time spent evaluating prompt (nanoseconds)") eval_count: Optional[int] = Field(default=None, description="Number of tokens generated") eval_duration: Optional[int] = Field(default=None, description="Time spent generating tokens (nanoseconds)") total_duration: Optional[int] = Field(default=None, description="Total request duration (nanoseconds)") - + # Performance metrics tokens_per_second: Optional[float] = Field(default=None, description="Generation speed in tokens/second") prompt_tokens_per_second: Optional[float] = Field(default=None, description="Prompt processing speed") - + # Additional provider-specific stats - extra_stats: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Provider-specific additional statistics") - + extra_stats: Optional[Dict[str, Any]] = Field( + default_factory=dict, description="Provider-specific additional statistics" + ) + def calculate_derived_stats(self) -> None: """Calculate derived statistics where possible""" # Calculate tokens per second for Ollama @@ -70,44 +75,49 @@ class UsageStats(BaseModel): # Convert nanoseconds to seconds and calculate tokens/sec duration_seconds = self.eval_duration / 1_000_000_000 self.tokens_per_second = self.eval_count / duration_seconds - + # Calculate prompt processing speed for Ollama if self.prompt_eval_count and self.prompt_eval_duration and self.prompt_eval_duration > 0: duration_seconds = self.prompt_eval_duration / 1_000_000_000 self.prompt_tokens_per_second = self.prompt_eval_count / duration_seconds - + # Standardize token counts across providers if not self.total_tokens and self.prompt_tokens and self.completion_tokens: self.total_tokens = self.prompt_tokens + self.completion_tokens - + # Map Ollama counts to standard format if not already set if not self.prompt_tokens and self.prompt_eval_count: self.prompt_tokens = self.prompt_eval_count if not self.completion_tokens and self.eval_count: self.completion_tokens = self.eval_count + # Embedding response models class EmbeddingData(BaseModel): """Single embedding result""" + embedding: List[float] = Field(description="The embedding vector") index: int = Field(description="Index in the input list") + class EmbeddingResponse(BaseModel): """Response from embeddings API""" + data: List[EmbeddingData] = Field(description="List of embedding results") model: str = Field(description="Model used for embeddings") usage: Optional[UsageStats] = Field(default=None, description="Usage statistics") - + def get_single_embedding(self) -> List[float]: """Get the first embedding for single-text requests""" if not self.data: raise ValueError("No embeddings in response") return self.data[0].embedding - + def get_embeddings(self) -> List[List[float]]: """Get all embeddings as a list of vectors""" return [item.embedding for item in self.data] + class ChatResponse(BaseModel): content: str model: str @@ -115,13 +125,14 @@ class ChatResponse(BaseModel): usage: Optional[UsageStats] = Field(default=None) # Keep legacy usage field for backward compatibility usage_legacy: Optional[Dict[str, int]] = Field(default=None, alias="usage_dict") - + def get_usage_dict(self) -> Dict[str, Any]: """Get usage statistics as dictionary for backward compatibility""" if self.usage: return self.usage.model_dump(exclude_none=True) return self.usage_legacy or {} + class LLMProvider(str, Enum): OLLAMA = "ollama" OPENAI = "openai" @@ -129,679 +140,527 @@ class LLMProvider(str, Enum): GEMINI = "gemini" GROK = "grok" + class BaseLLMAdapter(ABC): """Abstract base class for all LLM adapters""" - + def __init__(self, **config): self.config = config - + @abstractmethod async def chat( - self, - model: str, - messages: List[LLMMessage], - stream: bool = False, - **kwargs + self, model: str, messages: List[LLMMessage], stream: bool = False, **kwargs ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: """Send chat messages and get response""" - + @abstractmethod async def generate( - self, - model: str, - prompt: str, - stream: bool = False, - **kwargs + self, model: str, prompt: str, stream: bool = False, **kwargs ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: """Generate text from prompt""" - + @abstractmethod - async def embeddings( - self, - model: str, - input_texts: Union[str, List[str]], - **kwargs - ) -> EmbeddingResponse: + async def embeddings(self, model: str, input_texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse: """Generate embeddings for input text(s)""" - + @abstractmethod async def list_models(self) -> List[str]: """List available models""" + class OllamaAdapter(BaseLLMAdapter): """Adapter for Ollama with enhanced statistics""" - + def __init__(self, **config): super().__init__(**config) import ollama - self.client = ollama.AsyncClient( - host=config.get('host', defines.ollama_api_url) - ) - + + self.client = ollama.AsyncClient(host=config.get("host", defines.ollama_api_url)) + def _create_usage_stats(self, response_data: Dict[str, Any]) -> UsageStats: """Create comprehensive usage statistics from Ollama response""" usage_stats = UsageStats() - + # Extract Ollama-specific stats - if 'prompt_eval_count' in response_data: - usage_stats.prompt_eval_count = response_data['prompt_eval_count'] - if 'prompt_eval_duration' in response_data: - usage_stats.prompt_eval_duration = response_data['prompt_eval_duration'] - if 'eval_count' in response_data: - usage_stats.eval_count = response_data['eval_count'] - if 'eval_duration' in response_data: - usage_stats.eval_duration = response_data['eval_duration'] - if 'total_duration' in response_data: - usage_stats.total_duration = response_data['total_duration'] - + if "prompt_eval_count" in response_data: + usage_stats.prompt_eval_count = response_data["prompt_eval_count"] + if "prompt_eval_duration" in response_data: + usage_stats.prompt_eval_duration = response_data["prompt_eval_duration"] + if "eval_count" in response_data: + usage_stats.eval_count = response_data["eval_count"] + if "eval_duration" in response_data: + usage_stats.eval_duration = response_data["eval_duration"] + if "total_duration" in response_data: + usage_stats.total_duration = response_data["total_duration"] + # Store any additional stats extra_stats = {} if type(response_data) is dict: for key, value in response_data.items(): - if key not in ['prompt_eval_count', 'prompt_eval_duration', 'eval_count', - 'eval_duration', 'total_duration', 'message', 'done_reason', 'response']: + if key not in [ + "prompt_eval_count", + "prompt_eval_duration", + "eval_count", + "eval_duration", + "total_duration", + "message", + "done_reason", + "response", + ]: extra_stats[key] = value - + if extra_stats: usage_stats.extra_stats = extra_stats - + # Calculate derived statistics usage_stats.calculate_derived_stats() - + return usage_stats - + async def chat( - self, - model: str, - messages: List[LLMMessage], - stream: bool = False, - **kwargs + self, model: str, messages: List[LLMMessage], stream: bool = False, **kwargs ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: - # Convert LLMMessage objects to Ollama format ollama_messages = [] for msg in messages: - ollama_messages.append({ - "role": msg.role, - "content": msg.content - }) - + ollama_messages.append({"role": msg.role, "content": msg.content}) + if stream: return self._stream_chat(model, ollama_messages, **kwargs) else: - response = await self.client.chat( - model=model, - messages=ollama_messages, - stream=False, - **kwargs - ) - + response = await self.client.chat(model=model, messages=ollama_messages, stream=False, **kwargs) + usage_stats = self._create_usage_stats(response) - + return ChatResponse( - content=response['message']['content'], + content=response["message"]["content"], model=model, - finish_reason=response.get('done_reason'), - usage=usage_stats + finish_reason=response.get("done_reason"), + usage=usage_stats, ) - + async def _stream_chat(self, model: str, messages: List[Dict], **kwargs): # Await the chat call first, then iterate over the result - stream = await self.client.chat( - model=model, - messages=messages, - stream=True, - **kwargs - ) - + stream = await self.client.chat(model=model, messages=messages, stream=True, **kwargs) + # Accumulate stats for final chunk accumulated_stats = {} - + async for chunk in stream: # Update accumulated stats accumulated_stats.update(chunk) - content = chunk.get('message', {}).get('content', '') + content = chunk.get("message", {}).get("content", "") usage_stats = None if chunk.done: usage_stats = self._create_usage_stats(accumulated_stats) yield ChatResponse( - content=content, - model=model, - finish_reason=chunk.get('done_reason'), - usage=usage_stats + content=content, model=model, finish_reason=chunk.get("done_reason"), usage=usage_stats ) else: - yield ChatResponse( - content=content, - model=model, - finish_reason=None, - usage=None - ) - + yield ChatResponse(content=content, model=model, finish_reason=None, usage=None) + async def generate( - self, - model: str, - prompt: str, - stream: bool = False, - **kwargs + self, model: str, prompt: str, stream: bool = False, **kwargs ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: - if stream: return self._stream_generate(model, prompt, **kwargs) else: - response = await self.client.generate( - model=model, - prompt=prompt, - stream=False, - **kwargs - ) - + response = await self.client.generate(model=model, prompt=prompt, stream=False, **kwargs) + usage_stats = self._create_usage_stats(response) - + return ChatResponse( - content=response['response'], - model=model, - finish_reason=response.get('done_reason'), - usage=usage_stats + content=response["response"], model=model, finish_reason=response.get("done_reason"), usage=usage_stats ) - + async def _stream_generate(self, model: str, prompt: str, **kwargs): # Await the generate call first, then iterate over the result - stream = await self.client.generate( - model=model, - prompt=prompt, - stream=True, - **kwargs - ) - + stream = await self.client.generate(model=model, prompt=prompt, stream=True, **kwargs) + accumulated_stats = {} - + async for chunk in stream: - if chunk.get('response'): + if chunk.get("response"): accumulated_stats.update(chunk) - + usage_stats = None - if chunk.get('done', False): # Only create stats for final chunk + if chunk.get("done", False): # Only create stats for final chunk usage_stats = self._create_usage_stats(accumulated_stats) - + yield ChatResponse( - content=chunk['response'], - model=model, - finish_reason=chunk.get('done_reason'), - usage=usage_stats + content=chunk["response"], model=model, finish_reason=chunk.get("done_reason"), usage=usage_stats ) - async def embeddings( - self, - model: str, - input_texts: Union[str, List[str]], - **kwargs - ) -> EmbeddingResponse: + async def embeddings(self, model: str, input_texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse: """Generate embeddings using Ollama""" # Normalize input to list if isinstance(input_texts, str): texts = [input_texts] else: texts = input_texts - + # Ollama embeddings API typically handles one text at a time results = [] final_response = None - + for i, text in enumerate(texts): - response = await self.client.embeddings( - model=model, - prompt=text, - **kwargs - ) - - results.append(EmbeddingData( - embedding=response['embedding'], - index=i - )) + response = await self.client.embeddings(model=model, prompt=text, **kwargs) + + results.append(EmbeddingData(embedding=response["embedding"], index=i)) final_response = response - + # Create usage stats if available from the last response usage_stats = None if final_response and len(results) == 1: usage_stats = self._create_usage_stats(final_response) - - return EmbeddingResponse( - data=results, - model=model, - usage=usage_stats - ) - + + return EmbeddingResponse(data=results, model=model, usage=usage_stats) + async def list_models(self) -> List[str]: models = await self.client.list() - return [model['name'] for model in models['models']] + return [model["name"] for model in models["models"]] + class OpenAIAdapter(BaseLLMAdapter): """Adapter for OpenAI with enhanced statistics""" - + def __init__(self, **config): super().__init__(**config) - import openai - self.client = openai.AsyncOpenAI( - api_key=config.get('api_key', os.getenv('OPENAI_API_KEY')) - ) - + import openai + + self.client = openai.AsyncOpenAI(api_key=config.get("api_key", os.getenv("OPENAI_API_KEY"))) + def _create_usage_stats(self, usage_data: Any) -> UsageStats: """Create usage statistics from OpenAI response""" if not usage_data: return UsageStats() - - usage_dict = usage_data.model_dump() if hasattr(usage_data, 'model_dump') else usage_data - + + usage_dict = usage_data.model_dump() if hasattr(usage_data, "model_dump") else usage_data + return UsageStats( - prompt_tokens=usage_dict.get('prompt_tokens'), - completion_tokens=usage_dict.get('completion_tokens'), - total_tokens=usage_dict.get('total_tokens'), - extra_stats={k: v for k, v in usage_dict.items() - if k not in ['prompt_tokens', 'completion_tokens', 'total_tokens']} + prompt_tokens=usage_dict.get("prompt_tokens"), + completion_tokens=usage_dict.get("completion_tokens"), + total_tokens=usage_dict.get("total_tokens"), + extra_stats={ + k: v for k, v in usage_dict.items() if k not in ["prompt_tokens", "completion_tokens", "total_tokens"] + }, ) - + async def chat( - self, - model: str, - messages: List[LLMMessage], - stream: bool = False, - **kwargs + self, model: str, messages: List[LLMMessage], stream: bool = False, **kwargs ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: - # Convert LLMMessage objects to OpenAI format openai_messages = [] for msg in messages: - openai_messages.append({ - "role": msg.role, - "content": msg.content - }) - + openai_messages.append({"role": msg.role, "content": msg.content}) + if stream: return self._stream_chat(model, openai_messages, **kwargs) else: response = await self.client.chat.completions.create( - model=model, - messages=openai_messages, - stream=False, - **kwargs + model=model, messages=openai_messages, stream=False, **kwargs ) - + usage_stats = self._create_usage_stats(response.usage) - + return ChatResponse( content=response.choices[0].message.content, model=model, finish_reason=response.choices[0].finish_reason, - usage=usage_stats + usage=usage_stats, ) - + async def _stream_chat(self, model: str, messages: List[Dict], **kwargs): # Await the stream creation first, then iterate - stream = await self.client.chat.completions.create( - model=model, - messages=messages, - stream=True, - **kwargs - ) - + stream = await self.client.chat.completions.create(model=model, messages=messages, stream=True, **kwargs) + async for chunk in stream: if chunk.choices[0].delta.content: # Usage stats are only available in the final chunk for OpenAI streaming usage_stats = None - if (chunk.choices[0].finish_reason is not None and - hasattr(chunk, 'usage') and chunk.usage): + if chunk.choices[0].finish_reason is not None and hasattr(chunk, "usage") and chunk.usage: usage_stats = self._create_usage_stats(chunk.usage) - + yield ChatResponse( content=chunk.choices[0].delta.content, model=model, finish_reason=chunk.choices[0].finish_reason, - usage=usage_stats + usage=usage_stats, ) - + async def generate( - self, - model: str, - prompt: str, - stream: bool = False, - **kwargs + self, model: str, prompt: str, stream: bool = False, **kwargs ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: - # Convert to chat format for OpenAI messages = [LLMMessage(role="user", content=prompt)] return await self.chat(model, messages, stream, **kwargs) - async def embeddings( - self, - model: str, - input_texts: Union[str, List[str]], - **kwargs - ) -> EmbeddingResponse: + async def embeddings(self, model: str, input_texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse: """Generate embeddings using OpenAI""" # Normalize input to list if isinstance(input_texts, str): texts = [input_texts] else: texts = input_texts - - response = await self.client.embeddings.create( - model=model, - input=texts, - **kwargs - ) - + + response = await self.client.embeddings.create(model=model, input=texts, **kwargs) + # Convert OpenAI response to our format results = [] for item in response.data: - results.append(EmbeddingData( - embedding=item.embedding, - index=item.index - )) - + results.append(EmbeddingData(embedding=item.embedding, index=item.index)) + # Create usage stats usage_stats = self._create_usage_stats(response.usage) - - return EmbeddingResponse( - data=results, - model=model, - usage=usage_stats - ) - + + return EmbeddingResponse(data=results, model=model, usage=usage_stats) + async def list_models(self) -> List[str]: models = await self.client.models.list() return [model.id for model in models.data] + class AnthropicAdapter(BaseLLMAdapter): """Adapter for Anthropic Claude with enhanced statistics""" - + def __init__(self, **config): super().__init__(**config) - import anthropic - self.client = anthropic.AsyncAnthropic( - api_key=config.get('api_key', os.getenv('ANTHROPIC_API_KEY')) - ) - + import anthropic + + self.client = anthropic.AsyncAnthropic(api_key=config.get("api_key", os.getenv("ANTHROPIC_API_KEY"))) + def _create_usage_stats(self, usage_data: Any) -> UsageStats: """Create usage statistics from Anthropic response""" if not usage_data: return UsageStats() - - usage_dict = usage_data.model_dump() if hasattr(usage_data, 'model_dump') else usage_data - + + usage_dict = usage_data.model_dump() if hasattr(usage_data, "model_dump") else usage_data + return UsageStats( - prompt_tokens=usage_dict.get('input_tokens'), - completion_tokens=usage_dict.get('output_tokens'), - total_tokens=(usage_dict.get('input_tokens', 0) + usage_dict.get('output_tokens', 0)) or None, - extra_stats={k: v for k, v in usage_dict.items() - if k not in ['input_tokens', 'output_tokens']} + prompt_tokens=usage_dict.get("input_tokens"), + completion_tokens=usage_dict.get("output_tokens"), + total_tokens=(usage_dict.get("input_tokens", 0) + usage_dict.get("output_tokens", 0)) or None, + extra_stats={k: v for k, v in usage_dict.items() if k not in ["input_tokens", "output_tokens"]}, ) - + async def chat( - self, - model: str, - messages: List[LLMMessage], - stream: bool = False, - **kwargs + self, model: str, messages: List[LLMMessage], stream: bool = False, **kwargs ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: - # Anthropic requires system message to be separate system_message = None anthropic_messages = [] - + for msg in messages: if msg.role == "system": system_message = msg.content else: - anthropic_messages.append({ - "role": msg.role, - "content": msg.content - }) - + anthropic_messages.append({"role": msg.role, "content": msg.content}) + request_kwargs = { "model": model, "messages": anthropic_messages, - "max_tokens": kwargs.pop('max_tokens', 1000), - **kwargs + "max_tokens": kwargs.pop("max_tokens", 1000), + **kwargs, } - + if system_message: request_kwargs["system"] = system_message - + if stream: return self._stream_chat(**request_kwargs) else: - response = await self.client.messages.create( - stream=False, - **request_kwargs - ) - + response = await self.client.messages.create(stream=False, **request_kwargs) + usage_stats = self._create_usage_stats(response.usage) - + return ChatResponse( - content=response.content[0].text, - model=model, - finish_reason=response.stop_reason, - usage=usage_stats + content=response.content[0].text, model=model, finish_reason=response.stop_reason, usage=usage_stats ) - + async def _stream_chat(self, **kwargs): - model = kwargs['model'] - + model = kwargs["model"] + async with self.client.messages.stream(**kwargs) as stream: final_usage_stats = None - + async for text in stream.text_stream: yield ChatResponse( content=text, model=model, - usage=None # Usage stats not available until stream completion + usage=None, # Usage stats not available until stream completion ) - + # Collect usage stats after stream completion try: - if hasattr(stream, 'get_final_message'): + if hasattr(stream, "get_final_message"): final_message = await stream.get_final_message() - if final_message and hasattr(final_message, 'usage'): + if final_message and hasattr(final_message, "usage"): final_usage_stats = self._create_usage_stats(final_message.usage) - + # Yield a final empty response with usage stats - yield ChatResponse( - content="", - model=model, - usage=final_usage_stats, - finish_reason="stop" - ) + yield ChatResponse(content="", model=model, usage=final_usage_stats, finish_reason="stop") except Exception as e: logger.debug(f"Could not retrieve final usage stats: {e}") - + async def generate( - self, - model: str, - prompt: str, - stream: bool = False, - **kwargs + self, model: str, prompt: str, stream: bool = False, **kwargs ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: - messages = [LLMMessage(role="user", content=prompt)] return await self.chat(model, messages, stream, **kwargs) - async def embeddings( - self, - model: str, - input_texts: Union[str, List[str]], - **kwargs - ) -> EmbeddingResponse: + async def embeddings(self, model: str, input_texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse: """Anthropic doesn't provide embeddings API""" raise NotImplementedError( - "Anthropic does not provide embeddings API. " - "Consider using OpenAI, Ollama, or Gemini for embeddings." - ) - + "Anthropic does not provide embeddings API. " "Consider using OpenAI, Ollama, or Gemini for embeddings." + ) + async def list_models(self) -> List[str]: # Anthropic doesn't have a list models endpoint, return known models - return [ - "claude-3-5-sonnet-20241022", - "claude-3-5-haiku-20241022", - "claude-3-opus-20240229" - ] + return ["claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022", "claude-3-opus-20240229"] + class GeminiAdapter(BaseLLMAdapter): """Adapter for Google Gemini with enhanced statistics""" - + def __init__(self, **config): super().__init__(**config) - import google.generativeai as genai - genai.configure(api_key=config.get('api_key', os.getenv('GEMINI_API_KEY'))) + import google.generativeai as genai + + genai.configure(api_key=config.get("api_key", os.getenv("GEMINI_API_KEY"))) self.genai = genai - + def _create_usage_stats(self, response: Any) -> UsageStats: """Create usage statistics from Gemini response""" usage_stats = UsageStats() - + # Gemini usage metadata extraction - if hasattr(response, 'usage_metadata'): + if hasattr(response, "usage_metadata"): usage = response.usage_metadata - usage_stats.prompt_tokens = getattr(usage, 'prompt_token_count', None) - usage_stats.completion_tokens = getattr(usage, 'candidates_token_count', None) - usage_stats.total_tokens = getattr(usage, 'total_token_count', None) - + usage_stats.prompt_tokens = getattr(usage, "prompt_token_count", None) + usage_stats.completion_tokens = getattr(usage, "candidates_token_count", None) + usage_stats.total_tokens = getattr(usage, "total_token_count", None) + # Store additional Gemini-specific stats extra_stats = {} for attr in dir(usage): - if not attr.startswith('_') and attr not in ['prompt_token_count', 'candidates_token_count', 'total_token_count']: + if not attr.startswith("_") and attr not in [ + "prompt_token_count", + "candidates_token_count", + "total_token_count", + ]: extra_stats[attr] = getattr(usage, attr) - + if extra_stats: usage_stats.extra_stats = extra_stats - + return usage_stats - + async def chat( - self, - model: str, - messages: List[LLMMessage], - stream: bool = False, - **kwargs + self, model: str, messages: List[LLMMessage], stream: bool = False, **kwargs ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: - model_instance = self.genai.GenerativeModel(model) - + # Convert messages to Gemini format chat_history = [] current_message = None - + for i, msg in enumerate(messages[:-1]): # All but last message go to history if msg.role == "user": chat_history.append({"role": "user", "parts": [msg.content]}) elif msg.role == "assistant": chat_history.append({"role": "model", "parts": [msg.content]}) - + # Last message is the current prompt if messages: current_message = messages[-1].content - + chat = model_instance.start_chat(history=chat_history) if not current_message: raise ValueError("No current message provided for chat") - + if stream: return self._stream_chat(chat, current_message, **kwargs) else: response = await chat.send_message_async(current_message, **kwargs) usage_stats = self._create_usage_stats(response) - + return ChatResponse( content=response.text, model=model, finish_reason=response.candidates[0].finish_reason.name if response.candidates else None, - usage=usage_stats + usage=usage_stats, ) - + async def _stream_chat(self, chat, message: str, **kwargs): # Await the stream creation first, then iterate stream = await chat.send_message_async(message, stream=True, **kwargs) - + final_chunk = None - + async for chunk in stream: if chunk.text: final_chunk = chunk # Keep reference to last chunk - + yield ChatResponse( content=chunk.text, model=chat.model.model_name, - usage=None # Don't include usage stats until final chunk + usage=None, # Don't include usage stats until final chunk ) - + # After streaming is complete, yield final response with usage stats if final_chunk: usage_stats = self._create_usage_stats(final_chunk) - if usage_stats and any([usage_stats.prompt_tokens, usage_stats.completion_tokens, usage_stats.total_tokens]): + if usage_stats and any( + [usage_stats.prompt_tokens, usage_stats.completion_tokens, usage_stats.total_tokens] + ): yield ChatResponse( content="", model=chat.model.model_name, usage=usage_stats, - finish_reason=final_chunk.candidates[0].finish_reason.name if final_chunk.candidates else None + finish_reason=final_chunk.candidates[0].finish_reason.name if final_chunk.candidates else None, ) - + async def generate( - self, - model: str, - prompt: str, - stream: bool = False, - **kwargs + self, model: str, prompt: str, stream: bool = False, **kwargs ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: - messages = [LLMMessage(role="user", content=prompt)] return await self.chat(model, messages, stream, **kwargs) - async def embeddings( - self, - model: str, - input_texts: Union[str, List[str]], - **kwargs - ) -> EmbeddingResponse: + async def embeddings(self, model: str, input_texts: Union[str, List[str]], **kwargs) -> EmbeddingResponse: """Generate embeddings using Google Gemini""" # Normalize input to list if isinstance(input_texts, str): texts = [input_texts] else: texts = input_texts - + results = [] - + # Gemini embeddings - use embedding model for i, text in enumerate(texts): - response = self.genai.embed_content( - model=model, - content=text, - **kwargs - ) - - results.append(EmbeddingData( - embedding=response['embedding'], - index=i - )) - + response = self.genai.embed_content(model=model, content=text, **kwargs) + + results.append(EmbeddingData(embedding=response["embedding"], index=i)) + return EmbeddingResponse( data=results, model=model, - usage=None # Gemini embeddings don't typically return usage stats + usage=None, # Gemini embeddings don't typically return usage stats ) async def list_models(self) -> List[str]: models = self.genai.list_models() - return [model.name for model in models if 'generateContent' in model.supported_generation_methods] + return [model.name for model in models if "generateContent" in model.supported_generation_methods] + class UnifiedLLMProxy: """Main proxy class that provides unified interface to all LLM providers""" - + def __init__(self, default_provider: LLMProvider = LLMProvider.OLLAMA): self.adapters: Dict[LLMProvider, BaseLLMAdapter] = {} self.default_provider = default_provider self._initialized_providers = set() - + def configure_provider(self, provider: LLMProvider, **config): """Configure a specific provider with its settings""" adapter_classes = { @@ -811,32 +670,32 @@ class UnifiedLLMProxy: LLMProvider.GEMINI: GeminiAdapter, # Add other providers as needed } - + if provider in adapter_classes: self.adapters[provider] = adapter_classes[provider](**config) self._initialized_providers.add(provider) else: raise ValueError(f"Unsupported provider: {provider}") - + def set_default_provider(self, provider: LLMProvider): """Set the default provider for requests""" if provider not in self._initialized_providers: raise ValueError(f"Provider {provider} not configured") self.default_provider = provider - + async def chat( - self, - model: str, - messages: Union[List[LLMMessage], List[Dict[str, str]]], + self, + model: str, + messages: Union[List[LLMMessage], List[Dict[str, str]]], provider: Optional[LLMProvider] = None, stream: bool = False, - **kwargs + **kwargs, ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: """Send chat messages using specified or default provider""" - + provider = provider or self.default_provider adapter = self._get_adapter(provider) - + # Normalize messages to LLMMessage objects normalized_messages = [] for msg in messages: @@ -846,231 +705,203 @@ class UnifiedLLMProxy: normalized_messages.append(LLMMessage.from_dict(msg)) else: raise ValueError(f"Invalid message type: {type(msg)}") - + return await adapter.chat(model, normalized_messages, stream, **kwargs) - + async def chat_stream( - self, - model: str, - messages: Union[List[LLMMessage], List[Dict[str, str]]], + self, + model: str, + messages: Union[List[LLMMessage], List[Dict[str, str]]], provider: Optional[LLMProvider] = None, stream: bool = True, - **kwargs + **kwargs, ) -> AsyncGenerator[ChatResponse, None]: """Stream chat messages using specified or default provider""" if stream is False: raise ValueError("stream must be True for chat_stream") result = await self.chat(model, messages, provider, stream=True, **kwargs) # Type checker now knows this is an AsyncGenerator due to stream=True - async for chunk in result: + async for chunk in result: yield chunk - + async def chat_single( - self, - model: str, - messages: Union[List[LLMMessage], List[Dict[str, str]]], - provider: Optional[LLMProvider] = None, - **kwargs - ) -> ChatResponse: - """Get single chat response using specified or default provider""" - - result = await self.chat(model, messages, provider, stream=False, **kwargs) - # Type checker now knows this is a ChatResponse due to stream=False - return result - - async def generate( - self, - model: str, - prompt: str, - provider: Optional[LLMProvider] = None, - stream: bool = False, - **kwargs - ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: - """Generate text using specified or default provider""" - - provider = provider or self.default_provider - adapter = self._get_adapter(provider) - - return await adapter.generate(model, prompt, stream, **kwargs) - - async def generate_stream( - self, - model: str, - prompt: str, - provider: Optional[LLMProvider] = None, - **kwargs - ) -> AsyncGenerator[ChatResponse, None]: - """Stream text generation using specified or default provider""" - - result = await self.generate(model, prompt, provider, stream=True, **kwargs) - async for chunk in result: - yield chunk - - async def generate_single( - self, - model: str, - prompt: str, - provider: Optional[LLMProvider] = None, - **kwargs - ) -> ChatResponse: - """Get single generation response using specified or default provider""" - - result = await self.generate(model, prompt, provider, stream=False, **kwargs) - return result - - async def embeddings( self, model: str, - input_texts: Union[str, List[str]], + messages: Union[List[LLMMessage], List[Dict[str, str]]], provider: Optional[LLMProvider] = None, - **kwargs + **kwargs, + ) -> ChatResponse: + """Get single chat response using specified or default provider""" + + result = await self.chat(model, messages, provider, stream=False, **kwargs) + # Type checker now knows this is a ChatResponse due to stream=False + return result + + async def generate( + self, model: str, prompt: str, provider: Optional[LLMProvider] = None, stream: bool = False, **kwargs + ) -> Union[ChatResponse, AsyncGenerator[ChatResponse, None]]: + """Generate text using specified or default provider""" + + provider = provider or self.default_provider + adapter = self._get_adapter(provider) + + return await adapter.generate(model, prompt, stream, **kwargs) + + async def generate_stream( + self, model: str, prompt: str, provider: Optional[LLMProvider] = None, **kwargs + ) -> AsyncGenerator[ChatResponse, None]: + """Stream text generation using specified or default provider""" + + result = await self.generate(model, prompt, provider, stream=True, **kwargs) + async for chunk in result: + yield chunk + + async def generate_single( + self, model: str, prompt: str, provider: Optional[LLMProvider] = None, **kwargs + ) -> ChatResponse: + """Get single generation response using specified or default provider""" + + result = await self.generate(model, prompt, provider, stream=False, **kwargs) + return result + + async def embeddings( + self, model: str, input_texts: Union[str, List[str]], provider: Optional[LLMProvider] = None, **kwargs ) -> EmbeddingResponse: """Generate embeddings using specified or default provider""" provider = provider or self.default_provider adapter = self._get_adapter(provider) - + return await adapter.embeddings(model, input_texts, **kwargs) - + async def list_models(self, provider: Optional[LLMProvider] = None) -> List[str]: """List available models for specified or default provider""" - + provider = provider or self.default_provider adapter = self._get_adapter(provider) - + return await adapter.list_models() - + async def list_embedding_models(self, provider: Optional[LLMProvider] = None) -> List[str]: """List available embedding models for specified or default provider""" provider = provider or self.default_provider - + # Provider-specific embedding models embedding_models = { - LLMProvider.OLLAMA: [ - "nomic-embed-text", - "mxbai-embed-large", - "all-minilm", - "snowflake-arctic-embed" - ], - LLMProvider.OPENAI: [ - "text-embedding-3-small", - "text-embedding-3-large", - "text-embedding-ada-002" - ], + LLMProvider.OLLAMA: ["nomic-embed-text", "mxbai-embed-large", "all-minilm", "snowflake-arctic-embed"], + LLMProvider.OPENAI: ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"], LLMProvider.ANTHROPIC: [], # No embeddings API - LLMProvider.GEMINI: [ - "models/embedding-001", - "models/text-embedding-004" - ] + LLMProvider.GEMINI: ["models/embedding-001", "models/text-embedding-004"], } - + if provider == LLMProvider.ANTHROPIC: raise NotImplementedError("Anthropic does not provide embeddings API") - + # For Ollama, check which embedding models are actually available if provider == LLMProvider.OLLAMA: try: all_models = await self.list_models(provider) available_embedding_models = [] suggested_models = embedding_models[provider] - + for model in all_models: # Check if it's a known embedding model if any(emb_model in model for emb_model in suggested_models): available_embedding_models.append(model) - + return available_embedding_models if available_embedding_models else suggested_models except: return embedding_models[provider] - + return embedding_models.get(provider, []) - + def _get_adapter(self, provider: LLMProvider) -> BaseLLMAdapter: """Get adapter for specified provider""" if provider not in self.adapters: raise ValueError(f"Provider {provider} not configured") return self.adapters[provider] + # Example usage and configuration class LLMManager: """Singleton manager for the unified LLM proxy""" - + _instance = None _proxy = None - + @classmethod def get_instance(cls): if cls._instance is None: cls._instance = cls() return cls._instance - + def __init__(self): if LLMManager._proxy is None: LLMManager._proxy = UnifiedLLMProxy() self._configure_from_environment() - + def _configure_from_environment(self): """Configure providers based on environment variables""" if not self._proxy: raise RuntimeError("UnifiedLLMProxy instance not initialized") # Configure Ollama if available - ollama_host = os.getenv('OLLAMA_HOST', defines.ollama_api_url) + ollama_host = os.getenv("OLLAMA_HOST", defines.ollama_api_url) self._proxy.configure_provider(LLMProvider.OLLAMA, host=ollama_host) - + # Configure OpenAI if API key is available - if os.getenv('OPENAI_API_KEY'): + if os.getenv("OPENAI_API_KEY"): self._proxy.configure_provider(LLMProvider.OPENAI) - + # Configure Anthropic if API key is available - if os.getenv('ANTHROPIC_API_KEY'): + if os.getenv("ANTHROPIC_API_KEY"): self._proxy.configure_provider(LLMProvider.ANTHROPIC) - + # Configure Gemini if API key is available - if os.getenv('GEMINI_API_KEY'): + if os.getenv("GEMINI_API_KEY"): self._proxy.configure_provider(LLMProvider.GEMINI) - + # Set default provider from environment or use Ollama - default_provider = os.getenv('DEFAULT_LLM_PROVIDER', 'ollama') + default_provider = os.getenv("DEFAULT_LLM_PROVIDER", "ollama") try: self._proxy.set_default_provider(LLMProvider(default_provider)) except ValueError: # Fallback to Ollama if specified provider not available self._proxy.set_default_provider(LLMProvider.OLLAMA) - + def get_proxy(self) -> UnifiedLLMProxy: """Get the unified LLM proxy instance""" if not self._proxy: raise RuntimeError("UnifiedLLMProxy instance not initialized") return self._proxy + # Convenience function for easy access def get_llm() -> UnifiedLLMProxy: """Get the configured LLM proxy""" return LLMManager.get_instance().get_proxy() + # Example usage with detailed statistics async def example_usage(): """Example showing how to access detailed statistics""" llm = get_llm() - + # Configure providers llm.configure_provider(LLMProvider.OLLAMA, host="http://localhost:11434") - + # Simple chat - messages = [ - LLMMessage(role="user", content="Explain quantum computing in one paragraph") - ] - + messages = [LLMMessage(role="user", content="Explain quantum computing in one paragraph")] + response = await llm.chat_single("llama2", messages) - + print(f"Content: {response.content}") print(f"Model: {response.model}") - + if response.usage: print("Usage Statistics:") print(f" Prompt tokens: {response.usage.prompt_tokens}") print(f" Completion tokens: {response.usage.completion_tokens}") print(f" Total tokens: {response.usage.total_tokens}") - + # Ollama-specific stats if response.usage.prompt_eval_count: print(f" Prompt eval count: {response.usage.prompt_eval_count}") @@ -1080,20 +911,21 @@ async def example_usage(): print(f" Generation speed: {response.usage.tokens_per_second:.2f} tokens/sec") if response.usage.prompt_tokens_per_second: print(f" Prompt processing speed: {response.usage.prompt_tokens_per_second:.2f} tokens/sec") - + # Access as dictionary for backward compatibility usage_dict = response.get_usage_dict() print(f"Usage as dict: {usage_dict}") + async def example_embeddings_usage(): """Example showing how to use the embeddings API""" llm = get_llm() - + # Configure providers llm.configure_provider(LLMProvider.OLLAMA, host="http://localhost:11434") - if os.getenv('OPENAI_API_KEY'): + if os.getenv("OPENAI_API_KEY"): llm.configure_provider(LLMProvider.OPENAI) - + # List available embedding models print("=== Available Embedding Models ===") try: @@ -1101,56 +933,57 @@ async def example_embeddings_usage(): print(f"Ollama embedding models: {ollama_embedding_models}") except Exception as e: print(f"Could not list Ollama embedding models: {e}") - - if os.getenv('OPENAI_API_KEY'): + + if os.getenv("OPENAI_API_KEY"): try: openai_embedding_models = await llm.list_embedding_models(LLMProvider.OPENAI) print(f"OpenAI embedding models: {openai_embedding_models}") except Exception as e: print(f"Could not list OpenAI embedding models: {e}") - + # Single text embedding print("\n=== Single Text Embedding ===") single_text = "The quick brown fox jumps over the lazy dog" - + try: response = await llm.embeddings("nomic-embed-text", single_text, provider=LLMProvider.OLLAMA) print(f"Model: {response.model}") print(f"Embedding dimension: {len(response.get_single_embedding())}") print(f"First 5 values: {response.get_single_embedding()[:5]}") - + if response.usage: print(f"Usage: {response.usage.model_dump(exclude_none=True)}") except Exception as e: print(f"Ollama embedding failed: {e}") - + # Batch embeddings print("\n=== Batch Text Embeddings ===") texts = [ "Machine learning is fascinating", - "Natural language processing enables AI communication", + "Natural language processing enables AI communication", "Deep learning uses neural networks", - "Transformers revolutionized AI" + "Transformers revolutionized AI", ] - + try: # Try OpenAI if available - if os.getenv('OPENAI_API_KEY'): + if os.getenv("OPENAI_API_KEY"): response = await llm.embeddings("text-embedding-3-small", texts, provider=LLMProvider.OPENAI) print(f"OpenAI Model: {response.model}") print(f"Number of embeddings: {len(response.data)}") print(f"Embedding dimension: {len(response.data[0].embedding)}") - + # Calculate similarity between first two texts (requires numpy) try: - import numpy as np + import numpy as np + emb1 = np.array(response.data[0].embedding) emb2 = np.array(response.data[1].embedding) similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2)) print(f"Cosine similarity between first two texts: {similarity:.4f}") except ImportError: print("Install numpy to calculate similarity: pip install numpy") - + if response.usage: print(f"Usage: {response.usage.model_dump(exclude_none=True)}") else: @@ -1159,40 +992,44 @@ async def example_embeddings_usage(): print(f"Ollama Model: {response.model}") print(f"Number of embeddings: {len(response.data)}") print(f"Embedding dimension: {len(response.data[0].embedding)}") - + except Exception as e: print(f"Batch embedding failed: {e}") + async def example_streaming_with_stats(): """Example showing how to collect usage stats from streaming responses""" llm = get_llm() - - messages = [ - LLMMessage(role="user", content="Write a short story about AI") - ] - + + messages = [LLMMessage(role="user", content="Write a short story about AI")] + print("Streaming response:") final_stats = None - + async for chunk in llm.chat_stream("llama2", messages): # Print content as it streams if chunk.content: print(chunk.content, end="", flush=True) - + # Collect usage stats from final chunk if chunk.usage: final_stats = chunk.usage - + print("\n\nFinal Usage Statistics:") if final_stats: print(f" Total tokens: {final_stats.total_tokens}") - print(f" Generation speed: {final_stats.tokens_per_second:.2f} tokens/sec" if final_stats.tokens_per_second else " Generation speed: N/A") + print( + f" Generation speed: {final_stats.tokens_per_second:.2f} tokens/sec" + if final_stats.tokens_per_second + else " Generation speed: N/A" + ) else: print(" No usage statistics available") + if __name__ == "__main__": asyncio.run(example_usage()) - print("\n" + "="*50 + "\n") + print("\n" + "=" * 50 + "\n") asyncio.run(example_streaming_with_stats()) - print("\n" + "="*50 + "\n") - asyncio.run(example_embeddings_usage()) \ No newline at end of file + print("\n" + "=" * 50 + "\n") + asyncio.run(example_embeddings_usage()) diff --git a/src/backend/utils/metrics.py b/src/backend/utils/metrics.py index 6f30521..e090b84 100644 --- a/src/backend/utils/metrics.py +++ b/src/backend/utils/metrics.py @@ -1,6 +1,7 @@ from prometheus_client import Counter, Histogram # type: ignore from threading import Lock + def singleton(cls): instance = None lock = Lock() diff --git a/src/backend/utils/rate_limiter.py b/src/backend/utils/rate_limiter.py index 42f946b..b8790b5 100644 --- a/src/backend/utils/rate_limiter.py +++ b/src/backend/utils/rate_limiter.py @@ -6,68 +6,80 @@ from functools import wraps from datetime import datetime, timedelta, UTC from typing import Callable, Dict, Optional, Any from fastapi import Depends, HTTPException, Request -from pydantic import BaseModel # type: ignore +from pydantic import BaseModel # type: ignore from database.manager import RedisDatabase from logger import logger -from . dependencies import get_current_user_or_guest, get_database +from .dependencies import get_current_user_or_guest, get_database + async def get_rate_limiter(database: RedisDatabase = Depends(get_database)) -> RateLimiter: """Dependency to get rate limiter instance""" return RateLimiter(database) + class RateLimitConfig(BaseModel): """Rate limit configuration""" + requests_per_minute: int requests_per_hour: int requests_per_day: int burst_limit: int # Maximum requests in a short burst burst_window_seconds: int = 60 # Window for burst detection + class GuestRateLimitConfig(RateLimitConfig): """Rate limits for guest users - more restrictive""" + requests_per_minute: int = 10 requests_per_hour: int = 100 requests_per_day: int = 500 burst_limit: int = 15 burst_window_seconds: int = 60 + class AuthenticatedUserRateLimitConfig(RateLimitConfig): """Rate limits for authenticated users - more generous""" + requests_per_minute: int = 60 requests_per_hour: int = 1000 requests_per_day: int = 10000 burst_limit: int = 100 burst_window_seconds: int = 60 + class PremiumUserRateLimitConfig(RateLimitConfig): """Rate limits for premium/admin users - most generous""" + requests_per_minute: int = 120 requests_per_hour: int = 5000 requests_per_day: int = 50000 burst_limit: int = 200 burst_window_seconds: int = 60 + class RateLimitResult(BaseModel): """Result of rate limit check""" + allowed: bool reason: Optional[str] = None retry_after_seconds: Optional[int] = None remaining_requests: Dict[str, int] = {} reset_times: Dict[str, datetime] = {} + class RateLimiter: """Rate limiter using Redis for distributed rate limiting""" - + def __init__(self, database: RedisDatabase): self.database = database self.redis = database.redis - + # Rate limit configurations self.guest_config = GuestRateLimitConfig() self.user_config = AuthenticatedUserRateLimitConfig() self.premium_config = PremiumUserRateLimitConfig() - + def get_config_for_user(self, user_type: str, is_admin: bool = False) -> RateLimitConfig: """Get rate limit configuration based on user type""" if user_type == "guest": @@ -76,66 +88,62 @@ class RateLimiter: return self.premium_config else: return self.user_config - + async def check_rate_limit( - self, - user_id: str, - user_type: str, - is_admin: bool = False, - endpoint: Optional[str] = None + self, user_id: str, user_type: str, is_admin: bool = False, endpoint: Optional[str] = None ) -> RateLimitResult: """ Check if user has exceeded rate limits - + Args: user_id: Unique identifier for the user (guest session ID or user ID) user_type: "guest", "candidate", or "employer" is_admin: Whether user has admin privileges endpoint: Optional endpoint-specific rate limiting - + Returns: RateLimitResult indicating if request is allowed """ config = self.get_config_for_user(user_type, is_admin) current_time = datetime.now(UTC) - + # Create Redis keys for different time windows base_key = f"rate_limit:{user_type}:{user_id}" keys = { "minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}", "hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}", "day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}", - "burst": f"{base_key}:burst" + "burst": f"{base_key}:burst", } - + # Add endpoint-specific limiting if provided if endpoint: keys = {k: f"{v}:{endpoint}" for k, v in keys.items()} - + try: # Use Redis pipeline for atomic operations pipe = self.redis.pipeline() - + # Get current counts for key in keys.values(): pipe.get(key) - + results = await pipe.execute() current_counts = { "minute": int(results[0] or 0), "hour": int(results[1] or 0), "day": int(results[2] or 0), - "burst": int(results[3] or 0) + "burst": int(results[3] or 0), } - + # Check limits limits = { "minute": config.requests_per_minute, "hour": config.requests_per_hour, "day": config.requests_per_day, - "burst": config.burst_limit + "burst": config.burst_limit, } - + # Check each limit for window, current_count in current_counts.items(): limit = limits[window] @@ -146,110 +154,100 @@ class RateLimiter: elif window == "hour": retry_after = 3600 - (current_time.minute * 60 + current_time.second) elif window == "day": - retry_after = 86400 - (current_time.hour * 3600 + current_time.minute * 60 + current_time.second) + retry_after = 86400 - ( + current_time.hour * 3600 + current_time.minute * 60 + current_time.second + ) else: # burst retry_after = config.burst_window_seconds - - logger.warning(f"🚫 Rate limit exceeded for {user_type} {user_id}: {current_count}/{limit} {window}") - + + logger.warning( + f"🚫 Rate limit exceeded for {user_type} {user_id}: {current_count}/{limit} {window}" + ) + return RateLimitResult( allowed=False, reason=f"Rate limit exceeded: {current_count}/{limit} requests per {window}", retry_after_seconds=retry_after, remaining_requests={k: max(0, limits[k] - v) for k, v in current_counts.items()}, - reset_times=self._calculate_reset_times(current_time) + reset_times=self._calculate_reset_times(current_time), ) - + # If we get here, request is allowed - increment counters pipe = self.redis.pipeline() - + # Increment minute counter (expires after 2 minutes) pipe.incr(keys["minute"]) pipe.expire(keys["minute"], 120) - + # Increment hour counter (expires after 2 hours) pipe.incr(keys["hour"]) pipe.expire(keys["hour"], 7200) - + # Increment day counter (expires after 2 days) pipe.incr(keys["day"]) pipe.expire(keys["day"], 172800) - + # Increment burst counter (expires after burst window) pipe.incr(keys["burst"]) pipe.expire(keys["burst"], config.burst_window_seconds) - + await pipe.execute() - + # Calculate remaining requests - remaining = { - k: max(0, limits[k] - (current_counts[k] + 1)) - for k in current_counts.keys() - } - + remaining = {k: max(0, limits[k] - (current_counts[k] + 1)) for k in current_counts.keys()} + logger.debug(f"✅ Rate limit check passed for {user_type} {user_id}") - + return RateLimitResult( - allowed=True, - remaining_requests=remaining, - reset_times=self._calculate_reset_times(current_time) + allowed=True, remaining_requests=remaining, reset_times=self._calculate_reset_times(current_time) ) - + except Exception as e: logger.error(f"❌ Rate limit check failed for {user_id}: {e}") # Fail open - allow request if rate limiting system fails return RateLimitResult(allowed=True, reason="Rate limit check failed - allowing request") - + def _calculate_reset_times(self, current_time: datetime) -> Dict[str, datetime]: """Calculate when each rate limit window resets""" next_minute = current_time.replace(second=0, microsecond=0) + timedelta(minutes=1) next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1) next_day = current_time.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1) - - return { - "minute": next_minute, - "hour": next_hour, - "day": next_day - } - - async def get_user_rate_limit_status( - self, - user_id: str, - user_type: str, - is_admin: bool = False - ) -> Dict[str, Any]: + + return {"minute": next_minute, "hour": next_hour, "day": next_day} + + async def get_user_rate_limit_status(self, user_id: str, user_type: str, is_admin: bool = False) -> Dict[str, Any]: """Get current rate limit status for a user""" config = self.get_config_for_user(user_type, is_admin) current_time = datetime.now(UTC) - + base_key = f"rate_limit:{user_type}:{user_id}" keys = { "minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}", "hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}", "day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}", - "burst": f"{base_key}:burst" + "burst": f"{base_key}:burst", } - + try: pipe = self.redis.pipeline() for key in keys.values(): pipe.get(key) - + results = await pipe.execute() current_counts = { "minute": int(results[0] or 0), "hour": int(results[1] or 0), "day": int(results[2] or 0), - "burst": int(results[3] or 0) + "burst": int(results[3] or 0), } - + limits = { "minute": config.requests_per_minute, "hour": config.requests_per_hour, "day": config.requests_per_day, - "burst": config.burst_limit + "burst": config.burst_limit, } - + return { "user_id": user_id, "user_type": user_type, @@ -258,58 +256,56 @@ class RateLimiter: "limits": limits, "remaining": {k: max(0, limits[k] - current_counts[k]) for k in limits.keys()}, "reset_times": self._calculate_reset_times(current_time), - "config": config.model_dump() + "config": config.model_dump(), } - + except Exception as e: logger.error(f"❌ Failed to get rate limit status for {user_id}: {e}") return {"error": str(e)} - + async def reset_user_rate_limits(self, user_id: str, user_type: str) -> bool: """Reset all rate limits for a user (admin function)""" try: base_key = f"rate_limit:{user_type}:{user_id}" pattern = f"{base_key}:*" - + cursor = 0 deleted_count = 0 - + while True: cursor, keys = await self.redis.scan(cursor, match=pattern, count=100) if keys: await self.redis.delete(*keys) deleted_count += len(keys) - + if cursor == 0: break - + logger.info(f"🔄 Reset {deleted_count} rate limit keys for {user_type} {user_id}") return True - + except Exception as e: logger.error(f"❌ Failed to reset rate limits for {user_id}: {e}") return False - - + + # ============================ # Rate Limited Decorator # ============================ + def rate_limited( - guest_per_minute: int = 10, - user_per_minute: int = 60, - admin_per_minute: int = 120, - endpoint_specific: bool = True + guest_per_minute: int = 10, user_per_minute: int = 60, admin_per_minute: int = 120, endpoint_specific: bool = True ): """ Decorator to easily apply rate limiting to endpoints - + Args: guest_per_minute: Rate limit for guest users - user_per_minute: Rate limit for authenticated users + user_per_minute: Rate limit for authenticated users admin_per_minute: Rate limit for admin users endpoint_specific: Whether to apply endpoint-specific limits - + Usage: @rate_limited(guest_per_minute=5, user_per_minute=30) @api_router.post("/my-endpoint") @@ -320,61 +316,66 @@ def rate_limited( ): return {"message": "Rate limited endpoint"} """ + def decorator(func: Callable) -> Callable: @wraps(func) async def wrapper(*args, **kwargs): # Extract dependencies from function signature import inspect + inspect.signature(func) - + # Get request, current_user, and rate_limiter from kwargs or args request = None current_user = None rate_limiter = None - + # Try to find dependencies in kwargs first for param_name, param_value in kwargs.items(): if isinstance(param_value, Request): request = param_value - elif hasattr(param_value, 'user_type'): # User-like object + elif hasattr(param_value, "user_type"): # User-like object current_user = param_value elif isinstance(param_value, RateLimiter): rate_limiter = param_value - + # If not found in kwargs, check if they're provided via Depends if not rate_limiter: # Create rate limiter instance (this should ideally come from DI) database = get_database() rate_limiter = RateLimiter(database) - + # Apply rate limiting if we have the required components if request and current_user and rate_limiter: await apply_custom_rate_limiting( - request, current_user, rate_limiter, - guest_per_minute, user_per_minute, admin_per_minute + request, current_user, rate_limiter, guest_per_minute, user_per_minute, admin_per_minute ) - + # Call the original function return await func(*args, **kwargs) - + return wrapper + return decorator + async def apply_custom_rate_limiting( request: Request, current_user, rate_limiter: RateLimiter, guest_per_minute: int, - user_per_minute: int, - admin_per_minute: int + user_per_minute: int, + admin_per_minute: int, ): """Apply custom rate limiting with specified limits""" try: # Determine user info user_id = current_user.id - user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type) - is_admin = getattr(current_user, 'is_admin', False) - + user_type = ( + current_user.user_type.value if hasattr(current_user.user_type, "value") else str(current_user.user_type) + ) + is_admin = getattr(current_user, "is_admin", False) + # Determine appropriate limit if is_admin: requests_per_minute = admin_per_minute @@ -382,16 +383,20 @@ async def apply_custom_rate_limiting( requests_per_minute = guest_per_minute else: requests_per_minute = user_per_minute - + # Create custom rate limit key current_time = datetime.now(UTC) - custom_key = f"custom_rate_limit:{request.url.path}:{user_type}:{user_id}:minute:{current_time.strftime('%Y%m%d%H%M')}" - + custom_key = ( + f"custom_rate_limit:{request.url.path}:{user_type}:{user_id}:minute:{current_time.strftime('%Y%m%d%H%M')}" + ) + # Check current usage current_count = int(await rate_limiter.redis.get(custom_key) or 0) - + if current_count >= requests_per_minute: - logger.warning(f"🚫 Custom rate limit exceeded for {user_type} {user_id}: {current_count}/{requests_per_minute}") + logger.warning( + f"🚫 Custom rate limit exceeded for {user_type} {user_id}: {current_count}/{requests_per_minute}" + ) raise HTTPException( status_code=429, detail={ @@ -399,40 +404,40 @@ async def apply_custom_rate_limiting( "message": f"Custom rate limit exceeded: {current_count}/{requests_per_minute} requests per minute", "retryAfter": 60 - current_time.second, "userType": user_type, - "endpoint": request.url.path + "endpoint": request.url.path, }, - headers={"Retry-After": str(60 - current_time.second)} + headers={"Retry-After": str(60 - current_time.second)}, ) - + # Increment counter pipe = rate_limiter.redis.pipeline() pipe.incr(custom_key) pipe.expire(custom_key, 120) # 2 minutes TTL await pipe.execute() - - logger.debug(f"✅ Custom rate limit check passed for {user_type} {user_id}: {current_count + 1}/{requests_per_minute}") - + + logger.debug( + f"✅ Custom rate limit check passed for {user_type} {user_id}: {current_count + 1}/{requests_per_minute}" + ) + except HTTPException: raise except Exception as e: logger.error(f"❌ Custom rate limiting error: {e}") # Fail open + # ============================ # Alternative: FastAPI Dependency-Based Rate Limiting # ============================ -def create_rate_limit_dependency( - guest_per_minute: int = 10, - user_per_minute: int = 60, - admin_per_minute: int = 120 -): + +def create_rate_limit_dependency(guest_per_minute: int = 10, user_per_minute: int = 60, admin_per_minute: int = 120): """ Create a FastAPI dependency for rate limiting - + Usage: rate_limit_5_30 = create_rate_limit_dependency(guest_per_minute=5, user_per_minute=30) - + @api_router.post("/my-endpoint") async def my_endpoint( rate_check = Depends(rate_limit_5_30), @@ -441,85 +446,89 @@ def create_rate_limit_dependency( ): return {"message": "Rate limited endpoint"} """ + async def rate_limit_dependency( request: Request, - current_user = Depends(get_current_user_or_guest), - rate_limiter: RateLimiter = Depends(get_rate_limiter) + current_user=Depends(get_current_user_or_guest), + rate_limiter: RateLimiter = Depends(get_rate_limiter), ): await apply_custom_rate_limiting( - request, current_user, rate_limiter, - guest_per_minute, user_per_minute, admin_per_minute + request, current_user, rate_limiter, guest_per_minute, user_per_minute, admin_per_minute ) return True - + return rate_limit_dependency + # ============================ # Rate Limiting Utilities # ============================ + class EndpointRateLimiter: """Utility class for endpoint-specific rate limiting""" - + def __init__(self, rate_limiter: RateLimiter): self.rate_limiter = rate_limiter self.custom_limits = {} - + def set_endpoint_limits(self, endpoint: str, limits: dict): """Set custom limits for an endpoint""" self.custom_limits[endpoint] = limits - + async def check_endpoint_limit(self, request: Request, current_user) -> bool: """Check if request exceeds endpoint-specific limits""" endpoint = request.url.path - + if endpoint not in self.custom_limits: return True # No custom limits set - + limits = self.custom_limits[endpoint] - user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type) - - if getattr(current_user, 'is_admin', False): + user_type = ( + current_user.user_type.value if hasattr(current_user.user_type, "value") else str(current_user.user_type) + ) + + if getattr(current_user, "is_admin", False): user_type = "admin" - + limit = limits.get(user_type, limits.get("default", 60)) - + current_time = datetime.now(UTC) key = f"endpoint_limit:{endpoint}:{user_type}:{current_user.id}:minute:{current_time.strftime('%Y%m%d%H%M')}" - + current_count = int(await self.rate_limiter.redis.get(key) or 0) - + if current_count >= limit: raise HTTPException( - status_code=429, - detail=f"Endpoint rate limit exceeded: {current_count}/{limit} for {endpoint}" + status_code=429, detail=f"Endpoint rate limit exceeded: {current_count}/{limit} for {endpoint}" ) - + # Increment counter await self.rate_limiter.redis.incr(key) await self.rate_limiter.redis.expire(key, 120) - + return True + # Global endpoint rate limiter instance endpoint_rate_limiter = None + def get_endpoint_rate_limiter(rate_limiter: RateLimiter = Depends(get_rate_limiter)) -> EndpointRateLimiter: """Get endpoint rate limiter instance""" global endpoint_rate_limiter if endpoint_rate_limiter is None: endpoint_rate_limiter = EndpointRateLimiter(rate_limiter) - - # Configure endpoint-specific limits - endpoint_rate_limiter.set_endpoint_limits("/api/1.0/chat/sessions/*/messages/stream", { - "guest": 5, "candidate": 30, "employer": 30, "admin": 100 - }) - endpoint_rate_limiter.set_endpoint_limits("/api/1.0/candidates/documents/upload", { - "guest": 2, "candidate": 10, "employer": 10, "admin": 50 - }) - endpoint_rate_limiter.set_endpoint_limits("/api/1.0/jobs", { - "guest": 1, "candidate": 5, "employer": 20, "admin": 50 - }) - - return endpoint_rate_limiter + # Configure endpoint-specific limits + endpoint_rate_limiter.set_endpoint_limits( + "/api/1.0/chat/sessions/*/messages/stream", {"guest": 5, "candidate": 30, "employer": 30, "admin": 100} + ) + endpoint_rate_limiter.set_endpoint_limits( + "/api/1.0/candidates/documents/upload", {"guest": 2, "candidate": 10, "employer": 10, "admin": 50} + ) + endpoint_rate_limiter.set_endpoint_limits( + "/api/1.0/jobs", {"guest": 1, "candidate": 5, "employer": 20, "admin": 50} + ) + + return endpoint_rate_limiter diff --git a/src/backend/utils/responses.py b/src/backend/utils/responses.py index 297e08d..ef5c4f7 100644 --- a/src/backend/utils/responses.py +++ b/src/backend/utils/responses.py @@ -3,37 +3,17 @@ Response utility functions for consistent API responses """ from typing import Any, Optional, Dict, List + def create_success_response(data: Any, meta: Optional[Dict] = None) -> Dict: - return { - "success": True, - "data": data, - "meta": meta - } + return {"success": True, "data": data, "meta": meta} + def create_error_response(code: str, message: str, details: Any = None) -> Dict: - return { - "success": False, - "error": { - "code": code, - "message": message, - "details": details - } - } + return {"success": False, "error": {"code": code, "message": message, "details": details}} -def create_paginated_response( - data: List[Any], - page: int, - limit: int, - total: int -) -> Dict: + +def create_paginated_response(data: List[Any], page: int, limit: int, total: int) -> Dict: total_pages = (total + limit - 1) // limit has_more = page < total_pages - - return { - "data": data, - "total": total, - "page": page, - "limit": limit, - "totalPages": total_pages, - "hasMore": has_more - } \ No newline at end of file + + return {"data": data, "total": total, "page": page, "limit": limit, "totalPages": total_pages, "hasMore": has_more}