Compare commits

...

2 Commits

Author SHA1 Message Date
2cf3fa7b04 ruff reformat 2025-06-18 13:53:07 -07:00
f1c2e16389 Reformatting with ruff 2025-06-18 13:30:54 -07:00
69 changed files with 5012 additions and 5145 deletions

View File

@ -17,11 +17,9 @@ from logger import logger
AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses
# Maps class_name to (module_name, class_name)
class_registry: Dict[str, Tuple[str, str]] = (
{}
)
class_registry: Dict[str, Tuple[str, str]] = {}
__all__ = ['get_or_create_agent']
__all__ = ["get_or_create_agent"]
package_dir = pathlib.Path(__file__).parent
package_name = __name__
@ -38,16 +36,11 @@ for path in package_dir.glob("*.py"):
# Find all Agent subclasses in the module
for name, obj in inspect.getmembers(module, inspect.isclass):
if (
issubclass(obj, AnyAgent)
and obj is not AnyAgent
and obj is not Agent
and name not in class_registry
):
if issubclass(obj, AnyAgent) and obj is not AnyAgent and obj is not Agent and name not in class_registry:
class_registry[name] = (full_module_name, name)
globals()[name] = obj
logger.info(f"Adding agent: {name}")
__all__.append(name) # type: ignore
__all__.append(name) # type: ignore
except ImportError as e:
logger.error(traceback.format_exc())
logger.error(f"Error importing {full_module_name}: {e}")

View File

@ -1,5 +1,5 @@
from __future__ import annotations
from pydantic import BaseModel, Field, model_validator # type: ignore
from pydantic import BaseModel, Field # type: ignore
from typing import (
Literal,
get_args,
@ -13,10 +13,10 @@ import time
import re
from abc import ABC
from datetime import datetime, UTC
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
from prometheus_client import CollectorRegistry # type: ignore
import numpy as np # type: ignore
import json_extractor as json_extractor
from pydantic import BaseModel, Field, model_validator # type: ignore
from pydantic import BaseModel, Field # type: ignore
from uuid import uuid4
from typing import List, Optional, ClassVar, Any, Literal
@ -24,7 +24,7 @@ from datetime import datetime, UTC
import numpy as np # type: ignore
from uuid import uuid4
from prometheus_client import CollectorRegistry, Counter # type: ignore
from prometheus_client import CollectorRegistry # type: ignore
import os
import re
from pathlib import Path
@ -32,19 +32,45 @@ from pathlib import Path
from rag import start_file_watcher, ChromaDBFileWatcher
import defines
from logger import logger
from models import (Tunables, ChatMessageUser, ChatMessage, RagEntry, ChatMessageMetaData, ApiStatusType, Candidate, ChatContextType)
from models import (
Tunables,
ChatMessageUser,
ChatMessage,
RagEntry,
ChatMessageMetaData,
ApiStatusType,
Candidate,
ChatContextType,
)
import utils.llm_proxy as llm_manager
from database.manager import RedisDatabase
from models import ChromaDBGetResponse
from utils.metrics import Metrics
from models import ( ApiActivityType, ApiMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiStatusType, ChatMessageMetaData, Candidate)
from models import (
ApiActivityType,
ApiMessage,
ChatMessageError,
ChatMessageRagSearch,
ChatMessageStatus,
ChatMessageStreaming,
LLMMessage,
ChatMessage,
ChatOptions,
ChatMessageUser,
Tunables,
ApiStatusType,
ChatMessageMetaData,
Candidate,
)
from logger import logger
import defines
from .registry import agent_registry
from models import ( ChromaDBGetResponse )
from models import ChromaDBGetResponse
class CandidateEntity(Candidate):
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
@ -54,18 +80,15 @@ class CandidateEntity(Candidate):
async def cleanup(self):
"""Cleanup resources associated with this entity"""
# Internal instance members
CandidateEntity__agents: List[Agent] = []
CandidateEntity__observer: Optional[Any] = Field(default=None, exclude=True)
CandidateEntity__file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
CandidateEntity__prometheus_collector: Optional[CollectorRegistry] = Field(
default=None, exclude=True
)
CandidateEntity__prometheus_collector: Optional[CollectorRegistry] = Field(default=None, exclude=True)
CandidateEntity__metrics: Optional[Metrics] = Field(
default=None,
description="Metrics collector for this agent, used to track performance and usage."
default=None, description="Metrics collector for this agent, used to track performance and usage."
)
def __init__(self, candidate=None):
@ -78,17 +101,17 @@ class CandidateEntity(Candidate):
@classmethod
def exists(cls, username: str):
# Validate username format (only allow safe characters)
if not re.match(r'^[a-zA-Z0-9_-]+$', username):
if not re.match(r"^[a-zA-Z0-9_-]+$", username):
return False # Invalid username characters
# Check for minimum and maximum length
if not (3 <= len(username) <= 32):
return False # Invalid username length
# Use Path for safe path handling and normalization
user_dir = Path(defines.user_dir) / username
user_info_path = user_dir / defines.user_info_file
# Ensure the final path is actually within the intended parent directory
# to help prevent directory traversal attacks
try:
@ -96,33 +119,29 @@ class CandidateEntity(Candidate):
return False # Path traversal attempt detected
except (ValueError, RuntimeError): # Potential exceptions from resolve()
return False
# Check if file exists
return user_info_path.is_file()
def get_or_create_agent(self, agent_type: ChatContextType) -> Agent:
"""
Get or create an agent of the specified type for this candidate.
Args:
agent_type: The type of agent to create (default is 'candidate_chat').
**kwargs: Additional fields required by the specific agent subclass.
Returns:
The created agent instance.
"""
# Only instantiate one agent of each type per user
for agent in self.CandidateEntity__agents:
if agent.agent_type == agent_type:
return agent
return get_or_create_agent(
agent_type=agent_type,
user=self,
prometheus_collector=self.prometheus_collector
)
return get_or_create_agent(agent_type=agent_type, user=self, prometheus_collector=self.prometheus_collector)
# Wrapper properties that map into file_watcher
@property
def umap_collection(self) -> ChromaDBGetResponse:
@ -132,6 +151,7 @@ class CandidateEntity(Candidate):
# Fields managed by initialize()
CandidateEntity__initialized: bool = Field(default=False, exclude=True)
@property
def metrics(self) -> Metrics:
if not self.CandidateEntity__metrics:
@ -154,21 +174,16 @@ class CandidateEntity(Candidate):
def observer(self) -> Any:
if not self.CandidateEntity__observer:
raise ValueError("initialize() has not been called.")
return self.CandidateEntity__observer
return self.CandidateEntity__observer
def collect_metrics(self, agent: Agent, response):
if not self.metrics:
logger.warning("No metrics collector set for this agent.")
return
self.metrics.tokens_prompt.labels(agent=agent.agent_type).inc(
response.usage.prompt_eval_count
)
self.metrics.tokens_prompt.labels(agent=agent.agent_type).inc(response.usage.prompt_eval_count)
self.metrics.tokens_eval.labels(agent=agent.agent_type).inc(response.usage.eval_count)
async def initialize(
self,
prometheus_collector: CollectorRegistry,
database: RedisDatabase):
async def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
if self.CandidateEntity__initialized:
# Initialization can only be attempted once; if there are multiple attempts, it means
# a subsystem is failing or there is a logic bug in the code.
@ -183,13 +198,13 @@ class CandidateEntity(Candidate):
if not prometheus_collector:
raise ValueError("prometheus_collector can not be None")
self.CandidateEntity__prometheus_collector = prometheus_collector
self.CandidateEntity__metrics = Metrics(prometheus_collector=self.prometheus_collector)
user_dir = os.path.join(defines.user_dir, self.username)
vector_db_dir=os.path.join(user_dir, defines.persist_directory)
rag_content_dir=os.path.join(user_dir, defines.rag_content_dir)
vector_db_dir = os.path.join(user_dir, defines.persist_directory)
rag_content_dir = os.path.join(user_dir, defines.rag_content_dir)
os.makedirs(vector_db_dir, exist_ok=True)
os.makedirs(rag_content_dir, exist_ok=True)
@ -205,17 +220,21 @@ class CandidateEntity(Candidate):
)
has_username_rag = any(item.name == self.username for item in self.rags)
if not has_username_rag:
self.rags.append(RagEntry(
name=self.username,
description=f"Expert data about {self.full_name}.",
))
self.rags.append(
RagEntry(
name=self.username,
description=f"Expert data about {self.full_name}.",
)
)
self.rag_content_size = self.file_watcher.collection.count()
class Agent(BaseModel, ABC):
"""
Base class for all agent types.
This class defines the common attributes and methods for all agent types.
"""
class Config:
arbitrary_types_allowed = True # Allow arbitrary types like RedisDatabase
@ -237,7 +256,7 @@ class Agent(BaseModel, ABC):
conversation: List[ChatMessageUser] = Field(
default_factory=list,
description="Conversation history for this agent, used to maintain context across messages."
description="Conversation history for this agent, used to maintain context across messages.",
)
@property
@ -254,9 +273,7 @@ class Agent(BaseModel, ABC):
last_item = item
return last_item
def set_optimal_context_size(
self, llm: Any, model: str, prompt: str, ctx_buffer=2048
) -> int:
def set_optimal_context_size(self, llm: Any, model: str, prompt: str, ctx_buffer=2048) -> int:
# Most models average 1.3-1.5 tokens per word
word_count = len(prompt.split())
tokens = int(word_count * 1.4)
@ -265,9 +282,7 @@ class Agent(BaseModel, ABC):
total_ctx = tokens + ctx_buffer
if total_ctx > self.context_size:
logger.info(
f"Increasing context size from {self.context_size} to {total_ctx}"
)
logger.info(f"Increasing context size from {self.context_size} to {total_ctx}")
# Grow the context size if necessary
self.context_size = max(self.context_size, total_ctx)
@ -468,28 +483,28 @@ class Agent(BaseModel, ABC):
"""
if not rag_message.content:
return ""
context = []
for chroma_results in rag_message.content:
for index, metadata in enumerate(chroma_results.metadatas):
content = "\n".join([
line.strip()
for line in chroma_results.documents[index].split("\n")
if line
]).strip()
context.append(f"""
content = "\n".join(
[line.strip() for line in chroma_results.documents[index].split("\n") if line]
).strip()
context.append(
f"""
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")}
Document reference: {chroma_results.ids[index]}
Content: {content}
""")
"""
)
return "\n".join(context)
async def generate_rag_results(
self,
session_id: str,
prompt: str,
top_k: int=defines.default_rag_top_k,
threshold: float=defines.default_rag_threshold,
top_k: int = defines.default_rag_top_k,
threshold: float = defines.default_rag_threshold,
) -> AsyncGenerator[ApiMessage, None]:
"""
Generate RAG results for the given query.
@ -501,15 +516,11 @@ Content: {content}
A list of dictionaries containing the RAG results.
"""
if not self.user:
error_message = ChatMessageError(
session_id=session_id,
content="No user set for RAG generation."
)
error_message = ChatMessageError(session_id=session_id, content="No user set for RAG generation.")
yield error_message
return
results : List[ChromaDBGetResponse] = []
entries: int = 0
results: List[ChromaDBGetResponse] = []
user: CandidateEntity = self.user
for rag in user.rags:
if not rag.enabled:
@ -518,22 +529,20 @@ Content: {content}
status_message = ChatMessageStatus(
session_id=session_id,
activity=ApiActivityType.SEARCHING,
content = f"Searching RAG context {rag.name}..."
content=f"Searching RAG context {rag.name}...",
)
yield status_message
try:
chroma_results = await user.file_watcher.find_similar(
query=prompt, top_k=top_k, threshold=threshold
)
chroma_results = await user.file_watcher.find_similar(query=prompt, top_k=top_k, threshold=threshold)
if not chroma_results:
continue
query_embedding = np.array(chroma_results["query_embedding"]).flatten() # type: ignore
query_embedding = np.array(chroma_results["query_embedding"]).flatten() # type: ignore
umap_2d = user.file_watcher.umap_model_2d.transform([query_embedding])[0] # type: ignore
umap_3d = user.file_watcher.umap_model_3d.transform([query_embedding])[0] # type: ignore
umap_2d = user.file_watcher.umap_model_2d.transform([query_embedding])[0] # type: ignore
umap_3d = user.file_watcher.umap_model_3d.transform([query_embedding])[0] # type: ignore
rag_metadata = ChromaDBGetResponse(
rag_metadata = ChromaDBGetResponse(
name=rag.name,
query=prompt,
query_embedding=query_embedding.tolist(),
@ -549,7 +558,7 @@ Content: {content}
continue_message = ChatMessageStatus(
session_id=session_id,
activity=ApiActivityType.SEARCHING,
content=f"Error searching RAG context {rag.name}: {str(e)}"
content=f"Error searching RAG context {rag.name}: {str(e)}",
)
yield continue_message
@ -560,25 +569,23 @@ Content: {content}
)
yield final_message
return
async def llm_one_shot(
self,
llm: Any, model: str,
session_id: str, prompt: str, system_prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7) -> AsyncGenerator[ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessage, None]:
async def llm_one_shot(
self,
llm: Any,
model: str,
session_id: str,
prompt: str,
system_prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7,
) -> AsyncGenerator[ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessage, None]:
if not self.user:
error_message = ChatMessageError(
session_id=session_id,
content="No user set for chat generation."
)
error_message = ChatMessageError(session_id=session_id, content="No user set for chat generation.")
yield error_message
return
self.set_optimal_context_size(
llm=llm, model=model, prompt=prompt+system_prompt
)
self.set_optimal_context_size(llm=llm, model=model, prompt=prompt + system_prompt)
options = ChatOptions(
seed=8911,
@ -592,9 +599,7 @@ Content: {content}
]
status_message = ChatMessageStatus(
session_id=session_id,
activity=ApiActivityType.GENERATING,
content=f"Generating response..."
session_id=session_id, activity=ApiActivityType.GENERATING, content="Generating response..."
)
yield status_message
@ -610,10 +615,7 @@ Content: {content}
stream=True,
):
if not response:
error_message = ChatMessageError(
session_id=session_id,
content="No response from LLM."
)
error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
yield error_message
return
@ -628,59 +630,46 @@ Content: {content}
yield streaming_message
if not response:
error_message = ChatMessageError(
session_id=session_id,
content="No response from LLM."
)
error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
yield error_message
return
self.user.collect_metrics(agent=self, response=response)
self.context_tokens = (
response.usage.prompt_eval_count + response.usage.eval_count
)
self.context_tokens = response.usage.prompt_eval_count + response.usage.eval_count
chat_message = ChatMessage(
session_id=session_id,
tunables=tunables,
status=ApiStatusType.DONE,
content=content,
metadata = ChatMessageMetaData(
metadata=ChatMessageMetaData(
options=options,
eval_count=response.usage.eval_count,
eval_duration=response.usage.eval_duration,
prompt_eval_count=response.usage.prompt_eval_count,
prompt_eval_duration=response.usage.prompt_eval_duration,
)
),
)
yield chat_message
return
async def generate(
self, llm: Any, model: str,
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
) -> AsyncGenerator[ApiMessage, None]:
if not self.user:
error_message = ChatMessageError(
session_id=session_id,
content="No user set for chat generation."
)
error_message = ChatMessageError(session_id=session_id, content="No user set for chat generation.")
yield error_message
return
user_message = ChatMessageUser(
session_id=session_id,
content=prompt,
)
user = self.user
self.user.metrics.generate_count.labels(agent=self.agent_type).inc()
with self.user.metrics.generate_duration.labels(agent=self.agent_type).time():
context = None
rag_message : ChatMessageRagSearch | None = None
rag_message: ChatMessageRagSearch | None = None
if self.user:
message = None
async for message in self.generate_rag_results(session_id=session_id, prompt=prompt):
@ -692,38 +681,32 @@ Content: {content}
yield message
if not isinstance(message, ChatMessageRagSearch):
raise ValueError(
f"Expected ChatMessageRagSearch, got {type(rag_message)}"
)
raise ValueError(f"Expected ChatMessageRagSearch, got {type(rag_message)}")
rag_message = message
context = self.get_rag_context(rag_message)
# Create a pruned down message list based purely on the prompt and responses,
# discarding the full preamble generated by prepare_message
messages: List[LLMMessage] = [
LLMMessage(role="system", content=self.system_prompt)
]
messages: List[LLMMessage] = [LLMMessage(role="system", content=self.system_prompt)]
# Add the conversation history to the messages
messages.extend([
LLMMessage(role="user" if isinstance(m, ChatMessageUser) else "assistant", content=m.content)
for m in self.conversation
])
messages.extend(
[
LLMMessage(role="user" if isinstance(m, ChatMessageUser) else "assistant", content=m.content)
for m in self.conversation
]
)
# Add the RAG context to the messages if available
if context:
messages.append(
LLMMessage(
role="user",
content=f"<|context|>\nThe following is context information about {self.user.full_name}:\n{context}\n</|context|>\n\nPrompt to respond to:\n{prompt}\n"
content=f"<|context|>\nThe following is context information about {self.user.full_name}:\n{context}\n</|context|>\n\nPrompt to respond to:\n{prompt}\n",
)
)
else:
# Only the actual user query is provided with the full context message
messages.append(
LLMMessage(role="user", content=prompt)
)
llm_history = messages
messages.append(LLMMessage(role="user", content=prompt))
# use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
# message.metadata.tools = {
@ -827,16 +810,12 @@ Content: {content}
# not use_tools
status_message = ChatMessageStatus(
session_id=session_id,
activity=ApiActivityType.GENERATING,
content=f"Generating response..."
session_id=session_id, activity=ApiActivityType.GENERATING, content="Generating response..."
)
yield status_message
# Set the response for streaming
self.set_optimal_context_size(
llm, model, prompt=prompt
)
self.set_optimal_context_size(llm, model, prompt=prompt)
options = ChatOptions(
seed=8911,
@ -856,10 +835,7 @@ Content: {content}
stream=True,
):
if not response:
error_message = ChatMessageError(
session_id=session_id,
content="No response from LLM."
)
error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
yield error_message
return
@ -873,17 +849,12 @@ Content: {content}
yield streaming_message
if not response:
error_message = ChatMessageError(
session_id=session_id,
content="No response from LLM."
)
error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
yield error_message
return
self.user.collect_metrics(agent=self, response=response)
self.context_tokens = (
response.usage.prompt_eval_count + response.usage.eval_count
)
self.context_tokens = response.usage.prompt_eval_count + response.usage.eval_count
end_time = time.perf_counter()
chat_message = ChatMessage(
@ -891,7 +862,7 @@ Content: {content}
tunables=tunables,
status=ApiStatusType.DONE,
content=content,
metadata = ChatMessageMetaData(
metadata=ChatMessageMetaData(
options=options,
eval_count=response.usage.eval_count,
eval_duration=response.usage.eval_duration,
@ -902,10 +873,9 @@ Content: {content}
"llm_streamed": end_time - start_time,
"llm_with_tools": 0, # Placeholder for tool processing time
},
)
),
)
# Add the user and chat messages to the conversation
self.conversation.append(user_message)
self.conversation.append(chat_message)
@ -999,12 +969,13 @@ Content: {content}
raise ValueError("No Markdown found in the response")
_agents: List[Agent] = []
def get_or_create_agent(
agent_type: str,
prometheus_collector: CollectorRegistry,
user: Optional[CandidateEntity]=None) -> Agent:
agent_type: str, prometheus_collector: CollectorRegistry, user: Optional[CandidateEntity] = None
) -> Agent:
"""
Get or create and append a new agent of the specified type, ensuring only one agent per type exists.
@ -1028,14 +999,16 @@ def get_or_create_agent(
for agent_cls in Agent.__subclasses__():
if agent_cls.model_fields["agent_type"].default == agent_type:
# Create the agent instance with provided kwargs
agent = agent_cls(agent_type=agent_type, # type: ignore[call-arg]
user=user)
agent = agent_cls(
agent_type=agent_type, # type: ignore[call-arg]
user=user,
)
_agents.append(agent)
return agent
raise ValueError(f"No agent class found for agent_type: {agent_type}")
# Register the base agent
agent_registry.register(Agent._agent_type, Agent)
CandidateEntity.model_rebuild()

View File

@ -5,10 +5,10 @@ from .base import Agent, agent_registry
from logger import logger
from .registry import agent_registry
from models import ( ApiMessage, Tunables, ApiStatusType)
from models import ApiMessage, Tunables, ApiStatusType
system_message = f"""
system_message = """
When answering queries, follow these steps:
- When any content from <|context|> is relevant, synthesize information from all sources to provide the most complete answer.
@ -21,21 +21,19 @@ Always <|context|> when possible. Be concise, and never make up information. If
Before answering, ensure you have spelled the candidate's name correctly.
"""
class CandidateChat(Agent):
"""
CandidateChat Agent
"""
agent_type: Literal["candidate_chat"] = "candidate_chat" # type: ignore
agent_type: Literal["candidate_chat"] = "candidate_chat" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
system_prompt: str = system_message
async def generate(
self, llm: Any, model: str,
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
) -> AsyncGenerator[ApiMessage, None]:
user = self.user
if not user:
@ -54,12 +52,14 @@ Use that spelling instead of any spelling you may find in the <|context|>.
{system_message}
"""
async for message in super().generate(llm=llm, model=model, session_id=session_id, prompt=prompt, temperature=temperature, tunables=tunables):
async for message in super().generate(
llm=llm, model=model, session_id=session_id, prompt=prompt, temperature=temperature, tunables=tunables
):
if message.status == ApiStatusType.ERROR:
yield message
return
yield message
# Register the base agent
agent_registry.register(CandidateChat._agent_type, CandidateChat)

View File

@ -49,11 +49,13 @@ class Chat(Agent):
"""
Chat Agent
"""
agent_type: Literal["general"] = "general" # type: ignore
agent_type: Literal["general"] = "general" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
system_prompt: str = system_message
# async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
# if not self.context:

View File

@ -1,14 +1,11 @@
from __future__ import annotations
from typing import (
Dict,
Literal,
ClassVar,
cast,
Any,
AsyncGenerator,
List,
Optional
# override
Optional,
# override
) # NOTE: You must import Optional for late binding to work
import random
import time
@ -16,7 +13,15 @@ import time
import os
from .base import Agent, agent_registry
from models import ApiActivityType, ChatMessage, ChatMessageError, ChatMessageStatus, ChatMessageStreaming, ApiStatusType, Tunables
from models import (
ApiActivityType,
ChatMessage,
ChatMessageError,
ChatMessageStatus,
ChatMessageStreaming,
ApiStatusType,
Tunables,
)
from logger import logger
import defines
import backstory_traceback as traceback
@ -26,18 +31,16 @@ from image_generator.profile_image import generate_image, ImageRequest
seed = int(time.time())
random.seed(seed)
class ImageGenerator(Agent):
agent_type: Literal["generate_image"] = "generate_image" # type: ignore
agent_type: Literal["generate_image"] = "generate_image" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
agent_persist: bool = False
system_prompt: str = "" # No system prompt is used
system_prompt: str = "" # No system prompt is used
async def generate(
self, llm: Any, model: str,
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
if not self.user:
logger.error("User is not set for ImageGenerator agent.")
@ -57,11 +60,17 @@ class ImageGenerator(Agent):
yield status_message
logger.info(f"Image generation: {file_path} <- {prompt}")
request = ImageRequest(filepath=file_path, session_id=session_id, prompt=prompt, iterations=4, height=256, width=256, guidance_scale=7.5)
request = ImageRequest(
filepath=file_path,
session_id=session_id,
prompt=prompt,
iterations=4,
height=256,
width=256,
guidance_scale=7.5,
)
generated_message = None
async for generated_message in generate_image(
request=request
):
async for generated_message in generate_image(request=request):
if generated_message.status == ApiStatusType.ERROR:
yield generated_message
return
@ -71,8 +80,7 @@ class ImageGenerator(Agent):
if generated_message is None:
error_message = ChatMessageError(
session_id=session_id,
content="Image generation failed to produce a valid response."
session_id=session_id, content="Image generation failed to produce a valid response."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
@ -81,26 +89,24 @@ class ImageGenerator(Agent):
logger.info("Image generation done...")
user.profile_image = "profile.png"
# Image generated
generated_image = ChatMessage(
session_id=session_id,
status=ApiStatusType.DONE,
content = f"{defines.api_prefix}/profile/{user.username}",
metadata=generated_message.metadata
content=f"{defines.api_prefix}/profile/{user.username}",
metadata=generated_message.metadata,
)
yield generated_image
return
except Exception as e:
error_message = ChatMessageError(
session_id=session_id,
content=f"Error generating image: {str(e)}"
)
error_message = ChatMessageError(session_id=session_id, content=f"Error generating image: {str(e)}")
logger.error(traceback.format_exc())
logger.error(f"⚠️ {error_message.content}")
yield error_message
return
# Register the base agent
agent_registry.register(ImageGenerator._agent_type, ImageGenerator)

View File

@ -1,16 +1,15 @@
from __future__ import annotations
from pydantic import model_validator, Field, BaseModel # type: ignore
from pydantic import Field # type: ignore
from typing import (
Dict,
Literal,
ClassVar,
cast,
Any,
Tuple,
AsyncGenerator,
List,
Optional
# override
Optional,
# override
) # NOTE: You must import Optional for late binding to work
import random
import re
@ -19,10 +18,18 @@ import time
import time
import os
import random
from names_dataset import NameDataset, NameWrapper # type: ignore
from .base import Agent, agent_registry
from models import ApiActivityType, ChatMessage, ChatMessageError, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ApiStatusType, Tunables
from models import (
ApiActivityType,
ChatMessage,
ChatMessageError,
ApiMessageType,
ChatMessageStatus,
ChatMessageStreaming,
ApiStatusType,
Tunables,
)
from logger import logger
import defines
import backstory_traceback as traceback
@ -45,6 +52,7 @@ emptyUser = {
"questions": [],
}
def generate_persona_system_prompt(persona: Dict[str, Any]) -> str:
return f"""\
You are a casting director for a movie. Your job is to provide information on ficticious personas for use in a screen play.
@ -86,6 +94,7 @@ DO NOT infer, imply, abbreviate, or state the ethnicity, gender, or age in the u
You are providing those only for use later by the system when casting individuals for the role.
"""
generate_resume_system_prompt = """
You are a creative writing casting director. As part of the casting, you are building backstories about individuals. The first part
of that is to create an in-depth resume for the person. You will be provided with the following information:
@ -117,10 +126,12 @@ import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EthnicNameGenerator:
def __init__(self):
try:
from names_dataset import NameDataset # type: ignore
from names_dataset import NameDataset # type: ignore
self.nd = NameDataset()
except ImportError:
logger.error("NameDataset not available. Please install: pip install names-dataset")
@ -128,56 +139,54 @@ class EthnicNameGenerator:
except Exception as e:
logger.error(f"Error initializing NameDataset: {e}")
self.nd = None
# US Census 2020 approximate ethnic distribution
self.ethnic_weights = {
'White': 0.576,
'Hispanic': 0.186,
'Black': 0.134,
'Asian': 0.062,
'Native American': 0.013,
'Pacific Islander': 0.003,
'Mixed/Other': 0.026
"White": 0.576,
"Hispanic": 0.186,
"Black": 0.134,
"Asian": 0.062,
"Native American": 0.013,
"Pacific Islander": 0.003,
"Mixed/Other": 0.026,
}
# Map ethnicities to countries (using alpha-2 codes that NameDataset uses)
self.ethnic_country_mapping = {
'White': ['US', 'GB', 'DE', 'IE', 'IT', 'PL', 'FR', 'CA', 'AU'],
'Hispanic': ['MX', 'ES', 'CO', 'PE', 'AR', 'CU', 'VE', 'CL'],
'Black': ['US'], # African American names
'Asian': ['CN', 'IN', 'PH', 'VN', 'KR', 'JP', 'TH', 'MY'],
'Native American': ['US'],
'Pacific Islander': ['US'],
'Mixed/Other': ['US']
"White": ["US", "GB", "DE", "IE", "IT", "PL", "FR", "CA", "AU"],
"Hispanic": ["MX", "ES", "CO", "PE", "AR", "CU", "VE", "CL"],
"Black": ["US"], # African American names
"Asian": ["CN", "IN", "PH", "VN", "KR", "JP", "TH", "MY"],
"Native American": ["US"],
"Pacific Islander": ["US"],
"Mixed/Other": ["US"],
}
def get_weighted_ethnicity(self) -> str:
"""Select ethnicity based on US demographic weights"""
ethnicities = list(self.ethnic_weights.keys())
weights = list(self.ethnic_weights.values())
return random.choices(ethnicities, weights=weights)[0]
def get_names_by_criteria(self, countries: List[str], gender: Optional[str] = None,
n: int = 50, use_first_names: bool = True) -> List[str]:
def get_names_by_criteria(
self, countries: List[str], gender: Optional[str] = None, n: int = 50, use_first_names: bool = True
) -> List[str]:
"""Get names matching criteria using NameDataset's get_top_names method"""
if not self.nd:
return []
all_names = []
for country_code in countries:
try:
# Get top names for this country
top_names = self.nd.get_top_names(
n=n,
use_first_names=use_first_names,
country_alpha2=country_code,
gender=gender
n=n, use_first_names=use_first_names, country_alpha2=country_code, gender=gender
)
if country_code in top_names:
if use_first_names and gender:
# For first names with gender specified
gender_key = 'M' if gender.upper() in ['M', 'MALE'] else 'F'
gender_key = "M" if gender.upper() in ["M", "MALE"] else "F"
if gender_key in top_names[country_code]:
all_names.extend(top_names[country_code][gender_key])
elif use_first_names:
@ -187,102 +196,98 @@ class EthnicNameGenerator:
else:
# For last names
all_names.extend(top_names[country_code])
except Exception as e:
logger.debug(f"Error getting names for {country_code}: {e}")
continue
return list(set(all_names)) # Remove duplicates
def get_name_by_ethnicity(self, ethnicity: str, gender: str = 'random') -> Tuple[str, str, str, str]:
def get_name_by_ethnicity(self, ethnicity: str, gender: str = "random") -> Tuple[str, str, str, str]:
"""Generate a name based on ethnicity using the correct NameDataset API"""
if gender == 'random':
gender = random.choice(['Male', 'Female'])
countries = self.ethnic_country_mapping.get(ethnicity, ['US'])
if gender == "random":
gender = random.choice(["Male", "Female"])
countries = self.ethnic_country_mapping.get(ethnicity, ["US"])
# Get first names
first_names = self.get_names_by_criteria(
countries=countries,
gender=gender,
use_first_names=True
)
first_names = self.get_names_by_criteria(countries=countries, gender=gender, use_first_names=True)
# Get last names
last_names = self.get_names_by_criteria(
countries=countries,
use_first_names=False
)
last_names = self.get_names_by_criteria(countries=countries, use_first_names=False)
# Select names or use fallbacks
if first_names:
first_name = random.choice(first_names)
else:
first_name = self._get_fallback_first_name(gender, ethnicity)
logger.info(f"Using fallback first name for {ethnicity} {gender}")
if last_names:
last_name = random.choice(last_names)
else:
last_name = self._get_fallback_last_name(ethnicity)
logger.info(f"Using fallback last name for {ethnicity}")
return first_name, last_name, ethnicity, gender
def _get_fallback_first_name(self, gender: str, ethnicity: str) -> str:
"""Provide culturally appropriate fallback first names"""
fallback_names = {
'White': {
'Male': ['James', 'Robert', 'John', 'Michael', 'William', 'David', 'Richard', 'Joseph'],
'Female': ['Mary', 'Patricia', 'Jennifer', 'Linda', 'Elizabeth', 'Barbara', 'Susan', 'Jessica']
"White": {
"Male": ["James", "Robert", "John", "Michael", "William", "David", "Richard", "Joseph"],
"Female": ["Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan", "Jessica"],
},
'Hispanic': {
'Male': ['José', 'Luis', 'Miguel', 'Juan', 'Francisco', 'Alejandro', 'Antonio', 'Carlos'],
'Female': ['María', 'Guadalupe', 'Juana', 'Margarita', 'Francisca', 'Teresa', 'Rosa', 'Ana']
"Hispanic": {
"Male": ["José", "Luis", "Miguel", "Juan", "Francisco", "Alejandro", "Antonio", "Carlos"],
"Female": ["María", "Guadalupe", "Juana", "Margarita", "Francisca", "Teresa", "Rosa", "Ana"],
},
'Black': {
'Male': ['James', 'Robert', 'John', 'Michael', 'William', 'David', 'Richard', 'Charles'],
'Female': ['Mary', 'Patricia', 'Linda', 'Elizabeth', 'Barbara', 'Susan', 'Jessica', 'Sarah']
"Black": {
"Male": ["James", "Robert", "John", "Michael", "William", "David", "Richard", "Charles"],
"Female": ["Mary", "Patricia", "Linda", "Elizabeth", "Barbara", "Susan", "Jessica", "Sarah"],
},
"Asian": {
"Male": ["Wei", "Ming", "Chen", "Li", "Kumar", "Raj", "Hiroshi", "Takeshi"],
"Female": ["Mei", "Lin", "Ling", "Priya", "Yuki", "Soo", "Hana", "Anh"],
},
'Asian': {
'Male': ['Wei', 'Ming', 'Chen', 'Li', 'Kumar', 'Raj', 'Hiroshi', 'Takeshi'],
'Female': ['Mei', 'Lin', 'Ling', 'Priya', 'Yuki', 'Soo', 'Hana', 'Anh']
}
}
ethnicity_names = fallback_names.get(ethnicity, fallback_names['White'])
return random.choice(ethnicity_names.get(gender, ethnicity_names['Male']))
ethnicity_names = fallback_names.get(ethnicity, fallback_names["White"])
return random.choice(ethnicity_names.get(gender, ethnicity_names["Male"]))
def _get_fallback_last_name(self, ethnicity: str) -> str:
"""Provide culturally appropriate fallback last names"""
fallback_surnames = {
'White': ['Smith', 'Johnson', 'Williams', 'Brown', 'Jones', 'Miller', 'Wilson', 'Moore'],
'Hispanic': ['García', 'Rodríguez', 'Martínez', 'López', 'González', 'Pérez', 'Sánchez', 'Ramírez'],
'Black': ['Johnson', 'Williams', 'Brown', 'Jones', 'Davis', 'Miller', 'Wilson', 'Moore'],
'Asian': ['Li', 'Wang', 'Zhang', 'Liu', 'Chen', 'Yang', 'Huang', 'Zhao']
"White": ["Smith", "Johnson", "Williams", "Brown", "Jones", "Miller", "Wilson", "Moore"],
"Hispanic": ["García", "Rodríguez", "Martínez", "López", "González", "Pérez", "Sánchez", "Ramírez"],
"Black": ["Johnson", "Williams", "Brown", "Jones", "Davis", "Miller", "Wilson", "Moore"],
"Asian": ["Li", "Wang", "Zhang", "Liu", "Chen", "Yang", "Huang", "Zhao"],
}
return random.choice(fallback_surnames.get(ethnicity, fallback_surnames['White']))
def generate_random_name(self, gender: str = 'random') -> Tuple[str, str, str, str]:
return random.choice(fallback_surnames.get(ethnicity, fallback_surnames["White"]))
def generate_random_name(self, gender: str = "random") -> Tuple[str, str, str, str]:
"""Generate a random name with ethnicity based on US demographics"""
ethnicity = self.get_weighted_ethnicity()
return self.get_name_by_ethnicity(ethnicity, gender)
def generate_multiple_names(self, count: int = 10, gender: str = 'random') -> List[Dict]:
def generate_multiple_names(self, count: int = 10, gender: str = "random") -> List[Dict]:
"""Generate multiple random names"""
names = []
for _ in range(count):
first, last, ethnicity, actual_gender = self.generate_random_name(gender)
names.append({
'full_name': f"{first} {last}",
'first_name': first,
'last_name': last,
'ethnicity': ethnicity,
'gender': actual_gender
})
names.append(
{
"full_name": f"{first} {last}",
"first_name": first,
"last_name": last,
"ethnicity": ethnicity,
"gender": actual_gender,
}
)
return names
class GeneratePersona(Agent):
agent_type: Literal["generate_persona"] = "generate_persona" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
@ -307,30 +312,27 @@ class GeneratePersona(Agent):
self.full_name = f"{self.first_name} {self.last_name}"
async def generate(
self, llm: Any, model: str,
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
self.randomize()
original_prompt = prompt.strip()
persona = {
"age": self.age,
"gender": self.gender,
"ethnicity": self.ethnicity,
"full_name": self.full_name,
"first_name": self.first_name,
"last_name": self.last_name,
}
"age": self.age,
"gender": self.gender,
"ethnicity": self.ethnicity,
"full_name": self.full_name,
"first_name": self.first_name,
"last_name": self.last_name,
}
prompt = f"""\
```json
{json.dumps(persona)}
```
"""
if original_prompt:
prompt += f"""
Incorporate the following into the job description: {original_prompt}
@ -339,10 +341,11 @@ Incorporate the following into the job description: {original_prompt}
#
# Generate the persona
#
logger.info(f"🤖 Generating persona...")
logger.info("🤖 Generating persona...")
generating_message = None
async for generating_message in self.llm_one_shot(
llm=llm, model=model,
llm=llm,
model=model,
session_id=session_id,
prompt=prompt,
system_prompt=generate_persona_system_prompt(persona=persona),
@ -353,11 +356,10 @@ Incorporate the following into the job description: {original_prompt}
yield generating_message
if generating_message.status != ApiStatusType.DONE:
yield generating_message
if not generating_message:
error_message = ChatMessageError(
session_id=session_id,
content="Persona generation failed to generate a response."
session_id=session_id, content="Persona generation failed to generate a response."
)
yield error_message
return
@ -375,7 +377,7 @@ Incorporate the following into the job description: {original_prompt}
self.username = persona.get("username", None)
if not self.username:
raise ValueError("LLM did not generate a username")
self.username = re.sub(r'\s+', '.', self.username)
self.username = re.sub(r"\s+", ".", self.username)
user_dir = os.path.join(defines.user_dir, persona["username"])
while os.path.exists(user_dir):
match = re.match(r"^(.*?)(\d*)$", persona["username"])
@ -398,19 +400,14 @@ Incorporate the following into the job description: {original_prompt}
location_parts = persona["location"].split(",")
if len(location_parts) == 3:
city, state, country = [part.strip() for part in location_parts]
persona["location"] = {
"city": city,
"state": state,
"country": country
}
persona["location"] = {"city": city, "state": state, "country": country}
else:
logger.error(f"Invalid location format: {persona['location']}")
persona["location"] = None
persona["is_ai"] = True
except Exception as e:
error_message = ChatMessageError(
session_id=session_id,
content=f"Error parsing LLM response: {str(e)}\n\n{json_str}"
session_id=session_id, content=f"Error parsing LLM response: {str(e)}\n\n{json_str}"
)
logger.error(f"❌ Error parsing LLM response: {error_message.content}")
logger.error(traceback.format_exc())
@ -422,10 +419,7 @@ Incorporate the following into the job description: {original_prompt}
# Persona generated
persona_message = ChatMessage(
session_id=session_id,
status=ApiStatusType.DONE,
type=ApiMessageType.JSON,
content = json.dumps(persona)
session_id=session_id, status=ApiStatusType.DONE, type=ApiMessageType.JSON, content=json.dumps(persona)
)
yield persona_message
@ -434,8 +428,8 @@ Incorporate the following into the job description: {original_prompt}
#
status_message = ChatMessageStatus(
session_id=session_id,
activity = ApiActivityType.THINKING,
content = f"Generating resume for {persona['full_name']}..."
activity=ApiActivityType.THINKING,
content=f"Generating resume for {persona['full_name']}...",
)
logger.info(f"🤖 {status_message.content}")
yield status_message
@ -458,7 +452,8 @@ Incorporate the following into the job description: {original_prompt}
Make sure at least one of the candidate's job descriptions take into account the following: {original_prompt}."""
async for generating_message in self.llm_one_shot(
llm=llm, model=model,
llm=llm,
model=model,
session_id=session_id,
prompt=content,
system_prompt=generate_resume_system_prompt,
@ -469,11 +464,10 @@ Make sure at least one of the candidate's job descriptions take into account the
raise Exception(generating_message.content)
if generating_message.status != ApiStatusType.DONE:
yield generating_message
if not generating_message:
error_message = ChatMessageError(
session_id=session_id,
content="Resume generation failed to generate a response."
session_id=session_id, content="Resume generation failed to generate a response."
)
logger.error(f"{error_message.content}")
yield error_message
@ -481,10 +475,7 @@ Make sure at least one of the candidate's job descriptions take into account the
resume = self.extract_markdown_from_text(generating_message.content)
resume_message = ChatMessage(
session_id=session_id,
status=ApiStatusType.DONE,
type=ApiMessageType.TEXT,
content=resume
session_id=session_id, status=ApiStatusType.DONE, type=ApiMessageType.TEXT, content=resume
)
yield resume_message
return
@ -504,5 +495,6 @@ Make sure at least one of the candidate's job descriptions take into account the
raise ValueError("No JSON found in the response")
# Register the base agent
agent_registry.register(GeneratePersona._agent_type, GeneratePersona)

View File

@ -1,30 +1,34 @@
from __future__ import annotations
from pydantic import model_validator, Field # type: ignore
from typing import (
Dict,
Literal,
ClassVar,
Any,
AsyncGenerator,
List,
Optional
# override
# override
) # NOTE: You must import Optional for late binding to work
import json
import numpy as np # type: ignore
from logger import logger
from .base import Agent, agent_registry
from models import (ApiActivityType, ApiMessage, ApiStatusType, ChatMessage, ChatMessageError, ChatMessageResume, ChatMessageStatus, SkillAssessment, SkillStrength)
from models import (
ApiActivityType,
ApiMessage,
ApiStatusType,
ChatMessage,
ChatMessageError,
ChatMessageResume,
ChatMessageStatus,
SkillAssessment,
SkillStrength,
)
class GenerateResume(Agent):
agent_type: Literal["generate_resume"] = "generate_resume" # type: ignore
agent_type: Literal["generate_resume"] = "generate_resume" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
def generate_resume_prompt(
self,
skills: List[SkillAssessment]
):
def generate_resume_prompt(self, skills: List[SkillAssessment]):
"""
Generate a professional resume based on skill assessment results
@ -39,13 +43,13 @@ class GenerateResume(Agent):
"""
if not self.user:
raise ValueError("User must be bound to agent")
# Extract and organize skill assessment data
skills_by_strength = {
SkillStrength.STRONG: [],
SkillStrength.MODERATE: [],
SkillStrength.WEAK: [],
SkillStrength.NONE: []
SkillStrength.STRONG: [],
SkillStrength.MODERATE: [],
SkillStrength.WEAK: [],
SkillStrength.NONE: [],
}
experience_evidence = {}
@ -67,11 +71,7 @@ class GenerateResume(Agent):
experience_evidence[source] = []
experience_evidence[source].append(
{
"skill": skill,
"quote": evidence.quote,
"context": evidence.context
}
{"skill": skill, "quote": evidence.quote, "context": evidence.context}
)
# Build the system prompt
@ -107,7 +107,7 @@ Phone: {self.user.phone or 'N/A'}
### Weaker Skills (mentioned or implied):
{", ".join(skills_by_strength[SkillStrength.WEAK])}
"""
system_prompt += """\
## EXPERIENCE EVIDENCE:
@ -171,16 +171,16 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
) -> AsyncGenerator[ApiMessage, None]:
# Stage 1A: Analyze job requirements
status_message = ChatMessageStatus(
session_id=session_id,
content = f"Analyzing job requirements",
activity=ApiActivityType.THINKING
session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING
)
yield status_message
system_prompt, prompt = self.generate_resume_prompt(skills=skills)
generated_message = None
async for generated_message in self.llm_one_shot(llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt):
async for generated_message in self.llm_one_shot(
llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt
):
if generated_message.status == ApiStatusType.ERROR:
yield generated_message
return
@ -189,22 +189,20 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
if not generated_message:
error_message = ChatMessageError(
session_id=session_id,
content="Job requirements analysis failed to generate a response."
session_id=session_id, content="Job requirements analysis failed to generate a response."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
return
if not isinstance(generated_message, ChatMessage):
error_message = ChatMessageError(
session_id=session_id,
content="Job requirements analysis did not return a valid message."
session_id=session_id, content="Job requirements analysis did not return a valid message."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
return
resume_message = ChatMessageResume(
session_id=session_id,
status=ApiStatusType.DONE,
@ -215,8 +213,9 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
system_prompt=system_prompt,
)
yield resume_message
logger.info(f"✅ Resume generation completed successfully.")
logger.info("✅ Resume generation completed successfully.")
return
# Register the base agent
agent_registry.register(GenerateResume._agent_type, GenerateResume)

View File

@ -1,24 +1,34 @@
from __future__ import annotations
from pydantic import model_validator, Field # type: ignore
from typing import (
Dict,
Literal,
ClassVar,
Any,
AsyncGenerator,
List,
Optional
# override
Optional,
# override
) # NOTE: You must import Optional for late binding to work
import inspect
import json
import numpy as np # type: ignore
from .base import Agent, agent_registry
from models import ApiActivityType, ApiMessage, ChatMessage, ChatMessageError, ChatMessageStatus, ChatMessageStreaming, ApiStatusType, Job, JobRequirements, JobRequirementsMessage, Tunables
from models import (
ApiActivityType,
ApiMessage,
ChatMessage,
ChatMessageError,
ChatMessageStatus,
ChatMessageStreaming,
ApiStatusType,
Job,
JobRequirements,
JobRequirementsMessage,
Tunables,
)
from logger import logger
import backstory_traceback as traceback
class JobRequirementsAgent(Agent):
agent_type: Literal["job_requirements"] = "job_requirements" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
@ -94,14 +104,14 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
"""Analyze job requirements from job description."""
system_prompt, prompt = self.create_job_analysis_prompt(prompt)
status_message = ChatMessageStatus(
session_id=session_id,
content="Analyzing job requirements",
activity=ApiActivityType.THINKING
session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING
)
yield status_message
logger.info(f"🔍 {status_message.content}")
generated_message = None
async for generated_message in self.llm_one_shot(llm, model, session_id=session_id, prompt=prompt, system_prompt=system_prompt):
async for generated_message in self.llm_one_shot(
llm, model, session_id=session_id, prompt=prompt, system_prompt=system_prompt
):
if generated_message.status == ApiStatusType.ERROR:
yield generated_message
return
@ -110,12 +120,12 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
if not generated_message:
error_message = ChatMessageError(
session_id=session_id,
content="Job requirements analysis failed to generate a response.")
session_id=session_id, content="Job requirements analysis failed to generate a response."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
return
yield generated_message
return
@ -132,39 +142,34 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
display = {
"technical_skills": {
"required": reqs.technical_skills.required,
"preferred": reqs.technical_skills.preferred
"preferred": reqs.technical_skills.preferred,
},
"experience_requirements": {
"required": reqs.experience_requirements.required,
"preferred": reqs.experience_requirements.preferred
"preferred": reqs.experience_requirements.preferred,
},
"soft_skills": reqs.soft_skills,
"experience": reqs.experience,
"education": reqs.education,
"certifications": reqs.certifications,
"preferred_attributes": reqs.preferred_attributes,
"company_values": reqs.company_values
"company_values": reqs.company_values,
}
return display
async def generate(
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
) -> AsyncGenerator[ApiMessage, None]:
if not self.user:
error_message = ChatMessageError(
session_id=session_id,
content="User is not set for this agent."
)
error_message = ChatMessageError(session_id=session_id, content="User is not set for this agent.")
logger.error(f"⚠️ {error_message.content}")
yield error_message
return
# Stage 1A: Analyze job requirements
status_message = ChatMessageStatus(
session_id=session_id,
content = f"Analyzing job requirements",
activity=ApiActivityType.THINKING
session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING
)
yield status_message
@ -178,13 +183,12 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
if not generated_message:
error_message = ChatMessageError(
session_id=session_id,
content="Job requirements analysis failed to generate a response."
session_id=session_id, content="Job requirements analysis failed to generate a response."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
return
requirements = None
job_requirements_data = ""
company = ""
@ -214,7 +218,9 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
return
except Exception as e:
status_message.status = ApiStatusType.ERROR
status_message.content = f"Unexpected error processing job requirements: {str(e)}\n\n{job_requirements_data}"
status_message.content = (
f"Unexpected error processing job requirements: {str(e)}\n\n{job_requirements_data}"
)
logger.error(traceback.format_exc())
logger.error(f"⚠️ {status_message.content}")
yield status_message
@ -238,8 +244,9 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
job=job,
)
yield job_requirements_message
logger.info(f"✅ Job requirements analysis completed successfully.")
logger.info("✅ Job requirements analysis completed successfully.")
return
# Register the base agent
agent_registry.register(JobRequirementsAgent._agent_type, JobRequirementsAgent)

View File

@ -5,20 +5,19 @@ from .base import Agent, agent_registry
from logger import logger
from .registry import agent_registry
from models import ( ApiMessage, ApiStatusType, ChatMessageError, ChatMessageRagSearch, ApiStatusType, Tunables )
from models import ApiMessage, ApiStatusType, ChatMessageError, ChatMessageRagSearch, ApiStatusType, Tunables
class Chat(Agent):
"""
Chat Agent
"""
agent_type: Literal["rag_search"] = "rag_search" # type: ignore
agent_type: Literal["rag_search"] = "rag_search" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
async def generate(
self, llm: Any, model: str,
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
) -> AsyncGenerator[ApiMessage, None]:
"""
Generate a response based on the user message and the provided LLM.
@ -44,14 +43,14 @@ class Chat(Agent):
if not isinstance(rag_message, ChatMessageRagSearch):
logger.error(f"Expected ChatMessageRagSearch, got {type(rag_message)}")
error_message = ChatMessageError(
session_id=session_id,
content="RAG search did not return a valid response."
session_id=session_id, content="RAG search did not return a valid response."
)
yield error_message
return
rag_message.status = ApiStatusType.DONE
yield rag_message
# Register the base agent
agent_registry.register(Chat._agent_type, Chat)

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from typing import List, Dict, Optional, Type
# We'll use a registry pattern rather than hardcoded strings
class AgentRegistry:
"""Registry for agent types and classes"""

View File

@ -1,26 +1,32 @@
from __future__ import annotations
from pydantic import model_validator, Field # type: ignore
from typing import (
Dict,
Literal,
ClassVar,
Any,
AsyncGenerator,
List,
Optional
# override
Optional,
# override
) # NOTE: You must import Optional for late binding to work
import json
import numpy as np # type: ignore
from .base import Agent, agent_registry
from models import (ApiMessage, ChatMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageSkillAssessment, ApiStatusType, EvidenceDetail,
SkillAssessment, Tunables)
from models import (
ApiMessage,
ChatMessage,
ChatMessageError,
ChatMessageRagSearch,
ChatMessageSkillAssessment,
ApiStatusType,
EvidenceDetail,
SkillAssessment,
Tunables,
)
from logger import logger
import backstory_traceback as traceback
class SkillMatchAgent(Agent):
agent_type: Literal["skill_match"] = "skill_match" # type: ignore
agent_type: Literal["skill_match"] = "skill_match" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
def generate_skill_assessment_prompt(self, skill, rag_context):
@ -100,30 +106,24 @@ JSON RESPONSE:"""
return system_prompt, prompt
async def generate(
self, llm: Any, model: str,
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
) -> AsyncGenerator[ApiMessage, None]:
if not self.user:
error_message = ChatMessageError(
session_id=session_id,
content="Agent not attached to user. Attach the agent to a user before generating responses."
content="Agent not attached to user. Attach the agent to a user before generating responses.",
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
return
skill = prompt.strip()
if not skill:
error_message = ChatMessageError(
session_id=session_id,
content="Skill cannot be empty."
)
error_message = ChatMessageError(session_id=session_id, content="Skill cannot be empty.")
logger.error(f"⚠️ {error_message.content}")
yield error_message
return
generated_message = None
async for generated_message in self.generate_rag_results(session_id=session_id, prompt=skill):
if generated_message.status == ApiStatusType.ERROR:
@ -134,8 +134,7 @@ JSON RESPONSE:"""
if generated_message is None:
error_message = ChatMessageError(
session_id=session_id,
content="RAG search did not return a valid response."
session_id=session_id, content="RAG search did not return a valid response."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
@ -144,19 +143,20 @@ JSON RESPONSE:"""
if not isinstance(generated_message, ChatMessageRagSearch):
logger.error(f"Expected ChatMessageRagSearch, got {type(generated_message)}")
error_message = ChatMessageError(
session_id=session_id,
content="RAG search did not return a valid response."
session_id=session_id, content="RAG search did not return a valid response."
)
yield error_message
return
rag_message : ChatMessageRagSearch = generated_message
rag_message: ChatMessageRagSearch = generated_message
rag_context = self.get_rag_context(rag_message)
logger.info(f"🔍 RAG content retrieved {len(rag_context)} bytes of context")
system_prompt, prompt = self.generate_skill_assessment_prompt(skill=skill, rag_context=rag_context)
generated_message = None
async for generated_message in self.llm_one_shot(llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt, temperature=0.7):
async for generated_message in self.llm_one_shot(
llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt, temperature=0.7
):
if generated_message.status == ApiStatusType.ERROR:
logger.error(f"⚠️ {generated_message.content}")
yield generated_message
@ -166,8 +166,7 @@ JSON RESPONSE:"""
if generated_message is None:
error_message = ChatMessageError(
session_id=session_id,
content="Skill assessment failed to generate a response."
session_id=session_id, content="Skill assessment failed to generate a response."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
@ -175,8 +174,7 @@ JSON RESPONSE:"""
if not isinstance(generated_message, ChatMessage):
error_message = ChatMessageError(
session_id=session_id,
content="Skill assessment did not return a valid message."
session_id=session_id, content="Skill assessment did not return a valid message."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
@ -199,20 +197,21 @@ JSON RESPONSE:"""
EvidenceDetail(
source=evidence.get("source", ""),
quote=evidence.get("quote", ""),
context=evidence.get("context", "")
) for evidence in skill_assessment_data.get("evidence_details", [])
]
context=evidence.get("context", ""),
)
for evidence in skill_assessment_data.get("evidence_details", [])
],
)
except Exception as e:
error_message = ChatMessageError(
session_id=session_id,
content=f"Failed to parse Skill assessment JSON: {str(e)}\n\n{generated_message.content}\n\nJSON:\n{json_str}\n\n"
content=f"Failed to parse Skill assessment JSON: {str(e)}\n\n{generated_message.content}\n\nJSON:\n{json_str}\n\n",
)
logger.error(traceback.format_exc())
logger.error(f"⚠️ {error_message.content}")
yield error_message
return
# if skill_assessment.evidence_strength == "none":
# logger.info("⚠️ No evidence found for skill assessment, returning NONE.")
# with open("src/tmp.txt", "w") as f:
@ -233,8 +232,9 @@ JSON RESPONSE:"""
skill_assessment=skill_assessment,
)
yield skill_assessment_message
logger.info(f"✅ Skill assessment completed successfully.")
logger.info("✅ Skill assessment completed successfully.")
return
# Register the base agent
agent_registry.register(SkillMatchAgent._agent_type, SkillMatchAgent)

View File

@ -9,99 +9,103 @@ from typing import Optional, List, Dict, Any, Callable
from logger import logger
from database.manager import DatabaseManager
class BackgroundTaskManager:
"""Manages background tasks for the application using asyncio instead of threading"""
def __init__(self, database_manager: DatabaseManager):
self.database_manager = database_manager
self.running = False
self.tasks: List[asyncio.Task] = []
self.main_loop: Optional[asyncio.AbstractEventLoop] = None
async def cleanup_inactive_guests(self, inactive_hours: int = 24):
"""Clean up inactive guest sessions"""
try:
if self.database_manager.is_shutting_down:
logger.info("Skipping guest cleanup - application shutting down")
return 0
database = self.database_manager.get_database()
cleaned_count = await database.cleanup_inactive_guests(inactive_hours)
if cleaned_count > 0:
logger.info(f"🧹 Background cleanup: removed {cleaned_count} inactive guest sessions")
return cleaned_count
except Exception as e:
logger.error(f"❌ Error in guest cleanup: {e}")
return 0
async def cleanup_expired_verification_tokens(self):
"""Clean up expired email verification tokens"""
try:
if self.database_manager.is_shutting_down:
logger.info("Skipping token cleanup - application shutting down")
return 0
database = self.database_manager.get_database()
cleaned_count = await database.cleanup_expired_verification_tokens()
if cleaned_count > 0:
logger.info(f"🧹 Background cleanup: removed {cleaned_count} expired verification tokens")
return cleaned_count
except Exception as e:
logger.error(f"❌ Error in verification token cleanup: {e}")
return 0
async def update_guest_statistics(self):
"""Update guest usage statistics"""
try:
if self.database_manager.is_shutting_down:
logger.info("Skipping stats update - application shutting down")
return {}
database = self.database_manager.get_database()
stats = await database.get_guest_statistics()
# Log interesting statistics
if stats.get('total_guests', 0) > 0:
logger.info(f"📊 Guest stats: {stats['total_guests']} total, "
f"{stats['active_last_hour']} active in last hour, "
f"{stats['converted_guests']} converted")
if stats.get("total_guests", 0) > 0:
logger.info(
f"📊 Guest stats: {stats['total_guests']} total, "
f"{stats['active_last_hour']} active in last hour, "
f"{stats['converted_guests']} converted"
)
return stats
except Exception as e:
logger.error(f"❌ Error updating guest statistics: {e}")
return {}
async def cleanup_old_rate_limit_data(self, days_old: int = 7):
"""Clean up old rate limiting data"""
try:
if self.database_manager.is_shutting_down:
logger.info("Skipping rate limit cleanup - application shutting down")
return 0
# Get Redis client safely (using the event loop safe method)
from database.manager import redis_manager
redis = await redis_manager.get_client()
# Clean up rate limit keys older than specified days
cutoff_time = datetime.now(UTC) - timedelta(days=days_old)
pattern = "rate_limit:*"
cursor = 0
deleted_count = 0
while True:
cursor, keys = await redis.scan(cursor, match=pattern, count=100)
for key in keys:
# Check if key is old enough to delete
try:
ttl = await redis.ttl(key)
if ttl == -1: # No expiration set, check creation time
creation_time = await redis.hget(key, "created_at") # type: ignore
creation_time = await redis.hget(key, "created_at") # type: ignore
if creation_time:
creation_time = datetime.fromisoformat(creation_time).replace(tzinfo=UTC)
if creation_time < cutoff_time:
@ -110,41 +114,41 @@ class BackgroundTaskManager:
deleted_count += 1
except Exception:
continue
if cursor == 0:
break
if deleted_count > 0:
logger.info(f"🧹 Cleaned up {deleted_count} old rate limit keys")
return deleted_count
except Exception as e:
logger.error(f"❌ Error cleaning up rate limit data: {e}")
return 0
async def cleanup_orphaned_data(self):
"""Clean up orphaned database records"""
try:
if self.database_manager.is_shutting_down:
return 0
database = self.database_manager.get_database()
# Clean up orphaned job requirements
orphaned_count = await database.cleanup_orphaned_job_requirements()
if orphaned_count > 0:
logger.info(f"🧹 Cleaned up {orphaned_count} orphaned job requirements")
return orphaned_count
except Exception as e:
logger.error(f"❌ Error cleaning up orphaned data: {e}")
return 0
async def _run_periodic_task(self, name: str, task_func: Callable, interval_seconds: int, *args, **kwargs):
"""Run a periodic task safely in the same event loop"""
logger.info(f"🔄 Starting periodic task: {name} (every {interval_seconds}s)")
while self.running:
try:
# Verify we're still in the correct event loop
@ -152,34 +156,34 @@ class BackgroundTaskManager:
if current_loop != self.main_loop:
logger.error(f"Task {name} detected event loop change! Stopping.")
break
# Run the task
await task_func(*args, **kwargs)
except asyncio.CancelledError:
logger.info(f"Periodic task {name} was cancelled")
break
except Exception as e:
logger.error(f"❌ Error in periodic task {name}: {e}")
# Continue running despite errors
# Sleep with cancellation support
try:
await asyncio.sleep(interval_seconds)
except asyncio.CancelledError:
logger.info(f"Periodic task {name} cancelled during sleep")
break
async def start(self):
"""Start all background tasks in the current event loop"""
if self.running:
logger.warning("⚠️ Background task manager already running")
return
# Store the current event loop
self.main_loop = asyncio.get_running_loop()
self.running = True
# Define periodic tasks with their intervals (in seconds)
periodic_tasks = [
# (name, function, interval_seconds, *args)
@ -189,68 +193,62 @@ class BackgroundTaskManager:
("rate_limit_cleanup", self.cleanup_old_rate_limit_data, 24 * 3600, 7), # Daily, cleanup 7 days old
("orphaned_cleanup", self.cleanup_orphaned_data, 6 * 3600), # Every 6 hours
]
# Create asyncio tasks for each periodic task
for name, func, interval, *args in periodic_tasks:
task = asyncio.create_task(
self._run_periodic_task(name, func, interval, *args),
name=f"background_{name}"
)
task = asyncio.create_task(self._run_periodic_task(name, func, interval, *args), name=f"background_{name}")
self.tasks.append(task)
logger.info(f"📅 Scheduled background task: {name}")
# Run initial cleanup tasks immediately (but don't wait for them)
asyncio.create_task(self._run_initial_cleanup(), name="initial_cleanup")
logger.info("🚀 Background task manager started with asyncio tasks")
async def _run_initial_cleanup(self):
"""Run some cleanup tasks immediately on startup"""
try:
logger.info("🧹 Running initial cleanup tasks...")
# Clean up expired tokens immediately
await asyncio.sleep(5) # Give the app time to fully start
await self.cleanup_expired_verification_tokens()
# Clean up very old inactive guests (7 days old)
await self.cleanup_inactive_guests(inactive_hours=7 * 24)
# Update statistics
await self.update_guest_statistics()
logger.info("✅ Initial cleanup tasks completed")
except Exception as e:
logger.error(f"❌ Error in initial cleanup: {e}")
async def stop(self):
"""Stop all background tasks gracefully"""
logger.info("🛑 Stopping background task manager...")
self.running = False
# Cancel all running tasks
for task in self.tasks:
if not task.done():
task.cancel()
# Wait for all tasks to complete with timeout
if self.tasks:
try:
await asyncio.wait_for(
asyncio.gather(*self.tasks, return_exceptions=True),
timeout=30.0
)
await asyncio.wait_for(asyncio.gather(*self.tasks, return_exceptions=True), timeout=30.0)
logger.info("✅ All background tasks stopped gracefully")
except asyncio.TimeoutError:
logger.warning("⚠️ Some background tasks did not stop within timeout")
self.tasks.clear()
self.main_loop = None
logger.info("🛑 Background task manager stopped")
async def get_task_status(self) -> Dict[str, Any]:
"""Get status of all background tasks"""
status = {
@ -258,23 +256,23 @@ class BackgroundTaskManager:
"main_loop_id": id(self.main_loop) if self.main_loop else None,
"current_loop_id": None,
"task_count": len(self.tasks),
"tasks": []
"tasks": [],
}
try:
current_loop = asyncio.get_running_loop()
status["current_loop_id"] = id(current_loop)
status["loop_matches"] = (id(current_loop) == id(self.main_loop)) if self.main_loop else False
except RuntimeError:
status["current_loop_id"] = "no_running_loop"
for task in self.tasks:
task_info = {
"name": task.get_name(),
"done": task.done(),
"cancelled": task.cancelled(),
}
if task.done() and not task.cancelled():
try:
task.result() # This will raise an exception if the task failed
@ -286,11 +284,11 @@ class BackgroundTaskManager:
task_info["status"] = "cancelled"
else:
task_info["status"] = "running"
status["tasks"].append(task_info)
return status
async def force_run_task(self, task_name: str) -> Any:
"""Manually trigger a specific background task"""
task_map = {
@ -300,10 +298,10 @@ class BackgroundTaskManager:
"rate_limit_cleanup": self.cleanup_old_rate_limit_data,
"orphaned_cleanup": self.cleanup_orphaned_data,
}
if task_name not in task_map:
raise ValueError(f"Unknown task: {task_name}. Available: {list(task_map.keys())}")
logger.info(f"🔧 Manually running task: {task_name}")
result = await task_map[task_name]()
logger.info(f"✅ Manual task {task_name} completed")
@ -317,11 +315,12 @@ async def setup_background_tasks(database_manager: DatabaseManager) -> Backgroun
await task_manager.start()
return task_manager
# For integration with your existing app startup
async def initialize_with_background_tasks(database_manager: DatabaseManager):
"""Initialize database and background tasks together"""
# Start background tasks
background_tasks = await setup_background_tasks(database_manager)
# Return both for your app to manage
return database_manager, background_tasks
return database_manager, background_tasks

View File

@ -3,21 +3,22 @@ import os
import sys
import defines
def filter_traceback(tb, app_path=None, module_name=None):
"""
Filter traceback to include only frames from the specified application path or module.
Args:
tb: Traceback object (e.g., from sys.exc_info()[2])
app_path: Directory path of your application (e.g., '/path/to/your/app')
module_name: Name of the module to include (e.g., 'myapp')
Returns:
Formatted traceback string with filtered frames.
"""
# Extract stack frames
stack = traceback.extract_tb(tb)
# Filter frames based on app_path or module_name
filtered_stack = []
for frame in stack:
@ -27,25 +28,26 @@ def filter_traceback(tb, app_path=None, module_name=None):
filtered_stack.append(frame)
elif module_name and frame.filename.startswith(module_name):
filtered_stack.append(frame)
# Format the filtered stack trace
formatted_stack = traceback.format_list(filtered_stack)
# Get exception info to include the exception type and message
exc_type, exc_value, _ = sys.exc_info()
formatted_exc = traceback.format_exception_only(exc_type, exc_value)
# Combine the filtered stack trace with the exception message
return ''.join(formatted_stack + formatted_exc)
return "".join(formatted_stack + formatted_exc)
def format_exc(app_path=defines.app_path, module_name=None):
"""
Custom version of traceback.format_exc() that filters stack frames.
Args:
app_path: Directory path of your application
module_name: Name of the module to include
Returns:
Formatted traceback string with only relevant frames.
"""

View File

@ -1,4 +1,4 @@
from .core import RedisDatabase
from .manager import DatabaseManager, redis_manager
__all__ = ['RedisDatabase', 'DatabaseManager', 'redis_manager']
__all__ = ["RedisDatabase", "DatabaseManager", "redis_manager"]

View File

@ -1,15 +1,15 @@
KEY_PREFIXES = {
'viewers': 'viewer:',
'candidates': 'candidate:',
'employers': 'employer:',
'jobs': 'job:',
'job_applications': 'job_application:',
'chat_sessions': 'chat_session:',
'chat_messages': 'chat_messages:',
'ai_parameters': 'ai_parameters:',
'users': 'user:',
'candidate_documents': 'candidate_documents:',
'job_requirements': 'job_requirements:',
'resumes': 'resume:',
'user_resumes': 'user_resumes:',
"viewers": "viewer:",
"candidates": "candidate:",
"employers": "employer:",
"jobs": "job:",
"job_applications": "job_application:",
"chat_sessions": "chat_session:",
"chat_messages": "chat_messages:",
"ai_parameters": "ai_parameters:",
"users": "user:",
"candidate_documents": "candidate_documents:",
"job_requirements": "job_requirements:",
"resumes": "resume:",
"user_resumes": "user_resumes:",
}

View File

@ -8,14 +8,15 @@ from .mixins.auth import AuthMixin
from .mixins.analytics import AnalyticsMixin
from .mixins.job import JobMixin
from .mixins.skill import SkillMixin
from .mixins.ai import AIMixin
from .mixins.ai import AIMixin
# RedisDatabase is the main class that combines all mixins for a
# RedisDatabase is the main class that combines all mixins for a
# comprehensive Redis database interface.
class RedisDatabase(
AIMixin,
BaseMixin,
ResumeMixin,
ResumeMixin,
DocumentMixin,
UserMixin,
ChatMixin,
@ -27,7 +28,7 @@ class RedisDatabase(
"""
Main Redis database class combining all mixins
"""
def __init__(self, redis: Redis):
self.redis = redis
super().__init__()
super().__init__()

View File

@ -1,19 +1,16 @@
from redis.asyncio import (Redis, ConnectionPool)
from redis.asyncio import Redis, ConnectionPool
from typing import Optional, Optional
import json
import logging
import os
from datetime import datetime, UTC
import asyncio
from models import (
# User models
Candidate, Employer, BaseUser, EvidenceDetail, Guest, Authentication, AuthResponse, SkillAssessment,
)
from .core import RedisDatabase
logger = logging.getLogger(__name__)
# _RedisManager is a singleton class that manages the Redis connection and
# _RedisManager is a singleton class that manages the Redis connection and
# provides methods for connecting, disconnecting, and performing health checks.
#
# It uses connection pooling for better performance and resource management.
@ -22,14 +19,14 @@ class _RedisManager:
self.redis: Optional[Redis] = None
self.redis_url = os.getenv("REDIS_URL", "redis://redis:6379")
self.redis_db = int(os.getenv("REDIS_DB", "0"))
# Append database to URL if not already present
if not self.redis_url.endswith(f"/{self.redis_db}"):
self.redis_url = f"{self.redis_url}/{self.redis_db}"
self._connection_pool: Optional[ConnectionPool] = None
self._is_connected = False
async def connect(self):
"""Initialize Redis connection with connection pooling"""
if self._is_connected and self.redis:
@ -46,13 +43,11 @@ class _RedisManager:
retry_on_timeout=True,
socket_keepalive=True,
socket_keepalive_options={},
health_check_interval=30
health_check_interval=30,
)
self.redis = Redis(
connection_pool=self._connection_pool
)
self.redis = Redis(connection_pool=self._connection_pool)
if not self.redis:
raise RuntimeError("Redis client not initialized")
@ -60,67 +55,67 @@ class _RedisManager:
await self.redis.ping()
self._is_connected = True
logger.info("Successfully connected to Redis")
# Log Redis info
info = await self.redis.info()
logger.info(f"Redis version: {info.get('redis_version', 'unknown')}")
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
self._is_connected = False
self.redis = None
self._connection_pool = None
raise
async def disconnect(self):
"""Close Redis connection gracefully"""
if not self._is_connected:
logger.info("Redis already disconnected")
return
try:
if self.redis:
# Wait for any pending operations to complete
await asyncio.sleep(0.1)
# Close the client
await self.redis.aclose()
logger.info("Redis client closed")
if self._connection_pool:
# Close the connection pool
await self._connection_pool.aclose()
logger.info("Redis connection pool closed")
self._is_connected = False
self.redis = None
self._connection_pool = None
logger.info("Successfully disconnected from Redis")
except Exception as e:
logger.error(f"Error during Redis disconnect: {e}")
# Force cleanup even if there's an error
self._is_connected = False
self.redis = None
self._connection_pool = None
def get_client(self) -> Redis:
"""Get Redis client instance"""
if not self._is_connected or not self.redis:
raise RuntimeError("Redis client not initialized or disconnected")
return self.redis
@property
def is_connected(self) -> bool:
"""Check if Redis is connected"""
return self._is_connected and self.redis is not None
async def health_check(self) -> dict:
"""Perform health check on Redis connection"""
if not self.is_connected:
return {"status": "disconnected", "error": "Redis not connected"}
if not self.redis:
raise RuntimeError("Redis client not initialized")
@ -128,25 +123,25 @@ class _RedisManager:
# Test basic operations
await self.redis.ping()
info = await self.redis.info()
return {
"status": "healthy",
"redis_version": info.get("redis_version", "unknown"),
"uptime_seconds": info.get("uptime_in_seconds", 0),
"connected_clients": info.get("connected_clients", 0),
"used_memory_human": info.get("used_memory_human", "unknown"),
"total_commands_processed": info.get("total_commands_processed", 0)
"total_commands_processed": info.get("total_commands_processed", 0),
}
except Exception as e:
logger.error(f"Redis health check failed: {e}")
return {"status": "error", "error": str(e)}
async def force_save(self, background: bool = True) -> bool:
"""Force Redis to save data to disk"""
if not self.is_connected:
logger.warning("Cannot save: Redis not connected")
return False
try:
if not self.redis:
raise RuntimeError("Redis client not initialized")
@ -163,12 +158,12 @@ class _RedisManager:
except Exception as e:
logger.error(f"Redis save failed: {e}")
return False
async def get_info(self) -> Optional[dict]:
"""Get Redis server information"""
if not self.is_connected:
return None
try:
if not self.redis:
raise RuntimeError("Redis client not initialized")
@ -177,49 +172,51 @@ class _RedisManager:
logger.error(f"Failed to get Redis info: {e}")
return None
# Global Redis manager instance
redis_manager = _RedisManager()
# DatabaseManager is an enhanced database manager that provides graceful shutdown capabilities
# It manages the Redis connection, tracks active requests, and allows for data backup before shutdown.
class DatabaseManager:
"""Enhanced database manager with graceful shutdown capabilities"""
def __init__(self):
self.db: Optional[RedisDatabase] = None
self._shutdown_initiated = False
self._active_requests = 0
self._shutdown_timeout = int(os.getenv("SHUTDOWN_TIMEOUT", "30")) # seconds
self._backup_on_shutdown = os.getenv("BACKUP_ON_SHUTDOWN", "false").lower() == "true"
async def initialize(self):
"""Initialize database connection"""
try:
# Connect to Redis
await redis_manager.connect()
logger.info("Redis connection established")
# Create database instance
self.db = RedisDatabase(redis_manager.get_client())
# Test connection and log stats
if not redis_manager.redis:
raise RuntimeError("Redis client not initialized")
await redis_manager.redis.ping()
stats = await self.db.get_stats()
logger.info(f"Database initialized successfully. Stats: {stats}")
return self.db
except Exception as e:
logger.error(f"Failed to initialize database: {e}")
raise
async def backup_data(self) -> Optional[str]:
"""Create a backup of critical data before shutdown"""
if not self.db:
return None
try:
backup_data = {
"timestamp": datetime.now(UTC).isoformat(),
@ -227,41 +224,41 @@ class DatabaseManager:
"users": await self.db.get_all_users(),
# Add other critical data as needed
}
backup_filename = f"backup_{datetime.now(UTC).strftime('%Y%m%d_%H%M%S')}.json"
# Save to local file (you might want to save to cloud storage instead)
with open(backup_filename, 'w') as f:
with open(backup_filename, "w") as f:
json.dump(backup_data, f, indent=2, default=str)
logger.info(f"Backup created: {backup_filename}")
return backup_filename
except Exception as e:
logger.error(f"Backup failed: {e}")
return None
async def graceful_shutdown(self):
"""Perform graceful shutdown with optional backup"""
self._shutdown_initiated = True
logger.info("Initiating graceful shutdown...")
# Wait for active requests to complete (with timeout)
wait_time = 0
while self._active_requests > 0 and wait_time < self._shutdown_timeout:
logger.info(f"Waiting for {self._active_requests} active requests to complete...")
await asyncio.sleep(1)
wait_time += 1
if self._active_requests > 0:
logger.warning(f"Shutdown timeout reached. {self._active_requests} requests may be interrupted.")
# Create backup if configured
if self._backup_on_shutdown:
backup_file = await self.backup_data()
if backup_file:
logger.info(f"Pre-shutdown backup completed: {backup_file}")
# Force Redis to save data to disk
try:
if redis_manager.redis:
@ -269,10 +266,10 @@ class DatabaseManager:
try:
await redis_manager.redis.bgsave()
logger.info("Background save initiated")
# Wait a bit for background save to start
await asyncio.sleep(0.5)
except Exception as e:
logger.warning(f"Background save failed, trying synchronous save: {e}")
try:
@ -281,32 +278,32 @@ class DatabaseManager:
logger.info("Synchronous save completed")
except Exception as e2:
logger.warning(f"Synchronous save also failed (Redis persistence may be disabled): {e2}")
except Exception as e:
logger.error(f"Error during Redis save: {e}")
# Close Redis connection
try:
await redis_manager.disconnect()
logger.info("Redis connection closed successfully")
except Exception as e:
logger.error(f"Error closing Redis connection: {e}")
logger.info("Graceful shutdown completed")
def increment_requests(self):
"""Track active requests"""
self._active_requests += 1
def decrement_requests(self):
"""Track completed requests"""
self._active_requests = max(0, self._active_requests - 1)
@property
def is_shutting_down(self) -> bool:
"""Check if shutdown is in progress"""
return self._shutdown_initiated
def get_database(self) -> RedisDatabase:
"""Get database instance"""
if self.db is None:
@ -314,5 +311,3 @@ class DatabaseManager:
if self._shutdown_initiated:
raise RuntimeError("Application is shutting down")
return self.db

View File

@ -8,41 +8,42 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__)
class AIMixin(DatabaseProtocol):
"""Mixin for AI operations"""
async def get_ai_parameters(self, param_id: str) -> Optional[Dict]:
"""Get AI parameters by ID"""
key = f"{KEY_PREFIXES['ai_parameters']}{param_id}"
data = await self.redis.get(key)
return self._deserialize(data) if data else None
async def set_ai_parameters(self, param_id: str, param_data: Dict):
"""Set AI parameters data"""
key = f"{KEY_PREFIXES['ai_parameters']}{param_id}"
await self.redis.set(key, self._serialize(param_data))
async def get_all_ai_parameters(self) -> Dict[str, Any]:
"""Get all AI parameters"""
pattern = f"{KEY_PREFIXES['ai_parameters']}*"
keys = await self.redis.keys(pattern)
if not keys:
return {}
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
result = {}
for key, value in zip(keys, values):
param_id = key.replace(KEY_PREFIXES['ai_parameters'], '')
param_id = key.replace(KEY_PREFIXES["ai_parameters"], "")
result[param_id] = self._deserialize(value)
return result
async def delete_ai_parameters(self, param_id: str):
"""Delete AI parameters"""
key = f"{KEY_PREFIXES['ai_parameters']}{param_id}"
await self.redis.delete(key)

View File

@ -7,6 +7,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class AnalyticsMixin:
"""Mixin for analytics-related database operations"""

View File

@ -8,35 +8,37 @@ from .protocols import DatabaseProtocol
logger = logging.getLogger(__name__)
class AuthMixin(DatabaseProtocol):
"""Mixin for auth-related database operations"""
async def find_verification_token_by_email(self, email: str) -> Optional[Dict[str, Any]]:
"""Find pending verification token by email address"""
try:
pattern = "email_verification:*"
cursor = 0
email_lower = email.lower()
while True:
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
for key in keys:
token_data = await self.redis.get(key)
if token_data:
verification_info = json.loads(token_data)
if (verification_info.get("email", "").lower() == email_lower and
not verification_info.get("verified", False)):
if verification_info.get("email", "").lower() == email_lower and not verification_info.get(
"verified", False
):
# Extract token from key
token = key.replace("email_verification:", "")
verification_info["token"] = token
return verification_info
if cursor == 0:
break
return None
except Exception as e:
logger.error(f"❌ Error finding verification token by email {email}: {e}")
return None
@ -47,22 +49,22 @@ class AuthMixin(DatabaseProtocol):
pattern = "email_verification:*"
cursor = 0
count = 0
while True:
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
for key in keys:
token_data = await self.redis.get(key)
if token_data:
verification_info = json.loads(token_data)
if not verification_info.get("verified", False):
count += 1
if cursor == 0:
break
return count
except Exception as e:
logger.error(f"❌ Error counting pending verifications: {e}")
return 0
@ -74,29 +76,29 @@ class AuthMixin(DatabaseProtocol):
cursor = 0
cleaned_count = 0
current_time = datetime.now(timezone.utc)
while True:
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
for key in keys:
token_data = await self.redis.get(key)
if token_data:
verification_info = json.loads(token_data)
expires_at = datetime.fromisoformat(verification_info.get("expires_at", ""))
if current_time > expires_at:
await self.redis.delete(key)
cleaned_count += 1
logger.info(f"🧹 Cleaned expired verification token for {verification_info.get('email')}")
if cursor == 0:
break
if cleaned_count > 0:
logger.info(f"🧹 Cleaned up {cleaned_count} expired verification tokens")
return cleaned_count
except Exception as e:
logger.error(f"❌ Error cleaning up expired verification tokens: {e}")
return 0
@ -106,22 +108,19 @@ class AuthMixin(DatabaseProtocol):
try:
key = f"verification_attempts:{email.lower()}"
data = await self.redis.get(key)
if not data:
return 0
attempts_data = json.loads(data)
current_time = datetime.now(timezone.utc)
window_start = current_time - timedelta(hours=24)
# Filter out old attempts
recent_attempts = [
attempt for attempt in attempts_data
if datetime.fromisoformat(attempt) > window_start
]
recent_attempts = [attempt for attempt in attempts_data if datetime.fromisoformat(attempt) > window_start]
return len(recent_attempts)
except Exception as e:
logger.error(f"❌ Error getting verification attempts count for {email}: {e}")
return 0
@ -131,34 +130,31 @@ class AuthMixin(DatabaseProtocol):
try:
key = f"verification_attempts:{email.lower()}"
current_time = datetime.now(timezone.utc)
# Get existing attempts
data = await self.redis.get(key)
attempts_data = json.loads(data) if data else []
# Add current attempt
attempts_data.append(current_time.isoformat())
# Keep only last 24 hours of attempts
window_start = current_time - timedelta(hours=24)
recent_attempts = [
attempt for attempt in attempts_data
if datetime.fromisoformat(attempt) > window_start
]
recent_attempts = [attempt for attempt in attempts_data if datetime.fromisoformat(attempt) > window_start]
# Store with 24 hour expiration
await self.redis.setex(
key,
24 * 60 * 60, # 24 hours
json.dumps(recent_attempts)
json.dumps(recent_attempts),
)
return True
except Exception as e:
logger.error(f"❌ Error recording verification attempt for {email}: {e}")
return False
async def store_email_verification_token(self, email: str, token: str, user_type: str, user_data: dict) -> bool:
"""Store email verification token with user data"""
try:
@ -169,16 +165,16 @@ class AuthMixin(DatabaseProtocol):
"user_data": user_data,
"expires_at": (datetime.now(timezone.utc) + timedelta(hours=24)).isoformat(),
"created_at": datetime.now(timezone.utc).isoformat(),
"verified": False
"verified": False,
}
# Store with 24 hour expiration
await self.redis.setex(
key,
key,
24 * 60 * 60, # 24 hours in seconds
json.dumps(verification_data, default=str)
json.dumps(verification_data, default=str),
)
logger.info(f"📧 Stored email verification token for {email}")
return True
except Exception as e:
@ -208,7 +204,7 @@ class AuthMixin(DatabaseProtocol):
await self.redis.setex(
key,
24 * 60 * 60, # Keep for remaining TTL
json.dumps(token_data, default=str)
json.dumps(token_data, default=str),
)
return True
return False
@ -219,7 +215,7 @@ class AuthMixin(DatabaseProtocol):
async def store_mfa_code(self, email: str, code: str, device_id: str) -> bool:
"""Store MFA code for verification"""
try:
logger.info("🔐 Storing MFA code for email: %s", email )
logger.info("🔐 Storing MFA code for email: %s", email)
key = f"mfa_code:{email.lower()}:{device_id}"
mfa_data = {
"code": code,
@ -228,16 +224,16 @@ class AuthMixin(DatabaseProtocol):
"expires_at": (datetime.now(timezone.utc) + timedelta(minutes=10)).isoformat(),
"created_at": datetime.now(timezone.utc).isoformat(),
"attempts": 0,
"verified": False
"verified": False,
}
# Store with 10 minute expiration
await self.redis.setex(
key,
10 * 60, # 10 minutes in seconds
json.dumps(mfa_data, default=str)
json.dumps(mfa_data, default=str),
)
logger.info(f"🔐 Stored MFA code for {email}")
return True
except Exception as e:
@ -266,7 +262,7 @@ class AuthMixin(DatabaseProtocol):
await self.redis.setex(
key,
10 * 60, # Keep original TTL
json.dumps(mfa_data, default=str)
json.dumps(mfa_data, default=str),
)
return mfa_data["attempts"]
return 0
@ -285,7 +281,7 @@ class AuthMixin(DatabaseProtocol):
await self.redis.setex(
key,
10 * 60, # Keep for remaining TTL
json.dumps(mfa_data, default=str)
json.dumps(mfa_data, default=str),
)
return True
return False
@ -303,7 +299,7 @@ class AuthMixin(DatabaseProtocol):
except Exception as e:
logger.error(f"❌ Error storing authentication record for {user_id}: {e}")
return False
async def get_authentication(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Retrieve authentication record for a user"""
try:
@ -315,7 +311,7 @@ class AuthMixin(DatabaseProtocol):
except Exception as e:
logger.error(f"❌ Error retrieving authentication record for {user_id}: {e}")
return None
async def delete_authentication(self, user_id: str) -> bool:
"""Delete authentication record for a user"""
try:
@ -326,8 +322,10 @@ class AuthMixin(DatabaseProtocol):
except Exception as e:
logger.error(f"❌ Error deleting authentication record for {user_id}: {e}")
return False
async def store_refresh_token(self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]) -> bool:
async def store_refresh_token(
self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]
) -> bool:
"""Store refresh token for a user"""
try:
key = f"refresh_token:{token}"
@ -337,9 +335,9 @@ class AuthMixin(DatabaseProtocol):
"device": device_info.get("device", "unknown"),
"ip_address": device_info.get("ip_address", "unknown"),
"is_revoked": False,
"created_at": datetime.now(timezone.utc).isoformat()
"created_at": datetime.now(timezone.utc).isoformat(),
}
# Store with expiration
ttl_seconds = int((expires_at - datetime.now(timezone.utc)).total_seconds())
if ttl_seconds > 0:
@ -352,7 +350,7 @@ class AuthMixin(DatabaseProtocol):
except Exception as e:
logger.error(f"❌ Error storing refresh token for {user_id}: {e}")
return False
async def get_refresh_token(self, token: str) -> Optional[Dict[str, Any]]:
"""Retrieve refresh token data"""
try:
@ -364,7 +362,7 @@ class AuthMixin(DatabaseProtocol):
except Exception as e:
logger.error(f"❌ Error retrieving refresh token: {e}")
return None
async def revoke_refresh_token(self, token: str) -> bool:
"""Revoke a refresh token"""
try:
@ -374,13 +372,13 @@ class AuthMixin(DatabaseProtocol):
token_data["is_revoked"] = True
token_data["revoked_at"] = datetime.now(timezone.utc).isoformat()
await self.redis.set(key, json.dumps(token_data, default=str))
logger.info(f"🔐 Revoked refresh token")
logger.info("🔐 Revoked refresh token")
return True
return False
except Exception as e:
logger.error(f"❌ Error revoking refresh token: {e}")
return False
async def revoke_all_user_tokens(self, user_id: str) -> bool:
"""Revoke all refresh tokens for a user"""
try:
@ -388,10 +386,10 @@ class AuthMixin(DatabaseProtocol):
pattern = "refresh_token:*"
cursor = 0
revoked_count = 0
while True:
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
for key in keys:
token_data = await self.redis.get(key)
if token_data:
@ -401,16 +399,16 @@ class AuthMixin(DatabaseProtocol):
token_info["revoked_at"] = datetime.now(timezone.utc).isoformat()
await self.redis.set(key, json.dumps(token_info, default=str))
revoked_count += 1
if cursor == 0:
break
logger.info(f"🔐 Revoked {revoked_count} refresh tokens for user {user_id}")
return True
except Exception as e:
logger.error(f"❌ Error revoking all tokens for user {user_id}: {e}")
return False
# Password Reset Token Methods
async def store_password_reset_token(self, email: str, token: str, expires_at: datetime) -> bool:
"""Store password reset token"""
@ -420,9 +418,9 @@ class AuthMixin(DatabaseProtocol):
"email": email.lower(),
"expires_at": expires_at.isoformat(),
"used": False,
"created_at": datetime.now(timezone.utc).isoformat()
"created_at": datetime.now(timezone.utc).isoformat(),
}
# Store with expiration
ttl_seconds = int((expires_at - datetime.now(timezone.utc)).total_seconds())
if ttl_seconds > 0:
@ -435,7 +433,7 @@ class AuthMixin(DatabaseProtocol):
except Exception as e:
logger.error(f"❌ Error storing password reset token for {email}: {e}")
return False
async def get_password_reset_token(self, token: str) -> Optional[Dict[str, Any]]:
"""Retrieve password reset token data"""
try:
@ -447,7 +445,7 @@ class AuthMixin(DatabaseProtocol):
except Exception as e:
logger.error(f"❌ Error retrieving password reset token: {e}")
return None
async def mark_password_reset_token_used(self, token: str) -> bool:
"""Mark password reset token as used"""
try:
@ -457,13 +455,13 @@ class AuthMixin(DatabaseProtocol):
token_data["used"] = True
token_data["used_at"] = datetime.now(timezone.utc).isoformat()
await self.redis.set(key, json.dumps(token_data, default=str))
logger.info(f"🔐 Marked password reset token as used")
logger.info("🔐 Marked password reset token as used")
return True
return False
except Exception as e:
logger.error(f"❌ Error marking password reset token as used: {e}")
return False
# User Activity and Security Logging
async def log_security_event(self, user_id: str, event_type: str, details: Dict[str, Any]) -> bool:
"""Log security events for audit purposes"""
@ -473,40 +471,39 @@ class AuthMixin(DatabaseProtocol):
"timestamp": datetime.now(timezone.utc).isoformat(),
"user_id": user_id,
"event_type": event_type,
"details": details
"details": details,
}
# Add to list (latest events first)
await self.redis.lpush(key, json.dumps(event_data, default=str))# type: ignore
await self.redis.lpush(key, json.dumps(event_data, default=str)) # type: ignore
# Keep only last 100 events per day
await self.redis.ltrim(key, 0, 99)# type: ignore
await self.redis.ltrim(key, 0, 99) # type: ignore
# Set expiration for 30 days
await self.redis.expire(key, 30 * 24 * 60 * 60)
logger.info(f"🔒 Logged security event {event_type} for user {user_id}")
return True
except Exception as e:
logger.error(f"❌ Error logging security event for {user_id}: {e}")
return False
async def get_user_security_log(self, user_id: str, days: int = 7) -> List[Dict[str, Any]]:
"""Retrieve security log for a user"""
try:
events = []
for i in range(days):
date = (datetime.now(timezone.utc) - timedelta(days=i)).strftime('%Y-%m-%d')
date = (datetime.now(timezone.utc) - timedelta(days=i)).strftime("%Y-%m-%d")
key = f"security_log:{user_id}:{date}"
daily_events = await self.redis.lrange(key, 0, -1)# type: ignore
daily_events = await self.redis.lrange(key, 0, -1) # type: ignore
for event_json in daily_events:
events.append(json.loads(event_json))
# Sort by timestamp (most recent first)
events.sort(key=lambda x: x["timestamp"], reverse=True)
return events
except Exception as e:
logger.error(f"❌ Error retrieving security log for {user_id}: {e}")
return []
return []

View File

@ -5,20 +5,22 @@ from typing import Any, Dict, TYPE_CHECKING
from .protocols import DatabaseProtocol
from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
pass
class BaseMixin(DatabaseProtocol):
"""Base mixin with core Redis operations and utilities"""
def _serialize(self, data: Any) -> str:
"""Serialize data to JSON string for Redis storage"""
if data is None:
return ""
return json.dumps(data, default=str)
def _deserialize(self, data: str) -> Any:
"""Deserialize JSON string from Redis"""
if not data:
@ -45,6 +47,3 @@ class BaseMixin(DatabaseProtocol):
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)

View File

@ -8,76 +8,77 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__)
class ChatMixin(DatabaseProtocol):
"""Mixin for chat-related database operations"""
# Chat Sessions operations
async def get_candidate_chat_summary(self, candidate_id: str) -> Dict[str, Any]:
"""Get a summary of chat activity for a specific candidate"""
sessions = await self.get_chat_sessions_by_candidate(candidate_id)
if not sessions:
return {
"candidate_id": candidate_id,
"total_sessions": 0,
"total_messages": 0,
"first_chat": None,
"last_chat": None
"last_chat": None,
}
total_messages = 0
for session in sessions:
session_id = session.get("id")
if session_id:
message_count = await self.get_chat_message_count(session_id)
total_messages += message_count
# Sort sessions by creation date
sessions_by_date = sorted(sessions, key=lambda x: x.get("createdAt", ""))
return {
"candidate_id": candidate_id,
"total_sessions": len(sessions),
"total_messages": total_messages,
"first_chat": sessions_by_date[0].get("createdAt") if sessions_by_date else None,
"last_chat": sessions_by_date[-1].get("lastActivity") if sessions_by_date else None,
"recent_sessions": sessions[:5] # Last 5 sessions
"recent_sessions": sessions[:5], # Last 5 sessions
}
# Chat Sessions operations
async def get_chat_session(self, session_id: str) -> Optional[Dict]:
"""Get chat session by ID"""
key = f"{KEY_PREFIXES['chat_sessions']}{session_id}"
data = await self.redis.get(key)
return self._deserialize(data) if data else None
async def set_chat_session(self, session_id: str, session_data: Dict):
"""Set chat session data"""
key = f"{KEY_PREFIXES['chat_sessions']}{session_id}"
await self.redis.set(key, self._serialize(session_data))
async def get_all_chat_sessions(self) -> Dict[str, Any]:
"""Get all chat sessions"""
pattern = f"{KEY_PREFIXES['chat_sessions']}*"
keys = await self.redis.keys(pattern)
if not keys:
return {}
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
result = {}
for key, value in zip(keys, values):
session_id = key.replace(KEY_PREFIXES['chat_sessions'], '')
session_id = key.replace(KEY_PREFIXES["chat_sessions"], "")
result[session_id] = self._deserialize(value)
return result
async def delete_chat_session(self, session_id: str) -> bool:
'''Delete a chat session from Redis'''
"""Delete a chat session from Redis"""
try:
result = await self.redis.delete(f"chat_session:{session_id}")
return result > 0
@ -86,11 +87,11 @@ class ChatMixin(DatabaseProtocol):
raise
async def delete_chat_message(self, session_id: str, message_id: str) -> bool:
'''Delete a specific chat message from Redis'''
"""Delete a specific chat message from Redis"""
try:
# Remove from the session's message list
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
await self.redis.lrem(key, 0, message_id)# type: ignore
await self.redis.lrem(key, 0, message_id) # type: ignore
# Delete the message data itself
result = await self.redis.delete(f"chat_message:{message_id}")
return result > 0
@ -102,42 +103,42 @@ class ChatMixin(DatabaseProtocol):
async def get_chat_messages(self, session_id: str) -> List[Dict]:
"""Get chat messages for a session"""
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
messages = await self.redis.lrange(key, 0, -1)# type: ignore
return [self._deserialize(msg) for msg in messages if msg] # type: ignore
messages = await self.redis.lrange(key, 0, -1) # type: ignore
return [self._deserialize(msg) for msg in messages if msg] # type: ignore
async def add_chat_message(self, session_id: str, message_data: Dict):
"""Add a chat message to a session"""
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
await self.redis.rpush(key, self._serialize(message_data))# type: ignore
await self.redis.rpush(key, self._serialize(message_data)) # type: ignore
async def set_chat_messages(self, session_id: str, messages: List[Dict]):
"""Set all chat messages for a session (replaces existing)"""
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
# Clear existing messages
await self.redis.delete(key)
# Add new messages
if messages:
serialized_messages = [self._serialize(msg) for msg in messages]
await self.redis.rpush(key, *serialized_messages)# type: ignore
await self.redis.rpush(key, *serialized_messages) # type: ignore
async def get_all_chat_messages(self) -> Dict[str, List[Dict]]:
"""Get all chat messages grouped by session"""
pattern = f"{KEY_PREFIXES['chat_messages']}*"
keys = await self.redis.keys(pattern)
if not keys:
return {}
result = {}
for key in keys:
session_id = key.replace(KEY_PREFIXES['chat_messages'], '')
messages = await self.redis.lrange(key, 0, -1)# type: ignore
session_id = key.replace(KEY_PREFIXES["chat_messages"], "")
messages = await self.redis.lrange(key, 0, -1) # type: ignore
result[session_id] = [self._deserialize(msg) for msg in messages if msg]
return result
async def delete_chat_messages(self, session_id: str):
"""Delete all chat messages for a session"""
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
@ -148,61 +149,60 @@ class ChatMixin(DatabaseProtocol):
"""Get all chat sessions for a specific user"""
all_sessions = await self.get_all_chat_sessions()
user_sessions = []
for session_data in all_sessions.values():
if session_data.get("userId") == user_id or session_data.get("guestId") == user_id:
user_sessions.append(session_data)
# Sort by last activity (most recent first)
user_sessions.sort(key=lambda x: x.get("lastActivity", ""), reverse=True)
return user_sessions
async def get_chat_sessions_by_candidate(self, candidate_id: str) -> List[Dict]:
"""Get all chat sessions related to a specific candidate"""
all_sessions = await self.get_all_chat_sessions()
candidate_sessions = []
for session_data in all_sessions.values():
context = session_data.get("context", {})
if (context.get("relatedEntityType") == "candidate" and
context.get("relatedEntityId") == candidate_id):
if context.get("relatedEntityType") == "candidate" and context.get("relatedEntityId") == candidate_id:
candidate_sessions.append(session_data)
# Sort by last activity (most recent first)
candidate_sessions.sort(key=lambda x: x.get("lastActivity", ""), reverse=True)
return candidate_sessions
async def update_chat_session_activity(self, session_id: str):
"""Update the last activity timestamp for a chat session"""
session_data = await self.get_chat_session(session_id)
if session_data:
session_data["lastActivity"] = datetime.now(UTC).isoformat()
await self.set_chat_session(session_id, session_data)
async def get_recent_chat_messages(self, session_id: str, limit: int = 10) -> List[Dict]:
"""Get the most recent chat messages for a session"""
messages = await self.get_chat_messages(session_id)
# Return the last 'limit' messages
return messages[-limit:] if len(messages) > limit else messages
async def get_chat_message_count(self, session_id: str) -> int:
"""Get the total number of messages in a chat session"""
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
return await self.redis.llen(key)# type: ignore
return await self.redis.llen(key) # type: ignore
async def search_chat_messages(self, session_id: str, query: str) -> List[Dict]:
"""Search for messages containing specific text in a session"""
messages = await self.get_chat_messages(session_id)
query_lower = query.lower()
matching_messages = []
for msg in messages:
content = msg.get("content", "").lower()
if query_lower in content:
matching_messages.append(msg)
return matching_messages
# Chat Session Management
async def archive_chat_session(self, session_id: str):
"""Archive a chat session"""
@ -211,38 +211,37 @@ class ChatMixin(DatabaseProtocol):
session_data["isArchived"] = True
session_data["updatedAt"] = datetime.now(UTC).isoformat()
await self.set_chat_session(session_id, session_data)
async def delete_chat_session_completely(self, session_id: str):
"""Delete a chat session and all its messages"""
# Delete the session
await self.delete_chat_session(session_id)
# Delete all messages
await self.delete_chat_messages(session_id)
async def cleanup_old_chat_sessions(self, days_old: int = 90):
"""Archive or delete chat sessions older than specified days"""
cutoff_date = datetime.now(UTC) - timedelta(days=days_old)
cutoff_iso = cutoff_date.isoformat()
all_sessions = await self.get_all_chat_sessions()
archived_count = 0
for session_id, session_data in all_sessions.items():
last_activity = session_data.get("lastActivity", session_data.get("createdAt", ""))
if last_activity < cutoff_iso and not session_data.get("isArchived", False):
await self.archive_chat_session(session_id)
archived_count += 1
return archived_count
# Analytics and Reporting
async def get_chat_statistics(self) -> Dict[str, Any]:
"""Get comprehensive chat statistics"""
all_sessions = await self.get_all_chat_sessions()
all_messages = await self.get_all_chat_messages()
stats = {
"total_sessions": len(all_sessions),
"total_messages": sum(len(messages) for messages in all_messages.values()),
@ -250,34 +249,34 @@ class ChatMixin(DatabaseProtocol):
"archived_sessions": 0,
"sessions_by_type": {},
"sessions_with_candidates": 0,
"average_messages_per_session": 0
"average_messages_per_session": 0,
}
# Analyze sessions
for session_data in all_sessions.values():
if session_data.get("isArchived", False):
stats["archived_sessions"] += 1
else:
stats["active_sessions"] += 1
# Count by type
context_type = session_data.get("context", {}).get("type", "unknown")
stats["sessions_by_type"][context_type] = stats["sessions_by_type"].get(context_type, 0) + 1
# Count sessions with candidate association
if session_data.get("context", {}).get("relatedEntityType") == "candidate":
stats["sessions_with_candidates"] += 1
# Calculate averages
if stats["total_sessions"] > 0:
stats["average_messages_per_session"] = stats["total_messages"] / stats["total_sessions"]
return stats
async def bulk_update_chat_sessions(self, session_updates: Dict[str, Dict]):
"""Bulk update multiple chat sessions"""
pipe = self.redis.pipeline()
for session_id, updates in session_updates.items():
session_data = await self.get_chat_session(session_id)
if session_data:
@ -285,5 +284,5 @@ class ChatMixin(DatabaseProtocol):
session_data["updatedAt"] = datetime.now(UTC).isoformat()
key = f"{KEY_PREFIXES['chat_sessions']}{session_id}"
pipe.set(key, self._serialize(session_data))
await pipe.execute()

View File

@ -7,6 +7,7 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__)
class DocumentMixin(DatabaseProtocol):
"""Mixin for document-related database operations"""
@ -31,17 +32,17 @@ class DocumentMixin(DatabaseProtocol):
try:
# Get all document IDs for this candidate
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
document_ids = await self.redis.lrange(key, 0, -1)# type: ignore
document_ids = await self.redis.lrange(key, 0, -1) # type: ignore
if not document_ids:
logger.info(f"No documents found for candidate {candidate_id}")
return 0
deleted_count = 0
# Use pipeline for efficient batch operations
pipe = self.redis.pipeline()
# Delete each document's metadata
for doc_id in document_ids:
pipe.delete(f"document:{doc_id}")
@ -50,13 +51,13 @@ class DocumentMixin(DatabaseProtocol):
# Delete the candidate's document list
pipe.delete(key)
# Execute all operations
await pipe.execute()
logger.info(f"Successfully deleted {deleted_count} documents for candidate {candidate_id}")
return deleted_count
except Exception as e:
logger.error(f"Error deleting all documents for candidate {candidate_id}: {e}")
raise
@ -64,17 +65,17 @@ class DocumentMixin(DatabaseProtocol):
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]:
"""Get all documents for a specific candidate"""
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
document_ids = await self.redis.lrange(key, 0, -1) # type: ignore
document_ids = await self.redis.lrange(key, 0, -1) # type: ignore
if not document_ids:
return []
# Get all document metadata
pipe = self.redis.pipeline()
for doc_id in document_ids:
pipe.get(f"document:{doc_id}")
values = await pipe.execute()
documents = []
for doc_id, value in zip(document_ids, values):
if value:
@ -83,20 +84,20 @@ class DocumentMixin(DatabaseProtocol):
documents.append(doc_data)
else:
# Clean up orphaned document ID
await self.redis.lrem(key, 0, doc_id)# type: ignore
await self.redis.lrem(key, 0, doc_id) # type: ignore
logger.warning(f"Removed orphaned document ID {doc_id} for candidate {candidate_id}")
return documents
async def add_document_to_candidate(self, candidate_id: str, document_id: str):
"""Add a document ID to a candidate's document list"""
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
await self.redis.rpush(key, document_id)# type: ignore
await self.redis.rpush(key, document_id) # type: ignore
async def remove_document_from_candidate(self, candidate_id: str, document_id: str):
"""Remove a document ID from a candidate's document list"""
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
await self.redis.lrem(key, 0, document_id)# type: ignore
await self.redis.lrem(key, 0, document_id) # type: ignore
async def update_document(self, document_id: str, updates: Dict) -> Dict[Any, Any] | None:
"""Update document metadata"""
@ -115,29 +116,28 @@ class DocumentMixin(DatabaseProtocol):
async def bulk_update_document_rag_status(self, candidate_id: str, document_ids: List[str], include_in_rag: bool):
"""Bulk update RAG status for multiple documents"""
pipe = self.redis.pipeline()
for doc_id in document_ids:
doc_data = await self.get_document(doc_id)
if doc_data and doc_data.get("candidate_id") == candidate_id:
doc_data["include_in_rag"] = include_in_rag
doc_data["updatedAt"] = datetime.now(UTC).isoformat()
pipe.set(f"document:{doc_id}", self._serialize(doc_data))
await pipe.execute()
async def get_document_count_for_candidate(self, candidate_id: str) -> int:
"""Get total number of documents for a candidate"""
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
return await self.redis.llen(key)# type: ignore
return await self.redis.llen(key) # type: ignore
async def search_candidate_documents(self, candidate_id: str, query: str) -> List[Dict]:
"""Search documents by filename for a candidate"""
all_documents = await self.get_candidate_documents(candidate_id)
query_lower = query.lower()
return [
doc for doc in all_documents
if (query_lower in doc.get("filename", "").lower() or
query_lower in doc.get("originalName", "").lower())
]
return [
doc
for doc in all_documents
if (query_lower in doc.get("filename", "").lower() or query_lower in doc.get("originalName", "").lower())
]

View File

@ -8,39 +8,41 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__)
class JobMixin(DatabaseProtocol):
"""Mixin for job-related database operations"""
async def get_job(self, job_id: str) -> Optional[Dict]:
"""Get job by ID"""
key = f"{KEY_PREFIXES['jobs']}{job_id}"
data = await self.redis.get(key)
return self._deserialize(data) if data else None
async def set_job(self, job_id: str, job_data: Dict):
"""Set job data"""
key = f"{KEY_PREFIXES['jobs']}{job_id}"
await self.redis.set(key, self._serialize(job_data))
async def get_all_jobs(self) -> Dict[str, Any]:
"""Get all jobs"""
pattern = f"{KEY_PREFIXES['jobs']}*"
keys = await self.redis.keys(pattern)
if not keys:
return {}
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
result = {}
for key, value in zip(keys, values):
job_id = key.replace(KEY_PREFIXES['jobs'], '')
job_id = key.replace(KEY_PREFIXES["jobs"], "")
result[job_id] = self._deserialize(value)
return result
async def delete_job(self, job_id: str):
"""Delete job"""
key = f"{KEY_PREFIXES['jobs']}{job_id}"
@ -51,10 +53,10 @@ class JobMixin(DatabaseProtocol):
try:
# Get all documents for the candidate
candidate_documents = await self.get_candidate_documents(candidate_id)
if not candidate_documents:
return []
# Get job requirements for each document
job_requirements = []
for doc in candidate_documents:
@ -66,7 +68,7 @@ class JobMixin(DatabaseProtocol):
requirements["document_filename"] = doc.get("filename")
requirements["document_original_name"] = doc.get("originalName")
job_requirements.append(requirements)
return job_requirements
except Exception as e:
logger.error(f"❌ Error retrieving job requirements for candidate {candidate_id}: {e}")
@ -82,15 +84,15 @@ class JobMixin(DatabaseProtocol):
try:
deleted_count = 0
pipe = self.redis.pipeline()
for doc_id in document_ids:
key = f"{KEY_PREFIXES['job_requirements']}{doc_id}"
pipe.delete(key)
deleted_count += 1
results = await pipe.execute()
actual_deleted = sum(1 for result in results if result > 0)
logger.info(f"📋 Bulk deleted job requirements for {actual_deleted}/{len(document_ids)} documents")
return actual_deleted
except Exception as e:
@ -103,49 +105,49 @@ class JobMixin(DatabaseProtocol):
key = f"{KEY_PREFIXES['job_applications']}{application_id}"
data = await self.redis.get(key)
return self._deserialize(data) if data else None
async def set_job_application(self, application_id: str, application_data: Dict):
"""Set job application data"""
key = f"{KEY_PREFIXES['job_applications']}{application_id}"
await self.redis.set(key, self._serialize(application_data))
async def get_all_job_applications(self) -> Dict[str, Any]:
"""Get all job applications"""
pattern = f"{KEY_PREFIXES['job_applications']}*"
keys = await self.redis.keys(pattern)
if not keys:
return {}
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
result = {}
for key, value in zip(keys, values):
app_id = key.replace(KEY_PREFIXES['job_applications'], '')
app_id = key.replace(KEY_PREFIXES["job_applications"], "")
result[app_id] = self._deserialize(value)
return result
async def delete_job_application(self, application_id: str):
"""Delete job application"""
key = f"{KEY_PREFIXES['job_applications']}{application_id}"
await self.redis.delete(key)
async def cleanup_orphaned_job_requirements(self) -> int:
"""Clean up job requirements for documents that no longer exist"""
try:
# Get all job requirements
all_requirements = await self.get_all_job_requirements()
if not all_requirements:
return 0
orphaned_count = 0
pipe = self.redis.pipeline()
for document_id in all_requirements.keys():
# Check if the document still exists
document_exists = await self.get_document(document_id)
@ -155,16 +157,16 @@ class JobMixin(DatabaseProtocol):
pipe.delete(key)
orphaned_count += 1
logger.info(f"📋 Queued orphaned job requirements for deletion: {document_id}")
if orphaned_count > 0:
await pipe.execute()
logger.info(f"🧹 Cleaned up {orphaned_count} orphaned job requirements")
return orphaned_count
except Exception as e:
logger.error(f"❌ Error cleaning up orphaned job requirements: {e}")
return 0
return 0
async def get_job_requirements(self, document_id: str) -> Optional[Dict]:
"""Get cached job requirements analysis for a document"""
try:
@ -184,19 +186,19 @@ class JobMixin(DatabaseProtocol):
"""Save job requirements analysis results for a document"""
try:
key = f"{KEY_PREFIXES['job_requirements']}{document_id}"
# Add metadata to the requirements
requirements_with_meta = {
**requirements,
"cached_at": datetime.now(UTC).isoformat(),
"document_id": document_id
"document_id": document_id,
}
await self.redis.set(key, self._serialize(requirements_with_meta))
# Optional: Set expiration (e.g., 30 days) to prevent indefinite storage
# await self.redis.expire(key, 30 * 24 * 60 * 60) # 30 days
logger.info(f"📋 Saved job requirements for document {document_id}")
return True
except Exception as e:
@ -221,21 +223,21 @@ class JobMixin(DatabaseProtocol):
try:
pattern = f"{KEY_PREFIXES['job_requirements']}*"
keys = await self.redis.keys(pattern)
if not keys:
return {}
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
result = {}
for key, value in zip(keys, values):
document_id = key.replace(KEY_PREFIXES['job_requirements'], '')
document_id = key.replace(KEY_PREFIXES["job_requirements"], "")
if value:
result[document_id] = self._deserialize(value)
return result
except Exception as e:
logger.error(f"❌ Error retrieving all job requirements: {e}")
@ -246,35 +248,30 @@ class JobMixin(DatabaseProtocol):
try:
pattern = f"{KEY_PREFIXES['job_requirements']}*"
keys = await self.redis.keys(pattern)
stats = {
"total_cached_requirements": len(keys),
"cache_dates": {},
"documents_with_requirements": []
}
stats = {"total_cached_requirements": len(keys), "cache_dates": {}, "documents_with_requirements": []}
if keys:
# Get cache dates for analysis
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
for key, value in zip(keys, values):
if value:
requirements_data = self._deserialize(value)
if requirements_data:
document_id = key.replace(KEY_PREFIXES['job_requirements'], '')
document_id = key.replace(KEY_PREFIXES["job_requirements"], "")
stats["documents_with_requirements"].append(document_id)
# Track cache dates
cached_at = requirements_data.get("cached_at")
if cached_at:
cache_date = cached_at[:10] # Extract date part
stats["cache_dates"][cache_date] = stats["cache_dates"].get(cache_date, 0) + 1
return stats
except Exception as e:
logger.error(f"❌ Error getting job requirements stats: {e}")
return {"total_cached_requirements": 0, "cache_dates": {}, "documents_with_requirements": []}

View File

@ -7,153 +7,412 @@ if TYPE_CHECKING:
from models import SkillAssessment
class DatabaseProtocol(Protocol):
# Base mixin
redis: Redis
def _serialize(self, data) -> str: ...
def _deserialize(self, data: str): ...
def _serialize(self, data) -> str:
...
def _deserialize(self, data: str):
...
# Chat mixin
async def add_chat_message(self, session_id: str, message_data: Dict): ...
async def archive_chat_session(self, session_id: str): ...
async def bulk_update_chat_sessions(self, session_updates: Dict[str, Dict]): ...
async def delete_chat_message(self, session_id: str, message_id: str) -> bool: ...
async def delete_chat_messages(self, session_id: str): ...
async def delete_chat_session_completely(self, session_id: str): ...
async def delete_chat_session(self, session_id: str) -> bool: ...
async def add_chat_message(self, session_id: str, message_data: Dict):
...
async def archive_chat_session(self, session_id: str):
...
async def bulk_update_chat_sessions(self, session_updates: Dict[str, Dict]):
...
async def delete_chat_message(self, session_id: str, message_id: str) -> bool:
...
async def delete_chat_messages(self, session_id: str):
...
async def delete_chat_session_completely(self, session_id: str):
...
async def delete_chat_session(self, session_id: str) -> bool:
...
# Document mixin
async def add_document_to_candidate(self, candidate_id: str, document_id: str): ...
async def bulk_update_document_rag_status(self, candidate_id: str, document_ids: List[str], include_in_rag: bool): ...
async def add_document_to_candidate(self, candidate_id: str, document_id: str):
...
async def bulk_update_document_rag_status(self, candidate_id: str, document_ids: List[str], include_in_rag: bool):
...
# Job mixin
async def bulk_delete_job_requirements(self, document_ids: List[str]) -> int: ...
async def cache_skill_match(self, cache_key: str, assessment: SkillAssessment) -> None: ...
async def bulk_delete_job_requirements(self, document_ids: List[str]) -> int:
...
async def cache_skill_match(self, cache_key: str, assessment: SkillAssessment) -> None:
...
# User mixin
async def delete_candidate_batch(self, candidate_ids: List[str]) -> Dict[str, Dict[str, int]]: ...
async def delete_candidate(self, candidate_id: str) -> Dict[str, int]: ...
async def delete_employer(self, employer_id: str): ...
async def delete_guest(self, guest_id: str) -> bool: ...
async def delete_user(self, email: str): ...
async def find_candidate_by_username(self, username: str) -> Optional[Dict]: ...
async def get_all_users(self) -> Dict[str, Any]: ...
async def get_all_viewers(self) -> Dict[str, Any]: ...
async def get_candidate_chat_summary(self, candidate_id: str) -> Dict[str, Any]: ...
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]: ...
async def get_candidate(self, candidate_id: str) -> Optional[Dict]: ...
async def get_employer(self, employer_id: str) -> Optional[Dict]: ...
async def get_guest_by_session_id(self, session_id: str) -> Optional[Dict[str, Any]]: ...
async def get_guest(self, guest_id: str) -> Optional[Dict[str, Any]]: ...
async def get_guest_statistics(self) -> Dict[str, Any]: ...
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: ...
async def get_user_by_username(self, username: str) -> Optional[Dict]: ...
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]: ...
async def get_user_security_log(self, user_id: str, days: int = 7) -> List[Dict[str, Any]]: ...
async def get_user(self, login: str) -> Optional[Dict[str, Any]]: ...
async def invalidate_candidate_skill_cache(self, candidate_id: str) -> int: ...
async def invalidate_user_skill_cache(self, user_id: str) -> int: ...
async def set_candidate(self, candidate_id: str, candidate_data: Dict): ...
async def set_employer(self, employer_id: str, employer_data: Dict): ...
async def set_guest(self, guest_id: str, guest_data: Dict[str, Any]) -> None: ...
async def set_user_by_id(self, user_id: str, user_data: Dict[str, Any]) -> bool: ...
async def set_user(self, login: str, user_data: Dict[str, Any]) -> bool: ...
async def update_user_rag_timestamp(self, user_id: str) -> bool: ...
async def user_exists_by_email(self, email: str) -> bool: ...
async def user_exists_by_username(self, username: str) -> bool: ...
async def delete_candidate_batch(self, candidate_ids: List[str]) -> Dict[str, Dict[str, int]]:
...
# Auth mixin
async def cleanup_expired_verification_tokens(self) -> int: ...
async def cleanup_inactive_guests(self, inactive_hours: int = 24) -> int: ...
async def cleanup_old_chat_sessions(self, days_old: int = 90) -> int: ...
async def cleanup_orphaned_job_requirements(self) -> int: ...
async def clear_all_data(self: "DatabaseProtocol"): ...
async def clear_all_skill_match_cache(self) -> int: ...
async def delete_candidate(self, candidate_id: str) -> Dict[str, int]:
...
async def delete_employer(self, employer_id: str):
...
async def delete_guest(self, guest_id: str) -> bool:
...
async def delete_user(self, email: str):
...
async def find_candidate_by_username(self, username: str) -> Optional[Dict]:
...
async def get_all_users(self) -> Dict[str, Any]:
...
async def get_all_viewers(self) -> Dict[str, Any]:
...
async def get_candidate_chat_summary(self, candidate_id: str) -> Dict[str, Any]:
...
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]:
...
async def get_candidate(self, candidate_id: str) -> Optional[Dict]:
...
async def get_employer(self, employer_id: str) -> Optional[Dict]:
...
async def get_guest_by_session_id(self, session_id: str) -> Optional[Dict[str, Any]]:
...
async def get_guest(self, guest_id: str) -> Optional[Dict[str, Any]]:
...
async def get_guest_statistics(self) -> Dict[str, Any]:
...
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
...
async def get_user_by_username(self, username: str) -> Optional[Dict]:
...
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]:
...
async def get_user_security_log(self, user_id: str, days: int = 7) -> List[Dict[str, Any]]:
...
async def get_user(self, login: str) -> Optional[Dict[str, Any]]:
...
async def invalidate_candidate_skill_cache(self, candidate_id: str) -> int:
...
async def invalidate_user_skill_cache(self, user_id: str) -> int:
...
async def set_candidate(self, candidate_id: str, candidate_data: Dict):
...
async def set_employer(self, employer_id: str, employer_data: Dict):
...
async def set_guest(self, guest_id: str, guest_data: Dict[str, Any]) -> None:
...
async def set_user_by_id(self, user_id: str, user_data: Dict[str, Any]) -> bool:
...
async def set_user(self, login: str, user_data: Dict[str, Any]) -> bool:
...
async def update_user_rag_timestamp(self, user_id: str) -> bool:
...
async def user_exists_by_email(self, email: str) -> bool:
...
async def user_exists_by_username(self, username: str) -> bool:
...
# Auth mixin
async def cleanup_expired_verification_tokens(self) -> int:
...
async def cleanup_inactive_guests(self, inactive_hours: int = 24) -> int:
...
async def cleanup_old_chat_sessions(self, days_old: int = 90) -> int:
...
async def cleanup_orphaned_job_requirements(self) -> int:
...
async def clear_all_data(self: "DatabaseProtocol"):
...
async def clear_all_skill_match_cache(self) -> int:
...
# Resume mixin
async def delete_ai_parameters(self, param_id: str): ...
async def delete_all_candidate_documents(self, candidate_id: str) -> int: ...
async def delete_all_resumes_for_user(self, user_id: str) -> int: ...
async def delete_authentication(self, user_id: str) -> bool: ...
async def delete_document(self, document_id: str): ...
async def delete_ai_parameters(self, param_id: str):
...
async def delete_job_application(self, application_id: str): ...
async def delete_job_requirements(self, document_id: str) -> bool: ...
async def delete_job(self, job_id: str): ...
async def delete_resume(self, user_id: str, resume_id: str) -> bool: ...
async def delete_viewer(self, viewer_id: str): ...
async def find_verification_token_by_email(self, email: str) -> Optional[Dict[str, Any]]: ...
async def get_ai_parameters(self, param_id: str) -> Optional[Dict]: ...
async def get_all_ai_parameters(self) -> Dict[str, Any]: ...
async def get_all_candidates(self) -> Dict[str, Any]: ...
async def get_all_chat_messages(self) -> Dict[str, List[Dict]]: ...
async def get_all_chat_sessions(self) -> Dict[str, Any]: ...
async def get_all_employers(self) -> Dict[str, Any]: ...
async def get_all_guests(self) -> Dict[str, Dict[str, Any]]: ...
async def get_all_job_applications(self) -> Dict[str, Any]: ...
async def get_all_job_requirements(self) -> Dict[str, Any]: ...
async def get_all_jobs(self) -> Dict[str, Any]: ...
async def get_all_resumes_for_user(self, user_id: str) -> List[Dict]: ...
async def get_all_resumes(self) -> Dict[str, List[Dict]]: ...
async def get_authentication(self, user_id: str) -> Optional[Dict[str, Any]]: ...
async def get_cached_skill_match(self, cache_key: str) -> Optional[SkillAssessment]: ...
async def get_chat_message_count(self, session_id: str) -> int: ...
async def get_chat_messages(self, session_id: str) -> List[Dict]: ...
async def get_chat_sessions_by_candidate(self, candidate_id: str) -> List[Dict]: ...
async def get_chat_sessions_by_user(self, user_id: str) -> List[Dict]: ...
async def get_chat_session(self, session_id: str) -> Optional[Dict]: ...
async def get_chat_statistics(self) -> Dict[str, Any]: ...
async def get_document_count_for_candidate(self, candidate_id: str) -> int: ...
async def get_documents_by_rag_status(self, candidate_id: str, include_in_rag: bool = True) -> List[Dict]: ...
async def get_document(self, document_id: str) -> Optional[Dict]: ...
async def get_email_verification_token(self, token: str) -> Optional[Dict[str, Any]]: ...
async def get_job_application(self, application_id: str) -> Optional[Dict]: ...
async def get_job_requirements_by_candidate(self, candidate_id: str) -> List[Dict]: ...
async def get_job_requirements(self, document_id: str) -> Optional[Dict]: ...
async def get_job_requirements_stats(self) -> Dict[str, Any]: ...
async def get_job(self, job_id: str) -> Optional[Dict]: ...
async def get_mfa_code(self, email: str, device_id: str) -> Optional[Dict[str, Any]]: ...
async def get_multiple_candidates_by_usernames(self, usernames: List[str]) -> Dict[str, Dict]: ...
async def get_password_reset_token(self, token: str) -> Optional[Dict[str, Any]]: ...
async def get_pending_verifications_count(self) -> int: ...
async def get_recent_chat_messages(self, session_id: str, limit: int = 10) -> List[Dict]: ...
async def get_refresh_token(self, token: str) -> Optional[Dict[str, Any]]: ...
async def get_resumes_by_candidate(self, user_id: str, candidate_id: str) -> List[Dict]: ...
async def get_resumes_by_job(self, user_id: str, job_id: str) -> List[Dict]: ...
async def get_resume(self, user_id: str, resume_id: str) -> Optional[Dict]: ...
async def get_resume_statistics(self, user_id: str) -> Dict[str, Any]: ...
async def get_stats(self) -> Dict[str, int]: ...
async def get_verification_attempts_count(self, email: str) -> int: ...
async def get_viewer(self, viewer_id: str) -> Optional[Dict]: ...
async def increment_mfa_attempts(self, email: str, device_id: str) -> int: ...
async def invalidate_job_requirements_cache(self, document_id: str) -> bool: ...
async def log_security_event(self, user_id: str, event_type: str, details: Dict[str, Any]) -> bool: ...
async def mark_email_verified(self, token: str) -> bool: ...
async def mark_mfa_verified(self, email: str, device_id: str) -> bool: ...
async def mark_password_reset_token_used(self, token: str) -> bool: ...
async def record_verification_attempt(self, email: str) -> bool: ...
async def remove_document_from_candidate(self, candidate_id: str, document_id: str): ...
async def revoke_all_user_tokens(self, user_id: str) -> bool: ...
async def revoke_refresh_token(self, token: str) -> bool: ...
async def save_job_requirements(self, document_id: str, requirements: Dict) -> bool: ...
async def search_candidate_documents(self, candidate_id: str, query: str) -> List[Dict]: ...
async def search_chat_messages(self, session_id: str, query: str) -> List[Dict]: ...
async def search_resumes_for_user(self, user_id: str, query: str) -> List[Dict]: ...
async def set_ai_parameters(self, param_id: str, param_data: Dict): ...
async def set_authentication(self, user_id: str, auth_data: Dict[str, Any]) -> bool: ...
async def set_chat_messages(self, session_id: str, messages: List[Dict]): ...
async def set_chat_session(self, session_id: str, session_data: Dict): ...
async def set_document(self, document_id: str, document_data: Dict): ...
async def set_job_application(self, application_id: str, application_data: Dict): ...
async def set_job(self, job_id: str, job_data: Dict): ...
async def set_resume(self, user_id: str, resume_data: Dict) -> bool: ...
async def set_viewer(self, viewer_id: str, viewer_data: Dict): ...
async def store_email_verification_token(self, email: str, token: str, user_type: str, user_data: dict) -> bool: ...
async def store_mfa_code(self, email: str, code: str, device_id: str) -> bool: ...
async def store_password_reset_token(self, email: str, token: str, expires_at: datetime) -> bool: ...
async def store_refresh_token(self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]) -> bool: ...
async def update_chat_session_activity(self, session_id: str): ...
async def update_document(self, document_id: str, updates: Dict)-> Dict[Any, Any] | None: ...
async def update_resume(self, user_id: str, resume_id: str, updates: Dict) -> Optional[Dict]: ...
async def delete_all_candidate_documents(self, candidate_id: str) -> int:
...
async def delete_all_resumes_for_user(self, user_id: str) -> int:
...
async def delete_authentication(self, user_id: str) -> bool:
...
async def delete_document(self, document_id: str):
...
async def delete_job_application(self, application_id: str):
...
async def delete_job_requirements(self, document_id: str) -> bool:
...
async def delete_job(self, job_id: str):
...
async def delete_resume(self, user_id: str, resume_id: str) -> bool:
...
async def delete_viewer(self, viewer_id: str):
...
async def find_verification_token_by_email(self, email: str) -> Optional[Dict[str, Any]]:
...
async def get_ai_parameters(self, param_id: str) -> Optional[Dict]:
...
async def get_all_ai_parameters(self) -> Dict[str, Any]:
...
async def get_all_candidates(self) -> Dict[str, Any]:
...
async def get_all_chat_messages(self) -> Dict[str, List[Dict]]:
...
async def get_all_chat_sessions(self) -> Dict[str, Any]:
...
async def get_all_employers(self) -> Dict[str, Any]:
...
async def get_all_guests(self) -> Dict[str, Dict[str, Any]]:
...
async def get_all_job_applications(self) -> Dict[str, Any]:
...
async def get_all_job_requirements(self) -> Dict[str, Any]:
...
async def get_all_jobs(self) -> Dict[str, Any]:
...
async def get_all_resumes_for_user(self, user_id: str) -> List[Dict]:
...
async def get_all_resumes(self) -> Dict[str, List[Dict]]:
...
async def get_authentication(self, user_id: str) -> Optional[Dict[str, Any]]:
...
async def get_cached_skill_match(self, cache_key: str) -> Optional[SkillAssessment]:
...
async def get_chat_message_count(self, session_id: str) -> int:
...
async def get_chat_messages(self, session_id: str) -> List[Dict]:
...
async def get_chat_sessions_by_candidate(self, candidate_id: str) -> List[Dict]:
...
async def get_chat_sessions_by_user(self, user_id: str) -> List[Dict]:
...
async def get_chat_session(self, session_id: str) -> Optional[Dict]:
...
async def get_chat_statistics(self) -> Dict[str, Any]:
...
async def get_document_count_for_candidate(self, candidate_id: str) -> int:
...
async def get_documents_by_rag_status(self, candidate_id: str, include_in_rag: bool = True) -> List[Dict]:
...
async def get_document(self, document_id: str) -> Optional[Dict]:
...
async def get_email_verification_token(self, token: str) -> Optional[Dict[str, Any]]:
...
async def get_job_application(self, application_id: str) -> Optional[Dict]:
...
async def get_job_requirements_by_candidate(self, candidate_id: str) -> List[Dict]:
...
async def get_job_requirements(self, document_id: str) -> Optional[Dict]:
...
async def get_job_requirements_stats(self) -> Dict[str, Any]:
...
async def get_job(self, job_id: str) -> Optional[Dict]:
...
async def get_mfa_code(self, email: str, device_id: str) -> Optional[Dict[str, Any]]:
...
async def get_multiple_candidates_by_usernames(self, usernames: List[str]) -> Dict[str, Dict]:
...
async def get_password_reset_token(self, token: str) -> Optional[Dict[str, Any]]:
...
async def get_pending_verifications_count(self) -> int:
...
async def get_recent_chat_messages(self, session_id: str, limit: int = 10) -> List[Dict]:
...
async def get_refresh_token(self, token: str) -> Optional[Dict[str, Any]]:
...
async def get_resumes_by_candidate(self, user_id: str, candidate_id: str) -> List[Dict]:
...
async def get_resumes_by_job(self, user_id: str, job_id: str) -> List[Dict]:
...
async def get_resume(self, user_id: str, resume_id: str) -> Optional[Dict]:
...
async def get_resume_statistics(self, user_id: str) -> Dict[str, Any]:
...
async def get_stats(self) -> Dict[str, int]:
...
async def get_verification_attempts_count(self, email: str) -> int:
...
async def get_viewer(self, viewer_id: str) -> Optional[Dict]:
...
async def increment_mfa_attempts(self, email: str, device_id: str) -> int:
...
async def invalidate_job_requirements_cache(self, document_id: str) -> bool:
...
async def log_security_event(self, user_id: str, event_type: str, details: Dict[str, Any]) -> bool:
...
async def mark_email_verified(self, token: str) -> bool:
...
async def mark_mfa_verified(self, email: str, device_id: str) -> bool:
...
async def mark_password_reset_token_used(self, token: str) -> bool:
...
async def record_verification_attempt(self, email: str) -> bool:
...
async def remove_document_from_candidate(self, candidate_id: str, document_id: str):
...
async def revoke_all_user_tokens(self, user_id: str) -> bool:
...
async def revoke_refresh_token(self, token: str) -> bool:
...
async def save_job_requirements(self, document_id: str, requirements: Dict) -> bool:
...
async def search_candidate_documents(self, candidate_id: str, query: str) -> List[Dict]:
...
async def search_chat_messages(self, session_id: str, query: str) -> List[Dict]:
...
async def search_resumes_for_user(self, user_id: str, query: str) -> List[Dict]:
...
async def set_ai_parameters(self, param_id: str, param_data: Dict):
...
async def set_authentication(self, user_id: str, auth_data: Dict[str, Any]) -> bool:
...
async def set_chat_messages(self, session_id: str, messages: List[Dict]):
...
async def set_chat_session(self, session_id: str, session_data: Dict):
...
async def set_document(self, document_id: str, document_data: Dict):
...
async def set_job_application(self, application_id: str, application_data: Dict):
...
async def set_job(self, job_id: str, job_data: Dict):
...
async def set_resume(self, user_id: str, resume_data: Dict) -> bool:
...
async def set_viewer(self, viewer_id: str, viewer_data: Dict):
...
async def store_email_verification_token(self, email: str, token: str, user_type: str, user_data: dict) -> bool:
...
async def store_mfa_code(self, email: str, code: str, device_id: str) -> bool:
...
async def store_password_reset_token(self, email: str, token: str, expires_at: datetime) -> bool:
...
async def store_refresh_token(
self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]
) -> bool:
...
async def update_chat_session_activity(self, session_id: str):
...
async def update_document(self, document_id: str, updates: Dict) -> Dict[Any, Any] | None:
...
async def update_resume(self, user_id: str, resume_id: str, updates: Dict) -> Optional[Dict]:
...

View File

@ -7,6 +7,7 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__)
class ResumeMixin(DatabaseProtocol):
"""Mixin for resume-related database operations"""
@ -14,19 +15,19 @@ class ResumeMixin(DatabaseProtocol):
"""Save a resume for a user"""
try:
# Generate resume_id if not present
if 'id' not in resume_data:
if "id" not in resume_data:
raise ValueError("Resume data must include an 'id' field")
resume_id = resume_data['id']
resume_id = resume_data["id"]
# Store the resume data
key = f"{KEY_PREFIXES['resumes']}{user_id}:{resume_id}"
await self.redis.set(key, self._serialize(resume_data))
# Add resume_id to user's resume list
user_resumes_key = f"{KEY_PREFIXES['user_resumes']}{user_id}"
await self.redis.rpush(user_resumes_key, resume_id) # type: ignore
await self.redis.rpush(user_resumes_key, resume_id) # type: ignore
logger.info(f"📄 Saved resume {resume_id} for user {user_id}")
return True
except Exception as e:
@ -53,19 +54,19 @@ class ResumeMixin(DatabaseProtocol):
try:
# Get all resume IDs for this user
user_resumes_key = f"{KEY_PREFIXES['user_resumes']}{user_id}"
resume_ids = await self.redis.lrange(user_resumes_key, 0, -1)# type: ignore
resume_ids = await self.redis.lrange(user_resumes_key, 0, -1) # type: ignore
if not resume_ids:
logger.info(f"📄 No resumes found for user {user_id}")
return []
# Get all resume data
resumes = []
pipe = self.redis.pipeline()
for resume_id in resume_ids:
pipe.get(f"{KEY_PREFIXES['resumes']}{user_id}:{resume_id}")
values = await pipe.execute()
for resume_id, value in zip(resume_ids, values):
if value:
resume_data = self._deserialize(value)
@ -73,12 +74,12 @@ class ResumeMixin(DatabaseProtocol):
resumes.append(resume_data)
else:
# Clean up orphaned resume ID
await self.redis.lrem(user_resumes_key, 0, resume_id)# type: ignore
await self.redis.lrem(user_resumes_key, 0, resume_id) # type: ignore
logger.warning(f"Removed orphaned resume ID {resume_id} for user {user_id}")
# Sort by created_at timestamp (most recent first)
resumes.sort(key=lambda x: x.get("created_at", ""), reverse=True)
logger.info(f"📄 Retrieved {len(resumes)} resumes for user {user_id}")
return resumes
except Exception as e:
@ -91,11 +92,11 @@ class ResumeMixin(DatabaseProtocol):
# Delete the resume data
key = f"{KEY_PREFIXES['resumes']}{user_id}:{resume_id}"
result = await self.redis.delete(key)
# Remove from user's resume list
user_resumes_key = f"{KEY_PREFIXES['user_resumes']}{user_id}"
await self.redis.lrem(user_resumes_key, 0, resume_id)# type: ignore
await self.redis.lrem(user_resumes_key, 0, resume_id) # type: ignore
if result > 0:
logger.info(f"🗑️ Deleted resume {resume_id} for user {user_id}")
return True
@ -111,31 +112,31 @@ class ResumeMixin(DatabaseProtocol):
try:
# Get all resume IDs for this user
user_resumes_key = f"{KEY_PREFIXES['user_resumes']}{user_id}"
resume_ids = await self.redis.lrange(user_resumes_key, 0, -1)# type: ignore
resume_ids = await self.redis.lrange(user_resumes_key, 0, -1) # type: ignore
if not resume_ids:
logger.info(f"📄 No resumes found for user {user_id}")
return 0
deleted_count = 0
# Use pipeline for efficient batch operations
pipe = self.redis.pipeline()
# Delete each resume
for resume_id in resume_ids:
pipe.delete(f"{KEY_PREFIXES['resumes']}{user_id}:{resume_id}")
deleted_count += 1
# Delete the user's resume list
pipe.delete(user_resumes_key)
# Execute all operations
await pipe.execute()
logger.info(f"🗑️ Successfully deleted {deleted_count} resumes for user {user_id}")
return deleted_count
except Exception as e:
logger.error(f"❌ Error deleting all resumes for user {user_id}: {e}")
raise
@ -145,21 +146,21 @@ class ResumeMixin(DatabaseProtocol):
try:
pattern = f"{KEY_PREFIXES['resumes']}*"
keys = await self.redis.keys(pattern)
if not keys:
return {}
# Group by user_id
user_resumes = {}
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
for key, value in zip(keys, values):
if value:
# Extract user_id from key format: resume:{user_id}:{resume_id}
key_parts = key.replace(KEY_PREFIXES['resumes'], '').split(':', 1)
key_parts = key.replace(KEY_PREFIXES["resumes"], "").split(":", 1)
if len(key_parts) >= 1:
user_id = key_parts[0]
resume_data = self._deserialize(value)
@ -167,11 +168,11 @@ class ResumeMixin(DatabaseProtocol):
if user_id not in user_resumes:
user_resumes[user_id] = []
user_resumes[user_id].append(resume_data)
# Sort each user's resumes by created_at
for user_id in user_resumes:
user_resumes[user_id].sort(key=lambda x: x.get("created_at", ""), reverse=True)
return user_resumes
except Exception as e:
logger.error(f"❌ Error retrieving all resumes: {e}")
@ -182,20 +183,22 @@ class ResumeMixin(DatabaseProtocol):
try:
all_resumes = await self.get_all_resumes_for_user(user_id)
query_lower = query.lower()
matching_resumes = []
for resume in all_resumes:
# Search in resume content, job_id, candidate_id, etc.
searchable_text = " ".join([
resume.get("resume", ""),
resume.get("job_id", ""),
resume.get("candidate_id", ""),
str(resume.get("created_at", ""))
]).lower()
searchable_text = " ".join(
[
resume.get("resume", ""),
resume.get("job_id", ""),
resume.get("candidate_id", ""),
str(resume.get("created_at", "")),
]
).lower()
if query_lower in searchable_text:
matching_resumes.append(resume)
logger.info(f"📄 Found {len(matching_resumes)} matching resumes for user {user_id}")
return matching_resumes
except Exception as e:
@ -206,11 +209,8 @@ class ResumeMixin(DatabaseProtocol):
"""Get all resumes for a specific candidate created by a user"""
try:
all_resumes = await self.get_all_resumes_for_user(user_id)
candidate_resumes = [
resume for resume in all_resumes
if resume.get("candidate_id") == candidate_id
]
candidate_resumes = [resume for resume in all_resumes if resume.get("candidate_id") == candidate_id]
logger.info(f"📄 Found {len(candidate_resumes)} resumes for candidate {candidate_id} by user {user_id}")
return candidate_resumes
except Exception as e:
@ -221,11 +221,8 @@ class ResumeMixin(DatabaseProtocol):
"""Get all resumes for a specific job created by a user"""
try:
all_resumes = await self.get_all_resumes_for_user(user_id)
job_resumes = [
resume for resume in all_resumes
if resume.get("job_id") == job_id
]
job_resumes = [resume for resume in all_resumes if resume.get("job_id") == job_id]
logger.info(f"📄 Found {len(job_resumes)} resumes for job {job_id} by user {user_id}")
return job_resumes
except Exception as e:
@ -236,24 +233,24 @@ class ResumeMixin(DatabaseProtocol):
"""Get resume statistics for a user"""
try:
all_resumes = await self.get_all_resumes_for_user(user_id)
stats = {
"total_resumes": len(all_resumes),
"resumes_by_candidate": {},
"resumes_by_job": {},
"creation_timeline": {},
"recent_resumes": []
"recent_resumes": [],
}
for resume in all_resumes:
# Count by candidate
candidate_id = resume.get("candidate_id", "unknown")
stats["resumes_by_candidate"][candidate_id] = stats["resumes_by_candidate"].get(candidate_id, 0) + 1
# Count by job
job_id = resume.get("job_id", "unknown")
stats["resumes_by_job"][job_id] = stats["resumes_by_job"].get(job_id, 0) + 1
# Timeline by date
created_at = resume.get("created_at")
if created_at:
@ -262,14 +259,20 @@ class ResumeMixin(DatabaseProtocol):
stats["creation_timeline"][date_key] = stats["creation_timeline"].get(date_key, 0) + 1
except (IndexError, TypeError):
pass
# Get recent resumes (last 5)
stats["recent_resumes"] = all_resumes[:5]
return stats
except Exception as e:
logger.error(f"❌ Error getting resume statistics for user {user_id}: {e}")
return {"total_resumes": 0, "resumes_by_candidate": {}, "resumes_by_job": {}, "creation_timeline": {}, "recent_resumes": []}
return {
"total_resumes": 0,
"resumes_by_candidate": {},
"resumes_by_job": {},
"creation_timeline": {},
"recent_resumes": [],
}
async def update_resume(self, user_id: str, resume_id: str, updates: Dict) -> Optional[Dict]:
"""Update specific fields of a resume"""
@ -278,10 +281,10 @@ class ResumeMixin(DatabaseProtocol):
if resume_data:
resume_data.update(updates)
resume_data["updated_at"] = datetime.now(UTC).isoformat()
key = f"{KEY_PREFIXES['resumes']}{user_id}:{resume_id}"
await self.redis.set(key, self._serialize(resume_data))
logger.info(f"📄 Updated resume {resume_id} for user {user_id}")
return resume_data
return None

View File

@ -8,6 +8,7 @@ from .protocols import DatabaseProtocol
logger = logging.getLogger(__name__)
class SkillMixin(DatabaseProtocol):
"""Mixin for Skill-related database operations"""
@ -23,7 +24,7 @@ class SkillMixin(DatabaseProtocol):
except Exception as e:
logger.error(f"❌ Error getting cached skill match: {e}")
return None
async def invalidate_candidate_skill_cache(self, candidate_id: str) -> int:
"""Invalidate all cached skill matches for a specific candidate"""
try:
@ -35,7 +36,7 @@ class SkillMixin(DatabaseProtocol):
except Exception as e:
logger.error(f"Error invalidating candidate skill cache: {e}")
return 0
async def clear_all_skill_match_cache(self) -> int:
"""Clear all skill match cache (useful after major system updates)"""
try:
@ -47,7 +48,7 @@ class SkillMixin(DatabaseProtocol):
except Exception as e:
logger.error(f"Error clearing skill match cache: {e}")
return 0
async def invalidate_user_skill_cache(self, user_id: str) -> int:
"""Invalidate all cached skill matches when a user's RAG data is updated"""
try:
@ -55,12 +56,12 @@ class SkillMixin(DatabaseProtocol):
# You might need to adjust the pattern based on how you associate candidates with users
pattern = f"skill_match:{user_id}:*"
keys = await self.redis.keys(pattern)
# Filter keys that belong to candidates owned by this user
# This would require additional logic to determine candidate ownership
# For now, you might want to clear all cache when any user's RAG data updates
# or implement a more sophisticated mapping
if keys:
return await self.redis.delete(*keys)
return 0
@ -73,8 +74,10 @@ class SkillMixin(DatabaseProtocol):
try:
# Cache for 1 hour by default
await self.redis.set(
cache_key,
json.dumps(assessment.model_dump(mode='json', by_alias=True), default=str) # Serialize with datetime handling
cache_key,
json.dumps(
assessment.model_dump(mode="json", by_alias=True), default=str
), # Serialize with datetime handling
)
logger.info(f"💾 Skill match cached: {cache_key}")
except Exception as e:

View File

@ -10,6 +10,7 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__)
class UserMixin(DatabaseProtocol):
"""Mixin for user operations"""
@ -21,27 +22,27 @@ class UserMixin(DatabaseProtocol):
try:
# Ensure last_activity is always set
guest_data["last_activity"] = datetime.now(UTC).isoformat()
# Store in Redis with both hash and individual key for redundancy
await self.redis.hset("guests", guest_id, json.dumps(guest_data))# type: ignore
await self.redis.hset("guests", guest_id, json.dumps(guest_data)) # type: ignore
# Also store with a longer TTL as backup
await self.redis.setex(
f"guest_backup:{guest_id}",
f"guest_backup:{guest_id}",
86400 * 7, # 7 days TTL
json.dumps(guest_data)
json.dumps(guest_data),
)
logger.info(f"💾 Guest stored with backup: {guest_id}")
except Exception as e:
logger.error(f"❌ Error storing guest {guest_id}: {e}")
raise
async def get_guest(self, guest_id: str) -> Optional[Dict[str, Any]]:
"""Get guest data with fallback to backup"""
try:
# Try primary storage first
data = await self.redis.hget("guests", guest_id)# type: ignore
data = await self.redis.hget("guests", guest_id) # type: ignore
if data:
guest_data = json.loads(data)
# Update last activity when accessed
@ -49,24 +50,24 @@ class UserMixin(DatabaseProtocol):
await self.set_guest(guest_id, guest_data)
logger.info(f"🔍 Guest found in primary storage: {guest_id}")
return guest_data
# Fallback to backup storage
backup_data = await self.redis.get(f"guest_backup:{guest_id}")
if backup_data:
guest_data = json.loads(backup_data)
guest_data["last_activity"] = datetime.now(UTC).isoformat()
# Restore to primary storage
await self.set_guest(guest_id, guest_data)
logger.info(f"🔄 Guest restored from backup: {guest_id}")
return guest_data
logger.warning(f"⚠️ Guest not found: {guest_id}")
return None
except Exception as e:
logger.error(f"❌ Error getting guest {guest_id}: {e}")
return None
async def get_guest_by_session_id(self, session_id: str) -> Optional[Dict[str, Any]]:
"""Get guest data by session ID"""
try:
@ -78,23 +79,20 @@ class UserMixin(DatabaseProtocol):
except Exception as e:
logger.error(f"❌ Error getting guest by session ID {session_id}: {e}")
return None
async def get_all_guests(self) -> Dict[str, Dict[str, Any]]:
"""Get all guests"""
try:
data = await self.redis.hgetall("guests")# type: ignore
return {
guest_id: json.loads(guest_json)
for guest_id, guest_json in data.items()
}
data = await self.redis.hgetall("guests") # type: ignore
return {guest_id: json.loads(guest_json) for guest_id, guest_json in data.items()}
except Exception as e:
logger.error(f"❌ Error getting all guests: {e}")
return {}
async def delete_guest(self, guest_id: str) -> bool:
"""Delete a guest"""
try:
result = await self.redis.hdel("guests", guest_id)# type: ignore
result = await self.redis.hdel("guests", guest_id) # type: ignore
if result:
logger.info(f"🗑️ Guest deleted: {guest_id}")
return True
@ -102,35 +100,35 @@ class UserMixin(DatabaseProtocol):
except Exception as e:
logger.error(f"❌ Error deleting guest {guest_id}: {e}")
return False
async def cleanup_inactive_guests(self, inactive_hours: int = 24) -> int:
"""Clean up inactive guest sessions with safety checks"""
try:
all_guests = await self.get_all_guests()
current_time = datetime.now(UTC)
cutoff_time = current_time - timedelta(hours=inactive_hours)
deleted_count = 0
preserved_count = 0
for guest_id, guest_data in all_guests.items():
try:
last_activity_str = guest_data.get("last_activity")
created_at_str = guest_data.get("created_at")
# Skip cleanup if guest is very new (less than 1 hour old)
if created_at_str:
created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00"))
if current_time - created_at < timedelta(hours=1):
preserved_count += 1
logger.info(f"🛡️ Preserving new guest: {guest_id}")
continue
# Check last activity
should_delete = False
if last_activity_str:
try:
last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00'))
last_activity = datetime.fromisoformat(last_activity_str.replace("Z", "+00:00"))
if last_activity < cutoff_time:
should_delete = True
except ValueError:
@ -138,81 +136,81 @@ class UserMixin(DatabaseProtocol):
if not created_at_str:
should_delete = True
else:
# No last activity, but don't delete if guest is new
# No last activity, but don't delete if guest is new
if not created_at_str:
should_delete = True
if should_delete:
await self.delete_guest(guest_id)
deleted_count += 1
else:
preserved_count += 1
except Exception as e:
logger.error(f"❌ Error processing guest {guest_id} for cleanup: {e}")
preserved_count += 1 # Preserve on error
if deleted_count > 0:
logger.info(f"🧹 Guest cleanup: removed {deleted_count}, preserved {preserved_count}")
return deleted_count
except Exception as e:
logger.error(f"❌ Error in guest cleanup: {e}")
return 0
async def get_guest_statistics(self) -> Dict[str, Any]:
"""Get guest usage statistics"""
try:
all_guests = await self.get_all_guests()
current_time = datetime.now(UTC)
stats = {
"total_guests": len(all_guests),
"active_last_hour": 0,
"active_last_day": 0,
"converted_guests": 0,
"by_ip": {},
"creation_timeline": {}
"creation_timeline": {},
}
hour_ago = current_time - timedelta(hours=1)
day_ago = current_time - timedelta(days=1)
for guest_data in all_guests.values():
# Check activity
last_activity_str = guest_data.get("last_activity")
if last_activity_str:
try:
last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00'))
last_activity = datetime.fromisoformat(last_activity_str.replace("Z", "+00:00"))
if last_activity > hour_ago:
stats["active_last_hour"] += 1
if last_activity > day_ago:
stats["active_last_day"] += 1
except ValueError:
pass
# Check conversions
if guest_data.get("converted_to_user_id"):
stats["converted_guests"] += 1
# IP tracking
ip = guest_data.get("ip_address", "unknown")
stats["by_ip"][ip] = stats["by_ip"].get(ip, 0) + 1
# Creation timeline
created_at_str = guest_data.get("created_at")
if created_at_str:
try:
created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
date_key = created_at.strftime('%Y-%m-%d')
created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00"))
date_key = created_at.strftime("%Y-%m-%d")
stats["creation_timeline"][date_key] = stats["creation_timeline"].get(date_key, 0) + 1
except ValueError:
pass
return stats
except Exception as e:
logger.error(f"❌ Error getting guest statistics: {e}")
return {}
return {}
# ================
# Users
@ -222,7 +220,7 @@ class UserMixin(DatabaseProtocol):
username_key = f"{KEY_PREFIXES['users']}{username.lower()}"
data = await self.redis.get(username_key)
return self._deserialize(data) if data else None
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]:
"""Get the last time user's RAG data was updated (returns timezone-aware UTC)"""
try:
@ -241,19 +239,19 @@ class UserMixin(DatabaseProtocol):
return None
except Exception as e:
logger.error(f"❌ Error getting user RAG update time: {e}")
return None
return None
async def update_user_rag_timestamp(self, user_id: str) -> bool:
"""Set the user's RAG data update time (stores as UTC ISO format)"""
try:
update_time = datetime.now(timezone.utc)
# Ensure we're storing UTC timezone-aware format
if update_time.tzinfo is None:
update_time = update_time.replace(tzinfo=timezone.utc)
else:
update_time = update_time.astimezone(timezone.utc)
rag_update_key = f"user:{user_id}:rag_last_update"
# Store as ISO format with timezone info
timestamp_str = update_time.isoformat() # This includes timezone
@ -274,18 +272,18 @@ class UserMixin(DatabaseProtocol):
except Exception as e:
logger.error(f"❌ Error storing user by ID {user_id}: {e}")
return False
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Get user lookup data by user ID"""
try:
data = await self.redis.hget("user_lookup_by_id", user_id)# type: ignore
data = await self.redis.hget("user_lookup_by_id", user_id) # type: ignore
if data:
return json.loads(data)
return None
except Exception as e:
logger.error(f"❌ Error getting user by ID {user_id}: {e}")
return None
return None
async def user_exists_by_email(self, email: str) -> bool:
"""Check if a user exists with the given email"""
try:
@ -295,7 +293,7 @@ class UserMixin(DatabaseProtocol):
except Exception as e:
logger.error(f"❌ Error checking user existence by email {email}: {e}")
return False
async def user_exists_by_username(self, username: str) -> bool:
"""Check if a user exists with the given username"""
try:
@ -305,31 +303,31 @@ class UserMixin(DatabaseProtocol):
except Exception as e:
logger.error(f"❌ Error checking user existence by username {username}: {e}")
return False
async def get_all_users(self) -> Dict[str, Any]:
"""Get all users"""
pattern = f"{KEY_PREFIXES['users']}*"
keys = await self.redis.keys(pattern)
if not keys:
return {}
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
result = {}
for key, value in zip(keys, values):
email = key.replace(KEY_PREFIXES['users'], '')
email = key.replace(KEY_PREFIXES["users"], "")
logger.info(f"🔍 Found user key: {key}, type: {type(value)}")
if type(value) == str:
result[email] = value
else:
result[email] = self._deserialize(value)
return result
async def delete_user(self, email: str):
"""Delete user"""
key = f"{KEY_PREFIXES['users']}{email}"
@ -340,7 +338,7 @@ class UserMixin(DatabaseProtocol):
try:
login = login.strip().lower()
key = f"users:{login}"
data = await self.redis.get(key)
if data:
user_data = json.loads(data)
@ -356,7 +354,7 @@ class UserMixin(DatabaseProtocol):
try:
login = login.strip().lower()
key = f"users:{login}"
await self.redis.set(key, json.dumps(user_data, default=str))
logger.info(f"👤 Stored user data for {login}")
return True
@ -364,7 +362,6 @@ class UserMixin(DatabaseProtocol):
logger.error(f"❌ Error storing user {login}: {e}")
return False
# ================
# Employers
# ================
@ -373,37 +370,36 @@ class UserMixin(DatabaseProtocol):
key = f"{KEY_PREFIXES['employers']}{employer_id}"
data = await self.redis.get(key)
return self._deserialize(data) if data else None
async def set_employer(self, employer_id: str, employer_data: Dict):
"""Set employer data"""
key = f"{KEY_PREFIXES['employers']}{employer_id}"
await self.redis.set(key, self._serialize(employer_data))
async def get_all_employers(self) -> Dict[str, Any]:
"""Get all employers"""
pattern = f"{KEY_PREFIXES['employers']}*"
keys = await self.redis.keys(pattern)
if not keys:
return {}
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
result = {}
for key, value in zip(keys, values):
employer_id = key.replace(KEY_PREFIXES['employers'], '')
employer_id = key.replace(KEY_PREFIXES["employers"], "")
result[employer_id] = self._deserialize(value)
return result
async def delete_employer(self, employer_id: str):
"""Delete employer"""
key = f"{KEY_PREFIXES['employers']}{employer_id}"
await self.redis.delete(key)
# ================
# Candidates
@ -413,33 +409,33 @@ class UserMixin(DatabaseProtocol):
key = f"{KEY_PREFIXES['candidates']}{candidate_id}"
data = await self.redis.get(key)
return self._deserialize(data) if data else None
async def set_candidate(self, candidate_id: str, candidate_data: Dict):
"""Set candidate data"""
key = f"{KEY_PREFIXES['candidates']}{candidate_id}"
await self.redis.set(key, self._serialize(candidate_data))
async def get_all_candidates(self) -> Dict[str, Any]:
"""Get all candidates"""
pattern = f"{KEY_PREFIXES['candidates']}*"
keys = await self.redis.keys(pattern)
if not keys:
return {}
# Use pipeline for efficiency
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
result = {}
for key, value in zip(keys, values):
candidate_id = key.replace(KEY_PREFIXES['candidates'], '')
candidate_id = key.replace(KEY_PREFIXES["candidates"], "")
result[candidate_id] = self._deserialize(value)
return result
async def delete_candidate(self: Self, candidate_id: str) -> Dict[str, int]:
"""
Delete candidate and all related records in a cascading manner
@ -456,20 +452,20 @@ class UserMixin(DatabaseProtocol):
"security_logs": 0,
"ai_parameters": 0,
"candidate_record": 0,
"resumes": 0
"resumes": 0,
}
logger.info(f"🗑️ Starting cascading delete for candidate {candidate_id}")
# 1. Get candidate data first to retrieve associated information
candidate_data = await self.get_candidate(candidate_id)
if not candidate_data:
logger.warning(f"⚠️ Candidate {candidate_id} not found")
return deletion_stats
candidate_email = candidate_data.get("email", "").lower()
candidate_username = candidate_data.get("username", "").lower()
# 2. Delete all candidate documents and their metadata
try:
documents_deleted = await self.delete_all_candidate_documents(candidate_id)
@ -477,64 +473,68 @@ class UserMixin(DatabaseProtocol):
logger.info(f"🗑️ Deleted {documents_deleted} documents for candidate {candidate_id}")
except Exception as e:
logger.error(f"❌ Error deleting candidate documents: {e}")
# 3. Delete all chat sessions related to this candidate
try:
candidate_sessions = await self.get_chat_sessions_by_candidate(candidate_id)
messages_deleted = 0
for session in candidate_sessions:
session_id = session.get("id")
if session_id:
# Count messages before deletion
message_count = await self.get_chat_message_count(session_id)
messages_deleted += message_count
# Delete chat session and its messages
await self.delete_chat_session_completely(session_id)
deletion_stats["chat_sessions"] = len(candidate_sessions)
deletion_stats["chat_messages"] = messages_deleted
logger.info(f"🗑️ Deleted {len(candidate_sessions)} chat sessions and {messages_deleted} messages for candidate {candidate_id}")
logger.info(
f"🗑️ Deleted {len(candidate_sessions)} chat sessions and {messages_deleted} messages for candidate {candidate_id}"
)
except Exception as e:
logger.error(f"❌ Error deleting chat sessions: {e}")
# 4. Delete job applications from this candidate
try:
all_applications = await self.get_all_job_applications()
candidate_applications = []
for app_id, app_data in all_applications.items():
if app_data.get("candidateId") == candidate_id:
candidate_applications.append(app_id)
# Delete each application
for app_id in candidate_applications:
await self.delete_job_application(app_id)
deletion_stats["job_applications"] = len(candidate_applications)
logger.info(f"🗑️ Deleted {len(candidate_applications)} job applications for candidate {candidate_id}")
except Exception as e:
logger.error(f"❌ Error deleting job applications: {e}")
# 5. Delete user records (by email and username if they exist)
try:
user_records_deleted = 0
# Delete by email
if candidate_email and await self.user_exists_by_email(candidate_email):
await self.delete_user(candidate_email)
user_records_deleted += 1
logger.info(f"🗑️ Deleted user record by email: {candidate_email}")
# Delete by username (if different from email)
if (candidate_username and
candidate_username != candidate_email and
await self.user_exists_by_username(candidate_username)):
if (
candidate_username
and candidate_username != candidate_email
and await self.user_exists_by_username(candidate_username)
):
await self.delete_user(candidate_username)
user_records_deleted += 1
logger.info(f"🗑️ Deleted user record by username: {candidate_username}")
# Delete user by ID if exists
user_by_id = await self.get_user_by_id(candidate_id)
if user_by_id:
@ -542,12 +542,12 @@ class UserMixin(DatabaseProtocol):
await self.redis.delete(key)
user_records_deleted += 1
logger.info(f"🗑️ Deleted user record by ID: {candidate_id}")
deletion_stats["user_records"] = user_records_deleted
logger.info(f"🗑️ Deleted {user_records_deleted} user records for candidate {candidate_id}")
except Exception as e:
logger.error(f"❌ Error deleting user records: {e}")
# 6. Delete authentication records
try:
auth_deleted = await self.delete_authentication(candidate_id)
@ -556,57 +556,56 @@ class UserMixin(DatabaseProtocol):
logger.info(f"🗑️ Deleted authentication record for candidate {candidate_id}")
except Exception as e:
logger.error(f"❌ Error deleting authentication records: {e}")
# 7. Revoke all refresh tokens for this user
try:
await self.revoke_all_user_tokens(candidate_id)
logger.info(f"🗑️ Revoked all refresh tokens for candidate {candidate_id}")
except Exception as e:
logger.error(f"❌ Error revoking refresh tokens: {e}")
# 8. Delete security logs for this user
try:
security_logs_deleted = 0
# Security logs are stored by date, so we need to scan for them
pattern = f"security_log:{candidate_id}:*"
cursor = 0
while True:
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
if keys:
await self.redis.delete(*keys)
security_logs_deleted += len(keys)
if cursor == 0:
break
deletion_stats["security_logs"] = security_logs_deleted
if security_logs_deleted > 0:
logger.info(f"🗑️ Deleted {security_logs_deleted} security log entries for candidate {candidate_id}")
except Exception as e:
logger.error(f"❌ Error deleting security logs: {e}")
# 9. Delete AI parameters that might be specific to this candidate
try:
all_ai_params = await self.get_all_ai_parameters()
candidate_ai_params = []
for param_id, param_data in all_ai_params.items():
if (param_data.get("candidateId") == candidate_id or
param_data.get("userId") == candidate_id):
if param_data.get("candidateId") == candidate_id or param_data.get("userId") == candidate_id:
candidate_ai_params.append(param_id)
# Delete each AI parameter set
for param_id in candidate_ai_params:
await self.delete_ai_parameters(param_id)
deletion_stats["ai_parameters"] = len(candidate_ai_params)
if len(candidate_ai_params) > 0:
logger.info(f"🗑️ Deleted {len(candidate_ai_params)} AI parameter sets for candidate {candidate_id}")
except Exception as e:
logger.error(f"❌ Error deleting AI parameters: {e}")
# 10. Delete email verification tokens if any exist
try:
if candidate_email:
@ -614,10 +613,10 @@ class UserMixin(DatabaseProtocol):
pattern = "email_verification:*"
cursor = 0
tokens_deleted = 0
while True:
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
for key in keys:
token_data = await self.redis.get(key)
if token_data:
@ -625,25 +624,27 @@ class UserMixin(DatabaseProtocol):
if verification_info.get("email", "").lower() == candidate_email:
await self.redis.delete(key)
tokens_deleted += 1
if cursor == 0:
break
if tokens_deleted > 0:
logger.info(f"🗑️ Deleted {tokens_deleted} email verification tokens for candidate {candidate_id}")
logger.info(
f"🗑️ Deleted {tokens_deleted} email verification tokens for candidate {candidate_id}"
)
except Exception as e:
logger.error(f"❌ Error deleting email verification tokens: {e}")
# 11. Delete password reset tokens if any exist
try:
if candidate_email:
pattern = "password_reset:*"
cursor = 0
tokens_deleted = 0
while True:
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
for key in keys:
token_data = await self.redis.get(key)
if token_data:
@ -651,37 +652,37 @@ class UserMixin(DatabaseProtocol):
if reset_info.get("email", "").lower() == candidate_email:
await self.redis.delete(key)
tokens_deleted += 1
if cursor == 0:
break
if tokens_deleted > 0:
logger.info(f"🗑️ Deleted {tokens_deleted} password reset tokens for candidate {candidate_id}")
except Exception as e:
logger.error(f"❌ Error deleting password reset tokens: {e}")
# 12. Delete MFA codes if any exist
try:
if candidate_email:
pattern = f"mfa_code:{candidate_email}:*"
cursor = 0
mfa_codes_deleted = 0
while True:
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
if keys:
await self.redis.delete(*keys)
mfa_codes_deleted += len(keys)
if cursor == 0:
break
if mfa_codes_deleted > 0:
logger.info(f"🗑️ Deleted {mfa_codes_deleted} MFA codes for candidate {candidate_id}")
except Exception as e:
logger.error(f"❌ Error deleting MFA codes: {e}")
# 13. Finally, delete the candidate record itself
try:
key = f"{KEY_PREFIXES['candidates']}{candidate_id}"
@ -690,24 +691,24 @@ class UserMixin(DatabaseProtocol):
logger.info(f"🗑️ Deleted candidate record for {candidate_id}")
except Exception as e:
logger.error(f"❌ Error deleting candidate record: {e}")
# 14. Delete resumes associated with this candidate across all users
try:
all_resumes = await self.get_all_resumes()
candidate_resumes_deleted = 0
for user_id, user_resumes in all_resumes.items():
resumes_to_delete = []
for resume in user_resumes:
if resume.get("candidate_id") == candidate_id:
resumes_to_delete.append(resume.get("resume_id"))
# Delete each resume for this candidate
for resume_id in resumes_to_delete:
if resume_id:
await self.delete_resume(user_id, resume_id)
candidate_resumes_deleted += 1
deletion_stats["resumes"] = candidate_resumes_deleted
if candidate_resumes_deleted > 0:
logger.info(f"🗑️ Deleted {candidate_resumes_deleted} resumes for candidate {candidate_id}")
@ -717,14 +718,16 @@ class UserMixin(DatabaseProtocol):
# 15. Log the deletion as a security event (if we have admin/system user context)
try:
total_items_deleted = sum(deletion_stats.values())
logger.info(f"✅ Completed cascading delete for candidate {candidate_id}. "
f"Total items deleted: {total_items_deleted}")
logger.info(
f"✅ Completed cascading delete for candidate {candidate_id}. "
f"Total items deleted: {total_items_deleted}"
)
logger.info(f"📊 Deletion breakdown: {deletion_stats}")
except Exception as e:
logger.error(f"❌ Error logging deletion summary: {e}")
return deletion_stats
except Exception as e:
logger.error(f"❌ Critical error during candidate deletion {candidate_id}: {e}")
raise
@ -748,25 +751,25 @@ class UserMixin(DatabaseProtocol):
"candidate_record": 0,
"resumes": 0,
}
logger.info(f"🗑️ Starting batch deletion for {len(candidate_ids)} candidates")
for candidate_id in candidate_ids:
try:
deletion_stats = await self.delete_candidate(candidate_id)
batch_results[candidate_id] = deletion_stats
# Add to totals
for key, value in deletion_stats.items():
total_stats[key] += value
except Exception as e:
logger.error(f"❌ Failed to delete candidate {candidate_id}: {e}")
batch_results[candidate_id] = {"error": str(e)}
logger.info(f"✅ Completed batch deletion. Total items deleted: {sum(total_stats.values())}")
logger.info(f"📊 Batch totals: {total_stats}")
return {
"individual_results": batch_results,
"totals": total_stats,
@ -774,70 +777,70 @@ class UserMixin(DatabaseProtocol):
"total_candidates_processed": len(candidate_ids),
"successful_deletions": len([r for r in batch_results.values() if "error" not in r]),
"failed_deletions": len([r for r in batch_results.values() if "error" in r]),
"total_items_deleted": sum(total_stats.values())
}
"total_items_deleted": sum(total_stats.values()),
},
}
except Exception as e:
logger.error(f"❌ Critical error during batch candidate deletion: {e}")
raise
raise
async def find_candidate_by_username(self, username: str) -> Optional[Dict]:
"""Find candidate by username"""
all_candidates = await self.get_all_candidates()
username_lower = username.lower()
for candidate_data in all_candidates.values():
if candidate_data.get("username", "").lower() == username_lower:
return candidate_data
return None
async def get_multiple_candidates_by_usernames(self, usernames: List[str]) -> Dict[str, Dict]:
"""Get multiple candidates by their usernames efficiently"""
all_candidates = await self.get_all_candidates()
username_set = {username.lower() for username in usernames}
result = {}
for candidate_data in all_candidates.values():
candidate_username = candidate_data.get("username", "").lower()
if candidate_username in username_set:
result[candidate_username] = candidate_data
return result
async def get_candidate_chat_summary(self, candidate_id: str) -> Dict[str, Any]:
"""Get a summary of chat activity for a specific candidate"""
sessions = await self.get_chat_sessions_by_candidate(candidate_id)
if not sessions:
return {
"candidate_id": candidate_id,
"total_sessions": 0,
"total_messages": 0,
"first_chat": None,
"last_chat": None
"last_chat": None,
}
total_messages = 0
for session in sessions:
session_id = session.get("id")
if session_id:
message_count = await self.get_chat_message_count(session_id)
total_messages += message_count
# Sort sessions by creation date
sessions_by_date = sorted(sessions, key=lambda x: x.get("createdAt", ""))
return {
"candidate_id": candidate_id,
"total_sessions": len(sessions),
"total_messages": total_messages,
"first_chat": sessions_by_date[0].get("createdAt") if sessions_by_date else None,
"last_chat": sessions_by_date[-1].get("lastActivity") if sessions_by_date else None,
"recent_sessions": sessions[:5] # Last 5 sessions
"recent_sessions": sessions[:5], # Last 5 sessions
}
# ================
# Viewers
# ================
@ -846,35 +849,34 @@ class UserMixin(DatabaseProtocol):
key = f"{KEY_PREFIXES['viewers']}{viewer_id}"
data = await self.redis.get(key)
return self._deserialize(data) if data else None
async def set_viewer(self, viewer_id: str, viewer_data: Dict):
"""Set viewer data"""
key = f"{KEY_PREFIXES['viewers']}{viewer_id}"
await self.redis.set(key, self._serialize(viewer_data))
async def get_all_viewers(self) -> Dict[str, Any]:
"""Get all viewers"""
pattern = f"{KEY_PREFIXES['viewers']}*"
keys = await self.redis.keys(pattern)
if not keys:
return {}
# Use pipeline for efficiency
pipe = self.redis.pipeline()
for key in keys:
pipe.get(key)
values = await pipe.execute()
result = {}
for key, value in zip(keys, values):
viewer_id = key.replace(KEY_PREFIXES['viewers'], '')
viewer_id = key.replace(KEY_PREFIXES["viewers"], "")
result[viewer_id] = self._deserialize(value)
return result
async def delete_viewer(self, viewer_id: str):
"""Delete viewer"""
key = f"{KEY_PREFIXES['viewers']}{viewer_id}"
await self.redis.delete(key)

View File

@ -3,20 +3,20 @@ import os
ollama_api_url = "http://ollama:11434" # Default Ollama local endpoint
user_dir = "/opt/backstory/users"
user_info_file = "info.json" # Relative to "{user_dir}/{user}"
user_info_file = "info.json" # Relative to "{user_dir}/{user}"
default_username = "jketreno"
rag_content_dir = "rag-content" # Relative to "{user_dir}/{user}"
rag_content_dir = "rag-content" # Relative to "{user_dir}/{user}"
# Path to candidate full resume
resume_doc_dir = f"{rag_content_dir}/resume" # Relative to "{user_dir}/{user}
resume_doc_dir = f"{rag_content_dir}/resume" # Relative to "{user_dir}/{user}
resume_doc = "resume.md"
persist_directory = "db" # Relative to "{user_dir}/{user}"
persist_directory = "db" # Relative to "{user_dir}/{user}"
# Model name License Notes
# model = "deepseek-r1:7b" # MIT Tool calls don"t work
# Model name License Notes
# model = "deepseek-r1:7b" # MIT Tool calls don"t work
# model = "gemma3:4b" # Gemma Requires newer ollama https://ai.google.dev/gemma/terms
# model = "llama3.2" # Llama Good results; qwen seems slightly better https://huggingface.co/meta-llama/Llama-3.2-1B/blob/main/LICENSE.txt
# model = "mistral:7b" # Apache 2.0 Tool calls don"t work
model = "qwen2.5:7b" # Apache 2.0 Good results
model = "qwen2.5:7b" # Apache 2.0 Good results
# model = "qwen3:8b" # Apache 2.0 Requires newer ollama
model = os.getenv("MODEL_NAME", model)
@ -40,7 +40,7 @@ logging_level = os.getenv("LOGGING_LEVEL", "INFO").upper()
# RAG and Vector DB settings
## Where to read RAG content
chunk_buffer = 5 # Number of lines before and after chunk beyond the portion used in embedding (to return to callers)
chunk_buffer = 5 # Number of lines before and after chunk beyond the portion used in embedding (to return to callers)
# Maximum number of entries for ChromaDB to find
default_rag_top_k = 50
@ -60,7 +60,7 @@ cert_path = "/opt/backstory/keys/cert.pem"
host = os.getenv("BACKSTORY_HOST", "0.0.0.0")
port = int(os.getenv("BACKSTORY_PORT", "8911"))
api_prefix = "/api/1.0"
debug=os.getenv("BACKSTORY_DEBUG", "false").lower() in ("true", "1", "yes")
debug = os.getenv("BACKSTORY_DEBUG", "false").lower() in ("true", "1", "yes")
# Used for filtering tracebacks
app_path="/opt/backstory/src/backend"
app_path = "/opt/backstory/src/backend"

View File

@ -1,32 +1,32 @@
from fastapi import Request
from fastapi import Request
from database.manager import RedisDatabase
import hashlib
from logger import logger
from datetime import datetime, timezone
from user_agents import parse
from user_agents import parse
import json
class DeviceManager:
def __init__(self, database: RedisDatabase):
self.database = database
def generate_device_fingerprint(self, request: Request) -> str:
"""Generate device fingerprint from request"""
user_agent = request.headers.get("user-agent", "")
ip_address = request.client.host if request.client else "unknown"
accept_language = request.headers.get("accept-language", "")
# Create fingerprint
fingerprint_data = f"{user_agent}|{accept_language}"
fingerprint = hashlib.sha256(fingerprint_data.encode()).hexdigest()[:16]
return fingerprint
def parse_device_info(self, request: Request) -> dict:
"""Parse device information from request"""
user_agent_string = request.headers.get("user-agent", "")
user_agent = parse(user_agent_string)
return {
"device_id": self.generate_device_fingerprint(request),
"device_name": f"{user_agent.browser.family} on {user_agent.os.family}",
@ -35,9 +35,9 @@ class DeviceManager:
"os": user_agent.os.family,
"os_version": user_agent.os.version_string,
"ip_address": request.client.host if request.client else "unknown",
"user_agent": user_agent_string
"user_agent": user_agent_string,
}
async def is_trusted_device(self, user_id: str, device_id: str) -> bool:
"""Check if device is trusted for user"""
try:
@ -47,7 +47,7 @@ class DeviceManager:
except Exception as e:
logger.error(f"Error checking trusted device: {e}")
return False
async def add_trusted_device(self, user_id: str, device_id: str, device_info: dict):
"""Add device to trusted devices"""
try:
@ -55,20 +55,20 @@ class DeviceManager:
device_data = {
**device_info,
"added_at": datetime.now(timezone.utc).isoformat(),
"last_used": datetime.now(timezone.utc).isoformat()
"last_used": datetime.now(timezone.utc).isoformat(),
}
# Store for 90 days
await self.database.redis.setex(
key,
key,
90 * 24 * 60 * 60, # 90 days in seconds
json.dumps(device_data, default=str)
json.dumps(device_data, default=str),
)
logger.info(f"🔒 Added trusted device {device_id} for user {user_id}")
except Exception as e:
logger.error(f"Error adding trusted device: {e}")
async def update_device_last_used(self, user_id: str, device_id: str):
"""Update last used timestamp for device"""
try:
@ -80,7 +80,7 @@ class DeviceManager:
await self.database.redis.setex(
key,
90 * 24 * 60 * 60, # Reset 90 day expiry
json.dumps(device_info, default=str)
json.dumps(device_info, default=str),
)
except Exception as e:
logger.error(f"Error updating device last used: {e}")

View File

@ -1,8 +1,8 @@
import os
from typing import Tuple
from logger import logger
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
import smtplib
import asyncio
from email_templates import EMAIL_TEMPLATES
@ -10,6 +10,7 @@ from datetime import datetime, timezone, timedelta
import json
from database.manager import RedisDatabase
class EmailService:
def __init__(self):
# Configure these in your .env file
@ -26,64 +27,48 @@ class EmailService:
def _get_template(self, template_name: str) -> dict:
"""Get email template by name"""
return EMAIL_TEMPLATES.get(template_name, {})
def _format_template(self, template: str, **kwargs) -> str:
"""Format template with provided variables"""
return template.format(
app_name=self.app_name,
from_name=self.from_name,
frontend_url=self.frontend_url,
**kwargs
app_name=self.app_name, from_name=self.from_name, frontend_url=self.frontend_url, **kwargs
)
async def send_verification_email(
self,
to_email: str,
verification_token: str,
user_name: str,
user_type: str = "user"
self, to_email: str, verification_token: str, user_name: str, user_type: str = "user"
):
"""Send email verification email using template"""
try:
template = self._get_template("verification")
verification_link = f"{self.frontend_url}/login/verify-email?token={verification_token}"
subject = self._format_template(
template["subject"],
user_name=user_name,
to_email=to_email
)
subject = self._format_template(template["subject"], user_name=user_name, to_email=to_email)
html_content = self._format_template(
template["html"],
user_name=user_name,
user_type=user_type,
to_email=to_email,
verification_link=verification_link
verification_link=verification_link,
)
await self._send_email(to_email, subject, html_content)
logger.info(f"📧 Verification email sent to {to_email}")
except Exception as e:
logger.error(f"❌ Failed to send verification email to {to_email}: {e}")
raise
async def send_mfa_email(
self,
to_email: str,
mfa_code: str,
device_name: str,
user_name: str,
ip_address: str = "Unknown"
self, to_email: str, mfa_code: str, device_name: str, user_name: str, ip_address: str = "Unknown"
):
"""Send MFA code email using template"""
try:
template = self._get_template("mfa")
login_time = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")
subject = self._format_template(template["subject"])
html_content = self._format_template(
template["html"],
user_name=user_name,
@ -91,64 +76,56 @@ class EmailService:
ip_address=ip_address,
login_time=login_time,
mfa_code=mfa_code,
to_email=to_email
to_email=to_email,
)
await self._send_email(to_email, subject, html_content)
logger.info(f"📧 MFA code sent to {to_email} for device {device_name}")
except Exception as e:
logger.error(f"❌ Failed to send MFA email to {to_email}: {e}")
raise
async def send_password_reset_email(
self,
to_email: str,
reset_token: str,
user_name: str
):
async def send_password_reset_email(self, to_email: str, reset_token: str, user_name: str):
"""Send password reset email using template"""
try:
template = self._get_template("password_reset")
reset_link = f"{self.frontend_url}/login/reset-password?token={reset_token}"
subject = self._format_template(template["subject"])
html_content = self._format_template(
template["html"],
user_name=user_name,
reset_link=reset_link,
to_email=to_email
template["html"], user_name=user_name, reset_link=reset_link, to_email=to_email
)
await self._send_email(to_email, subject, html_content)
logger.info(f"📧 Password reset email sent to {to_email}")
except Exception as e:
logger.error(f"❌ Failed to send password reset email to {to_email}: {e}")
raise
async def _send_email(self, to_email: str, subject: str, html_content: str):
"""Send email using SMTP with improved error handling"""
try:
if not self.email_user:
raise ValueError("Email user is not configured")
# Create message
msg = MIMEMultipart('alternative')
msg['From'] = f"{self.from_name} <{self.email_user}>"
msg['To'] = to_email
msg['Subject'] = subject
msg['Reply-To'] = self.email_user
msg = MIMEMultipart("alternative")
msg["From"] = f"{self.from_name} <{self.email_user}>"
msg["To"] = to_email
msg["Subject"] = subject
msg["Reply-To"] = self.email_user
# Add HTML content
html_part = MIMEText(html_content, 'html', 'utf-8')
html_part = MIMEText(html_content, "html", "utf-8")
msg.attach(html_part)
# Send email with connection pooling and retry logic
max_retries = 3
if not self.smtp_server or self.smtp_port == 0 or not self.email_user or not self.email_password:
raise ValueError("SMTP configuration is not set in the environment variables")
for attempt in range(max_retries):
try:
with smtplib.SMTP(self.smtp_server, self.smtp_port) as server:
@ -157,71 +134,61 @@ class EmailService:
text = msg.as_string()
server.sendmail(self.email_user, to_email, text)
break # Success, exit retry loop
except smtplib.SMTPException as e:
if attempt == max_retries - 1: # Last attempt
raise
logger.warning(f"⚠️ SMTP attempt {attempt + 1} failed, retrying: {e}")
await asyncio.sleep(1) # Wait before retry
logger.debug(f"📧 Email sent successfully to {to_email}")
except Exception as e:
logger.error(f"❌ SMTP error sending to {to_email}: {e}")
raise
class EmailRateLimiter:
def __init__(self, database: RedisDatabase):
self.database = database
async def can_send_email(self, email: str, email_type: str, limit: int = 5, window_minutes: int = 60) -> bool:
"""Check if email can be sent based on rate limiting"""
try:
key = f"email_rate_limit:{email_type}:{email.lower()}"
current_time = datetime.now(timezone.utc)
window_start = current_time - timedelta(minutes=window_minutes)
# Get current count
count_data = await self.database.redis.get(key)
if not count_data:
# First email, allow it
await self._record_email_sent(key, current_time, window_minutes)
return True
email_records = json.loads(count_data)
# Filter out old records
recent_records = [
record for record in email_records
if datetime.fromisoformat(record) > window_start
]
recent_records = [record for record in email_records if datetime.fromisoformat(record) > window_start]
if len(recent_records) >= limit:
logger.warning(f"⚠️ Email rate limit exceeded for {email} ({email_type})")
return False
# Add current email to records
recent_records.append(current_time.isoformat())
await self.database.redis.setex(
key,
window_minutes * 60,
json.dumps(recent_records)
)
await self.database.redis.setex(key, window_minutes * 60, json.dumps(recent_records))
return True
except Exception as e:
logger.error(f"❌ Error checking email rate limit: {e}")
# On error, allow the email to be safe
return True
async def _record_email_sent(self, key: str, timestamp: datetime, ttl_minutes: int):
"""Record that an email was sent"""
await self.database.redis.setex(
key,
ttl_minutes * 60,
json.dumps([timestamp.isoformat()])
)
await self.database.redis.setex(key, ttl_minutes * 60, json.dumps([timestamp.isoformat()]))
class VerificationEmailRateLimiter:
def __init__(self, database: RedisDatabase):
@ -229,7 +196,7 @@ class VerificationEmailRateLimiter:
self.max_attempts_per_hour = 3 # Maximum 3 emails per hour
self.max_attempts_per_day = 10 # Maximum 10 emails per day
self.cooldown_minutes = 5 # 5 minute cooldown between emails
async def can_send_verification_email(self, email: str) -> Tuple[bool, str]:
"""
Check if verification email can be sent based on rate limiting
@ -238,49 +205,49 @@ class VerificationEmailRateLimiter:
try:
email_lower = email.lower()
current_time = datetime.now(timezone.utc)
# Check daily limit
daily_count = await self.database.get_verification_attempts_count(email)
if daily_count >= self.max_attempts_per_day:
return False, f"Daily limit reached. You can request up to {self.max_attempts_per_day} verification emails per day."
return (
False,
f"Daily limit reached. You can request up to {self.max_attempts_per_day} verification emails per day.",
)
# Check hourly limit
hour_ago = current_time - timedelta(hours=1)
hourly_key = f"verification_attempts:{email_lower}"
data = await self.database.redis.get(hourly_key)
if data:
attempts_data = json.loads(data)
recent_attempts = [
attempt for attempt in attempts_data
if datetime.fromisoformat(attempt) > hour_ago
]
recent_attempts = [attempt for attempt in attempts_data if datetime.fromisoformat(attempt) > hour_ago]
if len(recent_attempts) >= self.max_attempts_per_hour:
return False, f"Hourly limit reached. You can request up to {self.max_attempts_per_hour} verification emails per hour."
return (
False,
f"Hourly limit reached. You can request up to {self.max_attempts_per_hour} verification emails per hour.",
)
# Check cooldown period
if recent_attempts:
last_attempt = max(datetime.fromisoformat(attempt) for attempt in recent_attempts)
time_since_last = current_time - last_attempt
if time_since_last.total_seconds() < self.cooldown_minutes * 60:
remaining_minutes = self.cooldown_minutes - int(time_since_last.total_seconds() / 60)
return False, f"Please wait {remaining_minutes} more minute(s) before requesting another email."
return True, "OK"
except Exception as e:
logger.error(f"❌ Error checking verification email rate limit: {e}")
# On error, be conservative and deny
return False, "Rate limit check failed. Please try again later."
async def record_email_sent(self, email: str):
"""Record that a verification email was sent"""
await self.database.record_verification_attempt(email)
email_service = EmailService()

View File

@ -129,9 +129,8 @@ EMAIL_TEMPLATES = {
</div>
</body>
</html>
"""
""",
},
"mfa": {
"subject": "Security Code for Backstory",
"html": """
@ -274,9 +273,8 @@ EMAIL_TEMPLATES = {
</div>
</body>
</html>
"""
""",
},
"password_reset": {
"subject": "Reset your Backstory password",
"html": """
@ -386,6 +384,6 @@ EMAIL_TEMPLATES = {
</div>
</body>
</html>
"""
}
}
""",
},
}

View File

@ -3,16 +3,17 @@ import weakref
from datetime import datetime, timedelta
from typing import Dict, Optional
from contextlib import asynccontextmanager
from pydantic import BaseModel, Field # type: ignore
from pydantic import BaseModel # type: ignore
from models import Candidate
from agents.base import CandidateEntity
from database.manager import RedisDatabase
from prometheus_client import CollectorRegistry # type: ignore
from prometheus_client import CollectorRegistry # type: ignore
class EntityManager(BaseModel):
"""Manages lifecycle of CandidateEntity instances"""
def __init__(self, default_ttl_minutes: int = 30):
self._entities: Dict[str, CandidateEntity] = {}
self._weak_refs: Dict[str, weakref.ReferenceType] = {}
@ -25,7 +26,7 @@ class EntityManager(BaseModel):
"""Start background cleanup task"""
if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._periodic_cleanup())
async def stop_cleanup_task(self):
"""Stop background cleanup task"""
if self._cleanup_task:
@ -36,49 +37,44 @@ class EntityManager(BaseModel):
pass
self._cleanup_task = None
def initialize(
self,
prometheus_collector: CollectorRegistry,
database: RedisDatabase):
def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
"""Initialize the EntityManager with Prometheus collector"""
self._prometheus_collector = prometheus_collector
self._database = database
async def get_entity(self, candidate: "Candidate") -> CandidateEntity:
"""Get or create CandidateEntity with proper reference tracking"""
# Check if entity exists and is still valid
if id in self._entities:
entity = self._entities[candidate.id]
entity.last_accessed = datetime.now()
entity.reference_count += 1
return entity
if not self._prometheus_collector or not self._database:
raise ValueError("EntityManager has not been initialized with required components.")
entity = CandidateEntity(candidate=candidate)
await entity.initialize(
prometheus_collector=self._prometheus_collector,
database=self._database)
await entity.initialize(prometheus_collector=self._prometheus_collector, database=self._database)
# Store with reference tracking
self._entities[candidate.id] = entity
self._weak_refs[candidate.id] = weakref.ref(entity, self._on_entity_deleted(candidate.id))
entity.reference_count = 1
entity.last_accessed = datetime.now()
return entity
async def remove_entity(self, candidate_id: str) -> bool:
"""
Immediately remove and cleanup a candidate entity from active persistence.
This should be called when a candidate is being deleted from the system.
Args:
candidate_id: The ID of the candidate entity to remove
Returns:
bool: True if entity was found and removed, False if not found
"""
@ -88,36 +84,38 @@ class EntityManager(BaseModel):
if not entity:
print(f"Entity {candidate_id} not found in active persistence")
return False
# Remove from tracking dictionaries
self._entities.pop(candidate_id, None)
self._weak_refs.pop(candidate_id, None)
# Cleanup the entity
await entity.cleanup()
print(f"Successfully removed entity {candidate_id} from active persistence")
return True
except Exception as e:
print(f"Error removing entity {candidate_id}: {e}")
return False
def _on_entity_deleted(self, user_id: str):
"""Callback when entity is garbage collected"""
def cleanup_callback(weak_ref):
self._entities.pop(user_id, None)
self._weak_refs.pop(user_id, None)
print(f"Entity {user_id} garbage collected")
return cleanup_callback
async def release_entity(self, user_id: str):
"""Explicitly release reference to entity"""
if user_id in self._entities:
entity = self._entities[user_id]
entity.reference_count = max(0, entity.reference_count - 1)
entity.last_accessed = datetime.now()
async def _periodic_cleanup(self):
"""Background task to clean up expired entities"""
while True:
@ -128,20 +126,19 @@ class EntityManager(BaseModel):
break
except Exception as e:
print(f"Error in cleanup task: {e}")
async def _cleanup_expired_entities(self):
"""Remove entities that have expired based on TTL and reference count"""
current_time = datetime.now()
expired_entities = []
for user_id, entity in list(self._entities.items()):
time_since_access = current_time - entity.last_accessed
# Remove if TTL exceeded and no active references
if (time_since_access > timedelta(minutes=self._ttl_minutes)
and entity.reference_count == 0):
if time_since_access > timedelta(minutes=self._ttl_minutes) and entity.reference_count == 0:
expired_entities.append(user_id)
for user_id in expired_entities:
entity = self._entities.pop(user_id, None)
self._weak_refs.pop(user_id, None)
@ -153,6 +150,7 @@ class EntityManager(BaseModel):
# Global entity manager instance
entity_manager = EntityManager(default_ttl_minutes=30)
@asynccontextmanager
async def get_candidate_entity(candidate: Candidate):
"""Context manager for safe entity access with automatic reference management"""
@ -164,4 +162,5 @@ async def get_candidate_entity(candidate: Candidate):
finally:
await entity_manager.release_entity(candidate.id)
EntityManager.model_rebuild()

View File

@ -6,189 +6,195 @@ without getting caught up in serialization format complexities
import sys
from datetime import datetime
from models import (
UserStatus, UserType, SkillLevel, EmploymentType,
Candidate, Employer, Location, Skill
)
from models import UserStatus, UserType, SkillLevel, EmploymentType, Candidate, Employer, Location, Skill
def test_model_creation():
"""Test that we can create models successfully"""
print("🧪 Testing model creation...")
# Create supporting objects
location = Location(city="Austin", country="USA")
skill = Skill(name="Python", category="Programming", level=SkillLevel.ADVANCED)
# Create candidate
candidate = Candidate(
email="test@example.com",
user_type=UserType.CANDIDATE,
username="test_candidate",
createdAt=datetime.now(),
updatedAt=datetime.now(),
created_at=datetime.now(),
updated_at=datetime.now(),
status=UserStatus.ACTIVE,
firstName="John",
lastName="Doe",
fullName="John Doe",
first_name="John",
last_name="Doe",
full_name="John Doe",
skills=[skill],
experience=[],
education=[],
preferredJobTypes=[EmploymentType.FULL_TIME],
preferred_job_types=[EmploymentType.FULL_TIME],
location=location,
languages=[],
certifications=[]
certifications=[],
)
# Create employer
employer = Employer(
firstName="Mary",
lastName="Smith",
fullName="Mary Smith",
user_type=UserType.EMPLOYER,
first_name="Mary",
last_name="Smith",
full_name="Mary Smith",
email="hr@company.com",
username="test_employer",
createdAt=datetime.now(),
updatedAt=datetime.now(),
created_at=datetime.now(),
updated_at=datetime.now(),
status=UserStatus.ACTIVE,
companyName="Test Company",
company_name="Test Company",
industry="Technology",
companySize="50-200",
companyDescription="A test company",
location=location
company_size="50-200",
company_description="A test company",
location=location,
)
print(f"✅ Candidate: {candidate.first_name} {candidate.last_name}")
print(f"✅ Employer: {employer.company_name}")
print(f"✅ User types: {candidate.user_type}, {employer.user_type}")
return candidate, employer
def test_json_api_format():
"""Test JSON serialization in API format (the most important use case)"""
print("\n📡 Testing JSON API format...")
candidate, employer = test_model_creation()
# Serialize to JSON (API format)
candidate_json = candidate.model_dump_json(by_alias=True)
employer_json = employer.model_dump_json(by_alias=True)
print(f"✅ Candidate JSON: {len(candidate_json)} chars")
print(f"✅ Employer JSON: {len(employer_json)} chars")
# Deserialize from JSON
candidate_back = Candidate.model_validate_json(candidate_json)
employer_back = Employer.model_validate_json(employer_json)
# Verify data integrity
assert candidate_back.email == candidate.email
assert candidate_back.first_name == candidate.first_name
assert employer_back.company_name == employer.company_name
print(f"✅ JSON round-trip successful")
print(f"✅ Data integrity verified")
print("✅ JSON round-trip successful")
print("✅ Data integrity verified")
return True
def test_api_dict_format():
"""Test dictionary format with aliases (for API requests/responses)"""
print("\n📊 Testing API dictionary format...")
candidate, employer = test_model_creation()
# Create API format dictionaries
candidate_dict = candidate.model_dump(by_alias=True)
employer_dict = employer.model_dump(by_alias=True)
# Verify camelCase aliases are used
assert "firstName" in candidate_dict
assert "lastName" in candidate_dict
assert "createdAt" in candidate_dict
assert "companyName" in employer_dict
print(f"✅ API format dictionaries created")
print(f"✅ CamelCase aliases verified")
print("✅ API format dictionaries created")
print("✅ CamelCase aliases verified")
# Test deserializing from API format
candidate_back = Candidate.model_validate(candidate_dict)
employer_back = Employer.model_validate(employer_dict)
assert candidate_back.email == candidate.email
assert employer_back.company_name == employer.company_name
print(f"✅ API format round-trip successful")
print("✅ API format round-trip successful")
return True
def test_validation_constraints():
"""Test that validation constraints work"""
print("\n🔒 Testing validation constraints...")
try:
# Create a candidate with invalid email
invalid_candidate = Candidate(
Candidate(
user_type=UserType.CANDIDATE,
email="invalid-email",
username="test_invalid",
createdAt=datetime.now(),
updatedAt=datetime.now(),
created_at=datetime.now(),
updated_at=datetime.now(),
status=UserStatus.ACTIVE,
firstName="Jane",
lastName="Doe",
fullName="Jane Doe"
first_name="Jane",
last_name="Doe",
full_name="Jane Doe",
)
print("❌ Validation should have failed but didn't")
return False
except ValueError as e:
print(f"✅ Validation error caught: {e}")
print(f"✅ Validation error caught: {e}")
return True
def test_enum_values():
"""Test that enum values work correctly"""
print("\n📋 Testing enum values...")
# Test that enum values are properly handled
candidate, employer = test_model_creation()
# Check enum values in serialization
candidate_dict = candidate.model_dump(by_alias=True)
assert candidate_dict["status"] == "active"
assert candidate_dict["userType"] == "candidate"
assert employer.user_type == UserType.EMPLOYER
print(f"✅ Enum values correctly serialized")
print("✅ Enum values correctly serialized")
print(f"✅ User types: candidate={candidate.user_type}, employer={employer.user_type}")
return True
def main():
"""Run all focused tests"""
print("🎯 Focused Pydantic Model Tests")
print("=" * 40)
try:
test_model_creation()
test_json_api_format()
test_api_dict_format()
test_api_dict_format()
test_validation_constraints()
test_enum_values()
print(f"\n🎉 All focused tests passed!")
print("\n🎉 All focused tests passed!")
print("=" * 40)
print("✅ Models work correctly")
print("✅ JSON API format works")
print("✅ Validation constraints work")
print("✅ Enum values work")
print("✅ Ready for type generation!")
return True
except Exception as e:
print(f"\n❌ Test failed: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
print(f"\n{traceback.format_exc()}")
return False
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)
sys.exit(0 if success else 1)

View File

@ -1,7 +1,8 @@
from pydantic import BaseModel
from pydantic import BaseModel
import json
from typing import Any, List, Set
def check_serializable(obj: Any, path: str = "", errors: List[str] = [], visited: Set[int] = set()) -> List[str]:
"""
Recursively check all fields in an object for non-JSON-serializable types, avoiding infinite recursion.
@ -57,4 +58,4 @@ def check_serializable(obj: Any, path: str = "", errors: List[str] = [], visited
# Remove the current object from visited to allow processing in other branches
visited.discard(obj_id)
return errors
return errors

View File

@ -1,5 +1,5 @@
from typing import Type, TypeVar
from pydantic import BaseModel
from pydantic import BaseModel
import copy
from models import Candidate, CandidateAI, Employer, Guest, BaseUserWithType
@ -10,16 +10,19 @@ assert issubclass(CandidateAI, BaseUserWithType), "CandidateAI must inherit from
assert issubclass(Employer, BaseUserWithType), "Employer must inherit from BaseUserWithType"
assert issubclass(Guest, BaseUserWithType), "Guest must inherit from BaseUserWithType"
T = TypeVar('T', bound=BaseModel)
T = TypeVar("T", bound=BaseModel)
def cast_to_model(model_cls: Type[T], source: BaseModel) -> T:
data = {field: getattr(source, field) for field in model_cls.__fields__}
return model_cls(**data)
def cast_to_model_safe(model_cls: Type[T], source: BaseModel) -> T:
data = {field: copy.deepcopy(getattr(source, field)) for field in model_cls.__fields__}
return model_cls(**data)
def cast_to_base_user_with_type(user) -> BaseUserWithType:
"""
Casts a Candidate, CandidateAI, Employer, or Guest to BaseUserWithType.

View File

@ -4,10 +4,11 @@ import re
import time
from typing import Any
import torch
from diffusers import StableDiffusionPipeline, FluxPipeline
import torch
from diffusers import StableDiffusionPipeline, FluxPipeline
class ImageModelCache: # Stay loaded for 3 hours
class ImageModelCache: # Stay loaded for 3 hours
def __init__(self, timeout_seconds: float = 3 * 60 * 60):
self._pipe = None
self._model_name = None
@ -36,11 +37,11 @@ class ImageModelCache: # Stay loaded for 3 hours
cached_model_type = self._get_model_type(self._model_name) if self._model_name else None
if (
self._pipe is not None and
self._model_name == model and
self._device == device and
current_model_type == cached_model_type and
current_time - self._last_access_time < self._timeout_seconds
self._pipe is not None
and self._model_name == model
and self._device == device
and current_model_type == cached_model_type
and current_time - self._last_access_time < self._timeout_seconds
):
self._last_access_time = current_time
return self._pipe
@ -52,8 +53,10 @@ class ImageModelCache: # Stay loaded for 3 hours
model,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
)
def dummy_safety_checker(images, clip_input):
return images, [False] * len(images)
pipe.safety_checker = dummy_safety_checker
else:
pipe = FluxPipeline.from_pretrained(
@ -61,7 +64,7 @@ class ImageModelCache: # Stay loaded for 3 hours
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
)
try:
pipe.load_lora_weights('enhanceaiteam/Flux-uncensored', weight_name='lora.safetensors')
pipe.load_lora_weights("enhanceaiteam/Flux-uncensored", weight_name="lora.safetensors")
except Exception as e:
raise Exception(f"Failed to load LoRA weights: {str(e)}")
@ -89,10 +92,7 @@ class ImageModelCache: # Stay loaded for 3 hours
async def cleanup_if_expired(self):
async with self._lock:
if (
self._pipe is not None and
time.time() - self._last_access_time >= self._timeout_seconds
):
if self._pipe is not None and time.time() - self._last_access_time >= self._timeout_seconds:
await self._unload_model()
async def _periodic_cleanup(self):

View File

@ -1,5 +1,5 @@
from __future__ import annotations
from pydantic import BaseModel
from pydantic import BaseModel
from typing import Any, AsyncGenerator
import traceback
import asyncio
@ -29,6 +29,7 @@ TIME_ESTIMATES = {
}
}
class ImageRequest(BaseModel):
session_id: str
filepath: str
@ -39,34 +40,40 @@ class ImageRequest(BaseModel):
width: int = 256
guidance_scale: float = 7.5
# Global model cache instance
model_cache = ImageModelCache()
def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str):
"""Background worker for Flux image generation"""
try:
# Flux: Run generation in the background and yield progress updates
status_queue.put(ChatMessageStatus(
session_id=params.session_id,
content=f"Initializing image generation.",
activity=ApiActivityType.GENERATING_IMAGE,
))
status_queue.put(
ChatMessageStatus(
session_id=params.session_id,
content="Initializing image generation.",
activity=ApiActivityType.GENERATING_IMAGE,
)
)
# Start the generation task
start_gen_time = time.time()
# Simulate your pipe call with progress updates
def status_callback(pipeline, step, timestep, callback_kwargs):
# Send progress updates
progress = int((step+1) / params.iterations * 100)
status_queue.put(ChatMessageStatus(
session_id=params.session_id,
content=f"Processing step {step+1}/{params.iterations} ({progress}%)",
activity=ApiActivityType.GENERATING_IMAGE,
))
progress = int((step + 1) / params.iterations * 100)
status_queue.put(
ChatMessageStatus(
session_id=params.session_id,
content=f"Processing step {step+1}/{params.iterations} ({progress}%)",
activity=ApiActivityType.GENERATING_IMAGE,
)
)
return callback_kwargs
# Replace this block with your actual Flux pipe call:
image = pipe(
params.prompt,
@ -76,30 +83,36 @@ def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task
width=params.width,
callback_on_step_end=status_callback,
).images[0]
gen_time = time.time() - start_gen_time
per_step_time = gen_time / params.iterations if params.iterations > 0 else gen_time
logger.info(f"Saving to {params.filepath}")
image.save(params.filepath)
# Final completion status
status_queue.put(ChatMessage(
session_id=params.session_id,
status=ApiStatusType.DONE,
content=f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.",
))
status_queue.put(
ChatMessage(
session_id=params.session_id,
status=ApiStatusType.DONE,
content=f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.",
)
)
except Exception as e:
logger.error(traceback.format_exc())
logger.error(e)
status_queue.put(ChatMessageError(
session_id=params.session_id,
content=f"Error during image generation: {str(e)}",
))
status_queue.put(
ChatMessageError(
session_id=params.session_id,
content=f"Error during image generation: {str(e)}",
)
)
async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError, None]:
async def async_generate_image(
pipe: Any, params: ImageRequest
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError, None]:
"""
Single async function that handles background Flux generation with status streaming
"""
@ -109,40 +122,36 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato
try:
# Start background worker thread
worker_thread = Thread(
target=flux_worker,
args=(pipe, params, status_queue, task_id),
daemon=True
)
worker_thread = Thread(target=flux_worker, args=(pipe, params, status_queue, task_id), daemon=True)
worker_thread.start()
# Initial status
status_message = ChatMessageStatus(
session_id=params.session_id,
content=f"Starting image generation with task ID {task_id}",
activity=ApiActivityType.THINKING
activity=ApiActivityType.THINKING,
)
yield status_message
# Stream status updates
completed = False
last_heartbeat = time.time()
while not completed and worker_thread.is_alive():
try:
# Try to get status update (non-blocking)
status_update = status_queue.get_nowait()
# Send status update
yield status_update
# Check if completed
if status_update.status == ApiStatusType.DONE:
logger.info(f"Image generation completed for task {task_id}")
completed = True
last_heartbeat = time.time()
except queue.Empty:
# No new status, send heartbeat if needed
current_time = time.time()
@ -154,10 +163,10 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato
)
yield heartbeat
last_heartbeat = current_time
# Brief sleep to prevent busy waiting
await asyncio.sleep(0.1)
# Handle thread completion or timeout
if not completed:
if worker_thread.is_alive():
@ -175,20 +184,18 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato
content=f"Generation completed for task {task_id}.",
)
yield final_status
except Exception as e:
error_status = ChatMessageError(
session_id=params.session_id,
content=f'Server error: {str(e)}'
)
error_status = ChatMessageError(session_id=params.session_id, content=f"Server error: {str(e)}")
logger.error(error_status)
yield error_status
finally:
# Cleanup: ensure thread completion
if worker_thread and 'worker_thread' in locals() and worker_thread.is_alive():
if worker_thread and "worker_thread" in locals() and worker_thread.is_alive():
worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup
def status(session_id: str, status: str) -> ChatMessageStatus:
"""Update chat message status and return it."""
chat_message = ChatMessageStatus(
@ -198,6 +205,7 @@ def status(session_id: str, status: str) -> ChatMessageStatus:
)
return chat_message
async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, None]:
"""Generate an image with specified dimensions and yield status updates with time estimates."""
session_id = request.session_id
@ -205,10 +213,7 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
try:
# Validate prompt
if not prompt:
error_message = ChatMessageError(
session_id=session_id,
content="Prompt cannot be empty."
)
error_message = ChatMessageError(session_id=session_id, content="Prompt cannot be empty.")
logger.error(error_message.content)
yield error_message
return
@ -216,15 +221,14 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
# Validate dimensions
if request.height <= 0 or request.width <= 0:
error_message = ChatMessageError(
session_id=session_id,
content="Height and width must be positive integers."
session_id=session_id, content="Height and width must be positive integers."
)
logger.error(error_message.content)
yield error_message
return
filedir = os.path.dirname(request.filepath)
filename = os.path.basename(request.filepath)
os.path.basename(request.filepath)
os.makedirs(filedir, exist_ok=True)
model_type = "flux"
@ -233,14 +237,17 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
# Get initial time estimate, scaled by resolution
estimates = TIME_ESTIMATES[model_type][device]
resolution_scale = (request.height * request.width) / (512 * 512)
estimated_total = estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale
estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale
# Initialize or get cached pipeline
start_time = time.time()
yield status(session_id, f"Loading generative image model...")
yield status(session_id, "Loading generative image model...")
pipe = await model_cache.get_pipeline(request.model, device)
load_time = time.time() - start_time
yield status(session_id, f"Model loaded in {load_time:.1f} seconds.",)
yield status(
session_id,
f"Model loaded in {load_time:.1f} seconds.",
)
progress = None
async for progress in async_generate_image(pipe, request):
@ -252,13 +259,12 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
if not progress:
error_message = ChatMessageError(
session_id=session_id,
content="Image generation failed to produce a valid response."
session_id=session_id, content="Image generation failed to produce a valid response."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
return
# Final result
total_time = time.time() - start_time
chat_message = ChatMessage(
@ -269,11 +275,8 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
yield chat_message
except Exception as e:
error_message = ChatMessageError(
session_id=session_id,
content=f"Error during image generation: {str(e)}"
)
error_message = ChatMessageError(session_id=session_id, content=f"Error during image generation: {str(e)}")
logger.error(traceback.format_exc())
logger.error(error_message.content)
yield error_message
return
return

View File

@ -2,13 +2,14 @@ import json
import re
from typing import List, Union
def extract_json_blocks(text: str, allow_multiple: bool = False) -> List[dict]:
"""
Extract JSON blocks from text, even if surrounded by markdown or noisy text.
If allow_multiple is True, returns all JSON blocks; otherwise, only the first.
"""
found = []
# First try to extract from code blocks (most reliable)
code_block_pattern = r"```(?:json)?\s*([\s\S]+?)\s*```"
for match in re.finditer(code_block_pattern, text):
@ -20,24 +21,25 @@ def extract_json_blocks(text: str, allow_multiple: bool = False) -> List[dict]:
return [parsed]
except json.JSONDecodeError:
continue
# If no valid code blocks found, look for standalone JSON objects/arrays
if not found:
standalone_json = _extract_standalone_json(text, allow_multiple)
found.extend(standalone_json)
if not found:
raise ValueError("No valid JSON block found in the text")
return found
def _extract_standalone_json(text: str, allow_multiple: bool = False) -> List[Union[dict, list]]:
"""Extract standalone JSON objects or arrays from text using proper brace counting."""
found = []
i = 0
while i < len(text):
if text[i] in '{[':
if text[i] in "{[":
# Found potential JSON start
json_str = _extract_complete_json_at_position(text, i)
if json_str:
@ -52,31 +54,32 @@ def _extract_standalone_json(text: str, allow_multiple: bool = False) -> List[Un
except json.JSONDecodeError:
pass
i += 1
return found
def _extract_complete_json_at_position(text: str, start_pos: int) -> str:
"""
Extract a complete JSON object or array starting at the given position.
Uses proper brace/bracket counting and string escape handling.
"""
if start_pos >= len(text) or text[start_pos] not in '{[':
if start_pos >= len(text) or text[start_pos] not in "{[":
return ""
start_char = text[start_pos]
end_char = '}' if start_char == '{' else ']'
end_char = "}" if start_char == "{" else "]"
count = 1
i = start_pos + 1
in_string = False
escape_next = False
while i < len(text) and count > 0:
char = text[i]
if escape_next:
escape_next = False
elif char == '\\' and in_string:
elif char == "\\" and in_string:
escape_next = True
elif char == '"' and not escape_next:
in_string = not in_string
@ -85,13 +88,14 @@ def _extract_complete_json_at_position(text: str, start_pos: int) -> str:
count += 1
elif char == end_char:
count -= 1
i += 1
if count == 0:
return text[start_pos:i]
return ""
def extract_json_from_text(text: str) -> str:
"""Extract JSON string from text that may contain other content."""
return json.dumps(extract_json_blocks(text, allow_multiple=False)[0])

View File

@ -2,11 +2,11 @@ import os
import warnings
import logging
import defines
def _setup_logging(level=defines.logging_level) -> logging.Logger:
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
warnings.filterwarnings(
"ignore", message="Overriding a previously registered kernel"
)
warnings.filterwarnings("ignore", message="Overriding a previously registered kernel")
warnings.filterwarnings("ignore", message="Warning only once for all operators")
warnings.filterwarnings("ignore", message=".*Couldn't find ffmpeg or avconv.*")
warnings.filterwarnings("ignore", message="'force_all_finite' was renamed to")
@ -19,8 +19,7 @@ def _setup_logging(level=defines.logging_level) -> logging.Logger:
# Create a custom formatter
formatter = logging.Formatter(
fmt="%(levelname)s - %(filename)s:%(lineno)d - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
fmt="%(levelname)s - %(filename)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
# Create a handler (e.g., StreamHandler for console output)
@ -50,5 +49,6 @@ def _setup_logging(level=defines.logging_level) -> logging.Logger:
logger = logging.getLogger(__name__)
return logger
logger = _setup_logging(level=defines.logging_level)
logger.debug(f"Logging initialized with level: {defines.logging_level}")
logger.debug(f"Logging initialized with level: {defines.logging_level}")

View File

@ -39,7 +39,7 @@ from utils.dependencies import get_database, set_db_manager
from routes import (
admin,
auth,
candidates,
candidates,
chat,
employers,
jobs,
@ -58,6 +58,7 @@ background_task_manager = None
prev_int = signal.getsignal(signal.SIGINT)
prev_term = signal.getsignal(signal.SIGTERM)
def signal_handler(signum, frame):
logger.info(f"⚠️ Received signal {signum!r}, shutting down…")
# now call the old handler (it might raise KeyboardInterrupt or exit)
@ -66,14 +67,16 @@ def signal_handler(signum, frame):
elif signum == signal.SIGTERM and callable(prev_term):
prev_term(signum, frame)
# Global background task manager
background_task_manager: Optional[BackgroundTaskManager] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
global db_manager, background_task_manager
logger.info("🚀 Starting Backstory API with enhanced background tasks")
logger.info(f"📝 API Documentation available at: http://{defines.host}:{defines.port}{defines.api_prefix}/docs")
logger.info("🔗 API endpoints prefixed with: /api/1.0")
@ -81,10 +84,10 @@ async def lifespan(app: FastAPI):
# Initialize database
db_manager = DatabaseManager()
await db_manager.initialize()
# Set the database manager in dependencies
set_db_manager(db_manager)
entity_manager.initialize(prometheus_collector=prometheus_collector, database=db_manager.get_database())
# Initialize background task manager
@ -96,26 +99,27 @@ async def lifespan(app: FastAPI):
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
logger.info("🚀 Application startup completed with background tasks")
yield # Application is running
except Exception as e:
logger.error(f"❌ Failed to start application: {e}")
raise
finally:
# Shutdown
logger.info("Application shutdown requested")
# Stop background tasks first
if background_task_manager:
await background_task_manager.stop()
if db_manager:
await db_manager.graceful_shutdown()
app = FastAPI(
lifespan=lifespan,
title="Backstory API",
@ -129,12 +133,10 @@ app = FastAPI(
ssl_enabled = os.getenv("SSL_ENABLED", "true").lower() == "true"
if ssl_enabled:
allow_origins = ["https://battle-linux.ketrenos.com:3000",
"https://backstory-beta.ketrenos.com"]
allow_origins = ["https://battle-linux.ketrenos.com:3000", "https://backstory-beta.ketrenos.com"]
else:
allow_origins = ["http://battle-linux.ketrenos.com:3000",
"http://backstory-beta.ketrenos.com"]
allow_origins = ["http://battle-linux.ketrenos.com:3000", "http://backstory-beta.ketrenos.com"]
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
@ -144,12 +146,14 @@ app.add_middleware(
allow_headers=["*"],
)
# ============================
# Debug data type failures
# ============================
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
import traceback
logger.error(traceback.format_exc())
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Validation error {request.method} {request.url.path}: {str(exc)}")
@ -158,6 +162,7 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
content=json.dumps({"detail": str(exc)}),
)
# ============================
# Create API router with prefix
# ============================
@ -181,44 +186,43 @@ api_router.include_router(users.router)
# Health Check and Info Endpoints
# ============================
@app.get("/health")
async def health_check(
database = Depends(get_database),
database=Depends(get_database),
):
"""Health check endpoint"""
try:
if not redis_manager.redis:
raise RuntimeError("Redis client not initialized")
# Test Redis connection
await redis_manager.redis.ping()
# Get database stats
stats = await database.get_stats()
# Redis info
redis_info = await redis_manager.redis.info()
return {
"status": "healthy",
"timestamp": datetime.utcnow().isoformat(),
"database": {
"status": "connected",
"stats": stats
},
"database": {"status": "connected", "stats": stats},
"redis": {
"version": redis_info.get("redis_version", "unknown"),
"uptime": redis_info.get("uptime_in_seconds", 0),
"memory_used": redis_info.get("used_memory_human", "unknown")
}
"memory_used": redis_info.get("used_memory_human", "unknown"),
},
}
except RuntimeError as e:
return {"status": "shutting_down", "message": str(e)}
except Exception as e:
logger.error(f"❌ Health check failed: {e}")
return {"status": "error", "message": str(e)}
# ============================
# Include Router in App
# ============================
@ -231,11 +235,14 @@ app.include_router(api_router)
# ============================
logger.info(f"Debug mode is {'enabled' if defines.debug else 'disabled'}")
@app.middleware("http")
async def log_requests(request: Request, call_next):
try:
if defines.debug and not re.match(rf"{defines.api_prefix}/metrics", request.url.path):
logger.info(f"📝 Request {request.method}: {request.url.path}, Remote: {request.client.host if request.client else ''}")
logger.info(
f"📝 Request {request.method}: {request.url.path}, Remote: {request.client.host if request.client else ''}"
)
response = await call_next(request)
if defines.debug and not re.match(rf"{defines.api_prefix}/metrics", request.url.path):
if response.status_code < 200 or response.status_code >= 300:
@ -243,11 +250,13 @@ async def log_requests(request: Request, call_next):
return response
except Exception as e:
import traceback
logger.error(traceback.format_exc())
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Error processing request: {str(e)}, Path: {request.url.path}, Method: {request.method}")
return JSONResponse(status_code=400, content={"detail": "Invalid HTTP request"})
# ============================
# Request tracking middleware
# ============================
@ -258,7 +267,7 @@ async def track_requests(request, call_next):
"""Middleware to track active requests during shutdown"""
if db_manager.is_shutting_down:
return JSONResponse(status_code=503, content={"error": "Application is shutting down"})
db_manager.increment_requests()
try:
response = await call_next(request)
@ -266,6 +275,7 @@ async def track_requests(request, call_next):
finally:
db_manager.decrement_requests()
# ============================
# FastAPI Metrics
# ============================
@ -277,7 +287,7 @@ instrumentator = Instrumentator(
should_ignore_untemplated=True,
should_group_untemplated=True,
excluded_handlers=[f"{defines.api_prefix}/metrics"],
registry=prometheus_collector
registry=prometheus_collector,
)
# Instrument the FastAPI app
@ -291,15 +301,17 @@ instrumentator.expose(app, endpoint=f"{defines.api_prefix}/metrics")
# Static File Serving
# ============================
@app.get("/{path:path}")
async def serve_static(path: str, request: Request):
full_path = os.path.join(defines.static_content, path)
if os.path.exists(full_path) and os.path.isfile(full_path):
return FileResponse(full_path)
return FileResponse(os.path.join(defines.static_content, "index.html"))
# Root endpoint when no static files
@app.get("/", include_in_schema=False)
async def root():
@ -309,21 +321,23 @@ async def root():
"version": "1.0.0",
"api_prefix": defines.api_prefix,
"documentation": f"{defines.api_prefix}/docs",
"health": f"{defines.api_prefix}/health"
"health": f"{defines.api_prefix}/health",
}
async def periodic_verification_cleanup():
"""Background task to periodically clean up expired verification tokens"""
try:
database = get_database()
cleaned_count = await database.cleanup_expired_verification_tokens()
if cleaned_count > 0:
logger.info(f"🧹 Periodic cleanup: removed {cleaned_count} expired verification tokens")
except Exception as e:
logger.error(f"❌ Error in periodic verification cleanup: {e}")
if __name__ == "__main__":
host = defines.host
port = defines.port
@ -341,4 +355,4 @@ if __name__ == "__main__":
)
else:
logger.info(f"Starting web server at http://{host}:{port}")
uvicorn.run(app="main:app", host=host, port=port, log_config=None)
uvicorn.run(app="main:app", host=host, port=port, log_config=None)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,23 @@
[tool.ruff]
# Set line length (default is 88, Black-compatible)
line-length = 120
exclude = [
".git",
"__pycache__",
"migrations",
"venv",
"dist",
"build",
]
target-version = "py312"
[tool.ruff.format]
# Use Ruff's formatter (Black-compatible by default)
quote-style = "double" # Enforce double quotes
indent-style = "space" # Use spaces for indentation
skip-magic-trailing-comma = false # Add trailing commas where applicable
[tool.ruff.per-file-ignores]
# Ignore specific rules for certain files (e.g., tests or scripts)
"tests/**" = ["D100", "S101"] # Ignore missing docstrings and assert warnings in tests
"scripts/**" = ["E402"] # Ignore import order in scripts

View File

@ -1,7 +1,3 @@
from .rag import ChromaDBFileWatcher, start_file_watcher, RagEntry
__all__ = [
"ChromaDBFileWatcher",
"start_file_watcher",
"RagEntry"
]
__all__ = ["ChromaDBFileWatcher", "start_file_watcher", "RagEntry"]

View File

@ -7,10 +7,12 @@ import logging
import defines
class Chunk(TypedDict):
text: str
metadata: Dict[str, Any]
def clear_chunk(chunk: Chunk):
chunk["text"] = ""
chunk["metadata"] = {
@ -22,6 +24,7 @@ def clear_chunk(chunk: Chunk):
}
return chunk
class MarkdownChunker:
def __init__(self):
# Initialize the Markdown parser
@ -76,10 +79,7 @@ class MarkdownChunker:
# Initialize a chunk structure
chunk: Chunk = {
"text": "",
"metadata": {
"source_file": file_path,
"lines": total_lines
},
"metadata": {"source_file": file_path, "lines": total_lines},
}
clear_chunk(chunk)
@ -89,9 +89,7 @@ class MarkdownChunker:
return chunks
def _sanitize_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
return {
k: ("" if v is None else v) for k, v in metadata.items() if v is not None
}
return {k: ("" if v is None else v) for k, v in metadata.items() if v is not None}
def _extract_text_from_children(self, node: SyntaxTreeNode) -> str:
lines = []
@ -114,7 +112,7 @@ class MarkdownChunker:
chunks: List[Chunk],
chunk: Chunk,
level: int,
buffer: int = defines.chunk_buffer
buffer: int = defines.chunk_buffer,
) -> int:
is_list = False
# Handle heading nodes
@ -141,7 +139,7 @@ class MarkdownChunker:
if node.nester_tokens:
opening, closing = node.nester_tokens
if opening and opening.map:
( begin, end ) = opening.map
(begin, end) = opening.map
metadata = chunk["metadata"]
metadata["chunk_begin"] = max(0, begin - buffer)
metadata["chunk_end"] = min(metadata["lines"], end + buffer)
@ -186,7 +184,7 @@ class MarkdownChunker:
if node.nester_tokens:
opening, closing = node.nester_tokens
if opening and opening.map:
( begin, end ) = opening.map
(begin, end) = opening.map
metadata = chunk["metadata"]
metadata["chunk_begin"] = max(0, begin - buffer)
metadata["chunk_end"] = min(metadata["lines"], end + buffer)
@ -198,9 +196,7 @@ class MarkdownChunker:
# Recursively process children
if not is_list:
for child in node.children:
level = self._process_node(
child, current_headings, chunks, chunk, level=level
)
level = self._process_node(child, current_headings, chunks, chunk, level=level)
# After root-level recursion, finalize any remaining chunk
if node.type == "document":
@ -211,7 +207,7 @@ class MarkdownChunker:
if node.nester_tokens:
opening, closing = node.nester_tokens
if opening and opening.map:
( begin, end ) = opening.map
(begin, end) = opening.map
metadata = chunk["metadata"]
metadata["chunk_begin"] = max(0, begin - buffer)
metadata["chunk_end"] = min(metadata["lines"], end + buffer)

View File

@ -1,5 +1,5 @@
from __future__ import annotations
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field # type: ignore
from pydantic import BaseModel # type: ignore
from typing import List, Optional, Dict, Any
import os
import glob
@ -11,10 +11,10 @@ import json
import numpy as np # type: ignore
import traceback
import chromadb # type: ignore
import chromadb # type: ignore
from watchdog.observers import Observer # type: ignore
from watchdog.events import FileSystemEventHandler # type: ignore
import umap # type: ignore
import umap # type: ignore
from markitdown import MarkItDown # type: ignore
from chromadb.api.models.Collection import Collection # type: ignore
@ -33,11 +33,13 @@ __all__ = ["ChromaDBFileWatcher", "start_file_watcher"]
DEFAULT_CHUNK_SIZE = 750
DEFAULT_CHUNK_OVERLAP = 100
class RagEntry(BaseModel):
name: str
description: str = ""
enabled: bool = True
class ChromaDBFileWatcher(FileSystemEventHandler):
def __init__(
self,
@ -72,9 +74,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Path for storing file hash state
self.hash_state_path = os.path.join(
self.persist_directory, f"{collection_name}_hash_state.json"
)
self.hash_state_path = os.path.join(self.persist_directory, f"{collection_name}_hash_state.json")
# Flag to track if this is a new collection
self.is_new_collection = False
@ -158,9 +158,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
process_all: If True, process all files regardless of hash status
"""
# Check for new or modified files
file_paths = glob.glob(
os.path.join(self.watch_directory, "**/*"), recursive=True
)
file_paths = glob.glob(os.path.join(self.watch_directory, "**/*"), recursive=True)
files_checked = 0
files_processed = 0
files_to_process = []
@ -180,20 +178,12 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
continue
# If file is new, changed, or we're processing all files
if (
process_all
or file_path not in self.file_hashes
or self.file_hashes[file_path] != current_hash
):
if process_all or file_path not in self.file_hashes or self.file_hashes[file_path] != current_hash:
self.file_hashes[file_path] = current_hash
files_to_process.append(file_path)
logging.info(
f"File {'found' if process_all else 'changed'}: {file_path}"
)
logging.info(f"File {'found' if process_all else 'changed'}: {file_path}")
logging.info(
f"Found {len(files_to_process)} files to process after scanning {files_checked} files"
)
logging.info(f"Found {len(files_to_process)} files to process after scanning {files_checked} files")
# Check for deleted files
deleted_files = []
@ -201,9 +191,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
if not os.path.exists(file_path):
deleted_files.append(file_path)
# Schedule removal
asyncio.run_coroutine_threadsafe(
self.remove_file_from_collection(file_path), self.loop
)
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
# Don't block on result, just let it run
logging.info(f"File deleted: {file_path}")
@ -253,10 +241,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
if not current_hash: # File might have been deleted or is inaccessible
return
if (
file_path in self.file_hashes
and self.file_hashes[file_path] == current_hash
):
if file_path in self.file_hashes and self.file_hashes[file_path] == current_hash:
# File hasn't actually changed in content
logging.info(f"Hash has not changed for {file_path}")
return
@ -289,9 +274,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
if results and "ids" in results and results["ids"]:
self.collection.delete(ids=results["ids"])
await self.database.update_user_rag_timestamp(self.user_id)
logging.info(
f"Removed {len(results['ids'])} chunks for deleted file: {file_path}"
)
logging.info(f"Removed {len(results['ids'])} chunks for deleted file: {file_path}")
# Remove from hash dictionary
if file_path in self.file_hashes:
@ -304,17 +287,15 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
def _update_umaps(self):
# Update the UMAP embeddings
self._umap_collection = ChromaDBGetResponse.model_validate(self._collection.get(
include=["embeddings", "documents", "metadatas"]
))
self._umap_collection = ChromaDBGetResponse.model_validate(
self._collection.get(include=["embeddings", "documents", "metadatas"])
)
if not self._umap_collection or not len(self._umap_collection.embeddings):
logging.warning("⚠️ No embeddings found in the collection.")
return
# During initialization
logging.info(
f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors"
)
logging.info(f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors")
vectors = np.array(self._umap_collection.embeddings)
self._umap_model_2d = umap.UMAP(
n_components=2,
@ -323,14 +304,12 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
n_neighbors=30,
min_dist=0.1,
)
self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors) # type: ignore
self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors) # type: ignore
# logging.info(
# f"2D UMAP model n_components: {self._umap_model_2d.n_components}"
# ) # Should be 2
logging.info(
f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors"
)
logging.info(f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors")
self._umap_model_3d = umap.UMAP(
n_components=3,
random_state=8911,
@ -338,7 +317,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
n_neighbors=30,
min_dist=0.01,
)
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors)# type: ignore
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors) # type: ignore
# logging.info(
# f"3D UMAP model n_components: {self._umap_model_3d.n_components}"
# ) # Should be 3
@ -350,7 +329,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
os.makedirs(self.persist_directory)
# Initialize ChromaDB client
chroma_client = chromadb.PersistentClient(
chroma_client = chromadb.PersistentClient(
path=self.persist_directory,
settings=chromadb.Settings(anonymized_telemetry=False), # type: ignore
)
@ -373,13 +352,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
self.is_new_collection = True
logging.info(f"Recreating collection: {self.collection_name}")
return chroma_client.get_or_create_collection(
name=self.collection_name, metadata={"hnsw:space": "cosine"}
)
return chroma_client.get_or_create_collection(name=self.collection_name, metadata={"hnsw:space": "cosine"})
async def get_embedding(self, text: str) -> np.ndarray:
"""Generate and normalize an embedding for the given text."""
# Get embedding
try:
response = await self.llm.embeddings(model=defines.embedding_model, input_texts=text)
@ -419,9 +396,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# Generate a more unique ID based on content and metadata
path_hash = ""
if "path" in metadata:
path_hash = hashlib.md5(metadata["source_file"].encode()).hexdigest()[
:8
]
path_hash = hashlib.md5(metadata["source_file"].encode()).hexdigest()[:8]
content_hash = hashlib.md5(text.encode()).hexdigest()[:8]
chunk_id = f"{path_hash}_{i}_{content_hash}"
@ -438,7 +413,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
logging.error(traceback.format_exc())
logging.error(chunk)
def prepare_metadata(self, meta: Dict[str, Any], buffer=defines.chunk_buffer)-> str | None:
def prepare_metadata(self, meta: Dict[str, Any], buffer=defines.chunk_buffer) -> str | None:
source_file = meta.get("source_file")
try:
source_file = meta["source_file"]
@ -541,9 +516,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
return
file_path = event.src_path
asyncio.run_coroutine_threadsafe(
self.remove_file_from_collection(file_path), self.loop
)
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
logging.info(f"File deleted: {file_path}")
def on_moved(self, event):
@ -571,11 +544,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
try:
# Remove existing entries for this file
existing_results = self.collection.get(where={"path": file_path})
if (
existing_results
and "ids" in existing_results
and existing_results["ids"]
):
if existing_results and "ids" in existing_results and existing_results["ids"]:
self.collection.delete(ids=existing_results["ids"])
await self.database.update_user_rag_timestamp(self.user_id)
@ -584,15 +553,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
p = Path(file_path)
p_as_md = p.with_suffix(".md")
if p_as_md.exists():
logging.info(
f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}"
)
logging.info(f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}")
# If file_path.md doesn't exist or file_path is newer than file_path.md,
# fire off markitdown
if (not p_as_md.exists()) or (
p.stat().st_mtime > p_as_md.stat().st_mtime
):
if (not p_as_md.exists()) or (p.stat().st_mtime > p_as_md.stat().st_mtime):
self._markitdown(file_path, p_as_md)
return
@ -626,9 +591,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# Process all files regardless of hash state
num_processed = await self.scan_directory(process_all=True)
logging.info(
f"Vectorstore initialized with {self.collection.count()} documents"
)
logging.info(f"Vectorstore initialized with {self.collection.count()} documents")
self._update_umaps()
@ -676,7 +639,7 @@ def start_file_watcher(
persist_directory=persist_directory,
collection_name=collection_name,
recreate=recreate,
database=database
database=database,
)
# Process all files if:

View File

@ -13,14 +13,4 @@ from . import employers
from . import admin
from . import system
__all__ = [
"auth",
"candidates",
"resumes",
"jobs",
"chat",
"users",
"employers",
"admin",
"system"
]
__all__ = ["auth", "candidates", "resumes", "jobs", "chat", "users", "employers", "admin", "system"]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -4,125 +4,69 @@ Chat routes
import json
import uuid
from datetime import datetime, UTC
from typing import (Dict, Any)
from typing import Dict, Any
from fastapi import (
APIRouter, Depends, Body, Depends, Query, Path,
Body, APIRouter
)
from fastapi import APIRouter, Depends, Body, Depends, Query, Path, Body, APIRouter
from fastapi.responses import JSONResponse
from database.manager import RedisDatabase
from logger import logger
from utils.dependencies import (
get_database, get_current_user, get_current_user_or_guest
)
from utils.responses import (
create_success_response, create_error_response, create_paginated_response
)
from utils.helpers import (
stream_agent_response
)
from utils.dependencies import get_database, get_current_user, get_current_user_or_guest
from utils.responses import create_success_response, create_error_response, create_paginated_response
from utils.helpers import stream_agent_response
import backstory_traceback
import entities.entity_manager as entities
from models import (
LoginRequest, CreateCandidateRequest, CreateEmployerRequest,
Candidate, Employer, Guest, AuthResponse,
MFARequest, MFAData, MFARequestResponse, MFAVerifyRequest,
EmailVerificationRequest, ResendVerificationRequest,
# API
MOCK_UUID, ApiActivityType, ChatMessageError, ChatMessageResume,
ChatMessageSkillAssessment, ChatMessageStatus, ChatMessageStreaming,
ChatMessageUser, DocumentMessage, DocumentOptions, Job,
JobRequirements, JobRequirementsMessage, LoginRequest,
CreateCandidateRequest, CreateEmployerRequest,
# User models
Candidate, Employer, BaseUserWithType, BaseUser, Guest,
Authentication, AuthResponse, CandidateAI,
# Job models
JobApplication, ApplicationStatus,
# Chat models
ChatSession, ChatMessage, ChatContext, ChatQuery, ApiStatusType, ChatSenderType, ApiMessageType, ChatContextType,
ChatMessageRagSearch,
# Document models
Document, DocumentType, DocumentListResponse, DocumentUpdateRequest, DocumentContentResponse,
# Supporting models
Location, MFARequest, MFAData, MFARequestResponse, MFAVerifyRequest, RagContentMetadata, RagContentResponse, ResendVerificationRequest, Resume, ResumeMessage, Skill, SkillAssessment, SystemInfo, UserType, WorkExperience, Education,
# Email
EmailVerificationRequest
)
from models import Candidate, ChatMessageUser, Candidate, BaseUserWithType, ChatSession, ChatMessage
# Create router for authentication endpoints
router = APIRouter(prefix="/chat", tags=["chat"])
@router.post("/sessions/{session_id}/archive")
async def archive_chat_session(
session_id: str = Path(...),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
session_id: str = Path(...), current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)
):
"""Archive a chat session"""
try:
session_data = await database.get_chat_session(session_id)
if not session_data:
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
# Check if user owns this session or is admin
if session_data.get("userId") != current_user.id:
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Cannot archive another user's session")
status_code=403, content=create_error_response("FORBIDDEN", "Cannot archive another user's session")
)
await database.archive_chat_session(session_id)
return create_success_response({
"message": "Chat session archived successfully",
"sessionId": session_id
})
return create_success_response({"message": "Chat session archived successfully", "sessionId": session_id})
except Exception as e:
logger.error(f"❌ Archive chat session error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("ARCHIVE_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("ARCHIVE_ERROR", str(e)))
@router.get("/statistics")
async def get_chat_statistics(
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
):
async def get_chat_statistics(current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)):
"""Get chat statistics (admin/analytics endpoint)"""
try:
stats = await database.get_chat_statistics()
return create_success_response(stats)
except Exception as e:
logger.error(f"❌ Get chat statistics error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("STATS_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("STATS_ERROR", str(e)))
@router.post("/sessions")
async def create_chat_session(
session_data: Dict[str, Any] = Body(...),
current_user: BaseUserWithType = Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Create a new chat session with optional candidate username association"""
try:
@ -130,31 +74,30 @@ async def create_chat_session(
username = session_data.get("username")
candidate_id = None
candidate_data = None
# If username is provided, look up the candidate
if username:
logger.info(f"🔍 Looking up candidate with username: {username}")
# Get all candidates and find by username
all_candidates_data = await database.get_all_candidates()
candidates_list = [Candidate.model_validate(data) for data in all_candidates_data.values()]
# Find candidate by username (case-insensitive)
matching_candidates = [
c for c in candidates_list
if c.username.lower() == username.lower()
]
matching_candidates = [c for c in candidates_list if c.username.lower() == username.lower()]
if not matching_candidates:
return JSONResponse(
status_code=404,
content=create_error_response("CANDIDATE_NOT_FOUND", f"Candidate with username '{username}' not found")
content=create_error_response(
"CANDIDATE_NOT_FOUND", f"Candidate with username '{username}' not found"
),
)
candidate_data = matching_candidates[0]
candidate_id = candidate_data.id
logger.info(f"✅ Found candidate: {candidate_data.full_name} (ID: {candidate_id})")
# Add required fields
session_id = str(uuid.uuid4())
session_data["id"] = session_id
@ -167,7 +110,7 @@ async def create_chat_session(
if candidate_id and candidate_data:
context["relatedEntityId"] = candidate_id
context["relatedEntityType"] = "candidate"
# Add candidate info to additional context for AI reference
additional_context = context.get("additionalContext", {})
additional_context["candidateInfo"] = {
@ -177,72 +120,83 @@ async def create_chat_session(
"username": candidate_data.username,
"skills": [skill.name for skill in candidate_data.skills] if candidate_data.skills else [],
"experience": len(candidate_data.experience) if candidate_data.experience else 0,
"location": candidate_data.location.city if candidate_data.location else "Unknown"
"location": candidate_data.location.city if candidate_data.location else "Unknown",
}
context["additionalContext"] = additional_context
# Set a descriptive title if not provided
if not session_data.get("title"):
session_data["title"] = f"Chat about {candidate_data.full_name}"
session_data["context"] = context
# Create chat session
chat_session = ChatSession.model_validate(session_data)
await database.set_chat_session(chat_session.id, chat_session.model_dump())
logger.info(f"✅ Chat session created: {chat_session.id} for user {current_user.id}" +
(f" about candidate {candidate_data.full_name}" if candidate_data else ""))
logger.info(
f"✅ Chat session created: {chat_session.id} for user {current_user.id}"
+ (f" about candidate {candidate_data.full_name}" if candidate_data else "")
)
return create_success_response(chat_session.model_dump(by_alias=True))
except Exception as e:
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Chat session creation error: {e}")
logger.info(json.dumps(session_data, indent=2))
return JSONResponse(
status_code=400,
content=create_error_response("CREATION_FAILED", str(e))
)
return JSONResponse(status_code=400, content=create_error_response("CREATION_FAILED", str(e)))
@router.post("/sessions/messages/stream")
async def post_chat_session_message_stream(
user_message: ChatMessageUser = Body(...),
current_user = Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database),
):
"""Post a message to a chat session and stream the response with persistence"""
try:
chat_session_data = await database.get_chat_session(user_message.session_id)
if not chat_session_data:
logger.info("🔗 Chat session not found for session ID: " + user_message.session_id)
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
chat_session = ChatSession.model_validate(chat_session_data)
chat_type = chat_session.context.type
candidate_info = chat_session.context.additional_context.get("candidateInfo", {}) if chat_session.context and chat_session.context.additional_context else None
candidate_info = (
chat_session.context.additional_context.get("candidateInfo", {})
if chat_session.context and chat_session.context.additional_context
else None
)
# Get candidate info if this chat is about a specific candidate
if candidate_info:
logger.info(f"🔗 Chat session {user_message.session_id} about candidate {candidate_info['name']} accessed by user {current_user.id}")
logger.info(
f"🔗 Chat session {user_message.session_id} about candidate {candidate_info['name']} accessed by user {current_user.id}"
)
else:
logger.info(f"🔗 Chat session {user_message.session_id} type {chat_type} accessed by user {current_user.id}")
logger.info(
f"🔗 Chat session {user_message.session_id} type {chat_type} accessed by user {current_user.id}"
)
return JSONResponse(
status_code=400,
content=create_error_response("CANDIDATE_REQUIRED", "This chat session requires a candidate association")
content=create_error_response(
"CANDIDATE_REQUIRED", "This chat session requires a candidate association"
),
)
candidate_data = await database.get_candidate(candidate_info["id"]) if candidate_info else None
candidate : Candidate | None = Candidate.model_validate(candidate_data) if candidate_data else None
candidate: Candidate | None = Candidate.model_validate(candidate_data) if candidate_data else None
if not candidate:
logger.info(f"🔗 Candidate not found for chat session {user_message.session_id} with ID {candidate_info['id']}")
logger.info(
f"🔗 Candidate not found for chat session {user_message.session_id} with ID {candidate_info['id']}"
)
return JSONResponse(
status_code=404,
content=create_error_response("CANDIDATE_NOT_FOUND", "Candidate not found for this chat session")
content=create_error_response("CANDIDATE_NOT_FOUND", "Candidate not found for this chat session"),
)
logger.info(f"🔗 User {current_user.id} posting message to chat session {user_message.session_id} with query length: {len(user_message.content)}")
logger.info(
f"🔗 User {current_user.id} posting message to chat session {user_message.session_id} with query length: {len(user_message.content)}"
)
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
# Entity automatically released when done
@ -251,13 +205,13 @@ async def post_chat_session_message_stream(
logger.info(f"🔗 No chat agent found for session {user_message.session_id} with type {chat_type}")
return JSONResponse(
status_code=400,
content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type")
content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type"),
)
# Persist user message to database
await database.add_chat_message(user_message.session_id, user_message.model_dump())
logger.info(f"💬 User message saved to database for session {user_message.session_id}")
# Update session last activity
chat_session_data["lastActivity"] = datetime.now(UTC).isoformat()
await database.set_chat_session(user_message.session_id, chat_session_data)
@ -268,35 +222,30 @@ async def post_chat_session_message_stream(
database=database,
chat_session_data=chat_session_data,
)
except Exception as e:
except Exception:
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Chat message streaming error")
return JSONResponse(
status_code=500,
content=create_error_response("STREAMING_ERROR", "")
)
logger.error("❌ Chat message streaming error")
return JSONResponse(status_code=500, content=create_error_response("STREAMING_ERROR", ""))
@router.get("/sessions/{session_id}/messages")
async def get_chat_session_messages(
session_id: str = Path(...),
current_user = Depends(get_current_user_or_guest),
current_user=Depends(get_current_user_or_guest),
page: int = Query(1, ge=1),
limit: int = Query(50, ge=1, le=100), # Increased default for chat messages
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Get persisted chat messages for a session"""
try:
chat_session_data = await database.get_chat_session(session_id)
if not chat_session_data:
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
# Get messages from database
chat_messages = await database.get_chat_messages(session_id)
# Convert to ChatMessage objects and sort by timestamp
messages_list = []
for msg_data in chat_messages:
@ -306,69 +255,61 @@ async def get_chat_session_messages(
except Exception as e:
logger.warning(f"⚠️ Failed to validate message: {e}")
continue
# Sort by timestamp (oldest first for chat history)
messages_list.sort(key=lambda x: x.timestamp)
# Apply pagination
total = len(messages_list)
start = (page - 1) * limit
end = start + limit
paginated_messages = messages_list[start:end]
paginated_response = create_paginated_response(
[m.model_dump(by_alias=True) for m in paginated_messages],
page, limit, total
[m.model_dump(by_alias=True) for m in paginated_messages], page, limit, total
)
return create_success_response(paginated_response)
except Exception as e:
logger.error(f"❌ Get chat messages error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("FETCH_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e)))
@router.patch("/sessions/{session_id}")
async def update_chat_session(
session_id: str = Path(...),
updates: Dict[str, Any] = Body(...),
current_user = Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database),
):
"""Update a chat session's properties"""
try:
# Get the existing session
session_data = await database.get_chat_session(session_id)
if not session_data:
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
session = ChatSession.model_validate(session_data)
# Check authorization - user can only update their own sessions
if session.user_id != current_user.id:
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Cannot update another user's chat session")
status_code=403, content=create_error_response("FORBIDDEN", "Cannot update another user's chat session")
)
# Validate and apply updates
allowed_fields = {"title", "context", "isArchived", "systemPrompt"}
filtered_updates = {k: v for k, v in updates.items() if k in allowed_fields}
if not filtered_updates:
return JSONResponse(
status_code=400,
content=create_error_response("INVALID_UPDATES", "No valid fields provided for update")
status_code=400, content=create_error_response("INVALID_UPDATES", "No valid fields provided for update")
)
# Apply updates to session data
session_dict = session.model_dump()
# Handle special field mappings (camelCase to snake_case)
if "isArchived" in filtered_updates:
session_dict["is_archived"] = filtered_updates["isArchived"]
@ -380,7 +321,7 @@ async def update_chat_session(
# Merge context updates with existing context
existing_context = session_dict.get("context", {})
context_updates = filtered_updates["context"]
# Update specific context fields while preserving others
for context_key, context_value in context_updates.items():
if context_key == "additionalContext":
@ -397,141 +338,116 @@ async def update_chat_session(
snake_key = "related_entity_type"
elif context_key == "aiParameters":
snake_key = "ai_parameters"
existing_context[snake_key] = context_value
session_dict["context"] = existing_context
# Update last activity timestamp
session_dict["last_activity"] = datetime.now(UTC).isoformat()
# Validate the updated session
updated_session = ChatSession.model_validate(session_dict)
# Save to database
await database.set_chat_session(session_id, updated_session.model_dump())
logger.info(f"✅ Chat session {session_id} updated by user {current_user.id}")
return create_success_response(updated_session.model_dump(by_alias=True))
except ValueError as ve:
logger.warning(f"⚠️ Validation error updating chat session: {ve}")
return JSONResponse(
status_code=400,
content=create_error_response("VALIDATION_ERROR", str(ve))
)
return JSONResponse(status_code=400, content=create_error_response("VALIDATION_ERROR", str(ve)))
except Exception as e:
logger.error(f"❌ Update chat session error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("UPDATE_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("UPDATE_ERROR", str(e)))
@router.delete("/sessions/{session_id}")
async def delete_chat_session(
session_id: str = Path(...),
current_user = Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database),
):
"""Delete a chat session and all its messages"""
try:
# Get the session to verify it exists and check ownership
session_data = await database.get_chat_session(session_id)
if not session_data:
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
session = ChatSession.model_validate(session_data)
# Check authorization - user can only delete their own sessions
if session.user_id != current_user.id:
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Cannot delete another user's chat session")
status_code=403, content=create_error_response("FORBIDDEN", "Cannot delete another user's chat session")
)
# Delete all messages associated with this session
try:
await database.delete_chat_messages(session_id)
chat_messages = await database.get_chat_messages(session_id)
message_count = len(chat_messages)
message_count = len(chat_messages)
logger.info(f"🗑️ Deleted {message_count} messages from session {session_id}")
except Exception as e:
logger.warning(f"⚠️ Error deleting messages for session {session_id}: {e}")
# Continue with session deletion even if message deletion fails
# Delete the session itself
await database.delete_chat_session(session_id)
logger.info(f"🗑️ Chat session {session_id} deleted by user {current_user.id}")
return create_success_response({
"success": True,
"message": "Chat session deleted successfully",
"sessionId": session_id
})
return create_success_response(
{"success": True, "message": "Chat session deleted successfully", "sessionId": session_id}
)
except Exception as e:
logger.error(f"❌ Delete chat session error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("DELETE_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("DELETE_ERROR", str(e)))
@router.patch("/sessions/{session_id}/reset")
async def reset_chat_session(
session_id: str = Path(...),
current_user = Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database),
):
"""Delete a chat session and all its messages"""
try:
# Get the session to verify it exists and check ownership
session_data = await database.get_chat_session(session_id)
if not session_data:
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
session = ChatSession.model_validate(session_data)
# Check authorization - user can only delete their own sessions
if session.user_id != current_user.id:
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Cannot reset another user's chat session")
status_code=403, content=create_error_response("FORBIDDEN", "Cannot reset another user's chat session")
)
# Delete all messages associated with this session
try:
await database.delete_chat_messages(session_id)
chat_messages = await database.get_chat_messages(session_id)
message_count = len(chat_messages)
message_count = len(chat_messages)
logger.info(f"🗑️ Deleted {message_count} messages from session {session_id}")
except Exception as e:
logger.warning(f"⚠️ Error deleting messages for session {session_id}: {e}")
# Continue with session deletion even if message deletion fails
logger.info(f"🗑️ Chat session {session_id} reset by user {current_user.id}")
return create_success_response({
"success": True,
"message": "Chat session reset successfully",
"sessionId": session_id
})
except Exception as e:
logger.error(f"❌ Reset chat session error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("RESET_ERROR", str(e))
return create_success_response(
{"success": True, "message": "Chat session reset successfully", "sessionId": session_id}
)
except Exception as e:
logger.error(f"❌ Reset chat session error: {e}")
return JSONResponse(status_code=500, content=create_error_response("RESET_ERROR", str(e)))

View File

@ -9,58 +9,52 @@ from fastapi.responses import JSONResponse
from database.manager import RedisDatabase
from logger import logger
from utils.dependencies import (
get_current_admin, get_database
)
from utils.dependencies import get_current_admin, get_database
from utils.responses import create_success_response, create_error_response
# Create router for authentication endpoints
router = APIRouter(prefix="/auth", tags=["authentication"])
@router.get("/guest/{guest_id}")
async def debug_guest_session(
guest_id: str = Path(...),
admin_user = Depends(get_current_admin),
database: RedisDatabase = Depends(get_database)
guest_id: str = Path(...), admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
):
"""Debug guest session issues (admin only)"""
try:
# Check primary storage
primary_data = await database.redis.hget("guests", guest_id) # type: ignore
primary_data = await database.redis.hget("guests", guest_id) # type: ignore
primary_exists = primary_data is not None
# Check backup storage
# Check backup storage
backup_data = await database.redis.get(f"guest_backup:{guest_id}")
backup_exists = backup_data is not None
# Check user lookup
user_lookup = await database.get_user_by_id(guest_id)
# Get TTL info
primary_ttl = await database.redis.ttl(f"guests")
primary_ttl = await database.redis.ttl("guests")
backup_ttl = await database.redis.ttl(f"guest_backup:{guest_id}")
debug_info = {
"guest_id": guest_id,
"primary_storage": {
"exists": primary_exists,
"data": json.loads(primary_data) if primary_data else None,
"ttl": primary_ttl
"ttl": primary_ttl,
},
"backup_storage": {
"exists": backup_exists,
"exists": backup_exists,
"data": json.loads(backup_data) if backup_data else None,
"ttl": backup_ttl
"ttl": backup_ttl,
},
"user_lookup": user_lookup,
"timestamp": datetime.now(UTC).isoformat()
"timestamp": datetime.now(UTC).isoformat(),
}
return create_success_response(debug_info)
except Exception as e:
logger.error(f"❌ Debug guest session error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("DEBUG_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("DEBUG_ERROR", str(e)))

View File

@ -10,12 +10,8 @@ from fastapi.responses import JSONResponse
from database.manager import RedisDatabase
from logger import logger
from models import (
CreateEmployerRequest
)
from utils.dependencies import (
get_database
)
from models import CreateEmployerRequest
from utils.dependencies import get_database
from utils.responses import create_success_response, create_error_response
from email_service import email_service
from utils.auth_utils import AuthenticationManager
@ -23,34 +19,27 @@ from utils.auth_utils import AuthenticationManager
# Create router for job endpoints
router = APIRouter(prefix="/employers", tags=["employers"])
@router.post("")
async def create_employer_with_verification(
request: CreateEmployerRequest,
background_tasks: BackgroundTasks,
database: RedisDatabase = Depends(get_database)
request: CreateEmployerRequest, background_tasks: BackgroundTasks, database: RedisDatabase = Depends(get_database)
):
"""Create a new employer with email verification"""
try:
# Similar to candidate creation but for employer
auth_manager = AuthenticationManager(database)
user_exists, conflict_field = await auth_manager.check_user_exists(
request.email,
request.username
)
user_exists, conflict_field = await auth_manager.check_user_exists(request.email, request.username)
if user_exists and conflict_field:
return JSONResponse(
status_code=409,
content=create_error_response(
"USER_EXISTS",
f"A user with this {conflict_field} already exists"
)
content=create_error_response("USER_EXISTS", f"A user with this {conflict_field} already exists"),
)
employer_id = str(uuid.uuid4())
current_time = datetime.now(timezone.utc)
employer_data = {
"id": employer_id,
"email": request.email,
@ -64,46 +53,35 @@ async def create_employer_with_verification(
"updatedAt": current_time.isoformat(),
"status": "pending", # Not active until verified
"userType": "employer",
"location": {
"city": "",
"country": "",
"remote": False
},
"socialLinks": []
"location": {"city": "", "country": "", "remote": False},
"socialLinks": [],
}
verification_token = secrets.token_urlsafe(32)
await database.store_email_verification_token(
request.email,
verification_token,
"employer",
{"employer_data": employer_data, "password": request.password, "username": request.username},
)
background_tasks.add_task(
email_service.send_verification_email, request.email, verification_token, request.company_name
)
logger.info(f"✅ Employer registration initiated for: {request.email}")
return create_success_response(
{
"employer_data": employer_data,
"password": request.password,
"username": request.username
"message": "Registration successful! Please check your email to verify your account.",
"email": request.email,
"verificationRequired": True,
}
)
background_tasks.add_task(
email_service.send_verification_email,
request.email,
verification_token,
request.company_name
)
logger.info(f"✅ Employer registration initiated for: {request.email}")
return create_success_response({
"message": "Registration successful! Please check your email to verify your account.",
"email": request.email,
"verificationRequired": True
})
except Exception as e:
logger.error(f"❌ Employer creation error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("CREATION_FAILED", "Failed to create employer account")
status_code=500, content=create_error_response("CREATION_FAILED", "Failed to create employer account")
)

View File

@ -21,14 +21,24 @@ from utils.helpers import create_job_from_content, filter_and_paginate, get_docu
from database.manager import RedisDatabase
from logger import logger
from models import (
MOCK_UUID, ApiActivityType, ApiStatusType, ChatContextType, ChatMessage, ChatMessageError, ChatMessageStatus, DocumentType, Job, JobRequirementsMessage, Candidate, Employer
)
from utils.dependencies import (
get_current_admin, get_database, get_current_user
MOCK_UUID,
ApiActivityType,
ApiStatusType,
ChatContextType,
ChatMessage,
ChatMessageError,
ChatMessageStatus,
DocumentType,
Job,
JobRequirementsMessage,
Candidate,
Employer,
)
from utils.dependencies import get_current_admin, get_database, get_current_user
from utils.responses import create_paginated_response, create_success_response, create_error_response
import utils.llm_proxy as llm_manager
import entities.entity_manager as entities
# Create router for job endpoints
router = APIRouter(prefix="/jobs", tags=["jobs"])
@ -38,14 +48,14 @@ async def reformat_as_markdown(database: RedisDatabase, candidate_entity: Candid
if not chat_agent:
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="No agent found for job requirements chat type"
content="No agent found for job requirements chat type",
)
yield error_message
return
status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Reformatting job description as markdown...",
activity=ApiActivityType.CONVERTING
content="Reformatting job description as markdown...",
activity=ApiActivityType.CONVERTING,
)
yield status_message
@ -58,7 +68,7 @@ async def reformat_as_markdown(database: RedisDatabase, candidate_entity: Candid
system_prompt="""
You are a document editor. Take the provided job description and reformat as legible markdown.
Return only the markdown content, no other text. Make sure all content is included.
"""
""",
):
pass
@ -66,16 +76,16 @@ Return only the markdown content, no other text. Make sure all content is includ
logger.error("❌ Failed to reformat job description to markdown")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to reformat job description"
content="Failed to reformat job description",
)
yield error_message
return
chat_message : ChatMessage = message
chat_message: ChatMessage = message
try:
chat_message.content = chat_agent.extract_markdown_from_text(chat_message.content)
except Exception as e:
except Exception:
pass
logger.info(f"✅ Successfully converted content to markdown")
logger.info("✅ Successfully converted content to markdown")
yield chat_message
return
@ -84,11 +94,11 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Initiating connection with {current_user.first_name}'s AI agent...",
activity=ApiActivityType.INFO
activity=ApiActivityType.INFO,
)
yield status_message
await asyncio.sleep(0) # Let the status message propagate
await asyncio.sleep(0) # Let the status message propagate
async with entities.get_candidate_entity(candidate=current_user) as candidate_entity:
message = None
async for message in reformat_as_markdown(database, candidate_entity, content):
@ -98,7 +108,7 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
if not message or not isinstance(message, ChatMessage):
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to reformat job description"
content="Failed to reformat job description",
)
yield error_message
return
@ -108,23 +118,20 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
if not chat_agent:
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="No agent found for job requirements chat type"
content="No agent found for job requirements chat type",
)
yield error_message
return
status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Analyzing document for company and requirement details...",
activity=ApiActivityType.SEARCHING
content="Analyzing document for company and requirement details...",
activity=ApiActivityType.SEARCHING,
)
yield status_message
message = None
async for message in chat_agent.generate(
llm=llm_manager.get_llm(),
model=defines.model,
session_id=MOCK_UUID,
prompt=markdown_message.content
llm=llm_manager.get_llm(), model=defines.model, session_id=MOCK_UUID, prompt=markdown_message.content
):
if message.status != ApiStatusType.DONE:
yield message
@ -132,23 +139,22 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
if not message or not isinstance(message, JobRequirementsMessage):
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Job extraction did not convert successfully"
content="Job extraction did not convert successfully",
)
yield error_message
return
job_requirements : JobRequirementsMessage = message
job_requirements: JobRequirementsMessage = message
logger.info(f"✅ Successfully generated job requirements for job {job_requirements.id}")
yield job_requirements
return
@router.post("")
async def create_job(
job_data: Dict[str, Any] = Body(...),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Create a new job"""
try:
@ -158,27 +164,24 @@ async def create_job(
job.id = str(uuid.uuid4())
job.owner_id = current_user.id
job.owner = current_user
await database.set_job(job.id, job.model_dump())
return create_success_response(job.model_dump(by_alias=True))
except Exception as e:
logger.error(f"❌ Job creation error: {e}")
return JSONResponse(
status_code=400,
content=create_error_response("CREATION_FAILED", str(e))
)
return JSONResponse(status_code=400, content=create_error_response("CREATION_FAILED", str(e)))
@router.post("")
async def create_candidate_job(
job_data: Dict[str, Any] = Body(...),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Create a new job"""
is_employer = isinstance(current_user, Employer)
isinstance(current_user, Employer)
try:
job = Job.model_validate(job_data)
@ -187,46 +190,39 @@ async def create_candidate_job(
job.id = str(uuid.uuid4())
job.owner_id = current_user.id
job.owner = current_user
await database.set_job(job.id, job.model_dump())
return create_success_response(job.model_dump(by_alias=True))
except Exception as e:
logger.error(f"❌ Job creation error: {e}")
return JSONResponse(
status_code=400,
content=create_error_response("CREATION_FAILED", str(e))
)
return JSONResponse(status_code=400, content=create_error_response("CREATION_FAILED", str(e)))
@router.patch("/{job_id}")
async def update_job(
job_id: str = Path(...),
updates: Dict[str, Any] = Body(...),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Update a candidate"""
try:
job_data = await database.get_job(job_id)
if not job_data:
logger.warning(f"⚠️ Job not found for update: {job_data}")
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Job not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Job not found"))
job = Job.model_validate(job_data)
# Check authorization (user can only update their own profile)
if current_user.is_admin is False and job.owner_id != current_user.id:
logger.warning(f"⚠️ Unauthorized update attempt by user {current_user.id} on job {job_id}")
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Cannot update another user's job")
status_code=403, content=create_error_response("FORBIDDEN", "Cannot update another user's job")
)
# Apply updates
updates["updatedAt"] = datetime.now(UTC).isoformat()
logger.info(f"🔄 Updating job {job_id} with data: {updates}")
@ -234,36 +230,33 @@ async def update_job(
job_dict.update(updates)
updated_job = Job.model_validate(job_dict)
await database.set_job(job_id, updated_job.model_dump())
return create_success_response(updated_job.model_dump(by_alias=True))
except Exception as e:
logger.error(f"❌ Update job error: {e}")
return JSONResponse(
status_code=400,
content=create_error_response("UPDATE_FAILED", str(e))
)
return JSONResponse(status_code=400, content=create_error_response("UPDATE_FAILED", str(e)))
@router.post("/from-content")
async def create_job_from_description(
content: str = Body(...),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
content: str = Body(...), current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)
):
"""Upload a document for the current candidate"""
async def content_stream_generator(content):
# Verify user is a candidate
if current_user.user_type != "candidate":
logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Only candidates can upload documents"
content="Only candidates can upload documents",
)
yield error_message
return
logger.info(f"📁 Received file content: size='{len(content)} bytes'")
last_yield_was_streaming = False
async for message in create_job_from_content(database=database, current_user=current_user, content=content):
if message.status != ApiStatusType.STREAMING:
@ -277,10 +270,11 @@ async def create_job_from_description(
return
try:
async def to_json(method):
try:
async for message in method:
json_data = message.model_dump(mode='json', by_alias=True)
json_data = message.model_dump(mode="json", by_alias=True)
json_str = json.dumps(json_data)
yield f"data: {json_str}\n\n".encode("utf-8")
except Exception as e:
@ -304,18 +298,25 @@ async def create_job_from_description(
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Document upload error: {e}")
return StreamingResponse(
iter([json.dumps(ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to upload document"
).model_dump(by_alias=True)).encode("utf-8")]),
media_type="text/event-stream"
iter(
[
json.dumps(
ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to upload document",
).model_dump(by_alias=True)
).encode("utf-8")
]
),
media_type="text/event-stream",
)
@router.post("/upload")
async def create_job_from_file(
file: UploadFile = File(...),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Upload a job document for the current candidate and create a Job"""
# Check file size (limit to 10MB)
@ -324,66 +325,81 @@ async def create_job_from_file(
if len(file_content) > max_size:
logger.info(f"⚠️ File too large: {file.filename} ({len(file_content)} bytes)")
return StreamingResponse(
iter([json.dumps(ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="File size exceeds 10MB limit"
).model_dump(by_alias=True)).encode("utf-8")]),
media_type="text/event-stream"
iter(
[
json.dumps(
ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="File size exceeds 10MB limit",
).model_dump(by_alias=True)
).encode("utf-8")
]
),
media_type="text/event-stream",
)
if len(file_content) == 0:
logger.info(f"⚠️ File is empty: {file.filename}")
return StreamingResponse(
iter([json.dumps(ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="File is empty"
).model_dump(by_alias=True)).encode("utf-8")]),
media_type="text/event-stream"
iter(
[
json.dumps(
ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="File is empty",
).model_dump(by_alias=True)
).encode("utf-8")
]
),
media_type="text/event-stream",
)
"""Upload a document for the current candidate"""
async def upload_stream_generator(file_content):
# Verify user is a candidate
if current_user.user_type != "candidate":
logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Only candidates can upload documents"
content="Only candidates can upload documents",
)
yield error_message
return
file.filename = re.sub(r'^.*/', '', file.filename) if file.filename else '' # Sanitize filename
file.filename = re.sub(r"^.*/", "", file.filename) if file.filename else "" # Sanitize filename
if not file.filename or file.filename.strip() == "":
logger.warning("⚠️ File upload attempt with missing filename")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="File must have a valid filename"
content="File must have a valid filename",
)
yield error_message
return
logger.info(f"📁 Received file upload: filename='{file.filename}', content_type='{file.content_type}', size='{len(file_content)} bytes'")
logger.info(
f"📁 Received file upload: filename='{file.filename}', content_type='{file.content_type}', size='{len(file_content)} bytes'"
)
# Validate file type
allowed_types = ['.txt', '.md', '.docx', '.pdf', '.png', '.jpg', '.jpeg', '.gif']
allowed_types = [".txt", ".md", ".docx", ".pdf", ".png", ".jpg", ".jpeg", ".gif"]
file_extension = pathlib.Path(file.filename).suffix.lower() if file.filename else ""
if file_extension not in allowed_types:
logger.warning(f"⚠️ Invalid file type: {file_extension} for file {file.filename}")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}"
content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}",
)
yield error_message
return
document_type = get_document_type_from_filename(file.filename or "unknown.txt")
if document_type != DocumentType.MARKDOWN and document_type != DocumentType.TXT:
status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Converting content from {document_type}...",
activity=ApiActivityType.CONVERTING
activity=ApiActivityType.CONVERTING,
)
yield status_message
try:
@ -391,8 +407,8 @@ async def create_job_from_file(
stream = io.BytesIO(file_content)
stream_info = StreamInfo(
extension=file_extension, # e.g., ".pdf"
url=file.filename # optional, helps with logging and guessing
)
url=file.filename, # optional, helps with logging and guessing
)
result = md.convert_stream(stream, stream_info=stream_info, output_format="markdown")
file_content = result.text_content
logger.info(f"✅ Converted {file.filename} to Markdown format")
@ -404,16 +420,19 @@ async def create_job_from_file(
yield error_message
logger.error(f"❌ Error converting {file.filename} to Markdown: {e}")
return
async for message in create_job_from_content(database=database, current_user=current_user, content=file_content):
async for message in create_job_from_content(
database=database, current_user=current_user, content=file_content
):
yield message
return
try:
async def to_json(method):
try:
async for message in method:
json_data = message.model_dump(mode='json', by_alias=True)
json_data = message.model_dump(mode="json", by_alias=True)
json_str = json.dumps(json_data)
yield f"data: {json_str}\n\n".encode("utf-8")
except Exception as e:
@ -437,40 +456,39 @@ async def create_job_from_file(
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Document upload error: {e}")
return StreamingResponse(
iter([json.dumps(ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to upload document"
).model_dump(mode='json', by_alias=True)).encode("utf-8")]),
media_type="text/event-stream"
iter(
[
json.dumps(
ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to upload document",
).model_dump(mode="json", by_alias=True)
).encode("utf-8")
]
),
media_type="text/event-stream",
)
@router.get("/{job_id}")
async def get_job(
job_id: str = Path(...),
database: RedisDatabase = Depends(get_database)
):
async def get_job(job_id: str = Path(...), database: RedisDatabase = Depends(get_database)):
"""Get a job by ID"""
try:
job_data = await database.get_job(job_id)
if not job_data:
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Job not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Job not found"))
# Increment view count
job_data["views"] = job_data.get("views", 0) + 1
await database.set_job(job_id, job_data)
job = Job.model_validate(job_data)
return create_success_response(job.model_dump(by_alias=True))
except Exception as e:
logger.error(f"❌ Get job error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("FETCH_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e)))
@router.get("")
async def get_jobs(
@ -479,37 +497,32 @@ async def get_jobs(
sortBy: Optional[str] = Query(None, alias="sortBy"),
sortOrder: str = Query("desc", pattern="^(asc|desc)$", alias="sortOrder"),
filters: Optional[str] = Query(None),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Get paginated list of jobs"""
try:
filter_dict = None
if filters:
filter_dict = json.loads(filters)
# Get all jobs from Redis
all_jobs_data = await database.get_all_jobs()
jobs_list = []
for job in all_jobs_data.values():
jobs_list.append(Job.model_validate(job))
paginated_jobs, total = filter_and_paginate(
jobs_list, page, limit, sortBy, sortOrder, filter_dict
)
paginated_jobs, total = filter_and_paginate(jobs_list, page, limit, sortBy, sortOrder, filter_dict)
paginated_response = create_paginated_response(
[j.model_dump(by_alias=True) for j in paginated_jobs],
page, limit, total
[j.model_dump(by_alias=True) for j in paginated_jobs], page, limit, total
)
return create_success_response(paginated_response)
except Exception as e:
logger.error(f"❌ Get jobs error: {e}")
return JSONResponse(
status_code=400,
content=create_error_response("FETCH_FAILED", str(e))
)
return JSONResponse(status_code=400, content=create_error_response("FETCH_FAILED", str(e)))
@router.get("/search")
async def search_jobs(
@ -517,84 +530,67 @@ async def search_jobs(
filters: Optional[str] = Query(None),
page: int = Query(1, ge=1),
limit: int = Query(20, ge=1, le=100),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Search jobs"""
try:
filter_dict = {}
if filters:
filter_dict = json.loads(filters)
# Get all jobs from Redis
all_jobs_data = await database.get_all_jobs()
jobs_list = [Job.model_validate(data) for data in all_jobs_data.values() if data.get("is_active", True)]
if query:
query_lower = query.lower()
jobs_list = [
j for j in jobs_list
if ((j.title and query_lower in j.title.lower()) or
(j.description and query_lower in j.description.lower()) or
any(query_lower in skill.lower() for skill in getattr(j, "skills", []) or []))
j
for j in jobs_list
if (
(j.title and query_lower in j.title.lower())
or (j.description and query_lower in j.description.lower())
or any(query_lower in skill.lower() for skill in getattr(j, "skills", []) or [])
)
]
paginated_jobs, total = filter_and_paginate(
jobs_list, page, limit, filters=filter_dict
)
paginated_jobs, total = filter_and_paginate(jobs_list, page, limit, filters=filter_dict)
paginated_response = create_paginated_response(
[j.model_dump(by_alias=True) for j in paginated_jobs],
page, limit, total
[j.model_dump(by_alias=True) for j in paginated_jobs], page, limit, total
)
return create_success_response(paginated_response)
except Exception as e:
logger.error(f"❌ Search jobs error: {e}")
return JSONResponse(
status_code=400,
content=create_error_response("SEARCH_FAILED", str(e))
)
return JSONResponse(status_code=400, content=create_error_response("SEARCH_FAILED", str(e)))
@router.delete("/{job_id}")
async def delete_job(
job_id: str = Path(...),
admin_user = Depends(get_current_admin),
database: RedisDatabase = Depends(get_database)
job_id: str = Path(...), admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
):
"""Delete a Job"""
try:
# Check if admin user
if not admin_user.is_admin:
logger.warning(f"⚠️ Unauthorized delete attempt by user {admin_user.id}")
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Only admins can delete")
)
return JSONResponse(status_code=403, content=create_error_response("FORBIDDEN", "Only admins can delete"))
# Get candidate data
job_data = await database.get_job(job_id)
if not job_data:
logger.warning(f"⚠️ Candidate not found for deletion: {job_id}")
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Job not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Job not found"))
# Delete job from database
await database.delete_job(job_id)
logger.info(f"🗑️ Job deleted: {job_id} by admin {admin_user.id}")
return create_success_response({
"message": "Job deleted successfully",
"jobId": job_id
})
return create_success_response({"message": "Job deleted successfully", "jobId": job_id})
except Exception as e:
logger.error(f"❌ Delete job error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("DELETE_ERROR", "Failed to delete job")
)
return JSONResponse(status_code=500, content=create_error_response("DELETE_ERROR", "Failed to delete job"))

View File

@ -6,40 +6,37 @@ from utils.llm_proxy import LLMProvider, get_llm
router = APIRouter(prefix="/providers", tags=["providers"])
@router.get("/models")
async def list_models(provider: Optional[str] = None):
"""List available models for a provider"""
try:
llm = get_llm()
provider_enum = None
if provider:
try:
provider_enum = LLMProvider(provider.lower())
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Unsupported provider: {provider}"
)
raise HTTPException(status_code=400, detail=f"Unsupported provider: {provider}")
models = await llm.list_models(provider_enum)
return {
"provider": provider_enum.value if provider_enum else llm.default_provider.value,
"models": models
}
return {"provider": provider_enum.value if provider_enum else llm.default_provider.value, "models": models}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("")
async def list_providers():
"""List all configured providers"""
llm = get_llm()
return {
"providers": [provider.value for provider in llm._initialized_providers],
"default": llm.default_provider.value
"default": llm.default_provider.value,
}
@router.post("/{provider}/set-default")
async def set_default_provider(provider: str):
"""Set the default provider"""
@ -51,6 +48,7 @@ async def set_default_provider(provider: str):
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# Health check endpoint
@router.get("/health")
async def health_check():
@ -59,5 +57,5 @@ async def health_check():
return {
"status": "healthy",
"providers_configured": len(llm._initialized_providers),
"default_provider": llm.default_provider.value
"default_provider": llm.default_provider.value,
}

View File

@ -11,26 +11,24 @@ from fastapi.responses import StreamingResponse
import backstory_traceback as backstory_traceback
from database.manager import RedisDatabase
from logger import logger
from models import (
MOCK_UUID, ChatMessageError, Job, Candidate, Resume, ResumeMessage
)
from utils.dependencies import (
get_database, get_current_user
)
from models import MOCK_UUID, ChatMessageError, Job, Candidate, Resume, ResumeMessage
from utils.dependencies import get_database, get_current_user
from utils.responses import create_success_response
# Create router for authentication endpoints
router = APIRouter(prefix="/resumes", tags=["resumes"])
@router.post("/{candidate_id}/{job_id}")
async def create_candidate_resume(
candidate_id: str = Path(..., description="ID of the candidate"),
candidate_id: str = Path(..., description="ID of the candidate"),
job_id: str = Path(..., description="ID of the job"),
resume_content: str = Body(...),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Create a new resume for a candidate/job combination"""
async def message_stream_generator():
logger.info(f"🔍 Looking up candidate and job details for {candidate_id}/{job_id}")
@ -39,10 +37,10 @@ async def create_candidate_resume(
logger.error(f"❌ Candidate with ID '{candidate_id}' not found")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Candidate with ID '{candidate_id}' not found"
content=f"Candidate with ID '{candidate_id}' not found",
)
yield error_message
return
return
candidate = Candidate.model_validate(candidate_data)
job_data = await database.get_job(job_id)
@ -50,14 +48,16 @@ async def create_candidate_resume(
logger.error(f"❌ Job with ID '{job_id}' not found")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Job with ID '{job_id}' not found"
content=f"Job with ID '{job_id}' not found",
)
yield error_message
return
job = Job.model_validate(job_data)
logger.info(f"📄 Saving resume for candidate {candidate.first_name} {candidate.last_name} for job '{job.title}'")
logger.info(
f"📄 Saving resume for candidate {candidate.first_name} {candidate.last_name} for job '{job.title}'"
)
# Job and Candidate are valid. Save the resume
resume = Resume(
job_id=job_id,
@ -66,28 +66,26 @@ async def create_candidate_resume(
)
resume_message: ResumeMessage = ResumeMessage(
session_id=MOCK_UUID, # No session ID for document uploads
resume=resume
resume=resume,
)
# Save to database
success = await database.set_resume(current_user.id, resume.model_dump())
if not success:
error_message = ChatMessageError(
session_id=MOCK_UUID,
content="Failed to save resume to database"
)
error_message = ChatMessageError(session_id=MOCK_UUID, content="Failed to save resume to database")
yield error_message
return
logger.info(f"✅ Successfully saved resume {resume_message.resume.id} for user {current_user.id}")
yield resume_message
return
try:
async def to_json(method):
try:
async for message in method:
json_data = message.model_dump(mode='json', by_alias=True)
json_data = message.model_dump(mode="json", by_alias=True)
json_str = json.dumps(json_data)
yield f"data: {json_str}\n\n".encode("utf-8")
except Exception as e:
@ -111,22 +109,26 @@ async def create_candidate_resume(
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Resume creation error: {e}")
return StreamingResponse(
iter([json.dumps(ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to create resume"
).model_dump(mode='json', by_alias=True))]),
media_type="text/event-stream"
iter(
[
json.dumps(
ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to create resume",
).model_dump(mode="json", by_alias=True)
)
]
),
media_type="text/event-stream",
)
@router.get("")
async def get_user_resumes(
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
):
async def get_user_resumes(current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)):
"""Get all resumes for the current user"""
try:
resumes_data = await database.get_all_resumes_for_user(current_user.id)
resumes : List[Resume] = [Resume.model_validate(data) for data in resumes_data]
resumes: List[Resume] = [Resume.model_validate(data) for data in resumes_data]
for resume in resumes:
job_data = await database.get_job(resume.job_id)
if job_data:
@ -135,156 +137,130 @@ async def get_user_resumes(
if candidate_data:
resume.candidate = Candidate.model_validate(candidate_data)
resumes.sort(key=lambda x: x.updated_at, reverse=True) # Sort by creation date
return create_success_response({
"resumes": resumes,
"count": len(resumes)
})
return create_success_response({"resumes": resumes, "count": len(resumes)})
except Exception as e:
logger.error(f"❌ Error retrieving resumes for user {current_user.id}: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve resumes")
@router.get("/{resume_id}")
async def get_resume(
resume_id: str = Path(..., description="ID of the resume"),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Get a specific resume by ID"""
try:
resume = await database.get_resume(current_user.id, resume_id)
if not resume:
raise HTTPException(status_code=404, detail="Resume not found")
return {
"success": True,
"resume": resume
}
return {"success": True, "resume": resume}
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ Error retrieving resume {resume_id} for user {current_user.id}: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve resume")
@router.delete("/{resume_id}")
async def delete_resume(
resume_id: str = Path(..., description="ID of the resume"),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Delete a specific resume"""
try:
success = await database.delete_resume(current_user.id, resume_id)
if not success:
raise HTTPException(status_code=404, detail="Resume not found")
return {
"success": True,
"message": f"Resume {resume_id} deleted successfully"
}
return {"success": True, "message": f"Resume {resume_id} deleted successfully"}
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ Error deleting resume {resume_id} for user {current_user.id}: {e}")
raise HTTPException(status_code=500, detail="Failed to delete resume")
@router.get("/candidate/{candidate_id}")
async def get_resumes_by_candidate(
candidate_id: str = Path(..., description="ID of the candidate"),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Get all resumes for a specific candidate"""
try:
resumes = await database.get_resumes_by_candidate(current_user.id, candidate_id)
return {
"success": True,
"candidate_id": candidate_id,
"resumes": resumes,
"count": len(resumes)
}
return {"success": True, "candidate_id": candidate_id, "resumes": resumes, "count": len(resumes)}
except Exception as e:
logger.error(f"❌ Error retrieving resumes for candidate {candidate_id}: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve candidate resumes")
@router.get("/job/{job_id}")
async def get_resumes_by_job(
job_id: str = Path(..., description="ID of the job"),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Get all resumes for a specific job"""
try:
resumes = await database.get_resumes_by_job(current_user.id, job_id)
return {
"success": True,
"job_id": job_id,
"resumes": resumes,
"count": len(resumes)
}
return {"success": True, "job_id": job_id, "resumes": resumes, "count": len(resumes)}
except Exception as e:
logger.error(f"❌ Error retrieving resumes for job {job_id}: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve job resumes")
@router.get("/search")
async def search_resumes(
q: str = Query(..., description="Search query"),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Search resumes by content"""
try:
resumes = await database.search_resumes_for_user(current_user.id, q)
return {
"success": True,
"query": q,
"resumes": resumes,
"count": len(resumes)
}
return {"success": True, "query": q, "resumes": resumes, "count": len(resumes)}
except Exception as e:
logger.error(f"❌ Error searching resumes for user {current_user.id}: {e}")
raise HTTPException(status_code=500, detail="Failed to search resumes")
@router.get("/stats")
async def get_resume_statistics(
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)
):
"""Get resume statistics for the current user"""
try:
stats = await database.get_resume_statistics(current_user.id)
return {
"success": True,
"statistics": stats
}
return {"success": True, "statistics": stats}
except Exception as e:
logger.error(f"❌ Error retrieving resume statistics for user {current_user.id}: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve resume statistics")
@router.put("/{resume_id}")
async def update_resume(
resume_id: str = Path(..., description="ID of the resume"),
resume: str = Body(..., description="Updated resume content"),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Update the content of a specific resume"""
try:
updates = {
"resume": resume,
"updated_at": datetime.now(UTC).isoformat()
}
updates = {"resume": resume, "updated_at": datetime.now(UTC).isoformat()}
updated_resume_data = await database.update_resume(current_user.id, resume_id, updates)
if not updated_resume_data:
logger.warning(f"⚠️ Resume {resume_id} not found for user {current_user.id}")
raise HTTPException(status_code=404, detail="Resume not found")
updated_resume = Resume.model_validate(updated_resume_data) if updated_resume_data else None
return create_success_response({
"success": True,
"message": f"Resume {resume_id} updated successfully",
"resume": updated_resume
})
return create_success_response(
{"success": True, "message": f"Resume {resume_id} updated successfully", "resume": updated_resume}
)
except HTTPException:
raise
except Exception as e:

View File

@ -10,6 +10,7 @@ from utils.responses import create_success_response
# Create router for authentication endpoints
router = APIRouter(prefix="/system", tags=["system"])
async def get_redis() -> Redis:
"""Dependency to get Redis client"""
return redis_manager.get_client()
@ -19,9 +20,11 @@ async def get_redis() -> Redis:
async def get_system_info(request: Request):
"""Get system information"""
from system_info import system_info # Import system_info function from system_info module
system = system_info()
return create_success_response(system.model_dump(mode='json'))
return create_success_response(system.model_dump(mode="json"))
@router.get("/redis/stats")
async def redis_stats(redis: Redis = Depends(get_redis)):
@ -33,7 +36,7 @@ async def redis_stats(redis: Redis = Depends(get_redis)):
"total_commands_processed": info.get("total_commands_processed"),
"keyspace_hits": info.get("keyspace_hits"),
"keyspace_misses": info.get("keyspace_misses"),
"uptime_in_seconds": info.get("uptime_in_seconds")
"uptime_in_seconds": info.get("uptime_in_seconds"),
}
except Exception as e:
raise HTTPException(status_code=503, detail=f"Redis stats unavailable: {e}")

View File

@ -7,23 +7,17 @@ from fastapi.responses import JSONResponse
from database.manager import RedisDatabase
from logger import logger
from models import (
BaseUserWithType
)
from utils.dependencies import (
get_database
)
from models import BaseUserWithType
from utils.dependencies import get_database
from utils.responses import create_success_response, create_error_response
# Create router for job endpoints
router = APIRouter(prefix="/users", tags=["users"])
# reference can be candidateId, username, or email
@router.get("/users/{reference}")
async def get_user(
reference: str = Path(...),
database: RedisDatabase = Depends(get_database)
):
async def get_user(reference: str = Path(...), database: RedisDatabase = Depends(get_database)):
"""Get a candidate by username"""
try:
# Normalize reference to lowercase for case-insensitive search
@ -31,50 +25,41 @@ async def get_user(
all_candidate_data = await database.get_all_candidates()
if not all_candidate_data:
logger.warning(f"⚠️ No users found in database")
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "No users found")
)
logger.warning("⚠️ No users found in database")
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "No users found"))
user_data = None
for user in all_candidate_data.values():
if (user.get("id", "").lower() == query_lower or
user.get("username", "").lower() == query_lower or
user.get("email", "").lower() == query_lower):
if (
user.get("id", "").lower() == query_lower
or user.get("username", "").lower() == query_lower
or user.get("email", "").lower() == query_lower
):
user_data = user
break
if not user_data:
all_guest_data = await database.get_all_guests()
if not all_guest_data:
logger.warning(f"⚠️ No guests found in database")
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "No users found")
)
logger.warning("⚠️ No guests found in database")
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "No users found"))
for user in all_guest_data.values():
if (user.get("id", "").lower() == query_lower or
user.get("username", "").lower() == query_lower or
user.get("email", "").lower() == query_lower):
if (
user.get("id", "").lower() == query_lower
or user.get("username", "").lower() == query_lower
or user.get("email", "").lower() == query_lower
):
user_data = user
break
if not user_data:
if not user_data:
logger.warning(f"⚠️ User nor Guest found for reference: {reference}")
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "User not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "User not found"))
user = BaseUserWithType.model_validate(user_data)
return create_success_response(user.model_dump(by_alias=True))
except Exception as e:
logger.error(f"❌ Get user error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("FETCH_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e)))

View File

@ -4,6 +4,7 @@ import subprocess
import math
from models import SystemInfo
def get_installed_ram():
try:
with open("/proc/meminfo", "r") as f:
@ -19,9 +20,7 @@ def get_graphics_cards():
gpus = []
try:
# Run the ze-monitor utility
result = subprocess.run(
["ze-monitor"], capture_output=True, text=True, check=True
)
result = subprocess.run(["ze-monitor"], capture_output=True, text=True, check=True)
# Clean up the output (remove leading/trailing whitespace and newlines)
output = result.stdout.strip()
@ -71,6 +70,7 @@ def get_cpu_info():
except Exception as e:
return f"Error retrieving CPU info: {e}"
def system_info() -> SystemInfo:
"""
Collects system information including RAM, GPU, CPU, LLM model, embedding model, and context length.
@ -85,4 +85,4 @@ def system_info() -> SystemInfo:
embedding_model=defines.embedding_model,
max_context_length=defines.max_context,
)
return system
return system

View File

@ -1,17 +1,17 @@
import os
from pydantic import BaseModel
from pydantic import BaseModel
from typing import List, Optional, Any, Dict
from datetime import datetime
from typing import (
Any,
)
from bs4 import BeautifulSoup
from bs4 import BeautifulSoup
from geopy.geocoders import Nominatim
import pytz
from geopy.geocoders import Nominatim
import pytz
import requests
import yfinance as yf
import yfinance as yf
import logging
@ -122,9 +122,7 @@ def get_forecast(grid_endpoint):
# Process the forecast data into a simpler format
forecast = {
"location": data["properties"]
.get("relativeLocation", {})
.get("properties", {}),
"location": data["properties"].get("relativeLocation", {}).get("properties", {}),
"updated": data["properties"].get("updated", ""),
"periods": [],
}
@ -181,7 +179,7 @@ def get_forecast(grid_endpoint):
def TickerValue(ticker_symbols):
api_key = os.getenv("TWELVEDATA_API_KEY", "")
if not api_key:
return {"error": f"Error fetching data: No API key for TwelveData"}
return {"error": "Error fetching data: No API key for TwelveData"}
results = []
for ticker_symbol in ticker_symbols.split(","):
@ -189,9 +187,7 @@ def TickerValue(ticker_symbols):
if ticker_symbol == "":
continue
url = (
f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}"
)
url = f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}"
response = requests.get(url)
data = response.json()
@ -244,9 +240,7 @@ def yfTickerValue(ticker_symbols):
logging.error(f"Error fetching data for {ticker_symbol}: {e}")
logging.error(traceback.format_exc())
results.append(
{"error": f"Error fetching data for {ticker_symbol}: {str(e)}"}
)
results.append({"error": f"Error fetching data for {ticker_symbol}: {str(e)}"})
return results[0] if len(results) == 1 else results
@ -278,8 +272,10 @@ def DateTime(timezone="America/Los_Angeles"):
except Exception as e:
return {"error": f"Invalid timezone {timezone}: {str(e)}"}
async def GenerateImage(llm, model: str, prompt: str):
return { "image_id": "image-a830a83-bd831" }
return {"image_id": "image-a830a83-bd831"}
async def AnalyzeSite(llm, model: str, url: str, question: str):
"""
@ -347,7 +343,6 @@ async def AnalyzeSite(llm, model: str, url: str, question: str):
return f"Error processing the website content: {str(e)}"
# %%
class Function(BaseModel):
name: str
@ -355,171 +350,181 @@ class Function(BaseModel):
parameters: Dict[str, Any]
returns: Optional[Dict[str, Any]] = {}
class Tool(BaseModel):
type: str
function: Function
tools : List[Tool] = [
# Tool.model_validate({
# "type": "function",
# "function": {
# "name": "GenerateImage",
# "description": """\
# CRITICAL INSTRUCTIONS FOR IMAGE GENERATION:
# 1. Call this tool when users request images, drawings, or visual content
# 2. This tool returns an image_id (e.g., "img_abc123")
# 3. MANDATORY: You must respond with EXACTLY this format: <GenerateImage id={image_id}/>
# 4. FORBIDDEN: DO NOT use markdown image syntax ![](url)
# 5. FORBIDDEN: DO NOT create fake URLs or file paths
# 6. FORBIDDEN: DO NOT use any other image embedding format
# CORRECT EXAMPLE:
# User: "Draw a cat"
# Tool returns: {"image_id": "img_xyz789"}
# Your response: "Here's your cat image: <GenerateImage id=img_xyz789/>"
# WRONG EXAMPLES (DO NOT DO THIS):
# - ![](https://example.com/...)
# - ![Cat image](any_url)
# - <img src="...">
# The <GenerateImage id={image_id}/> format is the ONLY way to display images in this system.
# """,
# "parameters": {
# "type": "object",
# "properties": {
# "prompt": {
# "type": "string",
# "description": "Detailed image description including style, colors, subject, composition"
# }
# },
# "required": ["prompt"]
# },
# "returns": {
# "type": "object",
# "properties": {
# "image_id": {
# "type": "string",
# "description": "Unique identifier for the generated image. Use this EXACTLY in <GenerateImage id={this_value}/>"
# }
# }
# }
# }
# }),
Tool.model_validate({
"type": "function",
"function": {
"name": "TickerValue",
"description": "Get the current stock price of one or more ticker symbols. Returns an array of objects with 'symbol' and 'price' fields. Call this whenever you need to know the latest value of stock ticker symbols, for example when a user asks 'How much is Intel trading at?' or 'What are the prices of AAPL and MSFT?'",
"parameters": {
"type": "object",
"properties": {
"ticker": {
"type": "string",
"description": "The company stock ticker symbol. For multiple tickers, provide a comma-separated list (e.g., 'AAPL,MSFT,GOOGL').",
tools: List[Tool] = [
# Tool.model_validate({
# "type": "function",
# "function": {
# "name": "GenerateImage",
# "description": """\
# CRITICAL INSTRUCTIONS FOR IMAGE GENERATION:
# 1. Call this tool when users request images, drawings, or visual content
# 2. This tool returns an image_id (e.g., "img_abc123")
# 3. MANDATORY: You must respond with EXACTLY this format: <GenerateImage id={image_id}/>
# 4. FORBIDDEN: DO NOT use markdown image syntax ![](url)
# 5. FORBIDDEN: DO NOT create fake URLs or file paths
# 6. FORBIDDEN: DO NOT use any other image embedding format
# CORRECT EXAMPLE:
# User: "Draw a cat"
# Tool returns: {"image_id": "img_xyz789"}
# Your response: "Here's your cat image: <GenerateImage id=img_xyz789/>"
# WRONG EXAMPLES (DO NOT DO THIS):
# - ![](https://example.com/...)
# - ![Cat image](any_url)
# - <img src="...">
# The <GenerateImage id={image_id}/> format is the ONLY way to display images in this system.
# """,
# "parameters": {
# "type": "object",
# "properties": {
# "prompt": {
# "type": "string",
# "description": "Detailed image description including style, colors, subject, composition"
# }
# },
# "required": ["prompt"]
# },
# "returns": {
# "type": "object",
# "properties": {
# "image_id": {
# "type": "string",
# "description": "Unique identifier for the generated image. Use this EXACTLY in <GenerateImage id={this_value}/>"
# }
# }
# }
# }
# }),
Tool.model_validate(
{
"type": "function",
"function": {
"name": "TickerValue",
"description": "Get the current stock price of one or more ticker symbols. Returns an array of objects with 'symbol' and 'price' fields. Call this whenever you need to know the latest value of stock ticker symbols, for example when a user asks 'How much is Intel trading at?' or 'What are the prices of AAPL and MSFT?'",
"parameters": {
"type": "object",
"properties": {
"ticker": {
"type": "string",
"description": "The company stock ticker symbol. For multiple tickers, provide a comma-separated list (e.g., 'AAPL,MSFT,GOOGL').",
},
},
"required": ["ticker"],
"additionalProperties": False,
},
"required": ["ticker"],
"additionalProperties": False,
},
},
}),
Tool.model_validate({
"type": "function",
"function": {
"name": "AnalyzeSite",
"description": "Downloads the requested site and asks a second LLM agent to answer the question based on the site content. For example if the user says 'What are the top headlines on cnn.com?' you would use AnalyzeSite to get the answer. Only use this if the user asks about a specific site or company.",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The website URL to download and process",
},
"question": {
"type": "string",
"description": "The question to ask the second LLM about the content",
}
),
Tool.model_validate(
{
"type": "function",
"function": {
"name": "AnalyzeSite",
"description": "Downloads the requested site and asks a second LLM agent to answer the question based on the site content. For example if the user says 'What are the top headlines on cnn.com?' you would use AnalyzeSite to get the answer. Only use this if the user asks about a specific site or company.",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The website URL to download and process",
},
"question": {
"type": "string",
"description": "The question to ask the second LLM about the content",
},
},
"required": ["url", "question"],
"additionalProperties": False,
},
"required": ["url", "question"],
"additionalProperties": False,
},
"returns": {
"type": "object",
"properties": {
"source": {
"type": "string",
"description": "Identifier for the source LLM",
},
"content": {
"type": "string",
"description": "The complete response from the second LLM",
},
"metadata": {
"type": "object",
"description": "Additional information about the response",
"returns": {
"type": "object",
"properties": {
"source": {
"type": "string",
"description": "Identifier for the source LLM",
},
"content": {
"type": "string",
"description": "The complete response from the second LLM",
},
"metadata": {
"type": "object",
"description": "Additional information about the response",
},
},
},
},
},
}),
Tool.model_validate({
"type": "function",
"function": {
"name": "DateTime",
"description": "Get the current date and time in a specified timezone. For example if a user asks 'What time is it in Poland?' you would pass the Warsaw timezone to DateTime.",
"parameters": {
"type": "object",
"properties": {
"timezone": {
"type": "string",
"description": "Timezone name (e.g., 'UTC', 'America/New_York', 'Europe/London', 'America/Los_Angeles'). Default is 'America/Los_Angeles'.",
}
},
"required": [],
},
},
}),
Tool.model_validate({
"type": "function",
"function": {
"name": "WeatherForecast",
"description": "Get the full weather forecast as structured data for a given CITY and STATE location in the United States. For example, if the user asks 'What is the weather in Portland?' or 'What is the forecast for tomorrow?' use the provided data to answer the question.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City to find the weather forecast (e.g., 'Portland', 'Seattle').",
"minLength": 2,
},
"state": {
"type": "string",
"description": "State to find the weather forecast (e.g., 'OR', 'WA').",
"minLength": 2,
}
),
Tool.model_validate(
{
"type": "function",
"function": {
"name": "DateTime",
"description": "Get the current date and time in a specified timezone. For example if a user asks 'What time is it in Poland?' you would pass the Warsaw timezone to DateTime.",
"parameters": {
"type": "object",
"properties": {
"timezone": {
"type": "string",
"description": "Timezone name (e.g., 'UTC', 'America/New_York', 'Europe/London', 'America/Los_Angeles'). Default is 'America/Los_Angeles'.",
}
},
"required": [],
},
"required": ["city", "state"],
"additionalProperties": False,
},
},
}),
}
),
Tool.model_validate(
{
"type": "function",
"function": {
"name": "WeatherForecast",
"description": "Get the full weather forecast as structured data for a given CITY and STATE location in the United States. For example, if the user asks 'What is the weather in Portland?' or 'What is the forecast for tomorrow?' use the provided data to answer the question.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City to find the weather forecast (e.g., 'Portland', 'Seattle').",
"minLength": 2,
},
"state": {
"type": "string",
"description": "State to find the weather forecast (e.g., 'OR', 'WA').",
"minLength": 2,
},
},
"required": ["city", "state"],
"additionalProperties": False,
},
},
}
),
]
class ToolEntry(BaseModel):
enabled: bool = True
tool: Tool
def llm_tools(tools: List[ToolEntry]) -> List[Dict[str, Any]]:
return [entry.tool.model_dump(mode='json') for entry in tools if entry.enabled == True]
return [entry.tool.model_dump(mode="json") for entry in tools if entry.enabled is True]
def all_tools() -> List[ToolEntry]:
return [ToolEntry(tool=tool) for tool in tools]
def enabled_tools(tools: List[ToolEntry]) -> List[ToolEntry]:
return [ToolEntry(tool=entry.tool) for entry in tools if entry.enabled == True]
return [ToolEntry(tool=entry.tool) for entry in tools if entry.enabled is True]
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite", "GenerateImage"]
__all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"]

View File

@ -1,3 +1,3 @@
"""
Utils package - Utility functions and dependencies
"""
"""

View File

@ -5,65 +5,65 @@ Provides password hashing, verification, and security features
"""
from __future__ import annotations
import traceback
import bcrypt
import bcrypt
import secrets
import logging
from datetime import datetime, timezone, timedelta
from typing import Dict, Any, Optional, Tuple
from pydantic import BaseModel
from pydantic import BaseModel
logger = logging.getLogger(__name__)
class PasswordSecurity:
"""Handles password hashing and verification using bcrypt"""
@staticmethod
def hash_password(password: str) -> Tuple[str, str]:
"""
Hash a password with a random salt using bcrypt
Args:
password: Plain text password
Returns:
Tuple of (password_hash, salt) both as strings
"""
# Generate a random salt
salt = bcrypt.gensalt()
# Hash the password
password_hash = bcrypt.hashpw(password.encode('utf-8'), salt)
return password_hash.decode('utf-8'), salt.decode('utf-8')
password_hash = bcrypt.hashpw(password.encode("utf-8"), salt)
return password_hash.decode("utf-8"), salt.decode("utf-8")
@staticmethod
def verify_password(password: str, password_hash: str) -> bool:
"""
Verify a password against its hash
Args:
password: Plain text password to verify
password_hash: Stored password hash
Returns:
True if password matches, False otherwise
"""
try:
return bcrypt.checkpw(
password.encode('utf-8'),
password_hash.encode('utf-8')
)
return bcrypt.checkpw(password.encode("utf-8"), password_hash.encode("utf-8"))
except Exception as e:
logger.error(f"Password verification error: {e}")
return False
@staticmethod
def generate_secure_token(length: int = 32) -> str:
"""Generate a cryptographically secure random token"""
return secrets.token_urlsafe(length)
class AuthenticationRecord(BaseModel):
"""Authentication record for storing user credentials"""
user_id: str
password_hash: str
salt: str
@ -76,67 +76,70 @@ class AuthenticationRecord(BaseModel):
mfa_secret: Optional[str] = None
login_attempts: int = 0
locked_until: Optional[datetime] = None
class Config:
json_encoders = {
datetime: lambda v: v.isoformat() if v else None
}
json_encoders = {datetime: lambda v: v.isoformat() if v else None}
class SecurityConfig:
"""Security configuration constants"""
MAX_LOGIN_ATTEMPTS = 5
ACCOUNT_LOCKOUT_DURATION_MINUTES = 15
PASSWORD_MIN_LENGTH = 8
TOKEN_EXPIRY_HOURS = 24
REFRESH_TOKEN_EXPIRY_DAYS = 30
class AuthenticationManager:
"""Manages authentication operations with security features"""
def __init__(self, database):
self.database = database
self.password_security = PasswordSecurity()
async def create_user_authentication(self, user_id: str, password: str) -> AuthenticationRecord:
"""
Create authentication record for a new user
Args:
user_id: Unique user identifier
password: Plain text password
Returns:
AuthenticationRecord object
"""
if len(password) < SecurityConfig.PASSWORD_MIN_LENGTH:
raise ValueError(f"Password must be at least {SecurityConfig.PASSWORD_MIN_LENGTH} characters long")
# Hash the password
password_hash, salt = self.password_security.hash_password(password)
# Create authentication record
auth_record = AuthenticationRecord(
user_id=user_id,
password_hash=password_hash,
salt=salt,
last_password_change=datetime.now(timezone.utc),
login_attempts=0
login_attempts=0,
)
# Store in database
await self.database.set_authentication(user_id, auth_record.model_dump())
logger.info(f"🔐 Created authentication record for user {user_id}")
return auth_record
async def verify_user_credentials(self, login: str, password: str) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
async def verify_user_credentials(
self, login: str, password: str
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
"""
Verify user credentials with security checks
Args:
login: Email or username
password: Plain text password
Returns:
Tuple of (is_valid, user_data, error_message)
"""
@ -146,15 +149,15 @@ class AuthenticationManager:
if not user_data:
logger.warning(f"⚠️ Login attempt with non-existent user: {login}")
return False, None, "Invalid credentials"
# Get authentication record
auth_record = await self.database.get_authentication(user_data["id"])
if not auth_record:
logger.error(f"❌ No authentication record found for user {user_data['id']}")
return False, None, "Authentication record not found"
auth_data = AuthenticationRecord.model_validate(auth_record)
# Check if account is locked
if auth_data.locked_until and auth_data.locked_until > datetime.now(timezone.utc):
time_until_unlock = auth_data.locked_until - datetime.now(timezone.utc)
@ -164,48 +167,54 @@ class AuthenticationManager:
seconds = int(total_seconds % 60)
time_until_unlock_str = f"{minutes}m {seconds}s"
logger.warning(f"🔒 Account is locked for user {login} for another {time_until_unlock_str}.")
return False, None, f"Account is temporarily locked due to too many failed attempts. Retry after {time_until_unlock_str}"
return (
False,
None,
f"Account is temporarily locked due to too many failed attempts. Retry after {time_until_unlock_str}",
)
# Verify password
if not self.password_security.verify_password(password, auth_data.password_hash):
# Increment failed attempts
auth_data.login_attempts += 1
# Lock account if too many attempts
if auth_data.login_attempts >= SecurityConfig.MAX_LOGIN_ATTEMPTS:
auth_data.locked_until = datetime.now(timezone.utc) + timedelta(
minutes=SecurityConfig.ACCOUNT_LOCKOUT_DURATION_MINUTES
)
logger.warning(f"🔒 Account locked for user {login} after {auth_data.login_attempts} failed attempts")
logger.warning(
f"🔒 Account locked for user {login} after {auth_data.login_attempts} failed attempts"
)
# Update authentication record
await self.database.set_authentication(user_data["id"], auth_data.model_dump())
logger.warning(f"⚠️ Invalid password for user {login} (attempt {auth_data.login_attempts})")
return False, None, "Invalid credentials"
# Reset failed attempts on successful login
if auth_data.login_attempts > 0:
auth_data.login_attempts = 0
auth_data.locked_until = None
await self.database.set_authentication(user_data["id"], auth_data.model_dump())
logger.info(f"✅ Successful authentication for user {login}")
return True, user_data, None
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"❌ Authentication error for user {login}: {e}")
return False, None, "Authentication failed"
async def check_user_exists(self, email: str, username: str | None = None) -> Tuple[bool, Optional[str]]:
"""
Check if a user already exists with the given email or username
Args:
email: Email address to check
username: Username to check (optional)
Returns:
Tuple of (exists, conflict_field)
"""
@ -214,20 +223,20 @@ class AuthenticationManager:
existing_user = await self.database.get_user(email)
if existing_user:
return True, "email"
# Check username if provided
if username:
existing_user = await self.database.get_user(username)
if existing_user:
return True, "username"
return False, None
except Exception as e:
logger.error(f"❌ Error checking user existence: {e}")
# In case of error, assume user doesn't exist to avoid blocking creation
return False, None
async def update_last_login(self, user_id: str):
"""Update user's last login timestamp"""
try:
@ -238,38 +247,40 @@ class AuthenticationManager:
except Exception as e:
logger.error(f"❌ Error updating last login for user {user_id}: {e}")
# Utility functions for common operations
def validate_password_strength(password: str) -> Tuple[bool, list]:
"""
Validate password strength
Args:
password: Password to validate
Returns:
Tuple of (is_valid, list_of_issues)
"""
issues = []
if len(password) < SecurityConfig.PASSWORD_MIN_LENGTH:
issues.append(f"Password must be at least {SecurityConfig.PASSWORD_MIN_LENGTH} characters long")
if not any(c.isupper() for c in password):
issues.append("Password must contain at least one uppercase letter")
if not any(c.islower() for c in password):
issues.append("Password must contain at least one lowercase letter")
if not any(c.isdigit() for c in password):
issues.append("Password must contain at least one digit")
# Check for special characters
special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?"
if not any(c in special_chars for c in password):
issues.append("Password must contain at least one special character")
return len(issues) == 0, issues
def sanitize_login_input(login: str) -> str:
"""Sanitize login input (email or username)"""
return login.strip().lower() if login else ""
return login.strip().lower() if login else ""

View File

@ -19,7 +19,7 @@ from models import BaseUserWithType, Candidate, CandidateAI, Employer, Guest
from logger import logger
from background_tasks import BackgroundTaskManager
#from . rate_limiter import RateLimiter
# from . rate_limiter import RateLimiter
# Security
security = HTTPBearer()
@ -33,11 +33,13 @@ background_task_manager: Optional[BackgroundTaskManager] = None
# Global database manager reference
db_manager = None
def set_db_manager(manager: DatabaseManager):
"""Set the global database manager reference"""
global db_manager
db_manager = manager
def get_database() -> RedisDatabase:
"""
Safe database dependency that checks for availability
@ -47,26 +49,18 @@ def get_database() -> RedisDatabase:
if db_manager is None:
logger.error("Database manager not initialized")
raise HTTPException(
status_code=503,
detail="Database not available - service starting up"
)
raise HTTPException(status_code=503, detail="Database not available - service starting up")
if db_manager.is_shutting_down:
logger.warning("Database is shutting down")
raise HTTPException(
status_code=503,
detail="Service is shutting down"
)
raise HTTPException(status_code=503, detail="Service is shutting down")
try:
return db_manager.get_database()
except RuntimeError as e:
logger.error(f"Database not available: {e}")
raise HTTPException(
status_code=503,
detail="Database connection not available"
)
raise HTTPException(status_code=503, detail="Database connection not available")
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
@ -78,6 +72,7 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""Enhanced token verification with guest session recovery"""
try:
@ -87,35 +82,35 @@ async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials
payload = jwt.decode(credentials.credentials, JWT_SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get("sub")
token_type: str = payload.get("type", "access")
if user_id is None:
raise HTTPException(status_code=401, detail="Invalid authentication credentials")
# Check if token is blacklisted
redis = redis_manager.get_client()
blacklist_key = f"blacklisted_token:{credentials.credentials}"
is_blacklisted = await redis.exists(blacklist_key)
if is_blacklisted:
logger.warning(f"🚫 Attempt to use blacklisted token for user {user_id}")
raise HTTPException(status_code=401, detail="Token has been revoked")
# For guest tokens, verify guest still exists and update activity
if token_type == "guest" or payload.get("type") == "guest":
database = db_manager.get_database()
guest_data = await database.get_guest(user_id)
if not guest_data:
logger.warning(f"🚫 Guest session not found for token: {user_id}")
raise HTTPException(status_code=401, detail="Guest session expired")
# Update guest activity
guest_data["last_activity"] = datetime.now(UTC).isoformat()
await database.set_guest(user_id, guest_data)
logger.debug(f"🔄 Guest activity updated: {user_id}")
return user_id
except jwt.PyJWTError as e:
logger.warning(f"⚠️ JWT decode error: {e}")
raise HTTPException(status_code=401, detail="Invalid authentication credentials")
@ -125,9 +120,9 @@ async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials
logger.error(f"❌ Token verification error: {e}")
raise HTTPException(status_code=401, detail="Token verification failed")
async def get_current_user(
user_id: str = Depends(verify_token_with_blacklist),
database: RedisDatabase = Depends(get_database)
user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database)
) -> BaseUserWithType:
"""Get current user from database"""
try:
@ -135,54 +130,59 @@ async def get_current_user(
candidate_data = await database.get_candidate(user_id)
if candidate_data:
from helpers.model_cast import cast_to_base_user_with_type
if candidate_data.get("is_AI"):
return cast_to_base_user_with_type(CandidateAI.model_validate(candidate_data))
else:
return cast_to_base_user_with_type(Candidate.model_validate(candidate_data))
# Check employers
employer = await database.get_employer(user_id)
if employer:
return Employer.model_validate(employer)
logger.warning(f"⚠️ User {user_id} not found in database")
raise HTTPException(status_code=404, detail="User not found")
except Exception as e:
logger.error(f"❌ Error getting current user: {e}")
raise HTTPException(status_code=404, detail="User not found")
async def get_current_user_or_guest(
user_id: str = Depends(verify_token_with_blacklist),
database: RedisDatabase = Depends(get_database)
user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database)
) -> BaseUserWithType:
"""Get current user (including guests) from database"""
try:
# Check candidates first
candidate_data = await database.get_candidate(user_id)
if candidate_data:
return Candidate.model_validate(candidate_data) if not candidate_data.get("is_AI") else CandidateAI.model_validate(candidate_data)
return (
Candidate.model_validate(candidate_data)
if not candidate_data.get("is_AI")
else CandidateAI.model_validate(candidate_data)
)
# Check employers
employer_data = await database.get_employer(user_id)
if employer_data:
return Employer.model_validate(employer_data)
# Check guests
guest_data = await database.get_guest(user_id)
if guest_data:
return Guest.model_validate(guest_data)
logger.warning(f"⚠️ User {user_id} not found in database")
raise HTTPException(status_code=404, detail="User not found")
except Exception as e:
logger.error(f"❌ Error getting current user: {e}")
raise HTTPException(status_code=404, detail="User not found")
async def get_current_admin(
user_id: str = Depends(verify_token_with_blacklist),
database: RedisDatabase = Depends(get_database)
user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database)
) -> BaseUserWithType:
user = await get_current_user(user_id=user_id, database=database)
if isinstance(user, Candidate) and user.is_admin:
@ -193,6 +193,7 @@ async def get_current_admin(
logger.warning(f"⚠️ User {user_id} is not an admin")
raise HTTPException(status_code=403, detail="Admin access required")
prometheus_collector = CollectorRegistry()
# Keep the Instrumentator instance alive
@ -201,5 +202,5 @@ instrumentator = Instrumentator(
should_ignore_untemplated=True,
should_group_untemplated=True,
excluded_handlers=[f"{defines.api_prefix}/metrics"],
registry=prometheus_collector
registry=prometheus_collector,
)

View File

@ -1,5 +1,6 @@
from typing import List, Dict
from models import (Job)
from models import Job
def get_requirements_list(job: Job) -> List[Dict[str, str]]:
requirements: List[Dict[str, str]] = []
@ -7,56 +8,56 @@ def get_requirements_list(job: Job) -> List[Dict[str, str]]:
if job.requirements:
if job.requirements.technical_skills:
if job.requirements.technical_skills.required:
requirements.extend([
{"requirement": req, "domain": "Technical Skills (required)"}
for req in job.requirements.technical_skills.required
])
requirements.extend(
[
{"requirement": req, "domain": "Technical Skills (required)"}
for req in job.requirements.technical_skills.required
]
)
if job.requirements.technical_skills.preferred:
requirements.extend([
{"requirement": req, "domain": "Technical Skills (preferred)"}
for req in job.requirements.technical_skills.preferred
])
requirements.extend(
[
{"requirement": req, "domain": "Technical Skills (preferred)"}
for req in job.requirements.technical_skills.preferred
]
)
if job.requirements.experience_requirements:
if job.requirements.experience_requirements.required:
requirements.extend([
{"requirement": req, "domain": "Experience (required)"}
for req in job.requirements.experience_requirements.required
])
requirements.extend(
[
{"requirement": req, "domain": "Experience (required)"}
for req in job.requirements.experience_requirements.required
]
)
if job.requirements.experience_requirements.preferred:
requirements.extend([
{"requirement": req, "domain": "Experience (preferred)"}
for req in job.requirements.experience_requirements.preferred
])
if job.requirements.soft_skills:
requirements.extend([
{"requirement": req, "domain": "Soft Skills"}
for req in job.requirements.soft_skills
])
if job.requirements.experience:
requirements.extend([
{"requirement": req, "domain": "Experience"}
for req in job.requirements.experience
])
if job.requirements.education:
requirements.extend([
{"requirement": req, "domain": "Education"}
for req in job.requirements.education
])
if job.requirements.certifications:
requirements.extend([
{"requirement": req, "domain": "Certifications"}
for req in job.requirements.certifications
])
if job.requirements.preferred_attributes:
requirements.extend([
{"requirement": req, "domain": "Preferred Attributes"}
for req in job.requirements.preferred_attributes
])
requirements.extend(
[
{"requirement": req, "domain": "Experience (preferred)"}
for req in job.requirements.experience_requirements.preferred
]
)
return requirements
if job.requirements.soft_skills:
requirements.extend([{"requirement": req, "domain": "Soft Skills"} for req in job.requirements.soft_skills])
if job.requirements.experience:
requirements.extend([{"requirement": req, "domain": "Experience"} for req in job.requirements.experience])
if job.requirements.education:
requirements.extend([{"requirement": req, "domain": "Education"} for req in job.requirements.education])
if job.requirements.certifications:
requirements.extend(
[{"requirement": req, "domain": "Certifications"} for req in job.requirements.certifications]
)
if job.requirements.preferred_attributes:
requirements.extend(
[
{"requirement": req, "domain": "Preferred Attributes"}
for req in job.requirements.preferred_attributes
]
)
return requirements

View File

@ -13,44 +13,13 @@ from fastapi.responses import StreamingResponse
import defines
from logger import logger
from models import DocumentType
from models import (
LoginRequest, CreateCandidateRequest, CreateEmployerRequest,
Candidate, Employer, Guest, AuthResponse,
MFARequest, MFAData, MFARequestResponse, MFAVerifyRequest,
EmailVerificationRequest, ResendVerificationRequest,
# API
MOCK_UUID, ApiActivityType, ChatMessageError, ChatMessageResume,
ChatMessageSkillAssessment, ChatMessageStatus, ChatMessageStreaming,
ChatMessageUser, DocumentMessage, DocumentOptions, Job,
JobRequirements, JobRequirementsMessage, LoginRequest,
CreateCandidateRequest, CreateEmployerRequest,
# User models
Candidate, Employer, BaseUserWithType, BaseUser, Guest,
Authentication, AuthResponse, CandidateAI,
# Job models
JobApplication, ApplicationStatus,
# Chat models
ChatSession, ChatMessage, ChatContext, ChatQuery, ChatSenderType, ApiMessageType, ChatContextType,
ChatMessageRagSearch,
# Document models
Document, DocumentType, DocumentListResponse, DocumentUpdateRequest, DocumentContentResponse,
# Supporting models
Location, MFARequest, MFAData, MFARequestResponse, MFAVerifyRequest, RagContentMetadata, RagContentResponse, ResendVerificationRequest, Resume, ResumeMessage, Skill, SkillAssessment, SystemInfo, UserType, WorkExperience, Education,
# Email
EmailVerificationRequest,
ApiStatusType
)
from models import Job, ChatMessage, DocumentType, ApiStatusType
from typing import List, Dict
from models import (Job)
from models import Job
import utils.llm_proxy as llm_manager
async def get_last_item(generator):
"""Get the last item from an async generator"""
last_item = None
@ -65,20 +34,19 @@ def filter_and_paginate(
limit: int = 20,
sort_by: Optional[str] = None,
sort_order: str = "desc",
filters: Optional[Dict] = None
filters: Optional[Dict] = None,
) -> Tuple[List[Any], int]:
"""Filter, sort, and paginate items"""
filtered_items = items.copy()
# Apply filters (simplified filtering logic)
if filters:
for key, value in filters.items():
if isinstance(filtered_items[0], dict) and key in filtered_items[0]:
filtered_items = [item for item in filtered_items if item.get(key) == value]
elif hasattr(filtered_items[0], key) if filtered_items else False:
filtered_items = [item for item in filtered_items
if getattr(item, key, None) == value]
filtered_items = [item for item in filtered_items if getattr(item, key, None) == value]
# Sort items
if sort_by and filtered_items:
reverse = sort_order.lower() == "desc"
@ -89,25 +57,25 @@ def filter_and_paginate(
filtered_items.sort(key=lambda x: getattr(x, sort_by, ""), reverse=reverse)
except (AttributeError, TypeError):
pass # Skip sorting if attribute doesn't exist or isn't comparable
# Paginate
total = len(filtered_items)
start = (page - 1) * limit
end = start + limit
paginated_items = filtered_items[start:end]
return paginated_items, total
async def stream_agent_response(chat_agent, user_message, chat_session_data=None, database=None) -> StreamingResponse:
"""Stream agent response with proper formatting"""
async def message_stream_generator():
"""Generator to stream messages with persistence"""
last_log = None
final_message = None
import utils.llm_proxy as llm_manager
async for generated_message in chat_agent.generate(
llm=llm_manager.get_llm(),
model=defines.model,
@ -118,36 +86,39 @@ async def stream_agent_response(chat_agent, user_message, chat_session_data=None
logger.error(f"❌ AI generation error: {generated_message.content}")
yield f"data: {json.dumps({'status': 'error'})}\n\n"
return
# Store reference to the complete AI message
if generated_message.status == ApiStatusType.DONE:
final_message = generated_message
# If the message is not done, convert it to a ChatMessageBase to remove
# metadata and other unnecessary fields for streaming
if generated_message.status != ApiStatusType.DONE:
from models import ChatMessageStreaming, ChatMessageStatus
if not isinstance(generated_message, ChatMessageStreaming) and not isinstance(generated_message, ChatMessageStatus):
if not isinstance(generated_message, ChatMessageStreaming) and not isinstance(
generated_message, ChatMessageStatus
):
raise TypeError(
f"Expected ChatMessageStreaming or ChatMessageStatus, got {type(generated_message)}"
)
json_data = generated_message.model_dump(mode='json', by_alias=True)
json_data = generated_message.model_dump(mode="json", by_alias=True)
json_str = json.dumps(json_data)
yield f"data: {json_str}\n\n"
# After streaming is complete, persist the final AI message to database
if final_message and final_message.status == ApiStatusType.DONE:
try:
if database and chat_session_data:
await database.add_chat_message(final_message.session_id, final_message.model_dump())
logger.info(f"🤖 Message saved to database for session {final_message.session_id}")
# Update session last activity again
chat_session_data["lastActivity"] = datetime.now(UTC).isoformat()
await database.set_chat_session(final_message.session_id, chat_session_data)
except Exception as e:
logger.error(f"❌ Failed to save message to database: {e}")
@ -175,20 +146,20 @@ def get_candidate_files_dir(username: str) -> pathlib.Path:
def get_document_type_from_filename(filename: str) -> DocumentType:
"""Determine document type from filename extension"""
extension = pathlib.Path(filename).suffix.lower()
type_mapping = {
'.pdf': DocumentType.PDF,
'.docx': DocumentType.DOCX,
'.doc': DocumentType.DOCX,
'.txt': DocumentType.TXT,
'.md': DocumentType.MARKDOWN,
'.markdown': DocumentType.MARKDOWN,
'.png': DocumentType.IMAGE,
'.jpg': DocumentType.IMAGE,
'.jpeg': DocumentType.IMAGE,
'.gif': DocumentType.IMAGE,
".pdf": DocumentType.PDF,
".docx": DocumentType.DOCX,
".doc": DocumentType.DOCX,
".txt": DocumentType.TXT,
".md": DocumentType.MARKDOWN,
".markdown": DocumentType.MARKDOWN,
".png": DocumentType.IMAGE,
".jpg": DocumentType.IMAGE,
".jpeg": DocumentType.IMAGE,
".gif": DocumentType.IMAGE,
}
return type_mapping.get(extension, DocumentType.TXT)
@ -203,20 +174,15 @@ async def reformat_as_markdown(database, candidate_entity, content: str):
"""Reformat content as markdown using AI agent"""
from models import ChatContextType, MOCK_UUID, ChatMessageError, ChatMessageStatus, ApiActivityType, ChatMessage
import utils.llm_proxy as llm_manager
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
if not chat_agent:
error_message = ChatMessageError(
session_id=MOCK_UUID,
content="No agent found for job requirements chat type"
)
error_message = ChatMessageError(session_id=MOCK_UUID, content="No agent found for job requirements chat type")
yield error_message
return
status_message = ChatMessageStatus(
session_id=MOCK_UUID,
content=f"Reformatting job description as markdown...",
activity=ApiActivityType.CONVERTING
session_id=MOCK_UUID, content="Reformatting job description as markdown...", activity=ApiActivityType.CONVERTING
)
yield status_message
@ -229,26 +195,23 @@ async def reformat_as_markdown(database, candidate_entity, content: str):
system_prompt="""
You are a document editor. Take the provided job description and reformat as legible markdown.
Return only the markdown content, no other text. Make sure all content is included.
"""
""",
):
pass
if not message or not isinstance(message, ChatMessage):
logger.error("❌ Failed to reformat job description to markdown")
error_message = ChatMessageError(
session_id=MOCK_UUID,
content="Failed to reformat job description"
)
error_message = ChatMessageError(session_id=MOCK_UUID, content="Failed to reformat job description")
yield error_message
return
chat_message: ChatMessage = message
try:
chat_message.content = chat_agent.extract_markdown_from_text(chat_message.content)
except Exception as e:
except Exception:
pass
logger.info(f"✅ Successfully converted content to markdown")
logger.info("✅ Successfully converted content to markdown")
yield chat_message
return
@ -256,14 +219,19 @@ Return only the markdown content, no other text. Make sure all content is includ
async def create_job_from_content(database, current_user, content: str):
"""Create a job from content using AI analysis"""
from models import (
MOCK_UUID, ApiStatusType, ChatMessageError, ChatMessageStatus,
ApiActivityType, ChatContextType, JobRequirementsMessage
MOCK_UUID,
ApiStatusType,
ChatMessageError,
ChatMessageStatus,
ApiActivityType,
ChatContextType,
JobRequirementsMessage,
)
status_message = ChatMessageStatus(
session_id=MOCK_UUID,
content=f"Initiating connection with {current_user.first_name}'s AI agent...",
activity=ApiActivityType.INFO
activity=ApiActivityType.INFO,
)
yield status_message
await asyncio.sleep(0) # Let the status message propagate
@ -276,112 +244,105 @@ async def create_job_from_content(database, current_user, content: str):
# Only yield one final DONE message
if message.status != ApiStatusType.DONE:
yield message
if not message or not isinstance(message, ChatMessage):
error_message = ChatMessageError(
session_id=MOCK_UUID,
content="Failed to reformat job description"
)
error_message = ChatMessageError(session_id=MOCK_UUID, content="Failed to reformat job description")
yield error_message
return
markdown_message = message
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
if not chat_agent:
error_message = ChatMessageError(
session_id=MOCK_UUID,
content="No agent found for job requirements chat type"
session_id=MOCK_UUID, content="No agent found for job requirements chat type"
)
yield error_message
return
status_message = ChatMessageStatus(
session_id=MOCK_UUID,
content=f"Analyzing document for company and requirement details...",
activity=ApiActivityType.SEARCHING
content="Analyzing document for company and requirement details...",
activity=ApiActivityType.SEARCHING,
)
yield status_message
message = None
async for message in chat_agent.generate(
llm=llm_manager.get_llm(),
model=defines.model,
session_id=MOCK_UUID,
prompt=markdown_message.content
llm=llm_manager.get_llm(), model=defines.model, session_id=MOCK_UUID, prompt=markdown_message.content
):
if message.status != ApiStatusType.DONE:
yield message
if not message or not isinstance(message, JobRequirementsMessage):
error_message = ChatMessageError(
session_id=MOCK_UUID,
content="Job extraction did not convert successfully"
session_id=MOCK_UUID, content="Job extraction did not convert successfully"
)
yield error_message
return
job_requirements: JobRequirementsMessage = message
logger.info(f"✅ Successfully generated job requirements for job {job_requirements.id}")
yield job_requirements
return
def get_requirements_list(job: Job) -> List[Dict[str, str]]:
requirements: List[Dict[str, str]] = []
if job.requirements:
if job.requirements.technical_skills:
if job.requirements.technical_skills.required:
requirements.extend([
{"requirement": req, "domain": "Technical Skills (required)"}
for req in job.requirements.technical_skills.required
])
requirements.extend(
[
{"requirement": req, "domain": "Technical Skills (required)"}
for req in job.requirements.technical_skills.required
]
)
if job.requirements.technical_skills.preferred:
requirements.extend([
{"requirement": req, "domain": "Technical Skills (preferred)"}
for req in job.requirements.technical_skills.preferred
])
requirements.extend(
[
{"requirement": req, "domain": "Technical Skills (preferred)"}
for req in job.requirements.technical_skills.preferred
]
)
if job.requirements.experience_requirements:
if job.requirements.experience_requirements.required:
requirements.extend([
{"requirement": req, "domain": "Experience (required)"}
for req in job.requirements.experience_requirements.required
])
requirements.extend(
[
{"requirement": req, "domain": "Experience (required)"}
for req in job.requirements.experience_requirements.required
]
)
if job.requirements.experience_requirements.preferred:
requirements.extend([
{"requirement": req, "domain": "Experience (preferred)"}
for req in job.requirements.experience_requirements.preferred
])
if job.requirements.soft_skills:
requirements.extend([
{"requirement": req, "domain": "Soft Skills"}
for req in job.requirements.soft_skills
])
if job.requirements.experience:
requirements.extend([
{"requirement": req, "domain": "Experience"}
for req in job.requirements.experience
])
if job.requirements.education:
requirements.extend([
{"requirement": req, "domain": "Education"}
for req in job.requirements.education
])
if job.requirements.certifications:
requirements.extend([
{"requirement": req, "domain": "Certifications"}
for req in job.requirements.certifications
])
if job.requirements.preferred_attributes:
requirements.extend([
{"requirement": req, "domain": "Preferred Attributes"}
for req in job.requirements.preferred_attributes
])
requirements.extend(
[
{"requirement": req, "domain": "Experience (preferred)"}
for req in job.requirements.experience_requirements.preferred
]
)
return requirements
if job.requirements.soft_skills:
requirements.extend([{"requirement": req, "domain": "Soft Skills"} for req in job.requirements.soft_skills])
if job.requirements.experience:
requirements.extend([{"requirement": req, "domain": "Experience"} for req in job.requirements.experience])
if job.requirements.education:
requirements.extend([{"requirement": req, "domain": "Education"} for req in job.requirements.education])
if job.requirements.certifications:
requirements.extend(
[{"requirement": req, "domain": "Certifications"} for req in job.requirements.certifications]
)
if job.requirements.preferred_attributes:
requirements.extend(
[
{"requirement": req, "domain": "Preferred Attributes"}
for req in job.requirements.preferred_attributes
]
)
return requirements

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,7 @@
from prometheus_client import Counter, Histogram # type: ignore
from threading import Lock
def singleton(cls):
instance = None
lock = Lock()

View File

@ -6,68 +6,80 @@ from functools import wraps
from datetime import datetime, timedelta, UTC
from typing import Callable, Dict, Optional, Any
from fastapi import Depends, HTTPException, Request
from pydantic import BaseModel # type: ignore
from pydantic import BaseModel # type: ignore
from database.manager import RedisDatabase
from logger import logger
from . dependencies import get_current_user_or_guest, get_database
from .dependencies import get_current_user_or_guest, get_database
async def get_rate_limiter(database: RedisDatabase = Depends(get_database)) -> RateLimiter:
"""Dependency to get rate limiter instance"""
return RateLimiter(database)
class RateLimitConfig(BaseModel):
"""Rate limit configuration"""
requests_per_minute: int
requests_per_hour: int
requests_per_day: int
burst_limit: int # Maximum requests in a short burst
burst_window_seconds: int = 60 # Window for burst detection
class GuestRateLimitConfig(RateLimitConfig):
"""Rate limits for guest users - more restrictive"""
requests_per_minute: int = 10
requests_per_hour: int = 100
requests_per_day: int = 500
burst_limit: int = 15
burst_window_seconds: int = 60
class AuthenticatedUserRateLimitConfig(RateLimitConfig):
"""Rate limits for authenticated users - more generous"""
requests_per_minute: int = 60
requests_per_hour: int = 1000
requests_per_day: int = 10000
burst_limit: int = 100
burst_window_seconds: int = 60
class PremiumUserRateLimitConfig(RateLimitConfig):
"""Rate limits for premium/admin users - most generous"""
requests_per_minute: int = 120
requests_per_hour: int = 5000
requests_per_day: int = 50000
burst_limit: int = 200
burst_window_seconds: int = 60
class RateLimitResult(BaseModel):
"""Result of rate limit check"""
allowed: bool
reason: Optional[str] = None
retry_after_seconds: Optional[int] = None
remaining_requests: Dict[str, int] = {}
reset_times: Dict[str, datetime] = {}
class RateLimiter:
"""Rate limiter using Redis for distributed rate limiting"""
def __init__(self, database: RedisDatabase):
self.database = database
self.redis = database.redis
# Rate limit configurations
self.guest_config = GuestRateLimitConfig()
self.user_config = AuthenticatedUserRateLimitConfig()
self.premium_config = PremiumUserRateLimitConfig()
def get_config_for_user(self, user_type: str, is_admin: bool = False) -> RateLimitConfig:
"""Get rate limit configuration based on user type"""
if user_type == "guest":
@ -76,66 +88,62 @@ class RateLimiter:
return self.premium_config
else:
return self.user_config
async def check_rate_limit(
self,
user_id: str,
user_type: str,
is_admin: bool = False,
endpoint: Optional[str] = None
self, user_id: str, user_type: str, is_admin: bool = False, endpoint: Optional[str] = None
) -> RateLimitResult:
"""
Check if user has exceeded rate limits
Args:
user_id: Unique identifier for the user (guest session ID or user ID)
user_type: "guest", "candidate", or "employer"
is_admin: Whether user has admin privileges
endpoint: Optional endpoint-specific rate limiting
Returns:
RateLimitResult indicating if request is allowed
"""
config = self.get_config_for_user(user_type, is_admin)
current_time = datetime.now(UTC)
# Create Redis keys for different time windows
base_key = f"rate_limit:{user_type}:{user_id}"
keys = {
"minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}",
"hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}",
"day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}",
"burst": f"{base_key}:burst"
"burst": f"{base_key}:burst",
}
# Add endpoint-specific limiting if provided
if endpoint:
keys = {k: f"{v}:{endpoint}" for k, v in keys.items()}
try:
# Use Redis pipeline for atomic operations
pipe = self.redis.pipeline()
# Get current counts
for key in keys.values():
pipe.get(key)
results = await pipe.execute()
current_counts = {
"minute": int(results[0] or 0),
"hour": int(results[1] or 0),
"day": int(results[2] or 0),
"burst": int(results[3] or 0)
"burst": int(results[3] or 0),
}
# Check limits
limits = {
"minute": config.requests_per_minute,
"hour": config.requests_per_hour,
"day": config.requests_per_day,
"burst": config.burst_limit
"burst": config.burst_limit,
}
# Check each limit
for window, current_count in current_counts.items():
limit = limits[window]
@ -146,110 +154,100 @@ class RateLimiter:
elif window == "hour":
retry_after = 3600 - (current_time.minute * 60 + current_time.second)
elif window == "day":
retry_after = 86400 - (current_time.hour * 3600 + current_time.minute * 60 + current_time.second)
retry_after = 86400 - (
current_time.hour * 3600 + current_time.minute * 60 + current_time.second
)
else: # burst
retry_after = config.burst_window_seconds
logger.warning(f"🚫 Rate limit exceeded for {user_type} {user_id}: {current_count}/{limit} {window}")
logger.warning(
f"🚫 Rate limit exceeded for {user_type} {user_id}: {current_count}/{limit} {window}"
)
return RateLimitResult(
allowed=False,
reason=f"Rate limit exceeded: {current_count}/{limit} requests per {window}",
retry_after_seconds=retry_after,
remaining_requests={k: max(0, limits[k] - v) for k, v in current_counts.items()},
reset_times=self._calculate_reset_times(current_time)
reset_times=self._calculate_reset_times(current_time),
)
# If we get here, request is allowed - increment counters
pipe = self.redis.pipeline()
# Increment minute counter (expires after 2 minutes)
pipe.incr(keys["minute"])
pipe.expire(keys["minute"], 120)
# Increment hour counter (expires after 2 hours)
pipe.incr(keys["hour"])
pipe.expire(keys["hour"], 7200)
# Increment day counter (expires after 2 days)
pipe.incr(keys["day"])
pipe.expire(keys["day"], 172800)
# Increment burst counter (expires after burst window)
pipe.incr(keys["burst"])
pipe.expire(keys["burst"], config.burst_window_seconds)
await pipe.execute()
# Calculate remaining requests
remaining = {
k: max(0, limits[k] - (current_counts[k] + 1))
for k in current_counts.keys()
}
remaining = {k: max(0, limits[k] - (current_counts[k] + 1)) for k in current_counts.keys()}
logger.debug(f"✅ Rate limit check passed for {user_type} {user_id}")
return RateLimitResult(
allowed=True,
remaining_requests=remaining,
reset_times=self._calculate_reset_times(current_time)
allowed=True, remaining_requests=remaining, reset_times=self._calculate_reset_times(current_time)
)
except Exception as e:
logger.error(f"❌ Rate limit check failed for {user_id}: {e}")
# Fail open - allow request if rate limiting system fails
return RateLimitResult(allowed=True, reason="Rate limit check failed - allowing request")
def _calculate_reset_times(self, current_time: datetime) -> Dict[str, datetime]:
"""Calculate when each rate limit window resets"""
next_minute = current_time.replace(second=0, microsecond=0) + timedelta(minutes=1)
next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
next_day = current_time.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
return {
"minute": next_minute,
"hour": next_hour,
"day": next_day
}
async def get_user_rate_limit_status(
self,
user_id: str,
user_type: str,
is_admin: bool = False
) -> Dict[str, Any]:
return {"minute": next_minute, "hour": next_hour, "day": next_day}
async def get_user_rate_limit_status(self, user_id: str, user_type: str, is_admin: bool = False) -> Dict[str, Any]:
"""Get current rate limit status for a user"""
config = self.get_config_for_user(user_type, is_admin)
current_time = datetime.now(UTC)
base_key = f"rate_limit:{user_type}:{user_id}"
keys = {
"minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}",
"hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}",
"day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}",
"burst": f"{base_key}:burst"
"burst": f"{base_key}:burst",
}
try:
pipe = self.redis.pipeline()
for key in keys.values():
pipe.get(key)
results = await pipe.execute()
current_counts = {
"minute": int(results[0] or 0),
"hour": int(results[1] or 0),
"day": int(results[2] or 0),
"burst": int(results[3] or 0)
"burst": int(results[3] or 0),
}
limits = {
"minute": config.requests_per_minute,
"hour": config.requests_per_hour,
"day": config.requests_per_day,
"burst": config.burst_limit
"burst": config.burst_limit,
}
return {
"user_id": user_id,
"user_type": user_type,
@ -258,58 +256,56 @@ class RateLimiter:
"limits": limits,
"remaining": {k: max(0, limits[k] - current_counts[k]) for k in limits.keys()},
"reset_times": self._calculate_reset_times(current_time),
"config": config.model_dump()
"config": config.model_dump(),
}
except Exception as e:
logger.error(f"❌ Failed to get rate limit status for {user_id}: {e}")
return {"error": str(e)}
async def reset_user_rate_limits(self, user_id: str, user_type: str) -> bool:
"""Reset all rate limits for a user (admin function)"""
try:
base_key = f"rate_limit:{user_type}:{user_id}"
pattern = f"{base_key}:*"
cursor = 0
deleted_count = 0
while True:
cursor, keys = await self.redis.scan(cursor, match=pattern, count=100)
if keys:
await self.redis.delete(*keys)
deleted_count += len(keys)
if cursor == 0:
break
logger.info(f"🔄 Reset {deleted_count} rate limit keys for {user_type} {user_id}")
return True
except Exception as e:
logger.error(f"❌ Failed to reset rate limits for {user_id}: {e}")
return False
# ============================
# Rate Limited Decorator
# ============================
def rate_limited(
guest_per_minute: int = 10,
user_per_minute: int = 60,
admin_per_minute: int = 120,
endpoint_specific: bool = True
guest_per_minute: int = 10, user_per_minute: int = 60, admin_per_minute: int = 120, endpoint_specific: bool = True
):
"""
Decorator to easily apply rate limiting to endpoints
Args:
guest_per_minute: Rate limit for guest users
user_per_minute: Rate limit for authenticated users
user_per_minute: Rate limit for authenticated users
admin_per_minute: Rate limit for admin users
endpoint_specific: Whether to apply endpoint-specific limits
Usage:
@rate_limited(guest_per_minute=5, user_per_minute=30)
@api_router.post("/my-endpoint")
@ -320,61 +316,66 @@ def rate_limited(
):
return {"message": "Rate limited endpoint"}
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs):
# Extract dependencies from function signature
import inspect
sig = inspect.signature(func)
inspect.signature(func)
# Get request, current_user, and rate_limiter from kwargs or args
request = None
current_user = None
rate_limiter = None
# Try to find dependencies in kwargs first
for param_name, param_value in kwargs.items():
if isinstance(param_value, Request):
request = param_value
elif hasattr(param_value, 'user_type'): # User-like object
elif hasattr(param_value, "user_type"): # User-like object
current_user = param_value
elif isinstance(param_value, RateLimiter):
rate_limiter = param_value
# If not found in kwargs, check if they're provided via Depends
if not rate_limiter:
# Create rate limiter instance (this should ideally come from DI)
database = get_database()
rate_limiter = RateLimiter(database)
# Apply rate limiting if we have the required components
if request and current_user and rate_limiter:
await apply_custom_rate_limiting(
request, current_user, rate_limiter,
guest_per_minute, user_per_minute, admin_per_minute
request, current_user, rate_limiter, guest_per_minute, user_per_minute, admin_per_minute
)
# Call the original function
return await func(*args, **kwargs)
return wrapper
return decorator
async def apply_custom_rate_limiting(
request: Request,
current_user,
rate_limiter: RateLimiter,
guest_per_minute: int,
user_per_minute: int,
admin_per_minute: int
user_per_minute: int,
admin_per_minute: int,
):
"""Apply custom rate limiting with specified limits"""
try:
# Determine user info
user_id = current_user.id
user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type)
is_admin = getattr(current_user, 'is_admin', False)
user_type = (
current_user.user_type.value if hasattr(current_user.user_type, "value") else str(current_user.user_type)
)
is_admin = getattr(current_user, "is_admin", False)
# Determine appropriate limit
if is_admin:
requests_per_minute = admin_per_minute
@ -382,16 +383,20 @@ async def apply_custom_rate_limiting(
requests_per_minute = guest_per_minute
else:
requests_per_minute = user_per_minute
# Create custom rate limit key
current_time = datetime.now(UTC)
custom_key = f"custom_rate_limit:{request.url.path}:{user_type}:{user_id}:minute:{current_time.strftime('%Y%m%d%H%M')}"
custom_key = (
f"custom_rate_limit:{request.url.path}:{user_type}:{user_id}:minute:{current_time.strftime('%Y%m%d%H%M')}"
)
# Check current usage
current_count = int(await rate_limiter.redis.get(custom_key) or 0)
if current_count >= requests_per_minute:
logger.warning(f"🚫 Custom rate limit exceeded for {user_type} {user_id}: {current_count}/{requests_per_minute}")
logger.warning(
f"🚫 Custom rate limit exceeded for {user_type} {user_id}: {current_count}/{requests_per_minute}"
)
raise HTTPException(
status_code=429,
detail={
@ -399,40 +404,40 @@ async def apply_custom_rate_limiting(
"message": f"Custom rate limit exceeded: {current_count}/{requests_per_minute} requests per minute",
"retryAfter": 60 - current_time.second,
"userType": user_type,
"endpoint": request.url.path
"endpoint": request.url.path,
},
headers={"Retry-After": str(60 - current_time.second)}
headers={"Retry-After": str(60 - current_time.second)},
)
# Increment counter
pipe = rate_limiter.redis.pipeline()
pipe.incr(custom_key)
pipe.expire(custom_key, 120) # 2 minutes TTL
await pipe.execute()
logger.debug(f"✅ Custom rate limit check passed for {user_type} {user_id}: {current_count + 1}/{requests_per_minute}")
logger.debug(
f"✅ Custom rate limit check passed for {user_type} {user_id}: {current_count + 1}/{requests_per_minute}"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ Custom rate limiting error: {e}")
# Fail open
# ============================
# Alternative: FastAPI Dependency-Based Rate Limiting
# ============================
def create_rate_limit_dependency(
guest_per_minute: int = 10,
user_per_minute: int = 60,
admin_per_minute: int = 120
):
def create_rate_limit_dependency(guest_per_minute: int = 10, user_per_minute: int = 60, admin_per_minute: int = 120):
"""
Create a FastAPI dependency for rate limiting
Usage:
rate_limit_5_30 = create_rate_limit_dependency(guest_per_minute=5, user_per_minute=30)
@api_router.post("/my-endpoint")
async def my_endpoint(
rate_check = Depends(rate_limit_5_30),
@ -441,85 +446,89 @@ def create_rate_limit_dependency(
):
return {"message": "Rate limited endpoint"}
"""
async def rate_limit_dependency(
request: Request,
current_user = Depends(get_current_user_or_guest),
rate_limiter: RateLimiter = Depends(get_rate_limiter)
current_user=Depends(get_current_user_or_guest),
rate_limiter: RateLimiter = Depends(get_rate_limiter),
):
await apply_custom_rate_limiting(
request, current_user, rate_limiter,
guest_per_minute, user_per_minute, admin_per_minute
request, current_user, rate_limiter, guest_per_minute, user_per_minute, admin_per_minute
)
return True
return rate_limit_dependency
# ============================
# Rate Limiting Utilities
# ============================
class EndpointRateLimiter:
"""Utility class for endpoint-specific rate limiting"""
def __init__(self, rate_limiter: RateLimiter):
self.rate_limiter = rate_limiter
self.custom_limits = {}
def set_endpoint_limits(self, endpoint: str, limits: dict):
"""Set custom limits for an endpoint"""
self.custom_limits[endpoint] = limits
async def check_endpoint_limit(self, request: Request, current_user) -> bool:
"""Check if request exceeds endpoint-specific limits"""
endpoint = request.url.path
if endpoint not in self.custom_limits:
return True # No custom limits set
limits = self.custom_limits[endpoint]
user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type)
if getattr(current_user, 'is_admin', False):
user_type = (
current_user.user_type.value if hasattr(current_user.user_type, "value") else str(current_user.user_type)
)
if getattr(current_user, "is_admin", False):
user_type = "admin"
limit = limits.get(user_type, limits.get("default", 60))
current_time = datetime.now(UTC)
key = f"endpoint_limit:{endpoint}:{user_type}:{current_user.id}:minute:{current_time.strftime('%Y%m%d%H%M')}"
current_count = int(await self.rate_limiter.redis.get(key) or 0)
if current_count >= limit:
raise HTTPException(
status_code=429,
detail=f"Endpoint rate limit exceeded: {current_count}/{limit} for {endpoint}"
status_code=429, detail=f"Endpoint rate limit exceeded: {current_count}/{limit} for {endpoint}"
)
# Increment counter
await self.rate_limiter.redis.incr(key)
await self.rate_limiter.redis.expire(key, 120)
return True
# Global endpoint rate limiter instance
endpoint_rate_limiter = None
def get_endpoint_rate_limiter(rate_limiter: RateLimiter = Depends(get_rate_limiter)) -> EndpointRateLimiter:
"""Get endpoint rate limiter instance"""
global endpoint_rate_limiter
if endpoint_rate_limiter is None:
endpoint_rate_limiter = EndpointRateLimiter(rate_limiter)
# Configure endpoint-specific limits
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/chat/sessions/*/messages/stream", {
"guest": 5, "candidate": 30, "employer": 30, "admin": 100
})
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/candidates/documents/upload", {
"guest": 2, "candidate": 10, "employer": 10, "admin": 50
})
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/jobs", {
"guest": 1, "candidate": 5, "employer": 20, "admin": 50
})
return endpoint_rate_limiter
# Configure endpoint-specific limits
endpoint_rate_limiter.set_endpoint_limits(
"/api/1.0/chat/sessions/*/messages/stream", {"guest": 5, "candidate": 30, "employer": 30, "admin": 100}
)
endpoint_rate_limiter.set_endpoint_limits(
"/api/1.0/candidates/documents/upload", {"guest": 2, "candidate": 10, "employer": 10, "admin": 50}
)
endpoint_rate_limiter.set_endpoint_limits(
"/api/1.0/jobs", {"guest": 1, "candidate": 5, "employer": 20, "admin": 50}
)
return endpoint_rate_limiter

View File

@ -3,37 +3,17 @@ Response utility functions for consistent API responses
"""
from typing import Any, Optional, Dict, List
def create_success_response(data: Any, meta: Optional[Dict] = None) -> Dict:
return {
"success": True,
"data": data,
"meta": meta
}
return {"success": True, "data": data, "meta": meta}
def create_error_response(code: str, message: str, details: Any = None) -> Dict:
return {
"success": False,
"error": {
"code": code,
"message": message,
"details": details
}
}
return {"success": False, "error": {"code": code, "message": message, "details": details}}
def create_paginated_response(
data: List[Any],
page: int,
limit: int,
total: int
) -> Dict:
def create_paginated_response(data: List[Any], page: int, limit: int, total: int) -> Dict:
total_pages = (total + limit - 1) // limit
has_more = page < total_pages
return {
"data": data,
"total": total,
"page": page,
"limit": limit,
"totalPages": total_pages,
"hasMore": has_more
}
return {"data": data, "total": total, "page": page, "limit": limit, "totalPages": total_pages, "hasMore": has_more}