Compare commits
2 Commits
f53ff967cb
...
2cf3fa7b04
Author | SHA1 | Date | |
---|---|---|---|
2cf3fa7b04 | |||
f1c2e16389 |
@ -17,11 +17,9 @@ from logger import logger
|
||||
AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses
|
||||
|
||||
# Maps class_name to (module_name, class_name)
|
||||
class_registry: Dict[str, Tuple[str, str]] = (
|
||||
{}
|
||||
)
|
||||
class_registry: Dict[str, Tuple[str, str]] = {}
|
||||
|
||||
__all__ = ['get_or_create_agent']
|
||||
__all__ = ["get_or_create_agent"]
|
||||
|
||||
package_dir = pathlib.Path(__file__).parent
|
||||
package_name = __name__
|
||||
@ -38,12 +36,7 @@ for path in package_dir.glob("*.py"):
|
||||
|
||||
# Find all Agent subclasses in the module
|
||||
for name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if (
|
||||
issubclass(obj, AnyAgent)
|
||||
and obj is not AnyAgent
|
||||
and obj is not Agent
|
||||
and name not in class_registry
|
||||
):
|
||||
if issubclass(obj, AnyAgent) and obj is not AnyAgent and obj is not Agent and name not in class_registry:
|
||||
class_registry[name] = (full_module_name, name)
|
||||
globals()[name] = obj
|
||||
logger.info(f"Adding agent: {name}")
|
||||
|
@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
||||
from pydantic import BaseModel, Field # type: ignore
|
||||
from typing import (
|
||||
Literal,
|
||||
get_args,
|
||||
@ -13,10 +13,10 @@ import time
|
||||
import re
|
||||
from abc import ABC
|
||||
from datetime import datetime, UTC
|
||||
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
|
||||
from prometheus_client import CollectorRegistry # type: ignore
|
||||
import numpy as np # type: ignore
|
||||
import json_extractor as json_extractor
|
||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
||||
from pydantic import BaseModel, Field # type: ignore
|
||||
from uuid import uuid4
|
||||
from typing import List, Optional, ClassVar, Any, Literal
|
||||
|
||||
@ -24,7 +24,7 @@ from datetime import datetime, UTC
|
||||
import numpy as np # type: ignore
|
||||
|
||||
from uuid import uuid4
|
||||
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
||||
from prometheus_client import CollectorRegistry # type: ignore
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
@ -32,19 +32,45 @@ from pathlib import Path
|
||||
from rag import start_file_watcher, ChromaDBFileWatcher
|
||||
import defines
|
||||
from logger import logger
|
||||
from models import (Tunables, ChatMessageUser, ChatMessage, RagEntry, ChatMessageMetaData, ApiStatusType, Candidate, ChatContextType)
|
||||
from models import (
|
||||
Tunables,
|
||||
ChatMessageUser,
|
||||
ChatMessage,
|
||||
RagEntry,
|
||||
ChatMessageMetaData,
|
||||
ApiStatusType,
|
||||
Candidate,
|
||||
ChatContextType,
|
||||
)
|
||||
import utils.llm_proxy as llm_manager
|
||||
from database.manager import RedisDatabase
|
||||
from models import ChromaDBGetResponse
|
||||
from utils.metrics import Metrics
|
||||
|
||||
|
||||
from models import ( ApiActivityType, ApiMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiStatusType, ChatMessageMetaData, Candidate)
|
||||
from models import (
|
||||
ApiActivityType,
|
||||
ApiMessage,
|
||||
ChatMessageError,
|
||||
ChatMessageRagSearch,
|
||||
ChatMessageStatus,
|
||||
ChatMessageStreaming,
|
||||
LLMMessage,
|
||||
ChatMessage,
|
||||
ChatOptions,
|
||||
ChatMessageUser,
|
||||
Tunables,
|
||||
ApiStatusType,
|
||||
ChatMessageMetaData,
|
||||
Candidate,
|
||||
)
|
||||
from logger import logger
|
||||
import defines
|
||||
from .registry import agent_registry
|
||||
|
||||
from models import ( ChromaDBGetResponse )
|
||||
from models import ChromaDBGetResponse
|
||||
|
||||
|
||||
class CandidateEntity(Candidate):
|
||||
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
|
||||
|
||||
@ -59,13 +85,10 @@ class CandidateEntity(Candidate):
|
||||
CandidateEntity__agents: List[Agent] = []
|
||||
CandidateEntity__observer: Optional[Any] = Field(default=None, exclude=True)
|
||||
CandidateEntity__file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
|
||||
CandidateEntity__prometheus_collector: Optional[CollectorRegistry] = Field(
|
||||
default=None, exclude=True
|
||||
)
|
||||
CandidateEntity__prometheus_collector: Optional[CollectorRegistry] = Field(default=None, exclude=True)
|
||||
|
||||
CandidateEntity__metrics: Optional[Metrics] = Field(
|
||||
default=None,
|
||||
description="Metrics collector for this agent, used to track performance and usage."
|
||||
default=None, description="Metrics collector for this agent, used to track performance and usage."
|
||||
)
|
||||
|
||||
def __init__(self, candidate=None):
|
||||
@ -78,7 +101,7 @@ class CandidateEntity(Candidate):
|
||||
@classmethod
|
||||
def exists(cls, username: str):
|
||||
# Validate username format (only allow safe characters)
|
||||
if not re.match(r'^[a-zA-Z0-9_-]+$', username):
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", username):
|
||||
return False # Invalid username characters
|
||||
|
||||
# Check for minimum and maximum length
|
||||
@ -117,11 +140,7 @@ class CandidateEntity(Candidate):
|
||||
if agent.agent_type == agent_type:
|
||||
return agent
|
||||
|
||||
return get_or_create_agent(
|
||||
agent_type=agent_type,
|
||||
user=self,
|
||||
prometheus_collector=self.prometheus_collector
|
||||
)
|
||||
return get_or_create_agent(agent_type=agent_type, user=self, prometheus_collector=self.prometheus_collector)
|
||||
|
||||
# Wrapper properties that map into file_watcher
|
||||
@property
|
||||
@ -132,6 +151,7 @@ class CandidateEntity(Candidate):
|
||||
|
||||
# Fields managed by initialize()
|
||||
CandidateEntity__initialized: bool = Field(default=False, exclude=True)
|
||||
|
||||
@property
|
||||
def metrics(self) -> Metrics:
|
||||
if not self.CandidateEntity__metrics:
|
||||
@ -160,15 +180,10 @@ class CandidateEntity(Candidate):
|
||||
if not self.metrics:
|
||||
logger.warning("No metrics collector set for this agent.")
|
||||
return
|
||||
self.metrics.tokens_prompt.labels(agent=agent.agent_type).inc(
|
||||
response.usage.prompt_eval_count
|
||||
)
|
||||
self.metrics.tokens_prompt.labels(agent=agent.agent_type).inc(response.usage.prompt_eval_count)
|
||||
self.metrics.tokens_eval.labels(agent=agent.agent_type).inc(response.usage.eval_count)
|
||||
|
||||
async def initialize(
|
||||
self,
|
||||
prometheus_collector: CollectorRegistry,
|
||||
database: RedisDatabase):
|
||||
async def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
|
||||
if self.CandidateEntity__initialized:
|
||||
# Initialization can only be attempted once; if there are multiple attempts, it means
|
||||
# a subsystem is failing or there is a logic bug in the code.
|
||||
@ -188,8 +203,8 @@ class CandidateEntity(Candidate):
|
||||
self.CandidateEntity__metrics = Metrics(prometheus_collector=self.prometheus_collector)
|
||||
|
||||
user_dir = os.path.join(defines.user_dir, self.username)
|
||||
vector_db_dir=os.path.join(user_dir, defines.persist_directory)
|
||||
rag_content_dir=os.path.join(user_dir, defines.rag_content_dir)
|
||||
vector_db_dir = os.path.join(user_dir, defines.persist_directory)
|
||||
rag_content_dir = os.path.join(user_dir, defines.rag_content_dir)
|
||||
|
||||
os.makedirs(vector_db_dir, exist_ok=True)
|
||||
os.makedirs(rag_content_dir, exist_ok=True)
|
||||
@ -205,17 +220,21 @@ class CandidateEntity(Candidate):
|
||||
)
|
||||
has_username_rag = any(item.name == self.username for item in self.rags)
|
||||
if not has_username_rag:
|
||||
self.rags.append(RagEntry(
|
||||
self.rags.append(
|
||||
RagEntry(
|
||||
name=self.username,
|
||||
description=f"Expert data about {self.full_name}.",
|
||||
))
|
||||
)
|
||||
)
|
||||
self.rag_content_size = self.file_watcher.collection.count()
|
||||
|
||||
|
||||
class Agent(BaseModel, ABC):
|
||||
"""
|
||||
Base class for all agent types.
|
||||
This class defines the common attributes and methods for all agent types.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True # Allow arbitrary types like RedisDatabase
|
||||
|
||||
@ -237,7 +256,7 @@ class Agent(BaseModel, ABC):
|
||||
|
||||
conversation: List[ChatMessageUser] = Field(
|
||||
default_factory=list,
|
||||
description="Conversation history for this agent, used to maintain context across messages."
|
||||
description="Conversation history for this agent, used to maintain context across messages.",
|
||||
)
|
||||
|
||||
@property
|
||||
@ -254,9 +273,7 @@ class Agent(BaseModel, ABC):
|
||||
last_item = item
|
||||
return last_item
|
||||
|
||||
def set_optimal_context_size(
|
||||
self, llm: Any, model: str, prompt: str, ctx_buffer=2048
|
||||
) -> int:
|
||||
def set_optimal_context_size(self, llm: Any, model: str, prompt: str, ctx_buffer=2048) -> int:
|
||||
# Most models average 1.3-1.5 tokens per word
|
||||
word_count = len(prompt.split())
|
||||
tokens = int(word_count * 1.4)
|
||||
@ -265,9 +282,7 @@ class Agent(BaseModel, ABC):
|
||||
total_ctx = tokens + ctx_buffer
|
||||
|
||||
if total_ctx > self.context_size:
|
||||
logger.info(
|
||||
f"Increasing context size from {self.context_size} to {total_ctx}"
|
||||
)
|
||||
logger.info(f"Increasing context size from {self.context_size} to {total_ctx}")
|
||||
|
||||
# Grow the context size if necessary
|
||||
self.context_size = max(self.context_size, total_ctx)
|
||||
@ -472,24 +487,24 @@ class Agent(BaseModel, ABC):
|
||||
context = []
|
||||
for chroma_results in rag_message.content:
|
||||
for index, metadata in enumerate(chroma_results.metadatas):
|
||||
content = "\n".join([
|
||||
line.strip()
|
||||
for line in chroma_results.documents[index].split("\n")
|
||||
if line
|
||||
]).strip()
|
||||
context.append(f"""
|
||||
content = "\n".join(
|
||||
[line.strip() for line in chroma_results.documents[index].split("\n") if line]
|
||||
).strip()
|
||||
context.append(
|
||||
f"""
|
||||
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")}
|
||||
Document reference: {chroma_results.ids[index]}
|
||||
Content: {content}
|
||||
""")
|
||||
"""
|
||||
)
|
||||
return "\n".join(context)
|
||||
|
||||
async def generate_rag_results(
|
||||
self,
|
||||
session_id: str,
|
||||
prompt: str,
|
||||
top_k: int=defines.default_rag_top_k,
|
||||
threshold: float=defines.default_rag_threshold,
|
||||
top_k: int = defines.default_rag_top_k,
|
||||
threshold: float = defines.default_rag_threshold,
|
||||
) -> AsyncGenerator[ApiMessage, None]:
|
||||
"""
|
||||
Generate RAG results for the given query.
|
||||
@ -501,15 +516,11 @@ Content: {content}
|
||||
A list of dictionaries containing the RAG results.
|
||||
"""
|
||||
if not self.user:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="No user set for RAG generation."
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content="No user set for RAG generation.")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
results : List[ChromaDBGetResponse] = []
|
||||
entries: int = 0
|
||||
results: List[ChromaDBGetResponse] = []
|
||||
user: CandidateEntity = self.user
|
||||
for rag in user.rags:
|
||||
if not rag.enabled:
|
||||
@ -518,14 +529,12 @@ Content: {content}
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
activity=ApiActivityType.SEARCHING,
|
||||
content = f"Searching RAG context {rag.name}..."
|
||||
content=f"Searching RAG context {rag.name}...",
|
||||
)
|
||||
yield status_message
|
||||
|
||||
try:
|
||||
chroma_results = await user.file_watcher.find_similar(
|
||||
query=prompt, top_k=top_k, threshold=threshold
|
||||
)
|
||||
chroma_results = await user.file_watcher.find_similar(query=prompt, top_k=top_k, threshold=threshold)
|
||||
if not chroma_results:
|
||||
continue
|
||||
query_embedding = np.array(chroma_results["query_embedding"]).flatten() # type: ignore
|
||||
@ -549,7 +558,7 @@ Content: {content}
|
||||
continue_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
activity=ApiActivityType.SEARCHING,
|
||||
content=f"Error searching RAG context {rag.name}: {str(e)}"
|
||||
content=f"Error searching RAG context {rag.name}: {str(e)}",
|
||||
)
|
||||
yield continue_message
|
||||
|
||||
@ -563,22 +572,20 @@ Content: {content}
|
||||
|
||||
async def llm_one_shot(
|
||||
self,
|
||||
llm: Any, model: str,
|
||||
session_id: str, prompt: str, system_prompt: str,
|
||||
llm: Any,
|
||||
model: str,
|
||||
session_id: str,
|
||||
prompt: str,
|
||||
system_prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7) -> AsyncGenerator[ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessage, None]:
|
||||
|
||||
temperature=0.7,
|
||||
) -> AsyncGenerator[ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessage, None]:
|
||||
if not self.user:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="No user set for chat generation."
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content="No user set for chat generation.")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
self.set_optimal_context_size(
|
||||
llm=llm, model=model, prompt=prompt+system_prompt
|
||||
)
|
||||
self.set_optimal_context_size(llm=llm, model=model, prompt=prompt + system_prompt)
|
||||
|
||||
options = ChatOptions(
|
||||
seed=8911,
|
||||
@ -592,9 +599,7 @@ Content: {content}
|
||||
]
|
||||
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
activity=ApiActivityType.GENERATING,
|
||||
content=f"Generating response..."
|
||||
session_id=session_id, activity=ApiActivityType.GENERATING, content="Generating response..."
|
||||
)
|
||||
yield status_message
|
||||
|
||||
@ -610,10 +615,7 @@ Content: {content}
|
||||
stream=True,
|
||||
):
|
||||
if not response:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="No response from LLM."
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
@ -628,46 +630,34 @@ Content: {content}
|
||||
yield streaming_message
|
||||
|
||||
if not response:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="No response from LLM."
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
self.user.collect_metrics(agent=self, response=response)
|
||||
self.context_tokens = (
|
||||
response.usage.prompt_eval_count + response.usage.eval_count
|
||||
)
|
||||
self.context_tokens = response.usage.prompt_eval_count + response.usage.eval_count
|
||||
|
||||
chat_message = ChatMessage(
|
||||
session_id=session_id,
|
||||
tunables=tunables,
|
||||
status=ApiStatusType.DONE,
|
||||
content=content,
|
||||
metadata = ChatMessageMetaData(
|
||||
metadata=ChatMessageMetaData(
|
||||
options=options,
|
||||
eval_count=response.usage.eval_count,
|
||||
eval_duration=response.usage.eval_duration,
|
||||
prompt_eval_count=response.usage.prompt_eval_count,
|
||||
prompt_eval_duration=response.usage.prompt_eval_duration,
|
||||
|
||||
)
|
||||
),
|
||||
)
|
||||
yield chat_message
|
||||
return
|
||||
|
||||
async def generate(
|
||||
self, llm: Any, model: str,
|
||||
session_id: str, prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7
|
||||
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
|
||||
) -> AsyncGenerator[ApiMessage, None]:
|
||||
if not self.user:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="No user set for chat generation."
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content="No user set for chat generation.")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
@ -675,12 +665,11 @@ Content: {content}
|
||||
session_id=session_id,
|
||||
content=prompt,
|
||||
)
|
||||
user = self.user
|
||||
|
||||
self.user.metrics.generate_count.labels(agent=self.agent_type).inc()
|
||||
with self.user.metrics.generate_duration.labels(agent=self.agent_type).time():
|
||||
context = None
|
||||
rag_message : ChatMessageRagSearch | None = None
|
||||
rag_message: ChatMessageRagSearch | None = None
|
||||
if self.user:
|
||||
message = None
|
||||
async for message in self.generate_rag_results(session_id=session_id, prompt=prompt):
|
||||
@ -692,38 +681,32 @@ Content: {content}
|
||||
yield message
|
||||
|
||||
if not isinstance(message, ChatMessageRagSearch):
|
||||
raise ValueError(
|
||||
f"Expected ChatMessageRagSearch, got {type(rag_message)}"
|
||||
)
|
||||
raise ValueError(f"Expected ChatMessageRagSearch, got {type(rag_message)}")
|
||||
|
||||
rag_message = message
|
||||
context = self.get_rag_context(rag_message)
|
||||
|
||||
# Create a pruned down message list based purely on the prompt and responses,
|
||||
# discarding the full preamble generated by prepare_message
|
||||
messages: List[LLMMessage] = [
|
||||
LLMMessage(role="system", content=self.system_prompt)
|
||||
]
|
||||
messages: List[LLMMessage] = [LLMMessage(role="system", content=self.system_prompt)]
|
||||
# Add the conversation history to the messages
|
||||
messages.extend([
|
||||
messages.extend(
|
||||
[
|
||||
LLMMessage(role="user" if isinstance(m, ChatMessageUser) else "assistant", content=m.content)
|
||||
for m in self.conversation
|
||||
])
|
||||
]
|
||||
)
|
||||
# Add the RAG context to the messages if available
|
||||
if context:
|
||||
messages.append(
|
||||
LLMMessage(
|
||||
role="user",
|
||||
content=f"<|context|>\nThe following is context information about {self.user.full_name}:\n{context}\n</|context|>\n\nPrompt to respond to:\n{prompt}\n"
|
||||
content=f"<|context|>\nThe following is context information about {self.user.full_name}:\n{context}\n</|context|>\n\nPrompt to respond to:\n{prompt}\n",
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Only the actual user query is provided with the full context message
|
||||
messages.append(
|
||||
LLMMessage(role="user", content=prompt)
|
||||
)
|
||||
|
||||
llm_history = messages
|
||||
messages.append(LLMMessage(role="user", content=prompt))
|
||||
|
||||
# use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
|
||||
# message.metadata.tools = {
|
||||
@ -827,16 +810,12 @@ Content: {content}
|
||||
|
||||
# not use_tools
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
activity=ApiActivityType.GENERATING,
|
||||
content=f"Generating response..."
|
||||
session_id=session_id, activity=ApiActivityType.GENERATING, content="Generating response..."
|
||||
)
|
||||
yield status_message
|
||||
|
||||
# Set the response for streaming
|
||||
self.set_optimal_context_size(
|
||||
llm, model, prompt=prompt
|
||||
)
|
||||
self.set_optimal_context_size(llm, model, prompt=prompt)
|
||||
|
||||
options = ChatOptions(
|
||||
seed=8911,
|
||||
@ -856,10 +835,7 @@ Content: {content}
|
||||
stream=True,
|
||||
):
|
||||
if not response:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="No response from LLM."
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
@ -873,17 +849,12 @@ Content: {content}
|
||||
yield streaming_message
|
||||
|
||||
if not response:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="No response from LLM."
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
self.user.collect_metrics(agent=self, response=response)
|
||||
self.context_tokens = (
|
||||
response.usage.prompt_eval_count + response.usage.eval_count
|
||||
)
|
||||
self.context_tokens = response.usage.prompt_eval_count + response.usage.eval_count
|
||||
end_time = time.perf_counter()
|
||||
|
||||
chat_message = ChatMessage(
|
||||
@ -891,7 +862,7 @@ Content: {content}
|
||||
tunables=tunables,
|
||||
status=ApiStatusType.DONE,
|
||||
content=content,
|
||||
metadata = ChatMessageMetaData(
|
||||
metadata=ChatMessageMetaData(
|
||||
options=options,
|
||||
eval_count=response.usage.eval_count,
|
||||
eval_duration=response.usage.eval_duration,
|
||||
@ -902,9 +873,8 @@ Content: {content}
|
||||
"llm_streamed": end_time - start_time,
|
||||
"llm_with_tools": 0, # Placeholder for tool processing time
|
||||
},
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Add the user and chat messages to the conversation
|
||||
self.conversation.append(user_message)
|
||||
@ -999,12 +969,13 @@ Content: {content}
|
||||
|
||||
raise ValueError("No Markdown found in the response")
|
||||
|
||||
|
||||
_agents: List[Agent] = []
|
||||
|
||||
|
||||
def get_or_create_agent(
|
||||
agent_type: str,
|
||||
prometheus_collector: CollectorRegistry,
|
||||
user: Optional[CandidateEntity]=None) -> Agent:
|
||||
agent_type: str, prometheus_collector: CollectorRegistry, user: Optional[CandidateEntity] = None
|
||||
) -> Agent:
|
||||
"""
|
||||
Get or create and append a new agent of the specified type, ensuring only one agent per type exists.
|
||||
|
||||
@ -1028,14 +999,16 @@ def get_or_create_agent(
|
||||
for agent_cls in Agent.__subclasses__():
|
||||
if agent_cls.model_fields["agent_type"].default == agent_type:
|
||||
# Create the agent instance with provided kwargs
|
||||
agent = agent_cls(agent_type=agent_type, # type: ignore[call-arg]
|
||||
user=user)
|
||||
agent = agent_cls(
|
||||
agent_type=agent_type, # type: ignore[call-arg]
|
||||
user=user,
|
||||
)
|
||||
_agents.append(agent)
|
||||
return agent
|
||||
|
||||
raise ValueError(f"No agent class found for agent_type: {agent_type}")
|
||||
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(Agent._agent_type, Agent)
|
||||
CandidateEntity.model_rebuild()
|
||||
|
||||
|
@ -5,10 +5,10 @@ from .base import Agent, agent_registry
|
||||
from logger import logger
|
||||
|
||||
from .registry import agent_registry
|
||||
from models import ( ApiMessage, Tunables, ApiStatusType)
|
||||
from models import ApiMessage, Tunables, ApiStatusType
|
||||
|
||||
|
||||
system_message = f"""
|
||||
system_message = """
|
||||
When answering queries, follow these steps:
|
||||
|
||||
- When any content from <|context|> is relevant, synthesize information from all sources to provide the most complete answer.
|
||||
@ -21,6 +21,7 @@ Always <|context|> when possible. Be concise, and never make up information. If
|
||||
Before answering, ensure you have spelled the candidate's name correctly.
|
||||
"""
|
||||
|
||||
|
||||
class CandidateChat(Agent):
|
||||
"""
|
||||
CandidateChat Agent
|
||||
@ -32,10 +33,7 @@ class CandidateChat(Agent):
|
||||
system_prompt: str = system_message
|
||||
|
||||
async def generate(
|
||||
self, llm: Any, model: str,
|
||||
session_id: str, prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7
|
||||
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
|
||||
) -> AsyncGenerator[ApiMessage, None]:
|
||||
user = self.user
|
||||
if not user:
|
||||
@ -54,12 +52,14 @@ Use that spelling instead of any spelling you may find in the <|context|>.
|
||||
{system_message}
|
||||
"""
|
||||
|
||||
async for message in super().generate(llm=llm, model=model, session_id=session_id, prompt=prompt, temperature=temperature, tunables=tunables):
|
||||
async for message in super().generate(
|
||||
llm=llm, model=model, session_id=session_id, prompt=prompt, temperature=temperature, tunables=tunables
|
||||
):
|
||||
if message.status == ApiStatusType.ERROR:
|
||||
yield message
|
||||
return
|
||||
yield message
|
||||
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(CandidateChat._agent_type, CandidateChat)
|
||||
|
||||
|
@ -49,11 +49,13 @@ class Chat(Agent):
|
||||
"""
|
||||
Chat Agent
|
||||
"""
|
||||
|
||||
agent_type: Literal["general"] = "general" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
system_prompt: str = system_message
|
||||
|
||||
|
||||
# async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
|
||||
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||
# if not self.context:
|
||||
|
@ -1,14 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from typing import (
|
||||
Dict,
|
||||
Literal,
|
||||
ClassVar,
|
||||
cast,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
List,
|
||||
Optional
|
||||
# override
|
||||
Optional,
|
||||
# override
|
||||
) # NOTE: You must import Optional for late binding to work
|
||||
import random
|
||||
import time
|
||||
@ -16,7 +13,15 @@ import time
|
||||
import os
|
||||
|
||||
from .base import Agent, agent_registry
|
||||
from models import ApiActivityType, ChatMessage, ChatMessageError, ChatMessageStatus, ChatMessageStreaming, ApiStatusType, Tunables
|
||||
from models import (
|
||||
ApiActivityType,
|
||||
ChatMessage,
|
||||
ChatMessageError,
|
||||
ChatMessageStatus,
|
||||
ChatMessageStreaming,
|
||||
ApiStatusType,
|
||||
Tunables,
|
||||
)
|
||||
from logger import logger
|
||||
import defines
|
||||
import backstory_traceback as traceback
|
||||
@ -26,6 +31,7 @@ from image_generator.profile_image import generate_image, ImageRequest
|
||||
seed = int(time.time())
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
class ImageGenerator(Agent):
|
||||
agent_type: Literal["generate_image"] = "generate_image" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
@ -34,10 +40,7 @@ class ImageGenerator(Agent):
|
||||
system_prompt: str = "" # No system prompt is used
|
||||
|
||||
async def generate(
|
||||
self, llm: Any, model: str,
|
||||
session_id: str, prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7
|
||||
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
|
||||
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
|
||||
if not self.user:
|
||||
logger.error("User is not set for ImageGenerator agent.")
|
||||
@ -57,11 +60,17 @@ class ImageGenerator(Agent):
|
||||
yield status_message
|
||||
|
||||
logger.info(f"Image generation: {file_path} <- {prompt}")
|
||||
request = ImageRequest(filepath=file_path, session_id=session_id, prompt=prompt, iterations=4, height=256, width=256, guidance_scale=7.5)
|
||||
request = ImageRequest(
|
||||
filepath=file_path,
|
||||
session_id=session_id,
|
||||
prompt=prompt,
|
||||
iterations=4,
|
||||
height=256,
|
||||
width=256,
|
||||
guidance_scale=7.5,
|
||||
)
|
||||
generated_message = None
|
||||
async for generated_message in generate_image(
|
||||
request=request
|
||||
):
|
||||
async for generated_message in generate_image(request=request):
|
||||
if generated_message.status == ApiStatusType.ERROR:
|
||||
yield generated_message
|
||||
return
|
||||
@ -71,8 +80,7 @@ class ImageGenerator(Agent):
|
||||
|
||||
if generated_message is None:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Image generation failed to produce a valid response."
|
||||
session_id=session_id, content="Image generation failed to produce a valid response."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
@ -86,21 +94,19 @@ class ImageGenerator(Agent):
|
||||
generated_image = ChatMessage(
|
||||
session_id=session_id,
|
||||
status=ApiStatusType.DONE,
|
||||
content = f"{defines.api_prefix}/profile/{user.username}",
|
||||
metadata=generated_message.metadata
|
||||
content=f"{defines.api_prefix}/profile/{user.username}",
|
||||
metadata=generated_message.metadata,
|
||||
)
|
||||
yield generated_image
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content=f"Error generating image: {str(e)}"
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content=f"Error generating image: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(ImageGenerator._agent_type, ImageGenerator)
|
||||
|
@ -1,16 +1,15 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import model_validator, Field, BaseModel # type: ignore
|
||||
from pydantic import Field # type: ignore
|
||||
from typing import (
|
||||
Dict,
|
||||
Literal,
|
||||
ClassVar,
|
||||
cast,
|
||||
Any,
|
||||
Tuple,
|
||||
AsyncGenerator,
|
||||
List,
|
||||
Optional
|
||||
# override
|
||||
Optional,
|
||||
# override
|
||||
) # NOTE: You must import Optional for late binding to work
|
||||
import random
|
||||
import re
|
||||
@ -19,10 +18,18 @@ import time
|
||||
import time
|
||||
import os
|
||||
import random
|
||||
from names_dataset import NameDataset, NameWrapper # type: ignore
|
||||
|
||||
from .base import Agent, agent_registry
|
||||
from models import ApiActivityType, ChatMessage, ChatMessageError, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ApiStatusType, Tunables
|
||||
from models import (
|
||||
ApiActivityType,
|
||||
ChatMessage,
|
||||
ChatMessageError,
|
||||
ApiMessageType,
|
||||
ChatMessageStatus,
|
||||
ChatMessageStreaming,
|
||||
ApiStatusType,
|
||||
Tunables,
|
||||
)
|
||||
from logger import logger
|
||||
import defines
|
||||
import backstory_traceback as traceback
|
||||
@ -45,6 +52,7 @@ emptyUser = {
|
||||
"questions": [],
|
||||
}
|
||||
|
||||
|
||||
def generate_persona_system_prompt(persona: Dict[str, Any]) -> str:
|
||||
return f"""\
|
||||
You are a casting director for a movie. Your job is to provide information on ficticious personas for use in a screen play.
|
||||
@ -86,6 +94,7 @@ DO NOT infer, imply, abbreviate, or state the ethnicity, gender, or age in the u
|
||||
You are providing those only for use later by the system when casting individuals for the role.
|
||||
"""
|
||||
|
||||
|
||||
generate_resume_system_prompt = """
|
||||
You are a creative writing casting director. As part of the casting, you are building backstories about individuals. The first part
|
||||
of that is to create an in-depth resume for the person. You will be provided with the following information:
|
||||
@ -117,10 +126,12 @@ import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EthnicNameGenerator:
|
||||
def __init__(self):
|
||||
try:
|
||||
from names_dataset import NameDataset # type: ignore
|
||||
|
||||
self.nd = NameDataset()
|
||||
except ImportError:
|
||||
logger.error("NameDataset not available. Please install: pip install names-dataset")
|
||||
@ -131,24 +142,24 @@ class EthnicNameGenerator:
|
||||
|
||||
# US Census 2020 approximate ethnic distribution
|
||||
self.ethnic_weights = {
|
||||
'White': 0.576,
|
||||
'Hispanic': 0.186,
|
||||
'Black': 0.134,
|
||||
'Asian': 0.062,
|
||||
'Native American': 0.013,
|
||||
'Pacific Islander': 0.003,
|
||||
'Mixed/Other': 0.026
|
||||
"White": 0.576,
|
||||
"Hispanic": 0.186,
|
||||
"Black": 0.134,
|
||||
"Asian": 0.062,
|
||||
"Native American": 0.013,
|
||||
"Pacific Islander": 0.003,
|
||||
"Mixed/Other": 0.026,
|
||||
}
|
||||
|
||||
# Map ethnicities to countries (using alpha-2 codes that NameDataset uses)
|
||||
self.ethnic_country_mapping = {
|
||||
'White': ['US', 'GB', 'DE', 'IE', 'IT', 'PL', 'FR', 'CA', 'AU'],
|
||||
'Hispanic': ['MX', 'ES', 'CO', 'PE', 'AR', 'CU', 'VE', 'CL'],
|
||||
'Black': ['US'], # African American names
|
||||
'Asian': ['CN', 'IN', 'PH', 'VN', 'KR', 'JP', 'TH', 'MY'],
|
||||
'Native American': ['US'],
|
||||
'Pacific Islander': ['US'],
|
||||
'Mixed/Other': ['US']
|
||||
"White": ["US", "GB", "DE", "IE", "IT", "PL", "FR", "CA", "AU"],
|
||||
"Hispanic": ["MX", "ES", "CO", "PE", "AR", "CU", "VE", "CL"],
|
||||
"Black": ["US"], # African American names
|
||||
"Asian": ["CN", "IN", "PH", "VN", "KR", "JP", "TH", "MY"],
|
||||
"Native American": ["US"],
|
||||
"Pacific Islander": ["US"],
|
||||
"Mixed/Other": ["US"],
|
||||
}
|
||||
|
||||
def get_weighted_ethnicity(self) -> str:
|
||||
@ -157,8 +168,9 @@ class EthnicNameGenerator:
|
||||
weights = list(self.ethnic_weights.values())
|
||||
return random.choices(ethnicities, weights=weights)[0]
|
||||
|
||||
def get_names_by_criteria(self, countries: List[str], gender: Optional[str] = None,
|
||||
n: int = 50, use_first_names: bool = True) -> List[str]:
|
||||
def get_names_by_criteria(
|
||||
self, countries: List[str], gender: Optional[str] = None, n: int = 50, use_first_names: bool = True
|
||||
) -> List[str]:
|
||||
"""Get names matching criteria using NameDataset's get_top_names method"""
|
||||
if not self.nd:
|
||||
return []
|
||||
@ -168,16 +180,13 @@ class EthnicNameGenerator:
|
||||
try:
|
||||
# Get top names for this country
|
||||
top_names = self.nd.get_top_names(
|
||||
n=n,
|
||||
use_first_names=use_first_names,
|
||||
country_alpha2=country_code,
|
||||
gender=gender
|
||||
n=n, use_first_names=use_first_names, country_alpha2=country_code, gender=gender
|
||||
)
|
||||
|
||||
if country_code in top_names:
|
||||
if use_first_names and gender:
|
||||
# For first names with gender specified
|
||||
gender_key = 'M' if gender.upper() in ['M', 'MALE'] else 'F'
|
||||
gender_key = "M" if gender.upper() in ["M", "MALE"] else "F"
|
||||
if gender_key in top_names[country_code]:
|
||||
all_names.extend(top_names[country_code][gender_key])
|
||||
elif use_first_names:
|
||||
@ -194,25 +203,18 @@ class EthnicNameGenerator:
|
||||
|
||||
return list(set(all_names)) # Remove duplicates
|
||||
|
||||
def get_name_by_ethnicity(self, ethnicity: str, gender: str = 'random') -> Tuple[str, str, str, str]:
|
||||
def get_name_by_ethnicity(self, ethnicity: str, gender: str = "random") -> Tuple[str, str, str, str]:
|
||||
"""Generate a name based on ethnicity using the correct NameDataset API"""
|
||||
if gender == 'random':
|
||||
gender = random.choice(['Male', 'Female'])
|
||||
if gender == "random":
|
||||
gender = random.choice(["Male", "Female"])
|
||||
|
||||
countries = self.ethnic_country_mapping.get(ethnicity, ['US'])
|
||||
countries = self.ethnic_country_mapping.get(ethnicity, ["US"])
|
||||
|
||||
# Get first names
|
||||
first_names = self.get_names_by_criteria(
|
||||
countries=countries,
|
||||
gender=gender,
|
||||
use_first_names=True
|
||||
)
|
||||
first_names = self.get_names_by_criteria(countries=countries, gender=gender, use_first_names=True)
|
||||
|
||||
# Get last names
|
||||
last_names = self.get_names_by_criteria(
|
||||
countries=countries,
|
||||
use_first_names=False
|
||||
)
|
||||
last_names = self.get_names_by_criteria(countries=countries, use_first_names=False)
|
||||
|
||||
# Select names or use fallbacks
|
||||
if first_names:
|
||||
@ -232,57 +234,60 @@ class EthnicNameGenerator:
|
||||
def _get_fallback_first_name(self, gender: str, ethnicity: str) -> str:
|
||||
"""Provide culturally appropriate fallback first names"""
|
||||
fallback_names = {
|
||||
'White': {
|
||||
'Male': ['James', 'Robert', 'John', 'Michael', 'William', 'David', 'Richard', 'Joseph'],
|
||||
'Female': ['Mary', 'Patricia', 'Jennifer', 'Linda', 'Elizabeth', 'Barbara', 'Susan', 'Jessica']
|
||||
"White": {
|
||||
"Male": ["James", "Robert", "John", "Michael", "William", "David", "Richard", "Joseph"],
|
||||
"Female": ["Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan", "Jessica"],
|
||||
},
|
||||
'Hispanic': {
|
||||
'Male': ['José', 'Luis', 'Miguel', 'Juan', 'Francisco', 'Alejandro', 'Antonio', 'Carlos'],
|
||||
'Female': ['María', 'Guadalupe', 'Juana', 'Margarita', 'Francisca', 'Teresa', 'Rosa', 'Ana']
|
||||
"Hispanic": {
|
||||
"Male": ["José", "Luis", "Miguel", "Juan", "Francisco", "Alejandro", "Antonio", "Carlos"],
|
||||
"Female": ["María", "Guadalupe", "Juana", "Margarita", "Francisca", "Teresa", "Rosa", "Ana"],
|
||||
},
|
||||
'Black': {
|
||||
'Male': ['James', 'Robert', 'John', 'Michael', 'William', 'David', 'Richard', 'Charles'],
|
||||
'Female': ['Mary', 'Patricia', 'Linda', 'Elizabeth', 'Barbara', 'Susan', 'Jessica', 'Sarah']
|
||||
"Black": {
|
||||
"Male": ["James", "Robert", "John", "Michael", "William", "David", "Richard", "Charles"],
|
||||
"Female": ["Mary", "Patricia", "Linda", "Elizabeth", "Barbara", "Susan", "Jessica", "Sarah"],
|
||||
},
|
||||
"Asian": {
|
||||
"Male": ["Wei", "Ming", "Chen", "Li", "Kumar", "Raj", "Hiroshi", "Takeshi"],
|
||||
"Female": ["Mei", "Lin", "Ling", "Priya", "Yuki", "Soo", "Hana", "Anh"],
|
||||
},
|
||||
'Asian': {
|
||||
'Male': ['Wei', 'Ming', 'Chen', 'Li', 'Kumar', 'Raj', 'Hiroshi', 'Takeshi'],
|
||||
'Female': ['Mei', 'Lin', 'Ling', 'Priya', 'Yuki', 'Soo', 'Hana', 'Anh']
|
||||
}
|
||||
}
|
||||
|
||||
ethnicity_names = fallback_names.get(ethnicity, fallback_names['White'])
|
||||
return random.choice(ethnicity_names.get(gender, ethnicity_names['Male']))
|
||||
ethnicity_names = fallback_names.get(ethnicity, fallback_names["White"])
|
||||
return random.choice(ethnicity_names.get(gender, ethnicity_names["Male"]))
|
||||
|
||||
def _get_fallback_last_name(self, ethnicity: str) -> str:
|
||||
"""Provide culturally appropriate fallback last names"""
|
||||
fallback_surnames = {
|
||||
'White': ['Smith', 'Johnson', 'Williams', 'Brown', 'Jones', 'Miller', 'Wilson', 'Moore'],
|
||||
'Hispanic': ['García', 'Rodríguez', 'Martínez', 'López', 'González', 'Pérez', 'Sánchez', 'Ramírez'],
|
||||
'Black': ['Johnson', 'Williams', 'Brown', 'Jones', 'Davis', 'Miller', 'Wilson', 'Moore'],
|
||||
'Asian': ['Li', 'Wang', 'Zhang', 'Liu', 'Chen', 'Yang', 'Huang', 'Zhao']
|
||||
"White": ["Smith", "Johnson", "Williams", "Brown", "Jones", "Miller", "Wilson", "Moore"],
|
||||
"Hispanic": ["García", "Rodríguez", "Martínez", "López", "González", "Pérez", "Sánchez", "Ramírez"],
|
||||
"Black": ["Johnson", "Williams", "Brown", "Jones", "Davis", "Miller", "Wilson", "Moore"],
|
||||
"Asian": ["Li", "Wang", "Zhang", "Liu", "Chen", "Yang", "Huang", "Zhao"],
|
||||
}
|
||||
|
||||
return random.choice(fallback_surnames.get(ethnicity, fallback_surnames['White']))
|
||||
return random.choice(fallback_surnames.get(ethnicity, fallback_surnames["White"]))
|
||||
|
||||
def generate_random_name(self, gender: str = 'random') -> Tuple[str, str, str, str]:
|
||||
def generate_random_name(self, gender: str = "random") -> Tuple[str, str, str, str]:
|
||||
"""Generate a random name with ethnicity based on US demographics"""
|
||||
ethnicity = self.get_weighted_ethnicity()
|
||||
return self.get_name_by_ethnicity(ethnicity, gender)
|
||||
|
||||
def generate_multiple_names(self, count: int = 10, gender: str = 'random') -> List[Dict]:
|
||||
def generate_multiple_names(self, count: int = 10, gender: str = "random") -> List[Dict]:
|
||||
"""Generate multiple random names"""
|
||||
names = []
|
||||
for _ in range(count):
|
||||
first, last, ethnicity, actual_gender = self.generate_random_name(gender)
|
||||
names.append({
|
||||
'full_name': f"{first} {last}",
|
||||
'first_name': first,
|
||||
'last_name': last,
|
||||
'ethnicity': ethnicity,
|
||||
'gender': actual_gender
|
||||
})
|
||||
names.append(
|
||||
{
|
||||
"full_name": f"{first} {last}",
|
||||
"first_name": first,
|
||||
"last_name": last,
|
||||
"ethnicity": ethnicity,
|
||||
"gender": actual_gender,
|
||||
}
|
||||
)
|
||||
return names
|
||||
|
||||
|
||||
class GeneratePersona(Agent):
|
||||
agent_type: Literal["generate_persona"] = "generate_persona" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
@ -307,23 +312,20 @@ class GeneratePersona(Agent):
|
||||
self.full_name = f"{self.first_name} {self.last_name}"
|
||||
|
||||
async def generate(
|
||||
self, llm: Any, model: str,
|
||||
session_id: str, prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7
|
||||
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
|
||||
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
|
||||
self.randomize()
|
||||
|
||||
original_prompt = prompt.strip()
|
||||
|
||||
persona = {
|
||||
"age": self.age,
|
||||
"gender": self.gender,
|
||||
"ethnicity": self.ethnicity,
|
||||
"full_name": self.full_name,
|
||||
"first_name": self.first_name,
|
||||
"last_name": self.last_name,
|
||||
}
|
||||
"age": self.age,
|
||||
"gender": self.gender,
|
||||
"ethnicity": self.ethnicity,
|
||||
"full_name": self.full_name,
|
||||
"first_name": self.first_name,
|
||||
"last_name": self.last_name,
|
||||
}
|
||||
|
||||
prompt = f"""\
|
||||
```json
|
||||
@ -339,10 +341,11 @@ Incorporate the following into the job description: {original_prompt}
|
||||
#
|
||||
# Generate the persona
|
||||
#
|
||||
logger.info(f"🤖 Generating persona...")
|
||||
logger.info("🤖 Generating persona...")
|
||||
generating_message = None
|
||||
async for generating_message in self.llm_one_shot(
|
||||
llm=llm, model=model,
|
||||
llm=llm,
|
||||
model=model,
|
||||
session_id=session_id,
|
||||
prompt=prompt,
|
||||
system_prompt=generate_persona_system_prompt(persona=persona),
|
||||
@ -356,8 +359,7 @@ Incorporate the following into the job description: {original_prompt}
|
||||
|
||||
if not generating_message:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Persona generation failed to generate a response."
|
||||
session_id=session_id, content="Persona generation failed to generate a response."
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
@ -375,7 +377,7 @@ Incorporate the following into the job description: {original_prompt}
|
||||
self.username = persona.get("username", None)
|
||||
if not self.username:
|
||||
raise ValueError("LLM did not generate a username")
|
||||
self.username = re.sub(r'\s+', '.', self.username)
|
||||
self.username = re.sub(r"\s+", ".", self.username)
|
||||
user_dir = os.path.join(defines.user_dir, persona["username"])
|
||||
while os.path.exists(user_dir):
|
||||
match = re.match(r"^(.*?)(\d*)$", persona["username"])
|
||||
@ -398,19 +400,14 @@ Incorporate the following into the job description: {original_prompt}
|
||||
location_parts = persona["location"].split(",")
|
||||
if len(location_parts) == 3:
|
||||
city, state, country = [part.strip() for part in location_parts]
|
||||
persona["location"] = {
|
||||
"city": city,
|
||||
"state": state,
|
||||
"country": country
|
||||
}
|
||||
persona["location"] = {"city": city, "state": state, "country": country}
|
||||
else:
|
||||
logger.error(f"Invalid location format: {persona['location']}")
|
||||
persona["location"] = None
|
||||
persona["is_ai"] = True
|
||||
except Exception as e:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content=f"Error parsing LLM response: {str(e)}\n\n{json_str}"
|
||||
session_id=session_id, content=f"Error parsing LLM response: {str(e)}\n\n{json_str}"
|
||||
)
|
||||
logger.error(f"❌ Error parsing LLM response: {error_message.content}")
|
||||
logger.error(traceback.format_exc())
|
||||
@ -422,10 +419,7 @@ Incorporate the following into the job description: {original_prompt}
|
||||
|
||||
# Persona generated
|
||||
persona_message = ChatMessage(
|
||||
session_id=session_id,
|
||||
status=ApiStatusType.DONE,
|
||||
type=ApiMessageType.JSON,
|
||||
content = json.dumps(persona)
|
||||
session_id=session_id, status=ApiStatusType.DONE, type=ApiMessageType.JSON, content=json.dumps(persona)
|
||||
)
|
||||
yield persona_message
|
||||
|
||||
@ -434,8 +428,8 @@ Incorporate the following into the job description: {original_prompt}
|
||||
#
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
activity = ApiActivityType.THINKING,
|
||||
content = f"Generating resume for {persona['full_name']}..."
|
||||
activity=ApiActivityType.THINKING,
|
||||
content=f"Generating resume for {persona['full_name']}...",
|
||||
)
|
||||
logger.info(f"🤖 {status_message.content}")
|
||||
yield status_message
|
||||
@ -458,7 +452,8 @@ Incorporate the following into the job description: {original_prompt}
|
||||
Make sure at least one of the candidate's job descriptions take into account the following: {original_prompt}."""
|
||||
|
||||
async for generating_message in self.llm_one_shot(
|
||||
llm=llm, model=model,
|
||||
llm=llm,
|
||||
model=model,
|
||||
session_id=session_id,
|
||||
prompt=content,
|
||||
system_prompt=generate_resume_system_prompt,
|
||||
@ -472,8 +467,7 @@ Make sure at least one of the candidate's job descriptions take into account the
|
||||
|
||||
if not generating_message:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Resume generation failed to generate a response."
|
||||
session_id=session_id, content="Resume generation failed to generate a response."
|
||||
)
|
||||
logger.error(f"❌ {error_message.content}")
|
||||
yield error_message
|
||||
@ -481,10 +475,7 @@ Make sure at least one of the candidate's job descriptions take into account the
|
||||
|
||||
resume = self.extract_markdown_from_text(generating_message.content)
|
||||
resume_message = ChatMessage(
|
||||
session_id=session_id,
|
||||
status=ApiStatusType.DONE,
|
||||
type=ApiMessageType.TEXT,
|
||||
content=resume
|
||||
session_id=session_id, status=ApiStatusType.DONE, type=ApiMessageType.TEXT, content=resume
|
||||
)
|
||||
yield resume_message
|
||||
return
|
||||
@ -504,5 +495,6 @@ Make sure at least one of the candidate's job descriptions take into account the
|
||||
|
||||
raise ValueError("No JSON found in the response")
|
||||
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(GeneratePersona._agent_type, GeneratePersona)
|
||||
|
@ -1,30 +1,34 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import model_validator, Field # type: ignore
|
||||
from typing import (
|
||||
Dict,
|
||||
Literal,
|
||||
ClassVar,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
List,
|
||||
Optional
|
||||
# override
|
||||
# override
|
||||
) # NOTE: You must import Optional for late binding to work
|
||||
import json
|
||||
import numpy as np # type: ignore
|
||||
|
||||
from logger import logger
|
||||
from .base import Agent, agent_registry
|
||||
from models import (ApiActivityType, ApiMessage, ApiStatusType, ChatMessage, ChatMessageError, ChatMessageResume, ChatMessageStatus, SkillAssessment, SkillStrength)
|
||||
from models import (
|
||||
ApiActivityType,
|
||||
ApiMessage,
|
||||
ApiStatusType,
|
||||
ChatMessage,
|
||||
ChatMessageError,
|
||||
ChatMessageResume,
|
||||
ChatMessageStatus,
|
||||
SkillAssessment,
|
||||
SkillStrength,
|
||||
)
|
||||
|
||||
|
||||
class GenerateResume(Agent):
|
||||
agent_type: Literal["generate_resume"] = "generate_resume" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
def generate_resume_prompt(
|
||||
self,
|
||||
skills: List[SkillAssessment]
|
||||
):
|
||||
def generate_resume_prompt(self, skills: List[SkillAssessment]):
|
||||
"""
|
||||
Generate a professional resume based on skill assessment results
|
||||
|
||||
@ -45,7 +49,7 @@ class GenerateResume(Agent):
|
||||
SkillStrength.STRONG: [],
|
||||
SkillStrength.MODERATE: [],
|
||||
SkillStrength.WEAK: [],
|
||||
SkillStrength.NONE: []
|
||||
SkillStrength.NONE: [],
|
||||
}
|
||||
|
||||
experience_evidence = {}
|
||||
@ -67,11 +71,7 @@ class GenerateResume(Agent):
|
||||
experience_evidence[source] = []
|
||||
|
||||
experience_evidence[source].append(
|
||||
{
|
||||
"skill": skill,
|
||||
"quote": evidence.quote,
|
||||
"context": evidence.context
|
||||
}
|
||||
{"skill": skill, "quote": evidence.quote, "context": evidence.context}
|
||||
)
|
||||
|
||||
# Build the system prompt
|
||||
@ -171,16 +171,16 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
|
||||
) -> AsyncGenerator[ApiMessage, None]:
|
||||
# Stage 1A: Analyze job requirements
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
content = f"Analyzing job requirements",
|
||||
activity=ApiActivityType.THINKING
|
||||
session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING
|
||||
)
|
||||
yield status_message
|
||||
|
||||
system_prompt, prompt = self.generate_resume_prompt(skills=skills)
|
||||
|
||||
generated_message = None
|
||||
async for generated_message in self.llm_one_shot(llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt):
|
||||
async for generated_message in self.llm_one_shot(
|
||||
llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt
|
||||
):
|
||||
if generated_message.status == ApiStatusType.ERROR:
|
||||
yield generated_message
|
||||
return
|
||||
@ -189,8 +189,7 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
|
||||
|
||||
if not generated_message:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Job requirements analysis failed to generate a response."
|
||||
session_id=session_id, content="Job requirements analysis failed to generate a response."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
@ -198,8 +197,7 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
|
||||
|
||||
if not isinstance(generated_message, ChatMessage):
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Job requirements analysis did not return a valid message."
|
||||
session_id=session_id, content="Job requirements analysis did not return a valid message."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
@ -215,8 +213,9 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
yield resume_message
|
||||
logger.info(f"✅ Resume generation completed successfully.")
|
||||
logger.info("✅ Resume generation completed successfully.")
|
||||
return
|
||||
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(GenerateResume._agent_type, GenerateResume)
|
||||
|
@ -1,24 +1,34 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import model_validator, Field # type: ignore
|
||||
from typing import (
|
||||
Dict,
|
||||
Literal,
|
||||
ClassVar,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
List,
|
||||
Optional
|
||||
# override
|
||||
Optional,
|
||||
# override
|
||||
) # NOTE: You must import Optional for late binding to work
|
||||
import inspect
|
||||
import json
|
||||
import numpy as np # type: ignore
|
||||
|
||||
from .base import Agent, agent_registry
|
||||
from models import ApiActivityType, ApiMessage, ChatMessage, ChatMessageError, ChatMessageStatus, ChatMessageStreaming, ApiStatusType, Job, JobRequirements, JobRequirementsMessage, Tunables
|
||||
from models import (
|
||||
ApiActivityType,
|
||||
ApiMessage,
|
||||
ChatMessage,
|
||||
ChatMessageError,
|
||||
ChatMessageStatus,
|
||||
ChatMessageStreaming,
|
||||
ApiStatusType,
|
||||
Job,
|
||||
JobRequirements,
|
||||
JobRequirementsMessage,
|
||||
Tunables,
|
||||
)
|
||||
from logger import logger
|
||||
import backstory_traceback as traceback
|
||||
|
||||
|
||||
class JobRequirementsAgent(Agent):
|
||||
agent_type: Literal["job_requirements"] = "job_requirements" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
@ -94,14 +104,14 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
|
||||
"""Analyze job requirements from job description."""
|
||||
system_prompt, prompt = self.create_job_analysis_prompt(prompt)
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
content="Analyzing job requirements",
|
||||
activity=ApiActivityType.THINKING
|
||||
session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING
|
||||
)
|
||||
yield status_message
|
||||
logger.info(f"🔍 {status_message.content}")
|
||||
generated_message = None
|
||||
async for generated_message in self.llm_one_shot(llm, model, session_id=session_id, prompt=prompt, system_prompt=system_prompt):
|
||||
async for generated_message in self.llm_one_shot(
|
||||
llm, model, session_id=session_id, prompt=prompt, system_prompt=system_prompt
|
||||
):
|
||||
if generated_message.status == ApiStatusType.ERROR:
|
||||
yield generated_message
|
||||
return
|
||||
@ -110,8 +120,8 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
|
||||
|
||||
if not generated_message:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Job requirements analysis failed to generate a response.")
|
||||
session_id=session_id, content="Job requirements analysis failed to generate a response."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
return
|
||||
@ -132,18 +142,18 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
|
||||
display = {
|
||||
"technical_skills": {
|
||||
"required": reqs.technical_skills.required,
|
||||
"preferred": reqs.technical_skills.preferred
|
||||
"preferred": reqs.technical_skills.preferred,
|
||||
},
|
||||
"experience_requirements": {
|
||||
"required": reqs.experience_requirements.required,
|
||||
"preferred": reqs.experience_requirements.preferred
|
||||
"preferred": reqs.experience_requirements.preferred,
|
||||
},
|
||||
"soft_skills": reqs.soft_skills,
|
||||
"experience": reqs.experience,
|
||||
"education": reqs.education,
|
||||
"certifications": reqs.certifications,
|
||||
"preferred_attributes": reqs.preferred_attributes,
|
||||
"company_values": reqs.company_values
|
||||
"company_values": reqs.company_values,
|
||||
}
|
||||
|
||||
return display
|
||||
@ -152,19 +162,14 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
|
||||
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
|
||||
) -> AsyncGenerator[ApiMessage, None]:
|
||||
if not self.user:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="User is not set for this agent."
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content="User is not set for this agent.")
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
# Stage 1A: Analyze job requirements
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
content = f"Analyzing job requirements",
|
||||
activity=ApiActivityType.THINKING
|
||||
session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING
|
||||
)
|
||||
yield status_message
|
||||
|
||||
@ -178,8 +183,7 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
|
||||
|
||||
if not generated_message:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Job requirements analysis failed to generate a response."
|
||||
session_id=session_id, content="Job requirements analysis failed to generate a response."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
@ -214,7 +218,9 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
|
||||
return
|
||||
except Exception as e:
|
||||
status_message.status = ApiStatusType.ERROR
|
||||
status_message.content = f"Unexpected error processing job requirements: {str(e)}\n\n{job_requirements_data}"
|
||||
status_message.content = (
|
||||
f"Unexpected error processing job requirements: {str(e)}\n\n{job_requirements_data}"
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"⚠️ {status_message.content}")
|
||||
yield status_message
|
||||
@ -238,8 +244,9 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
|
||||
job=job,
|
||||
)
|
||||
yield job_requirements_message
|
||||
logger.info(f"✅ Job requirements analysis completed successfully.")
|
||||
logger.info("✅ Job requirements analysis completed successfully.")
|
||||
return
|
||||
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(JobRequirementsAgent._agent_type, JobRequirementsAgent)
|
||||
|
@ -5,20 +5,19 @@ from .base import Agent, agent_registry
|
||||
from logger import logger
|
||||
|
||||
from .registry import agent_registry
|
||||
from models import ( ApiMessage, ApiStatusType, ChatMessageError, ChatMessageRagSearch, ApiStatusType, Tunables )
|
||||
from models import ApiMessage, ApiStatusType, ChatMessageError, ChatMessageRagSearch, ApiStatusType, Tunables
|
||||
|
||||
|
||||
class Chat(Agent):
|
||||
"""
|
||||
Chat Agent
|
||||
"""
|
||||
|
||||
agent_type: Literal["rag_search"] = "rag_search" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
async def generate(
|
||||
self, llm: Any, model: str,
|
||||
session_id: str, prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7
|
||||
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
|
||||
) -> AsyncGenerator[ApiMessage, None]:
|
||||
"""
|
||||
Generate a response based on the user message and the provided LLM.
|
||||
@ -44,8 +43,7 @@ class Chat(Agent):
|
||||
if not isinstance(rag_message, ChatMessageRagSearch):
|
||||
logger.error(f"Expected ChatMessageRagSearch, got {type(rag_message)}")
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="RAG search did not return a valid response."
|
||||
session_id=session_id, content="RAG search did not return a valid response."
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
@ -53,5 +51,6 @@ class Chat(Agent):
|
||||
rag_message.status = ApiStatusType.DONE
|
||||
yield rag_message
|
||||
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(Chat._agent_type, Chat)
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Dict, Optional, Type
|
||||
|
||||
|
||||
# We'll use a registry pattern rather than hardcoded strings
|
||||
class AgentRegistry:
|
||||
"""Registry for agent types and classes"""
|
||||
|
@ -1,24 +1,30 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import model_validator, Field # type: ignore
|
||||
from typing import (
|
||||
Dict,
|
||||
Literal,
|
||||
ClassVar,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
List,
|
||||
Optional
|
||||
# override
|
||||
Optional,
|
||||
# override
|
||||
) # NOTE: You must import Optional for late binding to work
|
||||
import json
|
||||
import numpy as np # type: ignore
|
||||
|
||||
from .base import Agent, agent_registry
|
||||
from models import (ApiMessage, ChatMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageSkillAssessment, ApiStatusType, EvidenceDetail,
|
||||
SkillAssessment, Tunables)
|
||||
from models import (
|
||||
ApiMessage,
|
||||
ChatMessage,
|
||||
ChatMessageError,
|
||||
ChatMessageRagSearch,
|
||||
ChatMessageSkillAssessment,
|
||||
ApiStatusType,
|
||||
EvidenceDetail,
|
||||
SkillAssessment,
|
||||
Tunables,
|
||||
)
|
||||
from logger import logger
|
||||
import backstory_traceback as traceback
|
||||
|
||||
|
||||
class SkillMatchAgent(Agent):
|
||||
agent_type: Literal["skill_match"] = "skill_match" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
@ -100,15 +106,12 @@ JSON RESPONSE:"""
|
||||
return system_prompt, prompt
|
||||
|
||||
async def generate(
|
||||
self, llm: Any, model: str,
|
||||
session_id: str, prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7
|
||||
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
|
||||
) -> AsyncGenerator[ApiMessage, None]:
|
||||
if not self.user:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Agent not attached to user. Attach the agent to a user before generating responses."
|
||||
content="Agent not attached to user. Attach the agent to a user before generating responses.",
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
@ -116,10 +119,7 @@ JSON RESPONSE:"""
|
||||
|
||||
skill = prompt.strip()
|
||||
if not skill:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Skill cannot be empty."
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content="Skill cannot be empty.")
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
return
|
||||
@ -134,8 +134,7 @@ JSON RESPONSE:"""
|
||||
|
||||
if generated_message is None:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="RAG search did not return a valid response."
|
||||
session_id=session_id, content="RAG search did not return a valid response."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
@ -144,19 +143,20 @@ JSON RESPONSE:"""
|
||||
if not isinstance(generated_message, ChatMessageRagSearch):
|
||||
logger.error(f"Expected ChatMessageRagSearch, got {type(generated_message)}")
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="RAG search did not return a valid response."
|
||||
session_id=session_id, content="RAG search did not return a valid response."
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
rag_message : ChatMessageRagSearch = generated_message
|
||||
rag_message: ChatMessageRagSearch = generated_message
|
||||
|
||||
rag_context = self.get_rag_context(rag_message)
|
||||
logger.info(f"🔍 RAG content retrieved {len(rag_context)} bytes of context")
|
||||
system_prompt, prompt = self.generate_skill_assessment_prompt(skill=skill, rag_context=rag_context)
|
||||
|
||||
generated_message = None
|
||||
async for generated_message in self.llm_one_shot(llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt, temperature=0.7):
|
||||
async for generated_message in self.llm_one_shot(
|
||||
llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt, temperature=0.7
|
||||
):
|
||||
if generated_message.status == ApiStatusType.ERROR:
|
||||
logger.error(f"⚠️ {generated_message.content}")
|
||||
yield generated_message
|
||||
@ -166,8 +166,7 @@ JSON RESPONSE:"""
|
||||
|
||||
if generated_message is None:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Skill assessment failed to generate a response."
|
||||
session_id=session_id, content="Skill assessment failed to generate a response."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
@ -175,8 +174,7 @@ JSON RESPONSE:"""
|
||||
|
||||
if not isinstance(generated_message, ChatMessage):
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Skill assessment did not return a valid message."
|
||||
session_id=session_id, content="Skill assessment did not return a valid message."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
@ -199,14 +197,15 @@ JSON RESPONSE:"""
|
||||
EvidenceDetail(
|
||||
source=evidence.get("source", ""),
|
||||
quote=evidence.get("quote", ""),
|
||||
context=evidence.get("context", "")
|
||||
) for evidence in skill_assessment_data.get("evidence_details", [])
|
||||
]
|
||||
context=evidence.get("context", ""),
|
||||
)
|
||||
for evidence in skill_assessment_data.get("evidence_details", [])
|
||||
],
|
||||
)
|
||||
except Exception as e:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content=f"Failed to parse Skill assessment JSON: {str(e)}\n\n{generated_message.content}\n\nJSON:\n{json_str}\n\n"
|
||||
content=f"Failed to parse Skill assessment JSON: {str(e)}\n\n{generated_message.content}\n\nJSON:\n{json_str}\n\n",
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
@ -233,8 +232,9 @@ JSON RESPONSE:"""
|
||||
skill_assessment=skill_assessment,
|
||||
)
|
||||
yield skill_assessment_message
|
||||
logger.info(f"✅ Skill assessment completed successfully.")
|
||||
logger.info("✅ Skill assessment completed successfully.")
|
||||
return
|
||||
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(SkillMatchAgent._agent_type, SkillMatchAgent)
|
||||
|
@ -9,6 +9,7 @@ from typing import Optional, List, Dict, Any, Callable
|
||||
from logger import logger
|
||||
from database.manager import DatabaseManager
|
||||
|
||||
|
||||
class BackgroundTaskManager:
|
||||
"""Manages background tasks for the application using asyncio instead of threading"""
|
||||
|
||||
@ -65,10 +66,12 @@ class BackgroundTaskManager:
|
||||
stats = await database.get_guest_statistics()
|
||||
|
||||
# Log interesting statistics
|
||||
if stats.get('total_guests', 0) > 0:
|
||||
logger.info(f"📊 Guest stats: {stats['total_guests']} total, "
|
||||
if stats.get("total_guests", 0) > 0:
|
||||
logger.info(
|
||||
f"📊 Guest stats: {stats['total_guests']} total, "
|
||||
f"{stats['active_last_hour']} active in last hour, "
|
||||
f"{stats['converted_guests']} converted")
|
||||
f"{stats['converted_guests']} converted"
|
||||
)
|
||||
|
||||
return stats
|
||||
except Exception as e:
|
||||
@ -84,6 +87,7 @@ class BackgroundTaskManager:
|
||||
|
||||
# Get Redis client safely (using the event loop safe method)
|
||||
from database.manager import redis_manager
|
||||
|
||||
redis = await redis_manager.get_client()
|
||||
|
||||
# Clean up rate limit keys older than specified days
|
||||
@ -192,10 +196,7 @@ class BackgroundTaskManager:
|
||||
|
||||
# Create asyncio tasks for each periodic task
|
||||
for name, func, interval, *args in periodic_tasks:
|
||||
task = asyncio.create_task(
|
||||
self._run_periodic_task(name, func, interval, *args),
|
||||
name=f"background_{name}"
|
||||
)
|
||||
task = asyncio.create_task(self._run_periodic_task(name, func, interval, *args), name=f"background_{name}")
|
||||
self.tasks.append(task)
|
||||
logger.info(f"📅 Scheduled background task: {name}")
|
||||
|
||||
@ -238,10 +239,7 @@ class BackgroundTaskManager:
|
||||
# Wait for all tasks to complete with timeout
|
||||
if self.tasks:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*self.tasks, return_exceptions=True),
|
||||
timeout=30.0
|
||||
)
|
||||
await asyncio.wait_for(asyncio.gather(*self.tasks, return_exceptions=True), timeout=30.0)
|
||||
logger.info("✅ All background tasks stopped gracefully")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("⚠️ Some background tasks did not stop within timeout")
|
||||
@ -258,7 +256,7 @@ class BackgroundTaskManager:
|
||||
"main_loop_id": id(self.main_loop) if self.main_loop else None,
|
||||
"current_loop_id": None,
|
||||
"task_count": len(self.tasks),
|
||||
"tasks": []
|
||||
"tasks": [],
|
||||
}
|
||||
|
||||
try:
|
||||
@ -317,6 +315,7 @@ async def setup_background_tasks(database_manager: DatabaseManager) -> Backgroun
|
||||
await task_manager.start()
|
||||
return task_manager
|
||||
|
||||
|
||||
# For integration with your existing app startup
|
||||
async def initialize_with_background_tasks(database_manager: DatabaseManager):
|
||||
"""Initialize database and background tasks together"""
|
||||
|
@ -3,6 +3,7 @@ import os
|
||||
import sys
|
||||
import defines
|
||||
|
||||
|
||||
def filter_traceback(tb, app_path=None, module_name=None):
|
||||
"""
|
||||
Filter traceback to include only frames from the specified application path or module.
|
||||
@ -36,7 +37,8 @@ def filter_traceback(tb, app_path=None, module_name=None):
|
||||
formatted_exc = traceback.format_exception_only(exc_type, exc_value)
|
||||
|
||||
# Combine the filtered stack trace with the exception message
|
||||
return ''.join(formatted_stack + formatted_exc)
|
||||
return "".join(formatted_stack + formatted_exc)
|
||||
|
||||
|
||||
def format_exc(app_path=defines.app_path, module_name=None):
|
||||
"""
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .core import RedisDatabase
|
||||
from .manager import DatabaseManager, redis_manager
|
||||
|
||||
__all__ = ['RedisDatabase', 'DatabaseManager', 'redis_manager']
|
||||
__all__ = ["RedisDatabase", "DatabaseManager", "redis_manager"]
|
||||
|
@ -1,15 +1,15 @@
|
||||
KEY_PREFIXES = {
|
||||
'viewers': 'viewer:',
|
||||
'candidates': 'candidate:',
|
||||
'employers': 'employer:',
|
||||
'jobs': 'job:',
|
||||
'job_applications': 'job_application:',
|
||||
'chat_sessions': 'chat_session:',
|
||||
'chat_messages': 'chat_messages:',
|
||||
'ai_parameters': 'ai_parameters:',
|
||||
'users': 'user:',
|
||||
'candidate_documents': 'candidate_documents:',
|
||||
'job_requirements': 'job_requirements:',
|
||||
'resumes': 'resume:',
|
||||
'user_resumes': 'user_resumes:',
|
||||
"viewers": "viewer:",
|
||||
"candidates": "candidate:",
|
||||
"employers": "employer:",
|
||||
"jobs": "job:",
|
||||
"job_applications": "job_application:",
|
||||
"chat_sessions": "chat_session:",
|
||||
"chat_messages": "chat_messages:",
|
||||
"ai_parameters": "ai_parameters:",
|
||||
"users": "user:",
|
||||
"candidate_documents": "candidate_documents:",
|
||||
"job_requirements": "job_requirements:",
|
||||
"resumes": "resume:",
|
||||
"user_resumes": "user_resumes:",
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ from .mixins.job import JobMixin
|
||||
from .mixins.skill import SkillMixin
|
||||
from .mixins.ai import AIMixin
|
||||
|
||||
|
||||
# RedisDatabase is the main class that combines all mixins for a
|
||||
# comprehensive Redis database interface.
|
||||
class RedisDatabase(
|
||||
|
@ -1,18 +1,15 @@
|
||||
from redis.asyncio import (Redis, ConnectionPool)
|
||||
from redis.asyncio import Redis, ConnectionPool
|
||||
from typing import Optional, Optional
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, UTC
|
||||
import asyncio
|
||||
from models import (
|
||||
# User models
|
||||
Candidate, Employer, BaseUser, EvidenceDetail, Guest, Authentication, AuthResponse, SkillAssessment,
|
||||
)
|
||||
from .core import RedisDatabase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# _RedisManager is a singleton class that manages the Redis connection and
|
||||
# provides methods for connecting, disconnecting, and performing health checks.
|
||||
#
|
||||
@ -46,12 +43,10 @@ class _RedisManager:
|
||||
retry_on_timeout=True,
|
||||
socket_keepalive=True,
|
||||
socket_keepalive_options={},
|
||||
health_check_interval=30
|
||||
health_check_interval=30,
|
||||
)
|
||||
|
||||
self.redis = Redis(
|
||||
connection_pool=self._connection_pool
|
||||
)
|
||||
self.redis = Redis(connection_pool=self._connection_pool)
|
||||
|
||||
if not self.redis:
|
||||
raise RuntimeError("Redis client not initialized")
|
||||
@ -135,7 +130,7 @@ class _RedisManager:
|
||||
"uptime_seconds": info.get("uptime_in_seconds", 0),
|
||||
"connected_clients": info.get("connected_clients", 0),
|
||||
"used_memory_human": info.get("used_memory_human", "unknown"),
|
||||
"total_commands_processed": info.get("total_commands_processed", 0)
|
||||
"total_commands_processed": info.get("total_commands_processed", 0),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Redis health check failed: {e}")
|
||||
@ -177,9 +172,11 @@ class _RedisManager:
|
||||
logger.error(f"Failed to get Redis info: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# Global Redis manager instance
|
||||
redis_manager = _RedisManager()
|
||||
|
||||
|
||||
# DatabaseManager is an enhanced database manager that provides graceful shutdown capabilities
|
||||
# It manages the Redis connection, tracks active requests, and allows for data backup before shutdown.
|
||||
class DatabaseManager:
|
||||
@ -231,7 +228,7 @@ class DatabaseManager:
|
||||
backup_filename = f"backup_{datetime.now(UTC).strftime('%Y%m%d_%H%M%S')}.json"
|
||||
|
||||
# Save to local file (you might want to save to cloud storage instead)
|
||||
with open(backup_filename, 'w') as f:
|
||||
with open(backup_filename, "w") as f:
|
||||
json.dump(backup_data, f, indent=2, default=str)
|
||||
|
||||
logger.info(f"Backup created: {backup_filename}")
|
||||
@ -314,5 +311,3 @@ class DatabaseManager:
|
||||
if self._shutdown_initiated:
|
||||
raise RuntimeError("Application is shutting down")
|
||||
return self.db
|
||||
|
||||
|
@ -8,8 +8,10 @@ from ..constants import KEY_PREFIXES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIMixin(DatabaseProtocol):
|
||||
"""Mixin for AI operations"""
|
||||
|
||||
async def get_ai_parameters(self, param_id: str) -> Optional[Dict]:
|
||||
"""Get AI parameters by ID"""
|
||||
key = f"{KEY_PREFIXES['ai_parameters']}{param_id}"
|
||||
@ -36,7 +38,7 @@ class AIMixin(DatabaseProtocol):
|
||||
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
param_id = key.replace(KEY_PREFIXES['ai_parameters'], '')
|
||||
param_id = key.replace(KEY_PREFIXES["ai_parameters"], "")
|
||||
result[param_id] = self._deserialize(value)
|
||||
|
||||
return result
|
||||
@ -45,4 +47,3 @@ class AIMixin(DatabaseProtocol):
|
||||
"""Delete AI parameters"""
|
||||
key = f"{KEY_PREFIXES['ai_parameters']}{param_id}"
|
||||
await self.redis.delete(key)
|
||||
|
||||
|
@ -7,6 +7,6 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnalyticsMixin:
|
||||
"""Mixin for analytics-related database operations"""
|
||||
|
||||
|
@ -8,6 +8,7 @@ from .protocols import DatabaseProtocol
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthMixin(DatabaseProtocol):
|
||||
"""Mixin for auth-related database operations"""
|
||||
|
||||
@ -25,8 +26,9 @@ class AuthMixin(DatabaseProtocol):
|
||||
token_data = await self.redis.get(key)
|
||||
if token_data:
|
||||
verification_info = json.loads(token_data)
|
||||
if (verification_info.get("email", "").lower() == email_lower and
|
||||
not verification_info.get("verified", False)):
|
||||
if verification_info.get("email", "").lower() == email_lower and not verification_info.get(
|
||||
"verified", False
|
||||
):
|
||||
# Extract token from key
|
||||
token = key.replace("email_verification:", "")
|
||||
verification_info["token"] = token
|
||||
@ -115,10 +117,7 @@ class AuthMixin(DatabaseProtocol):
|
||||
window_start = current_time - timedelta(hours=24)
|
||||
|
||||
# Filter out old attempts
|
||||
recent_attempts = [
|
||||
attempt for attempt in attempts_data
|
||||
if datetime.fromisoformat(attempt) > window_start
|
||||
]
|
||||
recent_attempts = [attempt for attempt in attempts_data if datetime.fromisoformat(attempt) > window_start]
|
||||
|
||||
return len(recent_attempts)
|
||||
|
||||
@ -141,16 +140,13 @@ class AuthMixin(DatabaseProtocol):
|
||||
|
||||
# Keep only last 24 hours of attempts
|
||||
window_start = current_time - timedelta(hours=24)
|
||||
recent_attempts = [
|
||||
attempt for attempt in attempts_data
|
||||
if datetime.fromisoformat(attempt) > window_start
|
||||
]
|
||||
recent_attempts = [attempt for attempt in attempts_data if datetime.fromisoformat(attempt) > window_start]
|
||||
|
||||
# Store with 24 hour expiration
|
||||
await self.redis.setex(
|
||||
key,
|
||||
24 * 60 * 60, # 24 hours
|
||||
json.dumps(recent_attempts)
|
||||
json.dumps(recent_attempts),
|
||||
)
|
||||
|
||||
return True
|
||||
@ -169,14 +165,14 @@ class AuthMixin(DatabaseProtocol):
|
||||
"user_data": user_data,
|
||||
"expires_at": (datetime.now(timezone.utc) + timedelta(hours=24)).isoformat(),
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"verified": False
|
||||
"verified": False,
|
||||
}
|
||||
|
||||
# Store with 24 hour expiration
|
||||
await self.redis.setex(
|
||||
key,
|
||||
24 * 60 * 60, # 24 hours in seconds
|
||||
json.dumps(verification_data, default=str)
|
||||
json.dumps(verification_data, default=str),
|
||||
)
|
||||
|
||||
logger.info(f"📧 Stored email verification token for {email}")
|
||||
@ -208,7 +204,7 @@ class AuthMixin(DatabaseProtocol):
|
||||
await self.redis.setex(
|
||||
key,
|
||||
24 * 60 * 60, # Keep for remaining TTL
|
||||
json.dumps(token_data, default=str)
|
||||
json.dumps(token_data, default=str),
|
||||
)
|
||||
return True
|
||||
return False
|
||||
@ -219,7 +215,7 @@ class AuthMixin(DatabaseProtocol):
|
||||
async def store_mfa_code(self, email: str, code: str, device_id: str) -> bool:
|
||||
"""Store MFA code for verification"""
|
||||
try:
|
||||
logger.info("🔐 Storing MFA code for email: %s", email )
|
||||
logger.info("🔐 Storing MFA code for email: %s", email)
|
||||
key = f"mfa_code:{email.lower()}:{device_id}"
|
||||
mfa_data = {
|
||||
"code": code,
|
||||
@ -228,14 +224,14 @@ class AuthMixin(DatabaseProtocol):
|
||||
"expires_at": (datetime.now(timezone.utc) + timedelta(minutes=10)).isoformat(),
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"attempts": 0,
|
||||
"verified": False
|
||||
"verified": False,
|
||||
}
|
||||
|
||||
# Store with 10 minute expiration
|
||||
await self.redis.setex(
|
||||
key,
|
||||
10 * 60, # 10 minutes in seconds
|
||||
json.dumps(mfa_data, default=str)
|
||||
json.dumps(mfa_data, default=str),
|
||||
)
|
||||
|
||||
logger.info(f"🔐 Stored MFA code for {email}")
|
||||
@ -266,7 +262,7 @@ class AuthMixin(DatabaseProtocol):
|
||||
await self.redis.setex(
|
||||
key,
|
||||
10 * 60, # Keep original TTL
|
||||
json.dumps(mfa_data, default=str)
|
||||
json.dumps(mfa_data, default=str),
|
||||
)
|
||||
return mfa_data["attempts"]
|
||||
return 0
|
||||
@ -285,7 +281,7 @@ class AuthMixin(DatabaseProtocol):
|
||||
await self.redis.setex(
|
||||
key,
|
||||
10 * 60, # Keep for remaining TTL
|
||||
json.dumps(mfa_data, default=str)
|
||||
json.dumps(mfa_data, default=str),
|
||||
)
|
||||
return True
|
||||
return False
|
||||
@ -327,7 +323,9 @@ class AuthMixin(DatabaseProtocol):
|
||||
logger.error(f"❌ Error deleting authentication record for {user_id}: {e}")
|
||||
return False
|
||||
|
||||
async def store_refresh_token(self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]) -> bool:
|
||||
async def store_refresh_token(
|
||||
self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]
|
||||
) -> bool:
|
||||
"""Store refresh token for a user"""
|
||||
try:
|
||||
key = f"refresh_token:{token}"
|
||||
@ -337,7 +335,7 @@ class AuthMixin(DatabaseProtocol):
|
||||
"device": device_info.get("device", "unknown"),
|
||||
"ip_address": device_info.get("ip_address", "unknown"),
|
||||
"is_revoked": False,
|
||||
"created_at": datetime.now(timezone.utc).isoformat()
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
# Store with expiration
|
||||
@ -374,7 +372,7 @@ class AuthMixin(DatabaseProtocol):
|
||||
token_data["is_revoked"] = True
|
||||
token_data["revoked_at"] = datetime.now(timezone.utc).isoformat()
|
||||
await self.redis.set(key, json.dumps(token_data, default=str))
|
||||
logger.info(f"🔐 Revoked refresh token")
|
||||
logger.info("🔐 Revoked refresh token")
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
@ -420,7 +418,7 @@ class AuthMixin(DatabaseProtocol):
|
||||
"email": email.lower(),
|
||||
"expires_at": expires_at.isoformat(),
|
||||
"used": False,
|
||||
"created_at": datetime.now(timezone.utc).isoformat()
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
# Store with expiration
|
||||
@ -457,7 +455,7 @@ class AuthMixin(DatabaseProtocol):
|
||||
token_data["used"] = True
|
||||
token_data["used_at"] = datetime.now(timezone.utc).isoformat()
|
||||
await self.redis.set(key, json.dumps(token_data, default=str))
|
||||
logger.info(f"🔐 Marked password reset token as used")
|
||||
logger.info("🔐 Marked password reset token as used")
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
@ -473,14 +471,14 @@ class AuthMixin(DatabaseProtocol):
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"user_id": user_id,
|
||||
"event_type": event_type,
|
||||
"details": details
|
||||
"details": details,
|
||||
}
|
||||
|
||||
# Add to list (latest events first)
|
||||
await self.redis.lpush(key, json.dumps(event_data, default=str))# type: ignore
|
||||
await self.redis.lpush(key, json.dumps(event_data, default=str)) # type: ignore
|
||||
|
||||
# Keep only last 100 events per day
|
||||
await self.redis.ltrim(key, 0, 99)# type: ignore
|
||||
await self.redis.ltrim(key, 0, 99) # type: ignore
|
||||
|
||||
# Set expiration for 30 days
|
||||
await self.redis.expire(key, 30 * 24 * 60 * 60)
|
||||
@ -496,10 +494,10 @@ class AuthMixin(DatabaseProtocol):
|
||||
try:
|
||||
events = []
|
||||
for i in range(days):
|
||||
date = (datetime.now(timezone.utc) - timedelta(days=i)).strftime('%Y-%m-%d')
|
||||
date = (datetime.now(timezone.utc) - timedelta(days=i)).strftime("%Y-%m-%d")
|
||||
key = f"security_log:{user_id}:{date}"
|
||||
|
||||
daily_events = await self.redis.lrange(key, 0, -1)# type: ignore
|
||||
daily_events = await self.redis.lrange(key, 0, -1) # type: ignore
|
||||
for event_json in daily_events:
|
||||
events.append(json.loads(event_json))
|
||||
|
||||
@ -509,4 +507,3 @@ class AuthMixin(DatabaseProtocol):
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error retrieving security log for {user_id}: {e}")
|
||||
return []
|
||||
|
||||
|
@ -5,11 +5,13 @@ from typing import Any, Dict, TYPE_CHECKING
|
||||
from .protocols import DatabaseProtocol
|
||||
|
||||
from ..constants import KEY_PREFIXES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class BaseMixin(DatabaseProtocol):
|
||||
"""Base mixin with core Redis operations and utilities"""
|
||||
|
||||
@ -45,6 +47,3 @@ class BaseMixin(DatabaseProtocol):
|
||||
keys = await self.redis.keys(pattern)
|
||||
if keys:
|
||||
await self.redis.delete(*keys)
|
||||
|
||||
|
||||
|
@ -8,6 +8,7 @@ from ..constants import KEY_PREFIXES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatMixin(DatabaseProtocol):
|
||||
"""Mixin for chat-related database operations"""
|
||||
|
||||
@ -22,7 +23,7 @@ class ChatMixin(DatabaseProtocol):
|
||||
"total_sessions": 0,
|
||||
"total_messages": 0,
|
||||
"first_chat": None,
|
||||
"last_chat": None
|
||||
"last_chat": None,
|
||||
}
|
||||
|
||||
total_messages = 0
|
||||
@ -41,7 +42,7 @@ class ChatMixin(DatabaseProtocol):
|
||||
"total_messages": total_messages,
|
||||
"first_chat": sessions_by_date[0].get("createdAt") if sessions_by_date else None,
|
||||
"last_chat": sessions_by_date[-1].get("lastActivity") if sessions_by_date else None,
|
||||
"recent_sessions": sessions[:5] # Last 5 sessions
|
||||
"recent_sessions": sessions[:5], # Last 5 sessions
|
||||
}
|
||||
|
||||
# Chat Sessions operations
|
||||
@ -71,13 +72,13 @@ class ChatMixin(DatabaseProtocol):
|
||||
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
session_id = key.replace(KEY_PREFIXES['chat_sessions'], '')
|
||||
session_id = key.replace(KEY_PREFIXES["chat_sessions"], "")
|
||||
result[session_id] = self._deserialize(value)
|
||||
|
||||
return result
|
||||
|
||||
async def delete_chat_session(self, session_id: str) -> bool:
|
||||
'''Delete a chat session from Redis'''
|
||||
"""Delete a chat session from Redis"""
|
||||
try:
|
||||
result = await self.redis.delete(f"chat_session:{session_id}")
|
||||
return result > 0
|
||||
@ -86,11 +87,11 @@ class ChatMixin(DatabaseProtocol):
|
||||
raise
|
||||
|
||||
async def delete_chat_message(self, session_id: str, message_id: str) -> bool:
|
||||
'''Delete a specific chat message from Redis'''
|
||||
"""Delete a specific chat message from Redis"""
|
||||
try:
|
||||
# Remove from the session's message list
|
||||
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
|
||||
await self.redis.lrem(key, 0, message_id)# type: ignore
|
||||
await self.redis.lrem(key, 0, message_id) # type: ignore
|
||||
# Delete the message data itself
|
||||
result = await self.redis.delete(f"chat_message:{message_id}")
|
||||
return result > 0
|
||||
@ -102,13 +103,13 @@ class ChatMixin(DatabaseProtocol):
|
||||
async def get_chat_messages(self, session_id: str) -> List[Dict]:
|
||||
"""Get chat messages for a session"""
|
||||
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
|
||||
messages = await self.redis.lrange(key, 0, -1)# type: ignore
|
||||
messages = await self.redis.lrange(key, 0, -1) # type: ignore
|
||||
return [self._deserialize(msg) for msg in messages if msg] # type: ignore
|
||||
|
||||
async def add_chat_message(self, session_id: str, message_data: Dict):
|
||||
"""Add a chat message to a session"""
|
||||
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
|
||||
await self.redis.rpush(key, self._serialize(message_data))# type: ignore
|
||||
await self.redis.rpush(key, self._serialize(message_data)) # type: ignore
|
||||
|
||||
async def set_chat_messages(self, session_id: str, messages: List[Dict]):
|
||||
"""Set all chat messages for a session (replaces existing)"""
|
||||
@ -120,7 +121,7 @@ class ChatMixin(DatabaseProtocol):
|
||||
# Add new messages
|
||||
if messages:
|
||||
serialized_messages = [self._serialize(msg) for msg in messages]
|
||||
await self.redis.rpush(key, *serialized_messages)# type: ignore
|
||||
await self.redis.rpush(key, *serialized_messages) # type: ignore
|
||||
|
||||
async def get_all_chat_messages(self) -> Dict[str, List[Dict]]:
|
||||
"""Get all chat messages grouped by session"""
|
||||
@ -132,8 +133,8 @@ class ChatMixin(DatabaseProtocol):
|
||||
|
||||
result = {}
|
||||
for key in keys:
|
||||
session_id = key.replace(KEY_PREFIXES['chat_messages'], '')
|
||||
messages = await self.redis.lrange(key, 0, -1)# type: ignore
|
||||
session_id = key.replace(KEY_PREFIXES["chat_messages"], "")
|
||||
messages = await self.redis.lrange(key, 0, -1) # type: ignore
|
||||
result[session_id] = [self._deserialize(msg) for msg in messages if msg]
|
||||
|
||||
return result
|
||||
@ -164,8 +165,7 @@ class ChatMixin(DatabaseProtocol):
|
||||
|
||||
for session_data in all_sessions.values():
|
||||
context = session_data.get("context", {})
|
||||
if (context.get("relatedEntityType") == "candidate" and
|
||||
context.get("relatedEntityId") == candidate_id):
|
||||
if context.get("relatedEntityType") == "candidate" and context.get("relatedEntityId") == candidate_id:
|
||||
candidate_sessions.append(session_data)
|
||||
|
||||
# Sort by last activity (most recent first)
|
||||
@ -188,7 +188,7 @@ class ChatMixin(DatabaseProtocol):
|
||||
async def get_chat_message_count(self, session_id: str) -> int:
|
||||
"""Get the total number of messages in a chat session"""
|
||||
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
|
||||
return await self.redis.llen(key)# type: ignore
|
||||
return await self.redis.llen(key) # type: ignore
|
||||
|
||||
async def search_chat_messages(self, session_id: str, query: str) -> List[Dict]:
|
||||
"""Search for messages containing specific text in a session"""
|
||||
@ -236,7 +236,6 @@ class ChatMixin(DatabaseProtocol):
|
||||
|
||||
return archived_count
|
||||
|
||||
|
||||
# Analytics and Reporting
|
||||
async def get_chat_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive chat statistics"""
|
||||
@ -250,7 +249,7 @@ class ChatMixin(DatabaseProtocol):
|
||||
"archived_sessions": 0,
|
||||
"sessions_by_type": {},
|
||||
"sessions_with_candidates": 0,
|
||||
"average_messages_per_session": 0
|
||||
"average_messages_per_session": 0,
|
||||
}
|
||||
|
||||
# Analyze sessions
|
||||
|
@ -7,6 +7,7 @@ from ..constants import KEY_PREFIXES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentMixin(DatabaseProtocol):
|
||||
"""Mixin for document-related database operations"""
|
||||
|
||||
@ -31,7 +32,7 @@ class DocumentMixin(DatabaseProtocol):
|
||||
try:
|
||||
# Get all document IDs for this candidate
|
||||
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
|
||||
document_ids = await self.redis.lrange(key, 0, -1)# type: ignore
|
||||
document_ids = await self.redis.lrange(key, 0, -1) # type: ignore
|
||||
|
||||
if not document_ids:
|
||||
logger.info(f"No documents found for candidate {candidate_id}")
|
||||
@ -83,7 +84,7 @@ class DocumentMixin(DatabaseProtocol):
|
||||
documents.append(doc_data)
|
||||
else:
|
||||
# Clean up orphaned document ID
|
||||
await self.redis.lrem(key, 0, doc_id)# type: ignore
|
||||
await self.redis.lrem(key, 0, doc_id) # type: ignore
|
||||
logger.warning(f"Removed orphaned document ID {doc_id} for candidate {candidate_id}")
|
||||
|
||||
return documents
|
||||
@ -91,12 +92,12 @@ class DocumentMixin(DatabaseProtocol):
|
||||
async def add_document_to_candidate(self, candidate_id: str, document_id: str):
|
||||
"""Add a document ID to a candidate's document list"""
|
||||
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
|
||||
await self.redis.rpush(key, document_id)# type: ignore
|
||||
await self.redis.rpush(key, document_id) # type: ignore
|
||||
|
||||
async def remove_document_from_candidate(self, candidate_id: str, document_id: str):
|
||||
"""Remove a document ID from a candidate's document list"""
|
||||
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
|
||||
await self.redis.lrem(key, 0, document_id)# type: ignore
|
||||
await self.redis.lrem(key, 0, document_id) # type: ignore
|
||||
|
||||
async def update_document(self, document_id: str, updates: Dict) -> Dict[Any, Any] | None:
|
||||
"""Update document metadata"""
|
||||
@ -128,7 +129,7 @@ class DocumentMixin(DatabaseProtocol):
|
||||
async def get_document_count_for_candidate(self, candidate_id: str) -> int:
|
||||
"""Get total number of documents for a candidate"""
|
||||
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
|
||||
return await self.redis.llen(key)# type: ignore
|
||||
return await self.redis.llen(key) # type: ignore
|
||||
|
||||
async def search_candidate_documents(self, candidate_id: str, query: str) -> List[Dict]:
|
||||
"""Search documents by filename for a candidate"""
|
||||
@ -136,8 +137,7 @@ class DocumentMixin(DatabaseProtocol):
|
||||
query_lower = query.lower()
|
||||
|
||||
return [
|
||||
doc for doc in all_documents
|
||||
if (query_lower in doc.get("filename", "").lower() or
|
||||
query_lower in doc.get("originalName", "").lower())
|
||||
doc
|
||||
for doc in all_documents
|
||||
if (query_lower in doc.get("filename", "").lower() or query_lower in doc.get("originalName", "").lower())
|
||||
]
|
||||
|
||||
|
@ -8,8 +8,10 @@ from ..constants import KEY_PREFIXES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JobMixin(DatabaseProtocol):
|
||||
"""Mixin for job-related database operations"""
|
||||
|
||||
async def get_job(self, job_id: str) -> Optional[Dict]:
|
||||
"""Get job by ID"""
|
||||
key = f"{KEY_PREFIXES['jobs']}{job_id}"
|
||||
@ -36,7 +38,7 @@ class JobMixin(DatabaseProtocol):
|
||||
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
job_id = key.replace(KEY_PREFIXES['jobs'], '')
|
||||
job_id = key.replace(KEY_PREFIXES["jobs"], "")
|
||||
result[job_id] = self._deserialize(value)
|
||||
|
||||
return result
|
||||
@ -124,7 +126,7 @@ class JobMixin(DatabaseProtocol):
|
||||
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
app_id = key.replace(KEY_PREFIXES['job_applications'], '')
|
||||
app_id = key.replace(KEY_PREFIXES["job_applications"], "")
|
||||
result[app_id] = self._deserialize(value)
|
||||
|
||||
return result
|
||||
@ -189,7 +191,7 @@ class JobMixin(DatabaseProtocol):
|
||||
requirements_with_meta = {
|
||||
**requirements,
|
||||
"cached_at": datetime.now(UTC).isoformat(),
|
||||
"document_id": document_id
|
||||
"document_id": document_id,
|
||||
}
|
||||
|
||||
await self.redis.set(key, self._serialize(requirements_with_meta))
|
||||
@ -232,7 +234,7 @@ class JobMixin(DatabaseProtocol):
|
||||
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
document_id = key.replace(KEY_PREFIXES['job_requirements'], '')
|
||||
document_id = key.replace(KEY_PREFIXES["job_requirements"], "")
|
||||
if value:
|
||||
result[document_id] = self._deserialize(value)
|
||||
|
||||
@ -247,11 +249,7 @@ class JobMixin(DatabaseProtocol):
|
||||
pattern = f"{KEY_PREFIXES['job_requirements']}*"
|
||||
keys = await self.redis.keys(pattern)
|
||||
|
||||
stats = {
|
||||
"total_cached_requirements": len(keys),
|
||||
"cache_dates": {},
|
||||
"documents_with_requirements": []
|
||||
}
|
||||
stats = {"total_cached_requirements": len(keys), "cache_dates": {}, "documents_with_requirements": []}
|
||||
|
||||
if keys:
|
||||
# Get cache dates for analysis
|
||||
@ -264,7 +262,7 @@ class JobMixin(DatabaseProtocol):
|
||||
if value:
|
||||
requirements_data = self._deserialize(value)
|
||||
if requirements_data:
|
||||
document_id = key.replace(KEY_PREFIXES['job_requirements'], '')
|
||||
document_id = key.replace(KEY_PREFIXES["job_requirements"], "")
|
||||
stats["documents_with_requirements"].append(document_id)
|
||||
|
||||
# Track cache dates
|
||||
@ -277,4 +275,3 @@ class JobMixin(DatabaseProtocol):
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting job requirements stats: {e}")
|
||||
return {"total_cached_requirements": 0, "cache_dates": {}, "documents_with_requirements": []}
|
||||
|
||||
|
@ -7,153 +7,412 @@ if TYPE_CHECKING:
|
||||
|
||||
from models import SkillAssessment
|
||||
|
||||
|
||||
class DatabaseProtocol(Protocol):
|
||||
# Base mixin
|
||||
redis: Redis
|
||||
def _serialize(self, data) -> str: ...
|
||||
def _deserialize(self, data: str): ...
|
||||
|
||||
def _serialize(self, data) -> str:
|
||||
...
|
||||
|
||||
def _deserialize(self, data: str):
|
||||
...
|
||||
|
||||
# Chat mixin
|
||||
async def add_chat_message(self, session_id: str, message_data: Dict): ...
|
||||
async def archive_chat_session(self, session_id: str): ...
|
||||
async def bulk_update_chat_sessions(self, session_updates: Dict[str, Dict]): ...
|
||||
async def delete_chat_message(self, session_id: str, message_id: str) -> bool: ...
|
||||
async def delete_chat_messages(self, session_id: str): ...
|
||||
async def delete_chat_session_completely(self, session_id: str): ...
|
||||
async def delete_chat_session(self, session_id: str) -> bool: ...
|
||||
async def add_chat_message(self, session_id: str, message_data: Dict):
|
||||
...
|
||||
|
||||
async def archive_chat_session(self, session_id: str):
|
||||
...
|
||||
|
||||
async def bulk_update_chat_sessions(self, session_updates: Dict[str, Dict]):
|
||||
...
|
||||
|
||||
async def delete_chat_message(self, session_id: str, message_id: str) -> bool:
|
||||
...
|
||||
|
||||
async def delete_chat_messages(self, session_id: str):
|
||||
...
|
||||
|
||||
async def delete_chat_session_completely(self, session_id: str):
|
||||
...
|
||||
|
||||
async def delete_chat_session(self, session_id: str) -> bool:
|
||||
...
|
||||
|
||||
# Document mixin
|
||||
async def add_document_to_candidate(self, candidate_id: str, document_id: str): ...
|
||||
async def bulk_update_document_rag_status(self, candidate_id: str, document_ids: List[str], include_in_rag: bool): ...
|
||||
async def add_document_to_candidate(self, candidate_id: str, document_id: str):
|
||||
...
|
||||
|
||||
async def bulk_update_document_rag_status(self, candidate_id: str, document_ids: List[str], include_in_rag: bool):
|
||||
...
|
||||
|
||||
# Job mixin
|
||||
async def bulk_delete_job_requirements(self, document_ids: List[str]) -> int: ...
|
||||
async def cache_skill_match(self, cache_key: str, assessment: SkillAssessment) -> None: ...
|
||||
async def bulk_delete_job_requirements(self, document_ids: List[str]) -> int:
|
||||
...
|
||||
|
||||
async def cache_skill_match(self, cache_key: str, assessment: SkillAssessment) -> None:
|
||||
...
|
||||
|
||||
# User mixin
|
||||
async def delete_candidate_batch(self, candidate_ids: List[str]) -> Dict[str, Dict[str, int]]: ...
|
||||
async def delete_candidate(self, candidate_id: str) -> Dict[str, int]: ...
|
||||
async def delete_employer(self, employer_id: str): ...
|
||||
async def delete_guest(self, guest_id: str) -> bool: ...
|
||||
async def delete_user(self, email: str): ...
|
||||
async def find_candidate_by_username(self, username: str) -> Optional[Dict]: ...
|
||||
async def get_all_users(self) -> Dict[str, Any]: ...
|
||||
async def get_all_viewers(self) -> Dict[str, Any]: ...
|
||||
async def get_candidate_chat_summary(self, candidate_id: str) -> Dict[str, Any]: ...
|
||||
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]: ...
|
||||
async def get_candidate(self, candidate_id: str) -> Optional[Dict]: ...
|
||||
async def get_employer(self, employer_id: str) -> Optional[Dict]: ...
|
||||
async def get_guest_by_session_id(self, session_id: str) -> Optional[Dict[str, Any]]: ...
|
||||
async def get_guest(self, guest_id: str) -> Optional[Dict[str, Any]]: ...
|
||||
async def get_guest_statistics(self) -> Dict[str, Any]: ...
|
||||
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: ...
|
||||
async def get_user_by_username(self, username: str) -> Optional[Dict]: ...
|
||||
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]: ...
|
||||
async def get_user_security_log(self, user_id: str, days: int = 7) -> List[Dict[str, Any]]: ...
|
||||
async def get_user(self, login: str) -> Optional[Dict[str, Any]]: ...
|
||||
async def invalidate_candidate_skill_cache(self, candidate_id: str) -> int: ...
|
||||
async def invalidate_user_skill_cache(self, user_id: str) -> int: ...
|
||||
async def set_candidate(self, candidate_id: str, candidate_data: Dict): ...
|
||||
async def set_employer(self, employer_id: str, employer_data: Dict): ...
|
||||
async def set_guest(self, guest_id: str, guest_data: Dict[str, Any]) -> None: ...
|
||||
async def set_user_by_id(self, user_id: str, user_data: Dict[str, Any]) -> bool: ...
|
||||
async def set_user(self, login: str, user_data: Dict[str, Any]) -> bool: ...
|
||||
async def update_user_rag_timestamp(self, user_id: str) -> bool: ...
|
||||
async def user_exists_by_email(self, email: str) -> bool: ...
|
||||
async def user_exists_by_username(self, username: str) -> bool: ...
|
||||
async def delete_candidate_batch(self, candidate_ids: List[str]) -> Dict[str, Dict[str, int]]:
|
||||
...
|
||||
|
||||
async def delete_candidate(self, candidate_id: str) -> Dict[str, int]:
|
||||
...
|
||||
|
||||
async def delete_employer(self, employer_id: str):
|
||||
...
|
||||
|
||||
async def delete_guest(self, guest_id: str) -> bool:
|
||||
...
|
||||
|
||||
async def delete_user(self, email: str):
|
||||
...
|
||||
|
||||
async def find_candidate_by_username(self, username: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def get_all_users(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_all_viewers(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_candidate_chat_summary(self, candidate_id: str) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def get_candidate(self, candidate_id: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def get_employer(self, employer_id: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def get_guest_by_session_id(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def get_guest(self, guest_id: str) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def get_guest_statistics(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def get_user_by_username(self, username: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]:
|
||||
...
|
||||
|
||||
async def get_user_security_log(self, user_id: str, days: int = 7) -> List[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def get_user(self, login: str) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def invalidate_candidate_skill_cache(self, candidate_id: str) -> int:
|
||||
...
|
||||
|
||||
async def invalidate_user_skill_cache(self, user_id: str) -> int:
|
||||
...
|
||||
|
||||
async def set_candidate(self, candidate_id: str, candidate_data: Dict):
|
||||
...
|
||||
|
||||
async def set_employer(self, employer_id: str, employer_data: Dict):
|
||||
...
|
||||
|
||||
async def set_guest(self, guest_id: str, guest_data: Dict[str, Any]) -> None:
|
||||
...
|
||||
|
||||
async def set_user_by_id(self, user_id: str, user_data: Dict[str, Any]) -> bool:
|
||||
...
|
||||
|
||||
async def set_user(self, login: str, user_data: Dict[str, Any]) -> bool:
|
||||
...
|
||||
|
||||
async def update_user_rag_timestamp(self, user_id: str) -> bool:
|
||||
...
|
||||
|
||||
async def user_exists_by_email(self, email: str) -> bool:
|
||||
...
|
||||
|
||||
async def user_exists_by_username(self, username: str) -> bool:
|
||||
...
|
||||
|
||||
# Auth mixin
|
||||
async def cleanup_expired_verification_tokens(self) -> int: ...
|
||||
async def cleanup_inactive_guests(self, inactive_hours: int = 24) -> int: ...
|
||||
async def cleanup_old_chat_sessions(self, days_old: int = 90) -> int: ...
|
||||
async def cleanup_orphaned_job_requirements(self) -> int: ...
|
||||
async def clear_all_data(self: "DatabaseProtocol"): ...
|
||||
async def clear_all_skill_match_cache(self) -> int: ...
|
||||
async def cleanup_expired_verification_tokens(self) -> int:
|
||||
...
|
||||
|
||||
async def cleanup_inactive_guests(self, inactive_hours: int = 24) -> int:
|
||||
...
|
||||
|
||||
async def cleanup_old_chat_sessions(self, days_old: int = 90) -> int:
|
||||
...
|
||||
|
||||
async def cleanup_orphaned_job_requirements(self) -> int:
|
||||
...
|
||||
|
||||
async def clear_all_data(self: "DatabaseProtocol"):
|
||||
...
|
||||
|
||||
async def clear_all_skill_match_cache(self) -> int:
|
||||
...
|
||||
|
||||
# Resume mixin
|
||||
|
||||
async def delete_ai_parameters(self, param_id: str): ...
|
||||
async def delete_all_candidate_documents(self, candidate_id: str) -> int: ...
|
||||
async def delete_all_resumes_for_user(self, user_id: str) -> int: ...
|
||||
async def delete_authentication(self, user_id: str) -> bool: ...
|
||||
async def delete_document(self, document_id: str): ...
|
||||
async def delete_ai_parameters(self, param_id: str):
|
||||
...
|
||||
|
||||
async def delete_job_application(self, application_id: str): ...
|
||||
async def delete_job_requirements(self, document_id: str) -> bool: ...
|
||||
async def delete_job(self, job_id: str): ...
|
||||
async def delete_resume(self, user_id: str, resume_id: str) -> bool: ...
|
||||
async def delete_viewer(self, viewer_id: str): ...
|
||||
async def find_verification_token_by_email(self, email: str) -> Optional[Dict[str, Any]]: ...
|
||||
async def get_ai_parameters(self, param_id: str) -> Optional[Dict]: ...
|
||||
async def get_all_ai_parameters(self) -> Dict[str, Any]: ...
|
||||
async def get_all_candidates(self) -> Dict[str, Any]: ...
|
||||
async def get_all_chat_messages(self) -> Dict[str, List[Dict]]: ...
|
||||
async def get_all_chat_sessions(self) -> Dict[str, Any]: ...
|
||||
async def get_all_employers(self) -> Dict[str, Any]: ...
|
||||
async def get_all_guests(self) -> Dict[str, Dict[str, Any]]: ...
|
||||
async def get_all_job_applications(self) -> Dict[str, Any]: ...
|
||||
async def get_all_job_requirements(self) -> Dict[str, Any]: ...
|
||||
async def get_all_jobs(self) -> Dict[str, Any]: ...
|
||||
async def get_all_resumes_for_user(self, user_id: str) -> List[Dict]: ...
|
||||
async def get_all_resumes(self) -> Dict[str, List[Dict]]: ...
|
||||
async def get_authentication(self, user_id: str) -> Optional[Dict[str, Any]]: ...
|
||||
async def get_cached_skill_match(self, cache_key: str) -> Optional[SkillAssessment]: ...
|
||||
async def get_chat_message_count(self, session_id: str) -> int: ...
|
||||
async def get_chat_messages(self, session_id: str) -> List[Dict]: ...
|
||||
async def get_chat_sessions_by_candidate(self, candidate_id: str) -> List[Dict]: ...
|
||||
async def get_chat_sessions_by_user(self, user_id: str) -> List[Dict]: ...
|
||||
async def get_chat_session(self, session_id: str) -> Optional[Dict]: ...
|
||||
async def get_chat_statistics(self) -> Dict[str, Any]: ...
|
||||
async def get_document_count_for_candidate(self, candidate_id: str) -> int: ...
|
||||
async def get_documents_by_rag_status(self, candidate_id: str, include_in_rag: bool = True) -> List[Dict]: ...
|
||||
async def get_document(self, document_id: str) -> Optional[Dict]: ...
|
||||
async def get_email_verification_token(self, token: str) -> Optional[Dict[str, Any]]: ...
|
||||
async def get_job_application(self, application_id: str) -> Optional[Dict]: ...
|
||||
async def get_job_requirements_by_candidate(self, candidate_id: str) -> List[Dict]: ...
|
||||
async def get_job_requirements(self, document_id: str) -> Optional[Dict]: ...
|
||||
async def get_job_requirements_stats(self) -> Dict[str, Any]: ...
|
||||
async def get_job(self, job_id: str) -> Optional[Dict]: ...
|
||||
async def get_mfa_code(self, email: str, device_id: str) -> Optional[Dict[str, Any]]: ...
|
||||
async def get_multiple_candidates_by_usernames(self, usernames: List[str]) -> Dict[str, Dict]: ...
|
||||
async def get_password_reset_token(self, token: str) -> Optional[Dict[str, Any]]: ...
|
||||
async def get_pending_verifications_count(self) -> int: ...
|
||||
async def get_recent_chat_messages(self, session_id: str, limit: int = 10) -> List[Dict]: ...
|
||||
async def get_refresh_token(self, token: str) -> Optional[Dict[str, Any]]: ...
|
||||
async def get_resumes_by_candidate(self, user_id: str, candidate_id: str) -> List[Dict]: ...
|
||||
async def get_resumes_by_job(self, user_id: str, job_id: str) -> List[Dict]: ...
|
||||
async def get_resume(self, user_id: str, resume_id: str) -> Optional[Dict]: ...
|
||||
async def get_resume_statistics(self, user_id: str) -> Dict[str, Any]: ...
|
||||
async def get_stats(self) -> Dict[str, int]: ...
|
||||
async def get_verification_attempts_count(self, email: str) -> int: ...
|
||||
async def get_viewer(self, viewer_id: str) -> Optional[Dict]: ...
|
||||
async def increment_mfa_attempts(self, email: str, device_id: str) -> int: ...
|
||||
async def invalidate_job_requirements_cache(self, document_id: str) -> bool: ...
|
||||
async def log_security_event(self, user_id: str, event_type: str, details: Dict[str, Any]) -> bool: ...
|
||||
async def mark_email_verified(self, token: str) -> bool: ...
|
||||
async def mark_mfa_verified(self, email: str, device_id: str) -> bool: ...
|
||||
async def mark_password_reset_token_used(self, token: str) -> bool: ...
|
||||
async def record_verification_attempt(self, email: str) -> bool: ...
|
||||
async def remove_document_from_candidate(self, candidate_id: str, document_id: str): ...
|
||||
async def revoke_all_user_tokens(self, user_id: str) -> bool: ...
|
||||
async def revoke_refresh_token(self, token: str) -> bool: ...
|
||||
async def save_job_requirements(self, document_id: str, requirements: Dict) -> bool: ...
|
||||
async def search_candidate_documents(self, candidate_id: str, query: str) -> List[Dict]: ...
|
||||
async def search_chat_messages(self, session_id: str, query: str) -> List[Dict]: ...
|
||||
async def search_resumes_for_user(self, user_id: str, query: str) -> List[Dict]: ...
|
||||
async def set_ai_parameters(self, param_id: str, param_data: Dict): ...
|
||||
async def set_authentication(self, user_id: str, auth_data: Dict[str, Any]) -> bool: ...
|
||||
async def set_chat_messages(self, session_id: str, messages: List[Dict]): ...
|
||||
async def set_chat_session(self, session_id: str, session_data: Dict): ...
|
||||
async def set_document(self, document_id: str, document_data: Dict): ...
|
||||
async def set_job_application(self, application_id: str, application_data: Dict): ...
|
||||
async def set_job(self, job_id: str, job_data: Dict): ...
|
||||
async def set_resume(self, user_id: str, resume_data: Dict) -> bool: ...
|
||||
async def set_viewer(self, viewer_id: str, viewer_data: Dict): ...
|
||||
async def store_email_verification_token(self, email: str, token: str, user_type: str, user_data: dict) -> bool: ...
|
||||
async def store_mfa_code(self, email: str, code: str, device_id: str) -> bool: ...
|
||||
async def store_password_reset_token(self, email: str, token: str, expires_at: datetime) -> bool: ...
|
||||
async def store_refresh_token(self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]) -> bool: ...
|
||||
async def update_chat_session_activity(self, session_id: str): ...
|
||||
async def update_document(self, document_id: str, updates: Dict)-> Dict[Any, Any] | None: ...
|
||||
async def update_resume(self, user_id: str, resume_id: str, updates: Dict) -> Optional[Dict]: ...
|
||||
async def delete_all_candidate_documents(self, candidate_id: str) -> int:
|
||||
...
|
||||
|
||||
async def delete_all_resumes_for_user(self, user_id: str) -> int:
|
||||
...
|
||||
|
||||
async def delete_authentication(self, user_id: str) -> bool:
|
||||
...
|
||||
|
||||
async def delete_document(self, document_id: str):
|
||||
...
|
||||
|
||||
async def delete_job_application(self, application_id: str):
|
||||
...
|
||||
|
||||
async def delete_job_requirements(self, document_id: str) -> bool:
|
||||
...
|
||||
|
||||
async def delete_job(self, job_id: str):
|
||||
...
|
||||
|
||||
async def delete_resume(self, user_id: str, resume_id: str) -> bool:
|
||||
...
|
||||
|
||||
async def delete_viewer(self, viewer_id: str):
|
||||
...
|
||||
|
||||
async def find_verification_token_by_email(self, email: str) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def get_ai_parameters(self, param_id: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def get_all_ai_parameters(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_all_candidates(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_all_chat_messages(self) -> Dict[str, List[Dict]]:
|
||||
...
|
||||
|
||||
async def get_all_chat_sessions(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_all_employers(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_all_guests(self) -> Dict[str, Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def get_all_job_applications(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_all_job_requirements(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_all_jobs(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_all_resumes_for_user(self, user_id: str) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def get_all_resumes(self) -> Dict[str, List[Dict]]:
|
||||
...
|
||||
|
||||
async def get_authentication(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def get_cached_skill_match(self, cache_key: str) -> Optional[SkillAssessment]:
|
||||
...
|
||||
|
||||
async def get_chat_message_count(self, session_id: str) -> int:
|
||||
...
|
||||
|
||||
async def get_chat_messages(self, session_id: str) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def get_chat_sessions_by_candidate(self, candidate_id: str) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def get_chat_sessions_by_user(self, user_id: str) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def get_chat_session(self, session_id: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def get_chat_statistics(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_document_count_for_candidate(self, candidate_id: str) -> int:
|
||||
...
|
||||
|
||||
async def get_documents_by_rag_status(self, candidate_id: str, include_in_rag: bool = True) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def get_document(self, document_id: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def get_email_verification_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def get_job_application(self, application_id: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def get_job_requirements_by_candidate(self, candidate_id: str) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def get_job_requirements(self, document_id: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def get_job_requirements_stats(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_job(self, job_id: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def get_mfa_code(self, email: str, device_id: str) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def get_multiple_candidates_by_usernames(self, usernames: List[str]) -> Dict[str, Dict]:
|
||||
...
|
||||
|
||||
async def get_password_reset_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def get_pending_verifications_count(self) -> int:
|
||||
...
|
||||
|
||||
async def get_recent_chat_messages(self, session_id: str, limit: int = 10) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def get_refresh_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def get_resumes_by_candidate(self, user_id: str, candidate_id: str) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def get_resumes_by_job(self, user_id: str, job_id: str) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def get_resume(self, user_id: str, resume_id: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def get_resume_statistics(self, user_id: str) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_stats(self) -> Dict[str, int]:
|
||||
...
|
||||
|
||||
async def get_verification_attempts_count(self, email: str) -> int:
|
||||
...
|
||||
|
||||
async def get_viewer(self, viewer_id: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def increment_mfa_attempts(self, email: str, device_id: str) -> int:
|
||||
...
|
||||
|
||||
async def invalidate_job_requirements_cache(self, document_id: str) -> bool:
|
||||
...
|
||||
|
||||
async def log_security_event(self, user_id: str, event_type: str, details: Dict[str, Any]) -> bool:
|
||||
...
|
||||
|
||||
async def mark_email_verified(self, token: str) -> bool:
|
||||
...
|
||||
|
||||
async def mark_mfa_verified(self, email: str, device_id: str) -> bool:
|
||||
...
|
||||
|
||||
async def mark_password_reset_token_used(self, token: str) -> bool:
|
||||
...
|
||||
|
||||
async def record_verification_attempt(self, email: str) -> bool:
|
||||
...
|
||||
|
||||
async def remove_document_from_candidate(self, candidate_id: str, document_id: str):
|
||||
...
|
||||
|
||||
async def revoke_all_user_tokens(self, user_id: str) -> bool:
|
||||
...
|
||||
|
||||
async def revoke_refresh_token(self, token: str) -> bool:
|
||||
...
|
||||
|
||||
async def save_job_requirements(self, document_id: str, requirements: Dict) -> bool:
|
||||
...
|
||||
|
||||
async def search_candidate_documents(self, candidate_id: str, query: str) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def search_chat_messages(self, session_id: str, query: str) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def search_resumes_for_user(self, user_id: str, query: str) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def set_ai_parameters(self, param_id: str, param_data: Dict):
|
||||
...
|
||||
|
||||
async def set_authentication(self, user_id: str, auth_data: Dict[str, Any]) -> bool:
|
||||
...
|
||||
|
||||
async def set_chat_messages(self, session_id: str, messages: List[Dict]):
|
||||
...
|
||||
|
||||
async def set_chat_session(self, session_id: str, session_data: Dict):
|
||||
...
|
||||
|
||||
async def set_document(self, document_id: str, document_data: Dict):
|
||||
...
|
||||
|
||||
async def set_job_application(self, application_id: str, application_data: Dict):
|
||||
...
|
||||
|
||||
async def set_job(self, job_id: str, job_data: Dict):
|
||||
...
|
||||
|
||||
async def set_resume(self, user_id: str, resume_data: Dict) -> bool:
|
||||
...
|
||||
|
||||
async def set_viewer(self, viewer_id: str, viewer_data: Dict):
|
||||
...
|
||||
|
||||
async def store_email_verification_token(self, email: str, token: str, user_type: str, user_data: dict) -> bool:
|
||||
...
|
||||
|
||||
async def store_mfa_code(self, email: str, code: str, device_id: str) -> bool:
|
||||
...
|
||||
|
||||
async def store_password_reset_token(self, email: str, token: str, expires_at: datetime) -> bool:
|
||||
...
|
||||
|
||||
async def store_refresh_token(
|
||||
self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]
|
||||
) -> bool:
|
||||
...
|
||||
|
||||
async def update_chat_session_activity(self, session_id: str):
|
||||
...
|
||||
|
||||
async def update_document(self, document_id: str, updates: Dict) -> Dict[Any, Any] | None:
|
||||
...
|
||||
|
||||
async def update_resume(self, user_id: str, resume_id: str, updates: Dict) -> Optional[Dict]:
|
||||
...
|
||||
|
@ -7,6 +7,7 @@ from ..constants import KEY_PREFIXES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResumeMixin(DatabaseProtocol):
|
||||
"""Mixin for resume-related database operations"""
|
||||
|
||||
@ -14,10 +15,10 @@ class ResumeMixin(DatabaseProtocol):
|
||||
"""Save a resume for a user"""
|
||||
try:
|
||||
# Generate resume_id if not present
|
||||
if 'id' not in resume_data:
|
||||
if "id" not in resume_data:
|
||||
raise ValueError("Resume data must include an 'id' field")
|
||||
|
||||
resume_id = resume_data['id']
|
||||
resume_id = resume_data["id"]
|
||||
|
||||
# Store the resume data
|
||||
key = f"{KEY_PREFIXES['resumes']}{user_id}:{resume_id}"
|
||||
@ -53,7 +54,7 @@ class ResumeMixin(DatabaseProtocol):
|
||||
try:
|
||||
# Get all resume IDs for this user
|
||||
user_resumes_key = f"{KEY_PREFIXES['user_resumes']}{user_id}"
|
||||
resume_ids = await self.redis.lrange(user_resumes_key, 0, -1)# type: ignore
|
||||
resume_ids = await self.redis.lrange(user_resumes_key, 0, -1) # type: ignore
|
||||
|
||||
if not resume_ids:
|
||||
logger.info(f"📄 No resumes found for user {user_id}")
|
||||
@ -73,7 +74,7 @@ class ResumeMixin(DatabaseProtocol):
|
||||
resumes.append(resume_data)
|
||||
else:
|
||||
# Clean up orphaned resume ID
|
||||
await self.redis.lrem(user_resumes_key, 0, resume_id)# type: ignore
|
||||
await self.redis.lrem(user_resumes_key, 0, resume_id) # type: ignore
|
||||
logger.warning(f"Removed orphaned resume ID {resume_id} for user {user_id}")
|
||||
|
||||
# Sort by created_at timestamp (most recent first)
|
||||
@ -94,7 +95,7 @@ class ResumeMixin(DatabaseProtocol):
|
||||
|
||||
# Remove from user's resume list
|
||||
user_resumes_key = f"{KEY_PREFIXES['user_resumes']}{user_id}"
|
||||
await self.redis.lrem(user_resumes_key, 0, resume_id)# type: ignore
|
||||
await self.redis.lrem(user_resumes_key, 0, resume_id) # type: ignore
|
||||
|
||||
if result > 0:
|
||||
logger.info(f"🗑️ Deleted resume {resume_id} for user {user_id}")
|
||||
@ -111,7 +112,7 @@ class ResumeMixin(DatabaseProtocol):
|
||||
try:
|
||||
# Get all resume IDs for this user
|
||||
user_resumes_key = f"{KEY_PREFIXES['user_resumes']}{user_id}"
|
||||
resume_ids = await self.redis.lrange(user_resumes_key, 0, -1)# type: ignore
|
||||
resume_ids = await self.redis.lrange(user_resumes_key, 0, -1) # type: ignore
|
||||
|
||||
if not resume_ids:
|
||||
logger.info(f"📄 No resumes found for user {user_id}")
|
||||
@ -159,7 +160,7 @@ class ResumeMixin(DatabaseProtocol):
|
||||
for key, value in zip(keys, values):
|
||||
if value:
|
||||
# Extract user_id from key format: resume:{user_id}:{resume_id}
|
||||
key_parts = key.replace(KEY_PREFIXES['resumes'], '').split(':', 1)
|
||||
key_parts = key.replace(KEY_PREFIXES["resumes"], "").split(":", 1)
|
||||
if len(key_parts) >= 1:
|
||||
user_id = key_parts[0]
|
||||
resume_data = self._deserialize(value)
|
||||
@ -186,12 +187,14 @@ class ResumeMixin(DatabaseProtocol):
|
||||
matching_resumes = []
|
||||
for resume in all_resumes:
|
||||
# Search in resume content, job_id, candidate_id, etc.
|
||||
searchable_text = " ".join([
|
||||
searchable_text = " ".join(
|
||||
[
|
||||
resume.get("resume", ""),
|
||||
resume.get("job_id", ""),
|
||||
resume.get("candidate_id", ""),
|
||||
str(resume.get("created_at", ""))
|
||||
]).lower()
|
||||
str(resume.get("created_at", "")),
|
||||
]
|
||||
).lower()
|
||||
|
||||
if query_lower in searchable_text:
|
||||
matching_resumes.append(resume)
|
||||
@ -206,10 +209,7 @@ class ResumeMixin(DatabaseProtocol):
|
||||
"""Get all resumes for a specific candidate created by a user"""
|
||||
try:
|
||||
all_resumes = await self.get_all_resumes_for_user(user_id)
|
||||
candidate_resumes = [
|
||||
resume for resume in all_resumes
|
||||
if resume.get("candidate_id") == candidate_id
|
||||
]
|
||||
candidate_resumes = [resume for resume in all_resumes if resume.get("candidate_id") == candidate_id]
|
||||
|
||||
logger.info(f"📄 Found {len(candidate_resumes)} resumes for candidate {candidate_id} by user {user_id}")
|
||||
return candidate_resumes
|
||||
@ -221,10 +221,7 @@ class ResumeMixin(DatabaseProtocol):
|
||||
"""Get all resumes for a specific job created by a user"""
|
||||
try:
|
||||
all_resumes = await self.get_all_resumes_for_user(user_id)
|
||||
job_resumes = [
|
||||
resume for resume in all_resumes
|
||||
if resume.get("job_id") == job_id
|
||||
]
|
||||
job_resumes = [resume for resume in all_resumes if resume.get("job_id") == job_id]
|
||||
|
||||
logger.info(f"📄 Found {len(job_resumes)} resumes for job {job_id} by user {user_id}")
|
||||
return job_resumes
|
||||
@ -242,7 +239,7 @@ class ResumeMixin(DatabaseProtocol):
|
||||
"resumes_by_candidate": {},
|
||||
"resumes_by_job": {},
|
||||
"creation_timeline": {},
|
||||
"recent_resumes": []
|
||||
"recent_resumes": [],
|
||||
}
|
||||
|
||||
for resume in all_resumes:
|
||||
@ -269,7 +266,13 @@ class ResumeMixin(DatabaseProtocol):
|
||||
return stats
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting resume statistics for user {user_id}: {e}")
|
||||
return {"total_resumes": 0, "resumes_by_candidate": {}, "resumes_by_job": {}, "creation_timeline": {}, "recent_resumes": []}
|
||||
return {
|
||||
"total_resumes": 0,
|
||||
"resumes_by_candidate": {},
|
||||
"resumes_by_job": {},
|
||||
"creation_timeline": {},
|
||||
"recent_resumes": [],
|
||||
}
|
||||
|
||||
async def update_resume(self, user_id: str, resume_id: str, updates: Dict) -> Optional[Dict]:
|
||||
"""Update specific fields of a resume"""
|
||||
|
@ -8,6 +8,7 @@ from .protocols import DatabaseProtocol
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SkillMixin(DatabaseProtocol):
|
||||
"""Mixin for Skill-related database operations"""
|
||||
|
||||
@ -74,7 +75,9 @@ class SkillMixin(DatabaseProtocol):
|
||||
# Cache for 1 hour by default
|
||||
await self.redis.set(
|
||||
cache_key,
|
||||
json.dumps(assessment.model_dump(mode='json', by_alias=True), default=str) # Serialize with datetime handling
|
||||
json.dumps(
|
||||
assessment.model_dump(mode="json", by_alias=True), default=str
|
||||
), # Serialize with datetime handling
|
||||
)
|
||||
logger.info(f"💾 Skill match cached: {cache_key}")
|
||||
except Exception as e:
|
||||
|
@ -10,6 +10,7 @@ from ..constants import KEY_PREFIXES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserMixin(DatabaseProtocol):
|
||||
"""Mixin for user operations"""
|
||||
|
||||
@ -23,13 +24,13 @@ class UserMixin(DatabaseProtocol):
|
||||
guest_data["last_activity"] = datetime.now(UTC).isoformat()
|
||||
|
||||
# Store in Redis with both hash and individual key for redundancy
|
||||
await self.redis.hset("guests", guest_id, json.dumps(guest_data))# type: ignore
|
||||
await self.redis.hset("guests", guest_id, json.dumps(guest_data)) # type: ignore
|
||||
|
||||
# Also store with a longer TTL as backup
|
||||
await self.redis.setex(
|
||||
f"guest_backup:{guest_id}",
|
||||
86400 * 7, # 7 days TTL
|
||||
json.dumps(guest_data)
|
||||
json.dumps(guest_data),
|
||||
)
|
||||
|
||||
logger.info(f"💾 Guest stored with backup: {guest_id}")
|
||||
@ -41,7 +42,7 @@ class UserMixin(DatabaseProtocol):
|
||||
"""Get guest data with fallback to backup"""
|
||||
try:
|
||||
# Try primary storage first
|
||||
data = await self.redis.hget("guests", guest_id)# type: ignore
|
||||
data = await self.redis.hget("guests", guest_id) # type: ignore
|
||||
if data:
|
||||
guest_data = json.loads(data)
|
||||
# Update last activity when accessed
|
||||
@ -82,11 +83,8 @@ class UserMixin(DatabaseProtocol):
|
||||
async def get_all_guests(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all guests"""
|
||||
try:
|
||||
data = await self.redis.hgetall("guests")# type: ignore
|
||||
return {
|
||||
guest_id: json.loads(guest_json)
|
||||
for guest_id, guest_json in data.items()
|
||||
}
|
||||
data = await self.redis.hgetall("guests") # type: ignore
|
||||
return {guest_id: json.loads(guest_json) for guest_id, guest_json in data.items()}
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting all guests: {e}")
|
||||
return {}
|
||||
@ -94,7 +92,7 @@ class UserMixin(DatabaseProtocol):
|
||||
async def delete_guest(self, guest_id: str) -> bool:
|
||||
"""Delete a guest"""
|
||||
try:
|
||||
result = await self.redis.hdel("guests", guest_id)# type: ignore
|
||||
result = await self.redis.hdel("guests", guest_id) # type: ignore
|
||||
if result:
|
||||
logger.info(f"🗑️ Guest deleted: {guest_id}")
|
||||
return True
|
||||
@ -120,7 +118,7 @@ class UserMixin(DatabaseProtocol):
|
||||
|
||||
# Skip cleanup if guest is very new (less than 1 hour old)
|
||||
if created_at_str:
|
||||
created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
|
||||
created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00"))
|
||||
if current_time - created_at < timedelta(hours=1):
|
||||
preserved_count += 1
|
||||
logger.info(f"🛡️ Preserving new guest: {guest_id}")
|
||||
@ -130,7 +128,7 @@ class UserMixin(DatabaseProtocol):
|
||||
should_delete = False
|
||||
if last_activity_str:
|
||||
try:
|
||||
last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00'))
|
||||
last_activity = datetime.fromisoformat(last_activity_str.replace("Z", "+00:00"))
|
||||
if last_activity < cutoff_time:
|
||||
should_delete = True
|
||||
except ValueError:
|
||||
@ -172,7 +170,7 @@ class UserMixin(DatabaseProtocol):
|
||||
"active_last_day": 0,
|
||||
"converted_guests": 0,
|
||||
"by_ip": {},
|
||||
"creation_timeline": {}
|
||||
"creation_timeline": {},
|
||||
}
|
||||
|
||||
hour_ago = current_time - timedelta(hours=1)
|
||||
@ -183,7 +181,7 @@ class UserMixin(DatabaseProtocol):
|
||||
last_activity_str = guest_data.get("last_activity")
|
||||
if last_activity_str:
|
||||
try:
|
||||
last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00'))
|
||||
last_activity = datetime.fromisoformat(last_activity_str.replace("Z", "+00:00"))
|
||||
if last_activity > hour_ago:
|
||||
stats["active_last_hour"] += 1
|
||||
if last_activity > day_ago:
|
||||
@ -203,8 +201,8 @@ class UserMixin(DatabaseProtocol):
|
||||
created_at_str = guest_data.get("created_at")
|
||||
if created_at_str:
|
||||
try:
|
||||
created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
|
||||
date_key = created_at.strftime('%Y-%m-%d')
|
||||
created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00"))
|
||||
date_key = created_at.strftime("%Y-%m-%d")
|
||||
stats["creation_timeline"][date_key] = stats["creation_timeline"].get(date_key, 0) + 1
|
||||
except ValueError:
|
||||
pass
|
||||
@ -278,7 +276,7 @@ class UserMixin(DatabaseProtocol):
|
||||
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get user lookup data by user ID"""
|
||||
try:
|
||||
data = await self.redis.hget("user_lookup_by_id", user_id)# type: ignore
|
||||
data = await self.redis.hget("user_lookup_by_id", user_id) # type: ignore
|
||||
if data:
|
||||
return json.loads(data)
|
||||
return None
|
||||
@ -321,7 +319,7 @@ class UserMixin(DatabaseProtocol):
|
||||
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
email = key.replace(KEY_PREFIXES['users'], '')
|
||||
email = key.replace(KEY_PREFIXES["users"], "")
|
||||
logger.info(f"🔍 Found user key: {key}, type: {type(value)}")
|
||||
if type(value) == str:
|
||||
result[email] = value
|
||||
@ -364,7 +362,6 @@ class UserMixin(DatabaseProtocol):
|
||||
logger.error(f"❌ Error storing user {login}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# ================
|
||||
# Employers
|
||||
# ================
|
||||
@ -394,7 +391,7 @@ class UserMixin(DatabaseProtocol):
|
||||
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
employer_id = key.replace(KEY_PREFIXES['employers'], '')
|
||||
employer_id = key.replace(KEY_PREFIXES["employers"], "")
|
||||
result[employer_id] = self._deserialize(value)
|
||||
|
||||
return result
|
||||
@ -404,7 +401,6 @@ class UserMixin(DatabaseProtocol):
|
||||
key = f"{KEY_PREFIXES['employers']}{employer_id}"
|
||||
await self.redis.delete(key)
|
||||
|
||||
|
||||
# ================
|
||||
# Candidates
|
||||
# ================
|
||||
@ -435,7 +431,7 @@ class UserMixin(DatabaseProtocol):
|
||||
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
candidate_id = key.replace(KEY_PREFIXES['candidates'], '')
|
||||
candidate_id = key.replace(KEY_PREFIXES["candidates"], "")
|
||||
result[candidate_id] = self._deserialize(value)
|
||||
|
||||
return result
|
||||
@ -456,7 +452,7 @@ class UserMixin(DatabaseProtocol):
|
||||
"security_logs": 0,
|
||||
"ai_parameters": 0,
|
||||
"candidate_record": 0,
|
||||
"resumes": 0
|
||||
"resumes": 0,
|
||||
}
|
||||
|
||||
logger.info(f"🗑️ Starting cascading delete for candidate {candidate_id}")
|
||||
@ -495,7 +491,9 @@ class UserMixin(DatabaseProtocol):
|
||||
|
||||
deletion_stats["chat_sessions"] = len(candidate_sessions)
|
||||
deletion_stats["chat_messages"] = messages_deleted
|
||||
logger.info(f"🗑️ Deleted {len(candidate_sessions)} chat sessions and {messages_deleted} messages for candidate {candidate_id}")
|
||||
logger.info(
|
||||
f"🗑️ Deleted {len(candidate_sessions)} chat sessions and {messages_deleted} messages for candidate {candidate_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error deleting chat sessions: {e}")
|
||||
|
||||
@ -528,9 +526,11 @@ class UserMixin(DatabaseProtocol):
|
||||
logger.info(f"🗑️ Deleted user record by email: {candidate_email}")
|
||||
|
||||
# Delete by username (if different from email)
|
||||
if (candidate_username and
|
||||
candidate_username != candidate_email and
|
||||
await self.user_exists_by_username(candidate_username)):
|
||||
if (
|
||||
candidate_username
|
||||
and candidate_username != candidate_email
|
||||
and await self.user_exists_by_username(candidate_username)
|
||||
):
|
||||
await self.delete_user(candidate_username)
|
||||
user_records_deleted += 1
|
||||
logger.info(f"🗑️ Deleted user record by username: {candidate_username}")
|
||||
@ -593,8 +593,7 @@ class UserMixin(DatabaseProtocol):
|
||||
candidate_ai_params = []
|
||||
|
||||
for param_id, param_data in all_ai_params.items():
|
||||
if (param_data.get("candidateId") == candidate_id or
|
||||
param_data.get("userId") == candidate_id):
|
||||
if param_data.get("candidateId") == candidate_id or param_data.get("userId") == candidate_id:
|
||||
candidate_ai_params.append(param_id)
|
||||
|
||||
# Delete each AI parameter set
|
||||
@ -630,7 +629,9 @@ class UserMixin(DatabaseProtocol):
|
||||
break
|
||||
|
||||
if tokens_deleted > 0:
|
||||
logger.info(f"🗑️ Deleted {tokens_deleted} email verification tokens for candidate {candidate_id}")
|
||||
logger.info(
|
||||
f"🗑️ Deleted {tokens_deleted} email verification tokens for candidate {candidate_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error deleting email verification tokens: {e}")
|
||||
|
||||
@ -717,8 +718,10 @@ class UserMixin(DatabaseProtocol):
|
||||
# 15. Log the deletion as a security event (if we have admin/system user context)
|
||||
try:
|
||||
total_items_deleted = sum(deletion_stats.values())
|
||||
logger.info(f"✅ Completed cascading delete for candidate {candidate_id}. "
|
||||
f"Total items deleted: {total_items_deleted}")
|
||||
logger.info(
|
||||
f"✅ Completed cascading delete for candidate {candidate_id}. "
|
||||
f"Total items deleted: {total_items_deleted}"
|
||||
)
|
||||
logger.info(f"📊 Deletion breakdown: {deletion_stats}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error logging deletion summary: {e}")
|
||||
@ -774,8 +777,8 @@ class UserMixin(DatabaseProtocol):
|
||||
"total_candidates_processed": len(candidate_ids),
|
||||
"successful_deletions": len([r for r in batch_results.values() if "error" not in r]),
|
||||
"failed_deletions": len([r for r in batch_results.values() if "error" in r]),
|
||||
"total_items_deleted": sum(total_stats.values())
|
||||
}
|
||||
"total_items_deleted": sum(total_stats.values()),
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@ -816,7 +819,7 @@ class UserMixin(DatabaseProtocol):
|
||||
"total_sessions": 0,
|
||||
"total_messages": 0,
|
||||
"first_chat": None,
|
||||
"last_chat": None
|
||||
"last_chat": None,
|
||||
}
|
||||
|
||||
total_messages = 0
|
||||
@ -835,7 +838,7 @@ class UserMixin(DatabaseProtocol):
|
||||
"total_messages": total_messages,
|
||||
"first_chat": sessions_by_date[0].get("createdAt") if sessions_by_date else None,
|
||||
"last_chat": sessions_by_date[-1].get("lastActivity") if sessions_by_date else None,
|
||||
"recent_sessions": sessions[:5] # Last 5 sessions
|
||||
"recent_sessions": sessions[:5], # Last 5 sessions
|
||||
}
|
||||
|
||||
# ================
|
||||
@ -868,7 +871,7 @@ class UserMixin(DatabaseProtocol):
|
||||
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
viewer_id = key.replace(KEY_PREFIXES['viewers'], '')
|
||||
viewer_id = key.replace(KEY_PREFIXES["viewers"], "")
|
||||
result[viewer_id] = self._deserialize(value)
|
||||
|
||||
return result
|
||||
@ -877,4 +880,3 @@ class UserMixin(DatabaseProtocol):
|
||||
"""Delete viewer"""
|
||||
key = f"{KEY_PREFIXES['viewers']}{viewer_id}"
|
||||
await self.redis.delete(key)
|
||||
|
||||
|
@ -60,7 +60,7 @@ cert_path = "/opt/backstory/keys/cert.pem"
|
||||
host = os.getenv("BACKSTORY_HOST", "0.0.0.0")
|
||||
port = int(os.getenv("BACKSTORY_PORT", "8911"))
|
||||
api_prefix = "/api/1.0"
|
||||
debug=os.getenv("BACKSTORY_DEBUG", "false").lower() in ("true", "1", "yes")
|
||||
debug = os.getenv("BACKSTORY_DEBUG", "false").lower() in ("true", "1", "yes")
|
||||
|
||||
# Used for filtering tracebacks
|
||||
app_path="/opt/backstory/src/backend"
|
||||
app_path = "/opt/backstory/src/backend"
|
||||
|
@ -6,6 +6,7 @@ from datetime import datetime, timezone
|
||||
from user_agents import parse
|
||||
import json
|
||||
|
||||
|
||||
class DeviceManager:
|
||||
def __init__(self, database: RedisDatabase):
|
||||
self.database = database
|
||||
@ -13,7 +14,6 @@ class DeviceManager:
|
||||
def generate_device_fingerprint(self, request: Request) -> str:
|
||||
"""Generate device fingerprint from request"""
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
ip_address = request.client.host if request.client else "unknown"
|
||||
accept_language = request.headers.get("accept-language", "")
|
||||
|
||||
# Create fingerprint
|
||||
@ -35,7 +35,7 @@ class DeviceManager:
|
||||
"os": user_agent.os.family,
|
||||
"os_version": user_agent.os.version_string,
|
||||
"ip_address": request.client.host if request.client else "unknown",
|
||||
"user_agent": user_agent_string
|
||||
"user_agent": user_agent_string,
|
||||
}
|
||||
|
||||
async def is_trusted_device(self, user_id: str, device_id: str) -> bool:
|
||||
@ -55,14 +55,14 @@ class DeviceManager:
|
||||
device_data = {
|
||||
**device_info,
|
||||
"added_at": datetime.now(timezone.utc).isoformat(),
|
||||
"last_used": datetime.now(timezone.utc).isoformat()
|
||||
"last_used": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
# Store for 90 days
|
||||
await self.database.redis.setex(
|
||||
key,
|
||||
90 * 24 * 60 * 60, # 90 days in seconds
|
||||
json.dumps(device_data, default=str)
|
||||
json.dumps(device_data, default=str),
|
||||
)
|
||||
|
||||
logger.info(f"🔒 Added trusted device {device_id} for user {user_id}")
|
||||
@ -80,7 +80,7 @@ class DeviceManager:
|
||||
await self.database.redis.setex(
|
||||
key,
|
||||
90 * 24 * 60 * 60, # Reset 90 day expiry
|
||||
json.dumps(device_info, default=str)
|
||||
json.dumps(device_info, default=str),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating device last used: {e}")
|
||||
|
@ -10,6 +10,7 @@ from datetime import datetime, timezone, timedelta
|
||||
import json
|
||||
from database.manager import RedisDatabase
|
||||
|
||||
|
||||
class EmailService:
|
||||
def __init__(self):
|
||||
# Configure these in your .env file
|
||||
@ -30,36 +31,25 @@ class EmailService:
|
||||
def _format_template(self, template: str, **kwargs) -> str:
|
||||
"""Format template with provided variables"""
|
||||
return template.format(
|
||||
app_name=self.app_name,
|
||||
from_name=self.from_name,
|
||||
frontend_url=self.frontend_url,
|
||||
**kwargs
|
||||
app_name=self.app_name, from_name=self.from_name, frontend_url=self.frontend_url, **kwargs
|
||||
)
|
||||
|
||||
async def send_verification_email(
|
||||
self,
|
||||
to_email: str,
|
||||
verification_token: str,
|
||||
user_name: str,
|
||||
user_type: str = "user"
|
||||
self, to_email: str, verification_token: str, user_name: str, user_type: str = "user"
|
||||
):
|
||||
"""Send email verification email using template"""
|
||||
try:
|
||||
template = self._get_template("verification")
|
||||
verification_link = f"{self.frontend_url}/login/verify-email?token={verification_token}"
|
||||
|
||||
subject = self._format_template(
|
||||
template["subject"],
|
||||
user_name=user_name,
|
||||
to_email=to_email
|
||||
)
|
||||
subject = self._format_template(template["subject"], user_name=user_name, to_email=to_email)
|
||||
|
||||
html_content = self._format_template(
|
||||
template["html"],
|
||||
user_name=user_name,
|
||||
user_type=user_type,
|
||||
to_email=to_email,
|
||||
verification_link=verification_link
|
||||
verification_link=verification_link,
|
||||
)
|
||||
|
||||
await self._send_email(to_email, subject, html_content)
|
||||
@ -70,12 +60,7 @@ class EmailService:
|
||||
raise
|
||||
|
||||
async def send_mfa_email(
|
||||
self,
|
||||
to_email: str,
|
||||
mfa_code: str,
|
||||
device_name: str,
|
||||
user_name: str,
|
||||
ip_address: str = "Unknown"
|
||||
self, to_email: str, mfa_code: str, device_name: str, user_name: str, ip_address: str = "Unknown"
|
||||
):
|
||||
"""Send MFA code email using template"""
|
||||
try:
|
||||
@ -91,7 +76,7 @@ class EmailService:
|
||||
ip_address=ip_address,
|
||||
login_time=login_time,
|
||||
mfa_code=mfa_code,
|
||||
to_email=to_email
|
||||
to_email=to_email,
|
||||
)
|
||||
|
||||
await self._send_email(to_email, subject, html_content)
|
||||
@ -101,12 +86,7 @@ class EmailService:
|
||||
logger.error(f"❌ Failed to send MFA email to {to_email}: {e}")
|
||||
raise
|
||||
|
||||
async def send_password_reset_email(
|
||||
self,
|
||||
to_email: str,
|
||||
reset_token: str,
|
||||
user_name: str
|
||||
):
|
||||
async def send_password_reset_email(self, to_email: str, reset_token: str, user_name: str):
|
||||
"""Send password reset email using template"""
|
||||
try:
|
||||
template = self._get_template("password_reset")
|
||||
@ -115,10 +95,7 @@ class EmailService:
|
||||
subject = self._format_template(template["subject"])
|
||||
|
||||
html_content = self._format_template(
|
||||
template["html"],
|
||||
user_name=user_name,
|
||||
reset_link=reset_link,
|
||||
to_email=to_email
|
||||
template["html"], user_name=user_name, reset_link=reset_link, to_email=to_email
|
||||
)
|
||||
|
||||
await self._send_email(to_email, subject, html_content)
|
||||
@ -134,14 +111,14 @@ class EmailService:
|
||||
if not self.email_user:
|
||||
raise ValueError("Email user is not configured")
|
||||
# Create message
|
||||
msg = MIMEMultipart('alternative')
|
||||
msg['From'] = f"{self.from_name} <{self.email_user}>"
|
||||
msg['To'] = to_email
|
||||
msg['Subject'] = subject
|
||||
msg['Reply-To'] = self.email_user
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["From"] = f"{self.from_name} <{self.email_user}>"
|
||||
msg["To"] = to_email
|
||||
msg["Subject"] = subject
|
||||
msg["Reply-To"] = self.email_user
|
||||
|
||||
# Add HTML content
|
||||
html_part = MIMEText(html_content, 'html', 'utf-8')
|
||||
html_part = MIMEText(html_content, "html", "utf-8")
|
||||
msg.attach(html_part)
|
||||
|
||||
# Send email with connection pooling and retry logic
|
||||
@ -157,7 +134,6 @@ class EmailService:
|
||||
text = msg.as_string()
|
||||
server.sendmail(self.email_user, to_email, text)
|
||||
break # Success, exit retry loop
|
||||
|
||||
except smtplib.SMTPException as e:
|
||||
if attempt == max_retries - 1: # Last attempt
|
||||
raise
|
||||
@ -170,6 +146,7 @@ class EmailService:
|
||||
logger.error(f"❌ SMTP error sending to {to_email}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class EmailRateLimiter:
|
||||
def __init__(self, database: RedisDatabase):
|
||||
self.database = database
|
||||
@ -191,10 +168,7 @@ class EmailRateLimiter:
|
||||
email_records = json.loads(count_data)
|
||||
|
||||
# Filter out old records
|
||||
recent_records = [
|
||||
record for record in email_records
|
||||
if datetime.fromisoformat(record) > window_start
|
||||
]
|
||||
recent_records = [record for record in email_records if datetime.fromisoformat(record) > window_start]
|
||||
|
||||
if len(recent_records) >= limit:
|
||||
logger.warning(f"⚠️ Email rate limit exceeded for {email} ({email_type})")
|
||||
@ -202,11 +176,7 @@ class EmailRateLimiter:
|
||||
|
||||
# Add current email to records
|
||||
recent_records.append(current_time.isoformat())
|
||||
await self.database.redis.setex(
|
||||
key,
|
||||
window_minutes * 60,
|
||||
json.dumps(recent_records)
|
||||
)
|
||||
await self.database.redis.setex(key, window_minutes * 60, json.dumps(recent_records))
|
||||
|
||||
return True
|
||||
|
||||
@ -217,11 +187,8 @@ class EmailRateLimiter:
|
||||
|
||||
async def _record_email_sent(self, key: str, timestamp: datetime, ttl_minutes: int):
|
||||
"""Record that an email was sent"""
|
||||
await self.database.redis.setex(
|
||||
key,
|
||||
ttl_minutes * 60,
|
||||
json.dumps([timestamp.isoformat()])
|
||||
)
|
||||
await self.database.redis.setex(key, ttl_minutes * 60, json.dumps([timestamp.isoformat()]))
|
||||
|
||||
|
||||
class VerificationEmailRateLimiter:
|
||||
def __init__(self, database: RedisDatabase):
|
||||
@ -242,7 +209,10 @@ class VerificationEmailRateLimiter:
|
||||
# Check daily limit
|
||||
daily_count = await self.database.get_verification_attempts_count(email)
|
||||
if daily_count >= self.max_attempts_per_day:
|
||||
return False, f"Daily limit reached. You can request up to {self.max_attempts_per_day} verification emails per day."
|
||||
return (
|
||||
False,
|
||||
f"Daily limit reached. You can request up to {self.max_attempts_per_day} verification emails per day.",
|
||||
)
|
||||
|
||||
# Check hourly limit
|
||||
hour_ago = current_time - timedelta(hours=1)
|
||||
@ -251,13 +221,13 @@ class VerificationEmailRateLimiter:
|
||||
|
||||
if data:
|
||||
attempts_data = json.loads(data)
|
||||
recent_attempts = [
|
||||
attempt for attempt in attempts_data
|
||||
if datetime.fromisoformat(attempt) > hour_ago
|
||||
]
|
||||
recent_attempts = [attempt for attempt in attempts_data if datetime.fromisoformat(attempt) > hour_ago]
|
||||
|
||||
if len(recent_attempts) >= self.max_attempts_per_hour:
|
||||
return False, f"Hourly limit reached. You can request up to {self.max_attempts_per_hour} verification emails per hour."
|
||||
return (
|
||||
False,
|
||||
f"Hourly limit reached. You can request up to {self.max_attempts_per_hour} verification emails per hour.",
|
||||
)
|
||||
|
||||
# Check cooldown period
|
||||
if recent_attempts:
|
||||
@ -280,7 +250,4 @@ class VerificationEmailRateLimiter:
|
||||
await self.database.record_verification_attempt(email)
|
||||
|
||||
|
||||
|
||||
email_service = EmailService()
|
||||
|
||||
|
@ -129,9 +129,8 @@ EMAIL_TEMPLATES = {
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
""",
|
||||
},
|
||||
|
||||
"mfa": {
|
||||
"subject": "Security Code for Backstory",
|
||||
"html": """
|
||||
@ -274,9 +273,8 @@ EMAIL_TEMPLATES = {
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
""",
|
||||
},
|
||||
|
||||
"password_reset": {
|
||||
"subject": "Reset your Backstory password",
|
||||
"html": """
|
||||
@ -386,6 +384,6 @@ EMAIL_TEMPLATES = {
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
}
|
||||
""",
|
||||
},
|
||||
}
|
@ -3,13 +3,14 @@ import weakref
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional
|
||||
from contextlib import asynccontextmanager
|
||||
from pydantic import BaseModel, Field # type: ignore
|
||||
from pydantic import BaseModel # type: ignore
|
||||
|
||||
from models import Candidate
|
||||
from agents.base import CandidateEntity
|
||||
from database.manager import RedisDatabase
|
||||
from prometheus_client import CollectorRegistry # type: ignore
|
||||
|
||||
|
||||
class EntityManager(BaseModel):
|
||||
"""Manages lifecycle of CandidateEntity instances"""
|
||||
|
||||
@ -36,10 +37,7 @@ class EntityManager(BaseModel):
|
||||
pass
|
||||
self._cleanup_task = None
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
prometheus_collector: CollectorRegistry,
|
||||
database: RedisDatabase):
|
||||
def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
|
||||
"""Initialize the EntityManager with Prometheus collector"""
|
||||
self._prometheus_collector = prometheus_collector
|
||||
self._database = database
|
||||
@ -58,9 +56,7 @@ class EntityManager(BaseModel):
|
||||
raise ValueError("EntityManager has not been initialized with required components.")
|
||||
|
||||
entity = CandidateEntity(candidate=candidate)
|
||||
await entity.initialize(
|
||||
prometheus_collector=self._prometheus_collector,
|
||||
database=self._database)
|
||||
await entity.initialize(prometheus_collector=self._prometheus_collector, database=self._database)
|
||||
|
||||
# Store with reference tracking
|
||||
self._entities[candidate.id] = entity
|
||||
@ -105,10 +101,12 @@ class EntityManager(BaseModel):
|
||||
|
||||
def _on_entity_deleted(self, user_id: str):
|
||||
"""Callback when entity is garbage collected"""
|
||||
|
||||
def cleanup_callback(weak_ref):
|
||||
self._entities.pop(user_id, None)
|
||||
self._weak_refs.pop(user_id, None)
|
||||
print(f"Entity {user_id} garbage collected")
|
||||
|
||||
return cleanup_callback
|
||||
|
||||
async def release_entity(self, user_id: str):
|
||||
@ -138,8 +136,7 @@ class EntityManager(BaseModel):
|
||||
time_since_access = current_time - entity.last_accessed
|
||||
|
||||
# Remove if TTL exceeded and no active references
|
||||
if (time_since_access > timedelta(minutes=self._ttl_minutes)
|
||||
and entity.reference_count == 0):
|
||||
if time_since_access > timedelta(minutes=self._ttl_minutes) and entity.reference_count == 0:
|
||||
expired_entities.append(user_id)
|
||||
|
||||
for user_id in expired_entities:
|
||||
@ -153,6 +150,7 @@ class EntityManager(BaseModel):
|
||||
# Global entity manager instance
|
||||
entity_manager = EntityManager(default_ttl_minutes=30)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_candidate_entity(candidate: Candidate):
|
||||
"""Context manager for safe entity access with automatic reference management"""
|
||||
@ -164,4 +162,5 @@ async def get_candidate_entity(candidate: Candidate):
|
||||
finally:
|
||||
await entity_manager.release_entity(candidate.id)
|
||||
|
||||
|
||||
EntityManager.model_rebuild()
|
||||
|
@ -6,10 +6,7 @@ without getting caught up in serialization format complexities
|
||||
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from models import (
|
||||
UserStatus, UserType, SkillLevel, EmploymentType,
|
||||
Candidate, Employer, Location, Skill
|
||||
)
|
||||
from models import UserStatus, UserType, SkillLevel, EmploymentType, Candidate, Employer, Location, Skill
|
||||
|
||||
|
||||
def test_model_creation():
|
||||
@ -23,37 +20,38 @@ def test_model_creation():
|
||||
# Create candidate
|
||||
candidate = Candidate(
|
||||
email="test@example.com",
|
||||
user_type=UserType.CANDIDATE,
|
||||
username="test_candidate",
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
status=UserStatus.ACTIVE,
|
||||
firstName="John",
|
||||
lastName="Doe",
|
||||
fullName="John Doe",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
full_name="John Doe",
|
||||
skills=[skill],
|
||||
experience=[],
|
||||
education=[],
|
||||
preferredJobTypes=[EmploymentType.FULL_TIME],
|
||||
preferred_job_types=[EmploymentType.FULL_TIME],
|
||||
location=location,
|
||||
languages=[],
|
||||
certifications=[]
|
||||
certifications=[],
|
||||
)
|
||||
|
||||
# Create employer
|
||||
employer = Employer(
|
||||
firstName="Mary",
|
||||
lastName="Smith",
|
||||
fullName="Mary Smith",
|
||||
user_type=UserType.EMPLOYER,
|
||||
first_name="Mary",
|
||||
last_name="Smith",
|
||||
full_name="Mary Smith",
|
||||
email="hr@company.com",
|
||||
username="test_employer",
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
status=UserStatus.ACTIVE,
|
||||
companyName="Test Company",
|
||||
company_name="Test Company",
|
||||
industry="Technology",
|
||||
companySize="50-200",
|
||||
companyDescription="A test company",
|
||||
location=location
|
||||
company_size="50-200",
|
||||
company_description="A test company",
|
||||
location=location,
|
||||
)
|
||||
|
||||
print(f"✅ Candidate: {candidate.first_name} {candidate.last_name}")
|
||||
@ -62,6 +60,7 @@ def test_model_creation():
|
||||
|
||||
return candidate, employer
|
||||
|
||||
|
||||
def test_json_api_format():
|
||||
"""Test JSON serialization in API format (the most important use case)"""
|
||||
print("\n📡 Testing JSON API format...")
|
||||
@ -84,11 +83,12 @@ def test_json_api_format():
|
||||
assert candidate_back.first_name == candidate.first_name
|
||||
assert employer_back.company_name == employer.company_name
|
||||
|
||||
print(f"✅ JSON round-trip successful")
|
||||
print(f"✅ Data integrity verified")
|
||||
print("✅ JSON round-trip successful")
|
||||
print("✅ Data integrity verified")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def test_api_dict_format():
|
||||
"""Test dictionary format with aliases (for API requests/responses)"""
|
||||
print("\n📊 Testing API dictionary format...")
|
||||
@ -105,8 +105,8 @@ def test_api_dict_format():
|
||||
assert "createdAt" in candidate_dict
|
||||
assert "companyName" in employer_dict
|
||||
|
||||
print(f"✅ API format dictionaries created")
|
||||
print(f"✅ CamelCase aliases verified")
|
||||
print("✅ API format dictionaries created")
|
||||
print("✅ CamelCase aliases verified")
|
||||
|
||||
# Test deserializing from API format
|
||||
candidate_back = Candidate.model_validate(candidate_dict)
|
||||
@ -115,25 +115,27 @@ def test_api_dict_format():
|
||||
assert candidate_back.email == candidate.email
|
||||
assert employer_back.company_name == employer.company_name
|
||||
|
||||
print(f"✅ API format round-trip successful")
|
||||
print("✅ API format round-trip successful")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def test_validation_constraints():
|
||||
"""Test that validation constraints work"""
|
||||
print("\n🔒 Testing validation constraints...")
|
||||
|
||||
try:
|
||||
# Create a candidate with invalid email
|
||||
invalid_candidate = Candidate(
|
||||
Candidate(
|
||||
user_type=UserType.CANDIDATE,
|
||||
email="invalid-email",
|
||||
username="test_invalid",
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
status=UserStatus.ACTIVE,
|
||||
firstName="Jane",
|
||||
lastName="Doe",
|
||||
fullName="Jane Doe"
|
||||
first_name="Jane",
|
||||
last_name="Doe",
|
||||
full_name="Jane Doe",
|
||||
)
|
||||
print("❌ Validation should have failed but didn't")
|
||||
return False
|
||||
@ -141,6 +143,7 @@ def test_validation_constraints():
|
||||
print(f"✅ Validation error caught: {e}")
|
||||
return True
|
||||
|
||||
|
||||
def test_enum_values():
|
||||
"""Test that enum values work correctly"""
|
||||
print("\n📋 Testing enum values...")
|
||||
@ -155,11 +158,12 @@ def test_enum_values():
|
||||
assert candidate_dict["userType"] == "candidate"
|
||||
assert employer.user_type == UserType.EMPLOYER
|
||||
|
||||
print(f"✅ Enum values correctly serialized")
|
||||
print("✅ Enum values correctly serialized")
|
||||
print(f"✅ User types: candidate={candidate.user_type}, employer={employer.user_type}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all focused tests"""
|
||||
print("🎯 Focused Pydantic Model Tests")
|
||||
@ -172,7 +176,7 @@ def main():
|
||||
test_validation_constraints()
|
||||
test_enum_values()
|
||||
|
||||
print(f"\n🎉 All focused tests passed!")
|
||||
print("\n🎉 All focused tests passed!")
|
||||
print("=" * 40)
|
||||
print("✅ Models work correctly")
|
||||
print("✅ JSON API format works")
|
||||
@ -185,10 +189,12 @@ def main():
|
||||
except Exception as e:
|
||||
print(f"\n❌ Test failed: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
print(f"\n❌ {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
@ -2,6 +2,7 @@ from pydantic import BaseModel
|
||||
import json
|
||||
from typing import Any, List, Set
|
||||
|
||||
|
||||
def check_serializable(obj: Any, path: str = "", errors: List[str] = [], visited: Set[int] = set()) -> List[str]:
|
||||
"""
|
||||
Recursively check all fields in an object for non-JSON-serializable types, avoiding infinite recursion.
|
||||
|
@ -10,16 +10,19 @@ assert issubclass(CandidateAI, BaseUserWithType), "CandidateAI must inherit from
|
||||
assert issubclass(Employer, BaseUserWithType), "Employer must inherit from BaseUserWithType"
|
||||
assert issubclass(Guest, BaseUserWithType), "Guest must inherit from BaseUserWithType"
|
||||
|
||||
T = TypeVar('T', bound=BaseModel)
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
def cast_to_model(model_cls: Type[T], source: BaseModel) -> T:
|
||||
data = {field: getattr(source, field) for field in model_cls.__fields__}
|
||||
return model_cls(**data)
|
||||
|
||||
|
||||
def cast_to_model_safe(model_cls: Type[T], source: BaseModel) -> T:
|
||||
data = {field: copy.deepcopy(getattr(source, field)) for field in model_cls.__fields__}
|
||||
return model_cls(**data)
|
||||
|
||||
|
||||
def cast_to_base_user_with_type(user) -> BaseUserWithType:
|
||||
"""
|
||||
Casts a Candidate, CandidateAI, Employer, or Guest to BaseUserWithType.
|
||||
|
@ -7,6 +7,7 @@ from typing import Any
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline, FluxPipeline
|
||||
|
||||
|
||||
class ImageModelCache: # Stay loaded for 3 hours
|
||||
def __init__(self, timeout_seconds: float = 3 * 60 * 60):
|
||||
self._pipe = None
|
||||
@ -36,11 +37,11 @@ class ImageModelCache: # Stay loaded for 3 hours
|
||||
cached_model_type = self._get_model_type(self._model_name) if self._model_name else None
|
||||
|
||||
if (
|
||||
self._pipe is not None and
|
||||
self._model_name == model and
|
||||
self._device == device and
|
||||
current_model_type == cached_model_type and
|
||||
current_time - self._last_access_time < self._timeout_seconds
|
||||
self._pipe is not None
|
||||
and self._model_name == model
|
||||
and self._device == device
|
||||
and current_model_type == cached_model_type
|
||||
and current_time - self._last_access_time < self._timeout_seconds
|
||||
):
|
||||
self._last_access_time = current_time
|
||||
return self._pipe
|
||||
@ -52,8 +53,10 @@ class ImageModelCache: # Stay loaded for 3 hours
|
||||
model,
|
||||
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
||||
)
|
||||
|
||||
def dummy_safety_checker(images, clip_input):
|
||||
return images, [False] * len(images)
|
||||
|
||||
pipe.safety_checker = dummy_safety_checker
|
||||
else:
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
@ -61,7 +64,7 @@ class ImageModelCache: # Stay loaded for 3 hours
|
||||
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
||||
)
|
||||
try:
|
||||
pipe.load_lora_weights('enhanceaiteam/Flux-uncensored', weight_name='lora.safetensors')
|
||||
pipe.load_lora_weights("enhanceaiteam/Flux-uncensored", weight_name="lora.safetensors")
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to load LoRA weights: {str(e)}")
|
||||
|
||||
@ -89,10 +92,7 @@ class ImageModelCache: # Stay loaded for 3 hours
|
||||
|
||||
async def cleanup_if_expired(self):
|
||||
async with self._lock:
|
||||
if (
|
||||
self._pipe is not None and
|
||||
time.time() - self._last_access_time >= self._timeout_seconds
|
||||
):
|
||||
if self._pipe is not None and time.time() - self._last_access_time >= self._timeout_seconds:
|
||||
await self._unload_model()
|
||||
|
||||
async def _periodic_cleanup(self):
|
||||
|
@ -29,6 +29,7 @@ TIME_ESTIMATES = {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ImageRequest(BaseModel):
|
||||
session_id: str
|
||||
filepath: str
|
||||
@ -39,18 +40,22 @@ class ImageRequest(BaseModel):
|
||||
width: int = 256
|
||||
guidance_scale: float = 7.5
|
||||
|
||||
|
||||
# Global model cache instance
|
||||
model_cache = ImageModelCache()
|
||||
|
||||
|
||||
def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str):
|
||||
"""Background worker for Flux image generation"""
|
||||
try:
|
||||
# Flux: Run generation in the background and yield progress updates
|
||||
status_queue.put(ChatMessageStatus(
|
||||
status_queue.put(
|
||||
ChatMessageStatus(
|
||||
session_id=params.session_id,
|
||||
content=f"Initializing image generation.",
|
||||
content="Initializing image generation.",
|
||||
activity=ApiActivityType.GENERATING_IMAGE,
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
# Start the generation task
|
||||
start_gen_time = time.time()
|
||||
@ -58,13 +63,15 @@ def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task
|
||||
# Simulate your pipe call with progress updates
|
||||
def status_callback(pipeline, step, timestep, callback_kwargs):
|
||||
# Send progress updates
|
||||
progress = int((step+1) / params.iterations * 100)
|
||||
progress = int((step + 1) / params.iterations * 100)
|
||||
|
||||
status_queue.put(ChatMessageStatus(
|
||||
status_queue.put(
|
||||
ChatMessageStatus(
|
||||
session_id=params.session_id,
|
||||
content=f"Processing step {step+1}/{params.iterations} ({progress}%)",
|
||||
activity=ApiActivityType.GENERATING_IMAGE,
|
||||
))
|
||||
)
|
||||
)
|
||||
return callback_kwargs
|
||||
|
||||
# Replace this block with your actual Flux pipe call:
|
||||
@ -84,22 +91,28 @@ def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task
|
||||
image.save(params.filepath)
|
||||
|
||||
# Final completion status
|
||||
status_queue.put(ChatMessage(
|
||||
status_queue.put(
|
||||
ChatMessage(
|
||||
session_id=params.session_id,
|
||||
status=ApiStatusType.DONE,
|
||||
content=f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.",
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(e)
|
||||
status_queue.put(ChatMessageError(
|
||||
status_queue.put(
|
||||
ChatMessageError(
|
||||
session_id=params.session_id,
|
||||
content=f"Error during image generation: {str(e)}",
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError, None]:
|
||||
async def async_generate_image(
|
||||
pipe: Any, params: ImageRequest
|
||||
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError, None]:
|
||||
"""
|
||||
Single async function that handles background Flux generation with status streaming
|
||||
"""
|
||||
@ -109,18 +122,14 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato
|
||||
|
||||
try:
|
||||
# Start background worker thread
|
||||
worker_thread = Thread(
|
||||
target=flux_worker,
|
||||
args=(pipe, params, status_queue, task_id),
|
||||
daemon=True
|
||||
)
|
||||
worker_thread = Thread(target=flux_worker, args=(pipe, params, status_queue, task_id), daemon=True)
|
||||
worker_thread.start()
|
||||
|
||||
# Initial status
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=params.session_id,
|
||||
content=f"Starting image generation with task ID {task_id}",
|
||||
activity=ApiActivityType.THINKING
|
||||
activity=ApiActivityType.THINKING,
|
||||
)
|
||||
yield status_message
|
||||
|
||||
@ -177,18 +186,16 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato
|
||||
yield final_status
|
||||
|
||||
except Exception as e:
|
||||
error_status = ChatMessageError(
|
||||
session_id=params.session_id,
|
||||
content=f'Server error: {str(e)}'
|
||||
)
|
||||
error_status = ChatMessageError(session_id=params.session_id, content=f"Server error: {str(e)}")
|
||||
logger.error(error_status)
|
||||
yield error_status
|
||||
|
||||
finally:
|
||||
# Cleanup: ensure thread completion
|
||||
if worker_thread and 'worker_thread' in locals() and worker_thread.is_alive():
|
||||
if worker_thread and "worker_thread" in locals() and worker_thread.is_alive():
|
||||
worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup
|
||||
|
||||
|
||||
def status(session_id: str, status: str) -> ChatMessageStatus:
|
||||
"""Update chat message status and return it."""
|
||||
chat_message = ChatMessageStatus(
|
||||
@ -198,6 +205,7 @@ def status(session_id: str, status: str) -> ChatMessageStatus:
|
||||
)
|
||||
return chat_message
|
||||
|
||||
|
||||
async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, None]:
|
||||
"""Generate an image with specified dimensions and yield status updates with time estimates."""
|
||||
session_id = request.session_id
|
||||
@ -205,10 +213,7 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
|
||||
try:
|
||||
# Validate prompt
|
||||
if not prompt:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Prompt cannot be empty."
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content="Prompt cannot be empty.")
|
||||
logger.error(error_message.content)
|
||||
yield error_message
|
||||
return
|
||||
@ -216,15 +221,14 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
|
||||
# Validate dimensions
|
||||
if request.height <= 0 or request.width <= 0:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Height and width must be positive integers."
|
||||
session_id=session_id, content="Height and width must be positive integers."
|
||||
)
|
||||
logger.error(error_message.content)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
filedir = os.path.dirname(request.filepath)
|
||||
filename = os.path.basename(request.filepath)
|
||||
os.path.basename(request.filepath)
|
||||
os.makedirs(filedir, exist_ok=True)
|
||||
|
||||
model_type = "flux"
|
||||
@ -233,14 +237,17 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
|
||||
# Get initial time estimate, scaled by resolution
|
||||
estimates = TIME_ESTIMATES[model_type][device]
|
||||
resolution_scale = (request.height * request.width) / (512 * 512)
|
||||
estimated_total = estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale
|
||||
estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale
|
||||
|
||||
# Initialize or get cached pipeline
|
||||
start_time = time.time()
|
||||
yield status(session_id, f"Loading generative image model...")
|
||||
yield status(session_id, "Loading generative image model...")
|
||||
pipe = await model_cache.get_pipeline(request.model, device)
|
||||
load_time = time.time() - start_time
|
||||
yield status(session_id, f"Model loaded in {load_time:.1f} seconds.",)
|
||||
yield status(
|
||||
session_id,
|
||||
f"Model loaded in {load_time:.1f} seconds.",
|
||||
)
|
||||
|
||||
progress = None
|
||||
async for progress in async_generate_image(pipe, request):
|
||||
@ -252,8 +259,7 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
|
||||
|
||||
if not progress:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Image generation failed to produce a valid response."
|
||||
session_id=session_id, content="Image generation failed to produce a valid response."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
@ -269,10 +275,7 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
|
||||
yield chat_message
|
||||
|
||||
except Exception as e:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content=f"Error during image generation: {str(e)}"
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content=f"Error during image generation: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(error_message.content)
|
||||
yield error_message
|
||||
|
@ -2,6 +2,7 @@ import json
|
||||
import re
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
def extract_json_blocks(text: str, allow_multiple: bool = False) -> List[dict]:
|
||||
"""
|
||||
Extract JSON blocks from text, even if surrounded by markdown or noisy text.
|
||||
@ -31,13 +32,14 @@ def extract_json_blocks(text: str, allow_multiple: bool = False) -> List[dict]:
|
||||
|
||||
return found
|
||||
|
||||
|
||||
def _extract_standalone_json(text: str, allow_multiple: bool = False) -> List[Union[dict, list]]:
|
||||
"""Extract standalone JSON objects or arrays from text using proper brace counting."""
|
||||
found = []
|
||||
i = 0
|
||||
|
||||
while i < len(text):
|
||||
if text[i] in '{[':
|
||||
if text[i] in "{[":
|
||||
# Found potential JSON start
|
||||
json_str = _extract_complete_json_at_position(text, i)
|
||||
if json_str:
|
||||
@ -55,16 +57,17 @@ def _extract_standalone_json(text: str, allow_multiple: bool = False) -> List[Un
|
||||
|
||||
return found
|
||||
|
||||
|
||||
def _extract_complete_json_at_position(text: str, start_pos: int) -> str:
|
||||
"""
|
||||
Extract a complete JSON object or array starting at the given position.
|
||||
Uses proper brace/bracket counting and string escape handling.
|
||||
"""
|
||||
if start_pos >= len(text) or text[start_pos] not in '{[':
|
||||
if start_pos >= len(text) or text[start_pos] not in "{[":
|
||||
return ""
|
||||
|
||||
start_char = text[start_pos]
|
||||
end_char = '}' if start_char == '{' else ']'
|
||||
end_char = "}" if start_char == "{" else "]"
|
||||
|
||||
count = 1
|
||||
i = start_pos + 1
|
||||
@ -76,7 +79,7 @@ def _extract_complete_json_at_position(text: str, start_pos: int) -> str:
|
||||
|
||||
if escape_next:
|
||||
escape_next = False
|
||||
elif char == '\\' and in_string:
|
||||
elif char == "\\" and in_string:
|
||||
escape_next = True
|
||||
elif char == '"' and not escape_next:
|
||||
in_string = not in_string
|
||||
@ -92,6 +95,7 @@ def _extract_complete_json_at_position(text: str, start_pos: int) -> str:
|
||||
return text[start_pos:i]
|
||||
return ""
|
||||
|
||||
|
||||
def extract_json_from_text(text: str) -> str:
|
||||
"""Extract JSON string from text that may contain other content."""
|
||||
return json.dumps(extract_json_blocks(text, allow_multiple=False)[0])
|
||||
|
@ -2,11 +2,11 @@ import os
|
||||
import warnings
|
||||
import logging
|
||||
import defines
|
||||
|
||||
|
||||
def _setup_logging(level=defines.logging_level) -> logging.Logger:
|
||||
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
|
||||
warnings.filterwarnings(
|
||||
"ignore", message="Overriding a previously registered kernel"
|
||||
)
|
||||
warnings.filterwarnings("ignore", message="Overriding a previously registered kernel")
|
||||
warnings.filterwarnings("ignore", message="Warning only once for all operators")
|
||||
warnings.filterwarnings("ignore", message=".*Couldn't find ffmpeg or avconv.*")
|
||||
warnings.filterwarnings("ignore", message="'force_all_finite' was renamed to")
|
||||
@ -19,8 +19,7 @@ def _setup_logging(level=defines.logging_level) -> logging.Logger:
|
||||
|
||||
# Create a custom formatter
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(levelname)s - %(filename)s:%(lineno)d - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
fmt="%(levelname)s - %(filename)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
|
||||
# Create a handler (e.g., StreamHandler for console output)
|
||||
@ -50,5 +49,6 @@ def _setup_logging(level=defines.logging_level) -> logging.Logger:
|
||||
logger = logging.getLogger(__name__)
|
||||
return logger
|
||||
|
||||
|
||||
logger = _setup_logging(level=defines.logging_level)
|
||||
logger.debug(f"Logging initialized with level: {defines.logging_level}")
|
@ -58,6 +58,7 @@ background_task_manager = None
|
||||
prev_int = signal.getsignal(signal.SIGINT)
|
||||
prev_term = signal.getsignal(signal.SIGTERM)
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
logger.info(f"⚠️ Received signal {signum!r}, shutting down…")
|
||||
# now call the old handler (it might raise KeyboardInterrupt or exit)
|
||||
@ -66,9 +67,11 @@ def signal_handler(signum, frame):
|
||||
elif signum == signal.SIGTERM and callable(prev_term):
|
||||
prev_term(signum, frame)
|
||||
|
||||
|
||||
# Global background task manager
|
||||
background_task_manager: Optional[BackgroundTaskManager] = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup
|
||||
@ -116,6 +119,7 @@ async def lifespan(app: FastAPI):
|
||||
if db_manager:
|
||||
await db_manager.graceful_shutdown()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
title="Backstory API",
|
||||
@ -129,11 +133,9 @@ app = FastAPI(
|
||||
ssl_enabled = os.getenv("SSL_ENABLED", "true").lower() == "true"
|
||||
|
||||
if ssl_enabled:
|
||||
allow_origins = ["https://battle-linux.ketrenos.com:3000",
|
||||
"https://backstory-beta.ketrenos.com"]
|
||||
allow_origins = ["https://battle-linux.ketrenos.com:3000", "https://backstory-beta.ketrenos.com"]
|
||||
else:
|
||||
allow_origins = ["http://battle-linux.ketrenos.com:3000",
|
||||
"http://backstory-beta.ketrenos.com"]
|
||||
allow_origins = ["http://battle-linux.ketrenos.com:3000", "http://backstory-beta.ketrenos.com"]
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
@ -144,12 +146,14 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# ============================
|
||||
# Debug data type failures
|
||||
# ============================
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Validation error {request.method} {request.url.path}: {str(exc)}")
|
||||
@ -158,6 +162,7 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
|
||||
content=json.dumps({"detail": str(exc)}),
|
||||
)
|
||||
|
||||
|
||||
# ============================
|
||||
# Create API router with prefix
|
||||
# ============================
|
||||
@ -181,9 +186,10 @@ api_router.include_router(users.router)
|
||||
# Health Check and Info Endpoints
|
||||
# ============================
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check(
|
||||
database = Depends(get_database),
|
||||
database=Depends(get_database),
|
||||
):
|
||||
"""Health check endpoint"""
|
||||
try:
|
||||
@ -202,15 +208,12 @@ async def health_check(
|
||||
return {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"database": {
|
||||
"status": "connected",
|
||||
"stats": stats
|
||||
},
|
||||
"database": {"status": "connected", "stats": stats},
|
||||
"redis": {
|
||||
"version": redis_info.get("redis_version", "unknown"),
|
||||
"uptime": redis_info.get("uptime_in_seconds", 0),
|
||||
"memory_used": redis_info.get("used_memory_human", "unknown")
|
||||
}
|
||||
"memory_used": redis_info.get("used_memory_human", "unknown"),
|
||||
},
|
||||
}
|
||||
|
||||
except RuntimeError as e:
|
||||
@ -219,6 +222,7 @@ async def health_check(
|
||||
logger.error(f"❌ Health check failed: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
# ============================
|
||||
# Include Router in App
|
||||
# ============================
|
||||
@ -231,11 +235,14 @@ app.include_router(api_router)
|
||||
# ============================
|
||||
logger.info(f"Debug mode is {'enabled' if defines.debug else 'disabled'}")
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def log_requests(request: Request, call_next):
|
||||
try:
|
||||
if defines.debug and not re.match(rf"{defines.api_prefix}/metrics", request.url.path):
|
||||
logger.info(f"📝 Request {request.method}: {request.url.path}, Remote: {request.client.host if request.client else ''}")
|
||||
logger.info(
|
||||
f"📝 Request {request.method}: {request.url.path}, Remote: {request.client.host if request.client else ''}"
|
||||
)
|
||||
response = await call_next(request)
|
||||
if defines.debug and not re.match(rf"{defines.api_prefix}/metrics", request.url.path):
|
||||
if response.status_code < 200 or response.status_code >= 300:
|
||||
@ -243,11 +250,13 @@ async def log_requests(request: Request, call_next):
|
||||
return response
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Error processing request: {str(e)}, Path: {request.url.path}, Method: {request.method}")
|
||||
return JSONResponse(status_code=400, content={"detail": "Invalid HTTP request"})
|
||||
|
||||
|
||||
# ============================
|
||||
# Request tracking middleware
|
||||
# ============================
|
||||
@ -266,6 +275,7 @@ async def track_requests(request, call_next):
|
||||
finally:
|
||||
db_manager.decrement_requests()
|
||||
|
||||
|
||||
# ============================
|
||||
# FastAPI Metrics
|
||||
# ============================
|
||||
@ -277,7 +287,7 @@ instrumentator = Instrumentator(
|
||||
should_ignore_untemplated=True,
|
||||
should_group_untemplated=True,
|
||||
excluded_handlers=[f"{defines.api_prefix}/metrics"],
|
||||
registry=prometheus_collector
|
||||
registry=prometheus_collector,
|
||||
)
|
||||
|
||||
# Instrument the FastAPI app
|
||||
@ -291,6 +301,7 @@ instrumentator.expose(app, endpoint=f"{defines.api_prefix}/metrics")
|
||||
# Static File Serving
|
||||
# ============================
|
||||
|
||||
|
||||
@app.get("/{path:path}")
|
||||
async def serve_static(path: str, request: Request):
|
||||
full_path = os.path.join(defines.static_content, path)
|
||||
@ -300,6 +311,7 @@ async def serve_static(path: str, request: Request):
|
||||
|
||||
return FileResponse(os.path.join(defines.static_content, "index.html"))
|
||||
|
||||
|
||||
# Root endpoint when no static files
|
||||
@app.get("/", include_in_schema=False)
|
||||
async def root():
|
||||
@ -309,9 +321,10 @@ async def root():
|
||||
"version": "1.0.0",
|
||||
"api_prefix": defines.api_prefix,
|
||||
"documentation": f"{defines.api_prefix}/docs",
|
||||
"health": f"{defines.api_prefix}/health"
|
||||
"health": f"{defines.api_prefix}/health",
|
||||
}
|
||||
|
||||
|
||||
async def periodic_verification_cleanup():
|
||||
"""Background task to periodically clean up expired verification tokens"""
|
||||
try:
|
||||
@ -324,6 +337,7 @@ async def periodic_verification_cleanup():
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error in periodic verification cleanup: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
host = defines.host
|
||||
port = defines.port
|
||||
|
File diff suppressed because it is too large
Load Diff
23
src/backend/pyproject.toml
Normal file
23
src/backend/pyproject.toml
Normal file
@ -0,0 +1,23 @@
|
||||
[tool.ruff]
|
||||
# Set line length (default is 88, Black-compatible)
|
||||
line-length = 120
|
||||
exclude = [
|
||||
".git",
|
||||
"__pycache__",
|
||||
"migrations",
|
||||
"venv",
|
||||
"dist",
|
||||
"build",
|
||||
]
|
||||
target-version = "py312"
|
||||
|
||||
[tool.ruff.format]
|
||||
# Use Ruff's formatter (Black-compatible by default)
|
||||
quote-style = "double" # Enforce double quotes
|
||||
indent-style = "space" # Use spaces for indentation
|
||||
skip-magic-trailing-comma = false # Add trailing commas where applicable
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
# Ignore specific rules for certain files (e.g., tests or scripts)
|
||||
"tests/**" = ["D100", "S101"] # Ignore missing docstrings and assert warnings in tests
|
||||
"scripts/**" = ["E402"] # Ignore import order in scripts
|
@ -1,7 +1,3 @@
|
||||
from .rag import ChromaDBFileWatcher, start_file_watcher, RagEntry
|
||||
__all__ = [
|
||||
"ChromaDBFileWatcher",
|
||||
"start_file_watcher",
|
||||
"RagEntry"
|
||||
]
|
||||
|
||||
__all__ = ["ChromaDBFileWatcher", "start_file_watcher", "RagEntry"]
|
||||
|
@ -7,10 +7,12 @@ import logging
|
||||
|
||||
import defines
|
||||
|
||||
|
||||
class Chunk(TypedDict):
|
||||
text: str
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
def clear_chunk(chunk: Chunk):
|
||||
chunk["text"] = ""
|
||||
chunk["metadata"] = {
|
||||
@ -22,6 +24,7 @@ def clear_chunk(chunk: Chunk):
|
||||
}
|
||||
return chunk
|
||||
|
||||
|
||||
class MarkdownChunker:
|
||||
def __init__(self):
|
||||
# Initialize the Markdown parser
|
||||
@ -76,10 +79,7 @@ class MarkdownChunker:
|
||||
# Initialize a chunk structure
|
||||
chunk: Chunk = {
|
||||
"text": "",
|
||||
"metadata": {
|
||||
"source_file": file_path,
|
||||
"lines": total_lines
|
||||
},
|
||||
"metadata": {"source_file": file_path, "lines": total_lines},
|
||||
}
|
||||
clear_chunk(chunk)
|
||||
|
||||
@ -89,9 +89,7 @@ class MarkdownChunker:
|
||||
return chunks
|
||||
|
||||
def _sanitize_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {
|
||||
k: ("" if v is None else v) for k, v in metadata.items() if v is not None
|
||||
}
|
||||
return {k: ("" if v is None else v) for k, v in metadata.items() if v is not None}
|
||||
|
||||
def _extract_text_from_children(self, node: SyntaxTreeNode) -> str:
|
||||
lines = []
|
||||
@ -114,7 +112,7 @@ class MarkdownChunker:
|
||||
chunks: List[Chunk],
|
||||
chunk: Chunk,
|
||||
level: int,
|
||||
buffer: int = defines.chunk_buffer
|
||||
buffer: int = defines.chunk_buffer,
|
||||
) -> int:
|
||||
is_list = False
|
||||
# Handle heading nodes
|
||||
@ -141,7 +139,7 @@ class MarkdownChunker:
|
||||
if node.nester_tokens:
|
||||
opening, closing = node.nester_tokens
|
||||
if opening and opening.map:
|
||||
( begin, end ) = opening.map
|
||||
(begin, end) = opening.map
|
||||
metadata = chunk["metadata"]
|
||||
metadata["chunk_begin"] = max(0, begin - buffer)
|
||||
metadata["chunk_end"] = min(metadata["lines"], end + buffer)
|
||||
@ -186,7 +184,7 @@ class MarkdownChunker:
|
||||
if node.nester_tokens:
|
||||
opening, closing = node.nester_tokens
|
||||
if opening and opening.map:
|
||||
( begin, end ) = opening.map
|
||||
(begin, end) = opening.map
|
||||
metadata = chunk["metadata"]
|
||||
metadata["chunk_begin"] = max(0, begin - buffer)
|
||||
metadata["chunk_end"] = min(metadata["lines"], end + buffer)
|
||||
@ -198,9 +196,7 @@ class MarkdownChunker:
|
||||
# Recursively process children
|
||||
if not is_list:
|
||||
for child in node.children:
|
||||
level = self._process_node(
|
||||
child, current_headings, chunks, chunk, level=level
|
||||
)
|
||||
level = self._process_node(child, current_headings, chunks, chunk, level=level)
|
||||
|
||||
# After root-level recursion, finalize any remaining chunk
|
||||
if node.type == "document":
|
||||
@ -211,7 +207,7 @@ class MarkdownChunker:
|
||||
if node.nester_tokens:
|
||||
opening, closing = node.nester_tokens
|
||||
if opening and opening.map:
|
||||
( begin, end ) = opening.map
|
||||
(begin, end) = opening.map
|
||||
metadata = chunk["metadata"]
|
||||
metadata["chunk_begin"] = max(0, begin - buffer)
|
||||
metadata["chunk_end"] = min(metadata["lines"], end + buffer)
|
||||
|
@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field # type: ignore
|
||||
from pydantic import BaseModel # type: ignore
|
||||
from typing import List, Optional, Dict, Any
|
||||
import os
|
||||
import glob
|
||||
@ -33,11 +33,13 @@ __all__ = ["ChromaDBFileWatcher", "start_file_watcher"]
|
||||
DEFAULT_CHUNK_SIZE = 750
|
||||
DEFAULT_CHUNK_OVERLAP = 100
|
||||
|
||||
|
||||
class RagEntry(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
def __init__(
|
||||
self,
|
||||
@ -72,9 +74,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
# self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
|
||||
# Path for storing file hash state
|
||||
self.hash_state_path = os.path.join(
|
||||
self.persist_directory, f"{collection_name}_hash_state.json"
|
||||
)
|
||||
self.hash_state_path = os.path.join(self.persist_directory, f"{collection_name}_hash_state.json")
|
||||
|
||||
# Flag to track if this is a new collection
|
||||
self.is_new_collection = False
|
||||
@ -158,9 +158,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
process_all: If True, process all files regardless of hash status
|
||||
"""
|
||||
# Check for new or modified files
|
||||
file_paths = glob.glob(
|
||||
os.path.join(self.watch_directory, "**/*"), recursive=True
|
||||
)
|
||||
file_paths = glob.glob(os.path.join(self.watch_directory, "**/*"), recursive=True)
|
||||
files_checked = 0
|
||||
files_processed = 0
|
||||
files_to_process = []
|
||||
@ -180,20 +178,12 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
continue
|
||||
|
||||
# If file is new, changed, or we're processing all files
|
||||
if (
|
||||
process_all
|
||||
or file_path not in self.file_hashes
|
||||
or self.file_hashes[file_path] != current_hash
|
||||
):
|
||||
if process_all or file_path not in self.file_hashes or self.file_hashes[file_path] != current_hash:
|
||||
self.file_hashes[file_path] = current_hash
|
||||
files_to_process.append(file_path)
|
||||
logging.info(
|
||||
f"File {'found' if process_all else 'changed'}: {file_path}"
|
||||
)
|
||||
logging.info(f"File {'found' if process_all else 'changed'}: {file_path}")
|
||||
|
||||
logging.info(
|
||||
f"Found {len(files_to_process)} files to process after scanning {files_checked} files"
|
||||
)
|
||||
logging.info(f"Found {len(files_to_process)} files to process after scanning {files_checked} files")
|
||||
|
||||
# Check for deleted files
|
||||
deleted_files = []
|
||||
@ -201,9 +191,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
if not os.path.exists(file_path):
|
||||
deleted_files.append(file_path)
|
||||
# Schedule removal
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.remove_file_from_collection(file_path), self.loop
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
|
||||
# Don't block on result, just let it run
|
||||
logging.info(f"File deleted: {file_path}")
|
||||
|
||||
@ -253,10 +241,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
if not current_hash: # File might have been deleted or is inaccessible
|
||||
return
|
||||
|
||||
if (
|
||||
file_path in self.file_hashes
|
||||
and self.file_hashes[file_path] == current_hash
|
||||
):
|
||||
if file_path in self.file_hashes and self.file_hashes[file_path] == current_hash:
|
||||
# File hasn't actually changed in content
|
||||
logging.info(f"Hash has not changed for {file_path}")
|
||||
return
|
||||
@ -289,9 +274,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
if results and "ids" in results and results["ids"]:
|
||||
self.collection.delete(ids=results["ids"])
|
||||
await self.database.update_user_rag_timestamp(self.user_id)
|
||||
logging.info(
|
||||
f"Removed {len(results['ids'])} chunks for deleted file: {file_path}"
|
||||
)
|
||||
logging.info(f"Removed {len(results['ids'])} chunks for deleted file: {file_path}")
|
||||
|
||||
# Remove from hash dictionary
|
||||
if file_path in self.file_hashes:
|
||||
@ -304,17 +287,15 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
|
||||
def _update_umaps(self):
|
||||
# Update the UMAP embeddings
|
||||
self._umap_collection = ChromaDBGetResponse.model_validate(self._collection.get(
|
||||
include=["embeddings", "documents", "metadatas"]
|
||||
))
|
||||
self._umap_collection = ChromaDBGetResponse.model_validate(
|
||||
self._collection.get(include=["embeddings", "documents", "metadatas"])
|
||||
)
|
||||
if not self._umap_collection or not len(self._umap_collection.embeddings):
|
||||
logging.warning("⚠️ No embeddings found in the collection.")
|
||||
return
|
||||
|
||||
# During initialization
|
||||
logging.info(
|
||||
f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors"
|
||||
)
|
||||
logging.info(f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors")
|
||||
vectors = np.array(self._umap_collection.embeddings)
|
||||
self._umap_model_2d = umap.UMAP(
|
||||
n_components=2,
|
||||
@ -328,9 +309,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
# f"2D UMAP model n_components: {self._umap_model_2d.n_components}"
|
||||
# ) # Should be 2
|
||||
|
||||
logging.info(
|
||||
f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors"
|
||||
)
|
||||
logging.info(f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors")
|
||||
self._umap_model_3d = umap.UMAP(
|
||||
n_components=3,
|
||||
random_state=8911,
|
||||
@ -338,7 +317,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
n_neighbors=30,
|
||||
min_dist=0.01,
|
||||
)
|
||||
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors)# type: ignore
|
||||
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors) # type: ignore
|
||||
# logging.info(
|
||||
# f"3D UMAP model n_components: {self._umap_model_3d.n_components}"
|
||||
# ) # Should be 3
|
||||
@ -373,9 +352,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
self.is_new_collection = True
|
||||
logging.info(f"Recreating collection: {self.collection_name}")
|
||||
|
||||
return chroma_client.get_or_create_collection(
|
||||
name=self.collection_name, metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
return chroma_client.get_or_create_collection(name=self.collection_name, metadata={"hnsw:space": "cosine"})
|
||||
|
||||
async def get_embedding(self, text: str) -> np.ndarray:
|
||||
"""Generate and normalize an embedding for the given text."""
|
||||
@ -419,9 +396,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
# Generate a more unique ID based on content and metadata
|
||||
path_hash = ""
|
||||
if "path" in metadata:
|
||||
path_hash = hashlib.md5(metadata["source_file"].encode()).hexdigest()[
|
||||
:8
|
||||
]
|
||||
path_hash = hashlib.md5(metadata["source_file"].encode()).hexdigest()[:8]
|
||||
content_hash = hashlib.md5(text.encode()).hexdigest()[:8]
|
||||
chunk_id = f"{path_hash}_{i}_{content_hash}"
|
||||
|
||||
@ -438,7 +413,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
logging.error(traceback.format_exc())
|
||||
logging.error(chunk)
|
||||
|
||||
def prepare_metadata(self, meta: Dict[str, Any], buffer=defines.chunk_buffer)-> str | None:
|
||||
def prepare_metadata(self, meta: Dict[str, Any], buffer=defines.chunk_buffer) -> str | None:
|
||||
source_file = meta.get("source_file")
|
||||
try:
|
||||
source_file = meta["source_file"]
|
||||
@ -541,9 +516,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
return
|
||||
|
||||
file_path = event.src_path
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.remove_file_from_collection(file_path), self.loop
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
|
||||
logging.info(f"File deleted: {file_path}")
|
||||
|
||||
def on_moved(self, event):
|
||||
@ -571,11 +544,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
try:
|
||||
# Remove existing entries for this file
|
||||
existing_results = self.collection.get(where={"path": file_path})
|
||||
if (
|
||||
existing_results
|
||||
and "ids" in existing_results
|
||||
and existing_results["ids"]
|
||||
):
|
||||
if existing_results and "ids" in existing_results and existing_results["ids"]:
|
||||
self.collection.delete(ids=existing_results["ids"])
|
||||
await self.database.update_user_rag_timestamp(self.user_id)
|
||||
|
||||
@ -584,15 +553,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
p = Path(file_path)
|
||||
p_as_md = p.with_suffix(".md")
|
||||
if p_as_md.exists():
|
||||
logging.info(
|
||||
f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}"
|
||||
)
|
||||
logging.info(f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}")
|
||||
|
||||
# If file_path.md doesn't exist or file_path is newer than file_path.md,
|
||||
# fire off markitdown
|
||||
if (not p_as_md.exists()) or (
|
||||
p.stat().st_mtime > p_as_md.stat().st_mtime
|
||||
):
|
||||
if (not p_as_md.exists()) or (p.stat().st_mtime > p_as_md.stat().st_mtime):
|
||||
self._markitdown(file_path, p_as_md)
|
||||
return
|
||||
|
||||
@ -626,9 +591,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
# Process all files regardless of hash state
|
||||
num_processed = await self.scan_directory(process_all=True)
|
||||
|
||||
logging.info(
|
||||
f"Vectorstore initialized with {self.collection.count()} documents"
|
||||
)
|
||||
logging.info(f"Vectorstore initialized with {self.collection.count()} documents")
|
||||
|
||||
self._update_umaps()
|
||||
|
||||
@ -676,7 +639,7 @@ def start_file_watcher(
|
||||
persist_directory=persist_directory,
|
||||
collection_name=collection_name,
|
||||
recreate=recreate,
|
||||
database=database
|
||||
database=database,
|
||||
)
|
||||
|
||||
# Process all files if:
|
||||
|
@ -13,14 +13,4 @@ from . import employers
|
||||
from . import admin
|
||||
from . import system
|
||||
|
||||
__all__ = [
|
||||
"auth",
|
||||
"candidates",
|
||||
"resumes",
|
||||
"jobs",
|
||||
"chat",
|
||||
"users",
|
||||
"employers",
|
||||
"admin",
|
||||
"system"
|
||||
]
|
||||
__all__ = ["auth", "candidates", "resumes", "jobs", "chat", "users", "employers", "admin", "system"]
|
||||
|
@ -5,8 +5,18 @@ import json
|
||||
from datetime import datetime, timezone, UTC
|
||||
|
||||
from fastapi import (
|
||||
APIRouter, HTTPException, Depends, Body, Request, HTTPException,
|
||||
Depends, Query, Path, Body, APIRouter, Request
|
||||
APIRouter,
|
||||
HTTPException,
|
||||
Depends,
|
||||
Body,
|
||||
Request,
|
||||
HTTPException,
|
||||
Depends,
|
||||
Query,
|
||||
Path,
|
||||
Body,
|
||||
APIRouter,
|
||||
Request,
|
||||
)
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
@ -14,52 +24,51 @@ from fastapi.responses import JSONResponse
|
||||
from utils.rate_limiter import RateLimiter, get_rate_limiter
|
||||
from database.manager import RedisDatabase
|
||||
from logger import logger
|
||||
from utils.dependencies import (
|
||||
get_current_admin, get_current_user_or_guest, get_database, background_task_manager
|
||||
)
|
||||
from utils.dependencies import get_current_admin, get_current_user_or_guest, get_database, background_task_manager
|
||||
from utils.responses import (
|
||||
create_paginated_response, create_success_response, create_error_response,
|
||||
create_paginated_response,
|
||||
create_success_response,
|
||||
create_error_response,
|
||||
)
|
||||
|
||||
|
||||
# Create router for authentication endpoints
|
||||
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
|
||||
|
||||
@router.post("/tasks/cleanup-guests")
|
||||
async def manual_guest_cleanup(
|
||||
inactive_hours: int = Body(24, embed=True),
|
||||
current_user = Depends(get_current_admin),
|
||||
admin_user = Depends(get_current_admin)
|
||||
current_user=Depends(get_current_admin),
|
||||
admin_user=Depends(get_current_admin),
|
||||
):
|
||||
"""Manually trigger guest cleanup (admin only)"""
|
||||
try:
|
||||
if not background_task_manager:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available")
|
||||
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available"),
|
||||
)
|
||||
|
||||
cleaned_count = await background_task_manager.cleanup_inactive_guests(inactive_hours)
|
||||
|
||||
logger.info(f"🧹 Manual guest cleanup triggered by admin {admin_user.id}: {cleaned_count} guests cleaned")
|
||||
|
||||
return create_success_response({
|
||||
return create_success_response(
|
||||
{
|
||||
"message": f"Guest cleanup completed. Removed {cleaned_count} inactive sessions.",
|
||||
"cleaned_count": cleaned_count,
|
||||
"triggered_by": admin_user.id
|
||||
})
|
||||
"triggered_by": admin_user.id,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Manual guest cleanup error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("CLEANUP_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.post("/tasks/cleanup-tokens")
|
||||
async def manual_token_cleanup(
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
async def manual_token_cleanup(admin_user=Depends(get_current_admin)):
|
||||
"""Manually trigger verification token cleanup (admin only)"""
|
||||
try:
|
||||
global background_task_manager
|
||||
@ -67,31 +76,28 @@ async def manual_token_cleanup(
|
||||
if not background_task_manager:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available")
|
||||
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available"),
|
||||
)
|
||||
|
||||
cleaned_count = await background_task_manager.cleanup_expired_verification_tokens()
|
||||
|
||||
logger.info(f"🧹 Manual token cleanup triggered by admin {admin_user.id}: {cleaned_count} tokens cleaned")
|
||||
|
||||
return create_success_response({
|
||||
return create_success_response(
|
||||
{
|
||||
"message": f"Token cleanup completed. Removed {cleaned_count} expired tokens.",
|
||||
"cleaned_count": cleaned_count,
|
||||
"triggered_by": admin_user.id
|
||||
})
|
||||
"triggered_by": admin_user.id,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Manual token cleanup error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("CLEANUP_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.post("/tasks/cleanup-rate-limits")
|
||||
async def manual_rate_limit_cleanup(
|
||||
days_old: int = Body(7, embed=True),
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
async def manual_rate_limit_cleanup(days_old: int = Body(7, embed=True), admin_user=Depends(get_current_admin)):
|
||||
"""Manually trigger rate limit data cleanup (admin only)"""
|
||||
try:
|
||||
global background_task_manager
|
||||
@ -99,60 +105,55 @@ async def manual_rate_limit_cleanup(
|
||||
if not background_task_manager:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available")
|
||||
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available"),
|
||||
)
|
||||
|
||||
cleaned_count = await background_task_manager.cleanup_old_rate_limit_data(days_old)
|
||||
|
||||
logger.info(f"🧹 Manual rate limit cleanup triggered by admin {admin_user.id}: {cleaned_count} keys cleaned")
|
||||
|
||||
return create_success_response({
|
||||
return create_success_response(
|
||||
{
|
||||
"message": f"Rate limit cleanup completed. Removed {cleaned_count} old keys.",
|
||||
"cleaned_count": cleaned_count,
|
||||
"triggered_by": admin_user.id
|
||||
})
|
||||
"triggered_by": admin_user.id,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Manual rate limit cleanup error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("CLEANUP_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
|
||||
|
||||
|
||||
# ========================================
|
||||
# System Health and Maintenance Endpoints
|
||||
# ========================================
|
||||
|
||||
|
||||
@router.get("/system/health")
|
||||
async def get_system_health(
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
async def get_system_health(request: Request, admin_user=Depends(get_current_admin)):
|
||||
"""Get comprehensive system health status (admin only)"""
|
||||
try:
|
||||
# Database health
|
||||
database_manager = getattr(request.app.state, 'database_manager', None)
|
||||
database_manager = getattr(request.app.state, "database_manager", None)
|
||||
db_health = {"status": "unavailable", "healthy": False}
|
||||
|
||||
if database_manager:
|
||||
try:
|
||||
database = database_manager.get_database()
|
||||
database_manager.get_database()
|
||||
from database.manager import redis_manager
|
||||
|
||||
redis_health = await redis_manager.health_check()
|
||||
db_health = {
|
||||
"status": redis_health.get("status", "unknown"),
|
||||
"healthy": redis_health.get("status") == "healthy",
|
||||
"details": redis_health
|
||||
"details": redis_health,
|
||||
}
|
||||
except Exception as e:
|
||||
db_health = {
|
||||
"status": "error",
|
||||
"healthy": False,
|
||||
"error": str(e)
|
||||
}
|
||||
db_health = {"status": "error", "healthy": False, "error": str(e)}
|
||||
|
||||
# Background task health
|
||||
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
|
||||
background_task_manager = getattr(request.app.state, "background_task_manager", None)
|
||||
task_health = {"status": "unavailable", "healthy": False}
|
||||
|
||||
if background_task_manager:
|
||||
@ -166,39 +167,30 @@ async def get_system_health(
|
||||
"healthy": task_status["running"] and failed_tasks == 0,
|
||||
"running_tasks": running_tasks,
|
||||
"failed_tasks": failed_tasks,
|
||||
"total_tasks": task_status["task_count"]
|
||||
"total_tasks": task_status["task_count"],
|
||||
}
|
||||
except Exception as e:
|
||||
task_health = {
|
||||
"status": "error",
|
||||
"healthy": False,
|
||||
"error": str(e)
|
||||
}
|
||||
task_health = {"status": "error", "healthy": False, "error": str(e)}
|
||||
|
||||
# Overall health
|
||||
overall_healthy = db_health["healthy"] and task_health["healthy"]
|
||||
|
||||
return create_success_response({
|
||||
return create_success_response(
|
||||
{
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"overall_healthy": overall_healthy,
|
||||
"components": {
|
||||
"database": db_health,
|
||||
"background_tasks": task_health
|
||||
"components": {"database": db_health, "background_tasks": task_health},
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting system health: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("HEALTH_CHECK_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("HEALTH_CHECK_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.post("/maintenance/cleanup")
|
||||
async def run_maintenance_cleanup(
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
request: Request, admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Run comprehensive maintenance cleanup (admin only)"""
|
||||
try:
|
||||
@ -217,50 +209,40 @@ async def run_maintenance_cleanup(
|
||||
cleanup_results[operation_name] = {
|
||||
"success": True,
|
||||
"cleaned_count": result,
|
||||
"message": f"Cleaned {result} items"
|
||||
"message": f"Cleaned {result} items",
|
||||
}
|
||||
except Exception as e:
|
||||
cleanup_results[operation_name] = {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": f"Failed: {str(e)}"
|
||||
}
|
||||
cleanup_results[operation_name] = {"success": False, "error": str(e), "message": f"Failed: {str(e)}"}
|
||||
|
||||
# Calculate totals
|
||||
total_cleaned = sum(
|
||||
result.get("cleaned_count", 0)
|
||||
for result in cleanup_results.values()
|
||||
if result.get("success", False)
|
||||
result.get("cleaned_count", 0) for result in cleanup_results.values() if result.get("success", False)
|
||||
)
|
||||
|
||||
successful_operations = len([
|
||||
r for r in cleanup_results.values()
|
||||
if r.get("success", False)
|
||||
])
|
||||
successful_operations = len([r for r in cleanup_results.values() if r.get("success", False)])
|
||||
|
||||
return create_success_response({
|
||||
return create_success_response(
|
||||
{
|
||||
"message": f"Maintenance cleanup completed. {total_cleaned} items cleaned across {successful_operations} operations.",
|
||||
"total_cleaned": total_cleaned,
|
||||
"successful_operations": successful_operations,
|
||||
"details": cleanup_results
|
||||
})
|
||||
"details": cleanup_results,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error in maintenance cleanup: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("CLEANUP_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
|
||||
|
||||
|
||||
# ========================================
|
||||
# Background Task Statistics
|
||||
# ========================================
|
||||
|
||||
|
||||
@router.get("/tasks/stats")
|
||||
async def get_task_statistics(
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
request: Request, admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Get background task execution statistics (admin only)"""
|
||||
try:
|
||||
@ -268,7 +250,7 @@ async def get_task_statistics(
|
||||
guest_stats = await database.get_guest_statistics()
|
||||
|
||||
# Get background task manager status
|
||||
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
|
||||
background_task_manager = getattr(request.app.state, "background_task_manager", None)
|
||||
task_manager_stats = {}
|
||||
|
||||
if background_task_manager:
|
||||
@ -276,7 +258,7 @@ async def get_task_statistics(
|
||||
task_manager_stats = {
|
||||
"running": task_status["running"],
|
||||
"task_count": task_status["task_count"],
|
||||
"task_breakdown": {}
|
||||
"task_breakdown": {},
|
||||
}
|
||||
|
||||
# Count tasks by status
|
||||
@ -284,40 +266,35 @@ async def get_task_statistics(
|
||||
status = task["status"]
|
||||
task_manager_stats["task_breakdown"][status] = task_manager_stats["task_breakdown"].get(status, 0) + 1
|
||||
|
||||
return create_success_response({
|
||||
return create_success_response(
|
||||
{
|
||||
"guest_statistics": guest_stats,
|
||||
"task_manager": task_manager_stats,
|
||||
"timestamp": datetime.now(UTC).isoformat()
|
||||
})
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting task statistics: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STATS_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("STATS_ERROR", str(e)))
|
||||
|
||||
|
||||
# ========================================
|
||||
# Background Task Status Endpoints
|
||||
# ========================================
|
||||
|
||||
|
||||
@router.get("/tasks/status")
|
||||
async def get_background_task_status(
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
async def get_background_task_status(request: Request, admin_user=Depends(get_current_admin)):
|
||||
"""Get background task manager status (admin only)"""
|
||||
try:
|
||||
# Get background task manager from app state
|
||||
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
|
||||
background_task_manager = getattr(request.app.state, "background_task_manager", None)
|
||||
|
||||
if not background_task_manager:
|
||||
return create_success_response({
|
||||
"running": False,
|
||||
"message": "Background task manager not initialized",
|
||||
"tasks": [],
|
||||
"task_count": 0
|
||||
})
|
||||
return create_success_response(
|
||||
{"running": False, "message": "Background task manager not initialized", "tasks": [], "task_count": 0}
|
||||
)
|
||||
|
||||
# Get comprehensive task status using the new method
|
||||
task_status = await background_task_manager.get_task_status()
|
||||
@ -329,87 +306,64 @@ async def get_background_task_status(
|
||||
}
|
||||
|
||||
# Format the response
|
||||
return create_success_response({
|
||||
return create_success_response(
|
||||
{
|
||||
"running": task_status["running"],
|
||||
"task_count": task_status["task_count"],
|
||||
"loop_status": {
|
||||
"main_loop_id": task_status["main_loop_id"],
|
||||
"current_loop_id": task_status["current_loop_id"],
|
||||
"loop_matches": task_status.get("loop_matches", False)
|
||||
"loop_matches": task_status.get("loop_matches", False),
|
||||
},
|
||||
"tasks": task_status["tasks"],
|
||||
"system_info": system_info
|
||||
})
|
||||
"system_info": system_info,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get task status error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STATUS_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("STATUS_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.post("/tasks/run/{task_name}")
|
||||
async def run_background_task(
|
||||
task_name: str,
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
async def run_background_task(task_name: str, request: Request, admin_user=Depends(get_current_admin)):
|
||||
"""Manually trigger a specific background task (admin only)"""
|
||||
try:
|
||||
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
|
||||
background_task_manager = getattr(request.app.state, "background_task_manager", None)
|
||||
|
||||
if not background_task_manager:
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content=create_error_response(
|
||||
"MANAGER_UNAVAILABLE",
|
||||
"Background task manager not initialized"
|
||||
)
|
||||
content=create_error_response("MANAGER_UNAVAILABLE", "Background task manager not initialized"),
|
||||
)
|
||||
|
||||
# List of available tasks
|
||||
available_tasks = [
|
||||
"guest_cleanup",
|
||||
"token_cleanup",
|
||||
"guest_stats",
|
||||
"rate_limit_cleanup",
|
||||
"orphaned_cleanup"
|
||||
]
|
||||
available_tasks = ["guest_cleanup", "token_cleanup", "guest_stats", "rate_limit_cleanup", "orphaned_cleanup"]
|
||||
|
||||
if task_name not in available_tasks:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response(
|
||||
"INVALID_TASK",
|
||||
f"Unknown task: {task_name}. Available: {available_tasks}"
|
||||
)
|
||||
"INVALID_TASK", f"Unknown task: {task_name}. Available: {available_tasks}"
|
||||
),
|
||||
)
|
||||
|
||||
# Run the task
|
||||
result = await background_task_manager.force_run_task(task_name)
|
||||
|
||||
return create_success_response({
|
||||
"task_name": task_name,
|
||||
"result": result,
|
||||
"message": f"Task {task_name} completed successfully"
|
||||
})
|
||||
return create_success_response(
|
||||
{"task_name": task_name, "result": result, "message": f"Task {task_name} completed successfully"}
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("INVALID_TASK", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=400, content=create_error_response("INVALID_TASK", str(e)))
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error running task {task_name}: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("TASK_EXECUTION_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("TASK_EXECUTION_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.get("/tasks/list")
|
||||
async def list_available_tasks(
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
async def list_available_tasks(admin_user=Depends(get_current_admin)):
|
||||
"""List all available background tasks (admin only)"""
|
||||
try:
|
||||
tasks = [
|
||||
@ -417,63 +371,51 @@ async def list_available_tasks(
|
||||
"name": "guest_cleanup",
|
||||
"description": "Clean up inactive guest sessions",
|
||||
"interval": "6 hours",
|
||||
"parameters": ["inactive_hours (default: 48)"]
|
||||
"parameters": ["inactive_hours (default: 48)"],
|
||||
},
|
||||
{
|
||||
"name": "token_cleanup",
|
||||
"description": "Clean up expired email verification tokens",
|
||||
"interval": "12 hours",
|
||||
"parameters": []
|
||||
"parameters": [],
|
||||
},
|
||||
{
|
||||
"name": "guest_stats",
|
||||
"description": "Update guest usage statistics",
|
||||
"interval": "1 hour",
|
||||
"parameters": []
|
||||
"parameters": [],
|
||||
},
|
||||
{
|
||||
"name": "rate_limit_cleanup",
|
||||
"description": "Clean up old rate limiting data",
|
||||
"interval": "24 hours",
|
||||
"parameters": ["days_old (default: 7)"]
|
||||
"parameters": ["days_old (default: 7)"],
|
||||
},
|
||||
{
|
||||
"name": "orphaned_cleanup",
|
||||
"description": "Clean up orphaned database records",
|
||||
"interval": "6 hours",
|
||||
"parameters": []
|
||||
}
|
||||
"parameters": [],
|
||||
},
|
||||
]
|
||||
|
||||
return create_success_response({
|
||||
"total_tasks": len(tasks),
|
||||
"tasks": tasks
|
||||
})
|
||||
return create_success_response({"total_tasks": len(tasks), "tasks": tasks})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error listing tasks: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("LIST_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("LIST_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.post("/tasks/restart")
|
||||
async def restart_background_tasks(
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
async def restart_background_tasks(request: Request, admin_user=Depends(get_current_admin)):
|
||||
"""Restart the background task manager (admin only)"""
|
||||
try:
|
||||
database_manager = getattr(request.app.state, 'database_manager', None)
|
||||
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
|
||||
database_manager = getattr(request.app.state, "database_manager", None)
|
||||
background_task_manager = getattr(request.app.state, "background_task_manager", None)
|
||||
|
||||
if not database_manager:
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content=create_error_response(
|
||||
"DATABASE_UNAVAILABLE",
|
||||
"Database manager not available"
|
||||
)
|
||||
status_code=503, content=create_error_response("DATABASE_UNAVAILABLE", "Database manager not available")
|
||||
)
|
||||
|
||||
# Stop existing background tasks
|
||||
@ -483,6 +425,7 @@ async def restart_background_tasks(
|
||||
|
||||
# Create and start new background task manager
|
||||
from background_tasks import BackgroundTaskManager
|
||||
|
||||
new_background_task_manager = BackgroundTaskManager(database_manager)
|
||||
await new_background_task_manager.start()
|
||||
|
||||
@ -492,22 +435,20 @@ async def restart_background_tasks(
|
||||
# Get status of new manager
|
||||
status = await new_background_task_manager.get_task_status()
|
||||
|
||||
return create_success_response({
|
||||
"message": "Background task manager restarted successfully",
|
||||
"new_status": status
|
||||
})
|
||||
return create_success_response(
|
||||
{"message": "Background task manager restarted successfully", "new_status": status}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error restarting background tasks: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("RESTART_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("RESTART_ERROR", str(e)))
|
||||
|
||||
|
||||
# ============================
|
||||
# Task Monitoring and Metrics
|
||||
# ============================
|
||||
|
||||
|
||||
class TaskMetrics:
|
||||
"""Collect metrics for background tasks"""
|
||||
|
||||
@ -544,36 +485,32 @@ class TaskMetrics:
|
||||
metrics[task_name] = {
|
||||
"total_runs": self.task_runs[task_name],
|
||||
"total_errors": self.task_errors[task_name],
|
||||
"success_rate": (self.task_runs[task_name] - self.task_errors[task_name]) / self.task_runs[task_name] if self.task_runs[task_name] > 0 else 0,
|
||||
"success_rate": (self.task_runs[task_name] - self.task_errors[task_name]) / self.task_runs[task_name]
|
||||
if self.task_runs[task_name] > 0
|
||||
else 0,
|
||||
"average_duration": avg_duration,
|
||||
"last_runs": durations[-10:] if durations else []
|
||||
"last_runs": durations[-10:] if durations else [],
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
# Global task metrics
|
||||
task_metrics = TaskMetrics()
|
||||
|
||||
|
||||
@router.get("/tasks/metrics")
|
||||
async def get_task_metrics(
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
async def get_task_metrics(admin_user=Depends(get_current_admin)):
|
||||
"""Get background task metrics (admin only)"""
|
||||
try:
|
||||
global task_metrics
|
||||
metrics = task_metrics.get_metrics()
|
||||
|
||||
return create_success_response({
|
||||
"metrics": metrics,
|
||||
"timestamp": datetime.now(UTC).isoformat()
|
||||
})
|
||||
return create_success_response({"metrics": metrics, "timestamp": datetime.now(UTC).isoformat()})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get task metrics error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("METRICS_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("METRICS_ERROR", str(e)))
|
||||
|
||||
|
||||
# ============================
|
||||
@ -581,8 +518,7 @@ async def get_task_metrics(
|
||||
# ============================
|
||||
# @router.get("/verification-stats")
|
||||
async def get_verification_statistics(
|
||||
current_user = Depends(get_current_admin),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Get verification statistics (admin only)"""
|
||||
try:
|
||||
@ -591,22 +527,19 @@ async def get_verification_statistics(
|
||||
|
||||
stats = {
|
||||
"pending_verifications": await database.get_pending_verifications_count(),
|
||||
"expired_tokens_cleaned": await database.cleanup_expired_verification_tokens()
|
||||
"expired_tokens_cleaned": await database.cleanup_expired_verification_tokens(),
|
||||
}
|
||||
|
||||
return create_success_response(stats)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting verification stats: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STATS_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("STATS_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.post("/cleanup-verifications")
|
||||
async def cleanup_verification_tokens(
|
||||
current_user = Depends(get_current_admin),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Manually trigger cleanup of expired verification tokens (admin only)"""
|
||||
try:
|
||||
@ -617,24 +550,24 @@ async def cleanup_verification_tokens(
|
||||
|
||||
logger.info(f"🧹 Manual cleanup completed by admin {current_user.id}: {cleaned_count} tokens cleaned")
|
||||
|
||||
return create_success_response({
|
||||
return create_success_response(
|
||||
{
|
||||
"message": f"Cleanup completed. Removed {cleaned_count} expired verification tokens.",
|
||||
"cleaned_count": cleaned_count
|
||||
})
|
||||
"cleaned_count": cleaned_count,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error in manual cleanup: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("CLEANUP_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.get("/pending-verifications")
|
||||
async def get_pending_verifications(
|
||||
current_user = Depends(get_current_admin),
|
||||
current_user=Depends(get_current_admin),
|
||||
page: int = Query(1, ge=1),
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Get list of pending email verifications (admin only)"""
|
||||
try:
|
||||
@ -656,14 +589,16 @@ async def get_pending_verifications(
|
||||
if not verification_info.get("verified", False):
|
||||
expires_at = datetime.fromisoformat(verification_info.get("expires_at", ""))
|
||||
|
||||
pending_verifications.append({
|
||||
pending_verifications.append(
|
||||
{
|
||||
"email": verification_info.get("email"),
|
||||
"user_type": verification_info.get("user_type"),
|
||||
"created_at": verification_info.get("created_at"),
|
||||
"expires_at": verification_info.get("expires_at"),
|
||||
"is_expired": current_time > expires_at,
|
||||
"resend_count": verification_info.get("resend_count", 0)
|
||||
})
|
||||
"resend_count": verification_info.get("resend_count", 0),
|
||||
}
|
||||
)
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
@ -677,35 +612,27 @@ async def get_pending_verifications(
|
||||
end = start + limit
|
||||
paginated_verifications = pending_verifications[start:end]
|
||||
|
||||
paginated_response = create_paginated_response(
|
||||
paginated_verifications,
|
||||
page, limit, total
|
||||
)
|
||||
paginated_response = create_paginated_response(paginated_verifications, page, limit, total)
|
||||
|
||||
return create_success_response(paginated_response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting pending verifications: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("FETCH_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.get("/rate-limits/info")
|
||||
async def get_user_rate_limit_status(
|
||||
current_user = Depends(get_current_user_or_guest),
|
||||
current_user=Depends(get_current_user_or_guest),
|
||||
rate_limiter: RateLimiter = Depends(get_rate_limiter),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Get rate limit status for a user (admin only)"""
|
||||
try:
|
||||
# Get user to determine type
|
||||
user_data = await database.get_user_by_id(current_user.id)
|
||||
if not user_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("USER_NOT_FOUND", "User not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
|
||||
|
||||
user_type = user_data.get("type", "unknown")
|
||||
is_admin = False
|
||||
@ -725,27 +652,22 @@ async def get_user_rate_limit_status(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get rate limit status error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STATUS_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("STATUS_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.get("/rate-limits/{user_id}")
|
||||
async def get_anyone_rate_limit_status(
|
||||
user_id: str = Path(...),
|
||||
admin_user = Depends(get_current_admin),
|
||||
admin_user=Depends(get_current_admin),
|
||||
rate_limiter: RateLimiter = Depends(get_rate_limiter),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Get rate limit status for a user (admin only)"""
|
||||
try:
|
||||
# Get user to determine type
|
||||
user_data = await database.get_user_by_id(user_id)
|
||||
if not user_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("USER_NOT_FOUND", "User not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
|
||||
|
||||
user_type = user_data.get("type", "unknown")
|
||||
is_admin = False
|
||||
@ -765,47 +687,36 @@ async def get_anyone_rate_limit_status(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get rate limit status error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STATUS_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("STATUS_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.post("/rate-limits/{user_id}/reset")
|
||||
async def reset_user_rate_limits(
|
||||
user_id: str = Path(...),
|
||||
admin_user = Depends(get_current_admin),
|
||||
admin_user=Depends(get_current_admin),
|
||||
rate_limiter: RateLimiter = Depends(get_rate_limiter),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Reset rate limits for a user (admin only)"""
|
||||
try:
|
||||
# Get user to determine type
|
||||
user_data = await database.get_user_by_id(user_id)
|
||||
if not user_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("USER_NOT_FOUND", "User not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
|
||||
|
||||
user_type = user_data.get("type", "unknown")
|
||||
success = await rate_limiter.reset_user_rate_limits(user_id, user_type)
|
||||
|
||||
if success:
|
||||
logger.info(f"🔄 Rate limits reset for {user_type} {user_id} by admin {admin_user.id}")
|
||||
return create_success_response({
|
||||
"message": f"Rate limits reset for {user_type} {user_id}",
|
||||
"resetBy": admin_user.id
|
||||
})
|
||||
return create_success_response(
|
||||
{"message": f"Rate limits reset for {user_type} {user_id}", "resetBy": admin_user.id}
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("RESET_FAILED", "Failed to reset rate limits")
|
||||
status_code=500, content=create_error_response("RESET_FAILED", "Failed to reset rate limits")
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Reset rate limits error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("RESET_ERROR", str(e))
|
||||
)
|
||||
|
||||
return JSONResponse(status_code=500, content=create_error_response("RESET_ERROR", str(e)))
|
||||
|
@ -20,18 +20,25 @@ from device_manager import DeviceManager
|
||||
from email_service import VerificationEmailRateLimiter, email_service
|
||||
from logger import logger
|
||||
from models import (
|
||||
LoginRequest, CreateCandidateRequest, Candidate,
|
||||
Employer, Guest, AuthResponse, MFARequest,
|
||||
MFAData, MFAVerifyRequest, ResendVerificationRequest,
|
||||
MFARequestResponse, MFARequestResponse
|
||||
)
|
||||
from utils.dependencies import (
|
||||
get_current_admin, get_database, get_current_user, create_access_token
|
||||
LoginRequest,
|
||||
CreateCandidateRequest,
|
||||
Candidate,
|
||||
Employer,
|
||||
Guest,
|
||||
AuthResponse,
|
||||
MFARequest,
|
||||
MFAData,
|
||||
MFAVerifyRequest,
|
||||
ResendVerificationRequest,
|
||||
MFARequestResponse,
|
||||
MFARequestResponse,
|
||||
)
|
||||
from utils.dependencies import get_current_admin, get_database, get_current_user, create_access_token
|
||||
from utils.responses import create_success_response, create_error_response
|
||||
from utils.rate_limiter import get_rate_limiter
|
||||
from utils.auth_utils import (
|
||||
AuthenticationManager, SecurityConfig,
|
||||
AuthenticationManager,
|
||||
SecurityConfig,
|
||||
validate_password_strength,
|
||||
)
|
||||
|
||||
@ -43,28 +50,31 @@ if JWT_SECRET_KEY == "":
|
||||
raise ValueError("JWT_SECRET_KEY environment variable is not set")
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
|
||||
# ============================
|
||||
# Password Reset Endpoints
|
||||
# ============================
|
||||
class PasswordResetRequest(BaseModel):
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class PasswordResetConfirm(BaseModel):
|
||||
token: str
|
||||
new_password: str
|
||||
|
||||
@field_validator('new_password')
|
||||
@field_validator("new_password")
|
||||
def validate_password_strength(cls, v):
|
||||
is_valid, issues = validate_password_strength(v)
|
||||
if not is_valid:
|
||||
raise ValueError('; '.join(issues))
|
||||
raise ValueError("; ".join(issues))
|
||||
return v
|
||||
|
||||
|
||||
@router.post("/guest")
|
||||
async def create_guest_session_enhanced(
|
||||
request: Request,
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
rate_limiter: RateLimiter = Depends(get_rate_limiter)
|
||||
rate_limiter: RateLimiter = Depends(get_rate_limiter),
|
||||
):
|
||||
"""Create a guest session with enhanced validation and persistence"""
|
||||
try:
|
||||
@ -73,21 +83,15 @@ async def create_guest_session_enhanced(
|
||||
|
||||
# Check rate limits for guest session creation
|
||||
rate_result = await rate_limiter.check_rate_limit(
|
||||
user_id=ip_address,
|
||||
user_type="guest_creation",
|
||||
is_admin=False,
|
||||
endpoint="/guest"
|
||||
user_id=ip_address, user_type="guest_creation", is_admin=False, endpoint="/guest"
|
||||
)
|
||||
|
||||
if not rate_result.allowed:
|
||||
logger.warning(f"🚫 Guest creation rate limit exceeded for IP {ip_address}")
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content=create_error_response(
|
||||
"RATE_LIMITED",
|
||||
rate_result.reason or "Too many guest sessions created"
|
||||
),
|
||||
headers={"Retry-After": str(rate_result.retry_after_seconds or 300)}
|
||||
content=create_error_response("RATE_LIMITED", rate_result.reason or "Too many guest sessions created"),
|
||||
headers={"Retry-After": str(rate_result.retry_after_seconds or 300)},
|
||||
)
|
||||
|
||||
# Generate unique guest identifier with timestamp for uniqueness
|
||||
@ -139,7 +143,7 @@ async def create_guest_session_enhanced(
|
||||
"email": guest_data["email"],
|
||||
"username": guest_username,
|
||||
"session_id": session_id,
|
||||
"created_at": current_time.isoformat()
|
||||
"created_at": current_time.isoformat(),
|
||||
}
|
||||
|
||||
await database.set_user(guest_data["email"], user_auth_data)
|
||||
@ -149,11 +153,11 @@ async def create_guest_session_enhanced(
|
||||
# Create authentication tokens with longer expiry for guests
|
||||
access_token = create_access_token(
|
||||
data={"sub": guest_id, "type": "guest"},
|
||||
expires_delta=timedelta(hours=48) # Longer expiry for guests
|
||||
expires_delta=timedelta(hours=48), # Longer expiry for guests
|
||||
)
|
||||
refresh_token = create_access_token(
|
||||
data={"sub": guest_id, "type": "refresh_guest"},
|
||||
expires_delta=timedelta(days=14) # 2 weeks refresh for guests
|
||||
expires_delta=timedelta(days=14), # 2 weeks refresh for guests
|
||||
)
|
||||
|
||||
# Verify guest was stored correctly
|
||||
@ -161,8 +165,7 @@ async def create_guest_session_enhanced(
|
||||
if not verification:
|
||||
logger.error(f"❌ Failed to verify guest storage: {guest_id}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STORAGE_ERROR", "Failed to create guest session")
|
||||
status_code=500, content=create_error_response("STORAGE_ERROR", "Failed to create guest session")
|
||||
)
|
||||
|
||||
# Create guest object for response
|
||||
@ -178,7 +181,7 @@ async def create_guest_session_enhanced(
|
||||
"user": guest.model_dump(by_alias=True),
|
||||
"expiresAt": int((current_time + timedelta(hours=48)).timestamp()),
|
||||
"userType": "guest",
|
||||
"isGuest": True
|
||||
"isGuest": True,
|
||||
}
|
||||
|
||||
return create_success_response(auth_response)
|
||||
@ -186,25 +189,25 @@ async def create_guest_session_enhanced(
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Guest session creation error: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("GUEST_CREATION_FAILED", "Failed to create guest session")
|
||||
status_code=500, content=create_error_response("GUEST_CREATION_FAILED", "Failed to create guest session")
|
||||
)
|
||||
|
||||
|
||||
@router.post("/guest/convert")
|
||||
async def convert_guest_to_user(
|
||||
registration_data: Dict[str, Any] = Body(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Convert a guest session to a permanent user account"""
|
||||
try:
|
||||
# Verify current user is a guest
|
||||
if current_user.user_type != "guest":
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("NOT_GUEST", "Only guest users can be converted")
|
||||
status_code=400, content=create_error_response("NOT_GUEST", "Only guest users can be converted")
|
||||
)
|
||||
|
||||
guest: Guest = current_user
|
||||
@ -215,25 +218,18 @@ async def convert_guest_to_user(
|
||||
try:
|
||||
candidate_request = CreateCandidateRequest.model_validate(registration_data)
|
||||
except ValidationError as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("VALIDATION_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=400, content=create_error_response("VALIDATION_ERROR", str(e)))
|
||||
|
||||
# Check if email/username already exists
|
||||
auth_manager = AuthenticationManager(database)
|
||||
user_exists, conflict_field = await auth_manager.check_user_exists(
|
||||
candidate_request.email,
|
||||
candidate_request.username
|
||||
candidate_request.email, candidate_request.username
|
||||
)
|
||||
|
||||
if user_exists:
|
||||
return JSONResponse(
|
||||
status_code=409,
|
||||
content=create_error_response(
|
||||
"USER_EXISTS",
|
||||
f"A user with this {conflict_field} already exists"
|
||||
)
|
||||
content=create_error_response("USER_EXISTS", f"A user with this {conflict_field} already exists"),
|
||||
)
|
||||
|
||||
# Create candidate
|
||||
@ -253,7 +249,7 @@ async def convert_guest_to_user(
|
||||
"updated_at": current_time.isoformat(),
|
||||
"status": "active",
|
||||
"is_admin": False,
|
||||
"converted_from_guest": guest.id
|
||||
"converted_from_guest": guest.id,
|
||||
}
|
||||
|
||||
candidate = Candidate.model_validate(candidate_data)
|
||||
@ -269,7 +265,7 @@ async def convert_guest_to_user(
|
||||
"id": candidate_id,
|
||||
"type": "candidate",
|
||||
"email": candidate.email,
|
||||
"username": candidate.username
|
||||
"username": candidate.username,
|
||||
}
|
||||
|
||||
await database.set_user(candidate.email, user_auth_data)
|
||||
@ -286,43 +282,45 @@ async def convert_guest_to_user(
|
||||
access_token = create_access_token(data={"sub": candidate_id})
|
||||
refresh_token = create_access_token(
|
||||
data={"sub": candidate_id, "type": "refresh"},
|
||||
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS)
|
||||
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS),
|
||||
)
|
||||
|
||||
auth_response = AuthResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
user=candidate,
|
||||
expires_at=int((current_time + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp())
|
||||
expires_at=int((current_time + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()),
|
||||
)
|
||||
|
||||
logger.info(f"✅ Guest {guest.session_id} converted to candidate {candidate.username}")
|
||||
|
||||
return create_success_response({
|
||||
return create_success_response(
|
||||
{
|
||||
"message": "Guest account successfully converted to candidate",
|
||||
"auth": auth_response.model_dump(by_alias=True),
|
||||
"conversionType": "candidate"
|
||||
})
|
||||
"conversionType": "candidate",
|
||||
}
|
||||
)
|
||||
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("INVALID_TYPE", "Only candidate conversion is currently supported")
|
||||
content=create_error_response("INVALID_TYPE", "Only candidate conversion is currently supported"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Guest conversion error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("CONVERSION_FAILED", "Failed to convert guest account")
|
||||
status_code=500, content=create_error_response("CONVERSION_FAILED", "Failed to convert guest account")
|
||||
)
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(
|
||||
access_token: str = Body(..., alias="accessToken"),
|
||||
refresh_token: str = Body(..., alias="refreshToken"),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Logout endpoint - revokes both access and refresh tokens"""
|
||||
logger.info(f"🔑 User {current_user.id} is logging out")
|
||||
@ -336,21 +334,18 @@ async def logout(
|
||||
|
||||
if not user_id or token_type != "refresh":
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
|
||||
status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
|
||||
)
|
||||
except jwt.PyJWTError as e:
|
||||
logger.warning(f"⚠️ Invalid refresh token during logout: {e}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
|
||||
status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
|
||||
)
|
||||
|
||||
# Verify that the refresh token belongs to the current user
|
||||
if user_id != current_user.id:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content=create_error_response("FORBIDDEN", "Token does not belong to current user")
|
||||
status_code=403, content=create_error_response("FORBIDDEN", "Token does not belong to current user")
|
||||
)
|
||||
|
||||
# Get Redis client
|
||||
@ -362,12 +357,14 @@ async def logout(
|
||||
await redis.setex(
|
||||
f"blacklisted_token:{refresh_token}",
|
||||
refresh_ttl,
|
||||
json.dumps({
|
||||
json.dumps(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"token_type": "refresh",
|
||||
"revoked_at": datetime.now(UTC).isoformat(),
|
||||
"reason": "user_logout"
|
||||
})
|
||||
"reason": "user_logout",
|
||||
}
|
||||
),
|
||||
)
|
||||
logger.info(f"🔒 Blacklisted refresh token for user {user_id}")
|
||||
|
||||
@ -385,12 +382,14 @@ async def logout(
|
||||
await redis.setex(
|
||||
f"blacklisted_token:{access_token}",
|
||||
access_ttl,
|
||||
json.dumps({
|
||||
json.dumps(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"token_type": "access",
|
||||
"revoked_at": datetime.now(UTC).isoformat(),
|
||||
"reason": "user_logout"
|
||||
})
|
||||
"reason": "user_logout",
|
||||
}
|
||||
),
|
||||
)
|
||||
logger.info(f"🔒 Blacklisted access token for user {user_id}")
|
||||
else:
|
||||
@ -409,26 +408,20 @@ async def logout(
|
||||
# )
|
||||
|
||||
logger.info(f"🔑 User {user_id} logged out successfully")
|
||||
return create_success_response({
|
||||
return create_success_response(
|
||||
{
|
||||
"message": "Logged out successfully",
|
||||
"tokensRevoked": {
|
||||
"refreshToken": True,
|
||||
"accessToken": bool(access_token)
|
||||
"tokensRevoked": {"refreshToken": True, "accessToken": bool(access_token)},
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Logout error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("LOGOUT_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("LOGOUT_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.post("/logout-all")
|
||||
async def logout_all_devices(
|
||||
current_user = Depends(get_current_admin),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
async def logout_all_devices(current_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)):
|
||||
"""Logout from all devices by revoking all tokens for the user"""
|
||||
try:
|
||||
redis = redis_manager.get_client()
|
||||
@ -437,25 +430,20 @@ async def logout_all_devices(
|
||||
await redis.setex(
|
||||
f"user_tokens_revoked:{current_user.id}",
|
||||
int(timedelta(days=30).total_seconds()), # Max refresh token lifetime
|
||||
datetime.now(UTC).isoformat()
|
||||
datetime.now(UTC).isoformat(),
|
||||
)
|
||||
|
||||
logger.info(f"🔒 All tokens revoked for user {current_user.id}")
|
||||
return create_success_response({
|
||||
"message": "Logged out from all devices successfully"
|
||||
})
|
||||
return create_success_response({"message": "Logged out from all devices successfully"})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Logout all devices error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("LOGOUT_ALL_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("LOGOUT_ALL_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh_token_endpoint(
|
||||
refresh_token: str = Body(..., alias="refreshToken"),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
refresh_token: str = Body(..., alias="refreshToken"), database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Refresh token endpoint"""
|
||||
try:
|
||||
@ -466,8 +454,7 @@ async def refresh_token_endpoint(
|
||||
|
||||
if not user_id or token_type != "refresh":
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
|
||||
status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
|
||||
)
|
||||
|
||||
# Create new access token
|
||||
@ -484,37 +471,29 @@ async def refresh_token_endpoint(
|
||||
user = Employer.model_validate(employer_data)
|
||||
|
||||
if not user:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("USER_NOT_FOUND", "User not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
|
||||
|
||||
auth_response = AuthResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token, # Keep same refresh token
|
||||
user=user,
|
||||
expires_at=int((datetime.now(UTC) + timedelta(hours=24)).timestamp())
|
||||
expires_at=int((datetime.now(UTC) + timedelta(hours=24)).timestamp()),
|
||||
)
|
||||
|
||||
return create_success_response(auth_response.model_dump(by_alias=True))
|
||||
|
||||
except jwt.PyJWTError:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
|
||||
)
|
||||
return JSONResponse(status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token"))
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Token refresh error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("REFRESH_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("REFRESH_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.post("/resend-verification")
|
||||
async def resend_verification_email(
|
||||
request: ResendVerificationRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Resend verification email with comprehensive rate limiting and validation"""
|
||||
try:
|
||||
@ -527,10 +506,7 @@ async def resend_verification_email(
|
||||
can_send, reason = await rate_limiter.can_send_verification_email(email_lower)
|
||||
if not can_send:
|
||||
logger.warning(f"⚠️ Verification email rate limit exceeded for {email_lower}: {reason}")
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content=create_error_response("RATE_LIMITED", reason)
|
||||
)
|
||||
return JSONResponse(status_code=429, content=create_error_response("RATE_LIMITED", reason))
|
||||
|
||||
# Clean up expired tokens first
|
||||
await database.cleanup_expired_verification_tokens()
|
||||
@ -541,9 +517,11 @@ async def resend_verification_email(
|
||||
# User exists and is verified - don't reveal this for security
|
||||
logger.info(f"🔍 Resend verification requested for already verified user: {email_lower}")
|
||||
await rate_limiter.record_email_sent(email_lower) # Record attempt to prevent abuse
|
||||
return create_success_response({
|
||||
return create_success_response(
|
||||
{
|
||||
"message": "If your email is in our system and pending verification, a new verification email has been sent."
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
# Look for pending verification token
|
||||
verification_data = await database.find_verification_token_by_email(email_lower)
|
||||
@ -552,9 +530,11 @@ async def resend_verification_email(
|
||||
# No pending verification found - don't reveal this for security
|
||||
logger.info(f"🔍 Resend verification requested for non-existent pending verification: {email_lower}")
|
||||
await rate_limiter.record_email_sent(email_lower) # Record attempt to prevent abuse
|
||||
return create_success_response({
|
||||
return create_success_response(
|
||||
{
|
||||
"message": "If your email is in our system and pending verification, a new verification email has been sent."
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
# Check if verification token has expired
|
||||
expires_at = datetime.fromisoformat(verification_data["expires_at"])
|
||||
@ -568,8 +548,8 @@ async def resend_verification_email(
|
||||
status_code=400,
|
||||
content=create_error_response(
|
||||
"TOKEN_EXPIRED",
|
||||
"Your verification link has expired. Please register again to create a new account."
|
||||
)
|
||||
"Your verification link has expired. Please register again to create a new account.",
|
||||
),
|
||||
)
|
||||
|
||||
# Generate new verification token (invalidate old one)
|
||||
@ -577,20 +557,19 @@ async def resend_verification_email(
|
||||
new_token = secrets.token_urlsafe(32)
|
||||
|
||||
# Update verification data with new token and reset attempts
|
||||
verification_data.update({
|
||||
verification_data.update(
|
||||
{
|
||||
"token": new_token,
|
||||
"expires_at": (current_time + timedelta(hours=24)).isoformat(),
|
||||
"resent_at": current_time.isoformat(),
|
||||
"resend_count": verification_data.get("resend_count", 0) + 1
|
||||
})
|
||||
"resend_count": verification_data.get("resend_count", 0) + 1,
|
||||
}
|
||||
)
|
||||
|
||||
# Store new token and remove old one
|
||||
await database.redis.delete(f"email_verification:{old_token}")
|
||||
await database.store_email_verification_token(
|
||||
email_lower,
|
||||
new_token,
|
||||
verification_data["user_type"],
|
||||
verification_data["user_data"]
|
||||
email_lower, new_token, verification_data["user_type"], verification_data["user_data"]
|
||||
)
|
||||
|
||||
# Get user name for email
|
||||
@ -610,73 +589,66 @@ async def resend_verification_email(
|
||||
await rate_limiter.record_email_sent(email_lower)
|
||||
|
||||
# Send new verification email in background
|
||||
background_tasks.add_task(
|
||||
email_service.send_verification_email,
|
||||
email_lower,
|
||||
new_token,
|
||||
user_name,
|
||||
user_type
|
||||
)
|
||||
background_tasks.add_task(email_service.send_verification_email, email_lower, new_token, user_name, user_type)
|
||||
|
||||
# Log security event
|
||||
await database.log_security_event(
|
||||
verification_data["user_data"].get("candidate_data", {}).get("id") or
|
||||
verification_data["user_data"].get("employer_data", {}).get("id") or "unknown",
|
||||
verification_data["user_data"].get("candidate_data", {}).get("id")
|
||||
or verification_data["user_data"].get("employer_data", {}).get("id")
|
||||
or "unknown",
|
||||
"verification_resend",
|
||||
{
|
||||
"email": email_lower,
|
||||
"user_type": user_type,
|
||||
"resend_count": verification_data.get("resend_count", 1),
|
||||
"old_token_invalidated": old_token[:8] + "...", # Log partial token for debugging
|
||||
"ip_address": "unknown" # You can extract this from request if needed
|
||||
"ip_address": "unknown", # You can extract this from request if needed
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"✅ Verification email resent to {email_lower} (attempt #{verification_data.get('resend_count', 1)})"
|
||||
)
|
||||
|
||||
return create_success_response(
|
||||
{
|
||||
"message": "A new verification email has been sent to your email address. Please check your inbox and spam folder.",
|
||||
"resendCount": verification_data.get("resend_count", 1),
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"✅ Verification email resent to {email_lower} (attempt #{verification_data.get('resend_count', 1)})")
|
||||
|
||||
return create_success_response({
|
||||
"message": "A new verification email has been sent to your email address. Please check your inbox and spam folder.",
|
||||
"resendCount": verification_data.get("resend_count", 1)
|
||||
})
|
||||
|
||||
except ValueError as ve:
|
||||
logger.warning(f"⚠️ Invalid resend verification request: {ve}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("VALIDATION_ERROR", str(ve))
|
||||
)
|
||||
return JSONResponse(status_code=400, content=create_error_response("VALIDATION_ERROR", str(ve)))
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Resend verification email error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("RESEND_FAILED", "An error occurred while processing your request. Please try again later.")
|
||||
content=create_error_response(
|
||||
"RESEND_FAILED", "An error occurred while processing your request. Please try again later."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/mfa/request")
|
||||
async def request_mfa(
|
||||
request: MFARequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
http_request: Request,
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Request MFA for login from new device"""
|
||||
try:
|
||||
# Verify credentials first
|
||||
auth_manager = AuthenticationManager(database)
|
||||
is_valid, user_data, error_message = await auth_manager.verify_user_credentials(
|
||||
request.email,
|
||||
request.password
|
||||
)
|
||||
is_valid, user_data, error_message = await auth_manager.verify_user_credentials(request.email, request.password)
|
||||
|
||||
if not is_valid or not user_data:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content=create_error_response("AUTH_FAILED", "Invalid credentials")
|
||||
)
|
||||
return JSONResponse(status_code=401, content=create_error_response("AUTH_FAILED", "Invalid credentials"))
|
||||
|
||||
# Check if device is trusted
|
||||
device_manager = DeviceManager(database)
|
||||
device_info = device_manager.parse_device_info(http_request)
|
||||
device_manager.parse_device_info(http_request)
|
||||
|
||||
is_trusted = await device_manager.is_trusted_device(user_data["id"], request.device_id)
|
||||
|
||||
@ -684,10 +656,7 @@ async def request_mfa(
|
||||
# Device is trusted, proceed with normal login
|
||||
await device_manager.update_device_last_used(user_data["id"], request.device_id)
|
||||
|
||||
return create_success_response({
|
||||
"mfa_required": False,
|
||||
"message": "Device is trusted, proceed with login"
|
||||
})
|
||||
return create_success_response({"mfa_required": False, "message": "Device is trusted, proceed with login"})
|
||||
|
||||
# Generate MFA code
|
||||
mfa_code = f"{secrets.randbelow(1000000):06d}" # 6-digit code
|
||||
@ -709,8 +678,7 @@ async def request_mfa(
|
||||
|
||||
if not email:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA")
|
||||
status_code=400, content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA")
|
||||
)
|
||||
|
||||
# Store MFA code
|
||||
@ -718,13 +686,7 @@ async def request_mfa(
|
||||
logger.info(f"🔐 MFA code generated for {email} on device {request.device_id}")
|
||||
|
||||
# Send MFA code via email
|
||||
background_tasks.add_task(
|
||||
email_service.send_mfa_email,
|
||||
email,
|
||||
mfa_code,
|
||||
request.device_name,
|
||||
user_name
|
||||
)
|
||||
background_tasks.add_task(email_service.send_mfa_email, email, mfa_code, request.device_name, user_name)
|
||||
|
||||
logger.info(f"🔐 MFA requested for {request.email} from new device {request.device_name}")
|
||||
|
||||
@ -735,25 +697,22 @@ async def request_mfa(
|
||||
device_id=request.device_id,
|
||||
device_name=request.device_name,
|
||||
)
|
||||
mfa_response = MFARequestResponse(
|
||||
mfa_required=True,
|
||||
mfa_data=mfa_data
|
||||
)
|
||||
mfa_response = MFARequestResponse(mfa_required=True, mfa_data=mfa_data)
|
||||
return create_success_response(mfa_response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ MFA request error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("MFA_REQUEST_FAILED", "Failed to process MFA request")
|
||||
status_code=500, content=create_error_response("MFA_REQUEST_FAILED", "Failed to process MFA request")
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
async def login(
|
||||
request: LoginRequest,
|
||||
http_request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""login with automatic MFA email sending for new devices"""
|
||||
try:
|
||||
@ -766,16 +725,12 @@ async def login(
|
||||
device_id = device_info["device_id"]
|
||||
|
||||
# Verify credentials first
|
||||
is_valid, user_data, error_message = await auth_manager.verify_user_credentials(
|
||||
request.login,
|
||||
request.password
|
||||
)
|
||||
is_valid, user_data, error_message = await auth_manager.verify_user_credentials(request.login, request.password)
|
||||
|
||||
if not is_valid or not user_data:
|
||||
logger.warning(f"⚠️ Failed login attempt for: {request.login}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content=create_error_response("AUTH_FAILED", error_message or "Invalid credentials")
|
||||
status_code=401, content=create_error_response("AUTH_FAILED", error_message or "Invalid credentials")
|
||||
)
|
||||
|
||||
# Check if device is trusted
|
||||
@ -804,8 +759,7 @@ async def login(
|
||||
|
||||
if not email:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA")
|
||||
status_code=400, content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA")
|
||||
)
|
||||
|
||||
# Store MFA code
|
||||
@ -817,12 +771,7 @@ async def login(
|
||||
|
||||
# Send MFA code via email in background
|
||||
background_tasks.add_task(
|
||||
email_service.send_mfa_email,
|
||||
email,
|
||||
mfa_code,
|
||||
device_info["device_name"],
|
||||
user_name,
|
||||
ip_address
|
||||
email_service.send_mfa_email, email, mfa_code, device_info["device_name"], user_name, ip_address
|
||||
)
|
||||
|
||||
# Log security event
|
||||
@ -834,8 +783,8 @@ async def login(
|
||||
"device_name": device_info["device_name"],
|
||||
"ip_address": ip_address,
|
||||
"user_agent": device_info.get("user_agent", ""),
|
||||
"auto_sent": True
|
||||
}
|
||||
"auto_sent": True,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"🔐 MFA code automatically sent to {request.login} for device {device_info['device_name']}")
|
||||
@ -847,8 +796,8 @@ async def login(
|
||||
email=email,
|
||||
device_id=device_id,
|
||||
device_name=device_info["device_name"],
|
||||
code_sent=mfa_code
|
||||
)
|
||||
code_sent=mfa_code,
|
||||
),
|
||||
)
|
||||
return create_success_response(mfa_response.model_dump(by_alias=True))
|
||||
|
||||
@ -860,7 +809,7 @@ async def login(
|
||||
access_token = create_access_token(data={"sub": user_data["id"]})
|
||||
refresh_token = create_access_token(
|
||||
data={"sub": user_data["id"], "type": "refresh"},
|
||||
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS)
|
||||
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS),
|
||||
)
|
||||
|
||||
# Get user object
|
||||
@ -876,8 +825,7 @@ async def login(
|
||||
|
||||
if not user:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("USER_NOT_FOUND", "User profile not found")
|
||||
status_code=404, content=create_error_response("USER_NOT_FOUND", "User profile not found")
|
||||
)
|
||||
|
||||
# Log successful login from trusted device
|
||||
@ -888,8 +836,8 @@ async def login(
|
||||
"device_id": device_id,
|
||||
"device_name": device_info["device_name"],
|
||||
"ip_address": http_request.client.host if http_request.client else "Unknown",
|
||||
"trusted_device": True
|
||||
}
|
||||
"trusted_device": True,
|
||||
},
|
||||
)
|
||||
|
||||
# Create response
|
||||
@ -897,7 +845,9 @@ async def login(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
user=user,
|
||||
expires_at=int((datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp())
|
||||
expires_at=int(
|
||||
(datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(f"🔑 User {request.login} logged in successfully from trusted device")
|
||||
@ -908,17 +858,12 @@ async def login(
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Login error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("LOGIN_ERROR", "An error occurred during login")
|
||||
status_code=500, content=create_error_response("LOGIN_ERROR", "An error occurred during login")
|
||||
)
|
||||
|
||||
|
||||
@router.post("/mfa/verify")
|
||||
async def verify_mfa(
|
||||
request: MFAVerifyRequest,
|
||||
http_request: Request,
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
async def verify_mfa(request: MFAVerifyRequest, http_request: Request, database: RedisDatabase = Depends(get_database)):
|
||||
"""Verify MFA code and complete login with error handling"""
|
||||
try:
|
||||
# Get MFA data
|
||||
@ -928,13 +873,17 @@ async def verify_mfa(
|
||||
logger.warning(f"⚠️ No MFA session found for {request.email} on device {request.device_id}")
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NO_MFA_SESSION", "No active MFA session found. Please try logging in again.")
|
||||
content=create_error_response(
|
||||
"NO_MFA_SESSION", "No active MFA session found. Please try logging in again."
|
||||
),
|
||||
)
|
||||
|
||||
if mfa_data.get("verified"):
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("ALREADY_VERIFIED", "This MFA code has already been used. Please login again.")
|
||||
content=create_error_response(
|
||||
"ALREADY_VERIFIED", "This MFA code has already been used. Please login again."
|
||||
),
|
||||
)
|
||||
|
||||
# Check expiration
|
||||
@ -944,7 +893,7 @@ async def verify_mfa(
|
||||
await database.redis.delete(f"mfa_code:{request.email.lower()}:{request.device_id}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("MFA_EXPIRED", "MFA code has expired. Please try logging in again.")
|
||||
content=create_error_response("MFA_EXPIRED", "MFA code has expired. Please try logging in again."),
|
||||
)
|
||||
|
||||
# Check attempts
|
||||
@ -954,7 +903,9 @@ async def verify_mfa(
|
||||
await database.redis.delete(f"mfa_code:{request.email.lower()}:{request.device_id}")
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content=create_error_response("TOO_MANY_ATTEMPTS", "Too many incorrect attempts. Please try logging in again.")
|
||||
content=create_error_response(
|
||||
"TOO_MANY_ATTEMPTS", "Too many incorrect attempts. Please try logging in again."
|
||||
),
|
||||
)
|
||||
|
||||
# Verify code
|
||||
@ -965,9 +916,8 @@ async def verify_mfa(
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response(
|
||||
"INVALID_CODE",
|
||||
f"Invalid MFA code. {remaining_attempts} attempts remaining."
|
||||
)
|
||||
"INVALID_CODE", f"Invalid MFA code. {remaining_attempts} attempts remaining."
|
||||
),
|
||||
)
|
||||
|
||||
# Mark as verified
|
||||
@ -976,20 +926,13 @@ async def verify_mfa(
|
||||
# Get user data
|
||||
user_data = await database.get_user(request.email)
|
||||
if not user_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("USER_NOT_FOUND", "User not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
|
||||
|
||||
# Add device to trusted devices if requested
|
||||
if request.remember_device:
|
||||
device_manager = DeviceManager(database)
|
||||
device_info = device_manager.parse_device_info(http_request)
|
||||
await device_manager.add_trusted_device(
|
||||
user_data["id"],
|
||||
request.device_id,
|
||||
device_info
|
||||
)
|
||||
await device_manager.add_trusted_device(user_data["id"], request.device_id, device_info)
|
||||
logger.info(f"🔒 Device {request.device_id} added to trusted devices for user {user_data['id']}")
|
||||
|
||||
# Update last login
|
||||
@ -1000,7 +943,7 @@ async def verify_mfa(
|
||||
access_token = create_access_token(data={"sub": user_data["id"]})
|
||||
refresh_token = create_access_token(
|
||||
data={"sub": user_data["id"], "type": "refresh"},
|
||||
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS)
|
||||
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS),
|
||||
)
|
||||
|
||||
# Get user object
|
||||
@ -1016,8 +959,7 @@ async def verify_mfa(
|
||||
|
||||
if not user:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("USER_NOT_FOUND", "User profile not found")
|
||||
status_code=404, content=create_error_response("USER_NOT_FOUND", "User profile not found")
|
||||
)
|
||||
|
||||
# Log successful MFA verification and login
|
||||
@ -1028,8 +970,8 @@ async def verify_mfa(
|
||||
"device_id": request.device_id,
|
||||
"ip_address": http_request.client.host if http_request.client else "Unknown",
|
||||
"device_remembered": request.remember_device,
|
||||
"attempts_used": current_attempts + 1
|
||||
}
|
||||
"attempts_used": current_attempts + 1,
|
||||
},
|
||||
)
|
||||
|
||||
await database.log_security_event(
|
||||
@ -1039,8 +981,8 @@ async def verify_mfa(
|
||||
"device_id": request.device_id,
|
||||
"ip_address": http_request.client.host if http_request.client else "Unknown",
|
||||
"mfa_verified": True,
|
||||
"new_device": True
|
||||
}
|
||||
"new_device": True,
|
||||
},
|
||||
)
|
||||
|
||||
# Clean up MFA session
|
||||
@ -1051,7 +993,9 @@ async def verify_mfa(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
user=user,
|
||||
expires_at=int((datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp())
|
||||
expires_at=int(
|
||||
(datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(f"✅ MFA verified and login completed for {request.email}")
|
||||
@ -1062,15 +1006,12 @@ async def verify_mfa(
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ MFA verification error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("MFA_VERIFICATION_FAILED", "Failed to verify MFA")
|
||||
status_code=500, content=create_error_response("MFA_VERIFICATION_FAILED", "Failed to verify MFA")
|
||||
)
|
||||
|
||||
|
||||
@router.post("/password-reset/request")
|
||||
async def request_password_reset(
|
||||
request: PasswordResetRequest,
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
async def request_password_reset(request: PasswordResetRequest, database: RedisDatabase = Depends(get_database)):
|
||||
"""Request password reset"""
|
||||
try:
|
||||
# Check if user exists
|
||||
@ -1100,15 +1041,12 @@ async def request_password_reset(
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Password reset request error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("RESET_ERROR", "An error occurred processing the request")
|
||||
status_code=500, content=create_error_response("RESET_ERROR", "An error occurred processing the request")
|
||||
)
|
||||
|
||||
|
||||
@router.post("/password-reset/confirm")
|
||||
async def confirm_password_reset(
|
||||
request: PasswordResetConfirm,
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
async def confirm_password_reset(request: PasswordResetConfirm, database: RedisDatabase = Depends(get_database)):
|
||||
"""Confirm password reset with token"""
|
||||
try:
|
||||
# Find user by reset token
|
||||
@ -1122,8 +1060,5 @@ async def confirm_password_reset(
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Password reset confirm error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("RESET_ERROR", "An error occurred resetting the password")
|
||||
status_code=500, content=create_error_response("RESET_ERROR", "An error occurred resetting the password")
|
||||
)
|
||||
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -4,125 +4,69 @@ Chat routes
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, UTC
|
||||
from typing import (Dict, Any)
|
||||
from typing import Dict, Any
|
||||
|
||||
from fastapi import (
|
||||
APIRouter, Depends, Body, Depends, Query, Path,
|
||||
Body, APIRouter
|
||||
)
|
||||
from fastapi import APIRouter, Depends, Body, Depends, Query, Path, Body, APIRouter
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from database.manager import RedisDatabase
|
||||
from logger import logger
|
||||
from utils.dependencies import (
|
||||
get_database, get_current_user, get_current_user_or_guest
|
||||
)
|
||||
from utils.responses import (
|
||||
create_success_response, create_error_response, create_paginated_response
|
||||
)
|
||||
from utils.helpers import (
|
||||
stream_agent_response
|
||||
)
|
||||
from utils.dependencies import get_database, get_current_user, get_current_user_or_guest
|
||||
from utils.responses import create_success_response, create_error_response, create_paginated_response
|
||||
from utils.helpers import stream_agent_response
|
||||
import backstory_traceback
|
||||
|
||||
import entities.entity_manager as entities
|
||||
|
||||
from models import (
|
||||
LoginRequest, CreateCandidateRequest, CreateEmployerRequest,
|
||||
Candidate, Employer, Guest, AuthResponse,
|
||||
MFARequest, MFAData, MFARequestResponse, MFAVerifyRequest,
|
||||
EmailVerificationRequest, ResendVerificationRequest,
|
||||
# API
|
||||
MOCK_UUID, ApiActivityType, ChatMessageError, ChatMessageResume,
|
||||
ChatMessageSkillAssessment, ChatMessageStatus, ChatMessageStreaming,
|
||||
ChatMessageUser, DocumentMessage, DocumentOptions, Job,
|
||||
JobRequirements, JobRequirementsMessage, LoginRequest,
|
||||
CreateCandidateRequest, CreateEmployerRequest,
|
||||
|
||||
# User models
|
||||
Candidate, Employer, BaseUserWithType, BaseUser, Guest,
|
||||
Authentication, AuthResponse, CandidateAI,
|
||||
|
||||
# Job models
|
||||
JobApplication, ApplicationStatus,
|
||||
|
||||
# Chat models
|
||||
ChatSession, ChatMessage, ChatContext, ChatQuery, ApiStatusType, ChatSenderType, ApiMessageType, ChatContextType,
|
||||
ChatMessageRagSearch,
|
||||
|
||||
# Document models
|
||||
Document, DocumentType, DocumentListResponse, DocumentUpdateRequest, DocumentContentResponse,
|
||||
|
||||
# Supporting models
|
||||
Location, MFARequest, MFAData, MFARequestResponse, MFAVerifyRequest, RagContentMetadata, RagContentResponse, ResendVerificationRequest, Resume, ResumeMessage, Skill, SkillAssessment, SystemInfo, UserType, WorkExperience, Education,
|
||||
|
||||
# Email
|
||||
EmailVerificationRequest
|
||||
)
|
||||
from models import Candidate, ChatMessageUser, Candidate, BaseUserWithType, ChatSession, ChatMessage
|
||||
|
||||
|
||||
# Create router for authentication endpoints
|
||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/archive")
|
||||
async def archive_chat_session(
|
||||
session_id: str = Path(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
session_id: str = Path(...), current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Archive a chat session"""
|
||||
try:
|
||||
session_data = await database.get_chat_session(session_id)
|
||||
if not session_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Chat session not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
|
||||
|
||||
# Check if user owns this session or is admin
|
||||
if session_data.get("userId") != current_user.id:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content=create_error_response("FORBIDDEN", "Cannot archive another user's session")
|
||||
status_code=403, content=create_error_response("FORBIDDEN", "Cannot archive another user's session")
|
||||
)
|
||||
|
||||
await database.archive_chat_session(session_id)
|
||||
|
||||
return create_success_response({
|
||||
"message": "Chat session archived successfully",
|
||||
"sessionId": session_id
|
||||
})
|
||||
return create_success_response({"message": "Chat session archived successfully", "sessionId": session_id})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Archive chat session error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("ARCHIVE_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("ARCHIVE_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.get("/statistics")
|
||||
async def get_chat_statistics(
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
async def get_chat_statistics(current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)):
|
||||
"""Get chat statistics (admin/analytics endpoint)"""
|
||||
try:
|
||||
stats = await database.get_chat_statistics()
|
||||
return create_success_response(stats)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get chat statistics error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STATS_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("STATS_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.post("/sessions")
|
||||
async def create_chat_session(
|
||||
session_data: Dict[str, Any] = Body(...),
|
||||
current_user: BaseUserWithType = Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Create a new chat session with optional candidate username association"""
|
||||
try:
|
||||
@ -140,15 +84,14 @@ async def create_chat_session(
|
||||
candidates_list = [Candidate.model_validate(data) for data in all_candidates_data.values()]
|
||||
|
||||
# Find candidate by username (case-insensitive)
|
||||
matching_candidates = [
|
||||
c for c in candidates_list
|
||||
if c.username.lower() == username.lower()
|
||||
]
|
||||
matching_candidates = [c for c in candidates_list if c.username.lower() == username.lower()]
|
||||
|
||||
if not matching_candidates:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("CANDIDATE_NOT_FOUND", f"Candidate with username '{username}' not found")
|
||||
content=create_error_response(
|
||||
"CANDIDATE_NOT_FOUND", f"Candidate with username '{username}' not found"
|
||||
),
|
||||
)
|
||||
|
||||
candidate_data = matching_candidates[0]
|
||||
@ -177,7 +120,7 @@ async def create_chat_session(
|
||||
"username": candidate_data.username,
|
||||
"skills": [skill.name for skill in candidate_data.skills] if candidate_data.skills else [],
|
||||
"experience": len(candidate_data.experience) if candidate_data.experience else 0,
|
||||
"location": candidate_data.location.city if candidate_data.location else "Unknown"
|
||||
"location": candidate_data.location.city if candidate_data.location else "Unknown",
|
||||
}
|
||||
context["additionalContext"] = additional_context
|
||||
|
||||
@ -191,8 +134,10 @@ async def create_chat_session(
|
||||
chat_session = ChatSession.model_validate(session_data)
|
||||
await database.set_chat_session(chat_session.id, chat_session.model_dump())
|
||||
|
||||
logger.info(f"✅ Chat session created: {chat_session.id} for user {current_user.id}" +
|
||||
(f" about candidate {candidate_data.full_name}" if candidate_data else ""))
|
||||
logger.info(
|
||||
f"✅ Chat session created: {chat_session.id} for user {current_user.id}"
|
||||
+ (f" about candidate {candidate_data.full_name}" if candidate_data else "")
|
||||
)
|
||||
|
||||
return create_success_response(chat_session.model_dump(by_alias=True))
|
||||
|
||||
@ -200,49 +145,58 @@ async def create_chat_session(
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Chat session creation error: {e}")
|
||||
logger.info(json.dumps(session_data, indent=2))
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("CREATION_FAILED", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=400, content=create_error_response("CREATION_FAILED", str(e)))
|
||||
|
||||
|
||||
@router.post("/sessions/messages/stream")
|
||||
async def post_chat_session_message_stream(
|
||||
user_message: ChatMessageUser = Body(...),
|
||||
current_user = Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Post a message to a chat session and stream the response with persistence"""
|
||||
try:
|
||||
chat_session_data = await database.get_chat_session(user_message.session_id)
|
||||
if not chat_session_data:
|
||||
logger.info("🔗 Chat session not found for session ID: " + user_message.session_id)
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Chat session not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
|
||||
chat_session = ChatSession.model_validate(chat_session_data)
|
||||
chat_type = chat_session.context.type
|
||||
candidate_info = chat_session.context.additional_context.get("candidateInfo", {}) if chat_session.context and chat_session.context.additional_context else None
|
||||
candidate_info = (
|
||||
chat_session.context.additional_context.get("candidateInfo", {})
|
||||
if chat_session.context and chat_session.context.additional_context
|
||||
else None
|
||||
)
|
||||
|
||||
# Get candidate info if this chat is about a specific candidate
|
||||
if candidate_info:
|
||||
logger.info(f"🔗 Chat session {user_message.session_id} about candidate {candidate_info['name']} accessed by user {current_user.id}")
|
||||
logger.info(
|
||||
f"🔗 Chat session {user_message.session_id} about candidate {candidate_info['name']} accessed by user {current_user.id}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"🔗 Chat session {user_message.session_id} type {chat_type} accessed by user {current_user.id}")
|
||||
logger.info(
|
||||
f"🔗 Chat session {user_message.session_id} type {chat_type} accessed by user {current_user.id}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("CANDIDATE_REQUIRED", "This chat session requires a candidate association")
|
||||
content=create_error_response(
|
||||
"CANDIDATE_REQUIRED", "This chat session requires a candidate association"
|
||||
),
|
||||
)
|
||||
|
||||
candidate_data = await database.get_candidate(candidate_info["id"]) if candidate_info else None
|
||||
candidate : Candidate | None = Candidate.model_validate(candidate_data) if candidate_data else None
|
||||
candidate: Candidate | None = Candidate.model_validate(candidate_data) if candidate_data else None
|
||||
if not candidate:
|
||||
logger.info(f"🔗 Candidate not found for chat session {user_message.session_id} with ID {candidate_info['id']}")
|
||||
logger.info(
|
||||
f"🔗 Candidate not found for chat session {user_message.session_id} with ID {candidate_info['id']}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("CANDIDATE_NOT_FOUND", "Candidate not found for this chat session")
|
||||
content=create_error_response("CANDIDATE_NOT_FOUND", "Candidate not found for this chat session"),
|
||||
)
|
||||
logger.info(
|
||||
f"🔗 User {current_user.id} posting message to chat session {user_message.session_id} with query length: {len(user_message.content)}"
|
||||
)
|
||||
logger.info(f"🔗 User {current_user.id} posting message to chat session {user_message.session_id} with query length: {len(user_message.content)}")
|
||||
|
||||
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
|
||||
# Entity automatically released when done
|
||||
@ -251,7 +205,7 @@ async def post_chat_session_message_stream(
|
||||
logger.info(f"🔗 No chat agent found for session {user_message.session_id} with type {chat_type}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type")
|
||||
content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type"),
|
||||
)
|
||||
|
||||
# Persist user message to database
|
||||
@ -269,30 +223,25 @@ async def post_chat_session_message_stream(
|
||||
chat_session_data=chat_session_data,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Chat message streaming error")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STREAMING_ERROR", "")
|
||||
)
|
||||
logger.error("❌ Chat message streaming error")
|
||||
return JSONResponse(status_code=500, content=create_error_response("STREAMING_ERROR", ""))
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/messages")
|
||||
async def get_chat_session_messages(
|
||||
session_id: str = Path(...),
|
||||
current_user = Depends(get_current_user_or_guest),
|
||||
current_user=Depends(get_current_user_or_guest),
|
||||
page: int = Query(1, ge=1),
|
||||
limit: int = Query(50, ge=1, le=100), # Increased default for chat messages
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Get persisted chat messages for a session"""
|
||||
try:
|
||||
chat_session_data = await database.get_chat_session(session_id)
|
||||
if not chat_session_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Chat session not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
|
||||
|
||||
# Get messages from database
|
||||
chat_messages = await database.get_chat_messages(session_id)
|
||||
@ -317,43 +266,36 @@ async def get_chat_session_messages(
|
||||
paginated_messages = messages_list[start:end]
|
||||
|
||||
paginated_response = create_paginated_response(
|
||||
[m.model_dump(by_alias=True) for m in paginated_messages],
|
||||
page, limit, total
|
||||
[m.model_dump(by_alias=True) for m in paginated_messages], page, limit, total
|
||||
)
|
||||
|
||||
return create_success_response(paginated_response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get chat messages error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("FETCH_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.patch("/sessions/{session_id}")
|
||||
async def update_chat_session(
|
||||
session_id: str = Path(...),
|
||||
updates: Dict[str, Any] = Body(...),
|
||||
current_user = Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Update a chat session's properties"""
|
||||
try:
|
||||
# Get the existing session
|
||||
session_data = await database.get_chat_session(session_id)
|
||||
if not session_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Chat session not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
|
||||
|
||||
session = ChatSession.model_validate(session_data)
|
||||
|
||||
# Check authorization - user can only update their own sessions
|
||||
if session.user_id != current_user.id:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content=create_error_response("FORBIDDEN", "Cannot update another user's chat session")
|
||||
status_code=403, content=create_error_response("FORBIDDEN", "Cannot update another user's chat session")
|
||||
)
|
||||
|
||||
# Validate and apply updates
|
||||
@ -362,8 +304,7 @@ async def update_chat_session(
|
||||
|
||||
if not filtered_updates:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("INVALID_UPDATES", "No valid fields provided for update")
|
||||
status_code=400, content=create_error_response("INVALID_UPDATES", "No valid fields provided for update")
|
||||
)
|
||||
|
||||
# Apply updates to session data
|
||||
@ -417,40 +358,31 @@ async def update_chat_session(
|
||||
|
||||
except ValueError as ve:
|
||||
logger.warning(f"⚠️ Validation error updating chat session: {ve}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("VALIDATION_ERROR", str(ve))
|
||||
)
|
||||
return JSONResponse(status_code=400, content=create_error_response("VALIDATION_ERROR", str(ve)))
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Update chat session error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("UPDATE_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("UPDATE_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.delete("/sessions/{session_id}")
|
||||
async def delete_chat_session(
|
||||
session_id: str = Path(...),
|
||||
current_user = Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Delete a chat session and all its messages"""
|
||||
try:
|
||||
# Get the session to verify it exists and check ownership
|
||||
session_data = await database.get_chat_session(session_id)
|
||||
if not session_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Chat session not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
|
||||
|
||||
session = ChatSession.model_validate(session_data)
|
||||
|
||||
# Check authorization - user can only delete their own sessions
|
||||
if session.user_id != current_user.id:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content=create_error_response("FORBIDDEN", "Cannot delete another user's chat session")
|
||||
status_code=403, content=create_error_response("FORBIDDEN", "Cannot delete another user's chat session")
|
||||
)
|
||||
|
||||
# Delete all messages associated with this session
|
||||
@ -469,42 +401,34 @@ async def delete_chat_session(
|
||||
|
||||
logger.info(f"🗑️ Chat session {session_id} deleted by user {current_user.id}")
|
||||
|
||||
return create_success_response({
|
||||
"success": True,
|
||||
"message": "Chat session deleted successfully",
|
||||
"sessionId": session_id
|
||||
})
|
||||
return create_success_response(
|
||||
{"success": True, "message": "Chat session deleted successfully", "sessionId": session_id}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Delete chat session error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("DELETE_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("DELETE_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.patch("/sessions/{session_id}/reset")
|
||||
async def reset_chat_session(
|
||||
session_id: str = Path(...),
|
||||
current_user = Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Delete a chat session and all its messages"""
|
||||
try:
|
||||
# Get the session to verify it exists and check ownership
|
||||
session_data = await database.get_chat_session(session_id)
|
||||
if not session_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Chat session not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
|
||||
|
||||
session = ChatSession.model_validate(session_data)
|
||||
|
||||
# Check authorization - user can only delete their own sessions
|
||||
if session.user_id != current_user.id:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content=create_error_response("FORBIDDEN", "Cannot reset another user's chat session")
|
||||
status_code=403, content=create_error_response("FORBIDDEN", "Cannot reset another user's chat session")
|
||||
)
|
||||
|
||||
# Delete all messages associated with this session
|
||||
@ -518,20 +442,12 @@ async def reset_chat_session(
|
||||
logger.warning(f"⚠️ Error deleting messages for session {session_id}: {e}")
|
||||
# Continue with session deletion even if message deletion fails
|
||||
|
||||
|
||||
logger.info(f"🗑️ Chat session {session_id} reset by user {current_user.id}")
|
||||
|
||||
return create_success_response({
|
||||
"success": True,
|
||||
"message": "Chat session reset successfully",
|
||||
"sessionId": session_id
|
||||
})
|
||||
return create_success_response(
|
||||
{"success": True, "message": "Chat session reset successfully", "sessionId": session_id}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Reset chat session error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("RESET_ERROR", str(e))
|
||||
)
|
||||
|
||||
|
||||
return JSONResponse(status_code=500, content=create_error_response("RESET_ERROR", str(e)))
|
||||
|
@ -9,19 +9,16 @@ from fastapi.responses import JSONResponse
|
||||
|
||||
from database.manager import RedisDatabase
|
||||
from logger import logger
|
||||
from utils.dependencies import (
|
||||
get_current_admin, get_database
|
||||
)
|
||||
from utils.dependencies import get_current_admin, get_database
|
||||
from utils.responses import create_success_response, create_error_response
|
||||
|
||||
# Create router for authentication endpoints
|
||||
router = APIRouter(prefix="/auth", tags=["authentication"])
|
||||
|
||||
|
||||
@router.get("/guest/{guest_id}")
|
||||
async def debug_guest_session(
|
||||
guest_id: str = Path(...),
|
||||
admin_user = Depends(get_current_admin),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
guest_id: str = Path(...), admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Debug guest session issues (admin only)"""
|
||||
try:
|
||||
@ -37,7 +34,7 @@ async def debug_guest_session(
|
||||
user_lookup = await database.get_user_by_id(guest_id)
|
||||
|
||||
# Get TTL info
|
||||
primary_ttl = await database.redis.ttl(f"guests")
|
||||
primary_ttl = await database.redis.ttl("guests")
|
||||
backup_ttl = await database.redis.ttl(f"guest_backup:{guest_id}")
|
||||
|
||||
debug_info = {
|
||||
@ -45,22 +42,19 @@ async def debug_guest_session(
|
||||
"primary_storage": {
|
||||
"exists": primary_exists,
|
||||
"data": json.loads(primary_data) if primary_data else None,
|
||||
"ttl": primary_ttl
|
||||
"ttl": primary_ttl,
|
||||
},
|
||||
"backup_storage": {
|
||||
"exists": backup_exists,
|
||||
"data": json.loads(backup_data) if backup_data else None,
|
||||
"ttl": backup_ttl
|
||||
"ttl": backup_ttl,
|
||||
},
|
||||
"user_lookup": user_lookup,
|
||||
"timestamp": datetime.now(UTC).isoformat()
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
|
||||
return create_success_response(debug_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Debug guest session error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("DEBUG_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("DEBUG_ERROR", str(e)))
|
||||
|
@ -10,12 +10,8 @@ from fastapi.responses import JSONResponse
|
||||
|
||||
from database.manager import RedisDatabase
|
||||
from logger import logger
|
||||
from models import (
|
||||
CreateEmployerRequest
|
||||
)
|
||||
from utils.dependencies import (
|
||||
get_database
|
||||
)
|
||||
from models import CreateEmployerRequest
|
||||
from utils.dependencies import get_database
|
||||
from utils.responses import create_success_response, create_error_response
|
||||
from email_service import email_service
|
||||
from utils.auth_utils import AuthenticationManager
|
||||
@ -23,29 +19,22 @@ from utils.auth_utils import AuthenticationManager
|
||||
# Create router for job endpoints
|
||||
router = APIRouter(prefix="/employers", tags=["employers"])
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_employer_with_verification(
|
||||
request: CreateEmployerRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
request: CreateEmployerRequest, background_tasks: BackgroundTasks, database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Create a new employer with email verification"""
|
||||
try:
|
||||
# Similar to candidate creation but for employer
|
||||
auth_manager = AuthenticationManager(database)
|
||||
|
||||
user_exists, conflict_field = await auth_manager.check_user_exists(
|
||||
request.email,
|
||||
request.username
|
||||
)
|
||||
user_exists, conflict_field = await auth_manager.check_user_exists(request.email, request.username)
|
||||
|
||||
if user_exists and conflict_field:
|
||||
return JSONResponse(
|
||||
status_code=409,
|
||||
content=create_error_response(
|
||||
"USER_EXISTS",
|
||||
f"A user with this {conflict_field} already exists"
|
||||
)
|
||||
content=create_error_response("USER_EXISTS", f"A user with this {conflict_field} already exists"),
|
||||
)
|
||||
|
||||
employer_id = str(uuid.uuid4())
|
||||
@ -64,12 +53,8 @@ async def create_employer_with_verification(
|
||||
"updatedAt": current_time.isoformat(),
|
||||
"status": "pending", # Not active until verified
|
||||
"userType": "employer",
|
||||
"location": {
|
||||
"city": "",
|
||||
"country": "",
|
||||
"remote": False
|
||||
},
|
||||
"socialLinks": []
|
||||
"location": {"city": "", "country": "", "remote": False},
|
||||
"socialLinks": [],
|
||||
}
|
||||
|
||||
verification_token = secrets.token_urlsafe(32)
|
||||
@ -78,32 +63,25 @@ async def create_employer_with_verification(
|
||||
request.email,
|
||||
verification_token,
|
||||
"employer",
|
||||
{
|
||||
"employer_data": employer_data,
|
||||
"password": request.password,
|
||||
"username": request.username
|
||||
}
|
||||
{"employer_data": employer_data, "password": request.password, "username": request.username},
|
||||
)
|
||||
|
||||
background_tasks.add_task(
|
||||
email_service.send_verification_email,
|
||||
request.email,
|
||||
verification_token,
|
||||
request.company_name
|
||||
email_service.send_verification_email, request.email, verification_token, request.company_name
|
||||
)
|
||||
|
||||
logger.info(f"✅ Employer registration initiated for: {request.email}")
|
||||
|
||||
return create_success_response({
|
||||
return create_success_response(
|
||||
{
|
||||
"message": "Registration successful! Please check your email to verify your account.",
|
||||
"email": request.email,
|
||||
"verificationRequired": True
|
||||
})
|
||||
"verificationRequired": True,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Employer creation error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("CREATION_FAILED", "Failed to create employer account")
|
||||
status_code=500, content=create_error_response("CREATION_FAILED", "Failed to create employer account")
|
||||
)
|
||||
|
||||
|
@ -21,14 +21,24 @@ from utils.helpers import create_job_from_content, filter_and_paginate, get_docu
|
||||
from database.manager import RedisDatabase
|
||||
from logger import logger
|
||||
from models import (
|
||||
MOCK_UUID, ApiActivityType, ApiStatusType, ChatContextType, ChatMessage, ChatMessageError, ChatMessageStatus, DocumentType, Job, JobRequirementsMessage, Candidate, Employer
|
||||
)
|
||||
from utils.dependencies import (
|
||||
get_current_admin, get_database, get_current_user
|
||||
MOCK_UUID,
|
||||
ApiActivityType,
|
||||
ApiStatusType,
|
||||
ChatContextType,
|
||||
ChatMessage,
|
||||
ChatMessageError,
|
||||
ChatMessageStatus,
|
||||
DocumentType,
|
||||
Job,
|
||||
JobRequirementsMessage,
|
||||
Candidate,
|
||||
Employer,
|
||||
)
|
||||
from utils.dependencies import get_current_admin, get_database, get_current_user
|
||||
from utils.responses import create_paginated_response, create_success_response, create_error_response
|
||||
import utils.llm_proxy as llm_manager
|
||||
import entities.entity_manager as entities
|
||||
|
||||
# Create router for job endpoints
|
||||
router = APIRouter(prefix="/jobs", tags=["jobs"])
|
||||
|
||||
@ -38,14 +48,14 @@ async def reformat_as_markdown(database: RedisDatabase, candidate_entity: Candid
|
||||
if not chat_agent:
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="No agent found for job requirements chat type"
|
||||
content="No agent found for job requirements chat type",
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Reformatting job description as markdown...",
|
||||
activity=ApiActivityType.CONVERTING
|
||||
content="Reformatting job description as markdown...",
|
||||
activity=ApiActivityType.CONVERTING,
|
||||
)
|
||||
yield status_message
|
||||
|
||||
@ -58,7 +68,7 @@ async def reformat_as_markdown(database: RedisDatabase, candidate_entity: Candid
|
||||
system_prompt="""
|
||||
You are a document editor. Take the provided job description and reformat as legible markdown.
|
||||
Return only the markdown content, no other text. Make sure all content is included.
|
||||
"""
|
||||
""",
|
||||
):
|
||||
pass
|
||||
|
||||
@ -66,16 +76,16 @@ Return only the markdown content, no other text. Make sure all content is includ
|
||||
logger.error("❌ Failed to reformat job description to markdown")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Failed to reformat job description"
|
||||
content="Failed to reformat job description",
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
chat_message : ChatMessage = message
|
||||
chat_message: ChatMessage = message
|
||||
try:
|
||||
chat_message.content = chat_agent.extract_markdown_from_text(chat_message.content)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
logger.info(f"✅ Successfully converted content to markdown")
|
||||
logger.info("✅ Successfully converted content to markdown")
|
||||
yield chat_message
|
||||
return
|
||||
|
||||
@ -84,7 +94,7 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Initiating connection with {current_user.first_name}'s AI agent...",
|
||||
activity=ApiActivityType.INFO
|
||||
activity=ApiActivityType.INFO,
|
||||
)
|
||||
yield status_message
|
||||
await asyncio.sleep(0) # Let the status message propagate
|
||||
@ -98,7 +108,7 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
|
||||
if not message or not isinstance(message, ChatMessage):
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Failed to reformat job description"
|
||||
content="Failed to reformat job description",
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
@ -108,23 +118,20 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
|
||||
if not chat_agent:
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="No agent found for job requirements chat type"
|
||||
content="No agent found for job requirements chat type",
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Analyzing document for company and requirement details...",
|
||||
activity=ApiActivityType.SEARCHING
|
||||
content="Analyzing document for company and requirement details...",
|
||||
activity=ApiActivityType.SEARCHING,
|
||||
)
|
||||
yield status_message
|
||||
|
||||
message = None
|
||||
async for message in chat_agent.generate(
|
||||
llm=llm_manager.get_llm(),
|
||||
model=defines.model,
|
||||
session_id=MOCK_UUID,
|
||||
prompt=markdown_message.content
|
||||
llm=llm_manager.get_llm(), model=defines.model, session_id=MOCK_UUID, prompt=markdown_message.content
|
||||
):
|
||||
if message.status != ApiStatusType.DONE:
|
||||
yield message
|
||||
@ -132,23 +139,22 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
|
||||
if not message or not isinstance(message, JobRequirementsMessage):
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Job extraction did not convert successfully"
|
||||
content="Job extraction did not convert successfully",
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
job_requirements : JobRequirementsMessage = message
|
||||
job_requirements: JobRequirementsMessage = message
|
||||
logger.info(f"✅ Successfully generated job requirements for job {job_requirements.id}")
|
||||
yield job_requirements
|
||||
return
|
||||
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_job(
|
||||
job_data: Dict[str, Any] = Body(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Create a new job"""
|
||||
try:
|
||||
@ -165,20 +171,17 @@ async def create_job(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Job creation error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("CREATION_FAILED", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=400, content=create_error_response("CREATION_FAILED", str(e)))
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_candidate_job(
|
||||
job_data: Dict[str, Any] = Body(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Create a new job"""
|
||||
is_employer = isinstance(current_user, Employer)
|
||||
isinstance(current_user, Employer)
|
||||
|
||||
try:
|
||||
job = Job.model_validate(job_data)
|
||||
@ -194,28 +197,22 @@ async def create_candidate_job(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Job creation error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("CREATION_FAILED", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=400, content=create_error_response("CREATION_FAILED", str(e)))
|
||||
|
||||
|
||||
@router.patch("/{job_id}")
|
||||
async def update_job(
|
||||
job_id: str = Path(...),
|
||||
updates: Dict[str, Any] = Body(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Update a candidate"""
|
||||
try:
|
||||
job_data = await database.get_job(job_id)
|
||||
if not job_data:
|
||||
logger.warning(f"⚠️ Job not found for update: {job_data}")
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Job not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Job not found"))
|
||||
|
||||
job = Job.model_validate(job_data)
|
||||
|
||||
@ -223,8 +220,7 @@ async def update_job(
|
||||
if current_user.is_admin is False and job.owner_id != current_user.id:
|
||||
logger.warning(f"⚠️ Unauthorized update attempt by user {current_user.id} on job {job_id}")
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content=create_error_response("FORBIDDEN", "Cannot update another user's job")
|
||||
status_code=403, content=create_error_response("FORBIDDEN", "Cannot update another user's job")
|
||||
)
|
||||
|
||||
# Apply updates
|
||||
@ -239,25 +235,22 @@ async def update_job(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Update job error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("UPDATE_FAILED", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=400, content=create_error_response("UPDATE_FAILED", str(e)))
|
||||
|
||||
|
||||
@router.post("/from-content")
|
||||
async def create_job_from_description(
|
||||
content: str = Body(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
content: str = Body(...), current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Upload a document for the current candidate"""
|
||||
|
||||
async def content_stream_generator(content):
|
||||
# Verify user is a candidate
|
||||
if current_user.user_type != "candidate":
|
||||
logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Only candidates can upload documents"
|
||||
content="Only candidates can upload documents",
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
@ -277,10 +270,11 @@ async def create_job_from_description(
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
async def to_json(method):
|
||||
try:
|
||||
async for message in method:
|
||||
json_data = message.model_dump(mode='json', by_alias=True)
|
||||
json_data = message.model_dump(mode="json", by_alias=True)
|
||||
json_str = json.dumps(json_data)
|
||||
yield f"data: {json_str}\n\n".encode("utf-8")
|
||||
except Exception as e:
|
||||
@ -304,18 +298,25 @@ async def create_job_from_description(
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Document upload error: {e}")
|
||||
return StreamingResponse(
|
||||
iter([json.dumps(ChatMessageError(
|
||||
iter(
|
||||
[
|
||||
json.dumps(
|
||||
ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Failed to upload document"
|
||||
).model_dump(by_alias=True)).encode("utf-8")]),
|
||||
media_type="text/event-stream"
|
||||
content="Failed to upload document",
|
||||
).model_dump(by_alias=True)
|
||||
).encode("utf-8")
|
||||
]
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
async def create_job_from_file(
|
||||
file: UploadFile = File(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Upload a job document for the current candidate and create a Job"""
|
||||
# Check file size (limit to 10MB)
|
||||
@ -324,55 +325,70 @@ async def create_job_from_file(
|
||||
if len(file_content) > max_size:
|
||||
logger.info(f"⚠️ File too large: {file.filename} ({len(file_content)} bytes)")
|
||||
return StreamingResponse(
|
||||
iter([json.dumps(ChatMessageError(
|
||||
iter(
|
||||
[
|
||||
json.dumps(
|
||||
ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="File size exceeds 10MB limit"
|
||||
).model_dump(by_alias=True)).encode("utf-8")]),
|
||||
media_type="text/event-stream"
|
||||
content="File size exceeds 10MB limit",
|
||||
).model_dump(by_alias=True)
|
||||
).encode("utf-8")
|
||||
]
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
if len(file_content) == 0:
|
||||
logger.info(f"⚠️ File is empty: {file.filename}")
|
||||
return StreamingResponse(
|
||||
iter([json.dumps(ChatMessageError(
|
||||
iter(
|
||||
[
|
||||
json.dumps(
|
||||
ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="File is empty"
|
||||
).model_dump(by_alias=True)).encode("utf-8")]),
|
||||
media_type="text/event-stream"
|
||||
content="File is empty",
|
||||
).model_dump(by_alias=True)
|
||||
).encode("utf-8")
|
||||
]
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
"""Upload a document for the current candidate"""
|
||||
|
||||
async def upload_stream_generator(file_content):
|
||||
# Verify user is a candidate
|
||||
if current_user.user_type != "candidate":
|
||||
logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Only candidates can upload documents"
|
||||
content="Only candidates can upload documents",
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
file.filename = re.sub(r'^.*/', '', file.filename) if file.filename else '' # Sanitize filename
|
||||
file.filename = re.sub(r"^.*/", "", file.filename) if file.filename else "" # Sanitize filename
|
||||
if not file.filename or file.filename.strip() == "":
|
||||
logger.warning("⚠️ File upload attempt with missing filename")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="File must have a valid filename"
|
||||
content="File must have a valid filename",
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
logger.info(f"📁 Received file upload: filename='{file.filename}', content_type='{file.content_type}', size='{len(file_content)} bytes'")
|
||||
logger.info(
|
||||
f"📁 Received file upload: filename='{file.filename}', content_type='{file.content_type}', size='{len(file_content)} bytes'"
|
||||
)
|
||||
|
||||
# Validate file type
|
||||
allowed_types = ['.txt', '.md', '.docx', '.pdf', '.png', '.jpg', '.jpeg', '.gif']
|
||||
allowed_types = [".txt", ".md", ".docx", ".pdf", ".png", ".jpg", ".jpeg", ".gif"]
|
||||
file_extension = pathlib.Path(file.filename).suffix.lower() if file.filename else ""
|
||||
|
||||
if file_extension not in allowed_types:
|
||||
logger.warning(f"⚠️ Invalid file type: {file_extension} for file {file.filename}")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}"
|
||||
content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}",
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
@ -383,7 +399,7 @@ async def create_job_from_file(
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Converting content from {document_type}...",
|
||||
activity=ApiActivityType.CONVERTING
|
||||
activity=ApiActivityType.CONVERTING,
|
||||
)
|
||||
yield status_message
|
||||
try:
|
||||
@ -391,7 +407,7 @@ async def create_job_from_file(
|
||||
stream = io.BytesIO(file_content)
|
||||
stream_info = StreamInfo(
|
||||
extension=file_extension, # e.g., ".pdf"
|
||||
url=file.filename # optional, helps with logging and guessing
|
||||
url=file.filename, # optional, helps with logging and guessing
|
||||
)
|
||||
result = md.convert_stream(stream, stream_info=stream_info, output_format="markdown")
|
||||
file_content = result.text_content
|
||||
@ -405,15 +421,18 @@ async def create_job_from_file(
|
||||
logger.error(f"❌ Error converting {file.filename} to Markdown: {e}")
|
||||
return
|
||||
|
||||
async for message in create_job_from_content(database=database, current_user=current_user, content=file_content):
|
||||
async for message in create_job_from_content(
|
||||
database=database, current_user=current_user, content=file_content
|
||||
):
|
||||
yield message
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
async def to_json(method):
|
||||
try:
|
||||
async for message in method:
|
||||
json_data = message.model_dump(mode='json', by_alias=True)
|
||||
json_data = message.model_dump(mode="json", by_alias=True)
|
||||
json_str = json.dumps(json_data)
|
||||
yield f"data: {json_str}\n\n".encode("utf-8")
|
||||
except Exception as e:
|
||||
@ -437,26 +456,27 @@ async def create_job_from_file(
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Document upload error: {e}")
|
||||
return StreamingResponse(
|
||||
iter([json.dumps(ChatMessageError(
|
||||
iter(
|
||||
[
|
||||
json.dumps(
|
||||
ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Failed to upload document"
|
||||
).model_dump(mode='json', by_alias=True)).encode("utf-8")]),
|
||||
media_type="text/event-stream"
|
||||
content="Failed to upload document",
|
||||
).model_dump(mode="json", by_alias=True)
|
||||
).encode("utf-8")
|
||||
]
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{job_id}")
|
||||
async def get_job(
|
||||
job_id: str = Path(...),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
async def get_job(job_id: str = Path(...), database: RedisDatabase = Depends(get_database)):
|
||||
"""Get a job by ID"""
|
||||
try:
|
||||
job_data = await database.get_job(job_id)
|
||||
if not job_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Job not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Job not found"))
|
||||
|
||||
# Increment view count
|
||||
job_data["views"] = job_data.get("views", 0) + 1
|
||||
@ -467,10 +487,8 @@ async def get_job(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get job error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("FETCH_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_jobs(
|
||||
@ -479,7 +497,7 @@ async def get_jobs(
|
||||
sortBy: Optional[str] = Query(None, alias="sortBy"),
|
||||
sortOrder: str = Query("desc", pattern="^(asc|desc)$", alias="sortOrder"),
|
||||
filters: Optional[str] = Query(None),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Get paginated list of jobs"""
|
||||
try:
|
||||
@ -493,23 +511,18 @@ async def get_jobs(
|
||||
for job in all_jobs_data.values():
|
||||
jobs_list.append(Job.model_validate(job))
|
||||
|
||||
paginated_jobs, total = filter_and_paginate(
|
||||
jobs_list, page, limit, sortBy, sortOrder, filter_dict
|
||||
)
|
||||
paginated_jobs, total = filter_and_paginate(jobs_list, page, limit, sortBy, sortOrder, filter_dict)
|
||||
|
||||
paginated_response = create_paginated_response(
|
||||
[j.model_dump(by_alias=True) for j in paginated_jobs],
|
||||
page, limit, total
|
||||
[j.model_dump(by_alias=True) for j in paginated_jobs], page, limit, total
|
||||
)
|
||||
|
||||
return create_success_response(paginated_response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get jobs error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("FETCH_FAILED", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=400, content=create_error_response("FETCH_FAILED", str(e)))
|
||||
|
||||
|
||||
@router.get("/search")
|
||||
async def search_jobs(
|
||||
@ -517,7 +530,7 @@ async def search_jobs(
|
||||
filters: Optional[str] = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Search jobs"""
|
||||
try:
|
||||
@ -532,69 +545,52 @@ async def search_jobs(
|
||||
if query:
|
||||
query_lower = query.lower()
|
||||
jobs_list = [
|
||||
j for j in jobs_list
|
||||
if ((j.title and query_lower in j.title.lower()) or
|
||||
(j.description and query_lower in j.description.lower()) or
|
||||
any(query_lower in skill.lower() for skill in getattr(j, "skills", []) or []))
|
||||
j
|
||||
for j in jobs_list
|
||||
if (
|
||||
(j.title and query_lower in j.title.lower())
|
||||
or (j.description and query_lower in j.description.lower())
|
||||
or any(query_lower in skill.lower() for skill in getattr(j, "skills", []) or [])
|
||||
)
|
||||
]
|
||||
|
||||
paginated_jobs, total = filter_and_paginate(
|
||||
jobs_list, page, limit, filters=filter_dict
|
||||
)
|
||||
paginated_jobs, total = filter_and_paginate(jobs_list, page, limit, filters=filter_dict)
|
||||
|
||||
paginated_response = create_paginated_response(
|
||||
[j.model_dump(by_alias=True) for j in paginated_jobs],
|
||||
page, limit, total
|
||||
[j.model_dump(by_alias=True) for j in paginated_jobs], page, limit, total
|
||||
)
|
||||
|
||||
return create_success_response(paginated_response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Search jobs error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("SEARCH_FAILED", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=400, content=create_error_response("SEARCH_FAILED", str(e)))
|
||||
|
||||
|
||||
@router.delete("/{job_id}")
|
||||
async def delete_job(
|
||||
job_id: str = Path(...),
|
||||
admin_user = Depends(get_current_admin),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
job_id: str = Path(...), admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Delete a Job"""
|
||||
try:
|
||||
# Check if admin user
|
||||
if not admin_user.is_admin:
|
||||
logger.warning(f"⚠️ Unauthorized delete attempt by user {admin_user.id}")
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content=create_error_response("FORBIDDEN", "Only admins can delete")
|
||||
)
|
||||
return JSONResponse(status_code=403, content=create_error_response("FORBIDDEN", "Only admins can delete"))
|
||||
|
||||
# Get candidate data
|
||||
job_data = await database.get_job(job_id)
|
||||
if not job_data:
|
||||
logger.warning(f"⚠️ Candidate not found for deletion: {job_id}")
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Job not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Job not found"))
|
||||
|
||||
# Delete job from database
|
||||
await database.delete_job(job_id)
|
||||
|
||||
logger.info(f"🗑️ Job deleted: {job_id} by admin {admin_user.id}")
|
||||
|
||||
return create_success_response({
|
||||
"message": "Job deleted successfully",
|
||||
"jobId": job_id
|
||||
})
|
||||
return create_success_response({"message": "Job deleted successfully", "jobId": job_id})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Delete job error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("DELETE_ERROR", "Failed to delete job")
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("DELETE_ERROR", "Failed to delete job"))
|
||||
|
@ -6,6 +6,7 @@ from utils.llm_proxy import LLMProvider, get_llm
|
||||
|
||||
router = APIRouter(prefix="/providers", tags=["providers"])
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
async def list_models(provider: Optional[str] = None):
|
||||
"""List available models for a provider"""
|
||||
@ -17,29 +18,25 @@ async def list_models(provider: Optional[str] = None):
|
||||
try:
|
||||
provider_enum = LLMProvider(provider.lower())
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported provider: {provider}"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported provider: {provider}")
|
||||
|
||||
models = await llm.list_models(provider_enum)
|
||||
return {
|
||||
"provider": provider_enum.value if provider_enum else llm.default_provider.value,
|
||||
"models": models
|
||||
}
|
||||
return {"provider": provider_enum.value if provider_enum else llm.default_provider.value, "models": models}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_providers():
|
||||
"""List all configured providers"""
|
||||
llm = get_llm()
|
||||
return {
|
||||
"providers": [provider.value for provider in llm._initialized_providers],
|
||||
"default": llm.default_provider.value
|
||||
"default": llm.default_provider.value,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{provider}/set-default")
|
||||
async def set_default_provider(provider: str):
|
||||
"""Set the default provider"""
|
||||
@ -51,6 +48,7 @@ async def set_default_provider(provider: str):
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
# Health check endpoint
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
@ -59,5 +57,5 @@ async def health_check():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"providers_configured": len(llm._initialized_providers),
|
||||
"default_provider": llm.default_provider.value
|
||||
"default_provider": llm.default_provider.value,
|
||||
}
|
||||
|
@ -11,26 +11,24 @@ from fastapi.responses import StreamingResponse
|
||||
import backstory_traceback as backstory_traceback
|
||||
from database.manager import RedisDatabase
|
||||
from logger import logger
|
||||
from models import (
|
||||
MOCK_UUID, ChatMessageError, Job, Candidate, Resume, ResumeMessage
|
||||
)
|
||||
from utils.dependencies import (
|
||||
get_database, get_current_user
|
||||
)
|
||||
from models import MOCK_UUID, ChatMessageError, Job, Candidate, Resume, ResumeMessage
|
||||
from utils.dependencies import get_database, get_current_user
|
||||
from utils.responses import create_success_response
|
||||
|
||||
# Create router for authentication endpoints
|
||||
router = APIRouter(prefix="/resumes", tags=["resumes"])
|
||||
|
||||
|
||||
@router.post("/{candidate_id}/{job_id}")
|
||||
async def create_candidate_resume(
|
||||
candidate_id: str = Path(..., description="ID of the candidate"),
|
||||
job_id: str = Path(..., description="ID of the job"),
|
||||
resume_content: str = Body(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Create a new resume for a candidate/job combination"""
|
||||
|
||||
async def message_stream_generator():
|
||||
logger.info(f"🔍 Looking up candidate and job details for {candidate_id}/{job_id}")
|
||||
|
||||
@ -39,7 +37,7 @@ async def create_candidate_resume(
|
||||
logger.error(f"❌ Candidate with ID '{candidate_id}' not found")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Candidate with ID '{candidate_id}' not found"
|
||||
content=f"Candidate with ID '{candidate_id}' not found",
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
@ -50,13 +48,15 @@ async def create_candidate_resume(
|
||||
logger.error(f"❌ Job with ID '{job_id}' not found")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Job with ID '{job_id}' not found"
|
||||
content=f"Job with ID '{job_id}' not found",
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
job = Job.model_validate(job_data)
|
||||
|
||||
logger.info(f"📄 Saving resume for candidate {candidate.first_name} {candidate.last_name} for job '{job.title}'")
|
||||
logger.info(
|
||||
f"📄 Saving resume for candidate {candidate.first_name} {candidate.last_name} for job '{job.title}'"
|
||||
)
|
||||
|
||||
# Job and Candidate are valid. Save the resume
|
||||
resume = Resume(
|
||||
@ -66,16 +66,13 @@ async def create_candidate_resume(
|
||||
)
|
||||
resume_message: ResumeMessage = ResumeMessage(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
resume=resume
|
||||
resume=resume,
|
||||
)
|
||||
|
||||
# Save to database
|
||||
success = await database.set_resume(current_user.id, resume.model_dump())
|
||||
if not success:
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID,
|
||||
content="Failed to save resume to database"
|
||||
)
|
||||
error_message = ChatMessageError(session_id=MOCK_UUID, content="Failed to save resume to database")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
@ -84,10 +81,11 @@ async def create_candidate_resume(
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
async def to_json(method):
|
||||
try:
|
||||
async for message in method:
|
||||
json_data = message.model_dump(mode='json', by_alias=True)
|
||||
json_data = message.model_dump(mode="json", by_alias=True)
|
||||
json_str = json.dumps(json_data)
|
||||
yield f"data: {json_str}\n\n".encode("utf-8")
|
||||
except Exception as e:
|
||||
@ -111,22 +109,26 @@ async def create_candidate_resume(
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Resume creation error: {e}")
|
||||
return StreamingResponse(
|
||||
iter([json.dumps(ChatMessageError(
|
||||
iter(
|
||||
[
|
||||
json.dumps(
|
||||
ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Failed to create resume"
|
||||
).model_dump(mode='json', by_alias=True))]),
|
||||
media_type="text/event-stream"
|
||||
content="Failed to create resume",
|
||||
).model_dump(mode="json", by_alias=True)
|
||||
)
|
||||
]
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_user_resumes(
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
async def get_user_resumes(current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)):
|
||||
"""Get all resumes for the current user"""
|
||||
try:
|
||||
resumes_data = await database.get_all_resumes_for_user(current_user.id)
|
||||
resumes : List[Resume] = [Resume.model_validate(data) for data in resumes_data]
|
||||
resumes: List[Resume] = [Resume.model_validate(data) for data in resumes_data]
|
||||
for resume in resumes:
|
||||
job_data = await database.get_job(resume.job_id)
|
||||
if job_data:
|
||||
@ -135,19 +137,17 @@ async def get_user_resumes(
|
||||
if candidate_data:
|
||||
resume.candidate = Candidate.model_validate(candidate_data)
|
||||
resumes.sort(key=lambda x: x.updated_at, reverse=True) # Sort by creation date
|
||||
return create_success_response({
|
||||
"resumes": resumes,
|
||||
"count": len(resumes)
|
||||
})
|
||||
return create_success_response({"resumes": resumes, "count": len(resumes)})
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error retrieving resumes for user {current_user.id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve resumes")
|
||||
|
||||
|
||||
@router.get("/{resume_id}")
|
||||
async def get_resume(
|
||||
resume_id: str = Path(..., description="ID of the resume"),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Get a specific resume by ID"""
|
||||
try:
|
||||
@ -155,21 +155,19 @@ async def get_resume(
|
||||
if not resume:
|
||||
raise HTTPException(status_code=404, detail="Resume not found")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"resume": resume
|
||||
}
|
||||
return {"success": True, "resume": resume}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error retrieving resume {resume_id} for user {current_user.id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve resume")
|
||||
|
||||
|
||||
@router.delete("/{resume_id}")
|
||||
async def delete_resume(
|
||||
resume_id: str = Path(..., description="ID of the resume"),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Delete a specific resume"""
|
||||
try:
|
||||
@ -177,102 +175,82 @@ async def delete_resume(
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Resume not found")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Resume {resume_id} deleted successfully"
|
||||
}
|
||||
return {"success": True, "message": f"Resume {resume_id} deleted successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error deleting resume {resume_id} for user {current_user.id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to delete resume")
|
||||
|
||||
|
||||
@router.get("/candidate/{candidate_id}")
|
||||
async def get_resumes_by_candidate(
|
||||
candidate_id: str = Path(..., description="ID of the candidate"),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Get all resumes for a specific candidate"""
|
||||
try:
|
||||
resumes = await database.get_resumes_by_candidate(current_user.id, candidate_id)
|
||||
return {
|
||||
"success": True,
|
||||
"candidate_id": candidate_id,
|
||||
"resumes": resumes,
|
||||
"count": len(resumes)
|
||||
}
|
||||
return {"success": True, "candidate_id": candidate_id, "resumes": resumes, "count": len(resumes)}
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error retrieving resumes for candidate {candidate_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve candidate resumes")
|
||||
|
||||
|
||||
@router.get("/job/{job_id}")
|
||||
async def get_resumes_by_job(
|
||||
job_id: str = Path(..., description="ID of the job"),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Get all resumes for a specific job"""
|
||||
try:
|
||||
resumes = await database.get_resumes_by_job(current_user.id, job_id)
|
||||
return {
|
||||
"success": True,
|
||||
"job_id": job_id,
|
||||
"resumes": resumes,
|
||||
"count": len(resumes)
|
||||
}
|
||||
return {"success": True, "job_id": job_id, "resumes": resumes, "count": len(resumes)}
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error retrieving resumes for job {job_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve job resumes")
|
||||
|
||||
|
||||
@router.get("/search")
|
||||
async def search_resumes(
|
||||
q: str = Query(..., description="Search query"),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Search resumes by content"""
|
||||
try:
|
||||
resumes = await database.search_resumes_for_user(current_user.id, q)
|
||||
return {
|
||||
"success": True,
|
||||
"query": q,
|
||||
"resumes": resumes,
|
||||
"count": len(resumes)
|
||||
}
|
||||
return {"success": True, "query": q, "resumes": resumes, "count": len(resumes)}
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error searching resumes for user {current_user.id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to search resumes")
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_resume_statistics(
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Get resume statistics for the current user"""
|
||||
try:
|
||||
stats = await database.get_resume_statistics(current_user.id)
|
||||
return {
|
||||
"success": True,
|
||||
"statistics": stats
|
||||
}
|
||||
return {"success": True, "statistics": stats}
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error retrieving resume statistics for user {current_user.id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve resume statistics")
|
||||
|
||||
|
||||
@router.put("/{resume_id}")
|
||||
async def update_resume(
|
||||
resume_id: str = Path(..., description="ID of the resume"),
|
||||
resume: str = Body(..., description="Updated resume content"),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Update the content of a specific resume"""
|
||||
try:
|
||||
updates = {
|
||||
"resume": resume,
|
||||
"updated_at": datetime.now(UTC).isoformat()
|
||||
}
|
||||
updates = {"resume": resume, "updated_at": datetime.now(UTC).isoformat()}
|
||||
|
||||
updated_resume_data = await database.update_resume(current_user.id, resume_id, updates)
|
||||
if not updated_resume_data:
|
||||
@ -280,11 +258,9 @@ async def update_resume(
|
||||
raise HTTPException(status_code=404, detail="Resume not found")
|
||||
updated_resume = Resume.model_validate(updated_resume_data) if updated_resume_data else None
|
||||
|
||||
return create_success_response({
|
||||
"success": True,
|
||||
"message": f"Resume {resume_id} updated successfully",
|
||||
"resume": updated_resume
|
||||
})
|
||||
return create_success_response(
|
||||
{"success": True, "message": f"Resume {resume_id} updated successfully", "resume": updated_resume}
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
@ -10,6 +10,7 @@ from utils.responses import create_success_response
|
||||
# Create router for authentication endpoints
|
||||
router = APIRouter(prefix="/system", tags=["system"])
|
||||
|
||||
|
||||
async def get_redis() -> Redis:
|
||||
"""Dependency to get Redis client"""
|
||||
return redis_manager.get_client()
|
||||
@ -19,9 +20,11 @@ async def get_redis() -> Redis:
|
||||
async def get_system_info(request: Request):
|
||||
"""Get system information"""
|
||||
from system_info import system_info # Import system_info function from system_info module
|
||||
|
||||
system = system_info()
|
||||
|
||||
return create_success_response(system.model_dump(mode='json'))
|
||||
return create_success_response(system.model_dump(mode="json"))
|
||||
|
||||
|
||||
@router.get("/redis/stats")
|
||||
async def redis_stats(redis: Redis = Depends(get_redis)):
|
||||
@ -33,7 +36,7 @@ async def redis_stats(redis: Redis = Depends(get_redis)):
|
||||
"total_commands_processed": info.get("total_commands_processed"),
|
||||
"keyspace_hits": info.get("keyspace_hits"),
|
||||
"keyspace_misses": info.get("keyspace_misses"),
|
||||
"uptime_in_seconds": info.get("uptime_in_seconds")
|
||||
"uptime_in_seconds": info.get("uptime_in_seconds"),
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=503, detail=f"Redis stats unavailable: {e}")
|
||||
|
@ -7,23 +7,17 @@ from fastapi.responses import JSONResponse
|
||||
|
||||
from database.manager import RedisDatabase
|
||||
from logger import logger
|
||||
from models import (
|
||||
BaseUserWithType
|
||||
)
|
||||
from utils.dependencies import (
|
||||
get_database
|
||||
)
|
||||
from models import BaseUserWithType
|
||||
from utils.dependencies import get_database
|
||||
from utils.responses import create_success_response, create_error_response
|
||||
|
||||
# Create router for job endpoints
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
|
||||
# reference can be candidateId, username, or email
|
||||
@router.get("/users/{reference}")
|
||||
async def get_user(
|
||||
reference: str = Path(...),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
async def get_user(reference: str = Path(...), database: RedisDatabase = Depends(get_database)):
|
||||
"""Get a candidate by username"""
|
||||
try:
|
||||
# Normalize reference to lowercase for case-insensitive search
|
||||
@ -31,41 +25,36 @@ async def get_user(
|
||||
|
||||
all_candidate_data = await database.get_all_candidates()
|
||||
if not all_candidate_data:
|
||||
logger.warning(f"⚠️ No users found in database")
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "No users found")
|
||||
)
|
||||
logger.warning("⚠️ No users found in database")
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "No users found"))
|
||||
|
||||
user_data = None
|
||||
for user in all_candidate_data.values():
|
||||
if (user.get("id", "").lower() == query_lower or
|
||||
user.get("username", "").lower() == query_lower or
|
||||
user.get("email", "").lower() == query_lower):
|
||||
if (
|
||||
user.get("id", "").lower() == query_lower
|
||||
or user.get("username", "").lower() == query_lower
|
||||
or user.get("email", "").lower() == query_lower
|
||||
):
|
||||
user_data = user
|
||||
break
|
||||
|
||||
if not user_data:
|
||||
all_guest_data = await database.get_all_guests()
|
||||
if not all_guest_data:
|
||||
logger.warning(f"⚠️ No guests found in database")
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "No users found")
|
||||
)
|
||||
logger.warning("⚠️ No guests found in database")
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "No users found"))
|
||||
for user in all_guest_data.values():
|
||||
if (user.get("id", "").lower() == query_lower or
|
||||
user.get("username", "").lower() == query_lower or
|
||||
user.get("email", "").lower() == query_lower):
|
||||
if (
|
||||
user.get("id", "").lower() == query_lower
|
||||
or user.get("username", "").lower() == query_lower
|
||||
or user.get("email", "").lower() == query_lower
|
||||
):
|
||||
user_data = user
|
||||
break
|
||||
|
||||
if not user_data:
|
||||
logger.warning(f"⚠️ User nor Guest found for reference: {reference}")
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "User not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "User not found"))
|
||||
|
||||
user = BaseUserWithType.model_validate(user_data)
|
||||
|
||||
@ -73,8 +62,4 @@ async def get_user(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get user error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("FETCH_ERROR", str(e))
|
||||
)
|
||||
|
||||
return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e)))
|
||||
|
@ -4,6 +4,7 @@ import subprocess
|
||||
import math
|
||||
from models import SystemInfo
|
||||
|
||||
|
||||
def get_installed_ram():
|
||||
try:
|
||||
with open("/proc/meminfo", "r") as f:
|
||||
@ -19,9 +20,7 @@ def get_graphics_cards():
|
||||
gpus = []
|
||||
try:
|
||||
# Run the ze-monitor utility
|
||||
result = subprocess.run(
|
||||
["ze-monitor"], capture_output=True, text=True, check=True
|
||||
)
|
||||
result = subprocess.run(["ze-monitor"], capture_output=True, text=True, check=True)
|
||||
|
||||
# Clean up the output (remove leading/trailing whitespace and newlines)
|
||||
output = result.stdout.strip()
|
||||
@ -71,6 +70,7 @@ def get_cpu_info():
|
||||
except Exception as e:
|
||||
return f"Error retrieving CPU info: {e}"
|
||||
|
||||
|
||||
def system_info() -> SystemInfo:
|
||||
"""
|
||||
Collects system information including RAM, GPU, CPU, LLM model, embedding model, and context length.
|
||||
|
@ -122,9 +122,7 @@ def get_forecast(grid_endpoint):
|
||||
|
||||
# Process the forecast data into a simpler format
|
||||
forecast = {
|
||||
"location": data["properties"]
|
||||
.get("relativeLocation", {})
|
||||
.get("properties", {}),
|
||||
"location": data["properties"].get("relativeLocation", {}).get("properties", {}),
|
||||
"updated": data["properties"].get("updated", ""),
|
||||
"periods": [],
|
||||
}
|
||||
@ -181,7 +179,7 @@ def get_forecast(grid_endpoint):
|
||||
def TickerValue(ticker_symbols):
|
||||
api_key = os.getenv("TWELVEDATA_API_KEY", "")
|
||||
if not api_key:
|
||||
return {"error": f"Error fetching data: No API key for TwelveData"}
|
||||
return {"error": "Error fetching data: No API key for TwelveData"}
|
||||
|
||||
results = []
|
||||
for ticker_symbol in ticker_symbols.split(","):
|
||||
@ -189,9 +187,7 @@ def TickerValue(ticker_symbols):
|
||||
if ticker_symbol == "":
|
||||
continue
|
||||
|
||||
url = (
|
||||
f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}"
|
||||
)
|
||||
url = f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}"
|
||||
|
||||
response = requests.get(url)
|
||||
data = response.json()
|
||||
@ -244,9 +240,7 @@ def yfTickerValue(ticker_symbols):
|
||||
|
||||
logging.error(f"Error fetching data for {ticker_symbol}: {e}")
|
||||
logging.error(traceback.format_exc())
|
||||
results.append(
|
||||
{"error": f"Error fetching data for {ticker_symbol}: {str(e)}"}
|
||||
)
|
||||
results.append({"error": f"Error fetching data for {ticker_symbol}: {str(e)}"})
|
||||
|
||||
return results[0] if len(results) == 1 else results
|
||||
|
||||
@ -278,8 +272,10 @@ def DateTime(timezone="America/Los_Angeles"):
|
||||
except Exception as e:
|
||||
return {"error": f"Invalid timezone {timezone}: {str(e)}"}
|
||||
|
||||
|
||||
async def GenerateImage(llm, model: str, prompt: str):
|
||||
return { "image_id": "image-a830a83-bd831" }
|
||||
return {"image_id": "image-a830a83-bd831"}
|
||||
|
||||
|
||||
async def AnalyzeSite(llm, model: str, url: str, question: str):
|
||||
"""
|
||||
@ -347,7 +343,6 @@ async def AnalyzeSite(llm, model: str, url: str, question: str):
|
||||
return f"Error processing the website content: {str(e)}"
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
class Function(BaseModel):
|
||||
name: str
|
||||
@ -355,59 +350,58 @@ class Function(BaseModel):
|
||||
parameters: Dict[str, Any]
|
||||
returns: Optional[Dict[str, Any]] = {}
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
type: str
|
||||
function: Function
|
||||
|
||||
tools : List[Tool] = [
|
||||
# Tool.model_validate({
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "name": "GenerateImage",
|
||||
# "description": """\
|
||||
# CRITICAL INSTRUCTIONS FOR IMAGE GENERATION:
|
||||
|
||||
# 1. Call this tool when users request images, drawings, or visual content
|
||||
# 2. This tool returns an image_id (e.g., "img_abc123")
|
||||
# 3. MANDATORY: You must respond with EXACTLY this format: <GenerateImage id={image_id}/>
|
||||
# 4. FORBIDDEN: DO NOT use markdown image syntax 
|
||||
# 5. FORBIDDEN: DO NOT create fake URLs or file paths
|
||||
# 6. FORBIDDEN: DO NOT use any other image embedding format
|
||||
|
||||
# CORRECT EXAMPLE:
|
||||
# User: "Draw a cat"
|
||||
# Tool returns: {"image_id": "img_xyz789"}
|
||||
# Your response: "Here's your cat image: <GenerateImage id=img_xyz789/>"
|
||||
|
||||
# WRONG EXAMPLES (DO NOT DO THIS):
|
||||
# - 
|
||||
# - 
|
||||
# - <img src="...">
|
||||
|
||||
# The <GenerateImage id={image_id}/> format is the ONLY way to display images in this system.
|
||||
# """,
|
||||
# "parameters": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "prompt": {
|
||||
# "type": "string",
|
||||
# "description": "Detailed image description including style, colors, subject, composition"
|
||||
# }
|
||||
# },
|
||||
# "required": ["prompt"]
|
||||
# },
|
||||
# "returns": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "image_id": {
|
||||
# "type": "string",
|
||||
# "description": "Unique identifier for the generated image. Use this EXACTLY in <GenerateImage id={this_value}/>"
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
# }),
|
||||
Tool.model_validate({
|
||||
tools: List[Tool] = [
|
||||
# Tool.model_validate({
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "name": "GenerateImage",
|
||||
# "description": """\
|
||||
# CRITICAL INSTRUCTIONS FOR IMAGE GENERATION:
|
||||
# 1. Call this tool when users request images, drawings, or visual content
|
||||
# 2. This tool returns an image_id (e.g., "img_abc123")
|
||||
# 3. MANDATORY: You must respond with EXACTLY this format: <GenerateImage id={image_id}/>
|
||||
# 4. FORBIDDEN: DO NOT use markdown image syntax 
|
||||
# 5. FORBIDDEN: DO NOT create fake URLs or file paths
|
||||
# 6. FORBIDDEN: DO NOT use any other image embedding format
|
||||
# CORRECT EXAMPLE:
|
||||
# User: "Draw a cat"
|
||||
# Tool returns: {"image_id": "img_xyz789"}
|
||||
# Your response: "Here's your cat image: <GenerateImage id=img_xyz789/>"
|
||||
# WRONG EXAMPLES (DO NOT DO THIS):
|
||||
# - 
|
||||
# - 
|
||||
# - <img src="...">
|
||||
# The <GenerateImage id={image_id}/> format is the ONLY way to display images in this system.
|
||||
# """,
|
||||
# "parameters": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "prompt": {
|
||||
# "type": "string",
|
||||
# "description": "Detailed image description including style, colors, subject, composition"
|
||||
# }
|
||||
# },
|
||||
# "required": ["prompt"]
|
||||
# },
|
||||
# "returns": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "image_id": {
|
||||
# "type": "string",
|
||||
# "description": "Unique identifier for the generated image. Use this EXACTLY in <GenerateImage id={this_value}/>"
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
# }),
|
||||
Tool.model_validate(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "TickerValue",
|
||||
@ -424,8 +418,10 @@ tools : List[Tool] = [
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}),
|
||||
Tool.model_validate({
|
||||
}
|
||||
),
|
||||
Tool.model_validate(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "AnalyzeSite",
|
||||
@ -463,8 +459,10 @@ tools : List[Tool] = [
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
Tool.model_validate({
|
||||
}
|
||||
),
|
||||
Tool.model_validate(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "DateTime",
|
||||
@ -480,8 +478,10 @@ tools : List[Tool] = [
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
}),
|
||||
Tool.model_validate({
|
||||
}
|
||||
),
|
||||
Tool.model_validate(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WeatherForecast",
|
||||
@ -504,22 +504,27 @@ tools : List[Tool] = [
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}),
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class ToolEntry(BaseModel):
|
||||
enabled: bool = True
|
||||
tool: Tool
|
||||
|
||||
|
||||
def llm_tools(tools: List[ToolEntry]) -> List[Dict[str, Any]]:
|
||||
return [entry.tool.model_dump(mode='json') for entry in tools if entry.enabled == True]
|
||||
return [entry.tool.model_dump(mode="json") for entry in tools if entry.enabled is True]
|
||||
|
||||
|
||||
def all_tools() -> List[ToolEntry]:
|
||||
return [ToolEntry(tool=tool) for tool in tools]
|
||||
|
||||
|
||||
def enabled_tools(tools: List[ToolEntry]) -> List[ToolEntry]:
|
||||
return [ToolEntry(tool=entry.tool) for entry in tools if entry.enabled == True]
|
||||
return [ToolEntry(tool=entry.tool) for entry in tools if entry.enabled is True]
|
||||
|
||||
|
||||
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite", "GenerateImage"]
|
||||
__all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"]
|
||||
|
||||
|
@ -14,6 +14,7 @@ from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PasswordSecurity:
|
||||
"""Handles password hashing and verification using bcrypt"""
|
||||
|
||||
@ -32,9 +33,9 @@ class PasswordSecurity:
|
||||
salt = bcrypt.gensalt()
|
||||
|
||||
# Hash the password
|
||||
password_hash = bcrypt.hashpw(password.encode('utf-8'), salt)
|
||||
password_hash = bcrypt.hashpw(password.encode("utf-8"), salt)
|
||||
|
||||
return password_hash.decode('utf-8'), salt.decode('utf-8')
|
||||
return password_hash.decode("utf-8"), salt.decode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def verify_password(password: str, password_hash: str) -> bool:
|
||||
@ -49,10 +50,7 @@ class PasswordSecurity:
|
||||
True if password matches, False otherwise
|
||||
"""
|
||||
try:
|
||||
return bcrypt.checkpw(
|
||||
password.encode('utf-8'),
|
||||
password_hash.encode('utf-8')
|
||||
)
|
||||
return bcrypt.checkpw(password.encode("utf-8"), password_hash.encode("utf-8"))
|
||||
except Exception as e:
|
||||
logger.error(f"Password verification error: {e}")
|
||||
return False
|
||||
@ -62,8 +60,10 @@ class PasswordSecurity:
|
||||
"""Generate a cryptographically secure random token"""
|
||||
return secrets.token_urlsafe(length)
|
||||
|
||||
|
||||
class AuthenticationRecord(BaseModel):
|
||||
"""Authentication record for storing user credentials"""
|
||||
|
||||
user_id: str
|
||||
password_hash: str
|
||||
salt: str
|
||||
@ -78,18 +78,19 @@ class AuthenticationRecord(BaseModel):
|
||||
locked_until: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
json_encoders = {datetime: lambda v: v.isoformat() if v else None}
|
||||
|
||||
|
||||
class SecurityConfig:
|
||||
"""Security configuration constants"""
|
||||
|
||||
MAX_LOGIN_ATTEMPTS = 5
|
||||
ACCOUNT_LOCKOUT_DURATION_MINUTES = 15
|
||||
PASSWORD_MIN_LENGTH = 8
|
||||
TOKEN_EXPIRY_HOURS = 24
|
||||
REFRESH_TOKEN_EXPIRY_DAYS = 30
|
||||
|
||||
|
||||
class AuthenticationManager:
|
||||
"""Manages authentication operations with security features"""
|
||||
|
||||
@ -120,7 +121,7 @@ class AuthenticationManager:
|
||||
password_hash=password_hash,
|
||||
salt=salt,
|
||||
last_password_change=datetime.now(timezone.utc),
|
||||
login_attempts=0
|
||||
login_attempts=0,
|
||||
)
|
||||
|
||||
# Store in database
|
||||
@ -129,7 +130,9 @@ class AuthenticationManager:
|
||||
logger.info(f"🔐 Created authentication record for user {user_id}")
|
||||
return auth_record
|
||||
|
||||
async def verify_user_credentials(self, login: str, password: str) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
|
||||
async def verify_user_credentials(
|
||||
self, login: str, password: str
|
||||
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
|
||||
"""
|
||||
Verify user credentials with security checks
|
||||
|
||||
@ -164,7 +167,11 @@ class AuthenticationManager:
|
||||
seconds = int(total_seconds % 60)
|
||||
time_until_unlock_str = f"{minutes}m {seconds}s"
|
||||
logger.warning(f"🔒 Account is locked for user {login} for another {time_until_unlock_str}.")
|
||||
return False, None, f"Account is temporarily locked due to too many failed attempts. Retry after {time_until_unlock_str}"
|
||||
return (
|
||||
False,
|
||||
None,
|
||||
f"Account is temporarily locked due to too many failed attempts. Retry after {time_until_unlock_str}",
|
||||
)
|
||||
|
||||
# Verify password
|
||||
if not self.password_security.verify_password(password, auth_data.password_hash):
|
||||
@ -176,7 +183,9 @@ class AuthenticationManager:
|
||||
auth_data.locked_until = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=SecurityConfig.ACCOUNT_LOCKOUT_DURATION_MINUTES
|
||||
)
|
||||
logger.warning(f"🔒 Account locked for user {login} after {auth_data.login_attempts} failed attempts")
|
||||
logger.warning(
|
||||
f"🔒 Account locked for user {login} after {auth_data.login_attempts} failed attempts"
|
||||
)
|
||||
|
||||
# Update authentication record
|
||||
await self.database.set_authentication(user_data["id"], auth_data.model_dump())
|
||||
@ -238,6 +247,7 @@ class AuthenticationManager:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error updating last login for user {user_id}: {e}")
|
||||
|
||||
|
||||
# Utility functions for common operations
|
||||
def validate_password_strength(password: str) -> Tuple[bool, list]:
|
||||
"""
|
||||
@ -270,6 +280,7 @@ def validate_password_strength(password: str) -> Tuple[bool, list]:
|
||||
|
||||
return len(issues) == 0, issues
|
||||
|
||||
|
||||
def sanitize_login_input(login: str) -> str:
|
||||
"""Sanitize login input (email or username)"""
|
||||
return login.strip().lower() if login else ""
|
@ -19,7 +19,7 @@ from models import BaseUserWithType, Candidate, CandidateAI, Employer, Guest
|
||||
from logger import logger
|
||||
from background_tasks import BackgroundTaskManager
|
||||
|
||||
#from . rate_limiter import RateLimiter
|
||||
# from . rate_limiter import RateLimiter
|
||||
|
||||
# Security
|
||||
security = HTTPBearer()
|
||||
@ -33,11 +33,13 @@ background_task_manager: Optional[BackgroundTaskManager] = None
|
||||
# Global database manager reference
|
||||
db_manager = None
|
||||
|
||||
|
||||
def set_db_manager(manager: DatabaseManager):
|
||||
"""Set the global database manager reference"""
|
||||
global db_manager
|
||||
db_manager = manager
|
||||
|
||||
|
||||
def get_database() -> RedisDatabase:
|
||||
"""
|
||||
Safe database dependency that checks for availability
|
||||
@ -47,26 +49,18 @@ def get_database() -> RedisDatabase:
|
||||
|
||||
if db_manager is None:
|
||||
logger.error("Database manager not initialized")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Database not available - service starting up"
|
||||
)
|
||||
raise HTTPException(status_code=503, detail="Database not available - service starting up")
|
||||
|
||||
if db_manager.is_shutting_down:
|
||||
logger.warning("Database is shutting down")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Service is shutting down"
|
||||
)
|
||||
raise HTTPException(status_code=503, detail="Service is shutting down")
|
||||
|
||||
try:
|
||||
return db_manager.get_database()
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Database not available: {e}")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Database connection not available"
|
||||
)
|
||||
raise HTTPException(status_code=503, detail="Database connection not available")
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
||||
to_encode = data.copy()
|
||||
@ -78,6 +72,7 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
||||
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||
"""Enhanced token verification with guest session recovery"""
|
||||
try:
|
||||
@ -125,9 +120,9 @@ async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials
|
||||
logger.error(f"❌ Token verification error: {e}")
|
||||
raise HTTPException(status_code=401, detail="Token verification failed")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
user_id: str = Depends(verify_token_with_blacklist),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database)
|
||||
) -> BaseUserWithType:
|
||||
"""Get current user from database"""
|
||||
try:
|
||||
@ -135,6 +130,7 @@ async def get_current_user(
|
||||
candidate_data = await database.get_candidate(user_id)
|
||||
if candidate_data:
|
||||
from helpers.model_cast import cast_to_base_user_with_type
|
||||
|
||||
if candidate_data.get("is_AI"):
|
||||
return cast_to_base_user_with_type(CandidateAI.model_validate(candidate_data))
|
||||
else:
|
||||
@ -152,16 +148,20 @@ async def get_current_user(
|
||||
logger.error(f"❌ Error getting current user: {e}")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
|
||||
async def get_current_user_or_guest(
|
||||
user_id: str = Depends(verify_token_with_blacklist),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database)
|
||||
) -> BaseUserWithType:
|
||||
"""Get current user (including guests) from database"""
|
||||
try:
|
||||
# Check candidates first
|
||||
candidate_data = await database.get_candidate(user_id)
|
||||
if candidate_data:
|
||||
return Candidate.model_validate(candidate_data) if not candidate_data.get("is_AI") else CandidateAI.model_validate(candidate_data)
|
||||
return (
|
||||
Candidate.model_validate(candidate_data)
|
||||
if not candidate_data.get("is_AI")
|
||||
else CandidateAI.model_validate(candidate_data)
|
||||
)
|
||||
|
||||
# Check employers
|
||||
employer_data = await database.get_employer(user_id)
|
||||
@ -180,9 +180,9 @@ async def get_current_user_or_guest(
|
||||
logger.error(f"❌ Error getting current user: {e}")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
|
||||
async def get_current_admin(
|
||||
user_id: str = Depends(verify_token_with_blacklist),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database)
|
||||
) -> BaseUserWithType:
|
||||
user = await get_current_user(user_id=user_id, database=database)
|
||||
if isinstance(user, Candidate) and user.is_admin:
|
||||
@ -193,6 +193,7 @@ async def get_current_admin(
|
||||
logger.warning(f"⚠️ User {user_id} is not an admin")
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
|
||||
prometheus_collector = CollectorRegistry()
|
||||
|
||||
# Keep the Instrumentator instance alive
|
||||
@ -201,5 +202,5 @@ instrumentator = Instrumentator(
|
||||
should_ignore_untemplated=True,
|
||||
should_group_untemplated=True,
|
||||
excluded_handlers=[f"{defines.api_prefix}/metrics"],
|
||||
registry=prometheus_collector
|
||||
registry=prometheus_collector,
|
||||
)
|
||||
|
@ -1,5 +1,6 @@
|
||||
from typing import List, Dict
|
||||
from models import (Job)
|
||||
from models import Job
|
||||
|
||||
|
||||
def get_requirements_list(job: Job) -> List[Dict[str, str]]:
|
||||
requirements: List[Dict[str, str]] = []
|
||||
@ -7,56 +8,56 @@ def get_requirements_list(job: Job) -> List[Dict[str, str]]:
|
||||
if job.requirements:
|
||||
if job.requirements.technical_skills:
|
||||
if job.requirements.technical_skills.required:
|
||||
requirements.extend([
|
||||
requirements.extend(
|
||||
[
|
||||
{"requirement": req, "domain": "Technical Skills (required)"}
|
||||
for req in job.requirements.technical_skills.required
|
||||
])
|
||||
]
|
||||
)
|
||||
if job.requirements.technical_skills.preferred:
|
||||
requirements.extend([
|
||||
requirements.extend(
|
||||
[
|
||||
{"requirement": req, "domain": "Technical Skills (preferred)"}
|
||||
for req in job.requirements.technical_skills.preferred
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
if job.requirements.experience_requirements:
|
||||
if job.requirements.experience_requirements.required:
|
||||
requirements.extend([
|
||||
requirements.extend(
|
||||
[
|
||||
{"requirement": req, "domain": "Experience (required)"}
|
||||
for req in job.requirements.experience_requirements.required
|
||||
])
|
||||
]
|
||||
)
|
||||
if job.requirements.experience_requirements.preferred:
|
||||
requirements.extend([
|
||||
requirements.extend(
|
||||
[
|
||||
{"requirement": req, "domain": "Experience (preferred)"}
|
||||
for req in job.requirements.experience_requirements.preferred
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
if job.requirements.soft_skills:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Soft Skills"}
|
||||
for req in job.requirements.soft_skills
|
||||
])
|
||||
requirements.extend([{"requirement": req, "domain": "Soft Skills"} for req in job.requirements.soft_skills])
|
||||
|
||||
if job.requirements.experience:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Experience"}
|
||||
for req in job.requirements.experience
|
||||
])
|
||||
requirements.extend([{"requirement": req, "domain": "Experience"} for req in job.requirements.experience])
|
||||
|
||||
if job.requirements.education:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Education"}
|
||||
for req in job.requirements.education
|
||||
])
|
||||
requirements.extend([{"requirement": req, "domain": "Education"} for req in job.requirements.education])
|
||||
|
||||
if job.requirements.certifications:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Certifications"}
|
||||
for req in job.requirements.certifications
|
||||
])
|
||||
requirements.extend(
|
||||
[{"requirement": req, "domain": "Certifications"} for req in job.requirements.certifications]
|
||||
)
|
||||
|
||||
if job.requirements.preferred_attributes:
|
||||
requirements.extend([
|
||||
requirements.extend(
|
||||
[
|
||||
{"requirement": req, "domain": "Preferred Attributes"}
|
||||
for req in job.requirements.preferred_attributes
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
return requirements
|
@ -13,44 +13,13 @@ from fastapi.responses import StreamingResponse
|
||||
import defines
|
||||
from logger import logger
|
||||
from models import DocumentType
|
||||
from models import (
|
||||
LoginRequest, CreateCandidateRequest, CreateEmployerRequest,
|
||||
Candidate, Employer, Guest, AuthResponse,
|
||||
MFARequest, MFAData, MFARequestResponse, MFAVerifyRequest,
|
||||
EmailVerificationRequest, ResendVerificationRequest,
|
||||
# API
|
||||
MOCK_UUID, ApiActivityType, ChatMessageError, ChatMessageResume,
|
||||
ChatMessageSkillAssessment, ChatMessageStatus, ChatMessageStreaming,
|
||||
ChatMessageUser, DocumentMessage, DocumentOptions, Job,
|
||||
JobRequirements, JobRequirementsMessage, LoginRequest,
|
||||
CreateCandidateRequest, CreateEmployerRequest,
|
||||
|
||||
# User models
|
||||
Candidate, Employer, BaseUserWithType, BaseUser, Guest,
|
||||
Authentication, AuthResponse, CandidateAI,
|
||||
|
||||
# Job models
|
||||
JobApplication, ApplicationStatus,
|
||||
|
||||
# Chat models
|
||||
ChatSession, ChatMessage, ChatContext, ChatQuery, ChatSenderType, ApiMessageType, ChatContextType,
|
||||
ChatMessageRagSearch,
|
||||
|
||||
# Document models
|
||||
Document, DocumentType, DocumentListResponse, DocumentUpdateRequest, DocumentContentResponse,
|
||||
|
||||
# Supporting models
|
||||
Location, MFARequest, MFAData, MFARequestResponse, MFAVerifyRequest, RagContentMetadata, RagContentResponse, ResendVerificationRequest, Resume, ResumeMessage, Skill, SkillAssessment, SystemInfo, UserType, WorkExperience, Education,
|
||||
|
||||
# Email
|
||||
EmailVerificationRequest,
|
||||
ApiStatusType
|
||||
)
|
||||
from models import Job, ChatMessage, DocumentType, ApiStatusType
|
||||
|
||||
from typing import List, Dict
|
||||
from models import (Job)
|
||||
from models import Job
|
||||
import utils.llm_proxy as llm_manager
|
||||
|
||||
|
||||
async def get_last_item(generator):
|
||||
"""Get the last item from an async generator"""
|
||||
last_item = None
|
||||
@ -65,7 +34,7 @@ def filter_and_paginate(
|
||||
limit: int = 20,
|
||||
sort_by: Optional[str] = None,
|
||||
sort_order: str = "desc",
|
||||
filters: Optional[Dict] = None
|
||||
filters: Optional[Dict] = None,
|
||||
) -> Tuple[List[Any], int]:
|
||||
"""Filter, sort, and paginate items"""
|
||||
filtered_items = items.copy()
|
||||
@ -76,8 +45,7 @@ def filter_and_paginate(
|
||||
if isinstance(filtered_items[0], dict) and key in filtered_items[0]:
|
||||
filtered_items = [item for item in filtered_items if item.get(key) == value]
|
||||
elif hasattr(filtered_items[0], key) if filtered_items else False:
|
||||
filtered_items = [item for item in filtered_items
|
||||
if getattr(item, key, None) == value]
|
||||
filtered_items = [item for item in filtered_items if getattr(item, key, None) == value]
|
||||
|
||||
# Sort items
|
||||
if sort_by and filtered_items:
|
||||
@ -101,9 +69,9 @@ def filter_and_paginate(
|
||||
|
||||
async def stream_agent_response(chat_agent, user_message, chat_session_data=None, database=None) -> StreamingResponse:
|
||||
"""Stream agent response with proper formatting"""
|
||||
|
||||
async def message_stream_generator():
|
||||
"""Generator to stream messages with persistence"""
|
||||
last_log = None
|
||||
final_message = None
|
||||
|
||||
import utils.llm_proxy as llm_manager
|
||||
@ -127,12 +95,15 @@ async def stream_agent_response(chat_agent, user_message, chat_session_data=None
|
||||
# metadata and other unnecessary fields for streaming
|
||||
if generated_message.status != ApiStatusType.DONE:
|
||||
from models import ChatMessageStreaming, ChatMessageStatus
|
||||
if not isinstance(generated_message, ChatMessageStreaming) and not isinstance(generated_message, ChatMessageStatus):
|
||||
|
||||
if not isinstance(generated_message, ChatMessageStreaming) and not isinstance(
|
||||
generated_message, ChatMessageStatus
|
||||
):
|
||||
raise TypeError(
|
||||
f"Expected ChatMessageStreaming or ChatMessageStatus, got {type(generated_message)}"
|
||||
)
|
||||
|
||||
json_data = generated_message.model_dump(mode='json', by_alias=True)
|
||||
json_data = generated_message.model_dump(mode="json", by_alias=True)
|
||||
json_str = json.dumps(json_data)
|
||||
|
||||
yield f"data: {json_str}\n\n"
|
||||
@ -177,16 +148,16 @@ def get_document_type_from_filename(filename: str) -> DocumentType:
|
||||
extension = pathlib.Path(filename).suffix.lower()
|
||||
|
||||
type_mapping = {
|
||||
'.pdf': DocumentType.PDF,
|
||||
'.docx': DocumentType.DOCX,
|
||||
'.doc': DocumentType.DOCX,
|
||||
'.txt': DocumentType.TXT,
|
||||
'.md': DocumentType.MARKDOWN,
|
||||
'.markdown': DocumentType.MARKDOWN,
|
||||
'.png': DocumentType.IMAGE,
|
||||
'.jpg': DocumentType.IMAGE,
|
||||
'.jpeg': DocumentType.IMAGE,
|
||||
'.gif': DocumentType.IMAGE,
|
||||
".pdf": DocumentType.PDF,
|
||||
".docx": DocumentType.DOCX,
|
||||
".doc": DocumentType.DOCX,
|
||||
".txt": DocumentType.TXT,
|
||||
".md": DocumentType.MARKDOWN,
|
||||
".markdown": DocumentType.MARKDOWN,
|
||||
".png": DocumentType.IMAGE,
|
||||
".jpg": DocumentType.IMAGE,
|
||||
".jpeg": DocumentType.IMAGE,
|
||||
".gif": DocumentType.IMAGE,
|
||||
}
|
||||
|
||||
return type_mapping.get(extension, DocumentType.TXT)
|
||||
@ -206,17 +177,12 @@ async def reformat_as_markdown(database, candidate_entity, content: str):
|
||||
|
||||
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
|
||||
if not chat_agent:
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID,
|
||||
content="No agent found for job requirements chat type"
|
||||
)
|
||||
error_message = ChatMessageError(session_id=MOCK_UUID, content="No agent found for job requirements chat type")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID,
|
||||
content=f"Reformatting job description as markdown...",
|
||||
activity=ApiActivityType.CONVERTING
|
||||
session_id=MOCK_UUID, content="Reformatting job description as markdown...", activity=ApiActivityType.CONVERTING
|
||||
)
|
||||
yield status_message
|
||||
|
||||
@ -229,26 +195,23 @@ async def reformat_as_markdown(database, candidate_entity, content: str):
|
||||
system_prompt="""
|
||||
You are a document editor. Take the provided job description and reformat as legible markdown.
|
||||
Return only the markdown content, no other text. Make sure all content is included.
|
||||
"""
|
||||
""",
|
||||
):
|
||||
pass
|
||||
|
||||
if not message or not isinstance(message, ChatMessage):
|
||||
logger.error("❌ Failed to reformat job description to markdown")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID,
|
||||
content="Failed to reformat job description"
|
||||
)
|
||||
error_message = ChatMessageError(session_id=MOCK_UUID, content="Failed to reformat job description")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
chat_message: ChatMessage = message
|
||||
try:
|
||||
chat_message.content = chat_agent.extract_markdown_from_text(chat_message.content)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(f"✅ Successfully converted content to markdown")
|
||||
logger.info("✅ Successfully converted content to markdown")
|
||||
yield chat_message
|
||||
return
|
||||
|
||||
@ -256,14 +219,19 @@ Return only the markdown content, no other text. Make sure all content is includ
|
||||
async def create_job_from_content(database, current_user, content: str):
|
||||
"""Create a job from content using AI analysis"""
|
||||
from models import (
|
||||
MOCK_UUID, ApiStatusType, ChatMessageError, ChatMessageStatus,
|
||||
ApiActivityType, ChatContextType, JobRequirementsMessage
|
||||
MOCK_UUID,
|
||||
ApiStatusType,
|
||||
ChatMessageError,
|
||||
ChatMessageStatus,
|
||||
ApiActivityType,
|
||||
ChatContextType,
|
||||
JobRequirementsMessage,
|
||||
)
|
||||
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID,
|
||||
content=f"Initiating connection with {current_user.first_name}'s AI agent...",
|
||||
activity=ApiActivityType.INFO
|
||||
activity=ApiActivityType.INFO,
|
||||
)
|
||||
yield status_message
|
||||
await asyncio.sleep(0) # Let the status message propagate
|
||||
@ -278,10 +246,7 @@ async def create_job_from_content(database, current_user, content: str):
|
||||
yield message
|
||||
|
||||
if not message or not isinstance(message, ChatMessage):
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID,
|
||||
content="Failed to reformat job description"
|
||||
)
|
||||
error_message = ChatMessageError(session_id=MOCK_UUID, content="Failed to reformat job description")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
@ -290,33 +255,28 @@ async def create_job_from_content(database, current_user, content: str):
|
||||
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
|
||||
if not chat_agent:
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID,
|
||||
content="No agent found for job requirements chat type"
|
||||
session_id=MOCK_UUID, content="No agent found for job requirements chat type"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID,
|
||||
content=f"Analyzing document for company and requirement details...",
|
||||
activity=ApiActivityType.SEARCHING
|
||||
content="Analyzing document for company and requirement details...",
|
||||
activity=ApiActivityType.SEARCHING,
|
||||
)
|
||||
yield status_message
|
||||
|
||||
message = None
|
||||
async for message in chat_agent.generate(
|
||||
llm=llm_manager.get_llm(),
|
||||
model=defines.model,
|
||||
session_id=MOCK_UUID,
|
||||
prompt=markdown_message.content
|
||||
llm=llm_manager.get_llm(), model=defines.model, session_id=MOCK_UUID, prompt=markdown_message.content
|
||||
):
|
||||
if message.status != ApiStatusType.DONE:
|
||||
yield message
|
||||
|
||||
if not message or not isinstance(message, JobRequirementsMessage):
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID,
|
||||
content="Job extraction did not convert successfully"
|
||||
session_id=MOCK_UUID, content="Job extraction did not convert successfully"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
@ -326,62 +286,63 @@ async def create_job_from_content(database, current_user, content: str):
|
||||
yield job_requirements
|
||||
return
|
||||
|
||||
|
||||
def get_requirements_list(job: Job) -> List[Dict[str, str]]:
|
||||
requirements: List[Dict[str, str]] = []
|
||||
|
||||
if job.requirements:
|
||||
if job.requirements.technical_skills:
|
||||
if job.requirements.technical_skills.required:
|
||||
requirements.extend([
|
||||
requirements.extend(
|
||||
[
|
||||
{"requirement": req, "domain": "Technical Skills (required)"}
|
||||
for req in job.requirements.technical_skills.required
|
||||
])
|
||||
]
|
||||
)
|
||||
if job.requirements.technical_skills.preferred:
|
||||
requirements.extend([
|
||||
requirements.extend(
|
||||
[
|
||||
{"requirement": req, "domain": "Technical Skills (preferred)"}
|
||||
for req in job.requirements.technical_skills.preferred
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
if job.requirements.experience_requirements:
|
||||
if job.requirements.experience_requirements.required:
|
||||
requirements.extend([
|
||||
requirements.extend(
|
||||
[
|
||||
{"requirement": req, "domain": "Experience (required)"}
|
||||
for req in job.requirements.experience_requirements.required
|
||||
])
|
||||
]
|
||||
)
|
||||
if job.requirements.experience_requirements.preferred:
|
||||
requirements.extend([
|
||||
requirements.extend(
|
||||
[
|
||||
{"requirement": req, "domain": "Experience (preferred)"}
|
||||
for req in job.requirements.experience_requirements.preferred
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
if job.requirements.soft_skills:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Soft Skills"}
|
||||
for req in job.requirements.soft_skills
|
||||
])
|
||||
requirements.extend([{"requirement": req, "domain": "Soft Skills"} for req in job.requirements.soft_skills])
|
||||
|
||||
if job.requirements.experience:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Experience"}
|
||||
for req in job.requirements.experience
|
||||
])
|
||||
requirements.extend([{"requirement": req, "domain": "Experience"} for req in job.requirements.experience])
|
||||
|
||||
if job.requirements.education:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Education"}
|
||||
for req in job.requirements.education
|
||||
])
|
||||
requirements.extend([{"requirement": req, "domain": "Education"} for req in job.requirements.education])
|
||||
|
||||
if job.requirements.certifications:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Certifications"}
|
||||
for req in job.requirements.certifications
|
||||
])
|
||||
requirements.extend(
|
||||
[{"requirement": req, "domain": "Certifications"} for req in job.requirements.certifications]
|
||||
)
|
||||
|
||||
if job.requirements.preferred_attributes:
|
||||
requirements.extend([
|
||||
requirements.extend(
|
||||
[
|
||||
{"requirement": req, "domain": "Preferred Attributes"}
|
||||
for req in job.requirements.preferred_attributes
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
return requirements
|
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,7 @@
|
||||
from prometheus_client import Counter, Histogram # type: ignore
|
||||
from threading import Lock
|
||||
|
||||
|
||||
def singleton(cls):
|
||||
instance = None
|
||||
lock = Lock()
|
||||
|
@ -10,52 +10,64 @@ from pydantic import BaseModel # type: ignore
|
||||
from database.manager import RedisDatabase
|
||||
from logger import logger
|
||||
|
||||
from . dependencies import get_current_user_or_guest, get_database
|
||||
from .dependencies import get_current_user_or_guest, get_database
|
||||
|
||||
|
||||
async def get_rate_limiter(database: RedisDatabase = Depends(get_database)) -> RateLimiter:
|
||||
"""Dependency to get rate limiter instance"""
|
||||
return RateLimiter(database)
|
||||
|
||||
|
||||
class RateLimitConfig(BaseModel):
|
||||
"""Rate limit configuration"""
|
||||
|
||||
requests_per_minute: int
|
||||
requests_per_hour: int
|
||||
requests_per_day: int
|
||||
burst_limit: int # Maximum requests in a short burst
|
||||
burst_window_seconds: int = 60 # Window for burst detection
|
||||
|
||||
|
||||
class GuestRateLimitConfig(RateLimitConfig):
|
||||
"""Rate limits for guest users - more restrictive"""
|
||||
|
||||
requests_per_minute: int = 10
|
||||
requests_per_hour: int = 100
|
||||
requests_per_day: int = 500
|
||||
burst_limit: int = 15
|
||||
burst_window_seconds: int = 60
|
||||
|
||||
|
||||
class AuthenticatedUserRateLimitConfig(RateLimitConfig):
|
||||
"""Rate limits for authenticated users - more generous"""
|
||||
|
||||
requests_per_minute: int = 60
|
||||
requests_per_hour: int = 1000
|
||||
requests_per_day: int = 10000
|
||||
burst_limit: int = 100
|
||||
burst_window_seconds: int = 60
|
||||
|
||||
|
||||
class PremiumUserRateLimitConfig(RateLimitConfig):
|
||||
"""Rate limits for premium/admin users - most generous"""
|
||||
|
||||
requests_per_minute: int = 120
|
||||
requests_per_hour: int = 5000
|
||||
requests_per_day: int = 50000
|
||||
burst_limit: int = 200
|
||||
burst_window_seconds: int = 60
|
||||
|
||||
|
||||
class RateLimitResult(BaseModel):
|
||||
"""Result of rate limit check"""
|
||||
|
||||
allowed: bool
|
||||
reason: Optional[str] = None
|
||||
retry_after_seconds: Optional[int] = None
|
||||
remaining_requests: Dict[str, int] = {}
|
||||
reset_times: Dict[str, datetime] = {}
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Rate limiter using Redis for distributed rate limiting"""
|
||||
|
||||
@ -78,11 +90,7 @@ class RateLimiter:
|
||||
return self.user_config
|
||||
|
||||
async def check_rate_limit(
|
||||
self,
|
||||
user_id: str,
|
||||
user_type: str,
|
||||
is_admin: bool = False,
|
||||
endpoint: Optional[str] = None
|
||||
self, user_id: str, user_type: str, is_admin: bool = False, endpoint: Optional[str] = None
|
||||
) -> RateLimitResult:
|
||||
"""
|
||||
Check if user has exceeded rate limits
|
||||
@ -105,7 +113,7 @@ class RateLimiter:
|
||||
"minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}",
|
||||
"hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}",
|
||||
"day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}",
|
||||
"burst": f"{base_key}:burst"
|
||||
"burst": f"{base_key}:burst",
|
||||
}
|
||||
|
||||
# Add endpoint-specific limiting if provided
|
||||
@ -125,7 +133,7 @@ class RateLimiter:
|
||||
"minute": int(results[0] or 0),
|
||||
"hour": int(results[1] or 0),
|
||||
"day": int(results[2] or 0),
|
||||
"burst": int(results[3] or 0)
|
||||
"burst": int(results[3] or 0),
|
||||
}
|
||||
|
||||
# Check limits
|
||||
@ -133,7 +141,7 @@ class RateLimiter:
|
||||
"minute": config.requests_per_minute,
|
||||
"hour": config.requests_per_hour,
|
||||
"day": config.requests_per_day,
|
||||
"burst": config.burst_limit
|
||||
"burst": config.burst_limit,
|
||||
}
|
||||
|
||||
# Check each limit
|
||||
@ -146,18 +154,22 @@ class RateLimiter:
|
||||
elif window == "hour":
|
||||
retry_after = 3600 - (current_time.minute * 60 + current_time.second)
|
||||
elif window == "day":
|
||||
retry_after = 86400 - (current_time.hour * 3600 + current_time.minute * 60 + current_time.second)
|
||||
retry_after = 86400 - (
|
||||
current_time.hour * 3600 + current_time.minute * 60 + current_time.second
|
||||
)
|
||||
else: # burst
|
||||
retry_after = config.burst_window_seconds
|
||||
|
||||
logger.warning(f"🚫 Rate limit exceeded for {user_type} {user_id}: {current_count}/{limit} {window}")
|
||||
logger.warning(
|
||||
f"🚫 Rate limit exceeded for {user_type} {user_id}: {current_count}/{limit} {window}"
|
||||
)
|
||||
|
||||
return RateLimitResult(
|
||||
allowed=False,
|
||||
reason=f"Rate limit exceeded: {current_count}/{limit} requests per {window}",
|
||||
retry_after_seconds=retry_after,
|
||||
remaining_requests={k: max(0, limits[k] - v) for k, v in current_counts.items()},
|
||||
reset_times=self._calculate_reset_times(current_time)
|
||||
reset_times=self._calculate_reset_times(current_time),
|
||||
)
|
||||
|
||||
# If we get here, request is allowed - increment counters
|
||||
@ -182,17 +194,12 @@ class RateLimiter:
|
||||
await pipe.execute()
|
||||
|
||||
# Calculate remaining requests
|
||||
remaining = {
|
||||
k: max(0, limits[k] - (current_counts[k] + 1))
|
||||
for k in current_counts.keys()
|
||||
}
|
||||
remaining = {k: max(0, limits[k] - (current_counts[k] + 1)) for k in current_counts.keys()}
|
||||
|
||||
logger.debug(f"✅ Rate limit check passed for {user_type} {user_id}")
|
||||
|
||||
return RateLimitResult(
|
||||
allowed=True,
|
||||
remaining_requests=remaining,
|
||||
reset_times=self._calculate_reset_times(current_time)
|
||||
allowed=True, remaining_requests=remaining, reset_times=self._calculate_reset_times(current_time)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@ -206,18 +213,9 @@ class RateLimiter:
|
||||
next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
|
||||
next_day = current_time.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
|
||||
|
||||
return {
|
||||
"minute": next_minute,
|
||||
"hour": next_hour,
|
||||
"day": next_day
|
||||
}
|
||||
return {"minute": next_minute, "hour": next_hour, "day": next_day}
|
||||
|
||||
async def get_user_rate_limit_status(
|
||||
self,
|
||||
user_id: str,
|
||||
user_type: str,
|
||||
is_admin: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
async def get_user_rate_limit_status(self, user_id: str, user_type: str, is_admin: bool = False) -> Dict[str, Any]:
|
||||
"""Get current rate limit status for a user"""
|
||||
config = self.get_config_for_user(user_type, is_admin)
|
||||
current_time = datetime.now(UTC)
|
||||
@ -227,7 +225,7 @@ class RateLimiter:
|
||||
"minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}",
|
||||
"hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}",
|
||||
"day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}",
|
||||
"burst": f"{base_key}:burst"
|
||||
"burst": f"{base_key}:burst",
|
||||
}
|
||||
|
||||
try:
|
||||
@ -240,14 +238,14 @@ class RateLimiter:
|
||||
"minute": int(results[0] or 0),
|
||||
"hour": int(results[1] or 0),
|
||||
"day": int(results[2] or 0),
|
||||
"burst": int(results[3] or 0)
|
||||
"burst": int(results[3] or 0),
|
||||
}
|
||||
|
||||
limits = {
|
||||
"minute": config.requests_per_minute,
|
||||
"hour": config.requests_per_hour,
|
||||
"day": config.requests_per_day,
|
||||
"burst": config.burst_limit
|
||||
"burst": config.burst_limit,
|
||||
}
|
||||
|
||||
return {
|
||||
@ -258,7 +256,7 @@ class RateLimiter:
|
||||
"limits": limits,
|
||||
"remaining": {k: max(0, limits[k] - current_counts[k]) for k in limits.keys()},
|
||||
"reset_times": self._calculate_reset_times(current_time),
|
||||
"config": config.model_dump()
|
||||
"config": config.model_dump(),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@ -295,11 +293,9 @@ class RateLimiter:
|
||||
# Rate Limited Decorator
|
||||
# ============================
|
||||
|
||||
|
||||
def rate_limited(
|
||||
guest_per_minute: int = 10,
|
||||
user_per_minute: int = 60,
|
||||
admin_per_minute: int = 120,
|
||||
endpoint_specific: bool = True
|
||||
guest_per_minute: int = 10, user_per_minute: int = 60, admin_per_minute: int = 120, endpoint_specific: bool = True
|
||||
):
|
||||
"""
|
||||
Decorator to easily apply rate limiting to endpoints
|
||||
@ -320,12 +316,14 @@ def rate_limited(
|
||||
):
|
||||
return {"message": "Rate limited endpoint"}
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Extract dependencies from function signature
|
||||
import inspect
|
||||
sig = inspect.signature(func)
|
||||
|
||||
inspect.signature(func)
|
||||
|
||||
# Get request, current_user, and rate_limiter from kwargs or args
|
||||
request = None
|
||||
@ -336,7 +334,7 @@ def rate_limited(
|
||||
for param_name, param_value in kwargs.items():
|
||||
if isinstance(param_value, Request):
|
||||
request = param_value
|
||||
elif hasattr(param_value, 'user_type'): # User-like object
|
||||
elif hasattr(param_value, "user_type"): # User-like object
|
||||
current_user = param_value
|
||||
elif isinstance(param_value, RateLimiter):
|
||||
rate_limiter = param_value
|
||||
@ -350,30 +348,33 @@ def rate_limited(
|
||||
# Apply rate limiting if we have the required components
|
||||
if request and current_user and rate_limiter:
|
||||
await apply_custom_rate_limiting(
|
||||
request, current_user, rate_limiter,
|
||||
guest_per_minute, user_per_minute, admin_per_minute
|
||||
request, current_user, rate_limiter, guest_per_minute, user_per_minute, admin_per_minute
|
||||
)
|
||||
|
||||
# Call the original function
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
async def apply_custom_rate_limiting(
|
||||
request: Request,
|
||||
current_user,
|
||||
rate_limiter: RateLimiter,
|
||||
guest_per_minute: int,
|
||||
user_per_minute: int,
|
||||
admin_per_minute: int
|
||||
admin_per_minute: int,
|
||||
):
|
||||
"""Apply custom rate limiting with specified limits"""
|
||||
try:
|
||||
# Determine user info
|
||||
user_id = current_user.id
|
||||
user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type)
|
||||
is_admin = getattr(current_user, 'is_admin', False)
|
||||
user_type = (
|
||||
current_user.user_type.value if hasattr(current_user.user_type, "value") else str(current_user.user_type)
|
||||
)
|
||||
is_admin = getattr(current_user, "is_admin", False)
|
||||
|
||||
# Determine appropriate limit
|
||||
if is_admin:
|
||||
@ -385,13 +386,17 @@ async def apply_custom_rate_limiting(
|
||||
|
||||
# Create custom rate limit key
|
||||
current_time = datetime.now(UTC)
|
||||
custom_key = f"custom_rate_limit:{request.url.path}:{user_type}:{user_id}:minute:{current_time.strftime('%Y%m%d%H%M')}"
|
||||
custom_key = (
|
||||
f"custom_rate_limit:{request.url.path}:{user_type}:{user_id}:minute:{current_time.strftime('%Y%m%d%H%M')}"
|
||||
)
|
||||
|
||||
# Check current usage
|
||||
current_count = int(await rate_limiter.redis.get(custom_key) or 0)
|
||||
|
||||
if current_count >= requests_per_minute:
|
||||
logger.warning(f"🚫 Custom rate limit exceeded for {user_type} {user_id}: {current_count}/{requests_per_minute}")
|
||||
logger.warning(
|
||||
f"🚫 Custom rate limit exceeded for {user_type} {user_id}: {current_count}/{requests_per_minute}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
@ -399,9 +404,9 @@ async def apply_custom_rate_limiting(
|
||||
"message": f"Custom rate limit exceeded: {current_count}/{requests_per_minute} requests per minute",
|
||||
"retryAfter": 60 - current_time.second,
|
||||
"userType": user_type,
|
||||
"endpoint": request.url.path
|
||||
"endpoint": request.url.path,
|
||||
},
|
||||
headers={"Retry-After": str(60 - current_time.second)}
|
||||
headers={"Retry-After": str(60 - current_time.second)},
|
||||
)
|
||||
|
||||
# Increment counter
|
||||
@ -410,7 +415,9 @@ async def apply_custom_rate_limiting(
|
||||
pipe.expire(custom_key, 120) # 2 minutes TTL
|
||||
await pipe.execute()
|
||||
|
||||
logger.debug(f"✅ Custom rate limit check passed for {user_type} {user_id}: {current_count + 1}/{requests_per_minute}")
|
||||
logger.debug(
|
||||
f"✅ Custom rate limit check passed for {user_type} {user_id}: {current_count + 1}/{requests_per_minute}"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@ -418,15 +425,13 @@ async def apply_custom_rate_limiting(
|
||||
logger.error(f"❌ Custom rate limiting error: {e}")
|
||||
# Fail open
|
||||
|
||||
|
||||
# ============================
|
||||
# Alternative: FastAPI Dependency-Based Rate Limiting
|
||||
# ============================
|
||||
|
||||
def create_rate_limit_dependency(
|
||||
guest_per_minute: int = 10,
|
||||
user_per_minute: int = 60,
|
||||
admin_per_minute: int = 120
|
||||
):
|
||||
|
||||
def create_rate_limit_dependency(guest_per_minute: int = 10, user_per_minute: int = 60, admin_per_minute: int = 120):
|
||||
"""
|
||||
Create a FastAPI dependency for rate limiting
|
||||
|
||||
@ -441,23 +446,25 @@ def create_rate_limit_dependency(
|
||||
):
|
||||
return {"message": "Rate limited endpoint"}
|
||||
"""
|
||||
|
||||
async def rate_limit_dependency(
|
||||
request: Request,
|
||||
current_user = Depends(get_current_user_or_guest),
|
||||
rate_limiter: RateLimiter = Depends(get_rate_limiter)
|
||||
current_user=Depends(get_current_user_or_guest),
|
||||
rate_limiter: RateLimiter = Depends(get_rate_limiter),
|
||||
):
|
||||
await apply_custom_rate_limiting(
|
||||
request, current_user, rate_limiter,
|
||||
guest_per_minute, user_per_minute, admin_per_minute
|
||||
request, current_user, rate_limiter, guest_per_minute, user_per_minute, admin_per_minute
|
||||
)
|
||||
return True
|
||||
|
||||
return rate_limit_dependency
|
||||
|
||||
|
||||
# ============================
|
||||
# Rate Limiting Utilities
|
||||
# ============================
|
||||
|
||||
|
||||
class EndpointRateLimiter:
|
||||
"""Utility class for endpoint-specific rate limiting"""
|
||||
|
||||
@ -477,9 +484,11 @@ class EndpointRateLimiter:
|
||||
return True # No custom limits set
|
||||
|
||||
limits = self.custom_limits[endpoint]
|
||||
user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type)
|
||||
user_type = (
|
||||
current_user.user_type.value if hasattr(current_user.user_type, "value") else str(current_user.user_type)
|
||||
)
|
||||
|
||||
if getattr(current_user, 'is_admin', False):
|
||||
if getattr(current_user, "is_admin", False):
|
||||
user_type = "admin"
|
||||
|
||||
limit = limits.get(user_type, limits.get("default", 60))
|
||||
@ -491,8 +500,7 @@ class EndpointRateLimiter:
|
||||
|
||||
if current_count >= limit:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Endpoint rate limit exceeded: {current_count}/{limit} for {endpoint}"
|
||||
status_code=429, detail=f"Endpoint rate limit exceeded: {current_count}/{limit} for {endpoint}"
|
||||
)
|
||||
|
||||
# Increment counter
|
||||
@ -501,9 +509,11 @@ class EndpointRateLimiter:
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# Global endpoint rate limiter instance
|
||||
endpoint_rate_limiter = None
|
||||
|
||||
|
||||
def get_endpoint_rate_limiter(rate_limiter: RateLimiter = Depends(get_rate_limiter)) -> EndpointRateLimiter:
|
||||
"""Get endpoint rate limiter instance"""
|
||||
global endpoint_rate_limiter
|
||||
@ -511,15 +521,14 @@ def get_endpoint_rate_limiter(rate_limiter: RateLimiter = Depends(get_rate_limit
|
||||
endpoint_rate_limiter = EndpointRateLimiter(rate_limiter)
|
||||
|
||||
# Configure endpoint-specific limits
|
||||
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/chat/sessions/*/messages/stream", {
|
||||
"guest": 5, "candidate": 30, "employer": 30, "admin": 100
|
||||
})
|
||||
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/candidates/documents/upload", {
|
||||
"guest": 2, "candidate": 10, "employer": 10, "admin": 50
|
||||
})
|
||||
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/jobs", {
|
||||
"guest": 1, "candidate": 5, "employer": 20, "admin": 50
|
||||
})
|
||||
endpoint_rate_limiter.set_endpoint_limits(
|
||||
"/api/1.0/chat/sessions/*/messages/stream", {"guest": 5, "candidate": 30, "employer": 30, "admin": 100}
|
||||
)
|
||||
endpoint_rate_limiter.set_endpoint_limits(
|
||||
"/api/1.0/candidates/documents/upload", {"guest": 2, "candidate": 10, "employer": 10, "admin": 50}
|
||||
)
|
||||
endpoint_rate_limiter.set_endpoint_limits(
|
||||
"/api/1.0/jobs", {"guest": 1, "candidate": 5, "employer": 20, "admin": 50}
|
||||
)
|
||||
|
||||
return endpoint_rate_limiter
|
||||
|
||||
|
@ -3,37 +3,17 @@ Response utility functions for consistent API responses
|
||||
"""
|
||||
from typing import Any, Optional, Dict, List
|
||||
|
||||
|
||||
def create_success_response(data: Any, meta: Optional[Dict] = None) -> Dict:
|
||||
return {
|
||||
"success": True,
|
||||
"data": data,
|
||||
"meta": meta
|
||||
}
|
||||
return {"success": True, "data": data, "meta": meta}
|
||||
|
||||
|
||||
def create_error_response(code: str, message: str, details: Any = None) -> Dict:
|
||||
return {
|
||||
"success": False,
|
||||
"error": {
|
||||
"code": code,
|
||||
"message": message,
|
||||
"details": details
|
||||
}
|
||||
}
|
||||
return {"success": False, "error": {"code": code, "message": message, "details": details}}
|
||||
|
||||
def create_paginated_response(
|
||||
data: List[Any],
|
||||
page: int,
|
||||
limit: int,
|
||||
total: int
|
||||
) -> Dict:
|
||||
|
||||
def create_paginated_response(data: List[Any], page: int, limit: int, total: int) -> Dict:
|
||||
total_pages = (total + limit - 1) // limit
|
||||
has_more = page < total_pages
|
||||
|
||||
return {
|
||||
"data": data,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"totalPages": total_pages,
|
||||
"hasMore": has_more
|
||||
}
|
||||
return {"data": data, "total": total, "page": page, "limit": limit, "totalPages": total_pages, "hasMore": has_more}
|
||||
|
Loading…
x
Reference in New Issue
Block a user