Compare commits
2 Commits
f53ff967cb
...
2cf3fa7b04
Author | SHA1 | Date | |
---|---|---|---|
2cf3fa7b04 | |||
f1c2e16389 |
@ -17,11 +17,9 @@ from logger import logger
|
||||
AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses
|
||||
|
||||
# Maps class_name to (module_name, class_name)
|
||||
class_registry: Dict[str, Tuple[str, str]] = (
|
||||
{}
|
||||
)
|
||||
class_registry: Dict[str, Tuple[str, str]] = {}
|
||||
|
||||
__all__ = ['get_or_create_agent']
|
||||
__all__ = ["get_or_create_agent"]
|
||||
|
||||
package_dir = pathlib.Path(__file__).parent
|
||||
package_name = __name__
|
||||
@ -38,16 +36,11 @@ for path in package_dir.glob("*.py"):
|
||||
|
||||
# Find all Agent subclasses in the module
|
||||
for name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if (
|
||||
issubclass(obj, AnyAgent)
|
||||
and obj is not AnyAgent
|
||||
and obj is not Agent
|
||||
and name not in class_registry
|
||||
):
|
||||
if issubclass(obj, AnyAgent) and obj is not AnyAgent and obj is not Agent and name not in class_registry:
|
||||
class_registry[name] = (full_module_name, name)
|
||||
globals()[name] = obj
|
||||
logger.info(f"Adding agent: {name}")
|
||||
__all__.append(name) # type: ignore
|
||||
__all__.append(name) # type: ignore
|
||||
except ImportError as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"Error importing {full_module_name}: {e}")
|
||||
|
@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
||||
from pydantic import BaseModel, Field # type: ignore
|
||||
from typing import (
|
||||
Literal,
|
||||
get_args,
|
||||
@ -13,10 +13,10 @@ import time
|
||||
import re
|
||||
from abc import ABC
|
||||
from datetime import datetime, UTC
|
||||
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
|
||||
from prometheus_client import CollectorRegistry # type: ignore
|
||||
import numpy as np # type: ignore
|
||||
import json_extractor as json_extractor
|
||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
||||
from pydantic import BaseModel, Field # type: ignore
|
||||
from uuid import uuid4
|
||||
from typing import List, Optional, ClassVar, Any, Literal
|
||||
|
||||
@ -24,7 +24,7 @@ from datetime import datetime, UTC
|
||||
import numpy as np # type: ignore
|
||||
|
||||
from uuid import uuid4
|
||||
from prometheus_client import CollectorRegistry, Counter # type: ignore
|
||||
from prometheus_client import CollectorRegistry # type: ignore
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
@ -32,19 +32,45 @@ from pathlib import Path
|
||||
from rag import start_file_watcher, ChromaDBFileWatcher
|
||||
import defines
|
||||
from logger import logger
|
||||
from models import (Tunables, ChatMessageUser, ChatMessage, RagEntry, ChatMessageMetaData, ApiStatusType, Candidate, ChatContextType)
|
||||
from models import (
|
||||
Tunables,
|
||||
ChatMessageUser,
|
||||
ChatMessage,
|
||||
RagEntry,
|
||||
ChatMessageMetaData,
|
||||
ApiStatusType,
|
||||
Candidate,
|
||||
ChatContextType,
|
||||
)
|
||||
import utils.llm_proxy as llm_manager
|
||||
from database.manager import RedisDatabase
|
||||
from models import ChromaDBGetResponse
|
||||
from utils.metrics import Metrics
|
||||
|
||||
|
||||
from models import ( ApiActivityType, ApiMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiStatusType, ChatMessageMetaData, Candidate)
|
||||
from models import (
|
||||
ApiActivityType,
|
||||
ApiMessage,
|
||||
ChatMessageError,
|
||||
ChatMessageRagSearch,
|
||||
ChatMessageStatus,
|
||||
ChatMessageStreaming,
|
||||
LLMMessage,
|
||||
ChatMessage,
|
||||
ChatOptions,
|
||||
ChatMessageUser,
|
||||
Tunables,
|
||||
ApiStatusType,
|
||||
ChatMessageMetaData,
|
||||
Candidate,
|
||||
)
|
||||
from logger import logger
|
||||
import defines
|
||||
from .registry import agent_registry
|
||||
|
||||
from models import ( ChromaDBGetResponse )
|
||||
from models import ChromaDBGetResponse
|
||||
|
||||
|
||||
class CandidateEntity(Candidate):
|
||||
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
|
||||
|
||||
@ -59,13 +85,10 @@ class CandidateEntity(Candidate):
|
||||
CandidateEntity__agents: List[Agent] = []
|
||||
CandidateEntity__observer: Optional[Any] = Field(default=None, exclude=True)
|
||||
CandidateEntity__file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
|
||||
CandidateEntity__prometheus_collector: Optional[CollectorRegistry] = Field(
|
||||
default=None, exclude=True
|
||||
)
|
||||
CandidateEntity__prometheus_collector: Optional[CollectorRegistry] = Field(default=None, exclude=True)
|
||||
|
||||
CandidateEntity__metrics: Optional[Metrics] = Field(
|
||||
default=None,
|
||||
description="Metrics collector for this agent, used to track performance and usage."
|
||||
default=None, description="Metrics collector for this agent, used to track performance and usage."
|
||||
)
|
||||
|
||||
def __init__(self, candidate=None):
|
||||
@ -78,7 +101,7 @@ class CandidateEntity(Candidate):
|
||||
@classmethod
|
||||
def exists(cls, username: str):
|
||||
# Validate username format (only allow safe characters)
|
||||
if not re.match(r'^[a-zA-Z0-9_-]+$', username):
|
||||
if not re.match(r"^[a-zA-Z0-9_-]+$", username):
|
||||
return False # Invalid username characters
|
||||
|
||||
# Check for minimum and maximum length
|
||||
@ -117,11 +140,7 @@ class CandidateEntity(Candidate):
|
||||
if agent.agent_type == agent_type:
|
||||
return agent
|
||||
|
||||
return get_or_create_agent(
|
||||
agent_type=agent_type,
|
||||
user=self,
|
||||
prometheus_collector=self.prometheus_collector
|
||||
)
|
||||
return get_or_create_agent(agent_type=agent_type, user=self, prometheus_collector=self.prometheus_collector)
|
||||
|
||||
# Wrapper properties that map into file_watcher
|
||||
@property
|
||||
@ -132,6 +151,7 @@ class CandidateEntity(Candidate):
|
||||
|
||||
# Fields managed by initialize()
|
||||
CandidateEntity__initialized: bool = Field(default=False, exclude=True)
|
||||
|
||||
@property
|
||||
def metrics(self) -> Metrics:
|
||||
if not self.CandidateEntity__metrics:
|
||||
@ -160,15 +180,10 @@ class CandidateEntity(Candidate):
|
||||
if not self.metrics:
|
||||
logger.warning("No metrics collector set for this agent.")
|
||||
return
|
||||
self.metrics.tokens_prompt.labels(agent=agent.agent_type).inc(
|
||||
response.usage.prompt_eval_count
|
||||
)
|
||||
self.metrics.tokens_prompt.labels(agent=agent.agent_type).inc(response.usage.prompt_eval_count)
|
||||
self.metrics.tokens_eval.labels(agent=agent.agent_type).inc(response.usage.eval_count)
|
||||
|
||||
async def initialize(
|
||||
self,
|
||||
prometheus_collector: CollectorRegistry,
|
||||
database: RedisDatabase):
|
||||
async def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
|
||||
if self.CandidateEntity__initialized:
|
||||
# Initialization can only be attempted once; if there are multiple attempts, it means
|
||||
# a subsystem is failing or there is a logic bug in the code.
|
||||
@ -188,8 +203,8 @@ class CandidateEntity(Candidate):
|
||||
self.CandidateEntity__metrics = Metrics(prometheus_collector=self.prometheus_collector)
|
||||
|
||||
user_dir = os.path.join(defines.user_dir, self.username)
|
||||
vector_db_dir=os.path.join(user_dir, defines.persist_directory)
|
||||
rag_content_dir=os.path.join(user_dir, defines.rag_content_dir)
|
||||
vector_db_dir = os.path.join(user_dir, defines.persist_directory)
|
||||
rag_content_dir = os.path.join(user_dir, defines.rag_content_dir)
|
||||
|
||||
os.makedirs(vector_db_dir, exist_ok=True)
|
||||
os.makedirs(rag_content_dir, exist_ok=True)
|
||||
@ -205,17 +220,21 @@ class CandidateEntity(Candidate):
|
||||
)
|
||||
has_username_rag = any(item.name == self.username for item in self.rags)
|
||||
if not has_username_rag:
|
||||
self.rags.append(RagEntry(
|
||||
name=self.username,
|
||||
description=f"Expert data about {self.full_name}.",
|
||||
))
|
||||
self.rags.append(
|
||||
RagEntry(
|
||||
name=self.username,
|
||||
description=f"Expert data about {self.full_name}.",
|
||||
)
|
||||
)
|
||||
self.rag_content_size = self.file_watcher.collection.count()
|
||||
|
||||
|
||||
class Agent(BaseModel, ABC):
|
||||
"""
|
||||
Base class for all agent types.
|
||||
This class defines the common attributes and methods for all agent types.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True # Allow arbitrary types like RedisDatabase
|
||||
|
||||
@ -237,7 +256,7 @@ class Agent(BaseModel, ABC):
|
||||
|
||||
conversation: List[ChatMessageUser] = Field(
|
||||
default_factory=list,
|
||||
description="Conversation history for this agent, used to maintain context across messages."
|
||||
description="Conversation history for this agent, used to maintain context across messages.",
|
||||
)
|
||||
|
||||
@property
|
||||
@ -254,9 +273,7 @@ class Agent(BaseModel, ABC):
|
||||
last_item = item
|
||||
return last_item
|
||||
|
||||
def set_optimal_context_size(
|
||||
self, llm: Any, model: str, prompt: str, ctx_buffer=2048
|
||||
) -> int:
|
||||
def set_optimal_context_size(self, llm: Any, model: str, prompt: str, ctx_buffer=2048) -> int:
|
||||
# Most models average 1.3-1.5 tokens per word
|
||||
word_count = len(prompt.split())
|
||||
tokens = int(word_count * 1.4)
|
||||
@ -265,9 +282,7 @@ class Agent(BaseModel, ABC):
|
||||
total_ctx = tokens + ctx_buffer
|
||||
|
||||
if total_ctx > self.context_size:
|
||||
logger.info(
|
||||
f"Increasing context size from {self.context_size} to {total_ctx}"
|
||||
)
|
||||
logger.info(f"Increasing context size from {self.context_size} to {total_ctx}")
|
||||
|
||||
# Grow the context size if necessary
|
||||
self.context_size = max(self.context_size, total_ctx)
|
||||
@ -472,24 +487,24 @@ class Agent(BaseModel, ABC):
|
||||
context = []
|
||||
for chroma_results in rag_message.content:
|
||||
for index, metadata in enumerate(chroma_results.metadatas):
|
||||
content = "\n".join([
|
||||
line.strip()
|
||||
for line in chroma_results.documents[index].split("\n")
|
||||
if line
|
||||
]).strip()
|
||||
context.append(f"""
|
||||
content = "\n".join(
|
||||
[line.strip() for line in chroma_results.documents[index].split("\n") if line]
|
||||
).strip()
|
||||
context.append(
|
||||
f"""
|
||||
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")}
|
||||
Document reference: {chroma_results.ids[index]}
|
||||
Content: {content}
|
||||
""")
|
||||
"""
|
||||
)
|
||||
return "\n".join(context)
|
||||
|
||||
async def generate_rag_results(
|
||||
self,
|
||||
session_id: str,
|
||||
prompt: str,
|
||||
top_k: int=defines.default_rag_top_k,
|
||||
threshold: float=defines.default_rag_threshold,
|
||||
top_k: int = defines.default_rag_top_k,
|
||||
threshold: float = defines.default_rag_threshold,
|
||||
) -> AsyncGenerator[ApiMessage, None]:
|
||||
"""
|
||||
Generate RAG results for the given query.
|
||||
@ -501,15 +516,11 @@ Content: {content}
|
||||
A list of dictionaries containing the RAG results.
|
||||
"""
|
||||
if not self.user:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="No user set for RAG generation."
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content="No user set for RAG generation.")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
results : List[ChromaDBGetResponse] = []
|
||||
entries: int = 0
|
||||
results: List[ChromaDBGetResponse] = []
|
||||
user: CandidateEntity = self.user
|
||||
for rag in user.rags:
|
||||
if not rag.enabled:
|
||||
@ -518,20 +529,18 @@ Content: {content}
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
activity=ApiActivityType.SEARCHING,
|
||||
content = f"Searching RAG context {rag.name}..."
|
||||
content=f"Searching RAG context {rag.name}...",
|
||||
)
|
||||
yield status_message
|
||||
|
||||
try:
|
||||
chroma_results = await user.file_watcher.find_similar(
|
||||
query=prompt, top_k=top_k, threshold=threshold
|
||||
)
|
||||
chroma_results = await user.file_watcher.find_similar(query=prompt, top_k=top_k, threshold=threshold)
|
||||
if not chroma_results:
|
||||
continue
|
||||
query_embedding = np.array(chroma_results["query_embedding"]).flatten() # type: ignore
|
||||
query_embedding = np.array(chroma_results["query_embedding"]).flatten() # type: ignore
|
||||
|
||||
umap_2d = user.file_watcher.umap_model_2d.transform([query_embedding])[0] # type: ignore
|
||||
umap_3d = user.file_watcher.umap_model_3d.transform([query_embedding])[0] # type: ignore
|
||||
umap_2d = user.file_watcher.umap_model_2d.transform([query_embedding])[0] # type: ignore
|
||||
umap_3d = user.file_watcher.umap_model_3d.transform([query_embedding])[0] # type: ignore
|
||||
|
||||
rag_metadata = ChromaDBGetResponse(
|
||||
name=rag.name,
|
||||
@ -549,7 +558,7 @@ Content: {content}
|
||||
continue_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
activity=ApiActivityType.SEARCHING,
|
||||
content=f"Error searching RAG context {rag.name}: {str(e)}"
|
||||
content=f"Error searching RAG context {rag.name}: {str(e)}",
|
||||
)
|
||||
yield continue_message
|
||||
|
||||
@ -562,23 +571,21 @@ Content: {content}
|
||||
return
|
||||
|
||||
async def llm_one_shot(
|
||||
self,
|
||||
llm: Any, model: str,
|
||||
session_id: str, prompt: str, system_prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7) -> AsyncGenerator[ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessage, None]:
|
||||
|
||||
self,
|
||||
llm: Any,
|
||||
model: str,
|
||||
session_id: str,
|
||||
prompt: str,
|
||||
system_prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7,
|
||||
) -> AsyncGenerator[ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessage, None]:
|
||||
if not self.user:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="No user set for chat generation."
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content="No user set for chat generation.")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
self.set_optimal_context_size(
|
||||
llm=llm, model=model, prompt=prompt+system_prompt
|
||||
)
|
||||
self.set_optimal_context_size(llm=llm, model=model, prompt=prompt + system_prompt)
|
||||
|
||||
options = ChatOptions(
|
||||
seed=8911,
|
||||
@ -592,9 +599,7 @@ Content: {content}
|
||||
]
|
||||
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
activity=ApiActivityType.GENERATING,
|
||||
content=f"Generating response..."
|
||||
session_id=session_id, activity=ApiActivityType.GENERATING, content="Generating response..."
|
||||
)
|
||||
yield status_message
|
||||
|
||||
@ -610,10 +615,7 @@ Content: {content}
|
||||
stream=True,
|
||||
):
|
||||
if not response:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="No response from LLM."
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
@ -628,46 +630,34 @@ Content: {content}
|
||||
yield streaming_message
|
||||
|
||||
if not response:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="No response from LLM."
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
self.user.collect_metrics(agent=self, response=response)
|
||||
self.context_tokens = (
|
||||
response.usage.prompt_eval_count + response.usage.eval_count
|
||||
)
|
||||
self.context_tokens = response.usage.prompt_eval_count + response.usage.eval_count
|
||||
|
||||
chat_message = ChatMessage(
|
||||
session_id=session_id,
|
||||
tunables=tunables,
|
||||
status=ApiStatusType.DONE,
|
||||
content=content,
|
||||
metadata = ChatMessageMetaData(
|
||||
metadata=ChatMessageMetaData(
|
||||
options=options,
|
||||
eval_count=response.usage.eval_count,
|
||||
eval_duration=response.usage.eval_duration,
|
||||
prompt_eval_count=response.usage.prompt_eval_count,
|
||||
prompt_eval_duration=response.usage.prompt_eval_duration,
|
||||
|
||||
)
|
||||
),
|
||||
)
|
||||
yield chat_message
|
||||
return
|
||||
|
||||
async def generate(
|
||||
self, llm: Any, model: str,
|
||||
session_id: str, prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7
|
||||
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
|
||||
) -> AsyncGenerator[ApiMessage, None]:
|
||||
if not self.user:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="No user set for chat generation."
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content="No user set for chat generation.")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
@ -675,12 +665,11 @@ Content: {content}
|
||||
session_id=session_id,
|
||||
content=prompt,
|
||||
)
|
||||
user = self.user
|
||||
|
||||
self.user.metrics.generate_count.labels(agent=self.agent_type).inc()
|
||||
with self.user.metrics.generate_duration.labels(agent=self.agent_type).time():
|
||||
context = None
|
||||
rag_message : ChatMessageRagSearch | None = None
|
||||
rag_message: ChatMessageRagSearch | None = None
|
||||
if self.user:
|
||||
message = None
|
||||
async for message in self.generate_rag_results(session_id=session_id, prompt=prompt):
|
||||
@ -692,38 +681,32 @@ Content: {content}
|
||||
yield message
|
||||
|
||||
if not isinstance(message, ChatMessageRagSearch):
|
||||
raise ValueError(
|
||||
f"Expected ChatMessageRagSearch, got {type(rag_message)}"
|
||||
)
|
||||
raise ValueError(f"Expected ChatMessageRagSearch, got {type(rag_message)}")
|
||||
|
||||
rag_message = message
|
||||
context = self.get_rag_context(rag_message)
|
||||
|
||||
# Create a pruned down message list based purely on the prompt and responses,
|
||||
# discarding the full preamble generated by prepare_message
|
||||
messages: List[LLMMessage] = [
|
||||
LLMMessage(role="system", content=self.system_prompt)
|
||||
]
|
||||
messages: List[LLMMessage] = [LLMMessage(role="system", content=self.system_prompt)]
|
||||
# Add the conversation history to the messages
|
||||
messages.extend([
|
||||
LLMMessage(role="user" if isinstance(m, ChatMessageUser) else "assistant", content=m.content)
|
||||
for m in self.conversation
|
||||
])
|
||||
messages.extend(
|
||||
[
|
||||
LLMMessage(role="user" if isinstance(m, ChatMessageUser) else "assistant", content=m.content)
|
||||
for m in self.conversation
|
||||
]
|
||||
)
|
||||
# Add the RAG context to the messages if available
|
||||
if context:
|
||||
messages.append(
|
||||
LLMMessage(
|
||||
role="user",
|
||||
content=f"<|context|>\nThe following is context information about {self.user.full_name}:\n{context}\n</|context|>\n\nPrompt to respond to:\n{prompt}\n"
|
||||
content=f"<|context|>\nThe following is context information about {self.user.full_name}:\n{context}\n</|context|>\n\nPrompt to respond to:\n{prompt}\n",
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Only the actual user query is provided with the full context message
|
||||
messages.append(
|
||||
LLMMessage(role="user", content=prompt)
|
||||
)
|
||||
|
||||
llm_history = messages
|
||||
messages.append(LLMMessage(role="user", content=prompt))
|
||||
|
||||
# use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
|
||||
# message.metadata.tools = {
|
||||
@ -827,16 +810,12 @@ Content: {content}
|
||||
|
||||
# not use_tools
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
activity=ApiActivityType.GENERATING,
|
||||
content=f"Generating response..."
|
||||
session_id=session_id, activity=ApiActivityType.GENERATING, content="Generating response..."
|
||||
)
|
||||
yield status_message
|
||||
|
||||
# Set the response for streaming
|
||||
self.set_optimal_context_size(
|
||||
llm, model, prompt=prompt
|
||||
)
|
||||
self.set_optimal_context_size(llm, model, prompt=prompt)
|
||||
|
||||
options = ChatOptions(
|
||||
seed=8911,
|
||||
@ -856,10 +835,7 @@ Content: {content}
|
||||
stream=True,
|
||||
):
|
||||
if not response:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="No response from LLM."
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
@ -873,17 +849,12 @@ Content: {content}
|
||||
yield streaming_message
|
||||
|
||||
if not response:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="No response from LLM."
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
self.user.collect_metrics(agent=self, response=response)
|
||||
self.context_tokens = (
|
||||
response.usage.prompt_eval_count + response.usage.eval_count
|
||||
)
|
||||
self.context_tokens = response.usage.prompt_eval_count + response.usage.eval_count
|
||||
end_time = time.perf_counter()
|
||||
|
||||
chat_message = ChatMessage(
|
||||
@ -891,7 +862,7 @@ Content: {content}
|
||||
tunables=tunables,
|
||||
status=ApiStatusType.DONE,
|
||||
content=content,
|
||||
metadata = ChatMessageMetaData(
|
||||
metadata=ChatMessageMetaData(
|
||||
options=options,
|
||||
eval_count=response.usage.eval_count,
|
||||
eval_duration=response.usage.eval_duration,
|
||||
@ -902,10 +873,9 @@ Content: {content}
|
||||
"llm_streamed": end_time - start_time,
|
||||
"llm_with_tools": 0, # Placeholder for tool processing time
|
||||
},
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Add the user and chat messages to the conversation
|
||||
self.conversation.append(user_message)
|
||||
self.conversation.append(chat_message)
|
||||
@ -999,12 +969,13 @@ Content: {content}
|
||||
|
||||
raise ValueError("No Markdown found in the response")
|
||||
|
||||
|
||||
_agents: List[Agent] = []
|
||||
|
||||
|
||||
def get_or_create_agent(
|
||||
agent_type: str,
|
||||
prometheus_collector: CollectorRegistry,
|
||||
user: Optional[CandidateEntity]=None) -> Agent:
|
||||
agent_type: str, prometheus_collector: CollectorRegistry, user: Optional[CandidateEntity] = None
|
||||
) -> Agent:
|
||||
"""
|
||||
Get or create and append a new agent of the specified type, ensuring only one agent per type exists.
|
||||
|
||||
@ -1028,14 +999,16 @@ def get_or_create_agent(
|
||||
for agent_cls in Agent.__subclasses__():
|
||||
if agent_cls.model_fields["agent_type"].default == agent_type:
|
||||
# Create the agent instance with provided kwargs
|
||||
agent = agent_cls(agent_type=agent_type, # type: ignore[call-arg]
|
||||
user=user)
|
||||
agent = agent_cls(
|
||||
agent_type=agent_type, # type: ignore[call-arg]
|
||||
user=user,
|
||||
)
|
||||
_agents.append(agent)
|
||||
return agent
|
||||
|
||||
raise ValueError(f"No agent class found for agent_type: {agent_type}")
|
||||
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(Agent._agent_type, Agent)
|
||||
CandidateEntity.model_rebuild()
|
||||
|
||||
|
@ -5,10 +5,10 @@ from .base import Agent, agent_registry
|
||||
from logger import logger
|
||||
|
||||
from .registry import agent_registry
|
||||
from models import ( ApiMessage, Tunables, ApiStatusType)
|
||||
from models import ApiMessage, Tunables, ApiStatusType
|
||||
|
||||
|
||||
system_message = f"""
|
||||
system_message = """
|
||||
When answering queries, follow these steps:
|
||||
|
||||
- When any content from <|context|> is relevant, synthesize information from all sources to provide the most complete answer.
|
||||
@ -21,21 +21,19 @@ Always <|context|> when possible. Be concise, and never make up information. If
|
||||
Before answering, ensure you have spelled the candidate's name correctly.
|
||||
"""
|
||||
|
||||
|
||||
class CandidateChat(Agent):
|
||||
"""
|
||||
CandidateChat Agent
|
||||
"""
|
||||
|
||||
agent_type: Literal["candidate_chat"] = "candidate_chat" # type: ignore
|
||||
agent_type: Literal["candidate_chat"] = "candidate_chat" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
system_prompt: str = system_message
|
||||
|
||||
async def generate(
|
||||
self, llm: Any, model: str,
|
||||
session_id: str, prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7
|
||||
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
|
||||
) -> AsyncGenerator[ApiMessage, None]:
|
||||
user = self.user
|
||||
if not user:
|
||||
@ -54,12 +52,14 @@ Use that spelling instead of any spelling you may find in the <|context|>.
|
||||
{system_message}
|
||||
"""
|
||||
|
||||
async for message in super().generate(llm=llm, model=model, session_id=session_id, prompt=prompt, temperature=temperature, tunables=tunables):
|
||||
async for message in super().generate(
|
||||
llm=llm, model=model, session_id=session_id, prompt=prompt, temperature=temperature, tunables=tunables
|
||||
):
|
||||
if message.status == ApiStatusType.ERROR:
|
||||
yield message
|
||||
return
|
||||
yield message
|
||||
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(CandidateChat._agent_type, CandidateChat)
|
||||
|
||||
|
@ -49,11 +49,13 @@ class Chat(Agent):
|
||||
"""
|
||||
Chat Agent
|
||||
"""
|
||||
agent_type: Literal["general"] = "general" # type: ignore
|
||||
|
||||
agent_type: Literal["general"] = "general" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
system_prompt: str = system_message
|
||||
|
||||
|
||||
# async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
|
||||
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
|
||||
# if not self.context:
|
||||
|
@ -1,14 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from typing import (
|
||||
Dict,
|
||||
Literal,
|
||||
ClassVar,
|
||||
cast,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
List,
|
||||
Optional
|
||||
# override
|
||||
Optional,
|
||||
# override
|
||||
) # NOTE: You must import Optional for late binding to work
|
||||
import random
|
||||
import time
|
||||
@ -16,7 +13,15 @@ import time
|
||||
import os
|
||||
|
||||
from .base import Agent, agent_registry
|
||||
from models import ApiActivityType, ChatMessage, ChatMessageError, ChatMessageStatus, ChatMessageStreaming, ApiStatusType, Tunables
|
||||
from models import (
|
||||
ApiActivityType,
|
||||
ChatMessage,
|
||||
ChatMessageError,
|
||||
ChatMessageStatus,
|
||||
ChatMessageStreaming,
|
||||
ApiStatusType,
|
||||
Tunables,
|
||||
)
|
||||
from logger import logger
|
||||
import defines
|
||||
import backstory_traceback as traceback
|
||||
@ -26,18 +31,16 @@ from image_generator.profile_image import generate_image, ImageRequest
|
||||
seed = int(time.time())
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
class ImageGenerator(Agent):
|
||||
agent_type: Literal["generate_image"] = "generate_image" # type: ignore
|
||||
agent_type: Literal["generate_image"] = "generate_image" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
agent_persist: bool = False
|
||||
|
||||
system_prompt: str = "" # No system prompt is used
|
||||
system_prompt: str = "" # No system prompt is used
|
||||
|
||||
async def generate(
|
||||
self, llm: Any, model: str,
|
||||
session_id: str, prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7
|
||||
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
|
||||
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
|
||||
if not self.user:
|
||||
logger.error("User is not set for ImageGenerator agent.")
|
||||
@ -57,11 +60,17 @@ class ImageGenerator(Agent):
|
||||
yield status_message
|
||||
|
||||
logger.info(f"Image generation: {file_path} <- {prompt}")
|
||||
request = ImageRequest(filepath=file_path, session_id=session_id, prompt=prompt, iterations=4, height=256, width=256, guidance_scale=7.5)
|
||||
request = ImageRequest(
|
||||
filepath=file_path,
|
||||
session_id=session_id,
|
||||
prompt=prompt,
|
||||
iterations=4,
|
||||
height=256,
|
||||
width=256,
|
||||
guidance_scale=7.5,
|
||||
)
|
||||
generated_message = None
|
||||
async for generated_message in generate_image(
|
||||
request=request
|
||||
):
|
||||
async for generated_message in generate_image(request=request):
|
||||
if generated_message.status == ApiStatusType.ERROR:
|
||||
yield generated_message
|
||||
return
|
||||
@ -71,8 +80,7 @@ class ImageGenerator(Agent):
|
||||
|
||||
if generated_message is None:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Image generation failed to produce a valid response."
|
||||
session_id=session_id, content="Image generation failed to produce a valid response."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
@ -86,21 +94,19 @@ class ImageGenerator(Agent):
|
||||
generated_image = ChatMessage(
|
||||
session_id=session_id,
|
||||
status=ApiStatusType.DONE,
|
||||
content = f"{defines.api_prefix}/profile/{user.username}",
|
||||
metadata=generated_message.metadata
|
||||
content=f"{defines.api_prefix}/profile/{user.username}",
|
||||
metadata=generated_message.metadata,
|
||||
)
|
||||
yield generated_image
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content=f"Error generating image: {str(e)}"
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content=f"Error generating image: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(ImageGenerator._agent_type, ImageGenerator)
|
||||
|
@ -1,16 +1,15 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import model_validator, Field, BaseModel # type: ignore
|
||||
from pydantic import Field # type: ignore
|
||||
from typing import (
|
||||
Dict,
|
||||
Literal,
|
||||
ClassVar,
|
||||
cast,
|
||||
Any,
|
||||
Tuple,
|
||||
AsyncGenerator,
|
||||
List,
|
||||
Optional
|
||||
# override
|
||||
Optional,
|
||||
# override
|
||||
) # NOTE: You must import Optional for late binding to work
|
||||
import random
|
||||
import re
|
||||
@ -19,10 +18,18 @@ import time
|
||||
import time
|
||||
import os
|
||||
import random
|
||||
from names_dataset import NameDataset, NameWrapper # type: ignore
|
||||
|
||||
from .base import Agent, agent_registry
|
||||
from models import ApiActivityType, ChatMessage, ChatMessageError, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ApiStatusType, Tunables
|
||||
from models import (
|
||||
ApiActivityType,
|
||||
ChatMessage,
|
||||
ChatMessageError,
|
||||
ApiMessageType,
|
||||
ChatMessageStatus,
|
||||
ChatMessageStreaming,
|
||||
ApiStatusType,
|
||||
Tunables,
|
||||
)
|
||||
from logger import logger
|
||||
import defines
|
||||
import backstory_traceback as traceback
|
||||
@ -45,6 +52,7 @@ emptyUser = {
|
||||
"questions": [],
|
||||
}
|
||||
|
||||
|
||||
def generate_persona_system_prompt(persona: Dict[str, Any]) -> str:
|
||||
return f"""\
|
||||
You are a casting director for a movie. Your job is to provide information on ficticious personas for use in a screen play.
|
||||
@ -86,6 +94,7 @@ DO NOT infer, imply, abbreviate, or state the ethnicity, gender, or age in the u
|
||||
You are providing those only for use later by the system when casting individuals for the role.
|
||||
"""
|
||||
|
||||
|
||||
generate_resume_system_prompt = """
|
||||
You are a creative writing casting director. As part of the casting, you are building backstories about individuals. The first part
|
||||
of that is to create an in-depth resume for the person. You will be provided with the following information:
|
||||
@ -117,10 +126,12 @@ import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EthnicNameGenerator:
|
||||
def __init__(self):
|
||||
try:
|
||||
from names_dataset import NameDataset # type: ignore
|
||||
from names_dataset import NameDataset # type: ignore
|
||||
|
||||
self.nd = NameDataset()
|
||||
except ImportError:
|
||||
logger.error("NameDataset not available. Please install: pip install names-dataset")
|
||||
@ -131,24 +142,24 @@ class EthnicNameGenerator:
|
||||
|
||||
# US Census 2020 approximate ethnic distribution
|
||||
self.ethnic_weights = {
|
||||
'White': 0.576,
|
||||
'Hispanic': 0.186,
|
||||
'Black': 0.134,
|
||||
'Asian': 0.062,
|
||||
'Native American': 0.013,
|
||||
'Pacific Islander': 0.003,
|
||||
'Mixed/Other': 0.026
|
||||
"White": 0.576,
|
||||
"Hispanic": 0.186,
|
||||
"Black": 0.134,
|
||||
"Asian": 0.062,
|
||||
"Native American": 0.013,
|
||||
"Pacific Islander": 0.003,
|
||||
"Mixed/Other": 0.026,
|
||||
}
|
||||
|
||||
# Map ethnicities to countries (using alpha-2 codes that NameDataset uses)
|
||||
self.ethnic_country_mapping = {
|
||||
'White': ['US', 'GB', 'DE', 'IE', 'IT', 'PL', 'FR', 'CA', 'AU'],
|
||||
'Hispanic': ['MX', 'ES', 'CO', 'PE', 'AR', 'CU', 'VE', 'CL'],
|
||||
'Black': ['US'], # African American names
|
||||
'Asian': ['CN', 'IN', 'PH', 'VN', 'KR', 'JP', 'TH', 'MY'],
|
||||
'Native American': ['US'],
|
||||
'Pacific Islander': ['US'],
|
||||
'Mixed/Other': ['US']
|
||||
"White": ["US", "GB", "DE", "IE", "IT", "PL", "FR", "CA", "AU"],
|
||||
"Hispanic": ["MX", "ES", "CO", "PE", "AR", "CU", "VE", "CL"],
|
||||
"Black": ["US"], # African American names
|
||||
"Asian": ["CN", "IN", "PH", "VN", "KR", "JP", "TH", "MY"],
|
||||
"Native American": ["US"],
|
||||
"Pacific Islander": ["US"],
|
||||
"Mixed/Other": ["US"],
|
||||
}
|
||||
|
||||
def get_weighted_ethnicity(self) -> str:
|
||||
@ -157,8 +168,9 @@ class EthnicNameGenerator:
|
||||
weights = list(self.ethnic_weights.values())
|
||||
return random.choices(ethnicities, weights=weights)[0]
|
||||
|
||||
def get_names_by_criteria(self, countries: List[str], gender: Optional[str] = None,
|
||||
n: int = 50, use_first_names: bool = True) -> List[str]:
|
||||
def get_names_by_criteria(
|
||||
self, countries: List[str], gender: Optional[str] = None, n: int = 50, use_first_names: bool = True
|
||||
) -> List[str]:
|
||||
"""Get names matching criteria using NameDataset's get_top_names method"""
|
||||
if not self.nd:
|
||||
return []
|
||||
@ -168,16 +180,13 @@ class EthnicNameGenerator:
|
||||
try:
|
||||
# Get top names for this country
|
||||
top_names = self.nd.get_top_names(
|
||||
n=n,
|
||||
use_first_names=use_first_names,
|
||||
country_alpha2=country_code,
|
||||
gender=gender
|
||||
n=n, use_first_names=use_first_names, country_alpha2=country_code, gender=gender
|
||||
)
|
||||
|
||||
if country_code in top_names:
|
||||
if use_first_names and gender:
|
||||
# For first names with gender specified
|
||||
gender_key = 'M' if gender.upper() in ['M', 'MALE'] else 'F'
|
||||
gender_key = "M" if gender.upper() in ["M", "MALE"] else "F"
|
||||
if gender_key in top_names[country_code]:
|
||||
all_names.extend(top_names[country_code][gender_key])
|
||||
elif use_first_names:
|
||||
@ -194,25 +203,18 @@ class EthnicNameGenerator:
|
||||
|
||||
return list(set(all_names)) # Remove duplicates
|
||||
|
||||
def get_name_by_ethnicity(self, ethnicity: str, gender: str = 'random') -> Tuple[str, str, str, str]:
|
||||
def get_name_by_ethnicity(self, ethnicity: str, gender: str = "random") -> Tuple[str, str, str, str]:
|
||||
"""Generate a name based on ethnicity using the correct NameDataset API"""
|
||||
if gender == 'random':
|
||||
gender = random.choice(['Male', 'Female'])
|
||||
if gender == "random":
|
||||
gender = random.choice(["Male", "Female"])
|
||||
|
||||
countries = self.ethnic_country_mapping.get(ethnicity, ['US'])
|
||||
countries = self.ethnic_country_mapping.get(ethnicity, ["US"])
|
||||
|
||||
# Get first names
|
||||
first_names = self.get_names_by_criteria(
|
||||
countries=countries,
|
||||
gender=gender,
|
||||
use_first_names=True
|
||||
)
|
||||
first_names = self.get_names_by_criteria(countries=countries, gender=gender, use_first_names=True)
|
||||
|
||||
# Get last names
|
||||
last_names = self.get_names_by_criteria(
|
||||
countries=countries,
|
||||
use_first_names=False
|
||||
)
|
||||
last_names = self.get_names_by_criteria(countries=countries, use_first_names=False)
|
||||
|
||||
# Select names or use fallbacks
|
||||
if first_names:
|
||||
@ -232,57 +234,60 @@ class EthnicNameGenerator:
|
||||
def _get_fallback_first_name(self, gender: str, ethnicity: str) -> str:
|
||||
"""Provide culturally appropriate fallback first names"""
|
||||
fallback_names = {
|
||||
'White': {
|
||||
'Male': ['James', 'Robert', 'John', 'Michael', 'William', 'David', 'Richard', 'Joseph'],
|
||||
'Female': ['Mary', 'Patricia', 'Jennifer', 'Linda', 'Elizabeth', 'Barbara', 'Susan', 'Jessica']
|
||||
"White": {
|
||||
"Male": ["James", "Robert", "John", "Michael", "William", "David", "Richard", "Joseph"],
|
||||
"Female": ["Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan", "Jessica"],
|
||||
},
|
||||
'Hispanic': {
|
||||
'Male': ['José', 'Luis', 'Miguel', 'Juan', 'Francisco', 'Alejandro', 'Antonio', 'Carlos'],
|
||||
'Female': ['María', 'Guadalupe', 'Juana', 'Margarita', 'Francisca', 'Teresa', 'Rosa', 'Ana']
|
||||
"Hispanic": {
|
||||
"Male": ["José", "Luis", "Miguel", "Juan", "Francisco", "Alejandro", "Antonio", "Carlos"],
|
||||
"Female": ["María", "Guadalupe", "Juana", "Margarita", "Francisca", "Teresa", "Rosa", "Ana"],
|
||||
},
|
||||
'Black': {
|
||||
'Male': ['James', 'Robert', 'John', 'Michael', 'William', 'David', 'Richard', 'Charles'],
|
||||
'Female': ['Mary', 'Patricia', 'Linda', 'Elizabeth', 'Barbara', 'Susan', 'Jessica', 'Sarah']
|
||||
"Black": {
|
||||
"Male": ["James", "Robert", "John", "Michael", "William", "David", "Richard", "Charles"],
|
||||
"Female": ["Mary", "Patricia", "Linda", "Elizabeth", "Barbara", "Susan", "Jessica", "Sarah"],
|
||||
},
|
||||
"Asian": {
|
||||
"Male": ["Wei", "Ming", "Chen", "Li", "Kumar", "Raj", "Hiroshi", "Takeshi"],
|
||||
"Female": ["Mei", "Lin", "Ling", "Priya", "Yuki", "Soo", "Hana", "Anh"],
|
||||
},
|
||||
'Asian': {
|
||||
'Male': ['Wei', 'Ming', 'Chen', 'Li', 'Kumar', 'Raj', 'Hiroshi', 'Takeshi'],
|
||||
'Female': ['Mei', 'Lin', 'Ling', 'Priya', 'Yuki', 'Soo', 'Hana', 'Anh']
|
||||
}
|
||||
}
|
||||
|
||||
ethnicity_names = fallback_names.get(ethnicity, fallback_names['White'])
|
||||
return random.choice(ethnicity_names.get(gender, ethnicity_names['Male']))
|
||||
ethnicity_names = fallback_names.get(ethnicity, fallback_names["White"])
|
||||
return random.choice(ethnicity_names.get(gender, ethnicity_names["Male"]))
|
||||
|
||||
def _get_fallback_last_name(self, ethnicity: str) -> str:
|
||||
"""Provide culturally appropriate fallback last names"""
|
||||
fallback_surnames = {
|
||||
'White': ['Smith', 'Johnson', 'Williams', 'Brown', 'Jones', 'Miller', 'Wilson', 'Moore'],
|
||||
'Hispanic': ['García', 'Rodríguez', 'Martínez', 'López', 'González', 'Pérez', 'Sánchez', 'Ramírez'],
|
||||
'Black': ['Johnson', 'Williams', 'Brown', 'Jones', 'Davis', 'Miller', 'Wilson', 'Moore'],
|
||||
'Asian': ['Li', 'Wang', 'Zhang', 'Liu', 'Chen', 'Yang', 'Huang', 'Zhao']
|
||||
"White": ["Smith", "Johnson", "Williams", "Brown", "Jones", "Miller", "Wilson", "Moore"],
|
||||
"Hispanic": ["García", "Rodríguez", "Martínez", "López", "González", "Pérez", "Sánchez", "Ramírez"],
|
||||
"Black": ["Johnson", "Williams", "Brown", "Jones", "Davis", "Miller", "Wilson", "Moore"],
|
||||
"Asian": ["Li", "Wang", "Zhang", "Liu", "Chen", "Yang", "Huang", "Zhao"],
|
||||
}
|
||||
|
||||
return random.choice(fallback_surnames.get(ethnicity, fallback_surnames['White']))
|
||||
return random.choice(fallback_surnames.get(ethnicity, fallback_surnames["White"]))
|
||||
|
||||
def generate_random_name(self, gender: str = 'random') -> Tuple[str, str, str, str]:
|
||||
def generate_random_name(self, gender: str = "random") -> Tuple[str, str, str, str]:
|
||||
"""Generate a random name with ethnicity based on US demographics"""
|
||||
ethnicity = self.get_weighted_ethnicity()
|
||||
return self.get_name_by_ethnicity(ethnicity, gender)
|
||||
|
||||
def generate_multiple_names(self, count: int = 10, gender: str = 'random') -> List[Dict]:
|
||||
def generate_multiple_names(self, count: int = 10, gender: str = "random") -> List[Dict]:
|
||||
"""Generate multiple random names"""
|
||||
names = []
|
||||
for _ in range(count):
|
||||
first, last, ethnicity, actual_gender = self.generate_random_name(gender)
|
||||
names.append({
|
||||
'full_name': f"{first} {last}",
|
||||
'first_name': first,
|
||||
'last_name': last,
|
||||
'ethnicity': ethnicity,
|
||||
'gender': actual_gender
|
||||
})
|
||||
names.append(
|
||||
{
|
||||
"full_name": f"{first} {last}",
|
||||
"first_name": first,
|
||||
"last_name": last,
|
||||
"ethnicity": ethnicity,
|
||||
"gender": actual_gender,
|
||||
}
|
||||
)
|
||||
return names
|
||||
|
||||
|
||||
class GeneratePersona(Agent):
|
||||
agent_type: Literal["generate_persona"] = "generate_persona" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
@ -307,23 +312,20 @@ class GeneratePersona(Agent):
|
||||
self.full_name = f"{self.first_name} {self.last_name}"
|
||||
|
||||
async def generate(
|
||||
self, llm: Any, model: str,
|
||||
session_id: str, prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7
|
||||
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
|
||||
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
|
||||
self.randomize()
|
||||
|
||||
original_prompt = prompt.strip()
|
||||
|
||||
persona = {
|
||||
"age": self.age,
|
||||
"gender": self.gender,
|
||||
"ethnicity": self.ethnicity,
|
||||
"full_name": self.full_name,
|
||||
"first_name": self.first_name,
|
||||
"last_name": self.last_name,
|
||||
}
|
||||
"age": self.age,
|
||||
"gender": self.gender,
|
||||
"ethnicity": self.ethnicity,
|
||||
"full_name": self.full_name,
|
||||
"first_name": self.first_name,
|
||||
"last_name": self.last_name,
|
||||
}
|
||||
|
||||
prompt = f"""\
|
||||
```json
|
||||
@ -339,10 +341,11 @@ Incorporate the following into the job description: {original_prompt}
|
||||
#
|
||||
# Generate the persona
|
||||
#
|
||||
logger.info(f"🤖 Generating persona...")
|
||||
logger.info("🤖 Generating persona...")
|
||||
generating_message = None
|
||||
async for generating_message in self.llm_one_shot(
|
||||
llm=llm, model=model,
|
||||
llm=llm,
|
||||
model=model,
|
||||
session_id=session_id,
|
||||
prompt=prompt,
|
||||
system_prompt=generate_persona_system_prompt(persona=persona),
|
||||
@ -356,8 +359,7 @@ Incorporate the following into the job description: {original_prompt}
|
||||
|
||||
if not generating_message:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Persona generation failed to generate a response."
|
||||
session_id=session_id, content="Persona generation failed to generate a response."
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
@ -375,7 +377,7 @@ Incorporate the following into the job description: {original_prompt}
|
||||
self.username = persona.get("username", None)
|
||||
if not self.username:
|
||||
raise ValueError("LLM did not generate a username")
|
||||
self.username = re.sub(r'\s+', '.', self.username)
|
||||
self.username = re.sub(r"\s+", ".", self.username)
|
||||
user_dir = os.path.join(defines.user_dir, persona["username"])
|
||||
while os.path.exists(user_dir):
|
||||
match = re.match(r"^(.*?)(\d*)$", persona["username"])
|
||||
@ -398,19 +400,14 @@ Incorporate the following into the job description: {original_prompt}
|
||||
location_parts = persona["location"].split(",")
|
||||
if len(location_parts) == 3:
|
||||
city, state, country = [part.strip() for part in location_parts]
|
||||
persona["location"] = {
|
||||
"city": city,
|
||||
"state": state,
|
||||
"country": country
|
||||
}
|
||||
persona["location"] = {"city": city, "state": state, "country": country}
|
||||
else:
|
||||
logger.error(f"Invalid location format: {persona['location']}")
|
||||
persona["location"] = None
|
||||
persona["is_ai"] = True
|
||||
except Exception as e:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content=f"Error parsing LLM response: {str(e)}\n\n{json_str}"
|
||||
session_id=session_id, content=f"Error parsing LLM response: {str(e)}\n\n{json_str}"
|
||||
)
|
||||
logger.error(f"❌ Error parsing LLM response: {error_message.content}")
|
||||
logger.error(traceback.format_exc())
|
||||
@ -422,10 +419,7 @@ Incorporate the following into the job description: {original_prompt}
|
||||
|
||||
# Persona generated
|
||||
persona_message = ChatMessage(
|
||||
session_id=session_id,
|
||||
status=ApiStatusType.DONE,
|
||||
type=ApiMessageType.JSON,
|
||||
content = json.dumps(persona)
|
||||
session_id=session_id, status=ApiStatusType.DONE, type=ApiMessageType.JSON, content=json.dumps(persona)
|
||||
)
|
||||
yield persona_message
|
||||
|
||||
@ -434,8 +428,8 @@ Incorporate the following into the job description: {original_prompt}
|
||||
#
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
activity = ApiActivityType.THINKING,
|
||||
content = f"Generating resume for {persona['full_name']}..."
|
||||
activity=ApiActivityType.THINKING,
|
||||
content=f"Generating resume for {persona['full_name']}...",
|
||||
)
|
||||
logger.info(f"🤖 {status_message.content}")
|
||||
yield status_message
|
||||
@ -458,7 +452,8 @@ Incorporate the following into the job description: {original_prompt}
|
||||
Make sure at least one of the candidate's job descriptions take into account the following: {original_prompt}."""
|
||||
|
||||
async for generating_message in self.llm_one_shot(
|
||||
llm=llm, model=model,
|
||||
llm=llm,
|
||||
model=model,
|
||||
session_id=session_id,
|
||||
prompt=content,
|
||||
system_prompt=generate_resume_system_prompt,
|
||||
@ -472,8 +467,7 @@ Make sure at least one of the candidate's job descriptions take into account the
|
||||
|
||||
if not generating_message:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Resume generation failed to generate a response."
|
||||
session_id=session_id, content="Resume generation failed to generate a response."
|
||||
)
|
||||
logger.error(f"❌ {error_message.content}")
|
||||
yield error_message
|
||||
@ -481,10 +475,7 @@ Make sure at least one of the candidate's job descriptions take into account the
|
||||
|
||||
resume = self.extract_markdown_from_text(generating_message.content)
|
||||
resume_message = ChatMessage(
|
||||
session_id=session_id,
|
||||
status=ApiStatusType.DONE,
|
||||
type=ApiMessageType.TEXT,
|
||||
content=resume
|
||||
session_id=session_id, status=ApiStatusType.DONE, type=ApiMessageType.TEXT, content=resume
|
||||
)
|
||||
yield resume_message
|
||||
return
|
||||
@ -504,5 +495,6 @@ Make sure at least one of the candidate's job descriptions take into account the
|
||||
|
||||
raise ValueError("No JSON found in the response")
|
||||
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(GeneratePersona._agent_type, GeneratePersona)
|
||||
|
@ -1,30 +1,34 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import model_validator, Field # type: ignore
|
||||
from typing import (
|
||||
Dict,
|
||||
Literal,
|
||||
ClassVar,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
List,
|
||||
Optional
|
||||
# override
|
||||
# override
|
||||
) # NOTE: You must import Optional for late binding to work
|
||||
import json
|
||||
import numpy as np # type: ignore
|
||||
|
||||
from logger import logger
|
||||
from .base import Agent, agent_registry
|
||||
from models import (ApiActivityType, ApiMessage, ApiStatusType, ChatMessage, ChatMessageError, ChatMessageResume, ChatMessageStatus, SkillAssessment, SkillStrength)
|
||||
from models import (
|
||||
ApiActivityType,
|
||||
ApiMessage,
|
||||
ApiStatusType,
|
||||
ChatMessage,
|
||||
ChatMessageError,
|
||||
ChatMessageResume,
|
||||
ChatMessageStatus,
|
||||
SkillAssessment,
|
||||
SkillStrength,
|
||||
)
|
||||
|
||||
|
||||
class GenerateResume(Agent):
|
||||
agent_type: Literal["generate_resume"] = "generate_resume" # type: ignore
|
||||
agent_type: Literal["generate_resume"] = "generate_resume" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
def generate_resume_prompt(
|
||||
self,
|
||||
skills: List[SkillAssessment]
|
||||
):
|
||||
def generate_resume_prompt(self, skills: List[SkillAssessment]):
|
||||
"""
|
||||
Generate a professional resume based on skill assessment results
|
||||
|
||||
@ -45,7 +49,7 @@ class GenerateResume(Agent):
|
||||
SkillStrength.STRONG: [],
|
||||
SkillStrength.MODERATE: [],
|
||||
SkillStrength.WEAK: [],
|
||||
SkillStrength.NONE: []
|
||||
SkillStrength.NONE: [],
|
||||
}
|
||||
|
||||
experience_evidence = {}
|
||||
@ -67,11 +71,7 @@ class GenerateResume(Agent):
|
||||
experience_evidence[source] = []
|
||||
|
||||
experience_evidence[source].append(
|
||||
{
|
||||
"skill": skill,
|
||||
"quote": evidence.quote,
|
||||
"context": evidence.context
|
||||
}
|
||||
{"skill": skill, "quote": evidence.quote, "context": evidence.context}
|
||||
)
|
||||
|
||||
# Build the system prompt
|
||||
@ -171,16 +171,16 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
|
||||
) -> AsyncGenerator[ApiMessage, None]:
|
||||
# Stage 1A: Analyze job requirements
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
content = f"Analyzing job requirements",
|
||||
activity=ApiActivityType.THINKING
|
||||
session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING
|
||||
)
|
||||
yield status_message
|
||||
|
||||
system_prompt, prompt = self.generate_resume_prompt(skills=skills)
|
||||
|
||||
generated_message = None
|
||||
async for generated_message in self.llm_one_shot(llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt):
|
||||
async for generated_message in self.llm_one_shot(
|
||||
llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt
|
||||
):
|
||||
if generated_message.status == ApiStatusType.ERROR:
|
||||
yield generated_message
|
||||
return
|
||||
@ -189,8 +189,7 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
|
||||
|
||||
if not generated_message:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Job requirements analysis failed to generate a response."
|
||||
session_id=session_id, content="Job requirements analysis failed to generate a response."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
@ -198,8 +197,7 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
|
||||
|
||||
if not isinstance(generated_message, ChatMessage):
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Job requirements analysis did not return a valid message."
|
||||
session_id=session_id, content="Job requirements analysis did not return a valid message."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
@ -215,8 +213,9 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
yield resume_message
|
||||
logger.info(f"✅ Resume generation completed successfully.")
|
||||
logger.info("✅ Resume generation completed successfully.")
|
||||
return
|
||||
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(GenerateResume._agent_type, GenerateResume)
|
||||
|
@ -1,24 +1,34 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import model_validator, Field # type: ignore
|
||||
from typing import (
|
||||
Dict,
|
||||
Literal,
|
||||
ClassVar,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
List,
|
||||
Optional
|
||||
# override
|
||||
Optional,
|
||||
# override
|
||||
) # NOTE: You must import Optional for late binding to work
|
||||
import inspect
|
||||
import json
|
||||
import numpy as np # type: ignore
|
||||
|
||||
from .base import Agent, agent_registry
|
||||
from models import ApiActivityType, ApiMessage, ChatMessage, ChatMessageError, ChatMessageStatus, ChatMessageStreaming, ApiStatusType, Job, JobRequirements, JobRequirementsMessage, Tunables
|
||||
from models import (
|
||||
ApiActivityType,
|
||||
ApiMessage,
|
||||
ChatMessage,
|
||||
ChatMessageError,
|
||||
ChatMessageStatus,
|
||||
ChatMessageStreaming,
|
||||
ApiStatusType,
|
||||
Job,
|
||||
JobRequirements,
|
||||
JobRequirementsMessage,
|
||||
Tunables,
|
||||
)
|
||||
from logger import logger
|
||||
import backstory_traceback as traceback
|
||||
|
||||
|
||||
class JobRequirementsAgent(Agent):
|
||||
agent_type: Literal["job_requirements"] = "job_requirements" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
@ -94,14 +104,14 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
|
||||
"""Analyze job requirements from job description."""
|
||||
system_prompt, prompt = self.create_job_analysis_prompt(prompt)
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
content="Analyzing job requirements",
|
||||
activity=ApiActivityType.THINKING
|
||||
session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING
|
||||
)
|
||||
yield status_message
|
||||
logger.info(f"🔍 {status_message.content}")
|
||||
generated_message = None
|
||||
async for generated_message in self.llm_one_shot(llm, model, session_id=session_id, prompt=prompt, system_prompt=system_prompt):
|
||||
async for generated_message in self.llm_one_shot(
|
||||
llm, model, session_id=session_id, prompt=prompt, system_prompt=system_prompt
|
||||
):
|
||||
if generated_message.status == ApiStatusType.ERROR:
|
||||
yield generated_message
|
||||
return
|
||||
@ -110,8 +120,8 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
|
||||
|
||||
if not generated_message:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Job requirements analysis failed to generate a response.")
|
||||
session_id=session_id, content="Job requirements analysis failed to generate a response."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
return
|
||||
@ -132,18 +142,18 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
|
||||
display = {
|
||||
"technical_skills": {
|
||||
"required": reqs.technical_skills.required,
|
||||
"preferred": reqs.technical_skills.preferred
|
||||
"preferred": reqs.technical_skills.preferred,
|
||||
},
|
||||
"experience_requirements": {
|
||||
"required": reqs.experience_requirements.required,
|
||||
"preferred": reqs.experience_requirements.preferred
|
||||
"preferred": reqs.experience_requirements.preferred,
|
||||
},
|
||||
"soft_skills": reqs.soft_skills,
|
||||
"experience": reqs.experience,
|
||||
"education": reqs.education,
|
||||
"certifications": reqs.certifications,
|
||||
"preferred_attributes": reqs.preferred_attributes,
|
||||
"company_values": reqs.company_values
|
||||
"company_values": reqs.company_values,
|
||||
}
|
||||
|
||||
return display
|
||||
@ -152,19 +162,14 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
|
||||
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
|
||||
) -> AsyncGenerator[ApiMessage, None]:
|
||||
if not self.user:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="User is not set for this agent."
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content="User is not set for this agent.")
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
# Stage 1A: Analyze job requirements
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=session_id,
|
||||
content = f"Analyzing job requirements",
|
||||
activity=ApiActivityType.THINKING
|
||||
session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING
|
||||
)
|
||||
yield status_message
|
||||
|
||||
@ -178,8 +183,7 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
|
||||
|
||||
if not generated_message:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Job requirements analysis failed to generate a response."
|
||||
session_id=session_id, content="Job requirements analysis failed to generate a response."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
@ -214,7 +218,9 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
|
||||
return
|
||||
except Exception as e:
|
||||
status_message.status = ApiStatusType.ERROR
|
||||
status_message.content = f"Unexpected error processing job requirements: {str(e)}\n\n{job_requirements_data}"
|
||||
status_message.content = (
|
||||
f"Unexpected error processing job requirements: {str(e)}\n\n{job_requirements_data}"
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"⚠️ {status_message.content}")
|
||||
yield status_message
|
||||
@ -238,8 +244,9 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
|
||||
job=job,
|
||||
)
|
||||
yield job_requirements_message
|
||||
logger.info(f"✅ Job requirements analysis completed successfully.")
|
||||
logger.info("✅ Job requirements analysis completed successfully.")
|
||||
return
|
||||
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(JobRequirementsAgent._agent_type, JobRequirementsAgent)
|
||||
|
@ -5,20 +5,19 @@ from .base import Agent, agent_registry
|
||||
from logger import logger
|
||||
|
||||
from .registry import agent_registry
|
||||
from models import ( ApiMessage, ApiStatusType, ChatMessageError, ChatMessageRagSearch, ApiStatusType, Tunables )
|
||||
from models import ApiMessage, ApiStatusType, ChatMessageError, ChatMessageRagSearch, ApiStatusType, Tunables
|
||||
|
||||
|
||||
class Chat(Agent):
|
||||
"""
|
||||
Chat Agent
|
||||
"""
|
||||
agent_type: Literal["rag_search"] = "rag_search" # type: ignore
|
||||
|
||||
agent_type: Literal["rag_search"] = "rag_search" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
async def generate(
|
||||
self, llm: Any, model: str,
|
||||
session_id: str, prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7
|
||||
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
|
||||
) -> AsyncGenerator[ApiMessage, None]:
|
||||
"""
|
||||
Generate a response based on the user message and the provided LLM.
|
||||
@ -44,8 +43,7 @@ class Chat(Agent):
|
||||
if not isinstance(rag_message, ChatMessageRagSearch):
|
||||
logger.error(f"Expected ChatMessageRagSearch, got {type(rag_message)}")
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="RAG search did not return a valid response."
|
||||
session_id=session_id, content="RAG search did not return a valid response."
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
@ -53,5 +51,6 @@ class Chat(Agent):
|
||||
rag_message.status = ApiStatusType.DONE
|
||||
yield rag_message
|
||||
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(Chat._agent_type, Chat)
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Dict, Optional, Type
|
||||
|
||||
|
||||
# We'll use a registry pattern rather than hardcoded strings
|
||||
class AgentRegistry:
|
||||
"""Registry for agent types and classes"""
|
||||
|
@ -1,26 +1,32 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import model_validator, Field # type: ignore
|
||||
from typing import (
|
||||
Dict,
|
||||
Literal,
|
||||
ClassVar,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
List,
|
||||
Optional
|
||||
# override
|
||||
Optional,
|
||||
# override
|
||||
) # NOTE: You must import Optional for late binding to work
|
||||
import json
|
||||
import numpy as np # type: ignore
|
||||
|
||||
from .base import Agent, agent_registry
|
||||
from models import (ApiMessage, ChatMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageSkillAssessment, ApiStatusType, EvidenceDetail,
|
||||
SkillAssessment, Tunables)
|
||||
from models import (
|
||||
ApiMessage,
|
||||
ChatMessage,
|
||||
ChatMessageError,
|
||||
ChatMessageRagSearch,
|
||||
ChatMessageSkillAssessment,
|
||||
ApiStatusType,
|
||||
EvidenceDetail,
|
||||
SkillAssessment,
|
||||
Tunables,
|
||||
)
|
||||
from logger import logger
|
||||
import backstory_traceback as traceback
|
||||
|
||||
|
||||
class SkillMatchAgent(Agent):
|
||||
agent_type: Literal["skill_match"] = "skill_match" # type: ignore
|
||||
agent_type: Literal["skill_match"] = "skill_match" # type: ignore
|
||||
_agent_type: ClassVar[str] = agent_type # Add this for registration
|
||||
|
||||
def generate_skill_assessment_prompt(self, skill, rag_context):
|
||||
@ -100,15 +106,12 @@ JSON RESPONSE:"""
|
||||
return system_prompt, prompt
|
||||
|
||||
async def generate(
|
||||
self, llm: Any, model: str,
|
||||
session_id: str, prompt: str,
|
||||
tunables: Optional[Tunables] = None,
|
||||
temperature=0.7
|
||||
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
|
||||
) -> AsyncGenerator[ApiMessage, None]:
|
||||
if not self.user:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Agent not attached to user. Attach the agent to a user before generating responses."
|
||||
content="Agent not attached to user. Attach the agent to a user before generating responses.",
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
@ -116,10 +119,7 @@ JSON RESPONSE:"""
|
||||
|
||||
skill = prompt.strip()
|
||||
if not skill:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Skill cannot be empty."
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content="Skill cannot be empty.")
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
return
|
||||
@ -134,8 +134,7 @@ JSON RESPONSE:"""
|
||||
|
||||
if generated_message is None:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="RAG search did not return a valid response."
|
||||
session_id=session_id, content="RAG search did not return a valid response."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
@ -144,19 +143,20 @@ JSON RESPONSE:"""
|
||||
if not isinstance(generated_message, ChatMessageRagSearch):
|
||||
logger.error(f"Expected ChatMessageRagSearch, got {type(generated_message)}")
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="RAG search did not return a valid response."
|
||||
session_id=session_id, content="RAG search did not return a valid response."
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
rag_message : ChatMessageRagSearch = generated_message
|
||||
rag_message: ChatMessageRagSearch = generated_message
|
||||
|
||||
rag_context = self.get_rag_context(rag_message)
|
||||
logger.info(f"🔍 RAG content retrieved {len(rag_context)} bytes of context")
|
||||
system_prompt, prompt = self.generate_skill_assessment_prompt(skill=skill, rag_context=rag_context)
|
||||
|
||||
generated_message = None
|
||||
async for generated_message in self.llm_one_shot(llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt, temperature=0.7):
|
||||
async for generated_message in self.llm_one_shot(
|
||||
llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt, temperature=0.7
|
||||
):
|
||||
if generated_message.status == ApiStatusType.ERROR:
|
||||
logger.error(f"⚠️ {generated_message.content}")
|
||||
yield generated_message
|
||||
@ -166,8 +166,7 @@ JSON RESPONSE:"""
|
||||
|
||||
if generated_message is None:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Skill assessment failed to generate a response."
|
||||
session_id=session_id, content="Skill assessment failed to generate a response."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
@ -175,8 +174,7 @@ JSON RESPONSE:"""
|
||||
|
||||
if not isinstance(generated_message, ChatMessage):
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Skill assessment did not return a valid message."
|
||||
session_id=session_id, content="Skill assessment did not return a valid message."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
@ -199,14 +197,15 @@ JSON RESPONSE:"""
|
||||
EvidenceDetail(
|
||||
source=evidence.get("source", ""),
|
||||
quote=evidence.get("quote", ""),
|
||||
context=evidence.get("context", "")
|
||||
) for evidence in skill_assessment_data.get("evidence_details", [])
|
||||
]
|
||||
context=evidence.get("context", ""),
|
||||
)
|
||||
for evidence in skill_assessment_data.get("evidence_details", [])
|
||||
],
|
||||
)
|
||||
except Exception as e:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content=f"Failed to parse Skill assessment JSON: {str(e)}\n\n{generated_message.content}\n\nJSON:\n{json_str}\n\n"
|
||||
content=f"Failed to parse Skill assessment JSON: {str(e)}\n\n{generated_message.content}\n\nJSON:\n{json_str}\n\n",
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
@ -233,8 +232,9 @@ JSON RESPONSE:"""
|
||||
skill_assessment=skill_assessment,
|
||||
)
|
||||
yield skill_assessment_message
|
||||
logger.info(f"✅ Skill assessment completed successfully.")
|
||||
logger.info("✅ Skill assessment completed successfully.")
|
||||
return
|
||||
|
||||
|
||||
# Register the base agent
|
||||
agent_registry.register(SkillMatchAgent._agent_type, SkillMatchAgent)
|
||||
|
@ -9,6 +9,7 @@ from typing import Optional, List, Dict, Any, Callable
|
||||
from logger import logger
|
||||
from database.manager import DatabaseManager
|
||||
|
||||
|
||||
class BackgroundTaskManager:
|
||||
"""Manages background tasks for the application using asyncio instead of threading"""
|
||||
|
||||
@ -65,10 +66,12 @@ class BackgroundTaskManager:
|
||||
stats = await database.get_guest_statistics()
|
||||
|
||||
# Log interesting statistics
|
||||
if stats.get('total_guests', 0) > 0:
|
||||
logger.info(f"📊 Guest stats: {stats['total_guests']} total, "
|
||||
f"{stats['active_last_hour']} active in last hour, "
|
||||
f"{stats['converted_guests']} converted")
|
||||
if stats.get("total_guests", 0) > 0:
|
||||
logger.info(
|
||||
f"📊 Guest stats: {stats['total_guests']} total, "
|
||||
f"{stats['active_last_hour']} active in last hour, "
|
||||
f"{stats['converted_guests']} converted"
|
||||
)
|
||||
|
||||
return stats
|
||||
except Exception as e:
|
||||
@ -84,6 +87,7 @@ class BackgroundTaskManager:
|
||||
|
||||
# Get Redis client safely (using the event loop safe method)
|
||||
from database.manager import redis_manager
|
||||
|
||||
redis = await redis_manager.get_client()
|
||||
|
||||
# Clean up rate limit keys older than specified days
|
||||
@ -101,7 +105,7 @@ class BackgroundTaskManager:
|
||||
try:
|
||||
ttl = await redis.ttl(key)
|
||||
if ttl == -1: # No expiration set, check creation time
|
||||
creation_time = await redis.hget(key, "created_at") # type: ignore
|
||||
creation_time = await redis.hget(key, "created_at") # type: ignore
|
||||
if creation_time:
|
||||
creation_time = datetime.fromisoformat(creation_time).replace(tzinfo=UTC)
|
||||
if creation_time < cutoff_time:
|
||||
@ -192,10 +196,7 @@ class BackgroundTaskManager:
|
||||
|
||||
# Create asyncio tasks for each periodic task
|
||||
for name, func, interval, *args in periodic_tasks:
|
||||
task = asyncio.create_task(
|
||||
self._run_periodic_task(name, func, interval, *args),
|
||||
name=f"background_{name}"
|
||||
)
|
||||
task = asyncio.create_task(self._run_periodic_task(name, func, interval, *args), name=f"background_{name}")
|
||||
self.tasks.append(task)
|
||||
logger.info(f"📅 Scheduled background task: {name}")
|
||||
|
||||
@ -238,10 +239,7 @@ class BackgroundTaskManager:
|
||||
# Wait for all tasks to complete with timeout
|
||||
if self.tasks:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*self.tasks, return_exceptions=True),
|
||||
timeout=30.0
|
||||
)
|
||||
await asyncio.wait_for(asyncio.gather(*self.tasks, return_exceptions=True), timeout=30.0)
|
||||
logger.info("✅ All background tasks stopped gracefully")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("⚠️ Some background tasks did not stop within timeout")
|
||||
@ -258,7 +256,7 @@ class BackgroundTaskManager:
|
||||
"main_loop_id": id(self.main_loop) if self.main_loop else None,
|
||||
"current_loop_id": None,
|
||||
"task_count": len(self.tasks),
|
||||
"tasks": []
|
||||
"tasks": [],
|
||||
}
|
||||
|
||||
try:
|
||||
@ -317,6 +315,7 @@ async def setup_background_tasks(database_manager: DatabaseManager) -> Backgroun
|
||||
await task_manager.start()
|
||||
return task_manager
|
||||
|
||||
|
||||
# For integration with your existing app startup
|
||||
async def initialize_with_background_tasks(database_manager: DatabaseManager):
|
||||
"""Initialize database and background tasks together"""
|
||||
|
@ -3,6 +3,7 @@ import os
|
||||
import sys
|
||||
import defines
|
||||
|
||||
|
||||
def filter_traceback(tb, app_path=None, module_name=None):
|
||||
"""
|
||||
Filter traceback to include only frames from the specified application path or module.
|
||||
@ -36,7 +37,8 @@ def filter_traceback(tb, app_path=None, module_name=None):
|
||||
formatted_exc = traceback.format_exception_only(exc_type, exc_value)
|
||||
|
||||
# Combine the filtered stack trace with the exception message
|
||||
return ''.join(formatted_stack + formatted_exc)
|
||||
return "".join(formatted_stack + formatted_exc)
|
||||
|
||||
|
||||
def format_exc(app_path=defines.app_path, module_name=None):
|
||||
"""
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .core import RedisDatabase
|
||||
from .manager import DatabaseManager, redis_manager
|
||||
|
||||
__all__ = ['RedisDatabase', 'DatabaseManager', 'redis_manager']
|
||||
__all__ = ["RedisDatabase", "DatabaseManager", "redis_manager"]
|
||||
|
@ -1,15 +1,15 @@
|
||||
KEY_PREFIXES = {
|
||||
'viewers': 'viewer:',
|
||||
'candidates': 'candidate:',
|
||||
'employers': 'employer:',
|
||||
'jobs': 'job:',
|
||||
'job_applications': 'job_application:',
|
||||
'chat_sessions': 'chat_session:',
|
||||
'chat_messages': 'chat_messages:',
|
||||
'ai_parameters': 'ai_parameters:',
|
||||
'users': 'user:',
|
||||
'candidate_documents': 'candidate_documents:',
|
||||
'job_requirements': 'job_requirements:',
|
||||
'resumes': 'resume:',
|
||||
'user_resumes': 'user_resumes:',
|
||||
"viewers": "viewer:",
|
||||
"candidates": "candidate:",
|
||||
"employers": "employer:",
|
||||
"jobs": "job:",
|
||||
"job_applications": "job_application:",
|
||||
"chat_sessions": "chat_session:",
|
||||
"chat_messages": "chat_messages:",
|
||||
"ai_parameters": "ai_parameters:",
|
||||
"users": "user:",
|
||||
"candidate_documents": "candidate_documents:",
|
||||
"job_requirements": "job_requirements:",
|
||||
"resumes": "resume:",
|
||||
"user_resumes": "user_resumes:",
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ from .mixins.job import JobMixin
|
||||
from .mixins.skill import SkillMixin
|
||||
from .mixins.ai import AIMixin
|
||||
|
||||
|
||||
# RedisDatabase is the main class that combines all mixins for a
|
||||
# comprehensive Redis database interface.
|
||||
class RedisDatabase(
|
||||
|
@ -1,18 +1,15 @@
|
||||
from redis.asyncio import (Redis, ConnectionPool)
|
||||
from redis.asyncio import Redis, ConnectionPool
|
||||
from typing import Optional, Optional
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, UTC
|
||||
import asyncio
|
||||
from models import (
|
||||
# User models
|
||||
Candidate, Employer, BaseUser, EvidenceDetail, Guest, Authentication, AuthResponse, SkillAssessment,
|
||||
)
|
||||
from .core import RedisDatabase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# _RedisManager is a singleton class that manages the Redis connection and
|
||||
# provides methods for connecting, disconnecting, and performing health checks.
|
||||
#
|
||||
@ -46,12 +43,10 @@ class _RedisManager:
|
||||
retry_on_timeout=True,
|
||||
socket_keepalive=True,
|
||||
socket_keepalive_options={},
|
||||
health_check_interval=30
|
||||
health_check_interval=30,
|
||||
)
|
||||
|
||||
self.redis = Redis(
|
||||
connection_pool=self._connection_pool
|
||||
)
|
||||
self.redis = Redis(connection_pool=self._connection_pool)
|
||||
|
||||
if not self.redis:
|
||||
raise RuntimeError("Redis client not initialized")
|
||||
@ -135,7 +130,7 @@ class _RedisManager:
|
||||
"uptime_seconds": info.get("uptime_in_seconds", 0),
|
||||
"connected_clients": info.get("connected_clients", 0),
|
||||
"used_memory_human": info.get("used_memory_human", "unknown"),
|
||||
"total_commands_processed": info.get("total_commands_processed", 0)
|
||||
"total_commands_processed": info.get("total_commands_processed", 0),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Redis health check failed: {e}")
|
||||
@ -177,9 +172,11 @@ class _RedisManager:
|
||||
logger.error(f"Failed to get Redis info: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# Global Redis manager instance
|
||||
redis_manager = _RedisManager()
|
||||
|
||||
|
||||
# DatabaseManager is an enhanced database manager that provides graceful shutdown capabilities
|
||||
# It manages the Redis connection, tracks active requests, and allows for data backup before shutdown.
|
||||
class DatabaseManager:
|
||||
@ -231,7 +228,7 @@ class DatabaseManager:
|
||||
backup_filename = f"backup_{datetime.now(UTC).strftime('%Y%m%d_%H%M%S')}.json"
|
||||
|
||||
# Save to local file (you might want to save to cloud storage instead)
|
||||
with open(backup_filename, 'w') as f:
|
||||
with open(backup_filename, "w") as f:
|
||||
json.dump(backup_data, f, indent=2, default=str)
|
||||
|
||||
logger.info(f"Backup created: {backup_filename}")
|
||||
@ -314,5 +311,3 @@ class DatabaseManager:
|
||||
if self._shutdown_initiated:
|
||||
raise RuntimeError("Application is shutting down")
|
||||
return self.db
|
||||
|
||||
|
@ -8,8 +8,10 @@ from ..constants import KEY_PREFIXES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIMixin(DatabaseProtocol):
|
||||
"""Mixin for AI operations"""
|
||||
|
||||
async def get_ai_parameters(self, param_id: str) -> Optional[Dict]:
|
||||
"""Get AI parameters by ID"""
|
||||
key = f"{KEY_PREFIXES['ai_parameters']}{param_id}"
|
||||
@ -36,7 +38,7 @@ class AIMixin(DatabaseProtocol):
|
||||
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
param_id = key.replace(KEY_PREFIXES['ai_parameters'], '')
|
||||
param_id = key.replace(KEY_PREFIXES["ai_parameters"], "")
|
||||
result[param_id] = self._deserialize(value)
|
||||
|
||||
return result
|
||||
@ -45,4 +47,3 @@ class AIMixin(DatabaseProtocol):
|
||||
"""Delete AI parameters"""
|
||||
key = f"{KEY_PREFIXES['ai_parameters']}{param_id}"
|
||||
await self.redis.delete(key)
|
||||
|
||||
|
@ -7,6 +7,6 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnalyticsMixin:
|
||||
"""Mixin for analytics-related database operations"""
|
||||
|
||||
|
@ -8,6 +8,7 @@ from .protocols import DatabaseProtocol
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthMixin(DatabaseProtocol):
|
||||
"""Mixin for auth-related database operations"""
|
||||
|
||||
@ -25,8 +26,9 @@ class AuthMixin(DatabaseProtocol):
|
||||
token_data = await self.redis.get(key)
|
||||
if token_data:
|
||||
verification_info = json.loads(token_data)
|
||||
if (verification_info.get("email", "").lower() == email_lower and
|
||||
not verification_info.get("verified", False)):
|
||||
if verification_info.get("email", "").lower() == email_lower and not verification_info.get(
|
||||
"verified", False
|
||||
):
|
||||
# Extract token from key
|
||||
token = key.replace("email_verification:", "")
|
||||
verification_info["token"] = token
|
||||
@ -115,10 +117,7 @@ class AuthMixin(DatabaseProtocol):
|
||||
window_start = current_time - timedelta(hours=24)
|
||||
|
||||
# Filter out old attempts
|
||||
recent_attempts = [
|
||||
attempt for attempt in attempts_data
|
||||
if datetime.fromisoformat(attempt) > window_start
|
||||
]
|
||||
recent_attempts = [attempt for attempt in attempts_data if datetime.fromisoformat(attempt) > window_start]
|
||||
|
||||
return len(recent_attempts)
|
||||
|
||||
@ -141,16 +140,13 @@ class AuthMixin(DatabaseProtocol):
|
||||
|
||||
# Keep only last 24 hours of attempts
|
||||
window_start = current_time - timedelta(hours=24)
|
||||
recent_attempts = [
|
||||
attempt for attempt in attempts_data
|
||||
if datetime.fromisoformat(attempt) > window_start
|
||||
]
|
||||
recent_attempts = [attempt for attempt in attempts_data if datetime.fromisoformat(attempt) > window_start]
|
||||
|
||||
# Store with 24 hour expiration
|
||||
await self.redis.setex(
|
||||
key,
|
||||
24 * 60 * 60, # 24 hours
|
||||
json.dumps(recent_attempts)
|
||||
json.dumps(recent_attempts),
|
||||
)
|
||||
|
||||
return True
|
||||
@ -169,14 +165,14 @@ class AuthMixin(DatabaseProtocol):
|
||||
"user_data": user_data,
|
||||
"expires_at": (datetime.now(timezone.utc) + timedelta(hours=24)).isoformat(),
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"verified": False
|
||||
"verified": False,
|
||||
}
|
||||
|
||||
# Store with 24 hour expiration
|
||||
await self.redis.setex(
|
||||
key,
|
||||
24 * 60 * 60, # 24 hours in seconds
|
||||
json.dumps(verification_data, default=str)
|
||||
json.dumps(verification_data, default=str),
|
||||
)
|
||||
|
||||
logger.info(f"📧 Stored email verification token for {email}")
|
||||
@ -208,7 +204,7 @@ class AuthMixin(DatabaseProtocol):
|
||||
await self.redis.setex(
|
||||
key,
|
||||
24 * 60 * 60, # Keep for remaining TTL
|
||||
json.dumps(token_data, default=str)
|
||||
json.dumps(token_data, default=str),
|
||||
)
|
||||
return True
|
||||
return False
|
||||
@ -219,7 +215,7 @@ class AuthMixin(DatabaseProtocol):
|
||||
async def store_mfa_code(self, email: str, code: str, device_id: str) -> bool:
|
||||
"""Store MFA code for verification"""
|
||||
try:
|
||||
logger.info("🔐 Storing MFA code for email: %s", email )
|
||||
logger.info("🔐 Storing MFA code for email: %s", email)
|
||||
key = f"mfa_code:{email.lower()}:{device_id}"
|
||||
mfa_data = {
|
||||
"code": code,
|
||||
@ -228,14 +224,14 @@ class AuthMixin(DatabaseProtocol):
|
||||
"expires_at": (datetime.now(timezone.utc) + timedelta(minutes=10)).isoformat(),
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"attempts": 0,
|
||||
"verified": False
|
||||
"verified": False,
|
||||
}
|
||||
|
||||
# Store with 10 minute expiration
|
||||
await self.redis.setex(
|
||||
key,
|
||||
10 * 60, # 10 minutes in seconds
|
||||
json.dumps(mfa_data, default=str)
|
||||
json.dumps(mfa_data, default=str),
|
||||
)
|
||||
|
||||
logger.info(f"🔐 Stored MFA code for {email}")
|
||||
@ -266,7 +262,7 @@ class AuthMixin(DatabaseProtocol):
|
||||
await self.redis.setex(
|
||||
key,
|
||||
10 * 60, # Keep original TTL
|
||||
json.dumps(mfa_data, default=str)
|
||||
json.dumps(mfa_data, default=str),
|
||||
)
|
||||
return mfa_data["attempts"]
|
||||
return 0
|
||||
@ -285,7 +281,7 @@ class AuthMixin(DatabaseProtocol):
|
||||
await self.redis.setex(
|
||||
key,
|
||||
10 * 60, # Keep for remaining TTL
|
||||
json.dumps(mfa_data, default=str)
|
||||
json.dumps(mfa_data, default=str),
|
||||
)
|
||||
return True
|
||||
return False
|
||||
@ -327,7 +323,9 @@ class AuthMixin(DatabaseProtocol):
|
||||
logger.error(f"❌ Error deleting authentication record for {user_id}: {e}")
|
||||
return False
|
||||
|
||||
async def store_refresh_token(self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]) -> bool:
|
||||
async def store_refresh_token(
|
||||
self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]
|
||||
) -> bool:
|
||||
"""Store refresh token for a user"""
|
||||
try:
|
||||
key = f"refresh_token:{token}"
|
||||
@ -337,7 +335,7 @@ class AuthMixin(DatabaseProtocol):
|
||||
"device": device_info.get("device", "unknown"),
|
||||
"ip_address": device_info.get("ip_address", "unknown"),
|
||||
"is_revoked": False,
|
||||
"created_at": datetime.now(timezone.utc).isoformat()
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
# Store with expiration
|
||||
@ -374,7 +372,7 @@ class AuthMixin(DatabaseProtocol):
|
||||
token_data["is_revoked"] = True
|
||||
token_data["revoked_at"] = datetime.now(timezone.utc).isoformat()
|
||||
await self.redis.set(key, json.dumps(token_data, default=str))
|
||||
logger.info(f"🔐 Revoked refresh token")
|
||||
logger.info("🔐 Revoked refresh token")
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
@ -420,7 +418,7 @@ class AuthMixin(DatabaseProtocol):
|
||||
"email": email.lower(),
|
||||
"expires_at": expires_at.isoformat(),
|
||||
"used": False,
|
||||
"created_at": datetime.now(timezone.utc).isoformat()
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
# Store with expiration
|
||||
@ -457,7 +455,7 @@ class AuthMixin(DatabaseProtocol):
|
||||
token_data["used"] = True
|
||||
token_data["used_at"] = datetime.now(timezone.utc).isoformat()
|
||||
await self.redis.set(key, json.dumps(token_data, default=str))
|
||||
logger.info(f"🔐 Marked password reset token as used")
|
||||
logger.info("🔐 Marked password reset token as used")
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
@ -473,14 +471,14 @@ class AuthMixin(DatabaseProtocol):
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"user_id": user_id,
|
||||
"event_type": event_type,
|
||||
"details": details
|
||||
"details": details,
|
||||
}
|
||||
|
||||
# Add to list (latest events first)
|
||||
await self.redis.lpush(key, json.dumps(event_data, default=str))# type: ignore
|
||||
await self.redis.lpush(key, json.dumps(event_data, default=str)) # type: ignore
|
||||
|
||||
# Keep only last 100 events per day
|
||||
await self.redis.ltrim(key, 0, 99)# type: ignore
|
||||
await self.redis.ltrim(key, 0, 99) # type: ignore
|
||||
|
||||
# Set expiration for 30 days
|
||||
await self.redis.expire(key, 30 * 24 * 60 * 60)
|
||||
@ -496,10 +494,10 @@ class AuthMixin(DatabaseProtocol):
|
||||
try:
|
||||
events = []
|
||||
for i in range(days):
|
||||
date = (datetime.now(timezone.utc) - timedelta(days=i)).strftime('%Y-%m-%d')
|
||||
date = (datetime.now(timezone.utc) - timedelta(days=i)).strftime("%Y-%m-%d")
|
||||
key = f"security_log:{user_id}:{date}"
|
||||
|
||||
daily_events = await self.redis.lrange(key, 0, -1)# type: ignore
|
||||
daily_events = await self.redis.lrange(key, 0, -1) # type: ignore
|
||||
for event_json in daily_events:
|
||||
events.append(json.loads(event_json))
|
||||
|
||||
@ -509,4 +507,3 @@ class AuthMixin(DatabaseProtocol):
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error retrieving security log for {user_id}: {e}")
|
||||
return []
|
||||
|
||||
|
@ -5,11 +5,13 @@ from typing import Any, Dict, TYPE_CHECKING
|
||||
from .protocols import DatabaseProtocol
|
||||
|
||||
from ..constants import KEY_PREFIXES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class BaseMixin(DatabaseProtocol):
|
||||
"""Base mixin with core Redis operations and utilities"""
|
||||
|
||||
@ -45,6 +47,3 @@ class BaseMixin(DatabaseProtocol):
|
||||
keys = await self.redis.keys(pattern)
|
||||
if keys:
|
||||
await self.redis.delete(*keys)
|
||||
|
||||
|
||||
|
@ -8,6 +8,7 @@ from ..constants import KEY_PREFIXES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatMixin(DatabaseProtocol):
|
||||
"""Mixin for chat-related database operations"""
|
||||
|
||||
@ -22,7 +23,7 @@ class ChatMixin(DatabaseProtocol):
|
||||
"total_sessions": 0,
|
||||
"total_messages": 0,
|
||||
"first_chat": None,
|
||||
"last_chat": None
|
||||
"last_chat": None,
|
||||
}
|
||||
|
||||
total_messages = 0
|
||||
@ -41,7 +42,7 @@ class ChatMixin(DatabaseProtocol):
|
||||
"total_messages": total_messages,
|
||||
"first_chat": sessions_by_date[0].get("createdAt") if sessions_by_date else None,
|
||||
"last_chat": sessions_by_date[-1].get("lastActivity") if sessions_by_date else None,
|
||||
"recent_sessions": sessions[:5] # Last 5 sessions
|
||||
"recent_sessions": sessions[:5], # Last 5 sessions
|
||||
}
|
||||
|
||||
# Chat Sessions operations
|
||||
@ -71,13 +72,13 @@ class ChatMixin(DatabaseProtocol):
|
||||
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
session_id = key.replace(KEY_PREFIXES['chat_sessions'], '')
|
||||
session_id = key.replace(KEY_PREFIXES["chat_sessions"], "")
|
||||
result[session_id] = self._deserialize(value)
|
||||
|
||||
return result
|
||||
|
||||
async def delete_chat_session(self, session_id: str) -> bool:
|
||||
'''Delete a chat session from Redis'''
|
||||
"""Delete a chat session from Redis"""
|
||||
try:
|
||||
result = await self.redis.delete(f"chat_session:{session_id}")
|
||||
return result > 0
|
||||
@ -86,11 +87,11 @@ class ChatMixin(DatabaseProtocol):
|
||||
raise
|
||||
|
||||
async def delete_chat_message(self, session_id: str, message_id: str) -> bool:
|
||||
'''Delete a specific chat message from Redis'''
|
||||
"""Delete a specific chat message from Redis"""
|
||||
try:
|
||||
# Remove from the session's message list
|
||||
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
|
||||
await self.redis.lrem(key, 0, message_id)# type: ignore
|
||||
await self.redis.lrem(key, 0, message_id) # type: ignore
|
||||
# Delete the message data itself
|
||||
result = await self.redis.delete(f"chat_message:{message_id}")
|
||||
return result > 0
|
||||
@ -102,13 +103,13 @@ class ChatMixin(DatabaseProtocol):
|
||||
async def get_chat_messages(self, session_id: str) -> List[Dict]:
|
||||
"""Get chat messages for a session"""
|
||||
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
|
||||
messages = await self.redis.lrange(key, 0, -1)# type: ignore
|
||||
return [self._deserialize(msg) for msg in messages if msg] # type: ignore
|
||||
messages = await self.redis.lrange(key, 0, -1) # type: ignore
|
||||
return [self._deserialize(msg) for msg in messages if msg] # type: ignore
|
||||
|
||||
async def add_chat_message(self, session_id: str, message_data: Dict):
|
||||
"""Add a chat message to a session"""
|
||||
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
|
||||
await self.redis.rpush(key, self._serialize(message_data))# type: ignore
|
||||
await self.redis.rpush(key, self._serialize(message_data)) # type: ignore
|
||||
|
||||
async def set_chat_messages(self, session_id: str, messages: List[Dict]):
|
||||
"""Set all chat messages for a session (replaces existing)"""
|
||||
@ -120,7 +121,7 @@ class ChatMixin(DatabaseProtocol):
|
||||
# Add new messages
|
||||
if messages:
|
||||
serialized_messages = [self._serialize(msg) for msg in messages]
|
||||
await self.redis.rpush(key, *serialized_messages)# type: ignore
|
||||
await self.redis.rpush(key, *serialized_messages) # type: ignore
|
||||
|
||||
async def get_all_chat_messages(self) -> Dict[str, List[Dict]]:
|
||||
"""Get all chat messages grouped by session"""
|
||||
@ -132,8 +133,8 @@ class ChatMixin(DatabaseProtocol):
|
||||
|
||||
result = {}
|
||||
for key in keys:
|
||||
session_id = key.replace(KEY_PREFIXES['chat_messages'], '')
|
||||
messages = await self.redis.lrange(key, 0, -1)# type: ignore
|
||||
session_id = key.replace(KEY_PREFIXES["chat_messages"], "")
|
||||
messages = await self.redis.lrange(key, 0, -1) # type: ignore
|
||||
result[session_id] = [self._deserialize(msg) for msg in messages if msg]
|
||||
|
||||
return result
|
||||
@ -164,8 +165,7 @@ class ChatMixin(DatabaseProtocol):
|
||||
|
||||
for session_data in all_sessions.values():
|
||||
context = session_data.get("context", {})
|
||||
if (context.get("relatedEntityType") == "candidate" and
|
||||
context.get("relatedEntityId") == candidate_id):
|
||||
if context.get("relatedEntityType") == "candidate" and context.get("relatedEntityId") == candidate_id:
|
||||
candidate_sessions.append(session_data)
|
||||
|
||||
# Sort by last activity (most recent first)
|
||||
@ -188,7 +188,7 @@ class ChatMixin(DatabaseProtocol):
|
||||
async def get_chat_message_count(self, session_id: str) -> int:
|
||||
"""Get the total number of messages in a chat session"""
|
||||
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
|
||||
return await self.redis.llen(key)# type: ignore
|
||||
return await self.redis.llen(key) # type: ignore
|
||||
|
||||
async def search_chat_messages(self, session_id: str, query: str) -> List[Dict]:
|
||||
"""Search for messages containing specific text in a session"""
|
||||
@ -236,7 +236,6 @@ class ChatMixin(DatabaseProtocol):
|
||||
|
||||
return archived_count
|
||||
|
||||
|
||||
# Analytics and Reporting
|
||||
async def get_chat_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive chat statistics"""
|
||||
@ -250,7 +249,7 @@ class ChatMixin(DatabaseProtocol):
|
||||
"archived_sessions": 0,
|
||||
"sessions_by_type": {},
|
||||
"sessions_with_candidates": 0,
|
||||
"average_messages_per_session": 0
|
||||
"average_messages_per_session": 0,
|
||||
}
|
||||
|
||||
# Analyze sessions
|
||||
|
@ -7,6 +7,7 @@ from ..constants import KEY_PREFIXES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentMixin(DatabaseProtocol):
|
||||
"""Mixin for document-related database operations"""
|
||||
|
||||
@ -31,7 +32,7 @@ class DocumentMixin(DatabaseProtocol):
|
||||
try:
|
||||
# Get all document IDs for this candidate
|
||||
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
|
||||
document_ids = await self.redis.lrange(key, 0, -1)# type: ignore
|
||||
document_ids = await self.redis.lrange(key, 0, -1) # type: ignore
|
||||
|
||||
if not document_ids:
|
||||
logger.info(f"No documents found for candidate {candidate_id}")
|
||||
@ -64,7 +65,7 @@ class DocumentMixin(DatabaseProtocol):
|
||||
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]:
|
||||
"""Get all documents for a specific candidate"""
|
||||
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
|
||||
document_ids = await self.redis.lrange(key, 0, -1) # type: ignore
|
||||
document_ids = await self.redis.lrange(key, 0, -1) # type: ignore
|
||||
|
||||
if not document_ids:
|
||||
return []
|
||||
@ -83,7 +84,7 @@ class DocumentMixin(DatabaseProtocol):
|
||||
documents.append(doc_data)
|
||||
else:
|
||||
# Clean up orphaned document ID
|
||||
await self.redis.lrem(key, 0, doc_id)# type: ignore
|
||||
await self.redis.lrem(key, 0, doc_id) # type: ignore
|
||||
logger.warning(f"Removed orphaned document ID {doc_id} for candidate {candidate_id}")
|
||||
|
||||
return documents
|
||||
@ -91,12 +92,12 @@ class DocumentMixin(DatabaseProtocol):
|
||||
async def add_document_to_candidate(self, candidate_id: str, document_id: str):
|
||||
"""Add a document ID to a candidate's document list"""
|
||||
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
|
||||
await self.redis.rpush(key, document_id)# type: ignore
|
||||
await self.redis.rpush(key, document_id) # type: ignore
|
||||
|
||||
async def remove_document_from_candidate(self, candidate_id: str, document_id: str):
|
||||
"""Remove a document ID from a candidate's document list"""
|
||||
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
|
||||
await self.redis.lrem(key, 0, document_id)# type: ignore
|
||||
await self.redis.lrem(key, 0, document_id) # type: ignore
|
||||
|
||||
async def update_document(self, document_id: str, updates: Dict) -> Dict[Any, Any] | None:
|
||||
"""Update document metadata"""
|
||||
@ -128,7 +129,7 @@ class DocumentMixin(DatabaseProtocol):
|
||||
async def get_document_count_for_candidate(self, candidate_id: str) -> int:
|
||||
"""Get total number of documents for a candidate"""
|
||||
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
|
||||
return await self.redis.llen(key)# type: ignore
|
||||
return await self.redis.llen(key) # type: ignore
|
||||
|
||||
async def search_candidate_documents(self, candidate_id: str, query: str) -> List[Dict]:
|
||||
"""Search documents by filename for a candidate"""
|
||||
@ -136,8 +137,7 @@ class DocumentMixin(DatabaseProtocol):
|
||||
query_lower = query.lower()
|
||||
|
||||
return [
|
||||
doc for doc in all_documents
|
||||
if (query_lower in doc.get("filename", "").lower() or
|
||||
query_lower in doc.get("originalName", "").lower())
|
||||
doc
|
||||
for doc in all_documents
|
||||
if (query_lower in doc.get("filename", "").lower() or query_lower in doc.get("originalName", "").lower())
|
||||
]
|
||||
|
||||
|
@ -8,8 +8,10 @@ from ..constants import KEY_PREFIXES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JobMixin(DatabaseProtocol):
|
||||
"""Mixin for job-related database operations"""
|
||||
|
||||
async def get_job(self, job_id: str) -> Optional[Dict]:
|
||||
"""Get job by ID"""
|
||||
key = f"{KEY_PREFIXES['jobs']}{job_id}"
|
||||
@ -36,7 +38,7 @@ class JobMixin(DatabaseProtocol):
|
||||
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
job_id = key.replace(KEY_PREFIXES['jobs'], '')
|
||||
job_id = key.replace(KEY_PREFIXES["jobs"], "")
|
||||
result[job_id] = self._deserialize(value)
|
||||
|
||||
return result
|
||||
@ -124,7 +126,7 @@ class JobMixin(DatabaseProtocol):
|
||||
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
app_id = key.replace(KEY_PREFIXES['job_applications'], '')
|
||||
app_id = key.replace(KEY_PREFIXES["job_applications"], "")
|
||||
result[app_id] = self._deserialize(value)
|
||||
|
||||
return result
|
||||
@ -189,7 +191,7 @@ class JobMixin(DatabaseProtocol):
|
||||
requirements_with_meta = {
|
||||
**requirements,
|
||||
"cached_at": datetime.now(UTC).isoformat(),
|
||||
"document_id": document_id
|
||||
"document_id": document_id,
|
||||
}
|
||||
|
||||
await self.redis.set(key, self._serialize(requirements_with_meta))
|
||||
@ -232,7 +234,7 @@ class JobMixin(DatabaseProtocol):
|
||||
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
document_id = key.replace(KEY_PREFIXES['job_requirements'], '')
|
||||
document_id = key.replace(KEY_PREFIXES["job_requirements"], "")
|
||||
if value:
|
||||
result[document_id] = self._deserialize(value)
|
||||
|
||||
@ -247,11 +249,7 @@ class JobMixin(DatabaseProtocol):
|
||||
pattern = f"{KEY_PREFIXES['job_requirements']}*"
|
||||
keys = await self.redis.keys(pattern)
|
||||
|
||||
stats = {
|
||||
"total_cached_requirements": len(keys),
|
||||
"cache_dates": {},
|
||||
"documents_with_requirements": []
|
||||
}
|
||||
stats = {"total_cached_requirements": len(keys), "cache_dates": {}, "documents_with_requirements": []}
|
||||
|
||||
if keys:
|
||||
# Get cache dates for analysis
|
||||
@ -264,7 +262,7 @@ class JobMixin(DatabaseProtocol):
|
||||
if value:
|
||||
requirements_data = self._deserialize(value)
|
||||
if requirements_data:
|
||||
document_id = key.replace(KEY_PREFIXES['job_requirements'], '')
|
||||
document_id = key.replace(KEY_PREFIXES["job_requirements"], "")
|
||||
stats["documents_with_requirements"].append(document_id)
|
||||
|
||||
# Track cache dates
|
||||
@ -277,4 +275,3 @@ class JobMixin(DatabaseProtocol):
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting job requirements stats: {e}")
|
||||
return {"total_cached_requirements": 0, "cache_dates": {}, "documents_with_requirements": []}
|
||||
|
||||
|
@ -7,153 +7,412 @@ if TYPE_CHECKING:
|
||||
|
||||
from models import SkillAssessment
|
||||
|
||||
|
||||
class DatabaseProtocol(Protocol):
|
||||
# Base mixin
|
||||
redis: Redis
|
||||
def _serialize(self, data) -> str: ...
|
||||
def _deserialize(self, data: str): ...
|
||||
|
||||
def _serialize(self, data) -> str:
|
||||
...
|
||||
|
||||
def _deserialize(self, data: str):
|
||||
...
|
||||
|
||||
# Chat mixin
|
||||
async def add_chat_message(self, session_id: str, message_data: Dict): ...
|
||||
async def archive_chat_session(self, session_id: str): ...
|
||||
async def bulk_update_chat_sessions(self, session_updates: Dict[str, Dict]): ...
|
||||
async def delete_chat_message(self, session_id: str, message_id: str) -> bool: ...
|
||||
async def delete_chat_messages(self, session_id: str): ...
|
||||
async def delete_chat_session_completely(self, session_id: str): ...
|
||||
async def delete_chat_session(self, session_id: str) -> bool: ...
|
||||
async def add_chat_message(self, session_id: str, message_data: Dict):
|
||||
...
|
||||
|
||||
async def archive_chat_session(self, session_id: str):
|
||||
...
|
||||
|
||||
async def bulk_update_chat_sessions(self, session_updates: Dict[str, Dict]):
|
||||
...
|
||||
|
||||
async def delete_chat_message(self, session_id: str, message_id: str) -> bool:
|
||||
...
|
||||
|
||||
async def delete_chat_messages(self, session_id: str):
|
||||
...
|
||||
|
||||
async def delete_chat_session_completely(self, session_id: str):
|
||||
...
|
||||
|
||||
async def delete_chat_session(self, session_id: str) -> bool:
|
||||
...
|
||||
|
||||
# Document mixin
|
||||
async def add_document_to_candidate(self, candidate_id: str, document_id: str): ...
|
||||
async def bulk_update_document_rag_status(self, candidate_id: str, document_ids: List[str], include_in_rag: bool): ...
|
||||
async def add_document_to_candidate(self, candidate_id: str, document_id: str):
|
||||
...
|
||||
|
||||
async def bulk_update_document_rag_status(self, candidate_id: str, document_ids: List[str], include_in_rag: bool):
|
||||
...
|
||||
|
||||
# Job mixin
|
||||
async def bulk_delete_job_requirements(self, document_ids: List[str]) -> int: ...
|
||||
async def cache_skill_match(self, cache_key: str, assessment: SkillAssessment) -> None: ...
|
||||
async def bulk_delete_job_requirements(self, document_ids: List[str]) -> int:
|
||||
...
|
||||
|
||||
async def cache_skill_match(self, cache_key: str, assessment: SkillAssessment) -> None:
|
||||
...
|
||||
|
||||
# User mixin
|
||||
async def delete_candidate_batch(self, candidate_ids: List[str]) -> Dict[str, Dict[str, int]]: ...
|
||||
async def delete_candidate(self, candidate_id: str) -> Dict[str, int]: ...
|
||||
async def delete_employer(self, employer_id: str): ...
|
||||
async def delete_guest(self, guest_id: str) -> bool: ...
|
||||
async def delete_user(self, email: str): ...
|
||||
async def find_candidate_by_username(self, username: str) -> Optional[Dict]: ...
|
||||
async def get_all_users(self) -> Dict[str, Any]: ...
|
||||
async def get_all_viewers(self) -> Dict[str, Any]: ...
|
||||
async def get_candidate_chat_summary(self, candidate_id: str) -> Dict[str, Any]: ...
|
||||
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]: ...
|
||||
async def get_candidate(self, candidate_id: str) -> Optional[Dict]: ...
|
||||
async def get_employer(self, employer_id: str) -> Optional[Dict]: ...
|
||||
async def get_guest_by_session_id(self, session_id: str) -> Optional[Dict[str, Any]]: ...
|
||||
async def get_guest(self, guest_id: str) -> Optional[Dict[str, Any]]: ...
|
||||
async def get_guest_statistics(self) -> Dict[str, Any]: ...
|
||||
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: ...
|
||||
async def get_user_by_username(self, username: str) -> Optional[Dict]: ...
|
||||
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]: ...
|
||||
async def get_user_security_log(self, user_id: str, days: int = 7) -> List[Dict[str, Any]]: ...
|
||||
async def get_user(self, login: str) -> Optional[Dict[str, Any]]: ...
|
||||
async def invalidate_candidate_skill_cache(self, candidate_id: str) -> int: ...
|
||||
async def invalidate_user_skill_cache(self, user_id: str) -> int: ...
|
||||
async def set_candidate(self, candidate_id: str, candidate_data: Dict): ...
|
||||
async def set_employer(self, employer_id: str, employer_data: Dict): ...
|
||||
async def set_guest(self, guest_id: str, guest_data: Dict[str, Any]) -> None: ...
|
||||
async def set_user_by_id(self, user_id: str, user_data: Dict[str, Any]) -> bool: ...
|
||||
async def set_user(self, login: str, user_data: Dict[str, Any]) -> bool: ...
|
||||
async def update_user_rag_timestamp(self, user_id: str) -> bool: ...
|
||||
async def user_exists_by_email(self, email: str) -> bool: ...
|
||||
async def user_exists_by_username(self, username: str) -> bool: ...
|
||||
async def delete_candidate_batch(self, candidate_ids: List[str]) -> Dict[str, Dict[str, int]]:
|
||||
...
|
||||
|
||||
async def delete_candidate(self, candidate_id: str) -> Dict[str, int]:
|
||||
...
|
||||
|
||||
async def delete_employer(self, employer_id: str):
|
||||
...
|
||||
|
||||
async def delete_guest(self, guest_id: str) -> bool:
|
||||
...
|
||||
|
||||
async def delete_user(self, email: str):
|
||||
...
|
||||
|
||||
async def find_candidate_by_username(self, username: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def get_all_users(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_all_viewers(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_candidate_chat_summary(self, candidate_id: str) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def get_candidate(self, candidate_id: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def get_employer(self, employer_id: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def get_guest_by_session_id(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def get_guest(self, guest_id: str) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def get_guest_statistics(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def get_user_by_username(self, username: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]:
|
||||
...
|
||||
|
||||
async def get_user_security_log(self, user_id: str, days: int = 7) -> List[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def get_user(self, login: str) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def invalidate_candidate_skill_cache(self, candidate_id: str) -> int:
|
||||
...
|
||||
|
||||
async def invalidate_user_skill_cache(self, user_id: str) -> int:
|
||||
...
|
||||
|
||||
async def set_candidate(self, candidate_id: str, candidate_data: Dict):
|
||||
...
|
||||
|
||||
async def set_employer(self, employer_id: str, employer_data: Dict):
|
||||
...
|
||||
|
||||
async def set_guest(self, guest_id: str, guest_data: Dict[str, Any]) -> None:
|
||||
...
|
||||
|
||||
async def set_user_by_id(self, user_id: str, user_data: Dict[str, Any]) -> bool:
|
||||
...
|
||||
|
||||
async def set_user(self, login: str, user_data: Dict[str, Any]) -> bool:
|
||||
...
|
||||
|
||||
async def update_user_rag_timestamp(self, user_id: str) -> bool:
|
||||
...
|
||||
|
||||
async def user_exists_by_email(self, email: str) -> bool:
|
||||
...
|
||||
|
||||
async def user_exists_by_username(self, username: str) -> bool:
|
||||
...
|
||||
|
||||
# Auth mixin
|
||||
async def cleanup_expired_verification_tokens(self) -> int: ...
|
||||
async def cleanup_inactive_guests(self, inactive_hours: int = 24) -> int: ...
|
||||
async def cleanup_old_chat_sessions(self, days_old: int = 90) -> int: ...
|
||||
async def cleanup_orphaned_job_requirements(self) -> int: ...
|
||||
async def clear_all_data(self: "DatabaseProtocol"): ...
|
||||
async def clear_all_skill_match_cache(self) -> int: ...
|
||||
async def cleanup_expired_verification_tokens(self) -> int:
|
||||
...
|
||||
|
||||
async def cleanup_inactive_guests(self, inactive_hours: int = 24) -> int:
|
||||
...
|
||||
|
||||
async def cleanup_old_chat_sessions(self, days_old: int = 90) -> int:
|
||||
...
|
||||
|
||||
async def cleanup_orphaned_job_requirements(self) -> int:
|
||||
...
|
||||
|
||||
async def clear_all_data(self: "DatabaseProtocol"):
|
||||
...
|
||||
|
||||
async def clear_all_skill_match_cache(self) -> int:
|
||||
...
|
||||
|
||||
# Resume mixin
|
||||
|
||||
async def delete_ai_parameters(self, param_id: str): ...
|
||||
async def delete_all_candidate_documents(self, candidate_id: str) -> int: ...
|
||||
async def delete_all_resumes_for_user(self, user_id: str) -> int: ...
|
||||
async def delete_authentication(self, user_id: str) -> bool: ...
|
||||
async def delete_document(self, document_id: str): ...
|
||||
async def delete_ai_parameters(self, param_id: str):
|
||||
...
|
||||
|
||||
async def delete_job_application(self, application_id: str): ...
|
||||
async def delete_job_requirements(self, document_id: str) -> bool: ...
|
||||
async def delete_job(self, job_id: str): ...
|
||||
async def delete_resume(self, user_id: str, resume_id: str) -> bool: ...
|
||||
async def delete_viewer(self, viewer_id: str): ...
|
||||
async def find_verification_token_by_email(self, email: str) -> Optional[Dict[str, Any]]: ...
|
||||
async def get_ai_parameters(self, param_id: str) -> Optional[Dict]: ...
|
||||
async def get_all_ai_parameters(self) -> Dict[str, Any]: ...
|
||||
async def get_all_candidates(self) -> Dict[str, Any]: ...
|
||||
async def get_all_chat_messages(self) -> Dict[str, List[Dict]]: ...
|
||||
async def get_all_chat_sessions(self) -> Dict[str, Any]: ...
|
||||
async def get_all_employers(self) -> Dict[str, Any]: ...
|
||||
async def get_all_guests(self) -> Dict[str, Dict[str, Any]]: ...
|
||||
async def get_all_job_applications(self) -> Dict[str, Any]: ...
|
||||
async def get_all_job_requirements(self) -> Dict[str, Any]: ...
|
||||
async def get_all_jobs(self) -> Dict[str, Any]: ...
|
||||
async def get_all_resumes_for_user(self, user_id: str) -> List[Dict]: ...
|
||||
async def get_all_resumes(self) -> Dict[str, List[Dict]]: ...
|
||||
async def get_authentication(self, user_id: str) -> Optional[Dict[str, Any]]: ...
|
||||
async def get_cached_skill_match(self, cache_key: str) -> Optional[SkillAssessment]: ...
|
||||
async def get_chat_message_count(self, session_id: str) -> int: ...
|
||||
async def get_chat_messages(self, session_id: str) -> List[Dict]: ...
|
||||
async def get_chat_sessions_by_candidate(self, candidate_id: str) -> List[Dict]: ...
|
||||
async def get_chat_sessions_by_user(self, user_id: str) -> List[Dict]: ...
|
||||
async def get_chat_session(self, session_id: str) -> Optional[Dict]: ...
|
||||
async def get_chat_statistics(self) -> Dict[str, Any]: ...
|
||||
async def get_document_count_for_candidate(self, candidate_id: str) -> int: ...
|
||||
async def get_documents_by_rag_status(self, candidate_id: str, include_in_rag: bool = True) -> List[Dict]: ...
|
||||
async def get_document(self, document_id: str) -> Optional[Dict]: ...
|
||||
async def get_email_verification_token(self, token: str) -> Optional[Dict[str, Any]]: ...
|
||||
async def get_job_application(self, application_id: str) -> Optional[Dict]: ...
|
||||
async def get_job_requirements_by_candidate(self, candidate_id: str) -> List[Dict]: ...
|
||||
async def get_job_requirements(self, document_id: str) -> Optional[Dict]: ...
|
||||
async def get_job_requirements_stats(self) -> Dict[str, Any]: ...
|
||||
async def get_job(self, job_id: str) -> Optional[Dict]: ...
|
||||
async def get_mfa_code(self, email: str, device_id: str) -> Optional[Dict[str, Any]]: ...
|
||||
async def get_multiple_candidates_by_usernames(self, usernames: List[str]) -> Dict[str, Dict]: ...
|
||||
async def get_password_reset_token(self, token: str) -> Optional[Dict[str, Any]]: ...
|
||||
async def get_pending_verifications_count(self) -> int: ...
|
||||
async def get_recent_chat_messages(self, session_id: str, limit: int = 10) -> List[Dict]: ...
|
||||
async def get_refresh_token(self, token: str) -> Optional[Dict[str, Any]]: ...
|
||||
async def get_resumes_by_candidate(self, user_id: str, candidate_id: str) -> List[Dict]: ...
|
||||
async def get_resumes_by_job(self, user_id: str, job_id: str) -> List[Dict]: ...
|
||||
async def get_resume(self, user_id: str, resume_id: str) -> Optional[Dict]: ...
|
||||
async def get_resume_statistics(self, user_id: str) -> Dict[str, Any]: ...
|
||||
async def get_stats(self) -> Dict[str, int]: ...
|
||||
async def get_verification_attempts_count(self, email: str) -> int: ...
|
||||
async def get_viewer(self, viewer_id: str) -> Optional[Dict]: ...
|
||||
async def increment_mfa_attempts(self, email: str, device_id: str) -> int: ...
|
||||
async def invalidate_job_requirements_cache(self, document_id: str) -> bool: ...
|
||||
async def log_security_event(self, user_id: str, event_type: str, details: Dict[str, Any]) -> bool: ...
|
||||
async def mark_email_verified(self, token: str) -> bool: ...
|
||||
async def mark_mfa_verified(self, email: str, device_id: str) -> bool: ...
|
||||
async def mark_password_reset_token_used(self, token: str) -> bool: ...
|
||||
async def record_verification_attempt(self, email: str) -> bool: ...
|
||||
async def remove_document_from_candidate(self, candidate_id: str, document_id: str): ...
|
||||
async def revoke_all_user_tokens(self, user_id: str) -> bool: ...
|
||||
async def revoke_refresh_token(self, token: str) -> bool: ...
|
||||
async def save_job_requirements(self, document_id: str, requirements: Dict) -> bool: ...
|
||||
async def search_candidate_documents(self, candidate_id: str, query: str) -> List[Dict]: ...
|
||||
async def search_chat_messages(self, session_id: str, query: str) -> List[Dict]: ...
|
||||
async def search_resumes_for_user(self, user_id: str, query: str) -> List[Dict]: ...
|
||||
async def set_ai_parameters(self, param_id: str, param_data: Dict): ...
|
||||
async def set_authentication(self, user_id: str, auth_data: Dict[str, Any]) -> bool: ...
|
||||
async def set_chat_messages(self, session_id: str, messages: List[Dict]): ...
|
||||
async def set_chat_session(self, session_id: str, session_data: Dict): ...
|
||||
async def set_document(self, document_id: str, document_data: Dict): ...
|
||||
async def set_job_application(self, application_id: str, application_data: Dict): ...
|
||||
async def set_job(self, job_id: str, job_data: Dict): ...
|
||||
async def set_resume(self, user_id: str, resume_data: Dict) -> bool: ...
|
||||
async def set_viewer(self, viewer_id: str, viewer_data: Dict): ...
|
||||
async def store_email_verification_token(self, email: str, token: str, user_type: str, user_data: dict) -> bool: ...
|
||||
async def store_mfa_code(self, email: str, code: str, device_id: str) -> bool: ...
|
||||
async def store_password_reset_token(self, email: str, token: str, expires_at: datetime) -> bool: ...
|
||||
async def store_refresh_token(self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]) -> bool: ...
|
||||
async def update_chat_session_activity(self, session_id: str): ...
|
||||
async def update_document(self, document_id: str, updates: Dict)-> Dict[Any, Any] | None: ...
|
||||
async def update_resume(self, user_id: str, resume_id: str, updates: Dict) -> Optional[Dict]: ...
|
||||
async def delete_all_candidate_documents(self, candidate_id: str) -> int:
|
||||
...
|
||||
|
||||
async def delete_all_resumes_for_user(self, user_id: str) -> int:
|
||||
...
|
||||
|
||||
async def delete_authentication(self, user_id: str) -> bool:
|
||||
...
|
||||
|
||||
async def delete_document(self, document_id: str):
|
||||
...
|
||||
|
||||
async def delete_job_application(self, application_id: str):
|
||||
...
|
||||
|
||||
async def delete_job_requirements(self, document_id: str) -> bool:
|
||||
...
|
||||
|
||||
async def delete_job(self, job_id: str):
|
||||
...
|
||||
|
||||
async def delete_resume(self, user_id: str, resume_id: str) -> bool:
|
||||
...
|
||||
|
||||
async def delete_viewer(self, viewer_id: str):
|
||||
...
|
||||
|
||||
async def find_verification_token_by_email(self, email: str) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def get_ai_parameters(self, param_id: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def get_all_ai_parameters(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_all_candidates(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_all_chat_messages(self) -> Dict[str, List[Dict]]:
|
||||
...
|
||||
|
||||
async def get_all_chat_sessions(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_all_employers(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_all_guests(self) -> Dict[str, Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def get_all_job_applications(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_all_job_requirements(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_all_jobs(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_all_resumes_for_user(self, user_id: str) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def get_all_resumes(self) -> Dict[str, List[Dict]]:
|
||||
...
|
||||
|
||||
async def get_authentication(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def get_cached_skill_match(self, cache_key: str) -> Optional[SkillAssessment]:
|
||||
...
|
||||
|
||||
async def get_chat_message_count(self, session_id: str) -> int:
|
||||
...
|
||||
|
||||
async def get_chat_messages(self, session_id: str) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def get_chat_sessions_by_candidate(self, candidate_id: str) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def get_chat_sessions_by_user(self, user_id: str) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def get_chat_session(self, session_id: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def get_chat_statistics(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_document_count_for_candidate(self, candidate_id: str) -> int:
|
||||
...
|
||||
|
||||
async def get_documents_by_rag_status(self, candidate_id: str, include_in_rag: bool = True) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def get_document(self, document_id: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def get_email_verification_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def get_job_application(self, application_id: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def get_job_requirements_by_candidate(self, candidate_id: str) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def get_job_requirements(self, document_id: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def get_job_requirements_stats(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_job(self, job_id: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def get_mfa_code(self, email: str, device_id: str) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def get_multiple_candidates_by_usernames(self, usernames: List[str]) -> Dict[str, Dict]:
|
||||
...
|
||||
|
||||
async def get_password_reset_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def get_pending_verifications_count(self) -> int:
|
||||
...
|
||||
|
||||
async def get_recent_chat_messages(self, session_id: str, limit: int = 10) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def get_refresh_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def get_resumes_by_candidate(self, user_id: str, candidate_id: str) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def get_resumes_by_job(self, user_id: str, job_id: str) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def get_resume(self, user_id: str, resume_id: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def get_resume_statistics(self, user_id: str) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
async def get_stats(self) -> Dict[str, int]:
|
||||
...
|
||||
|
||||
async def get_verification_attempts_count(self, email: str) -> int:
|
||||
...
|
||||
|
||||
async def get_viewer(self, viewer_id: str) -> Optional[Dict]:
|
||||
...
|
||||
|
||||
async def increment_mfa_attempts(self, email: str, device_id: str) -> int:
|
||||
...
|
||||
|
||||
async def invalidate_job_requirements_cache(self, document_id: str) -> bool:
|
||||
...
|
||||
|
||||
async def log_security_event(self, user_id: str, event_type: str, details: Dict[str, Any]) -> bool:
|
||||
...
|
||||
|
||||
async def mark_email_verified(self, token: str) -> bool:
|
||||
...
|
||||
|
||||
async def mark_mfa_verified(self, email: str, device_id: str) -> bool:
|
||||
...
|
||||
|
||||
async def mark_password_reset_token_used(self, token: str) -> bool:
|
||||
...
|
||||
|
||||
async def record_verification_attempt(self, email: str) -> bool:
|
||||
...
|
||||
|
||||
async def remove_document_from_candidate(self, candidate_id: str, document_id: str):
|
||||
...
|
||||
|
||||
async def revoke_all_user_tokens(self, user_id: str) -> bool:
|
||||
...
|
||||
|
||||
async def revoke_refresh_token(self, token: str) -> bool:
|
||||
...
|
||||
|
||||
async def save_job_requirements(self, document_id: str, requirements: Dict) -> bool:
|
||||
...
|
||||
|
||||
async def search_candidate_documents(self, candidate_id: str, query: str) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def search_chat_messages(self, session_id: str, query: str) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def search_resumes_for_user(self, user_id: str, query: str) -> List[Dict]:
|
||||
...
|
||||
|
||||
async def set_ai_parameters(self, param_id: str, param_data: Dict):
|
||||
...
|
||||
|
||||
async def set_authentication(self, user_id: str, auth_data: Dict[str, Any]) -> bool:
|
||||
...
|
||||
|
||||
async def set_chat_messages(self, session_id: str, messages: List[Dict]):
|
||||
...
|
||||
|
||||
async def set_chat_session(self, session_id: str, session_data: Dict):
|
||||
...
|
||||
|
||||
async def set_document(self, document_id: str, document_data: Dict):
|
||||
...
|
||||
|
||||
async def set_job_application(self, application_id: str, application_data: Dict):
|
||||
...
|
||||
|
||||
async def set_job(self, job_id: str, job_data: Dict):
|
||||
...
|
||||
|
||||
async def set_resume(self, user_id: str, resume_data: Dict) -> bool:
|
||||
...
|
||||
|
||||
async def set_viewer(self, viewer_id: str, viewer_data: Dict):
|
||||
...
|
||||
|
||||
async def store_email_verification_token(self, email: str, token: str, user_type: str, user_data: dict) -> bool:
|
||||
...
|
||||
|
||||
async def store_mfa_code(self, email: str, code: str, device_id: str) -> bool:
|
||||
...
|
||||
|
||||
async def store_password_reset_token(self, email: str, token: str, expires_at: datetime) -> bool:
|
||||
...
|
||||
|
||||
async def store_refresh_token(
|
||||
self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]
|
||||
) -> bool:
|
||||
...
|
||||
|
||||
async def update_chat_session_activity(self, session_id: str):
|
||||
...
|
||||
|
||||
async def update_document(self, document_id: str, updates: Dict) -> Dict[Any, Any] | None:
|
||||
...
|
||||
|
||||
async def update_resume(self, user_id: str, resume_id: str, updates: Dict) -> Optional[Dict]:
|
||||
...
|
||||
|
@ -7,6 +7,7 @@ from ..constants import KEY_PREFIXES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResumeMixin(DatabaseProtocol):
|
||||
"""Mixin for resume-related database operations"""
|
||||
|
||||
@ -14,10 +15,10 @@ class ResumeMixin(DatabaseProtocol):
|
||||
"""Save a resume for a user"""
|
||||
try:
|
||||
# Generate resume_id if not present
|
||||
if 'id' not in resume_data:
|
||||
if "id" not in resume_data:
|
||||
raise ValueError("Resume data must include an 'id' field")
|
||||
|
||||
resume_id = resume_data['id']
|
||||
resume_id = resume_data["id"]
|
||||
|
||||
# Store the resume data
|
||||
key = f"{KEY_PREFIXES['resumes']}{user_id}:{resume_id}"
|
||||
@ -25,7 +26,7 @@ class ResumeMixin(DatabaseProtocol):
|
||||
|
||||
# Add resume_id to user's resume list
|
||||
user_resumes_key = f"{KEY_PREFIXES['user_resumes']}{user_id}"
|
||||
await self.redis.rpush(user_resumes_key, resume_id) # type: ignore
|
||||
await self.redis.rpush(user_resumes_key, resume_id) # type: ignore
|
||||
|
||||
logger.info(f"📄 Saved resume {resume_id} for user {user_id}")
|
||||
return True
|
||||
@ -53,7 +54,7 @@ class ResumeMixin(DatabaseProtocol):
|
||||
try:
|
||||
# Get all resume IDs for this user
|
||||
user_resumes_key = f"{KEY_PREFIXES['user_resumes']}{user_id}"
|
||||
resume_ids = await self.redis.lrange(user_resumes_key, 0, -1)# type: ignore
|
||||
resume_ids = await self.redis.lrange(user_resumes_key, 0, -1) # type: ignore
|
||||
|
||||
if not resume_ids:
|
||||
logger.info(f"📄 No resumes found for user {user_id}")
|
||||
@ -73,7 +74,7 @@ class ResumeMixin(DatabaseProtocol):
|
||||
resumes.append(resume_data)
|
||||
else:
|
||||
# Clean up orphaned resume ID
|
||||
await self.redis.lrem(user_resumes_key, 0, resume_id)# type: ignore
|
||||
await self.redis.lrem(user_resumes_key, 0, resume_id) # type: ignore
|
||||
logger.warning(f"Removed orphaned resume ID {resume_id} for user {user_id}")
|
||||
|
||||
# Sort by created_at timestamp (most recent first)
|
||||
@ -94,7 +95,7 @@ class ResumeMixin(DatabaseProtocol):
|
||||
|
||||
# Remove from user's resume list
|
||||
user_resumes_key = f"{KEY_PREFIXES['user_resumes']}{user_id}"
|
||||
await self.redis.lrem(user_resumes_key, 0, resume_id)# type: ignore
|
||||
await self.redis.lrem(user_resumes_key, 0, resume_id) # type: ignore
|
||||
|
||||
if result > 0:
|
||||
logger.info(f"🗑️ Deleted resume {resume_id} for user {user_id}")
|
||||
@ -111,7 +112,7 @@ class ResumeMixin(DatabaseProtocol):
|
||||
try:
|
||||
# Get all resume IDs for this user
|
||||
user_resumes_key = f"{KEY_PREFIXES['user_resumes']}{user_id}"
|
||||
resume_ids = await self.redis.lrange(user_resumes_key, 0, -1)# type: ignore
|
||||
resume_ids = await self.redis.lrange(user_resumes_key, 0, -1) # type: ignore
|
||||
|
||||
if not resume_ids:
|
||||
logger.info(f"📄 No resumes found for user {user_id}")
|
||||
@ -159,7 +160,7 @@ class ResumeMixin(DatabaseProtocol):
|
||||
for key, value in zip(keys, values):
|
||||
if value:
|
||||
# Extract user_id from key format: resume:{user_id}:{resume_id}
|
||||
key_parts = key.replace(KEY_PREFIXES['resumes'], '').split(':', 1)
|
||||
key_parts = key.replace(KEY_PREFIXES["resumes"], "").split(":", 1)
|
||||
if len(key_parts) >= 1:
|
||||
user_id = key_parts[0]
|
||||
resume_data = self._deserialize(value)
|
||||
@ -186,12 +187,14 @@ class ResumeMixin(DatabaseProtocol):
|
||||
matching_resumes = []
|
||||
for resume in all_resumes:
|
||||
# Search in resume content, job_id, candidate_id, etc.
|
||||
searchable_text = " ".join([
|
||||
resume.get("resume", ""),
|
||||
resume.get("job_id", ""),
|
||||
resume.get("candidate_id", ""),
|
||||
str(resume.get("created_at", ""))
|
||||
]).lower()
|
||||
searchable_text = " ".join(
|
||||
[
|
||||
resume.get("resume", ""),
|
||||
resume.get("job_id", ""),
|
||||
resume.get("candidate_id", ""),
|
||||
str(resume.get("created_at", "")),
|
||||
]
|
||||
).lower()
|
||||
|
||||
if query_lower in searchable_text:
|
||||
matching_resumes.append(resume)
|
||||
@ -206,10 +209,7 @@ class ResumeMixin(DatabaseProtocol):
|
||||
"""Get all resumes for a specific candidate created by a user"""
|
||||
try:
|
||||
all_resumes = await self.get_all_resumes_for_user(user_id)
|
||||
candidate_resumes = [
|
||||
resume for resume in all_resumes
|
||||
if resume.get("candidate_id") == candidate_id
|
||||
]
|
||||
candidate_resumes = [resume for resume in all_resumes if resume.get("candidate_id") == candidate_id]
|
||||
|
||||
logger.info(f"📄 Found {len(candidate_resumes)} resumes for candidate {candidate_id} by user {user_id}")
|
||||
return candidate_resumes
|
||||
@ -221,10 +221,7 @@ class ResumeMixin(DatabaseProtocol):
|
||||
"""Get all resumes for a specific job created by a user"""
|
||||
try:
|
||||
all_resumes = await self.get_all_resumes_for_user(user_id)
|
||||
job_resumes = [
|
||||
resume for resume in all_resumes
|
||||
if resume.get("job_id") == job_id
|
||||
]
|
||||
job_resumes = [resume for resume in all_resumes if resume.get("job_id") == job_id]
|
||||
|
||||
logger.info(f"📄 Found {len(job_resumes)} resumes for job {job_id} by user {user_id}")
|
||||
return job_resumes
|
||||
@ -242,7 +239,7 @@ class ResumeMixin(DatabaseProtocol):
|
||||
"resumes_by_candidate": {},
|
||||
"resumes_by_job": {},
|
||||
"creation_timeline": {},
|
||||
"recent_resumes": []
|
||||
"recent_resumes": [],
|
||||
}
|
||||
|
||||
for resume in all_resumes:
|
||||
@ -269,7 +266,13 @@ class ResumeMixin(DatabaseProtocol):
|
||||
return stats
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting resume statistics for user {user_id}: {e}")
|
||||
return {"total_resumes": 0, "resumes_by_candidate": {}, "resumes_by_job": {}, "creation_timeline": {}, "recent_resumes": []}
|
||||
return {
|
||||
"total_resumes": 0,
|
||||
"resumes_by_candidate": {},
|
||||
"resumes_by_job": {},
|
||||
"creation_timeline": {},
|
||||
"recent_resumes": [],
|
||||
}
|
||||
|
||||
async def update_resume(self, user_id: str, resume_id: str, updates: Dict) -> Optional[Dict]:
|
||||
"""Update specific fields of a resume"""
|
||||
|
@ -8,6 +8,7 @@ from .protocols import DatabaseProtocol
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SkillMixin(DatabaseProtocol):
|
||||
"""Mixin for Skill-related database operations"""
|
||||
|
||||
@ -74,7 +75,9 @@ class SkillMixin(DatabaseProtocol):
|
||||
# Cache for 1 hour by default
|
||||
await self.redis.set(
|
||||
cache_key,
|
||||
json.dumps(assessment.model_dump(mode='json', by_alias=True), default=str) # Serialize with datetime handling
|
||||
json.dumps(
|
||||
assessment.model_dump(mode="json", by_alias=True), default=str
|
||||
), # Serialize with datetime handling
|
||||
)
|
||||
logger.info(f"💾 Skill match cached: {cache_key}")
|
||||
except Exception as e:
|
||||
|
@ -10,6 +10,7 @@ from ..constants import KEY_PREFIXES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserMixin(DatabaseProtocol):
|
||||
"""Mixin for user operations"""
|
||||
|
||||
@ -23,13 +24,13 @@ class UserMixin(DatabaseProtocol):
|
||||
guest_data["last_activity"] = datetime.now(UTC).isoformat()
|
||||
|
||||
# Store in Redis with both hash and individual key for redundancy
|
||||
await self.redis.hset("guests", guest_id, json.dumps(guest_data))# type: ignore
|
||||
await self.redis.hset("guests", guest_id, json.dumps(guest_data)) # type: ignore
|
||||
|
||||
# Also store with a longer TTL as backup
|
||||
await self.redis.setex(
|
||||
f"guest_backup:{guest_id}",
|
||||
86400 * 7, # 7 days TTL
|
||||
json.dumps(guest_data)
|
||||
json.dumps(guest_data),
|
||||
)
|
||||
|
||||
logger.info(f"💾 Guest stored with backup: {guest_id}")
|
||||
@ -41,7 +42,7 @@ class UserMixin(DatabaseProtocol):
|
||||
"""Get guest data with fallback to backup"""
|
||||
try:
|
||||
# Try primary storage first
|
||||
data = await self.redis.hget("guests", guest_id)# type: ignore
|
||||
data = await self.redis.hget("guests", guest_id) # type: ignore
|
||||
if data:
|
||||
guest_data = json.loads(data)
|
||||
# Update last activity when accessed
|
||||
@ -82,11 +83,8 @@ class UserMixin(DatabaseProtocol):
|
||||
async def get_all_guests(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get all guests"""
|
||||
try:
|
||||
data = await self.redis.hgetall("guests")# type: ignore
|
||||
return {
|
||||
guest_id: json.loads(guest_json)
|
||||
for guest_id, guest_json in data.items()
|
||||
}
|
||||
data = await self.redis.hgetall("guests") # type: ignore
|
||||
return {guest_id: json.loads(guest_json) for guest_id, guest_json in data.items()}
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting all guests: {e}")
|
||||
return {}
|
||||
@ -94,7 +92,7 @@ class UserMixin(DatabaseProtocol):
|
||||
async def delete_guest(self, guest_id: str) -> bool:
|
||||
"""Delete a guest"""
|
||||
try:
|
||||
result = await self.redis.hdel("guests", guest_id)# type: ignore
|
||||
result = await self.redis.hdel("guests", guest_id) # type: ignore
|
||||
if result:
|
||||
logger.info(f"🗑️ Guest deleted: {guest_id}")
|
||||
return True
|
||||
@ -120,7 +118,7 @@ class UserMixin(DatabaseProtocol):
|
||||
|
||||
# Skip cleanup if guest is very new (less than 1 hour old)
|
||||
if created_at_str:
|
||||
created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
|
||||
created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00"))
|
||||
if current_time - created_at < timedelta(hours=1):
|
||||
preserved_count += 1
|
||||
logger.info(f"🛡️ Preserving new guest: {guest_id}")
|
||||
@ -130,7 +128,7 @@ class UserMixin(DatabaseProtocol):
|
||||
should_delete = False
|
||||
if last_activity_str:
|
||||
try:
|
||||
last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00'))
|
||||
last_activity = datetime.fromisoformat(last_activity_str.replace("Z", "+00:00"))
|
||||
if last_activity < cutoff_time:
|
||||
should_delete = True
|
||||
except ValueError:
|
||||
@ -172,7 +170,7 @@ class UserMixin(DatabaseProtocol):
|
||||
"active_last_day": 0,
|
||||
"converted_guests": 0,
|
||||
"by_ip": {},
|
||||
"creation_timeline": {}
|
||||
"creation_timeline": {},
|
||||
}
|
||||
|
||||
hour_ago = current_time - timedelta(hours=1)
|
||||
@ -183,7 +181,7 @@ class UserMixin(DatabaseProtocol):
|
||||
last_activity_str = guest_data.get("last_activity")
|
||||
if last_activity_str:
|
||||
try:
|
||||
last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00'))
|
||||
last_activity = datetime.fromisoformat(last_activity_str.replace("Z", "+00:00"))
|
||||
if last_activity > hour_ago:
|
||||
stats["active_last_hour"] += 1
|
||||
if last_activity > day_ago:
|
||||
@ -203,8 +201,8 @@ class UserMixin(DatabaseProtocol):
|
||||
created_at_str = guest_data.get("created_at")
|
||||
if created_at_str:
|
||||
try:
|
||||
created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
|
||||
date_key = created_at.strftime('%Y-%m-%d')
|
||||
created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00"))
|
||||
date_key = created_at.strftime("%Y-%m-%d")
|
||||
stats["creation_timeline"][date_key] = stats["creation_timeline"].get(date_key, 0) + 1
|
||||
except ValueError:
|
||||
pass
|
||||
@ -278,7 +276,7 @@ class UserMixin(DatabaseProtocol):
|
||||
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get user lookup data by user ID"""
|
||||
try:
|
||||
data = await self.redis.hget("user_lookup_by_id", user_id)# type: ignore
|
||||
data = await self.redis.hget("user_lookup_by_id", user_id) # type: ignore
|
||||
if data:
|
||||
return json.loads(data)
|
||||
return None
|
||||
@ -321,7 +319,7 @@ class UserMixin(DatabaseProtocol):
|
||||
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
email = key.replace(KEY_PREFIXES['users'], '')
|
||||
email = key.replace(KEY_PREFIXES["users"], "")
|
||||
logger.info(f"🔍 Found user key: {key}, type: {type(value)}")
|
||||
if type(value) == str:
|
||||
result[email] = value
|
||||
@ -364,7 +362,6 @@ class UserMixin(DatabaseProtocol):
|
||||
logger.error(f"❌ Error storing user {login}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# ================
|
||||
# Employers
|
||||
# ================
|
||||
@ -394,7 +391,7 @@ class UserMixin(DatabaseProtocol):
|
||||
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
employer_id = key.replace(KEY_PREFIXES['employers'], '')
|
||||
employer_id = key.replace(KEY_PREFIXES["employers"], "")
|
||||
result[employer_id] = self._deserialize(value)
|
||||
|
||||
return result
|
||||
@ -404,7 +401,6 @@ class UserMixin(DatabaseProtocol):
|
||||
key = f"{KEY_PREFIXES['employers']}{employer_id}"
|
||||
await self.redis.delete(key)
|
||||
|
||||
|
||||
# ================
|
||||
# Candidates
|
||||
# ================
|
||||
@ -435,7 +431,7 @@ class UserMixin(DatabaseProtocol):
|
||||
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
candidate_id = key.replace(KEY_PREFIXES['candidates'], '')
|
||||
candidate_id = key.replace(KEY_PREFIXES["candidates"], "")
|
||||
result[candidate_id] = self._deserialize(value)
|
||||
|
||||
return result
|
||||
@ -456,7 +452,7 @@ class UserMixin(DatabaseProtocol):
|
||||
"security_logs": 0,
|
||||
"ai_parameters": 0,
|
||||
"candidate_record": 0,
|
||||
"resumes": 0
|
||||
"resumes": 0,
|
||||
}
|
||||
|
||||
logger.info(f"🗑️ Starting cascading delete for candidate {candidate_id}")
|
||||
@ -495,7 +491,9 @@ class UserMixin(DatabaseProtocol):
|
||||
|
||||
deletion_stats["chat_sessions"] = len(candidate_sessions)
|
||||
deletion_stats["chat_messages"] = messages_deleted
|
||||
logger.info(f"🗑️ Deleted {len(candidate_sessions)} chat sessions and {messages_deleted} messages for candidate {candidate_id}")
|
||||
logger.info(
|
||||
f"🗑️ Deleted {len(candidate_sessions)} chat sessions and {messages_deleted} messages for candidate {candidate_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error deleting chat sessions: {e}")
|
||||
|
||||
@ -528,9 +526,11 @@ class UserMixin(DatabaseProtocol):
|
||||
logger.info(f"🗑️ Deleted user record by email: {candidate_email}")
|
||||
|
||||
# Delete by username (if different from email)
|
||||
if (candidate_username and
|
||||
candidate_username != candidate_email and
|
||||
await self.user_exists_by_username(candidate_username)):
|
||||
if (
|
||||
candidate_username
|
||||
and candidate_username != candidate_email
|
||||
and await self.user_exists_by_username(candidate_username)
|
||||
):
|
||||
await self.delete_user(candidate_username)
|
||||
user_records_deleted += 1
|
||||
logger.info(f"🗑️ Deleted user record by username: {candidate_username}")
|
||||
@ -593,8 +593,7 @@ class UserMixin(DatabaseProtocol):
|
||||
candidate_ai_params = []
|
||||
|
||||
for param_id, param_data in all_ai_params.items():
|
||||
if (param_data.get("candidateId") == candidate_id or
|
||||
param_data.get("userId") == candidate_id):
|
||||
if param_data.get("candidateId") == candidate_id or param_data.get("userId") == candidate_id:
|
||||
candidate_ai_params.append(param_id)
|
||||
|
||||
# Delete each AI parameter set
|
||||
@ -630,7 +629,9 @@ class UserMixin(DatabaseProtocol):
|
||||
break
|
||||
|
||||
if tokens_deleted > 0:
|
||||
logger.info(f"🗑️ Deleted {tokens_deleted} email verification tokens for candidate {candidate_id}")
|
||||
logger.info(
|
||||
f"🗑️ Deleted {tokens_deleted} email verification tokens for candidate {candidate_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error deleting email verification tokens: {e}")
|
||||
|
||||
@ -717,8 +718,10 @@ class UserMixin(DatabaseProtocol):
|
||||
# 15. Log the deletion as a security event (if we have admin/system user context)
|
||||
try:
|
||||
total_items_deleted = sum(deletion_stats.values())
|
||||
logger.info(f"✅ Completed cascading delete for candidate {candidate_id}. "
|
||||
f"Total items deleted: {total_items_deleted}")
|
||||
logger.info(
|
||||
f"✅ Completed cascading delete for candidate {candidate_id}. "
|
||||
f"Total items deleted: {total_items_deleted}"
|
||||
)
|
||||
logger.info(f"📊 Deletion breakdown: {deletion_stats}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error logging deletion summary: {e}")
|
||||
@ -774,8 +777,8 @@ class UserMixin(DatabaseProtocol):
|
||||
"total_candidates_processed": len(candidate_ids),
|
||||
"successful_deletions": len([r for r in batch_results.values() if "error" not in r]),
|
||||
"failed_deletions": len([r for r in batch_results.values() if "error" in r]),
|
||||
"total_items_deleted": sum(total_stats.values())
|
||||
}
|
||||
"total_items_deleted": sum(total_stats.values()),
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@ -816,7 +819,7 @@ class UserMixin(DatabaseProtocol):
|
||||
"total_sessions": 0,
|
||||
"total_messages": 0,
|
||||
"first_chat": None,
|
||||
"last_chat": None
|
||||
"last_chat": None,
|
||||
}
|
||||
|
||||
total_messages = 0
|
||||
@ -835,7 +838,7 @@ class UserMixin(DatabaseProtocol):
|
||||
"total_messages": total_messages,
|
||||
"first_chat": sessions_by_date[0].get("createdAt") if sessions_by_date else None,
|
||||
"last_chat": sessions_by_date[-1].get("lastActivity") if sessions_by_date else None,
|
||||
"recent_sessions": sessions[:5] # Last 5 sessions
|
||||
"recent_sessions": sessions[:5], # Last 5 sessions
|
||||
}
|
||||
|
||||
# ================
|
||||
@ -868,7 +871,7 @@ class UserMixin(DatabaseProtocol):
|
||||
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
viewer_id = key.replace(KEY_PREFIXES['viewers'], '')
|
||||
viewer_id = key.replace(KEY_PREFIXES["viewers"], "")
|
||||
result[viewer_id] = self._deserialize(value)
|
||||
|
||||
return result
|
||||
@ -877,4 +880,3 @@ class UserMixin(DatabaseProtocol):
|
||||
"""Delete viewer"""
|
||||
key = f"{KEY_PREFIXES['viewers']}{viewer_id}"
|
||||
await self.redis.delete(key)
|
||||
|
||||
|
@ -3,20 +3,20 @@ import os
|
||||
ollama_api_url = "http://ollama:11434" # Default Ollama local endpoint
|
||||
|
||||
user_dir = "/opt/backstory/users"
|
||||
user_info_file = "info.json" # Relative to "{user_dir}/{user}"
|
||||
user_info_file = "info.json" # Relative to "{user_dir}/{user}"
|
||||
default_username = "jketreno"
|
||||
rag_content_dir = "rag-content" # Relative to "{user_dir}/{user}"
|
||||
rag_content_dir = "rag-content" # Relative to "{user_dir}/{user}"
|
||||
# Path to candidate full resume
|
||||
resume_doc_dir = f"{rag_content_dir}/resume" # Relative to "{user_dir}/{user}
|
||||
resume_doc_dir = f"{rag_content_dir}/resume" # Relative to "{user_dir}/{user}
|
||||
resume_doc = "resume.md"
|
||||
persist_directory = "db" # Relative to "{user_dir}/{user}"
|
||||
persist_directory = "db" # Relative to "{user_dir}/{user}"
|
||||
|
||||
# Model name License Notes
|
||||
# model = "deepseek-r1:7b" # MIT Tool calls don"t work
|
||||
# model = "gemma3:4b" # Gemma Requires newer ollama https://ai.google.dev/gemma/terms
|
||||
# model = "llama3.2" # Llama Good results; qwen seems slightly better https://huggingface.co/meta-llama/Llama-3.2-1B/blob/main/LICENSE.txt
|
||||
# model = "mistral:7b" # Apache 2.0 Tool calls don"t work
|
||||
model = "qwen2.5:7b" # Apache 2.0 Good results
|
||||
model = "qwen2.5:7b" # Apache 2.0 Good results
|
||||
# model = "qwen3:8b" # Apache 2.0 Requires newer ollama
|
||||
model = os.getenv("MODEL_NAME", model)
|
||||
|
||||
@ -40,7 +40,7 @@ logging_level = os.getenv("LOGGING_LEVEL", "INFO").upper()
|
||||
# RAG and Vector DB settings
|
||||
## Where to read RAG content
|
||||
|
||||
chunk_buffer = 5 # Number of lines before and after chunk beyond the portion used in embedding (to return to callers)
|
||||
chunk_buffer = 5 # Number of lines before and after chunk beyond the portion used in embedding (to return to callers)
|
||||
|
||||
# Maximum number of entries for ChromaDB to find
|
||||
default_rag_top_k = 50
|
||||
@ -60,7 +60,7 @@ cert_path = "/opt/backstory/keys/cert.pem"
|
||||
host = os.getenv("BACKSTORY_HOST", "0.0.0.0")
|
||||
port = int(os.getenv("BACKSTORY_PORT", "8911"))
|
||||
api_prefix = "/api/1.0"
|
||||
debug=os.getenv("BACKSTORY_DEBUG", "false").lower() in ("true", "1", "yes")
|
||||
debug = os.getenv("BACKSTORY_DEBUG", "false").lower() in ("true", "1", "yes")
|
||||
|
||||
# Used for filtering tracebacks
|
||||
app_path="/opt/backstory/src/backend"
|
||||
app_path = "/opt/backstory/src/backend"
|
||||
|
@ -6,6 +6,7 @@ from datetime import datetime, timezone
|
||||
from user_agents import parse
|
||||
import json
|
||||
|
||||
|
||||
class DeviceManager:
|
||||
def __init__(self, database: RedisDatabase):
|
||||
self.database = database
|
||||
@ -13,7 +14,6 @@ class DeviceManager:
|
||||
def generate_device_fingerprint(self, request: Request) -> str:
|
||||
"""Generate device fingerprint from request"""
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
ip_address = request.client.host if request.client else "unknown"
|
||||
accept_language = request.headers.get("accept-language", "")
|
||||
|
||||
# Create fingerprint
|
||||
@ -35,7 +35,7 @@ class DeviceManager:
|
||||
"os": user_agent.os.family,
|
||||
"os_version": user_agent.os.version_string,
|
||||
"ip_address": request.client.host if request.client else "unknown",
|
||||
"user_agent": user_agent_string
|
||||
"user_agent": user_agent_string,
|
||||
}
|
||||
|
||||
async def is_trusted_device(self, user_id: str, device_id: str) -> bool:
|
||||
@ -55,14 +55,14 @@ class DeviceManager:
|
||||
device_data = {
|
||||
**device_info,
|
||||
"added_at": datetime.now(timezone.utc).isoformat(),
|
||||
"last_used": datetime.now(timezone.utc).isoformat()
|
||||
"last_used": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
# Store for 90 days
|
||||
await self.database.redis.setex(
|
||||
key,
|
||||
90 * 24 * 60 * 60, # 90 days in seconds
|
||||
json.dumps(device_data, default=str)
|
||||
json.dumps(device_data, default=str),
|
||||
)
|
||||
|
||||
logger.info(f"🔒 Added trusted device {device_id} for user {user_id}")
|
||||
@ -80,7 +80,7 @@ class DeviceManager:
|
||||
await self.database.redis.setex(
|
||||
key,
|
||||
90 * 24 * 60 * 60, # Reset 90 day expiry
|
||||
json.dumps(device_info, default=str)
|
||||
json.dumps(device_info, default=str),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating device last used: {e}")
|
||||
|
@ -10,6 +10,7 @@ from datetime import datetime, timezone, timedelta
|
||||
import json
|
||||
from database.manager import RedisDatabase
|
||||
|
||||
|
||||
class EmailService:
|
||||
def __init__(self):
|
||||
# Configure these in your .env file
|
||||
@ -30,36 +31,25 @@ class EmailService:
|
||||
def _format_template(self, template: str, **kwargs) -> str:
|
||||
"""Format template with provided variables"""
|
||||
return template.format(
|
||||
app_name=self.app_name,
|
||||
from_name=self.from_name,
|
||||
frontend_url=self.frontend_url,
|
||||
**kwargs
|
||||
app_name=self.app_name, from_name=self.from_name, frontend_url=self.frontend_url, **kwargs
|
||||
)
|
||||
|
||||
async def send_verification_email(
|
||||
self,
|
||||
to_email: str,
|
||||
verification_token: str,
|
||||
user_name: str,
|
||||
user_type: str = "user"
|
||||
self, to_email: str, verification_token: str, user_name: str, user_type: str = "user"
|
||||
):
|
||||
"""Send email verification email using template"""
|
||||
try:
|
||||
template = self._get_template("verification")
|
||||
verification_link = f"{self.frontend_url}/login/verify-email?token={verification_token}"
|
||||
|
||||
subject = self._format_template(
|
||||
template["subject"],
|
||||
user_name=user_name,
|
||||
to_email=to_email
|
||||
)
|
||||
subject = self._format_template(template["subject"], user_name=user_name, to_email=to_email)
|
||||
|
||||
html_content = self._format_template(
|
||||
template["html"],
|
||||
user_name=user_name,
|
||||
user_type=user_type,
|
||||
to_email=to_email,
|
||||
verification_link=verification_link
|
||||
verification_link=verification_link,
|
||||
)
|
||||
|
||||
await self._send_email(to_email, subject, html_content)
|
||||
@ -70,12 +60,7 @@ class EmailService:
|
||||
raise
|
||||
|
||||
async def send_mfa_email(
|
||||
self,
|
||||
to_email: str,
|
||||
mfa_code: str,
|
||||
device_name: str,
|
||||
user_name: str,
|
||||
ip_address: str = "Unknown"
|
||||
self, to_email: str, mfa_code: str, device_name: str, user_name: str, ip_address: str = "Unknown"
|
||||
):
|
||||
"""Send MFA code email using template"""
|
||||
try:
|
||||
@ -91,7 +76,7 @@ class EmailService:
|
||||
ip_address=ip_address,
|
||||
login_time=login_time,
|
||||
mfa_code=mfa_code,
|
||||
to_email=to_email
|
||||
to_email=to_email,
|
||||
)
|
||||
|
||||
await self._send_email(to_email, subject, html_content)
|
||||
@ -101,12 +86,7 @@ class EmailService:
|
||||
logger.error(f"❌ Failed to send MFA email to {to_email}: {e}")
|
||||
raise
|
||||
|
||||
async def send_password_reset_email(
|
||||
self,
|
||||
to_email: str,
|
||||
reset_token: str,
|
||||
user_name: str
|
||||
):
|
||||
async def send_password_reset_email(self, to_email: str, reset_token: str, user_name: str):
|
||||
"""Send password reset email using template"""
|
||||
try:
|
||||
template = self._get_template("password_reset")
|
||||
@ -115,10 +95,7 @@ class EmailService:
|
||||
subject = self._format_template(template["subject"])
|
||||
|
||||
html_content = self._format_template(
|
||||
template["html"],
|
||||
user_name=user_name,
|
||||
reset_link=reset_link,
|
||||
to_email=to_email
|
||||
template["html"], user_name=user_name, reset_link=reset_link, to_email=to_email
|
||||
)
|
||||
|
||||
await self._send_email(to_email, subject, html_content)
|
||||
@ -134,14 +111,14 @@ class EmailService:
|
||||
if not self.email_user:
|
||||
raise ValueError("Email user is not configured")
|
||||
# Create message
|
||||
msg = MIMEMultipart('alternative')
|
||||
msg['From'] = f"{self.from_name} <{self.email_user}>"
|
||||
msg['To'] = to_email
|
||||
msg['Subject'] = subject
|
||||
msg['Reply-To'] = self.email_user
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["From"] = f"{self.from_name} <{self.email_user}>"
|
||||
msg["To"] = to_email
|
||||
msg["Subject"] = subject
|
||||
msg["Reply-To"] = self.email_user
|
||||
|
||||
# Add HTML content
|
||||
html_part = MIMEText(html_content, 'html', 'utf-8')
|
||||
html_part = MIMEText(html_content, "html", "utf-8")
|
||||
msg.attach(html_part)
|
||||
|
||||
# Send email with connection pooling and retry logic
|
||||
@ -157,7 +134,6 @@ class EmailService:
|
||||
text = msg.as_string()
|
||||
server.sendmail(self.email_user, to_email, text)
|
||||
break # Success, exit retry loop
|
||||
|
||||
except smtplib.SMTPException as e:
|
||||
if attempt == max_retries - 1: # Last attempt
|
||||
raise
|
||||
@ -170,6 +146,7 @@ class EmailService:
|
||||
logger.error(f"❌ SMTP error sending to {to_email}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class EmailRateLimiter:
|
||||
def __init__(self, database: RedisDatabase):
|
||||
self.database = database
|
||||
@ -191,10 +168,7 @@ class EmailRateLimiter:
|
||||
email_records = json.loads(count_data)
|
||||
|
||||
# Filter out old records
|
||||
recent_records = [
|
||||
record for record in email_records
|
||||
if datetime.fromisoformat(record) > window_start
|
||||
]
|
||||
recent_records = [record for record in email_records if datetime.fromisoformat(record) > window_start]
|
||||
|
||||
if len(recent_records) >= limit:
|
||||
logger.warning(f"⚠️ Email rate limit exceeded for {email} ({email_type})")
|
||||
@ -202,11 +176,7 @@ class EmailRateLimiter:
|
||||
|
||||
# Add current email to records
|
||||
recent_records.append(current_time.isoformat())
|
||||
await self.database.redis.setex(
|
||||
key,
|
||||
window_minutes * 60,
|
||||
json.dumps(recent_records)
|
||||
)
|
||||
await self.database.redis.setex(key, window_minutes * 60, json.dumps(recent_records))
|
||||
|
||||
return True
|
||||
|
||||
@ -217,11 +187,8 @@ class EmailRateLimiter:
|
||||
|
||||
async def _record_email_sent(self, key: str, timestamp: datetime, ttl_minutes: int):
|
||||
"""Record that an email was sent"""
|
||||
await self.database.redis.setex(
|
||||
key,
|
||||
ttl_minutes * 60,
|
||||
json.dumps([timestamp.isoformat()])
|
||||
)
|
||||
await self.database.redis.setex(key, ttl_minutes * 60, json.dumps([timestamp.isoformat()]))
|
||||
|
||||
|
||||
class VerificationEmailRateLimiter:
|
||||
def __init__(self, database: RedisDatabase):
|
||||
@ -242,7 +209,10 @@ class VerificationEmailRateLimiter:
|
||||
# Check daily limit
|
||||
daily_count = await self.database.get_verification_attempts_count(email)
|
||||
if daily_count >= self.max_attempts_per_day:
|
||||
return False, f"Daily limit reached. You can request up to {self.max_attempts_per_day} verification emails per day."
|
||||
return (
|
||||
False,
|
||||
f"Daily limit reached. You can request up to {self.max_attempts_per_day} verification emails per day.",
|
||||
)
|
||||
|
||||
# Check hourly limit
|
||||
hour_ago = current_time - timedelta(hours=1)
|
||||
@ -251,13 +221,13 @@ class VerificationEmailRateLimiter:
|
||||
|
||||
if data:
|
||||
attempts_data = json.loads(data)
|
||||
recent_attempts = [
|
||||
attempt for attempt in attempts_data
|
||||
if datetime.fromisoformat(attempt) > hour_ago
|
||||
]
|
||||
recent_attempts = [attempt for attempt in attempts_data if datetime.fromisoformat(attempt) > hour_ago]
|
||||
|
||||
if len(recent_attempts) >= self.max_attempts_per_hour:
|
||||
return False, f"Hourly limit reached. You can request up to {self.max_attempts_per_hour} verification emails per hour."
|
||||
return (
|
||||
False,
|
||||
f"Hourly limit reached. You can request up to {self.max_attempts_per_hour} verification emails per hour.",
|
||||
)
|
||||
|
||||
# Check cooldown period
|
||||
if recent_attempts:
|
||||
@ -280,7 +250,4 @@ class VerificationEmailRateLimiter:
|
||||
await self.database.record_verification_attempt(email)
|
||||
|
||||
|
||||
|
||||
email_service = EmailService()
|
||||
|
||||
|
@ -129,9 +129,8 @@ EMAIL_TEMPLATES = {
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
""",
|
||||
},
|
||||
|
||||
"mfa": {
|
||||
"subject": "Security Code for Backstory",
|
||||
"html": """
|
||||
@ -274,9 +273,8 @@ EMAIL_TEMPLATES = {
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
""",
|
||||
},
|
||||
|
||||
"password_reset": {
|
||||
"subject": "Reset your Backstory password",
|
||||
"html": """
|
||||
@ -386,6 +384,6 @@ EMAIL_TEMPLATES = {
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
}
|
||||
""",
|
||||
},
|
||||
}
|
@ -3,12 +3,13 @@ import weakref
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Optional
|
||||
from contextlib import asynccontextmanager
|
||||
from pydantic import BaseModel, Field # type: ignore
|
||||
from pydantic import BaseModel # type: ignore
|
||||
|
||||
from models import Candidate
|
||||
from agents.base import CandidateEntity
|
||||
from database.manager import RedisDatabase
|
||||
from prometheus_client import CollectorRegistry # type: ignore
|
||||
from prometheus_client import CollectorRegistry # type: ignore
|
||||
|
||||
|
||||
class EntityManager(BaseModel):
|
||||
"""Manages lifecycle of CandidateEntity instances"""
|
||||
@ -36,10 +37,7 @@ class EntityManager(BaseModel):
|
||||
pass
|
||||
self._cleanup_task = None
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
prometheus_collector: CollectorRegistry,
|
||||
database: RedisDatabase):
|
||||
def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
|
||||
"""Initialize the EntityManager with Prometheus collector"""
|
||||
self._prometheus_collector = prometheus_collector
|
||||
self._database = database
|
||||
@ -58,9 +56,7 @@ class EntityManager(BaseModel):
|
||||
raise ValueError("EntityManager has not been initialized with required components.")
|
||||
|
||||
entity = CandidateEntity(candidate=candidate)
|
||||
await entity.initialize(
|
||||
prometheus_collector=self._prometheus_collector,
|
||||
database=self._database)
|
||||
await entity.initialize(prometheus_collector=self._prometheus_collector, database=self._database)
|
||||
|
||||
# Store with reference tracking
|
||||
self._entities[candidate.id] = entity
|
||||
@ -105,10 +101,12 @@ class EntityManager(BaseModel):
|
||||
|
||||
def _on_entity_deleted(self, user_id: str):
|
||||
"""Callback when entity is garbage collected"""
|
||||
|
||||
def cleanup_callback(weak_ref):
|
||||
self._entities.pop(user_id, None)
|
||||
self._weak_refs.pop(user_id, None)
|
||||
print(f"Entity {user_id} garbage collected")
|
||||
|
||||
return cleanup_callback
|
||||
|
||||
async def release_entity(self, user_id: str):
|
||||
@ -138,8 +136,7 @@ class EntityManager(BaseModel):
|
||||
time_since_access = current_time - entity.last_accessed
|
||||
|
||||
# Remove if TTL exceeded and no active references
|
||||
if (time_since_access > timedelta(minutes=self._ttl_minutes)
|
||||
and entity.reference_count == 0):
|
||||
if time_since_access > timedelta(minutes=self._ttl_minutes) and entity.reference_count == 0:
|
||||
expired_entities.append(user_id)
|
||||
|
||||
for user_id in expired_entities:
|
||||
@ -153,6 +150,7 @@ class EntityManager(BaseModel):
|
||||
# Global entity manager instance
|
||||
entity_manager = EntityManager(default_ttl_minutes=30)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_candidate_entity(candidate: Candidate):
|
||||
"""Context manager for safe entity access with automatic reference management"""
|
||||
@ -164,4 +162,5 @@ async def get_candidate_entity(candidate: Candidate):
|
||||
finally:
|
||||
await entity_manager.release_entity(candidate.id)
|
||||
|
||||
|
||||
EntityManager.model_rebuild()
|
||||
|
@ -6,10 +6,7 @@ without getting caught up in serialization format complexities
|
||||
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from models import (
|
||||
UserStatus, UserType, SkillLevel, EmploymentType,
|
||||
Candidate, Employer, Location, Skill
|
||||
)
|
||||
from models import UserStatus, UserType, SkillLevel, EmploymentType, Candidate, Employer, Location, Skill
|
||||
|
||||
|
||||
def test_model_creation():
|
||||
@ -23,37 +20,38 @@ def test_model_creation():
|
||||
# Create candidate
|
||||
candidate = Candidate(
|
||||
email="test@example.com",
|
||||
user_type=UserType.CANDIDATE,
|
||||
username="test_candidate",
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
status=UserStatus.ACTIVE,
|
||||
firstName="John",
|
||||
lastName="Doe",
|
||||
fullName="John Doe",
|
||||
first_name="John",
|
||||
last_name="Doe",
|
||||
full_name="John Doe",
|
||||
skills=[skill],
|
||||
experience=[],
|
||||
education=[],
|
||||
preferredJobTypes=[EmploymentType.FULL_TIME],
|
||||
preferred_job_types=[EmploymentType.FULL_TIME],
|
||||
location=location,
|
||||
languages=[],
|
||||
certifications=[]
|
||||
certifications=[],
|
||||
)
|
||||
|
||||
# Create employer
|
||||
employer = Employer(
|
||||
firstName="Mary",
|
||||
lastName="Smith",
|
||||
fullName="Mary Smith",
|
||||
user_type=UserType.EMPLOYER,
|
||||
first_name="Mary",
|
||||
last_name="Smith",
|
||||
full_name="Mary Smith",
|
||||
email="hr@company.com",
|
||||
username="test_employer",
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
status=UserStatus.ACTIVE,
|
||||
companyName="Test Company",
|
||||
company_name="Test Company",
|
||||
industry="Technology",
|
||||
companySize="50-200",
|
||||
companyDescription="A test company",
|
||||
location=location
|
||||
company_size="50-200",
|
||||
company_description="A test company",
|
||||
location=location,
|
||||
)
|
||||
|
||||
print(f"✅ Candidate: {candidate.first_name} {candidate.last_name}")
|
||||
@ -62,6 +60,7 @@ def test_model_creation():
|
||||
|
||||
return candidate, employer
|
||||
|
||||
|
||||
def test_json_api_format():
|
||||
"""Test JSON serialization in API format (the most important use case)"""
|
||||
print("\n📡 Testing JSON API format...")
|
||||
@ -84,11 +83,12 @@ def test_json_api_format():
|
||||
assert candidate_back.first_name == candidate.first_name
|
||||
assert employer_back.company_name == employer.company_name
|
||||
|
||||
print(f"✅ JSON round-trip successful")
|
||||
print(f"✅ Data integrity verified")
|
||||
print("✅ JSON round-trip successful")
|
||||
print("✅ Data integrity verified")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def test_api_dict_format():
|
||||
"""Test dictionary format with aliases (for API requests/responses)"""
|
||||
print("\n📊 Testing API dictionary format...")
|
||||
@ -105,8 +105,8 @@ def test_api_dict_format():
|
||||
assert "createdAt" in candidate_dict
|
||||
assert "companyName" in employer_dict
|
||||
|
||||
print(f"✅ API format dictionaries created")
|
||||
print(f"✅ CamelCase aliases verified")
|
||||
print("✅ API format dictionaries created")
|
||||
print("✅ CamelCase aliases verified")
|
||||
|
||||
# Test deserializing from API format
|
||||
candidate_back = Candidate.model_validate(candidate_dict)
|
||||
@ -115,25 +115,27 @@ def test_api_dict_format():
|
||||
assert candidate_back.email == candidate.email
|
||||
assert employer_back.company_name == employer.company_name
|
||||
|
||||
print(f"✅ API format round-trip successful")
|
||||
print("✅ API format round-trip successful")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def test_validation_constraints():
|
||||
"""Test that validation constraints work"""
|
||||
print("\n🔒 Testing validation constraints...")
|
||||
|
||||
try:
|
||||
# Create a candidate with invalid email
|
||||
invalid_candidate = Candidate(
|
||||
Candidate(
|
||||
user_type=UserType.CANDIDATE,
|
||||
email="invalid-email",
|
||||
username="test_invalid",
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
status=UserStatus.ACTIVE,
|
||||
firstName="Jane",
|
||||
lastName="Doe",
|
||||
fullName="Jane Doe"
|
||||
first_name="Jane",
|
||||
last_name="Doe",
|
||||
full_name="Jane Doe",
|
||||
)
|
||||
print("❌ Validation should have failed but didn't")
|
||||
return False
|
||||
@ -141,6 +143,7 @@ def test_validation_constraints():
|
||||
print(f"✅ Validation error caught: {e}")
|
||||
return True
|
||||
|
||||
|
||||
def test_enum_values():
|
||||
"""Test that enum values work correctly"""
|
||||
print("\n📋 Testing enum values...")
|
||||
@ -155,11 +158,12 @@ def test_enum_values():
|
||||
assert candidate_dict["userType"] == "candidate"
|
||||
assert employer.user_type == UserType.EMPLOYER
|
||||
|
||||
print(f"✅ Enum values correctly serialized")
|
||||
print("✅ Enum values correctly serialized")
|
||||
print(f"✅ User types: candidate={candidate.user_type}, employer={employer.user_type}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all focused tests"""
|
||||
print("🎯 Focused Pydantic Model Tests")
|
||||
@ -172,7 +176,7 @@ def main():
|
||||
test_validation_constraints()
|
||||
test_enum_values()
|
||||
|
||||
print(f"\n🎉 All focused tests passed!")
|
||||
print("\n🎉 All focused tests passed!")
|
||||
print("=" * 40)
|
||||
print("✅ Models work correctly")
|
||||
print("✅ JSON API format works")
|
||||
@ -185,10 +189,12 @@ def main():
|
||||
except Exception as e:
|
||||
print(f"\n❌ Test failed: {type(e).__name__}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
print(f"\n❌ {traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
@ -2,6 +2,7 @@ from pydantic import BaseModel
|
||||
import json
|
||||
from typing import Any, List, Set
|
||||
|
||||
|
||||
def check_serializable(obj: Any, path: str = "", errors: List[str] = [], visited: Set[int] = set()) -> List[str]:
|
||||
"""
|
||||
Recursively check all fields in an object for non-JSON-serializable types, avoiding infinite recursion.
|
||||
|
@ -10,16 +10,19 @@ assert issubclass(CandidateAI, BaseUserWithType), "CandidateAI must inherit from
|
||||
assert issubclass(Employer, BaseUserWithType), "Employer must inherit from BaseUserWithType"
|
||||
assert issubclass(Guest, BaseUserWithType), "Guest must inherit from BaseUserWithType"
|
||||
|
||||
T = TypeVar('T', bound=BaseModel)
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
def cast_to_model(model_cls: Type[T], source: BaseModel) -> T:
|
||||
data = {field: getattr(source, field) for field in model_cls.__fields__}
|
||||
return model_cls(**data)
|
||||
|
||||
|
||||
def cast_to_model_safe(model_cls: Type[T], source: BaseModel) -> T:
|
||||
data = {field: copy.deepcopy(getattr(source, field)) for field in model_cls.__fields__}
|
||||
return model_cls(**data)
|
||||
|
||||
|
||||
def cast_to_base_user_with_type(user) -> BaseUserWithType:
|
||||
"""
|
||||
Casts a Candidate, CandidateAI, Employer, or Guest to BaseUserWithType.
|
||||
|
@ -7,7 +7,8 @@ from typing import Any
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline, FluxPipeline
|
||||
|
||||
class ImageModelCache: # Stay loaded for 3 hours
|
||||
|
||||
class ImageModelCache: # Stay loaded for 3 hours
|
||||
def __init__(self, timeout_seconds: float = 3 * 60 * 60):
|
||||
self._pipe = None
|
||||
self._model_name = None
|
||||
@ -36,11 +37,11 @@ class ImageModelCache: # Stay loaded for 3 hours
|
||||
cached_model_type = self._get_model_type(self._model_name) if self._model_name else None
|
||||
|
||||
if (
|
||||
self._pipe is not None and
|
||||
self._model_name == model and
|
||||
self._device == device and
|
||||
current_model_type == cached_model_type and
|
||||
current_time - self._last_access_time < self._timeout_seconds
|
||||
self._pipe is not None
|
||||
and self._model_name == model
|
||||
and self._device == device
|
||||
and current_model_type == cached_model_type
|
||||
and current_time - self._last_access_time < self._timeout_seconds
|
||||
):
|
||||
self._last_access_time = current_time
|
||||
return self._pipe
|
||||
@ -52,8 +53,10 @@ class ImageModelCache: # Stay loaded for 3 hours
|
||||
model,
|
||||
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
||||
)
|
||||
|
||||
def dummy_safety_checker(images, clip_input):
|
||||
return images, [False] * len(images)
|
||||
|
||||
pipe.safety_checker = dummy_safety_checker
|
||||
else:
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
@ -61,7 +64,7 @@ class ImageModelCache: # Stay loaded for 3 hours
|
||||
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
||||
)
|
||||
try:
|
||||
pipe.load_lora_weights('enhanceaiteam/Flux-uncensored', weight_name='lora.safetensors')
|
||||
pipe.load_lora_weights("enhanceaiteam/Flux-uncensored", weight_name="lora.safetensors")
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to load LoRA weights: {str(e)}")
|
||||
|
||||
@ -89,10 +92,7 @@ class ImageModelCache: # Stay loaded for 3 hours
|
||||
|
||||
async def cleanup_if_expired(self):
|
||||
async with self._lock:
|
||||
if (
|
||||
self._pipe is not None and
|
||||
time.time() - self._last_access_time >= self._timeout_seconds
|
||||
):
|
||||
if self._pipe is not None and time.time() - self._last_access_time >= self._timeout_seconds:
|
||||
await self._unload_model()
|
||||
|
||||
async def _periodic_cleanup(self):
|
||||
|
@ -29,6 +29,7 @@ TIME_ESTIMATES = {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ImageRequest(BaseModel):
|
||||
session_id: str
|
||||
filepath: str
|
||||
@ -39,18 +40,22 @@ class ImageRequest(BaseModel):
|
||||
width: int = 256
|
||||
guidance_scale: float = 7.5
|
||||
|
||||
|
||||
# Global model cache instance
|
||||
model_cache = ImageModelCache()
|
||||
|
||||
|
||||
def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str):
|
||||
"""Background worker for Flux image generation"""
|
||||
try:
|
||||
# Flux: Run generation in the background and yield progress updates
|
||||
status_queue.put(ChatMessageStatus(
|
||||
session_id=params.session_id,
|
||||
content=f"Initializing image generation.",
|
||||
activity=ApiActivityType.GENERATING_IMAGE,
|
||||
))
|
||||
status_queue.put(
|
||||
ChatMessageStatus(
|
||||
session_id=params.session_id,
|
||||
content="Initializing image generation.",
|
||||
activity=ApiActivityType.GENERATING_IMAGE,
|
||||
)
|
||||
)
|
||||
|
||||
# Start the generation task
|
||||
start_gen_time = time.time()
|
||||
@ -58,13 +63,15 @@ def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task
|
||||
# Simulate your pipe call with progress updates
|
||||
def status_callback(pipeline, step, timestep, callback_kwargs):
|
||||
# Send progress updates
|
||||
progress = int((step+1) / params.iterations * 100)
|
||||
progress = int((step + 1) / params.iterations * 100)
|
||||
|
||||
status_queue.put(ChatMessageStatus(
|
||||
session_id=params.session_id,
|
||||
content=f"Processing step {step+1}/{params.iterations} ({progress}%)",
|
||||
activity=ApiActivityType.GENERATING_IMAGE,
|
||||
))
|
||||
status_queue.put(
|
||||
ChatMessageStatus(
|
||||
session_id=params.session_id,
|
||||
content=f"Processing step {step+1}/{params.iterations} ({progress}%)",
|
||||
activity=ApiActivityType.GENERATING_IMAGE,
|
||||
)
|
||||
)
|
||||
return callback_kwargs
|
||||
|
||||
# Replace this block with your actual Flux pipe call:
|
||||
@ -84,22 +91,28 @@ def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task
|
||||
image.save(params.filepath)
|
||||
|
||||
# Final completion status
|
||||
status_queue.put(ChatMessage(
|
||||
session_id=params.session_id,
|
||||
status=ApiStatusType.DONE,
|
||||
content=f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.",
|
||||
))
|
||||
status_queue.put(
|
||||
ChatMessage(
|
||||
session_id=params.session_id,
|
||||
status=ApiStatusType.DONE,
|
||||
content=f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.",
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(e)
|
||||
status_queue.put(ChatMessageError(
|
||||
session_id=params.session_id,
|
||||
content=f"Error during image generation: {str(e)}",
|
||||
))
|
||||
status_queue.put(
|
||||
ChatMessageError(
|
||||
session_id=params.session_id,
|
||||
content=f"Error during image generation: {str(e)}",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError, None]:
|
||||
async def async_generate_image(
|
||||
pipe: Any, params: ImageRequest
|
||||
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError, None]:
|
||||
"""
|
||||
Single async function that handles background Flux generation with status streaming
|
||||
"""
|
||||
@ -109,18 +122,14 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato
|
||||
|
||||
try:
|
||||
# Start background worker thread
|
||||
worker_thread = Thread(
|
||||
target=flux_worker,
|
||||
args=(pipe, params, status_queue, task_id),
|
||||
daemon=True
|
||||
)
|
||||
worker_thread = Thread(target=flux_worker, args=(pipe, params, status_queue, task_id), daemon=True)
|
||||
worker_thread.start()
|
||||
|
||||
# Initial status
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=params.session_id,
|
||||
content=f"Starting image generation with task ID {task_id}",
|
||||
activity=ApiActivityType.THINKING
|
||||
activity=ApiActivityType.THINKING,
|
||||
)
|
||||
yield status_message
|
||||
|
||||
@ -177,18 +186,16 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato
|
||||
yield final_status
|
||||
|
||||
except Exception as e:
|
||||
error_status = ChatMessageError(
|
||||
session_id=params.session_id,
|
||||
content=f'Server error: {str(e)}'
|
||||
)
|
||||
error_status = ChatMessageError(session_id=params.session_id, content=f"Server error: {str(e)}")
|
||||
logger.error(error_status)
|
||||
yield error_status
|
||||
|
||||
finally:
|
||||
# Cleanup: ensure thread completion
|
||||
if worker_thread and 'worker_thread' in locals() and worker_thread.is_alive():
|
||||
if worker_thread and "worker_thread" in locals() and worker_thread.is_alive():
|
||||
worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup
|
||||
|
||||
|
||||
def status(session_id: str, status: str) -> ChatMessageStatus:
|
||||
"""Update chat message status and return it."""
|
||||
chat_message = ChatMessageStatus(
|
||||
@ -198,6 +205,7 @@ def status(session_id: str, status: str) -> ChatMessageStatus:
|
||||
)
|
||||
return chat_message
|
||||
|
||||
|
||||
async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, None]:
|
||||
"""Generate an image with specified dimensions and yield status updates with time estimates."""
|
||||
session_id = request.session_id
|
||||
@ -205,10 +213,7 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
|
||||
try:
|
||||
# Validate prompt
|
||||
if not prompt:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Prompt cannot be empty."
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content="Prompt cannot be empty.")
|
||||
logger.error(error_message.content)
|
||||
yield error_message
|
||||
return
|
||||
@ -216,15 +221,14 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
|
||||
# Validate dimensions
|
||||
if request.height <= 0 or request.width <= 0:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Height and width must be positive integers."
|
||||
session_id=session_id, content="Height and width must be positive integers."
|
||||
)
|
||||
logger.error(error_message.content)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
filedir = os.path.dirname(request.filepath)
|
||||
filename = os.path.basename(request.filepath)
|
||||
os.path.basename(request.filepath)
|
||||
os.makedirs(filedir, exist_ok=True)
|
||||
|
||||
model_type = "flux"
|
||||
@ -233,14 +237,17 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
|
||||
# Get initial time estimate, scaled by resolution
|
||||
estimates = TIME_ESTIMATES[model_type][device]
|
||||
resolution_scale = (request.height * request.width) / (512 * 512)
|
||||
estimated_total = estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale
|
||||
estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale
|
||||
|
||||
# Initialize or get cached pipeline
|
||||
start_time = time.time()
|
||||
yield status(session_id, f"Loading generative image model...")
|
||||
yield status(session_id, "Loading generative image model...")
|
||||
pipe = await model_cache.get_pipeline(request.model, device)
|
||||
load_time = time.time() - start_time
|
||||
yield status(session_id, f"Model loaded in {load_time:.1f} seconds.",)
|
||||
yield status(
|
||||
session_id,
|
||||
f"Model loaded in {load_time:.1f} seconds.",
|
||||
)
|
||||
|
||||
progress = None
|
||||
async for progress in async_generate_image(pipe, request):
|
||||
@ -252,8 +259,7 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
|
||||
|
||||
if not progress:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content="Image generation failed to produce a valid response."
|
||||
session_id=session_id, content="Image generation failed to produce a valid response."
|
||||
)
|
||||
logger.error(f"⚠️ {error_message.content}")
|
||||
yield error_message
|
||||
@ -269,10 +275,7 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
|
||||
yield chat_message
|
||||
|
||||
except Exception as e:
|
||||
error_message = ChatMessageError(
|
||||
session_id=session_id,
|
||||
content=f"Error during image generation: {str(e)}"
|
||||
)
|
||||
error_message = ChatMessageError(session_id=session_id, content=f"Error during image generation: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(error_message.content)
|
||||
yield error_message
|
||||
|
@ -2,6 +2,7 @@ import json
|
||||
import re
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
def extract_json_blocks(text: str, allow_multiple: bool = False) -> List[dict]:
|
||||
"""
|
||||
Extract JSON blocks from text, even if surrounded by markdown or noisy text.
|
||||
@ -31,13 +32,14 @@ def extract_json_blocks(text: str, allow_multiple: bool = False) -> List[dict]:
|
||||
|
||||
return found
|
||||
|
||||
|
||||
def _extract_standalone_json(text: str, allow_multiple: bool = False) -> List[Union[dict, list]]:
|
||||
"""Extract standalone JSON objects or arrays from text using proper brace counting."""
|
||||
found = []
|
||||
i = 0
|
||||
|
||||
while i < len(text):
|
||||
if text[i] in '{[':
|
||||
if text[i] in "{[":
|
||||
# Found potential JSON start
|
||||
json_str = _extract_complete_json_at_position(text, i)
|
||||
if json_str:
|
||||
@ -55,16 +57,17 @@ def _extract_standalone_json(text: str, allow_multiple: bool = False) -> List[Un
|
||||
|
||||
return found
|
||||
|
||||
|
||||
def _extract_complete_json_at_position(text: str, start_pos: int) -> str:
|
||||
"""
|
||||
Extract a complete JSON object or array starting at the given position.
|
||||
Uses proper brace/bracket counting and string escape handling.
|
||||
"""
|
||||
if start_pos >= len(text) or text[start_pos] not in '{[':
|
||||
if start_pos >= len(text) or text[start_pos] not in "{[":
|
||||
return ""
|
||||
|
||||
start_char = text[start_pos]
|
||||
end_char = '}' if start_char == '{' else ']'
|
||||
end_char = "}" if start_char == "{" else "]"
|
||||
|
||||
count = 1
|
||||
i = start_pos + 1
|
||||
@ -76,7 +79,7 @@ def _extract_complete_json_at_position(text: str, start_pos: int) -> str:
|
||||
|
||||
if escape_next:
|
||||
escape_next = False
|
||||
elif char == '\\' and in_string:
|
||||
elif char == "\\" and in_string:
|
||||
escape_next = True
|
||||
elif char == '"' and not escape_next:
|
||||
in_string = not in_string
|
||||
@ -92,6 +95,7 @@ def _extract_complete_json_at_position(text: str, start_pos: int) -> str:
|
||||
return text[start_pos:i]
|
||||
return ""
|
||||
|
||||
|
||||
def extract_json_from_text(text: str) -> str:
|
||||
"""Extract JSON string from text that may contain other content."""
|
||||
return json.dumps(extract_json_blocks(text, allow_multiple=False)[0])
|
||||
|
@ -2,11 +2,11 @@ import os
|
||||
import warnings
|
||||
import logging
|
||||
import defines
|
||||
|
||||
|
||||
def _setup_logging(level=defines.logging_level) -> logging.Logger:
|
||||
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
|
||||
warnings.filterwarnings(
|
||||
"ignore", message="Overriding a previously registered kernel"
|
||||
)
|
||||
warnings.filterwarnings("ignore", message="Overriding a previously registered kernel")
|
||||
warnings.filterwarnings("ignore", message="Warning only once for all operators")
|
||||
warnings.filterwarnings("ignore", message=".*Couldn't find ffmpeg or avconv.*")
|
||||
warnings.filterwarnings("ignore", message="'force_all_finite' was renamed to")
|
||||
@ -19,8 +19,7 @@ def _setup_logging(level=defines.logging_level) -> logging.Logger:
|
||||
|
||||
# Create a custom formatter
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(levelname)s - %(filename)s:%(lineno)d - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
fmt="%(levelname)s - %(filename)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
|
||||
# Create a handler (e.g., StreamHandler for console output)
|
||||
@ -50,5 +49,6 @@ def _setup_logging(level=defines.logging_level) -> logging.Logger:
|
||||
logger = logging.getLogger(__name__)
|
||||
return logger
|
||||
|
||||
|
||||
logger = _setup_logging(level=defines.logging_level)
|
||||
logger.debug(f"Logging initialized with level: {defines.logging_level}")
|
@ -58,6 +58,7 @@ background_task_manager = None
|
||||
prev_int = signal.getsignal(signal.SIGINT)
|
||||
prev_term = signal.getsignal(signal.SIGTERM)
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
logger.info(f"⚠️ Received signal {signum!r}, shutting down…")
|
||||
# now call the old handler (it might raise KeyboardInterrupt or exit)
|
||||
@ -66,9 +67,11 @@ def signal_handler(signum, frame):
|
||||
elif signum == signal.SIGTERM and callable(prev_term):
|
||||
prev_term(signum, frame)
|
||||
|
||||
|
||||
# Global background task manager
|
||||
background_task_manager: Optional[BackgroundTaskManager] = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup
|
||||
@ -116,6 +119,7 @@ async def lifespan(app: FastAPI):
|
||||
if db_manager:
|
||||
await db_manager.graceful_shutdown()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
title="Backstory API",
|
||||
@ -129,11 +133,9 @@ app = FastAPI(
|
||||
ssl_enabled = os.getenv("SSL_ENABLED", "true").lower() == "true"
|
||||
|
||||
if ssl_enabled:
|
||||
allow_origins = ["https://battle-linux.ketrenos.com:3000",
|
||||
"https://backstory-beta.ketrenos.com"]
|
||||
allow_origins = ["https://battle-linux.ketrenos.com:3000", "https://backstory-beta.ketrenos.com"]
|
||||
else:
|
||||
allow_origins = ["http://battle-linux.ketrenos.com:3000",
|
||||
"http://backstory-beta.ketrenos.com"]
|
||||
allow_origins = ["http://battle-linux.ketrenos.com:3000", "http://backstory-beta.ketrenos.com"]
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
@ -144,12 +146,14 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# ============================
|
||||
# Debug data type failures
|
||||
# ============================
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Validation error {request.method} {request.url.path}: {str(exc)}")
|
||||
@ -158,6 +162,7 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
|
||||
content=json.dumps({"detail": str(exc)}),
|
||||
)
|
||||
|
||||
|
||||
# ============================
|
||||
# Create API router with prefix
|
||||
# ============================
|
||||
@ -181,9 +186,10 @@ api_router.include_router(users.router)
|
||||
# Health Check and Info Endpoints
|
||||
# ============================
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check(
|
||||
database = Depends(get_database),
|
||||
database=Depends(get_database),
|
||||
):
|
||||
"""Health check endpoint"""
|
||||
try:
|
||||
@ -202,15 +208,12 @@ async def health_check(
|
||||
return {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"database": {
|
||||
"status": "connected",
|
||||
"stats": stats
|
||||
},
|
||||
"database": {"status": "connected", "stats": stats},
|
||||
"redis": {
|
||||
"version": redis_info.get("redis_version", "unknown"),
|
||||
"uptime": redis_info.get("uptime_in_seconds", 0),
|
||||
"memory_used": redis_info.get("used_memory_human", "unknown")
|
||||
}
|
||||
"memory_used": redis_info.get("used_memory_human", "unknown"),
|
||||
},
|
||||
}
|
||||
|
||||
except RuntimeError as e:
|
||||
@ -219,6 +222,7 @@ async def health_check(
|
||||
logger.error(f"❌ Health check failed: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
# ============================
|
||||
# Include Router in App
|
||||
# ============================
|
||||
@ -231,11 +235,14 @@ app.include_router(api_router)
|
||||
# ============================
|
||||
logger.info(f"Debug mode is {'enabled' if defines.debug else 'disabled'}")
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def log_requests(request: Request, call_next):
|
||||
try:
|
||||
if defines.debug and not re.match(rf"{defines.api_prefix}/metrics", request.url.path):
|
||||
logger.info(f"📝 Request {request.method}: {request.url.path}, Remote: {request.client.host if request.client else ''}")
|
||||
logger.info(
|
||||
f"📝 Request {request.method}: {request.url.path}, Remote: {request.client.host if request.client else ''}"
|
||||
)
|
||||
response = await call_next(request)
|
||||
if defines.debug and not re.match(rf"{defines.api_prefix}/metrics", request.url.path):
|
||||
if response.status_code < 200 or response.status_code >= 300:
|
||||
@ -243,11 +250,13 @@ async def log_requests(request: Request, call_next):
|
||||
return response
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Error processing request: {str(e)}, Path: {request.url.path}, Method: {request.method}")
|
||||
return JSONResponse(status_code=400, content={"detail": "Invalid HTTP request"})
|
||||
|
||||
|
||||
# ============================
|
||||
# Request tracking middleware
|
||||
# ============================
|
||||
@ -266,6 +275,7 @@ async def track_requests(request, call_next):
|
||||
finally:
|
||||
db_manager.decrement_requests()
|
||||
|
||||
|
||||
# ============================
|
||||
# FastAPI Metrics
|
||||
# ============================
|
||||
@ -277,7 +287,7 @@ instrumentator = Instrumentator(
|
||||
should_ignore_untemplated=True,
|
||||
should_group_untemplated=True,
|
||||
excluded_handlers=[f"{defines.api_prefix}/metrics"],
|
||||
registry=prometheus_collector
|
||||
registry=prometheus_collector,
|
||||
)
|
||||
|
||||
# Instrument the FastAPI app
|
||||
@ -291,6 +301,7 @@ instrumentator.expose(app, endpoint=f"{defines.api_prefix}/metrics")
|
||||
# Static File Serving
|
||||
# ============================
|
||||
|
||||
|
||||
@app.get("/{path:path}")
|
||||
async def serve_static(path: str, request: Request):
|
||||
full_path = os.path.join(defines.static_content, path)
|
||||
@ -300,6 +311,7 @@ async def serve_static(path: str, request: Request):
|
||||
|
||||
return FileResponse(os.path.join(defines.static_content, "index.html"))
|
||||
|
||||
|
||||
# Root endpoint when no static files
|
||||
@app.get("/", include_in_schema=False)
|
||||
async def root():
|
||||
@ -309,9 +321,10 @@ async def root():
|
||||
"version": "1.0.0",
|
||||
"api_prefix": defines.api_prefix,
|
||||
"documentation": f"{defines.api_prefix}/docs",
|
||||
"health": f"{defines.api_prefix}/health"
|
||||
"health": f"{defines.api_prefix}/health",
|
||||
}
|
||||
|
||||
|
||||
async def periodic_verification_cleanup():
|
||||
"""Background task to periodically clean up expired verification tokens"""
|
||||
try:
|
||||
@ -324,6 +337,7 @@ async def periodic_verification_cleanup():
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error in periodic verification cleanup: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
host = defines.host
|
||||
port = defines.port
|
||||
|
File diff suppressed because it is too large
Load Diff
23
src/backend/pyproject.toml
Normal file
23
src/backend/pyproject.toml
Normal file
@ -0,0 +1,23 @@
|
||||
[tool.ruff]
|
||||
# Set line length (default is 88, Black-compatible)
|
||||
line-length = 120
|
||||
exclude = [
|
||||
".git",
|
||||
"__pycache__",
|
||||
"migrations",
|
||||
"venv",
|
||||
"dist",
|
||||
"build",
|
||||
]
|
||||
target-version = "py312"
|
||||
|
||||
[tool.ruff.format]
|
||||
# Use Ruff's formatter (Black-compatible by default)
|
||||
quote-style = "double" # Enforce double quotes
|
||||
indent-style = "space" # Use spaces for indentation
|
||||
skip-magic-trailing-comma = false # Add trailing commas where applicable
|
||||
|
||||
[tool.ruff.per-file-ignores]
|
||||
# Ignore specific rules for certain files (e.g., tests or scripts)
|
||||
"tests/**" = ["D100", "S101"] # Ignore missing docstrings and assert warnings in tests
|
||||
"scripts/**" = ["E402"] # Ignore import order in scripts
|
@ -1,7 +1,3 @@
|
||||
from .rag import ChromaDBFileWatcher, start_file_watcher, RagEntry
|
||||
__all__ = [
|
||||
"ChromaDBFileWatcher",
|
||||
"start_file_watcher",
|
||||
"RagEntry"
|
||||
]
|
||||
|
||||
__all__ = ["ChromaDBFileWatcher", "start_file_watcher", "RagEntry"]
|
||||
|
@ -7,10 +7,12 @@ import logging
|
||||
|
||||
import defines
|
||||
|
||||
|
||||
class Chunk(TypedDict):
|
||||
text: str
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
def clear_chunk(chunk: Chunk):
|
||||
chunk["text"] = ""
|
||||
chunk["metadata"] = {
|
||||
@ -22,6 +24,7 @@ def clear_chunk(chunk: Chunk):
|
||||
}
|
||||
return chunk
|
||||
|
||||
|
||||
class MarkdownChunker:
|
||||
def __init__(self):
|
||||
# Initialize the Markdown parser
|
||||
@ -76,10 +79,7 @@ class MarkdownChunker:
|
||||
# Initialize a chunk structure
|
||||
chunk: Chunk = {
|
||||
"text": "",
|
||||
"metadata": {
|
||||
"source_file": file_path,
|
||||
"lines": total_lines
|
||||
},
|
||||
"metadata": {"source_file": file_path, "lines": total_lines},
|
||||
}
|
||||
clear_chunk(chunk)
|
||||
|
||||
@ -89,9 +89,7 @@ class MarkdownChunker:
|
||||
return chunks
|
||||
|
||||
def _sanitize_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {
|
||||
k: ("" if v is None else v) for k, v in metadata.items() if v is not None
|
||||
}
|
||||
return {k: ("" if v is None else v) for k, v in metadata.items() if v is not None}
|
||||
|
||||
def _extract_text_from_children(self, node: SyntaxTreeNode) -> str:
|
||||
lines = []
|
||||
@ -114,7 +112,7 @@ class MarkdownChunker:
|
||||
chunks: List[Chunk],
|
||||
chunk: Chunk,
|
||||
level: int,
|
||||
buffer: int = defines.chunk_buffer
|
||||
buffer: int = defines.chunk_buffer,
|
||||
) -> int:
|
||||
is_list = False
|
||||
# Handle heading nodes
|
||||
@ -141,7 +139,7 @@ class MarkdownChunker:
|
||||
if node.nester_tokens:
|
||||
opening, closing = node.nester_tokens
|
||||
if opening and opening.map:
|
||||
( begin, end ) = opening.map
|
||||
(begin, end) = opening.map
|
||||
metadata = chunk["metadata"]
|
||||
metadata["chunk_begin"] = max(0, begin - buffer)
|
||||
metadata["chunk_end"] = min(metadata["lines"], end + buffer)
|
||||
@ -186,7 +184,7 @@ class MarkdownChunker:
|
||||
if node.nester_tokens:
|
||||
opening, closing = node.nester_tokens
|
||||
if opening and opening.map:
|
||||
( begin, end ) = opening.map
|
||||
(begin, end) = opening.map
|
||||
metadata = chunk["metadata"]
|
||||
metadata["chunk_begin"] = max(0, begin - buffer)
|
||||
metadata["chunk_end"] = min(metadata["lines"], end + buffer)
|
||||
@ -198,9 +196,7 @@ class MarkdownChunker:
|
||||
# Recursively process children
|
||||
if not is_list:
|
||||
for child in node.children:
|
||||
level = self._process_node(
|
||||
child, current_headings, chunks, chunk, level=level
|
||||
)
|
||||
level = self._process_node(child, current_headings, chunks, chunk, level=level)
|
||||
|
||||
# After root-level recursion, finalize any remaining chunk
|
||||
if node.type == "document":
|
||||
@ -211,7 +207,7 @@ class MarkdownChunker:
|
||||
if node.nester_tokens:
|
||||
opening, closing = node.nester_tokens
|
||||
if opening and opening.map:
|
||||
( begin, end ) = opening.map
|
||||
(begin, end) = opening.map
|
||||
metadata = chunk["metadata"]
|
||||
metadata["chunk_begin"] = max(0, begin - buffer)
|
||||
metadata["chunk_end"] = min(metadata["lines"], end + buffer)
|
||||
|
@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field # type: ignore
|
||||
from pydantic import BaseModel # type: ignore
|
||||
from typing import List, Optional, Dict, Any
|
||||
import os
|
||||
import glob
|
||||
@ -11,10 +11,10 @@ import json
|
||||
import numpy as np # type: ignore
|
||||
import traceback
|
||||
|
||||
import chromadb # type: ignore
|
||||
import chromadb # type: ignore
|
||||
from watchdog.observers import Observer # type: ignore
|
||||
from watchdog.events import FileSystemEventHandler # type: ignore
|
||||
import umap # type: ignore
|
||||
import umap # type: ignore
|
||||
from markitdown import MarkItDown # type: ignore
|
||||
from chromadb.api.models.Collection import Collection # type: ignore
|
||||
|
||||
@ -33,11 +33,13 @@ __all__ = ["ChromaDBFileWatcher", "start_file_watcher"]
|
||||
DEFAULT_CHUNK_SIZE = 750
|
||||
DEFAULT_CHUNK_OVERLAP = 100
|
||||
|
||||
|
||||
class RagEntry(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
def __init__(
|
||||
self,
|
||||
@ -72,9 +74,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
# self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
|
||||
# Path for storing file hash state
|
||||
self.hash_state_path = os.path.join(
|
||||
self.persist_directory, f"{collection_name}_hash_state.json"
|
||||
)
|
||||
self.hash_state_path = os.path.join(self.persist_directory, f"{collection_name}_hash_state.json")
|
||||
|
||||
# Flag to track if this is a new collection
|
||||
self.is_new_collection = False
|
||||
@ -158,9 +158,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
process_all: If True, process all files regardless of hash status
|
||||
"""
|
||||
# Check for new or modified files
|
||||
file_paths = glob.glob(
|
||||
os.path.join(self.watch_directory, "**/*"), recursive=True
|
||||
)
|
||||
file_paths = glob.glob(os.path.join(self.watch_directory, "**/*"), recursive=True)
|
||||
files_checked = 0
|
||||
files_processed = 0
|
||||
files_to_process = []
|
||||
@ -180,20 +178,12 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
continue
|
||||
|
||||
# If file is new, changed, or we're processing all files
|
||||
if (
|
||||
process_all
|
||||
or file_path not in self.file_hashes
|
||||
or self.file_hashes[file_path] != current_hash
|
||||
):
|
||||
if process_all or file_path not in self.file_hashes or self.file_hashes[file_path] != current_hash:
|
||||
self.file_hashes[file_path] = current_hash
|
||||
files_to_process.append(file_path)
|
||||
logging.info(
|
||||
f"File {'found' if process_all else 'changed'}: {file_path}"
|
||||
)
|
||||
logging.info(f"File {'found' if process_all else 'changed'}: {file_path}")
|
||||
|
||||
logging.info(
|
||||
f"Found {len(files_to_process)} files to process after scanning {files_checked} files"
|
||||
)
|
||||
logging.info(f"Found {len(files_to_process)} files to process after scanning {files_checked} files")
|
||||
|
||||
# Check for deleted files
|
||||
deleted_files = []
|
||||
@ -201,9 +191,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
if not os.path.exists(file_path):
|
||||
deleted_files.append(file_path)
|
||||
# Schedule removal
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.remove_file_from_collection(file_path), self.loop
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
|
||||
# Don't block on result, just let it run
|
||||
logging.info(f"File deleted: {file_path}")
|
||||
|
||||
@ -253,10 +241,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
if not current_hash: # File might have been deleted or is inaccessible
|
||||
return
|
||||
|
||||
if (
|
||||
file_path in self.file_hashes
|
||||
and self.file_hashes[file_path] == current_hash
|
||||
):
|
||||
if file_path in self.file_hashes and self.file_hashes[file_path] == current_hash:
|
||||
# File hasn't actually changed in content
|
||||
logging.info(f"Hash has not changed for {file_path}")
|
||||
return
|
||||
@ -289,9 +274,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
if results and "ids" in results and results["ids"]:
|
||||
self.collection.delete(ids=results["ids"])
|
||||
await self.database.update_user_rag_timestamp(self.user_id)
|
||||
logging.info(
|
||||
f"Removed {len(results['ids'])} chunks for deleted file: {file_path}"
|
||||
)
|
||||
logging.info(f"Removed {len(results['ids'])} chunks for deleted file: {file_path}")
|
||||
|
||||
# Remove from hash dictionary
|
||||
if file_path in self.file_hashes:
|
||||
@ -304,17 +287,15 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
|
||||
def _update_umaps(self):
|
||||
# Update the UMAP embeddings
|
||||
self._umap_collection = ChromaDBGetResponse.model_validate(self._collection.get(
|
||||
include=["embeddings", "documents", "metadatas"]
|
||||
))
|
||||
self._umap_collection = ChromaDBGetResponse.model_validate(
|
||||
self._collection.get(include=["embeddings", "documents", "metadatas"])
|
||||
)
|
||||
if not self._umap_collection or not len(self._umap_collection.embeddings):
|
||||
logging.warning("⚠️ No embeddings found in the collection.")
|
||||
return
|
||||
|
||||
# During initialization
|
||||
logging.info(
|
||||
f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors"
|
||||
)
|
||||
logging.info(f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors")
|
||||
vectors = np.array(self._umap_collection.embeddings)
|
||||
self._umap_model_2d = umap.UMAP(
|
||||
n_components=2,
|
||||
@ -323,14 +304,12 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
n_neighbors=30,
|
||||
min_dist=0.1,
|
||||
)
|
||||
self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors) # type: ignore
|
||||
self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors) # type: ignore
|
||||
# logging.info(
|
||||
# f"2D UMAP model n_components: {self._umap_model_2d.n_components}"
|
||||
# ) # Should be 2
|
||||
|
||||
logging.info(
|
||||
f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors"
|
||||
)
|
||||
logging.info(f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors")
|
||||
self._umap_model_3d = umap.UMAP(
|
||||
n_components=3,
|
||||
random_state=8911,
|
||||
@ -338,7 +317,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
n_neighbors=30,
|
||||
min_dist=0.01,
|
||||
)
|
||||
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors)# type: ignore
|
||||
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors) # type: ignore
|
||||
# logging.info(
|
||||
# f"3D UMAP model n_components: {self._umap_model_3d.n_components}"
|
||||
# ) # Should be 3
|
||||
@ -373,9 +352,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
self.is_new_collection = True
|
||||
logging.info(f"Recreating collection: {self.collection_name}")
|
||||
|
||||
return chroma_client.get_or_create_collection(
|
||||
name=self.collection_name, metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
return chroma_client.get_or_create_collection(name=self.collection_name, metadata={"hnsw:space": "cosine"})
|
||||
|
||||
async def get_embedding(self, text: str) -> np.ndarray:
|
||||
"""Generate and normalize an embedding for the given text."""
|
||||
@ -419,9 +396,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
# Generate a more unique ID based on content and metadata
|
||||
path_hash = ""
|
||||
if "path" in metadata:
|
||||
path_hash = hashlib.md5(metadata["source_file"].encode()).hexdigest()[
|
||||
:8
|
||||
]
|
||||
path_hash = hashlib.md5(metadata["source_file"].encode()).hexdigest()[:8]
|
||||
content_hash = hashlib.md5(text.encode()).hexdigest()[:8]
|
||||
chunk_id = f"{path_hash}_{i}_{content_hash}"
|
||||
|
||||
@ -438,7 +413,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
logging.error(traceback.format_exc())
|
||||
logging.error(chunk)
|
||||
|
||||
def prepare_metadata(self, meta: Dict[str, Any], buffer=defines.chunk_buffer)-> str | None:
|
||||
def prepare_metadata(self, meta: Dict[str, Any], buffer=defines.chunk_buffer) -> str | None:
|
||||
source_file = meta.get("source_file")
|
||||
try:
|
||||
source_file = meta["source_file"]
|
||||
@ -541,9 +516,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
return
|
||||
|
||||
file_path = event.src_path
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.remove_file_from_collection(file_path), self.loop
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
|
||||
logging.info(f"File deleted: {file_path}")
|
||||
|
||||
def on_moved(self, event):
|
||||
@ -571,11 +544,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
try:
|
||||
# Remove existing entries for this file
|
||||
existing_results = self.collection.get(where={"path": file_path})
|
||||
if (
|
||||
existing_results
|
||||
and "ids" in existing_results
|
||||
and existing_results["ids"]
|
||||
):
|
||||
if existing_results and "ids" in existing_results and existing_results["ids"]:
|
||||
self.collection.delete(ids=existing_results["ids"])
|
||||
await self.database.update_user_rag_timestamp(self.user_id)
|
||||
|
||||
@ -584,15 +553,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
p = Path(file_path)
|
||||
p_as_md = p.with_suffix(".md")
|
||||
if p_as_md.exists():
|
||||
logging.info(
|
||||
f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}"
|
||||
)
|
||||
logging.info(f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}")
|
||||
|
||||
# If file_path.md doesn't exist or file_path is newer than file_path.md,
|
||||
# fire off markitdown
|
||||
if (not p_as_md.exists()) or (
|
||||
p.stat().st_mtime > p_as_md.stat().st_mtime
|
||||
):
|
||||
if (not p_as_md.exists()) or (p.stat().st_mtime > p_as_md.stat().st_mtime):
|
||||
self._markitdown(file_path, p_as_md)
|
||||
return
|
||||
|
||||
@ -626,9 +591,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
|
||||
# Process all files regardless of hash state
|
||||
num_processed = await self.scan_directory(process_all=True)
|
||||
|
||||
logging.info(
|
||||
f"Vectorstore initialized with {self.collection.count()} documents"
|
||||
)
|
||||
logging.info(f"Vectorstore initialized with {self.collection.count()} documents")
|
||||
|
||||
self._update_umaps()
|
||||
|
||||
@ -676,7 +639,7 @@ def start_file_watcher(
|
||||
persist_directory=persist_directory,
|
||||
collection_name=collection_name,
|
||||
recreate=recreate,
|
||||
database=database
|
||||
database=database,
|
||||
)
|
||||
|
||||
# Process all files if:
|
||||
|
@ -13,14 +13,4 @@ from . import employers
|
||||
from . import admin
|
||||
from . import system
|
||||
|
||||
__all__ = [
|
||||
"auth",
|
||||
"candidates",
|
||||
"resumes",
|
||||
"jobs",
|
||||
"chat",
|
||||
"users",
|
||||
"employers",
|
||||
"admin",
|
||||
"system"
|
||||
]
|
||||
__all__ = ["auth", "candidates", "resumes", "jobs", "chat", "users", "employers", "admin", "system"]
|
||||
|
@ -5,8 +5,18 @@ import json
|
||||
from datetime import datetime, timezone, UTC
|
||||
|
||||
from fastapi import (
|
||||
APIRouter, HTTPException, Depends, Body, Request, HTTPException,
|
||||
Depends, Query, Path, Body, APIRouter, Request
|
||||
APIRouter,
|
||||
HTTPException,
|
||||
Depends,
|
||||
Body,
|
||||
Request,
|
||||
HTTPException,
|
||||
Depends,
|
||||
Query,
|
||||
Path,
|
||||
Body,
|
||||
APIRouter,
|
||||
Request,
|
||||
)
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
@ -14,52 +24,51 @@ from fastapi.responses import JSONResponse
|
||||
from utils.rate_limiter import RateLimiter, get_rate_limiter
|
||||
from database.manager import RedisDatabase
|
||||
from logger import logger
|
||||
from utils.dependencies import (
|
||||
get_current_admin, get_current_user_or_guest, get_database, background_task_manager
|
||||
)
|
||||
from utils.dependencies import get_current_admin, get_current_user_or_guest, get_database, background_task_manager
|
||||
from utils.responses import (
|
||||
create_paginated_response, create_success_response, create_error_response,
|
||||
create_paginated_response,
|
||||
create_success_response,
|
||||
create_error_response,
|
||||
)
|
||||
|
||||
|
||||
# Create router for authentication endpoints
|
||||
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
|
||||
|
||||
@router.post("/tasks/cleanup-guests")
|
||||
async def manual_guest_cleanup(
|
||||
inactive_hours: int = Body(24, embed=True),
|
||||
current_user = Depends(get_current_admin),
|
||||
admin_user = Depends(get_current_admin)
|
||||
current_user=Depends(get_current_admin),
|
||||
admin_user=Depends(get_current_admin),
|
||||
):
|
||||
"""Manually trigger guest cleanup (admin only)"""
|
||||
try:
|
||||
if not background_task_manager:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available")
|
||||
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available"),
|
||||
)
|
||||
|
||||
cleaned_count = await background_task_manager.cleanup_inactive_guests(inactive_hours)
|
||||
|
||||
logger.info(f"🧹 Manual guest cleanup triggered by admin {admin_user.id}: {cleaned_count} guests cleaned")
|
||||
|
||||
return create_success_response({
|
||||
"message": f"Guest cleanup completed. Removed {cleaned_count} inactive sessions.",
|
||||
"cleaned_count": cleaned_count,
|
||||
"triggered_by": admin_user.id
|
||||
})
|
||||
return create_success_response(
|
||||
{
|
||||
"message": f"Guest cleanup completed. Removed {cleaned_count} inactive sessions.",
|
||||
"cleaned_count": cleaned_count,
|
||||
"triggered_by": admin_user.id,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Manual guest cleanup error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("CLEANUP_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.post("/tasks/cleanup-tokens")
|
||||
async def manual_token_cleanup(
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
async def manual_token_cleanup(admin_user=Depends(get_current_admin)):
|
||||
"""Manually trigger verification token cleanup (admin only)"""
|
||||
try:
|
||||
global background_task_manager
|
||||
@ -67,31 +76,28 @@ async def manual_token_cleanup(
|
||||
if not background_task_manager:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available")
|
||||
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available"),
|
||||
)
|
||||
|
||||
cleaned_count = await background_task_manager.cleanup_expired_verification_tokens()
|
||||
|
||||
logger.info(f"🧹 Manual token cleanup triggered by admin {admin_user.id}: {cleaned_count} tokens cleaned")
|
||||
|
||||
return create_success_response({
|
||||
"message": f"Token cleanup completed. Removed {cleaned_count} expired tokens.",
|
||||
"cleaned_count": cleaned_count,
|
||||
"triggered_by": admin_user.id
|
||||
})
|
||||
return create_success_response(
|
||||
{
|
||||
"message": f"Token cleanup completed. Removed {cleaned_count} expired tokens.",
|
||||
"cleaned_count": cleaned_count,
|
||||
"triggered_by": admin_user.id,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Manual token cleanup error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("CLEANUP_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.post("/tasks/cleanup-rate-limits")
|
||||
async def manual_rate_limit_cleanup(
|
||||
days_old: int = Body(7, embed=True),
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
async def manual_rate_limit_cleanup(days_old: int = Body(7, embed=True), admin_user=Depends(get_current_admin)):
|
||||
"""Manually trigger rate limit data cleanup (admin only)"""
|
||||
try:
|
||||
global background_task_manager
|
||||
@ -99,60 +105,55 @@ async def manual_rate_limit_cleanup(
|
||||
if not background_task_manager:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available")
|
||||
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available"),
|
||||
)
|
||||
|
||||
cleaned_count = await background_task_manager.cleanup_old_rate_limit_data(days_old)
|
||||
|
||||
logger.info(f"🧹 Manual rate limit cleanup triggered by admin {admin_user.id}: {cleaned_count} keys cleaned")
|
||||
|
||||
return create_success_response({
|
||||
"message": f"Rate limit cleanup completed. Removed {cleaned_count} old keys.",
|
||||
"cleaned_count": cleaned_count,
|
||||
"triggered_by": admin_user.id
|
||||
})
|
||||
return create_success_response(
|
||||
{
|
||||
"message": f"Rate limit cleanup completed. Removed {cleaned_count} old keys.",
|
||||
"cleaned_count": cleaned_count,
|
||||
"triggered_by": admin_user.id,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Manual rate limit cleanup error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("CLEANUP_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
|
||||
|
||||
|
||||
# ========================================
|
||||
# System Health and Maintenance Endpoints
|
||||
# ========================================
|
||||
|
||||
|
||||
@router.get("/system/health")
|
||||
async def get_system_health(
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
async def get_system_health(request: Request, admin_user=Depends(get_current_admin)):
|
||||
"""Get comprehensive system health status (admin only)"""
|
||||
try:
|
||||
# Database health
|
||||
database_manager = getattr(request.app.state, 'database_manager', None)
|
||||
database_manager = getattr(request.app.state, "database_manager", None)
|
||||
db_health = {"status": "unavailable", "healthy": False}
|
||||
|
||||
if database_manager:
|
||||
try:
|
||||
database = database_manager.get_database()
|
||||
database_manager.get_database()
|
||||
from database.manager import redis_manager
|
||||
|
||||
redis_health = await redis_manager.health_check()
|
||||
db_health = {
|
||||
"status": redis_health.get("status", "unknown"),
|
||||
"healthy": redis_health.get("status") == "healthy",
|
||||
"details": redis_health
|
||||
"details": redis_health,
|
||||
}
|
||||
except Exception as e:
|
||||
db_health = {
|
||||
"status": "error",
|
||||
"healthy": False,
|
||||
"error": str(e)
|
||||
}
|
||||
db_health = {"status": "error", "healthy": False, "error": str(e)}
|
||||
|
||||
# Background task health
|
||||
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
|
||||
background_task_manager = getattr(request.app.state, "background_task_manager", None)
|
||||
task_health = {"status": "unavailable", "healthy": False}
|
||||
|
||||
if background_task_manager:
|
||||
@ -166,39 +167,30 @@ async def get_system_health(
|
||||
"healthy": task_status["running"] and failed_tasks == 0,
|
||||
"running_tasks": running_tasks,
|
||||
"failed_tasks": failed_tasks,
|
||||
"total_tasks": task_status["task_count"]
|
||||
"total_tasks": task_status["task_count"],
|
||||
}
|
||||
except Exception as e:
|
||||
task_health = {
|
||||
"status": "error",
|
||||
"healthy": False,
|
||||
"error": str(e)
|
||||
}
|
||||
task_health = {"status": "error", "healthy": False, "error": str(e)}
|
||||
|
||||
# Overall health
|
||||
overall_healthy = db_health["healthy"] and task_health["healthy"]
|
||||
|
||||
return create_success_response({
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"overall_healthy": overall_healthy,
|
||||
"components": {
|
||||
"database": db_health,
|
||||
"background_tasks": task_health
|
||||
return create_success_response(
|
||||
{
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"overall_healthy": overall_healthy,
|
||||
"components": {"database": db_health, "background_tasks": task_health},
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting system health: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("HEALTH_CHECK_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("HEALTH_CHECK_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.post("/maintenance/cleanup")
|
||||
async def run_maintenance_cleanup(
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
request: Request, admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Run comprehensive maintenance cleanup (admin only)"""
|
||||
try:
|
||||
@ -217,50 +209,40 @@ async def run_maintenance_cleanup(
|
||||
cleanup_results[operation_name] = {
|
||||
"success": True,
|
||||
"cleaned_count": result,
|
||||
"message": f"Cleaned {result} items"
|
||||
"message": f"Cleaned {result} items",
|
||||
}
|
||||
except Exception as e:
|
||||
cleanup_results[operation_name] = {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": f"Failed: {str(e)}"
|
||||
}
|
||||
cleanup_results[operation_name] = {"success": False, "error": str(e), "message": f"Failed: {str(e)}"}
|
||||
|
||||
# Calculate totals
|
||||
total_cleaned = sum(
|
||||
result.get("cleaned_count", 0)
|
||||
for result in cleanup_results.values()
|
||||
if result.get("success", False)
|
||||
result.get("cleaned_count", 0) for result in cleanup_results.values() if result.get("success", False)
|
||||
)
|
||||
|
||||
successful_operations = len([
|
||||
r for r in cleanup_results.values()
|
||||
if r.get("success", False)
|
||||
])
|
||||
successful_operations = len([r for r in cleanup_results.values() if r.get("success", False)])
|
||||
|
||||
return create_success_response({
|
||||
"message": f"Maintenance cleanup completed. {total_cleaned} items cleaned across {successful_operations} operations.",
|
||||
"total_cleaned": total_cleaned,
|
||||
"successful_operations": successful_operations,
|
||||
"details": cleanup_results
|
||||
})
|
||||
return create_success_response(
|
||||
{
|
||||
"message": f"Maintenance cleanup completed. {total_cleaned} items cleaned across {successful_operations} operations.",
|
||||
"total_cleaned": total_cleaned,
|
||||
"successful_operations": successful_operations,
|
||||
"details": cleanup_results,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error in maintenance cleanup: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("CLEANUP_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
|
||||
|
||||
|
||||
# ========================================
|
||||
# Background Task Statistics
|
||||
# ========================================
|
||||
|
||||
|
||||
@router.get("/tasks/stats")
|
||||
async def get_task_statistics(
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
request: Request, admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Get background task execution statistics (admin only)"""
|
||||
try:
|
||||
@ -268,7 +250,7 @@ async def get_task_statistics(
|
||||
guest_stats = await database.get_guest_statistics()
|
||||
|
||||
# Get background task manager status
|
||||
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
|
||||
background_task_manager = getattr(request.app.state, "background_task_manager", None)
|
||||
task_manager_stats = {}
|
||||
|
||||
if background_task_manager:
|
||||
@ -276,7 +258,7 @@ async def get_task_statistics(
|
||||
task_manager_stats = {
|
||||
"running": task_status["running"],
|
||||
"task_count": task_status["task_count"],
|
||||
"task_breakdown": {}
|
||||
"task_breakdown": {},
|
||||
}
|
||||
|
||||
# Count tasks by status
|
||||
@ -284,40 +266,35 @@ async def get_task_statistics(
|
||||
status = task["status"]
|
||||
task_manager_stats["task_breakdown"][status] = task_manager_stats["task_breakdown"].get(status, 0) + 1
|
||||
|
||||
return create_success_response({
|
||||
"guest_statistics": guest_stats,
|
||||
"task_manager": task_manager_stats,
|
||||
"timestamp": datetime.now(UTC).isoformat()
|
||||
})
|
||||
return create_success_response(
|
||||
{
|
||||
"guest_statistics": guest_stats,
|
||||
"task_manager": task_manager_stats,
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting task statistics: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STATS_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("STATS_ERROR", str(e)))
|
||||
|
||||
|
||||
# ========================================
|
||||
# Background Task Status Endpoints
|
||||
# ========================================
|
||||
|
||||
|
||||
@router.get("/tasks/status")
|
||||
async def get_background_task_status(
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
async def get_background_task_status(request: Request, admin_user=Depends(get_current_admin)):
|
||||
"""Get background task manager status (admin only)"""
|
||||
try:
|
||||
# Get background task manager from app state
|
||||
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
|
||||
background_task_manager = getattr(request.app.state, "background_task_manager", None)
|
||||
|
||||
if not background_task_manager:
|
||||
return create_success_response({
|
||||
"running": False,
|
||||
"message": "Background task manager not initialized",
|
||||
"tasks": [],
|
||||
"task_count": 0
|
||||
})
|
||||
return create_success_response(
|
||||
{"running": False, "message": "Background task manager not initialized", "tasks": [], "task_count": 0}
|
||||
)
|
||||
|
||||
# Get comprehensive task status using the new method
|
||||
task_status = await background_task_manager.get_task_status()
|
||||
@ -325,91 +302,68 @@ async def get_background_task_status(
|
||||
# Add additional system info
|
||||
system_info = {
|
||||
"uptime_seconds": None, # Could calculate from start time if stored
|
||||
"last_cleanup": None, # Could track last cleanup time
|
||||
"last_cleanup": None, # Could track last cleanup time
|
||||
}
|
||||
|
||||
# Format the response
|
||||
return create_success_response({
|
||||
"running": task_status["running"],
|
||||
"task_count": task_status["task_count"],
|
||||
"loop_status": {
|
||||
"main_loop_id": task_status["main_loop_id"],
|
||||
"current_loop_id": task_status["current_loop_id"],
|
||||
"loop_matches": task_status.get("loop_matches", False)
|
||||
},
|
||||
"tasks": task_status["tasks"],
|
||||
"system_info": system_info
|
||||
})
|
||||
return create_success_response(
|
||||
{
|
||||
"running": task_status["running"],
|
||||
"task_count": task_status["task_count"],
|
||||
"loop_status": {
|
||||
"main_loop_id": task_status["main_loop_id"],
|
||||
"current_loop_id": task_status["current_loop_id"],
|
||||
"loop_matches": task_status.get("loop_matches", False),
|
||||
},
|
||||
"tasks": task_status["tasks"],
|
||||
"system_info": system_info,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get task status error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STATUS_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("STATUS_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.post("/tasks/run/{task_name}")
|
||||
async def run_background_task(
|
||||
task_name: str,
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
async def run_background_task(task_name: str, request: Request, admin_user=Depends(get_current_admin)):
|
||||
"""Manually trigger a specific background task (admin only)"""
|
||||
try:
|
||||
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
|
||||
background_task_manager = getattr(request.app.state, "background_task_manager", None)
|
||||
|
||||
if not background_task_manager:
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content=create_error_response(
|
||||
"MANAGER_UNAVAILABLE",
|
||||
"Background task manager not initialized"
|
||||
)
|
||||
content=create_error_response("MANAGER_UNAVAILABLE", "Background task manager not initialized"),
|
||||
)
|
||||
|
||||
# List of available tasks
|
||||
available_tasks = [
|
||||
"guest_cleanup",
|
||||
"token_cleanup",
|
||||
"guest_stats",
|
||||
"rate_limit_cleanup",
|
||||
"orphaned_cleanup"
|
||||
]
|
||||
available_tasks = ["guest_cleanup", "token_cleanup", "guest_stats", "rate_limit_cleanup", "orphaned_cleanup"]
|
||||
|
||||
if task_name not in available_tasks:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response(
|
||||
"INVALID_TASK",
|
||||
f"Unknown task: {task_name}. Available: {available_tasks}"
|
||||
)
|
||||
"INVALID_TASK", f"Unknown task: {task_name}. Available: {available_tasks}"
|
||||
),
|
||||
)
|
||||
|
||||
# Run the task
|
||||
result = await background_task_manager.force_run_task(task_name)
|
||||
|
||||
return create_success_response({
|
||||
"task_name": task_name,
|
||||
"result": result,
|
||||
"message": f"Task {task_name} completed successfully"
|
||||
})
|
||||
return create_success_response(
|
||||
{"task_name": task_name, "result": result, "message": f"Task {task_name} completed successfully"}
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("INVALID_TASK", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=400, content=create_error_response("INVALID_TASK", str(e)))
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error running task {task_name}: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("TASK_EXECUTION_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("TASK_EXECUTION_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.get("/tasks/list")
|
||||
async def list_available_tasks(
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
async def list_available_tasks(admin_user=Depends(get_current_admin)):
|
||||
"""List all available background tasks (admin only)"""
|
||||
try:
|
||||
tasks = [
|
||||
@ -417,63 +371,51 @@ async def list_available_tasks(
|
||||
"name": "guest_cleanup",
|
||||
"description": "Clean up inactive guest sessions",
|
||||
"interval": "6 hours",
|
||||
"parameters": ["inactive_hours (default: 48)"]
|
||||
"parameters": ["inactive_hours (default: 48)"],
|
||||
},
|
||||
{
|
||||
"name": "token_cleanup",
|
||||
"description": "Clean up expired email verification tokens",
|
||||
"interval": "12 hours",
|
||||
"parameters": []
|
||||
"parameters": [],
|
||||
},
|
||||
{
|
||||
"name": "guest_stats",
|
||||
"description": "Update guest usage statistics",
|
||||
"interval": "1 hour",
|
||||
"parameters": []
|
||||
"parameters": [],
|
||||
},
|
||||
{
|
||||
"name": "rate_limit_cleanup",
|
||||
"description": "Clean up old rate limiting data",
|
||||
"interval": "24 hours",
|
||||
"parameters": ["days_old (default: 7)"]
|
||||
"parameters": ["days_old (default: 7)"],
|
||||
},
|
||||
{
|
||||
"name": "orphaned_cleanup",
|
||||
"description": "Clean up orphaned database records",
|
||||
"interval": "6 hours",
|
||||
"parameters": []
|
||||
}
|
||||
"parameters": [],
|
||||
},
|
||||
]
|
||||
|
||||
return create_success_response({
|
||||
"total_tasks": len(tasks),
|
||||
"tasks": tasks
|
||||
})
|
||||
return create_success_response({"total_tasks": len(tasks), "tasks": tasks})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error listing tasks: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("LIST_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("LIST_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.post("/tasks/restart")
|
||||
async def restart_background_tasks(
|
||||
request: Request,
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
async def restart_background_tasks(request: Request, admin_user=Depends(get_current_admin)):
|
||||
"""Restart the background task manager (admin only)"""
|
||||
try:
|
||||
database_manager = getattr(request.app.state, 'database_manager', None)
|
||||
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
|
||||
database_manager = getattr(request.app.state, "database_manager", None)
|
||||
background_task_manager = getattr(request.app.state, "background_task_manager", None)
|
||||
|
||||
if not database_manager:
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content=create_error_response(
|
||||
"DATABASE_UNAVAILABLE",
|
||||
"Database manager not available"
|
||||
)
|
||||
status_code=503, content=create_error_response("DATABASE_UNAVAILABLE", "Database manager not available")
|
||||
)
|
||||
|
||||
# Stop existing background tasks
|
||||
@ -483,6 +425,7 @@ async def restart_background_tasks(
|
||||
|
||||
# Create and start new background task manager
|
||||
from background_tasks import BackgroundTaskManager
|
||||
|
||||
new_background_task_manager = BackgroundTaskManager(database_manager)
|
||||
await new_background_task_manager.start()
|
||||
|
||||
@ -492,22 +435,20 @@ async def restart_background_tasks(
|
||||
# Get status of new manager
|
||||
status = await new_background_task_manager.get_task_status()
|
||||
|
||||
return create_success_response({
|
||||
"message": "Background task manager restarted successfully",
|
||||
"new_status": status
|
||||
})
|
||||
return create_success_response(
|
||||
{"message": "Background task manager restarted successfully", "new_status": status}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error restarting background tasks: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("RESTART_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("RESTART_ERROR", str(e)))
|
||||
|
||||
|
||||
# ============================
|
||||
# Task Monitoring and Metrics
|
||||
# ============================
|
||||
|
||||
|
||||
class TaskMetrics:
|
||||
"""Collect metrics for background tasks"""
|
||||
|
||||
@ -544,36 +485,32 @@ class TaskMetrics:
|
||||
metrics[task_name] = {
|
||||
"total_runs": self.task_runs[task_name],
|
||||
"total_errors": self.task_errors[task_name],
|
||||
"success_rate": (self.task_runs[task_name] - self.task_errors[task_name]) / self.task_runs[task_name] if self.task_runs[task_name] > 0 else 0,
|
||||
"success_rate": (self.task_runs[task_name] - self.task_errors[task_name]) / self.task_runs[task_name]
|
||||
if self.task_runs[task_name] > 0
|
||||
else 0,
|
||||
"average_duration": avg_duration,
|
||||
"last_runs": durations[-10:] if durations else []
|
||||
"last_runs": durations[-10:] if durations else [],
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
# Global task metrics
|
||||
task_metrics = TaskMetrics()
|
||||
|
||||
|
||||
@router.get("/tasks/metrics")
|
||||
async def get_task_metrics(
|
||||
admin_user = Depends(get_current_admin)
|
||||
):
|
||||
async def get_task_metrics(admin_user=Depends(get_current_admin)):
|
||||
"""Get background task metrics (admin only)"""
|
||||
try:
|
||||
global task_metrics
|
||||
metrics = task_metrics.get_metrics()
|
||||
|
||||
return create_success_response({
|
||||
"metrics": metrics,
|
||||
"timestamp": datetime.now(UTC).isoformat()
|
||||
})
|
||||
return create_success_response({"metrics": metrics, "timestamp": datetime.now(UTC).isoformat()})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get task metrics error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("METRICS_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("METRICS_ERROR", str(e)))
|
||||
|
||||
|
||||
# ============================
|
||||
@ -581,8 +518,7 @@ async def get_task_metrics(
|
||||
# ============================
|
||||
# @router.get("/verification-stats")
|
||||
async def get_verification_statistics(
|
||||
current_user = Depends(get_current_admin),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Get verification statistics (admin only)"""
|
||||
try:
|
||||
@ -591,22 +527,19 @@ async def get_verification_statistics(
|
||||
|
||||
stats = {
|
||||
"pending_verifications": await database.get_pending_verifications_count(),
|
||||
"expired_tokens_cleaned": await database.cleanup_expired_verification_tokens()
|
||||
"expired_tokens_cleaned": await database.cleanup_expired_verification_tokens(),
|
||||
}
|
||||
|
||||
return create_success_response(stats)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting verification stats: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STATS_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("STATS_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.post("/cleanup-verifications")
|
||||
async def cleanup_verification_tokens(
|
||||
current_user = Depends(get_current_admin),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Manually trigger cleanup of expired verification tokens (admin only)"""
|
||||
try:
|
||||
@ -617,24 +550,24 @@ async def cleanup_verification_tokens(
|
||||
|
||||
logger.info(f"🧹 Manual cleanup completed by admin {current_user.id}: {cleaned_count} tokens cleaned")
|
||||
|
||||
return create_success_response({
|
||||
"message": f"Cleanup completed. Removed {cleaned_count} expired verification tokens.",
|
||||
"cleaned_count": cleaned_count
|
||||
})
|
||||
return create_success_response(
|
||||
{
|
||||
"message": f"Cleanup completed. Removed {cleaned_count} expired verification tokens.",
|
||||
"cleaned_count": cleaned_count,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error in manual cleanup: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("CLEANUP_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.get("/pending-verifications")
|
||||
async def get_pending_verifications(
|
||||
current_user = Depends(get_current_admin),
|
||||
current_user=Depends(get_current_admin),
|
||||
page: int = Query(1, ge=1),
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Get list of pending email verifications (admin only)"""
|
||||
try:
|
||||
@ -656,14 +589,16 @@ async def get_pending_verifications(
|
||||
if not verification_info.get("verified", False):
|
||||
expires_at = datetime.fromisoformat(verification_info.get("expires_at", ""))
|
||||
|
||||
pending_verifications.append({
|
||||
"email": verification_info.get("email"),
|
||||
"user_type": verification_info.get("user_type"),
|
||||
"created_at": verification_info.get("created_at"),
|
||||
"expires_at": verification_info.get("expires_at"),
|
||||
"is_expired": current_time > expires_at,
|
||||
"resend_count": verification_info.get("resend_count", 0)
|
||||
})
|
||||
pending_verifications.append(
|
||||
{
|
||||
"email": verification_info.get("email"),
|
||||
"user_type": verification_info.get("user_type"),
|
||||
"created_at": verification_info.get("created_at"),
|
||||
"expires_at": verification_info.get("expires_at"),
|
||||
"is_expired": current_time > expires_at,
|
||||
"resend_count": verification_info.get("resend_count", 0),
|
||||
}
|
||||
)
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
@ -677,35 +612,27 @@ async def get_pending_verifications(
|
||||
end = start + limit
|
||||
paginated_verifications = pending_verifications[start:end]
|
||||
|
||||
paginated_response = create_paginated_response(
|
||||
paginated_verifications,
|
||||
page, limit, total
|
||||
)
|
||||
paginated_response = create_paginated_response(paginated_verifications, page, limit, total)
|
||||
|
||||
return create_success_response(paginated_response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error getting pending verifications: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("FETCH_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.get("/rate-limits/info")
|
||||
async def get_user_rate_limit_status(
|
||||
current_user = Depends(get_current_user_or_guest),
|
||||
current_user=Depends(get_current_user_or_guest),
|
||||
rate_limiter: RateLimiter = Depends(get_rate_limiter),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Get rate limit status for a user (admin only)"""
|
||||
try:
|
||||
# Get user to determine type
|
||||
user_data = await database.get_user_by_id(current_user.id)
|
||||
if not user_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("USER_NOT_FOUND", "User not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
|
||||
|
||||
user_type = user_data.get("type", "unknown")
|
||||
is_admin = False
|
||||
@ -725,27 +652,22 @@ async def get_user_rate_limit_status(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get rate limit status error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STATUS_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("STATUS_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.get("/rate-limits/{user_id}")
|
||||
async def get_anyone_rate_limit_status(
|
||||
user_id: str = Path(...),
|
||||
admin_user = Depends(get_current_admin),
|
||||
admin_user=Depends(get_current_admin),
|
||||
rate_limiter: RateLimiter = Depends(get_rate_limiter),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Get rate limit status for a user (admin only)"""
|
||||
try:
|
||||
# Get user to determine type
|
||||
user_data = await database.get_user_by_id(user_id)
|
||||
if not user_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("USER_NOT_FOUND", "User not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
|
||||
|
||||
user_type = user_data.get("type", "unknown")
|
||||
is_admin = False
|
||||
@ -765,47 +687,36 @@ async def get_anyone_rate_limit_status(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get rate limit status error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STATUS_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("STATUS_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.post("/rate-limits/{user_id}/reset")
|
||||
async def reset_user_rate_limits(
|
||||
user_id: str = Path(...),
|
||||
admin_user = Depends(get_current_admin),
|
||||
admin_user=Depends(get_current_admin),
|
||||
rate_limiter: RateLimiter = Depends(get_rate_limiter),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Reset rate limits for a user (admin only)"""
|
||||
try:
|
||||
# Get user to determine type
|
||||
user_data = await database.get_user_by_id(user_id)
|
||||
if not user_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("USER_NOT_FOUND", "User not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
|
||||
|
||||
user_type = user_data.get("type", "unknown")
|
||||
success = await rate_limiter.reset_user_rate_limits(user_id, user_type)
|
||||
|
||||
if success:
|
||||
logger.info(f"🔄 Rate limits reset for {user_type} {user_id} by admin {admin_user.id}")
|
||||
return create_success_response({
|
||||
"message": f"Rate limits reset for {user_type} {user_id}",
|
||||
"resetBy": admin_user.id
|
||||
})
|
||||
return create_success_response(
|
||||
{"message": f"Rate limits reset for {user_type} {user_id}", "resetBy": admin_user.id}
|
||||
)
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("RESET_FAILED", "Failed to reset rate limits")
|
||||
status_code=500, content=create_error_response("RESET_FAILED", "Failed to reset rate limits")
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Reset rate limits error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("RESET_ERROR", str(e))
|
||||
)
|
||||
|
||||
return JSONResponse(status_code=500, content=create_error_response("RESET_ERROR", str(e)))
|
||||
|
@ -20,18 +20,25 @@ from device_manager import DeviceManager
|
||||
from email_service import VerificationEmailRateLimiter, email_service
|
||||
from logger import logger
|
||||
from models import (
|
||||
LoginRequest, CreateCandidateRequest, Candidate,
|
||||
Employer, Guest, AuthResponse, MFARequest,
|
||||
MFAData, MFAVerifyRequest, ResendVerificationRequest,
|
||||
MFARequestResponse, MFARequestResponse
|
||||
)
|
||||
from utils.dependencies import (
|
||||
get_current_admin, get_database, get_current_user, create_access_token
|
||||
LoginRequest,
|
||||
CreateCandidateRequest,
|
||||
Candidate,
|
||||
Employer,
|
||||
Guest,
|
||||
AuthResponse,
|
||||
MFARequest,
|
||||
MFAData,
|
||||
MFAVerifyRequest,
|
||||
ResendVerificationRequest,
|
||||
MFARequestResponse,
|
||||
MFARequestResponse,
|
||||
)
|
||||
from utils.dependencies import get_current_admin, get_database, get_current_user, create_access_token
|
||||
from utils.responses import create_success_response, create_error_response
|
||||
from utils.rate_limiter import get_rate_limiter
|
||||
from utils.auth_utils import (
|
||||
AuthenticationManager, SecurityConfig,
|
||||
AuthenticationManager,
|
||||
SecurityConfig,
|
||||
validate_password_strength,
|
||||
)
|
||||
|
||||
@ -43,28 +50,31 @@ if JWT_SECRET_KEY == "":
|
||||
raise ValueError("JWT_SECRET_KEY environment variable is not set")
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
|
||||
# ============================
|
||||
# Password Reset Endpoints
|
||||
# ============================
|
||||
class PasswordResetRequest(BaseModel):
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class PasswordResetConfirm(BaseModel):
|
||||
token: str
|
||||
new_password: str
|
||||
|
||||
@field_validator('new_password')
|
||||
@field_validator("new_password")
|
||||
def validate_password_strength(cls, v):
|
||||
is_valid, issues = validate_password_strength(v)
|
||||
if not is_valid:
|
||||
raise ValueError('; '.join(issues))
|
||||
raise ValueError("; ".join(issues))
|
||||
return v
|
||||
|
||||
|
||||
@router.post("/guest")
|
||||
async def create_guest_session_enhanced(
|
||||
request: Request,
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
rate_limiter: RateLimiter = Depends(get_rate_limiter)
|
||||
rate_limiter: RateLimiter = Depends(get_rate_limiter),
|
||||
):
|
||||
"""Create a guest session with enhanced validation and persistence"""
|
||||
try:
|
||||
@ -73,21 +83,15 @@ async def create_guest_session_enhanced(
|
||||
|
||||
# Check rate limits for guest session creation
|
||||
rate_result = await rate_limiter.check_rate_limit(
|
||||
user_id=ip_address,
|
||||
user_type="guest_creation",
|
||||
is_admin=False,
|
||||
endpoint="/guest"
|
||||
user_id=ip_address, user_type="guest_creation", is_admin=False, endpoint="/guest"
|
||||
)
|
||||
|
||||
if not rate_result.allowed:
|
||||
logger.warning(f"🚫 Guest creation rate limit exceeded for IP {ip_address}")
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content=create_error_response(
|
||||
"RATE_LIMITED",
|
||||
rate_result.reason or "Too many guest sessions created"
|
||||
),
|
||||
headers={"Retry-After": str(rate_result.retry_after_seconds or 300)}
|
||||
content=create_error_response("RATE_LIMITED", rate_result.reason or "Too many guest sessions created"),
|
||||
headers={"Retry-After": str(rate_result.retry_after_seconds or 300)},
|
||||
)
|
||||
|
||||
# Generate unique guest identifier with timestamp for uniqueness
|
||||
@ -126,7 +130,7 @@ async def create_guest_session_enhanced(
|
||||
"user_agent": request.headers.get("user-agent", "Unknown"),
|
||||
"converted_to_user_id": None,
|
||||
"browser_session": True, # Mark as browser session
|
||||
"persistent": True, # Mark as persistent
|
||||
"persistent": True, # Mark as persistent
|
||||
}
|
||||
|
||||
# Store guest with enhanced persistence
|
||||
@ -139,7 +143,7 @@ async def create_guest_session_enhanced(
|
||||
"email": guest_data["email"],
|
||||
"username": guest_username,
|
||||
"session_id": session_id,
|
||||
"created_at": current_time.isoformat()
|
||||
"created_at": current_time.isoformat(),
|
||||
}
|
||||
|
||||
await database.set_user(guest_data["email"], user_auth_data)
|
||||
@ -149,11 +153,11 @@ async def create_guest_session_enhanced(
|
||||
# Create authentication tokens with longer expiry for guests
|
||||
access_token = create_access_token(
|
||||
data={"sub": guest_id, "type": "guest"},
|
||||
expires_delta=timedelta(hours=48) # Longer expiry for guests
|
||||
expires_delta=timedelta(hours=48), # Longer expiry for guests
|
||||
)
|
||||
refresh_token = create_access_token(
|
||||
data={"sub": guest_id, "type": "refresh_guest"},
|
||||
expires_delta=timedelta(days=14) # 2 weeks refresh for guests
|
||||
expires_delta=timedelta(days=14), # 2 weeks refresh for guests
|
||||
)
|
||||
|
||||
# Verify guest was stored correctly
|
||||
@ -161,8 +165,7 @@ async def create_guest_session_enhanced(
|
||||
if not verification:
|
||||
logger.error(f"❌ Failed to verify guest storage: {guest_id}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STORAGE_ERROR", "Failed to create guest session")
|
||||
status_code=500, content=create_error_response("STORAGE_ERROR", "Failed to create guest session")
|
||||
)
|
||||
|
||||
# Create guest object for response
|
||||
@ -178,7 +181,7 @@ async def create_guest_session_enhanced(
|
||||
"user": guest.model_dump(by_alias=True),
|
||||
"expiresAt": int((current_time + timedelta(hours=48)).timestamp()),
|
||||
"userType": "guest",
|
||||
"isGuest": True
|
||||
"isGuest": True,
|
||||
}
|
||||
|
||||
return create_success_response(auth_response)
|
||||
@ -186,25 +189,25 @@ async def create_guest_session_enhanced(
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Guest session creation error: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("GUEST_CREATION_FAILED", "Failed to create guest session")
|
||||
status_code=500, content=create_error_response("GUEST_CREATION_FAILED", "Failed to create guest session")
|
||||
)
|
||||
|
||||
|
||||
@router.post("/guest/convert")
|
||||
async def convert_guest_to_user(
|
||||
registration_data: Dict[str, Any] = Body(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Convert a guest session to a permanent user account"""
|
||||
try:
|
||||
# Verify current user is a guest
|
||||
if current_user.user_type != "guest":
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("NOT_GUEST", "Only guest users can be converted")
|
||||
status_code=400, content=create_error_response("NOT_GUEST", "Only guest users can be converted")
|
||||
)
|
||||
|
||||
guest: Guest = current_user
|
||||
@ -215,25 +218,18 @@ async def convert_guest_to_user(
|
||||
try:
|
||||
candidate_request = CreateCandidateRequest.model_validate(registration_data)
|
||||
except ValidationError as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("VALIDATION_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=400, content=create_error_response("VALIDATION_ERROR", str(e)))
|
||||
|
||||
# Check if email/username already exists
|
||||
auth_manager = AuthenticationManager(database)
|
||||
user_exists, conflict_field = await auth_manager.check_user_exists(
|
||||
candidate_request.email,
|
||||
candidate_request.username
|
||||
candidate_request.email, candidate_request.username
|
||||
)
|
||||
|
||||
if user_exists:
|
||||
return JSONResponse(
|
||||
status_code=409,
|
||||
content=create_error_response(
|
||||
"USER_EXISTS",
|
||||
f"A user with this {conflict_field} already exists"
|
||||
)
|
||||
content=create_error_response("USER_EXISTS", f"A user with this {conflict_field} already exists"),
|
||||
)
|
||||
|
||||
# Create candidate
|
||||
@ -253,7 +249,7 @@ async def convert_guest_to_user(
|
||||
"updated_at": current_time.isoformat(),
|
||||
"status": "active",
|
||||
"is_admin": False,
|
||||
"converted_from_guest": guest.id
|
||||
"converted_from_guest": guest.id,
|
||||
}
|
||||
|
||||
candidate = Candidate.model_validate(candidate_data)
|
||||
@ -269,7 +265,7 @@ async def convert_guest_to_user(
|
||||
"id": candidate_id,
|
||||
"type": "candidate",
|
||||
"email": candidate.email,
|
||||
"username": candidate.username
|
||||
"username": candidate.username,
|
||||
}
|
||||
|
||||
await database.set_user(candidate.email, user_auth_data)
|
||||
@ -286,43 +282,45 @@ async def convert_guest_to_user(
|
||||
access_token = create_access_token(data={"sub": candidate_id})
|
||||
refresh_token = create_access_token(
|
||||
data={"sub": candidate_id, "type": "refresh"},
|
||||
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS)
|
||||
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS),
|
||||
)
|
||||
|
||||
auth_response = AuthResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
user=candidate,
|
||||
expires_at=int((current_time + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp())
|
||||
expires_at=int((current_time + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()),
|
||||
)
|
||||
|
||||
logger.info(f"✅ Guest {guest.session_id} converted to candidate {candidate.username}")
|
||||
|
||||
return create_success_response({
|
||||
"message": "Guest account successfully converted to candidate",
|
||||
"auth": auth_response.model_dump(by_alias=True),
|
||||
"conversionType": "candidate"
|
||||
})
|
||||
return create_success_response(
|
||||
{
|
||||
"message": "Guest account successfully converted to candidate",
|
||||
"auth": auth_response.model_dump(by_alias=True),
|
||||
"conversionType": "candidate",
|
||||
}
|
||||
)
|
||||
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("INVALID_TYPE", "Only candidate conversion is currently supported")
|
||||
content=create_error_response("INVALID_TYPE", "Only candidate conversion is currently supported"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Guest conversion error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("CONVERSION_FAILED", "Failed to convert guest account")
|
||||
status_code=500, content=create_error_response("CONVERSION_FAILED", "Failed to convert guest account")
|
||||
)
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(
|
||||
access_token: str = Body(..., alias="accessToken"),
|
||||
refresh_token: str = Body(..., alias="refreshToken"),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Logout endpoint - revokes both access and refresh tokens"""
|
||||
logger.info(f"🔑 User {current_user.id} is logging out")
|
||||
@ -336,21 +334,18 @@ async def logout(
|
||||
|
||||
if not user_id or token_type != "refresh":
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
|
||||
status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
|
||||
)
|
||||
except jwt.PyJWTError as e:
|
||||
logger.warning(f"⚠️ Invalid refresh token during logout: {e}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
|
||||
status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
|
||||
)
|
||||
|
||||
# Verify that the refresh token belongs to the current user
|
||||
if user_id != current_user.id:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content=create_error_response("FORBIDDEN", "Token does not belong to current user")
|
||||
status_code=403, content=create_error_response("FORBIDDEN", "Token does not belong to current user")
|
||||
)
|
||||
|
||||
# Get Redis client
|
||||
@ -362,12 +357,14 @@ async def logout(
|
||||
await redis.setex(
|
||||
f"blacklisted_token:{refresh_token}",
|
||||
refresh_ttl,
|
||||
json.dumps({
|
||||
"user_id": user_id,
|
||||
"token_type": "refresh",
|
||||
"revoked_at": datetime.now(UTC).isoformat(),
|
||||
"reason": "user_logout"
|
||||
})
|
||||
json.dumps(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"token_type": "refresh",
|
||||
"revoked_at": datetime.now(UTC).isoformat(),
|
||||
"reason": "user_logout",
|
||||
}
|
||||
),
|
||||
)
|
||||
logger.info(f"🔒 Blacklisted refresh token for user {user_id}")
|
||||
|
||||
@ -385,12 +382,14 @@ async def logout(
|
||||
await redis.setex(
|
||||
f"blacklisted_token:{access_token}",
|
||||
access_ttl,
|
||||
json.dumps({
|
||||
"user_id": user_id,
|
||||
"token_type": "access",
|
||||
"revoked_at": datetime.now(UTC).isoformat(),
|
||||
"reason": "user_logout"
|
||||
})
|
||||
json.dumps(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"token_type": "access",
|
||||
"revoked_at": datetime.now(UTC).isoformat(),
|
||||
"reason": "user_logout",
|
||||
}
|
||||
),
|
||||
)
|
||||
logger.info(f"🔒 Blacklisted access token for user {user_id}")
|
||||
else:
|
||||
@ -409,26 +408,20 @@ async def logout(
|
||||
# )
|
||||
|
||||
logger.info(f"🔑 User {user_id} logged out successfully")
|
||||
return create_success_response({
|
||||
"message": "Logged out successfully",
|
||||
"tokensRevoked": {
|
||||
"refreshToken": True,
|
||||
"accessToken": bool(access_token)
|
||||
return create_success_response(
|
||||
{
|
||||
"message": "Logged out successfully",
|
||||
"tokensRevoked": {"refreshToken": True, "accessToken": bool(access_token)},
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Logout error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("LOGOUT_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("LOGOUT_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.post("/logout-all")
|
||||
async def logout_all_devices(
|
||||
current_user = Depends(get_current_admin),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
async def logout_all_devices(current_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)):
|
||||
"""Logout from all devices by revoking all tokens for the user"""
|
||||
try:
|
||||
redis = redis_manager.get_client()
|
||||
@ -437,25 +430,20 @@ async def logout_all_devices(
|
||||
await redis.setex(
|
||||
f"user_tokens_revoked:{current_user.id}",
|
||||
int(timedelta(days=30).total_seconds()), # Max refresh token lifetime
|
||||
datetime.now(UTC).isoformat()
|
||||
datetime.now(UTC).isoformat(),
|
||||
)
|
||||
|
||||
logger.info(f"🔒 All tokens revoked for user {current_user.id}")
|
||||
return create_success_response({
|
||||
"message": "Logged out from all devices successfully"
|
||||
})
|
||||
return create_success_response({"message": "Logged out from all devices successfully"})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Logout all devices error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("LOGOUT_ALL_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("LOGOUT_ALL_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh_token_endpoint(
|
||||
refresh_token: str = Body(..., alias="refreshToken"),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
refresh_token: str = Body(..., alias="refreshToken"), database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Refresh token endpoint"""
|
||||
try:
|
||||
@ -466,8 +454,7 @@ async def refresh_token_endpoint(
|
||||
|
||||
if not user_id or token_type != "refresh":
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
|
||||
status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
|
||||
)
|
||||
|
||||
# Create new access token
|
||||
@ -484,37 +471,29 @@ async def refresh_token_endpoint(
|
||||
user = Employer.model_validate(employer_data)
|
||||
|
||||
if not user:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("USER_NOT_FOUND", "User not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
|
||||
|
||||
auth_response = AuthResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token, # Keep same refresh token
|
||||
user=user,
|
||||
expires_at=int((datetime.now(UTC) + timedelta(hours=24)).timestamp())
|
||||
expires_at=int((datetime.now(UTC) + timedelta(hours=24)).timestamp()),
|
||||
)
|
||||
|
||||
return create_success_response(auth_response.model_dump(by_alias=True))
|
||||
|
||||
except jwt.PyJWTError:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
|
||||
)
|
||||
return JSONResponse(status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token"))
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Token refresh error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("REFRESH_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("REFRESH_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.post("/resend-verification")
|
||||
async def resend_verification_email(
|
||||
request: ResendVerificationRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Resend verification email with comprehensive rate limiting and validation"""
|
||||
try:
|
||||
@ -527,10 +506,7 @@ async def resend_verification_email(
|
||||
can_send, reason = await rate_limiter.can_send_verification_email(email_lower)
|
||||
if not can_send:
|
||||
logger.warning(f"⚠️ Verification email rate limit exceeded for {email_lower}: {reason}")
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content=create_error_response("RATE_LIMITED", reason)
|
||||
)
|
||||
return JSONResponse(status_code=429, content=create_error_response("RATE_LIMITED", reason))
|
||||
|
||||
# Clean up expired tokens first
|
||||
await database.cleanup_expired_verification_tokens()
|
||||
@ -541,9 +517,11 @@ async def resend_verification_email(
|
||||
# User exists and is verified - don't reveal this for security
|
||||
logger.info(f"🔍 Resend verification requested for already verified user: {email_lower}")
|
||||
await rate_limiter.record_email_sent(email_lower) # Record attempt to prevent abuse
|
||||
return create_success_response({
|
||||
"message": "If your email is in our system and pending verification, a new verification email has been sent."
|
||||
})
|
||||
return create_success_response(
|
||||
{
|
||||
"message": "If your email is in our system and pending verification, a new verification email has been sent."
|
||||
}
|
||||
)
|
||||
|
||||
# Look for pending verification token
|
||||
verification_data = await database.find_verification_token_by_email(email_lower)
|
||||
@ -552,9 +530,11 @@ async def resend_verification_email(
|
||||
# No pending verification found - don't reveal this for security
|
||||
logger.info(f"🔍 Resend verification requested for non-existent pending verification: {email_lower}")
|
||||
await rate_limiter.record_email_sent(email_lower) # Record attempt to prevent abuse
|
||||
return create_success_response({
|
||||
"message": "If your email is in our system and pending verification, a new verification email has been sent."
|
||||
})
|
||||
return create_success_response(
|
||||
{
|
||||
"message": "If your email is in our system and pending verification, a new verification email has been sent."
|
||||
}
|
||||
)
|
||||
|
||||
# Check if verification token has expired
|
||||
expires_at = datetime.fromisoformat(verification_data["expires_at"])
|
||||
@ -568,8 +548,8 @@ async def resend_verification_email(
|
||||
status_code=400,
|
||||
content=create_error_response(
|
||||
"TOKEN_EXPIRED",
|
||||
"Your verification link has expired. Please register again to create a new account."
|
||||
)
|
||||
"Your verification link has expired. Please register again to create a new account.",
|
||||
),
|
||||
)
|
||||
|
||||
# Generate new verification token (invalidate old one)
|
||||
@ -577,20 +557,19 @@ async def resend_verification_email(
|
||||
new_token = secrets.token_urlsafe(32)
|
||||
|
||||
# Update verification data with new token and reset attempts
|
||||
verification_data.update({
|
||||
"token": new_token,
|
||||
"expires_at": (current_time + timedelta(hours=24)).isoformat(),
|
||||
"resent_at": current_time.isoformat(),
|
||||
"resend_count": verification_data.get("resend_count", 0) + 1
|
||||
})
|
||||
verification_data.update(
|
||||
{
|
||||
"token": new_token,
|
||||
"expires_at": (current_time + timedelta(hours=24)).isoformat(),
|
||||
"resent_at": current_time.isoformat(),
|
||||
"resend_count": verification_data.get("resend_count", 0) + 1,
|
||||
}
|
||||
)
|
||||
|
||||
# Store new token and remove old one
|
||||
await database.redis.delete(f"email_verification:{old_token}")
|
||||
await database.store_email_verification_token(
|
||||
email_lower,
|
||||
new_token,
|
||||
verification_data["user_type"],
|
||||
verification_data["user_data"]
|
||||
email_lower, new_token, verification_data["user_type"], verification_data["user_data"]
|
||||
)
|
||||
|
||||
# Get user name for email
|
||||
@ -610,73 +589,66 @@ async def resend_verification_email(
|
||||
await rate_limiter.record_email_sent(email_lower)
|
||||
|
||||
# Send new verification email in background
|
||||
background_tasks.add_task(
|
||||
email_service.send_verification_email,
|
||||
email_lower,
|
||||
new_token,
|
||||
user_name,
|
||||
user_type
|
||||
)
|
||||
background_tasks.add_task(email_service.send_verification_email, email_lower, new_token, user_name, user_type)
|
||||
|
||||
# Log security event
|
||||
await database.log_security_event(
|
||||
verification_data["user_data"].get("candidate_data", {}).get("id") or
|
||||
verification_data["user_data"].get("employer_data", {}).get("id") or "unknown",
|
||||
verification_data["user_data"].get("candidate_data", {}).get("id")
|
||||
or verification_data["user_data"].get("employer_data", {}).get("id")
|
||||
or "unknown",
|
||||
"verification_resend",
|
||||
{
|
||||
"email": email_lower,
|
||||
"user_type": user_type,
|
||||
"resend_count": verification_data.get("resend_count", 1),
|
||||
"old_token_invalidated": old_token[:8] + "...", # Log partial token for debugging
|
||||
"ip_address": "unknown" # You can extract this from request if needed
|
||||
"ip_address": "unknown", # You can extract this from request if needed
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"✅ Verification email resent to {email_lower} (attempt #{verification_data.get('resend_count', 1)})"
|
||||
)
|
||||
|
||||
return create_success_response(
|
||||
{
|
||||
"message": "A new verification email has been sent to your email address. Please check your inbox and spam folder.",
|
||||
"resendCount": verification_data.get("resend_count", 1),
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"✅ Verification email resent to {email_lower} (attempt #{verification_data.get('resend_count', 1)})")
|
||||
|
||||
return create_success_response({
|
||||
"message": "A new verification email has been sent to your email address. Please check your inbox and spam folder.",
|
||||
"resendCount": verification_data.get("resend_count", 1)
|
||||
})
|
||||
|
||||
except ValueError as ve:
|
||||
logger.warning(f"⚠️ Invalid resend verification request: {ve}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("VALIDATION_ERROR", str(ve))
|
||||
)
|
||||
return JSONResponse(status_code=400, content=create_error_response("VALIDATION_ERROR", str(ve)))
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Resend verification email error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("RESEND_FAILED", "An error occurred while processing your request. Please try again later.")
|
||||
content=create_error_response(
|
||||
"RESEND_FAILED", "An error occurred while processing your request. Please try again later."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/mfa/request")
|
||||
async def request_mfa(
|
||||
request: MFARequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
http_request: Request,
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Request MFA for login from new device"""
|
||||
try:
|
||||
# Verify credentials first
|
||||
auth_manager = AuthenticationManager(database)
|
||||
is_valid, user_data, error_message = await auth_manager.verify_user_credentials(
|
||||
request.email,
|
||||
request.password
|
||||
)
|
||||
is_valid, user_data, error_message = await auth_manager.verify_user_credentials(request.email, request.password)
|
||||
|
||||
if not is_valid or not user_data:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content=create_error_response("AUTH_FAILED", "Invalid credentials")
|
||||
)
|
||||
return JSONResponse(status_code=401, content=create_error_response("AUTH_FAILED", "Invalid credentials"))
|
||||
|
||||
# Check if device is trusted
|
||||
device_manager = DeviceManager(database)
|
||||
device_info = device_manager.parse_device_info(http_request)
|
||||
device_manager.parse_device_info(http_request)
|
||||
|
||||
is_trusted = await device_manager.is_trusted_device(user_data["id"], request.device_id)
|
||||
|
||||
@ -684,10 +656,7 @@ async def request_mfa(
|
||||
# Device is trusted, proceed with normal login
|
||||
await device_manager.update_device_last_used(user_data["id"], request.device_id)
|
||||
|
||||
return create_success_response({
|
||||
"mfa_required": False,
|
||||
"message": "Device is trusted, proceed with login"
|
||||
})
|
||||
return create_success_response({"mfa_required": False, "message": "Device is trusted, proceed with login"})
|
||||
|
||||
# Generate MFA code
|
||||
mfa_code = f"{secrets.randbelow(1000000):06d}" # 6-digit code
|
||||
@ -709,8 +678,7 @@ async def request_mfa(
|
||||
|
||||
if not email:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA")
|
||||
status_code=400, content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA")
|
||||
)
|
||||
|
||||
# Store MFA code
|
||||
@ -718,13 +686,7 @@ async def request_mfa(
|
||||
logger.info(f"🔐 MFA code generated for {email} on device {request.device_id}")
|
||||
|
||||
# Send MFA code via email
|
||||
background_tasks.add_task(
|
||||
email_service.send_mfa_email,
|
||||
email,
|
||||
mfa_code,
|
||||
request.device_name,
|
||||
user_name
|
||||
)
|
||||
background_tasks.add_task(email_service.send_mfa_email, email, mfa_code, request.device_name, user_name)
|
||||
|
||||
logger.info(f"🔐 MFA requested for {request.email} from new device {request.device_name}")
|
||||
|
||||
@ -735,25 +697,22 @@ async def request_mfa(
|
||||
device_id=request.device_id,
|
||||
device_name=request.device_name,
|
||||
)
|
||||
mfa_response = MFARequestResponse(
|
||||
mfa_required=True,
|
||||
mfa_data=mfa_data
|
||||
)
|
||||
mfa_response = MFARequestResponse(mfa_required=True, mfa_data=mfa_data)
|
||||
return create_success_response(mfa_response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ MFA request error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("MFA_REQUEST_FAILED", "Failed to process MFA request")
|
||||
status_code=500, content=create_error_response("MFA_REQUEST_FAILED", "Failed to process MFA request")
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
async def login(
|
||||
request: LoginRequest,
|
||||
http_request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""login with automatic MFA email sending for new devices"""
|
||||
try:
|
||||
@ -766,16 +725,12 @@ async def login(
|
||||
device_id = device_info["device_id"]
|
||||
|
||||
# Verify credentials first
|
||||
is_valid, user_data, error_message = await auth_manager.verify_user_credentials(
|
||||
request.login,
|
||||
request.password
|
||||
)
|
||||
is_valid, user_data, error_message = await auth_manager.verify_user_credentials(request.login, request.password)
|
||||
|
||||
if not is_valid or not user_data:
|
||||
logger.warning(f"⚠️ Failed login attempt for: {request.login}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content=create_error_response("AUTH_FAILED", error_message or "Invalid credentials")
|
||||
status_code=401, content=create_error_response("AUTH_FAILED", error_message or "Invalid credentials")
|
||||
)
|
||||
|
||||
# Check if device is trusted
|
||||
@ -804,8 +759,7 @@ async def login(
|
||||
|
||||
if not email:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA")
|
||||
status_code=400, content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA")
|
||||
)
|
||||
|
||||
# Store MFA code
|
||||
@ -817,12 +771,7 @@ async def login(
|
||||
|
||||
# Send MFA code via email in background
|
||||
background_tasks.add_task(
|
||||
email_service.send_mfa_email,
|
||||
email,
|
||||
mfa_code,
|
||||
device_info["device_name"],
|
||||
user_name,
|
||||
ip_address
|
||||
email_service.send_mfa_email, email, mfa_code, device_info["device_name"], user_name, ip_address
|
||||
)
|
||||
|
||||
# Log security event
|
||||
@ -834,8 +783,8 @@ async def login(
|
||||
"device_name": device_info["device_name"],
|
||||
"ip_address": ip_address,
|
||||
"user_agent": device_info.get("user_agent", ""),
|
||||
"auto_sent": True
|
||||
}
|
||||
"auto_sent": True,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"🔐 MFA code automatically sent to {request.login} for device {device_info['device_name']}")
|
||||
@ -847,8 +796,8 @@ async def login(
|
||||
email=email,
|
||||
device_id=device_id,
|
||||
device_name=device_info["device_name"],
|
||||
code_sent=mfa_code
|
||||
)
|
||||
code_sent=mfa_code,
|
||||
),
|
||||
)
|
||||
return create_success_response(mfa_response.model_dump(by_alias=True))
|
||||
|
||||
@ -860,7 +809,7 @@ async def login(
|
||||
access_token = create_access_token(data={"sub": user_data["id"]})
|
||||
refresh_token = create_access_token(
|
||||
data={"sub": user_data["id"], "type": "refresh"},
|
||||
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS)
|
||||
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS),
|
||||
)
|
||||
|
||||
# Get user object
|
||||
@ -876,8 +825,7 @@ async def login(
|
||||
|
||||
if not user:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("USER_NOT_FOUND", "User profile not found")
|
||||
status_code=404, content=create_error_response("USER_NOT_FOUND", "User profile not found")
|
||||
)
|
||||
|
||||
# Log successful login from trusted device
|
||||
@ -888,8 +836,8 @@ async def login(
|
||||
"device_id": device_id,
|
||||
"device_name": device_info["device_name"],
|
||||
"ip_address": http_request.client.host if http_request.client else "Unknown",
|
||||
"trusted_device": True
|
||||
}
|
||||
"trusted_device": True,
|
||||
},
|
||||
)
|
||||
|
||||
# Create response
|
||||
@ -897,7 +845,9 @@ async def login(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
user=user,
|
||||
expires_at=int((datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp())
|
||||
expires_at=int(
|
||||
(datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(f"🔑 User {request.login} logged in successfully from trusted device")
|
||||
@ -908,17 +858,12 @@ async def login(
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Login error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("LOGIN_ERROR", "An error occurred during login")
|
||||
status_code=500, content=create_error_response("LOGIN_ERROR", "An error occurred during login")
|
||||
)
|
||||
|
||||
|
||||
@router.post("/mfa/verify")
|
||||
async def verify_mfa(
|
||||
request: MFAVerifyRequest,
|
||||
http_request: Request,
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
async def verify_mfa(request: MFAVerifyRequest, http_request: Request, database: RedisDatabase = Depends(get_database)):
|
||||
"""Verify MFA code and complete login with error handling"""
|
||||
try:
|
||||
# Get MFA data
|
||||
@ -928,13 +873,17 @@ async def verify_mfa(
|
||||
logger.warning(f"⚠️ No MFA session found for {request.email} on device {request.device_id}")
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NO_MFA_SESSION", "No active MFA session found. Please try logging in again.")
|
||||
content=create_error_response(
|
||||
"NO_MFA_SESSION", "No active MFA session found. Please try logging in again."
|
||||
),
|
||||
)
|
||||
|
||||
if mfa_data.get("verified"):
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("ALREADY_VERIFIED", "This MFA code has already been used. Please login again.")
|
||||
content=create_error_response(
|
||||
"ALREADY_VERIFIED", "This MFA code has already been used. Please login again."
|
||||
),
|
||||
)
|
||||
|
||||
# Check expiration
|
||||
@ -944,7 +893,7 @@ async def verify_mfa(
|
||||
await database.redis.delete(f"mfa_code:{request.email.lower()}:{request.device_id}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("MFA_EXPIRED", "MFA code has expired. Please try logging in again.")
|
||||
content=create_error_response("MFA_EXPIRED", "MFA code has expired. Please try logging in again."),
|
||||
)
|
||||
|
||||
# Check attempts
|
||||
@ -954,7 +903,9 @@ async def verify_mfa(
|
||||
await database.redis.delete(f"mfa_code:{request.email.lower()}:{request.device_id}")
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content=create_error_response("TOO_MANY_ATTEMPTS", "Too many incorrect attempts. Please try logging in again.")
|
||||
content=create_error_response(
|
||||
"TOO_MANY_ATTEMPTS", "Too many incorrect attempts. Please try logging in again."
|
||||
),
|
||||
)
|
||||
|
||||
# Verify code
|
||||
@ -965,9 +916,8 @@ async def verify_mfa(
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response(
|
||||
"INVALID_CODE",
|
||||
f"Invalid MFA code. {remaining_attempts} attempts remaining."
|
||||
)
|
||||
"INVALID_CODE", f"Invalid MFA code. {remaining_attempts} attempts remaining."
|
||||
),
|
||||
)
|
||||
|
||||
# Mark as verified
|
||||
@ -976,20 +926,13 @@ async def verify_mfa(
|
||||
# Get user data
|
||||
user_data = await database.get_user(request.email)
|
||||
if not user_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("USER_NOT_FOUND", "User not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
|
||||
|
||||
# Add device to trusted devices if requested
|
||||
if request.remember_device:
|
||||
device_manager = DeviceManager(database)
|
||||
device_info = device_manager.parse_device_info(http_request)
|
||||
await device_manager.add_trusted_device(
|
||||
user_data["id"],
|
||||
request.device_id,
|
||||
device_info
|
||||
)
|
||||
await device_manager.add_trusted_device(user_data["id"], request.device_id, device_info)
|
||||
logger.info(f"🔒 Device {request.device_id} added to trusted devices for user {user_data['id']}")
|
||||
|
||||
# Update last login
|
||||
@ -1000,7 +943,7 @@ async def verify_mfa(
|
||||
access_token = create_access_token(data={"sub": user_data["id"]})
|
||||
refresh_token = create_access_token(
|
||||
data={"sub": user_data["id"], "type": "refresh"},
|
||||
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS)
|
||||
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS),
|
||||
)
|
||||
|
||||
# Get user object
|
||||
@ -1016,8 +959,7 @@ async def verify_mfa(
|
||||
|
||||
if not user:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("USER_NOT_FOUND", "User profile not found")
|
||||
status_code=404, content=create_error_response("USER_NOT_FOUND", "User profile not found")
|
||||
)
|
||||
|
||||
# Log successful MFA verification and login
|
||||
@ -1028,8 +970,8 @@ async def verify_mfa(
|
||||
"device_id": request.device_id,
|
||||
"ip_address": http_request.client.host if http_request.client else "Unknown",
|
||||
"device_remembered": request.remember_device,
|
||||
"attempts_used": current_attempts + 1
|
||||
}
|
||||
"attempts_used": current_attempts + 1,
|
||||
},
|
||||
)
|
||||
|
||||
await database.log_security_event(
|
||||
@ -1039,8 +981,8 @@ async def verify_mfa(
|
||||
"device_id": request.device_id,
|
||||
"ip_address": http_request.client.host if http_request.client else "Unknown",
|
||||
"mfa_verified": True,
|
||||
"new_device": True
|
||||
}
|
||||
"new_device": True,
|
||||
},
|
||||
)
|
||||
|
||||
# Clean up MFA session
|
||||
@ -1051,7 +993,9 @@ async def verify_mfa(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
user=user,
|
||||
expires_at=int((datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp())
|
||||
expires_at=int(
|
||||
(datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(f"✅ MFA verified and login completed for {request.email}")
|
||||
@ -1062,15 +1006,12 @@ async def verify_mfa(
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ MFA verification error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("MFA_VERIFICATION_FAILED", "Failed to verify MFA")
|
||||
status_code=500, content=create_error_response("MFA_VERIFICATION_FAILED", "Failed to verify MFA")
|
||||
)
|
||||
|
||||
|
||||
@router.post("/password-reset/request")
|
||||
async def request_password_reset(
|
||||
request: PasswordResetRequest,
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
async def request_password_reset(request: PasswordResetRequest, database: RedisDatabase = Depends(get_database)):
|
||||
"""Request password reset"""
|
||||
try:
|
||||
# Check if user exists
|
||||
@ -1100,15 +1041,12 @@ async def request_password_reset(
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Password reset request error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("RESET_ERROR", "An error occurred processing the request")
|
||||
status_code=500, content=create_error_response("RESET_ERROR", "An error occurred processing the request")
|
||||
)
|
||||
|
||||
|
||||
@router.post("/password-reset/confirm")
|
||||
async def confirm_password_reset(
|
||||
request: PasswordResetConfirm,
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
async def confirm_password_reset(request: PasswordResetConfirm, database: RedisDatabase = Depends(get_database)):
|
||||
"""Confirm password reset with token"""
|
||||
try:
|
||||
# Find user by reset token
|
||||
@ -1122,8 +1060,5 @@ async def confirm_password_reset(
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Password reset confirm error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("RESET_ERROR", "An error occurred resetting the password")
|
||||
status_code=500, content=create_error_response("RESET_ERROR", "An error occurred resetting the password")
|
||||
)
|
||||
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -4,125 +4,69 @@ Chat routes
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, UTC
|
||||
from typing import (Dict, Any)
|
||||
from typing import Dict, Any
|
||||
|
||||
from fastapi import (
|
||||
APIRouter, Depends, Body, Depends, Query, Path,
|
||||
Body, APIRouter
|
||||
)
|
||||
from fastapi import APIRouter, Depends, Body, Depends, Query, Path, Body, APIRouter
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from database.manager import RedisDatabase
|
||||
from logger import logger
|
||||
from utils.dependencies import (
|
||||
get_database, get_current_user, get_current_user_or_guest
|
||||
)
|
||||
from utils.responses import (
|
||||
create_success_response, create_error_response, create_paginated_response
|
||||
)
|
||||
from utils.helpers import (
|
||||
stream_agent_response
|
||||
)
|
||||
from utils.dependencies import get_database, get_current_user, get_current_user_or_guest
|
||||
from utils.responses import create_success_response, create_error_response, create_paginated_response
|
||||
from utils.helpers import stream_agent_response
|
||||
import backstory_traceback
|
||||
|
||||
import entities.entity_manager as entities
|
||||
|
||||
from models import (
|
||||
LoginRequest, CreateCandidateRequest, CreateEmployerRequest,
|
||||
Candidate, Employer, Guest, AuthResponse,
|
||||
MFARequest, MFAData, MFARequestResponse, MFAVerifyRequest,
|
||||
EmailVerificationRequest, ResendVerificationRequest,
|
||||
# API
|
||||
MOCK_UUID, ApiActivityType, ChatMessageError, ChatMessageResume,
|
||||
ChatMessageSkillAssessment, ChatMessageStatus, ChatMessageStreaming,
|
||||
ChatMessageUser, DocumentMessage, DocumentOptions, Job,
|
||||
JobRequirements, JobRequirementsMessage, LoginRequest,
|
||||
CreateCandidateRequest, CreateEmployerRequest,
|
||||
|
||||
# User models
|
||||
Candidate, Employer, BaseUserWithType, BaseUser, Guest,
|
||||
Authentication, AuthResponse, CandidateAI,
|
||||
|
||||
# Job models
|
||||
JobApplication, ApplicationStatus,
|
||||
|
||||
# Chat models
|
||||
ChatSession, ChatMessage, ChatContext, ChatQuery, ApiStatusType, ChatSenderType, ApiMessageType, ChatContextType,
|
||||
ChatMessageRagSearch,
|
||||
|
||||
# Document models
|
||||
Document, DocumentType, DocumentListResponse, DocumentUpdateRequest, DocumentContentResponse,
|
||||
|
||||
# Supporting models
|
||||
Location, MFARequest, MFAData, MFARequestResponse, MFAVerifyRequest, RagContentMetadata, RagContentResponse, ResendVerificationRequest, Resume, ResumeMessage, Skill, SkillAssessment, SystemInfo, UserType, WorkExperience, Education,
|
||||
|
||||
# Email
|
||||
EmailVerificationRequest
|
||||
)
|
||||
from models import Candidate, ChatMessageUser, Candidate, BaseUserWithType, ChatSession, ChatMessage
|
||||
|
||||
|
||||
# Create router for authentication endpoints
|
||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/archive")
|
||||
async def archive_chat_session(
|
||||
session_id: str = Path(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
session_id: str = Path(...), current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Archive a chat session"""
|
||||
try:
|
||||
session_data = await database.get_chat_session(session_id)
|
||||
if not session_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Chat session not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
|
||||
|
||||
# Check if user owns this session or is admin
|
||||
if session_data.get("userId") != current_user.id:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content=create_error_response("FORBIDDEN", "Cannot archive another user's session")
|
||||
status_code=403, content=create_error_response("FORBIDDEN", "Cannot archive another user's session")
|
||||
)
|
||||
|
||||
await database.archive_chat_session(session_id)
|
||||
|
||||
return create_success_response({
|
||||
"message": "Chat session archived successfully",
|
||||
"sessionId": session_id
|
||||
})
|
||||
return create_success_response({"message": "Chat session archived successfully", "sessionId": session_id})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Archive chat session error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("ARCHIVE_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("ARCHIVE_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.get("/statistics")
|
||||
async def get_chat_statistics(
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
async def get_chat_statistics(current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)):
|
||||
"""Get chat statistics (admin/analytics endpoint)"""
|
||||
try:
|
||||
stats = await database.get_chat_statistics()
|
||||
return create_success_response(stats)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get chat statistics error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STATS_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("STATS_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.post("/sessions")
|
||||
async def create_chat_session(
|
||||
session_data: Dict[str, Any] = Body(...),
|
||||
current_user: BaseUserWithType = Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Create a new chat session with optional candidate username association"""
|
||||
try:
|
||||
@ -140,15 +84,14 @@ async def create_chat_session(
|
||||
candidates_list = [Candidate.model_validate(data) for data in all_candidates_data.values()]
|
||||
|
||||
# Find candidate by username (case-insensitive)
|
||||
matching_candidates = [
|
||||
c for c in candidates_list
|
||||
if c.username.lower() == username.lower()
|
||||
]
|
||||
matching_candidates = [c for c in candidates_list if c.username.lower() == username.lower()]
|
||||
|
||||
if not matching_candidates:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("CANDIDATE_NOT_FOUND", f"Candidate with username '{username}' not found")
|
||||
content=create_error_response(
|
||||
"CANDIDATE_NOT_FOUND", f"Candidate with username '{username}' not found"
|
||||
),
|
||||
)
|
||||
|
||||
candidate_data = matching_candidates[0]
|
||||
@ -177,7 +120,7 @@ async def create_chat_session(
|
||||
"username": candidate_data.username,
|
||||
"skills": [skill.name for skill in candidate_data.skills] if candidate_data.skills else [],
|
||||
"experience": len(candidate_data.experience) if candidate_data.experience else 0,
|
||||
"location": candidate_data.location.city if candidate_data.location else "Unknown"
|
||||
"location": candidate_data.location.city if candidate_data.location else "Unknown",
|
||||
}
|
||||
context["additionalContext"] = additional_context
|
||||
|
||||
@ -191,8 +134,10 @@ async def create_chat_session(
|
||||
chat_session = ChatSession.model_validate(session_data)
|
||||
await database.set_chat_session(chat_session.id, chat_session.model_dump())
|
||||
|
||||
logger.info(f"✅ Chat session created: {chat_session.id} for user {current_user.id}" +
|
||||
(f" about candidate {candidate_data.full_name}" if candidate_data else ""))
|
||||
logger.info(
|
||||
f"✅ Chat session created: {chat_session.id} for user {current_user.id}"
|
||||
+ (f" about candidate {candidate_data.full_name}" if candidate_data else "")
|
||||
)
|
||||
|
||||
return create_success_response(chat_session.model_dump(by_alias=True))
|
||||
|
||||
@ -200,49 +145,58 @@ async def create_chat_session(
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Chat session creation error: {e}")
|
||||
logger.info(json.dumps(session_data, indent=2))
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("CREATION_FAILED", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=400, content=create_error_response("CREATION_FAILED", str(e)))
|
||||
|
||||
|
||||
@router.post("/sessions/messages/stream")
|
||||
async def post_chat_session_message_stream(
|
||||
user_message: ChatMessageUser = Body(...),
|
||||
current_user = Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Post a message to a chat session and stream the response with persistence"""
|
||||
try:
|
||||
chat_session_data = await database.get_chat_session(user_message.session_id)
|
||||
if not chat_session_data:
|
||||
logger.info("🔗 Chat session not found for session ID: " + user_message.session_id)
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Chat session not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
|
||||
chat_session = ChatSession.model_validate(chat_session_data)
|
||||
chat_type = chat_session.context.type
|
||||
candidate_info = chat_session.context.additional_context.get("candidateInfo", {}) if chat_session.context and chat_session.context.additional_context else None
|
||||
candidate_info = (
|
||||
chat_session.context.additional_context.get("candidateInfo", {})
|
||||
if chat_session.context and chat_session.context.additional_context
|
||||
else None
|
||||
)
|
||||
|
||||
# Get candidate info if this chat is about a specific candidate
|
||||
if candidate_info:
|
||||
logger.info(f"🔗 Chat session {user_message.session_id} about candidate {candidate_info['name']} accessed by user {current_user.id}")
|
||||
logger.info(
|
||||
f"🔗 Chat session {user_message.session_id} about candidate {candidate_info['name']} accessed by user {current_user.id}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"🔗 Chat session {user_message.session_id} type {chat_type} accessed by user {current_user.id}")
|
||||
logger.info(
|
||||
f"🔗 Chat session {user_message.session_id} type {chat_type} accessed by user {current_user.id}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("CANDIDATE_REQUIRED", "This chat session requires a candidate association")
|
||||
content=create_error_response(
|
||||
"CANDIDATE_REQUIRED", "This chat session requires a candidate association"
|
||||
),
|
||||
)
|
||||
|
||||
candidate_data = await database.get_candidate(candidate_info["id"]) if candidate_info else None
|
||||
candidate : Candidate | None = Candidate.model_validate(candidate_data) if candidate_data else None
|
||||
candidate: Candidate | None = Candidate.model_validate(candidate_data) if candidate_data else None
|
||||
if not candidate:
|
||||
logger.info(f"🔗 Candidate not found for chat session {user_message.session_id} with ID {candidate_info['id']}")
|
||||
logger.info(
|
||||
f"🔗 Candidate not found for chat session {user_message.session_id} with ID {candidate_info['id']}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("CANDIDATE_NOT_FOUND", "Candidate not found for this chat session")
|
||||
content=create_error_response("CANDIDATE_NOT_FOUND", "Candidate not found for this chat session"),
|
||||
)
|
||||
logger.info(f"🔗 User {current_user.id} posting message to chat session {user_message.session_id} with query length: {len(user_message.content)}")
|
||||
logger.info(
|
||||
f"🔗 User {current_user.id} posting message to chat session {user_message.session_id} with query length: {len(user_message.content)}"
|
||||
)
|
||||
|
||||
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
|
||||
# Entity automatically released when done
|
||||
@ -251,7 +205,7 @@ async def post_chat_session_message_stream(
|
||||
logger.info(f"🔗 No chat agent found for session {user_message.session_id} with type {chat_type}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type")
|
||||
content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type"),
|
||||
)
|
||||
|
||||
# Persist user message to database
|
||||
@ -269,30 +223,25 @@ async def post_chat_session_message_stream(
|
||||
chat_session_data=chat_session_data,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Chat message streaming error")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("STREAMING_ERROR", "")
|
||||
)
|
||||
logger.error("❌ Chat message streaming error")
|
||||
return JSONResponse(status_code=500, content=create_error_response("STREAMING_ERROR", ""))
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/messages")
|
||||
async def get_chat_session_messages(
|
||||
session_id: str = Path(...),
|
||||
current_user = Depends(get_current_user_or_guest),
|
||||
current_user=Depends(get_current_user_or_guest),
|
||||
page: int = Query(1, ge=1),
|
||||
limit: int = Query(50, ge=1, le=100), # Increased default for chat messages
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Get persisted chat messages for a session"""
|
||||
try:
|
||||
chat_session_data = await database.get_chat_session(session_id)
|
||||
if not chat_session_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Chat session not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
|
||||
|
||||
# Get messages from database
|
||||
chat_messages = await database.get_chat_messages(session_id)
|
||||
@ -317,43 +266,36 @@ async def get_chat_session_messages(
|
||||
paginated_messages = messages_list[start:end]
|
||||
|
||||
paginated_response = create_paginated_response(
|
||||
[m.model_dump(by_alias=True) for m in paginated_messages],
|
||||
page, limit, total
|
||||
[m.model_dump(by_alias=True) for m in paginated_messages], page, limit, total
|
||||
)
|
||||
|
||||
return create_success_response(paginated_response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get chat messages error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("FETCH_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.patch("/sessions/{session_id}")
|
||||
async def update_chat_session(
|
||||
session_id: str = Path(...),
|
||||
updates: Dict[str, Any] = Body(...),
|
||||
current_user = Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Update a chat session's properties"""
|
||||
try:
|
||||
# Get the existing session
|
||||
session_data = await database.get_chat_session(session_id)
|
||||
if not session_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Chat session not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
|
||||
|
||||
session = ChatSession.model_validate(session_data)
|
||||
|
||||
# Check authorization - user can only update their own sessions
|
||||
if session.user_id != current_user.id:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content=create_error_response("FORBIDDEN", "Cannot update another user's chat session")
|
||||
status_code=403, content=create_error_response("FORBIDDEN", "Cannot update another user's chat session")
|
||||
)
|
||||
|
||||
# Validate and apply updates
|
||||
@ -362,8 +304,7 @@ async def update_chat_session(
|
||||
|
||||
if not filtered_updates:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("INVALID_UPDATES", "No valid fields provided for update")
|
||||
status_code=400, content=create_error_response("INVALID_UPDATES", "No valid fields provided for update")
|
||||
)
|
||||
|
||||
# Apply updates to session data
|
||||
@ -417,40 +358,31 @@ async def update_chat_session(
|
||||
|
||||
except ValueError as ve:
|
||||
logger.warning(f"⚠️ Validation error updating chat session: {ve}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("VALIDATION_ERROR", str(ve))
|
||||
)
|
||||
return JSONResponse(status_code=400, content=create_error_response("VALIDATION_ERROR", str(ve)))
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Update chat session error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("UPDATE_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("UPDATE_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.delete("/sessions/{session_id}")
|
||||
async def delete_chat_session(
|
||||
session_id: str = Path(...),
|
||||
current_user = Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Delete a chat session and all its messages"""
|
||||
try:
|
||||
# Get the session to verify it exists and check ownership
|
||||
session_data = await database.get_chat_session(session_id)
|
||||
if not session_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Chat session not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
|
||||
|
||||
session = ChatSession.model_validate(session_data)
|
||||
|
||||
# Check authorization - user can only delete their own sessions
|
||||
if session.user_id != current_user.id:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content=create_error_response("FORBIDDEN", "Cannot delete another user's chat session")
|
||||
status_code=403, content=create_error_response("FORBIDDEN", "Cannot delete another user's chat session")
|
||||
)
|
||||
|
||||
# Delete all messages associated with this session
|
||||
@ -469,42 +401,34 @@ async def delete_chat_session(
|
||||
|
||||
logger.info(f"🗑️ Chat session {session_id} deleted by user {current_user.id}")
|
||||
|
||||
return create_success_response({
|
||||
"success": True,
|
||||
"message": "Chat session deleted successfully",
|
||||
"sessionId": session_id
|
||||
})
|
||||
return create_success_response(
|
||||
{"success": True, "message": "Chat session deleted successfully", "sessionId": session_id}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Delete chat session error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("DELETE_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("DELETE_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.patch("/sessions/{session_id}/reset")
|
||||
async def reset_chat_session(
|
||||
session_id: str = Path(...),
|
||||
current_user = Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user_or_guest),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Delete a chat session and all its messages"""
|
||||
try:
|
||||
# Get the session to verify it exists and check ownership
|
||||
session_data = await database.get_chat_session(session_id)
|
||||
if not session_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Chat session not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
|
||||
|
||||
session = ChatSession.model_validate(session_data)
|
||||
|
||||
# Check authorization - user can only delete their own sessions
|
||||
if session.user_id != current_user.id:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content=create_error_response("FORBIDDEN", "Cannot reset another user's chat session")
|
||||
status_code=403, content=create_error_response("FORBIDDEN", "Cannot reset another user's chat session")
|
||||
)
|
||||
|
||||
# Delete all messages associated with this session
|
||||
@ -518,20 +442,12 @@ async def reset_chat_session(
|
||||
logger.warning(f"⚠️ Error deleting messages for session {session_id}: {e}")
|
||||
# Continue with session deletion even if message deletion fails
|
||||
|
||||
|
||||
logger.info(f"🗑️ Chat session {session_id} reset by user {current_user.id}")
|
||||
|
||||
return create_success_response({
|
||||
"success": True,
|
||||
"message": "Chat session reset successfully",
|
||||
"sessionId": session_id
|
||||
})
|
||||
return create_success_response(
|
||||
{"success": True, "message": "Chat session reset successfully", "sessionId": session_id}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Reset chat session error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("RESET_ERROR", str(e))
|
||||
)
|
||||
|
||||
|
||||
return JSONResponse(status_code=500, content=create_error_response("RESET_ERROR", str(e)))
|
||||
|
@ -9,24 +9,21 @@ from fastapi.responses import JSONResponse
|
||||
|
||||
from database.manager import RedisDatabase
|
||||
from logger import logger
|
||||
from utils.dependencies import (
|
||||
get_current_admin, get_database
|
||||
)
|
||||
from utils.dependencies import get_current_admin, get_database
|
||||
from utils.responses import create_success_response, create_error_response
|
||||
|
||||
# Create router for authentication endpoints
|
||||
router = APIRouter(prefix="/auth", tags=["authentication"])
|
||||
|
||||
|
||||
@router.get("/guest/{guest_id}")
|
||||
async def debug_guest_session(
|
||||
guest_id: str = Path(...),
|
||||
admin_user = Depends(get_current_admin),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
guest_id: str = Path(...), admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Debug guest session issues (admin only)"""
|
||||
try:
|
||||
# Check primary storage
|
||||
primary_data = await database.redis.hget("guests", guest_id) # type: ignore
|
||||
primary_data = await database.redis.hget("guests", guest_id) # type: ignore
|
||||
primary_exists = primary_data is not None
|
||||
|
||||
# Check backup storage
|
||||
@ -37,7 +34,7 @@ async def debug_guest_session(
|
||||
user_lookup = await database.get_user_by_id(guest_id)
|
||||
|
||||
# Get TTL info
|
||||
primary_ttl = await database.redis.ttl(f"guests")
|
||||
primary_ttl = await database.redis.ttl("guests")
|
||||
backup_ttl = await database.redis.ttl(f"guest_backup:{guest_id}")
|
||||
|
||||
debug_info = {
|
||||
@ -45,22 +42,19 @@ async def debug_guest_session(
|
||||
"primary_storage": {
|
||||
"exists": primary_exists,
|
||||
"data": json.loads(primary_data) if primary_data else None,
|
||||
"ttl": primary_ttl
|
||||
"ttl": primary_ttl,
|
||||
},
|
||||
"backup_storage": {
|
||||
"exists": backup_exists,
|
||||
"data": json.loads(backup_data) if backup_data else None,
|
||||
"ttl": backup_ttl
|
||||
"ttl": backup_ttl,
|
||||
},
|
||||
"user_lookup": user_lookup,
|
||||
"timestamp": datetime.now(UTC).isoformat()
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
|
||||
return create_success_response(debug_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Debug guest session error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("DEBUG_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("DEBUG_ERROR", str(e)))
|
||||
|
@ -10,12 +10,8 @@ from fastapi.responses import JSONResponse
|
||||
|
||||
from database.manager import RedisDatabase
|
||||
from logger import logger
|
||||
from models import (
|
||||
CreateEmployerRequest
|
||||
)
|
||||
from utils.dependencies import (
|
||||
get_database
|
||||
)
|
||||
from models import CreateEmployerRequest
|
||||
from utils.dependencies import get_database
|
||||
from utils.responses import create_success_response, create_error_response
|
||||
from email_service import email_service
|
||||
from utils.auth_utils import AuthenticationManager
|
||||
@ -23,29 +19,22 @@ from utils.auth_utils import AuthenticationManager
|
||||
# Create router for job endpoints
|
||||
router = APIRouter(prefix="/employers", tags=["employers"])
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_employer_with_verification(
|
||||
request: CreateEmployerRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
request: CreateEmployerRequest, background_tasks: BackgroundTasks, database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Create a new employer with email verification"""
|
||||
try:
|
||||
# Similar to candidate creation but for employer
|
||||
auth_manager = AuthenticationManager(database)
|
||||
|
||||
user_exists, conflict_field = await auth_manager.check_user_exists(
|
||||
request.email,
|
||||
request.username
|
||||
)
|
||||
user_exists, conflict_field = await auth_manager.check_user_exists(request.email, request.username)
|
||||
|
||||
if user_exists and conflict_field:
|
||||
return JSONResponse(
|
||||
status_code=409,
|
||||
content=create_error_response(
|
||||
"USER_EXISTS",
|
||||
f"A user with this {conflict_field} already exists"
|
||||
)
|
||||
content=create_error_response("USER_EXISTS", f"A user with this {conflict_field} already exists"),
|
||||
)
|
||||
|
||||
employer_id = str(uuid.uuid4())
|
||||
@ -64,12 +53,8 @@ async def create_employer_with_verification(
|
||||
"updatedAt": current_time.isoformat(),
|
||||
"status": "pending", # Not active until verified
|
||||
"userType": "employer",
|
||||
"location": {
|
||||
"city": "",
|
||||
"country": "",
|
||||
"remote": False
|
||||
},
|
||||
"socialLinks": []
|
||||
"location": {"city": "", "country": "", "remote": False},
|
||||
"socialLinks": [],
|
||||
}
|
||||
|
||||
verification_token = secrets.token_urlsafe(32)
|
||||
@ -78,32 +63,25 @@ async def create_employer_with_verification(
|
||||
request.email,
|
||||
verification_token,
|
||||
"employer",
|
||||
{
|
||||
"employer_data": employer_data,
|
||||
"password": request.password,
|
||||
"username": request.username
|
||||
}
|
||||
{"employer_data": employer_data, "password": request.password, "username": request.username},
|
||||
)
|
||||
|
||||
background_tasks.add_task(
|
||||
email_service.send_verification_email,
|
||||
request.email,
|
||||
verification_token,
|
||||
request.company_name
|
||||
email_service.send_verification_email, request.email, verification_token, request.company_name
|
||||
)
|
||||
|
||||
logger.info(f"✅ Employer registration initiated for: {request.email}")
|
||||
|
||||
return create_success_response({
|
||||
"message": "Registration successful! Please check your email to verify your account.",
|
||||
"email": request.email,
|
||||
"verificationRequired": True
|
||||
})
|
||||
return create_success_response(
|
||||
{
|
||||
"message": "Registration successful! Please check your email to verify your account.",
|
||||
"email": request.email,
|
||||
"verificationRequired": True,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Employer creation error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("CREATION_FAILED", "Failed to create employer account")
|
||||
status_code=500, content=create_error_response("CREATION_FAILED", "Failed to create employer account")
|
||||
)
|
||||
|
||||
|
@ -21,14 +21,24 @@ from utils.helpers import create_job_from_content, filter_and_paginate, get_docu
|
||||
from database.manager import RedisDatabase
|
||||
from logger import logger
|
||||
from models import (
|
||||
MOCK_UUID, ApiActivityType, ApiStatusType, ChatContextType, ChatMessage, ChatMessageError, ChatMessageStatus, DocumentType, Job, JobRequirementsMessage, Candidate, Employer
|
||||
)
|
||||
from utils.dependencies import (
|
||||
get_current_admin, get_database, get_current_user
|
||||
MOCK_UUID,
|
||||
ApiActivityType,
|
||||
ApiStatusType,
|
||||
ChatContextType,
|
||||
ChatMessage,
|
||||
ChatMessageError,
|
||||
ChatMessageStatus,
|
||||
DocumentType,
|
||||
Job,
|
||||
JobRequirementsMessage,
|
||||
Candidate,
|
||||
Employer,
|
||||
)
|
||||
from utils.dependencies import get_current_admin, get_database, get_current_user
|
||||
from utils.responses import create_paginated_response, create_success_response, create_error_response
|
||||
import utils.llm_proxy as llm_manager
|
||||
import entities.entity_manager as entities
|
||||
|
||||
# Create router for job endpoints
|
||||
router = APIRouter(prefix="/jobs", tags=["jobs"])
|
||||
|
||||
@ -38,14 +48,14 @@ async def reformat_as_markdown(database: RedisDatabase, candidate_entity: Candid
|
||||
if not chat_agent:
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="No agent found for job requirements chat type"
|
||||
content="No agent found for job requirements chat type",
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Reformatting job description as markdown...",
|
||||
activity=ApiActivityType.CONVERTING
|
||||
content="Reformatting job description as markdown...",
|
||||
activity=ApiActivityType.CONVERTING,
|
||||
)
|
||||
yield status_message
|
||||
|
||||
@ -58,7 +68,7 @@ async def reformat_as_markdown(database: RedisDatabase, candidate_entity: Candid
|
||||
system_prompt="""
|
||||
You are a document editor. Take the provided job description and reformat as legible markdown.
|
||||
Return only the markdown content, no other text. Make sure all content is included.
|
||||
"""
|
||||
""",
|
||||
):
|
||||
pass
|
||||
|
||||
@ -66,16 +76,16 @@ Return only the markdown content, no other text. Make sure all content is includ
|
||||
logger.error("❌ Failed to reformat job description to markdown")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Failed to reformat job description"
|
||||
content="Failed to reformat job description",
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
chat_message : ChatMessage = message
|
||||
chat_message: ChatMessage = message
|
||||
try:
|
||||
chat_message.content = chat_agent.extract_markdown_from_text(chat_message.content)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
logger.info(f"✅ Successfully converted content to markdown")
|
||||
logger.info("✅ Successfully converted content to markdown")
|
||||
yield chat_message
|
||||
return
|
||||
|
||||
@ -84,10 +94,10 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Initiating connection with {current_user.first_name}'s AI agent...",
|
||||
activity=ApiActivityType.INFO
|
||||
activity=ApiActivityType.INFO,
|
||||
)
|
||||
yield status_message
|
||||
await asyncio.sleep(0) # Let the status message propagate
|
||||
await asyncio.sleep(0) # Let the status message propagate
|
||||
|
||||
async with entities.get_candidate_entity(candidate=current_user) as candidate_entity:
|
||||
message = None
|
||||
@ -98,7 +108,7 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
|
||||
if not message or not isinstance(message, ChatMessage):
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Failed to reformat job description"
|
||||
content="Failed to reformat job description",
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
@ -108,23 +118,20 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
|
||||
if not chat_agent:
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="No agent found for job requirements chat type"
|
||||
content="No agent found for job requirements chat type",
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Analyzing document for company and requirement details...",
|
||||
activity=ApiActivityType.SEARCHING
|
||||
content="Analyzing document for company and requirement details...",
|
||||
activity=ApiActivityType.SEARCHING,
|
||||
)
|
||||
yield status_message
|
||||
|
||||
message = None
|
||||
async for message in chat_agent.generate(
|
||||
llm=llm_manager.get_llm(),
|
||||
model=defines.model,
|
||||
session_id=MOCK_UUID,
|
||||
prompt=markdown_message.content
|
||||
llm=llm_manager.get_llm(), model=defines.model, session_id=MOCK_UUID, prompt=markdown_message.content
|
||||
):
|
||||
if message.status != ApiStatusType.DONE:
|
||||
yield message
|
||||
@ -132,23 +139,22 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
|
||||
if not message or not isinstance(message, JobRequirementsMessage):
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Job extraction did not convert successfully"
|
||||
content="Job extraction did not convert successfully",
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
job_requirements : JobRequirementsMessage = message
|
||||
job_requirements: JobRequirementsMessage = message
|
||||
logger.info(f"✅ Successfully generated job requirements for job {job_requirements.id}")
|
||||
yield job_requirements
|
||||
return
|
||||
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_job(
|
||||
job_data: Dict[str, Any] = Body(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Create a new job"""
|
||||
try:
|
||||
@ -165,20 +171,17 @@ async def create_job(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Job creation error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("CREATION_FAILED", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=400, content=create_error_response("CREATION_FAILED", str(e)))
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_candidate_job(
|
||||
job_data: Dict[str, Any] = Body(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Create a new job"""
|
||||
is_employer = isinstance(current_user, Employer)
|
||||
isinstance(current_user, Employer)
|
||||
|
||||
try:
|
||||
job = Job.model_validate(job_data)
|
||||
@ -194,28 +197,22 @@ async def create_candidate_job(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Job creation error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("CREATION_FAILED", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=400, content=create_error_response("CREATION_FAILED", str(e)))
|
||||
|
||||
|
||||
@router.patch("/{job_id}")
|
||||
async def update_job(
|
||||
job_id: str = Path(...),
|
||||
updates: Dict[str, Any] = Body(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Update a candidate"""
|
||||
try:
|
||||
job_data = await database.get_job(job_id)
|
||||
if not job_data:
|
||||
logger.warning(f"⚠️ Job not found for update: {job_data}")
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Job not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Job not found"))
|
||||
|
||||
job = Job.model_validate(job_data)
|
||||
|
||||
@ -223,8 +220,7 @@ async def update_job(
|
||||
if current_user.is_admin is False and job.owner_id != current_user.id:
|
||||
logger.warning(f"⚠️ Unauthorized update attempt by user {current_user.id} on job {job_id}")
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content=create_error_response("FORBIDDEN", "Cannot update another user's job")
|
||||
status_code=403, content=create_error_response("FORBIDDEN", "Cannot update another user's job")
|
||||
)
|
||||
|
||||
# Apply updates
|
||||
@ -239,25 +235,22 @@ async def update_job(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Update job error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("UPDATE_FAILED", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=400, content=create_error_response("UPDATE_FAILED", str(e)))
|
||||
|
||||
|
||||
@router.post("/from-content")
|
||||
async def create_job_from_description(
|
||||
content: str = Body(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
content: str = Body(...), current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Upload a document for the current candidate"""
|
||||
|
||||
async def content_stream_generator(content):
|
||||
# Verify user is a candidate
|
||||
if current_user.user_type != "candidate":
|
||||
logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Only candidates can upload documents"
|
||||
content="Only candidates can upload documents",
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
@ -277,10 +270,11 @@ async def create_job_from_description(
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
async def to_json(method):
|
||||
try:
|
||||
async for message in method:
|
||||
json_data = message.model_dump(mode='json', by_alias=True)
|
||||
json_data = message.model_dump(mode="json", by_alias=True)
|
||||
json_str = json.dumps(json_data)
|
||||
yield f"data: {json_str}\n\n".encode("utf-8")
|
||||
except Exception as e:
|
||||
@ -304,18 +298,25 @@ async def create_job_from_description(
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Document upload error: {e}")
|
||||
return StreamingResponse(
|
||||
iter([json.dumps(ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Failed to upload document"
|
||||
).model_dump(by_alias=True)).encode("utf-8")]),
|
||||
media_type="text/event-stream"
|
||||
iter(
|
||||
[
|
||||
json.dumps(
|
||||
ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Failed to upload document",
|
||||
).model_dump(by_alias=True)
|
||||
).encode("utf-8")
|
||||
]
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
async def create_job_from_file(
|
||||
file: UploadFile = File(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Upload a job document for the current candidate and create a Job"""
|
||||
# Check file size (limit to 10MB)
|
||||
@ -324,55 +325,70 @@ async def create_job_from_file(
|
||||
if len(file_content) > max_size:
|
||||
logger.info(f"⚠️ File too large: {file.filename} ({len(file_content)} bytes)")
|
||||
return StreamingResponse(
|
||||
iter([json.dumps(ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="File size exceeds 10MB limit"
|
||||
).model_dump(by_alias=True)).encode("utf-8")]),
|
||||
media_type="text/event-stream"
|
||||
iter(
|
||||
[
|
||||
json.dumps(
|
||||
ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="File size exceeds 10MB limit",
|
||||
).model_dump(by_alias=True)
|
||||
).encode("utf-8")
|
||||
]
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
if len(file_content) == 0:
|
||||
logger.info(f"⚠️ File is empty: {file.filename}")
|
||||
return StreamingResponse(
|
||||
iter([json.dumps(ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="File is empty"
|
||||
).model_dump(by_alias=True)).encode("utf-8")]),
|
||||
media_type="text/event-stream"
|
||||
iter(
|
||||
[
|
||||
json.dumps(
|
||||
ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="File is empty",
|
||||
).model_dump(by_alias=True)
|
||||
).encode("utf-8")
|
||||
]
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
"""Upload a document for the current candidate"""
|
||||
|
||||
async def upload_stream_generator(file_content):
|
||||
# Verify user is a candidate
|
||||
if current_user.user_type != "candidate":
|
||||
logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Only candidates can upload documents"
|
||||
content="Only candidates can upload documents",
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
file.filename = re.sub(r'^.*/', '', file.filename) if file.filename else '' # Sanitize filename
|
||||
file.filename = re.sub(r"^.*/", "", file.filename) if file.filename else "" # Sanitize filename
|
||||
if not file.filename or file.filename.strip() == "":
|
||||
logger.warning("⚠️ File upload attempt with missing filename")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="File must have a valid filename"
|
||||
content="File must have a valid filename",
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
logger.info(f"📁 Received file upload: filename='{file.filename}', content_type='{file.content_type}', size='{len(file_content)} bytes'")
|
||||
logger.info(
|
||||
f"📁 Received file upload: filename='{file.filename}', content_type='{file.content_type}', size='{len(file_content)} bytes'"
|
||||
)
|
||||
|
||||
# Validate file type
|
||||
allowed_types = ['.txt', '.md', '.docx', '.pdf', '.png', '.jpg', '.jpeg', '.gif']
|
||||
allowed_types = [".txt", ".md", ".docx", ".pdf", ".png", ".jpg", ".jpeg", ".gif"]
|
||||
file_extension = pathlib.Path(file.filename).suffix.lower() if file.filename else ""
|
||||
|
||||
if file_extension not in allowed_types:
|
||||
logger.warning(f"⚠️ Invalid file type: {file_extension} for file {file.filename}")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}"
|
||||
content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}",
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
@ -383,7 +399,7 @@ async def create_job_from_file(
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Converting content from {document_type}...",
|
||||
activity=ApiActivityType.CONVERTING
|
||||
activity=ApiActivityType.CONVERTING,
|
||||
)
|
||||
yield status_message
|
||||
try:
|
||||
@ -391,7 +407,7 @@ async def create_job_from_file(
|
||||
stream = io.BytesIO(file_content)
|
||||
stream_info = StreamInfo(
|
||||
extension=file_extension, # e.g., ".pdf"
|
||||
url=file.filename # optional, helps with logging and guessing
|
||||
url=file.filename, # optional, helps with logging and guessing
|
||||
)
|
||||
result = md.convert_stream(stream, stream_info=stream_info, output_format="markdown")
|
||||
file_content = result.text_content
|
||||
@ -405,15 +421,18 @@ async def create_job_from_file(
|
||||
logger.error(f"❌ Error converting {file.filename} to Markdown: {e}")
|
||||
return
|
||||
|
||||
async for message in create_job_from_content(database=database, current_user=current_user, content=file_content):
|
||||
async for message in create_job_from_content(
|
||||
database=database, current_user=current_user, content=file_content
|
||||
):
|
||||
yield message
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
async def to_json(method):
|
||||
try:
|
||||
async for message in method:
|
||||
json_data = message.model_dump(mode='json', by_alias=True)
|
||||
json_data = message.model_dump(mode="json", by_alias=True)
|
||||
json_str = json.dumps(json_data)
|
||||
yield f"data: {json_str}\n\n".encode("utf-8")
|
||||
except Exception as e:
|
||||
@ -437,26 +456,27 @@ async def create_job_from_file(
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Document upload error: {e}")
|
||||
return StreamingResponse(
|
||||
iter([json.dumps(ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Failed to upload document"
|
||||
).model_dump(mode='json', by_alias=True)).encode("utf-8")]),
|
||||
media_type="text/event-stream"
|
||||
iter(
|
||||
[
|
||||
json.dumps(
|
||||
ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Failed to upload document",
|
||||
).model_dump(mode="json", by_alias=True)
|
||||
).encode("utf-8")
|
||||
]
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{job_id}")
|
||||
async def get_job(
|
||||
job_id: str = Path(...),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
async def get_job(job_id: str = Path(...), database: RedisDatabase = Depends(get_database)):
|
||||
"""Get a job by ID"""
|
||||
try:
|
||||
job_data = await database.get_job(job_id)
|
||||
if not job_data:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Job not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Job not found"))
|
||||
|
||||
# Increment view count
|
||||
job_data["views"] = job_data.get("views", 0) + 1
|
||||
@ -467,10 +487,8 @@ async def get_job(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get job error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("FETCH_ERROR", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e)))
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_jobs(
|
||||
@ -479,7 +497,7 @@ async def get_jobs(
|
||||
sortBy: Optional[str] = Query(None, alias="sortBy"),
|
||||
sortOrder: str = Query("desc", pattern="^(asc|desc)$", alias="sortOrder"),
|
||||
filters: Optional[str] = Query(None),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Get paginated list of jobs"""
|
||||
try:
|
||||
@ -493,23 +511,18 @@ async def get_jobs(
|
||||
for job in all_jobs_data.values():
|
||||
jobs_list.append(Job.model_validate(job))
|
||||
|
||||
paginated_jobs, total = filter_and_paginate(
|
||||
jobs_list, page, limit, sortBy, sortOrder, filter_dict
|
||||
)
|
||||
paginated_jobs, total = filter_and_paginate(jobs_list, page, limit, sortBy, sortOrder, filter_dict)
|
||||
|
||||
paginated_response = create_paginated_response(
|
||||
[j.model_dump(by_alias=True) for j in paginated_jobs],
|
||||
page, limit, total
|
||||
[j.model_dump(by_alias=True) for j in paginated_jobs], page, limit, total
|
||||
)
|
||||
|
||||
return create_success_response(paginated_response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get jobs error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("FETCH_FAILED", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=400, content=create_error_response("FETCH_FAILED", str(e)))
|
||||
|
||||
|
||||
@router.get("/search")
|
||||
async def search_jobs(
|
||||
@ -517,7 +530,7 @@ async def search_jobs(
|
||||
filters: Optional[str] = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Search jobs"""
|
||||
try:
|
||||
@ -532,69 +545,52 @@ async def search_jobs(
|
||||
if query:
|
||||
query_lower = query.lower()
|
||||
jobs_list = [
|
||||
j for j in jobs_list
|
||||
if ((j.title and query_lower in j.title.lower()) or
|
||||
(j.description and query_lower in j.description.lower()) or
|
||||
any(query_lower in skill.lower() for skill in getattr(j, "skills", []) or []))
|
||||
j
|
||||
for j in jobs_list
|
||||
if (
|
||||
(j.title and query_lower in j.title.lower())
|
||||
or (j.description and query_lower in j.description.lower())
|
||||
or any(query_lower in skill.lower() for skill in getattr(j, "skills", []) or [])
|
||||
)
|
||||
]
|
||||
|
||||
paginated_jobs, total = filter_and_paginate(
|
||||
jobs_list, page, limit, filters=filter_dict
|
||||
)
|
||||
paginated_jobs, total = filter_and_paginate(jobs_list, page, limit, filters=filter_dict)
|
||||
|
||||
paginated_response = create_paginated_response(
|
||||
[j.model_dump(by_alias=True) for j in paginated_jobs],
|
||||
page, limit, total
|
||||
[j.model_dump(by_alias=True) for j in paginated_jobs], page, limit, total
|
||||
)
|
||||
|
||||
return create_success_response(paginated_response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Search jobs error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content=create_error_response("SEARCH_FAILED", str(e))
|
||||
)
|
||||
return JSONResponse(status_code=400, content=create_error_response("SEARCH_FAILED", str(e)))
|
||||
|
||||
|
||||
@router.delete("/{job_id}")
|
||||
async def delete_job(
|
||||
job_id: str = Path(...),
|
||||
admin_user = Depends(get_current_admin),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
job_id: str = Path(...), admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Delete a Job"""
|
||||
try:
|
||||
# Check if admin user
|
||||
if not admin_user.is_admin:
|
||||
logger.warning(f"⚠️ Unauthorized delete attempt by user {admin_user.id}")
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content=create_error_response("FORBIDDEN", "Only admins can delete")
|
||||
)
|
||||
return JSONResponse(status_code=403, content=create_error_response("FORBIDDEN", "Only admins can delete"))
|
||||
|
||||
# Get candidate data
|
||||
job_data = await database.get_job(job_id)
|
||||
if not job_data:
|
||||
logger.warning(f"⚠️ Candidate not found for deletion: {job_id}")
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "Job not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Job not found"))
|
||||
|
||||
# Delete job from database
|
||||
await database.delete_job(job_id)
|
||||
|
||||
logger.info(f"🗑️ Job deleted: {job_id} by admin {admin_user.id}")
|
||||
|
||||
return create_success_response({
|
||||
"message": "Job deleted successfully",
|
||||
"jobId": job_id
|
||||
})
|
||||
return create_success_response({"message": "Job deleted successfully", "jobId": job_id})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Delete job error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("DELETE_ERROR", "Failed to delete job")
|
||||
)
|
||||
return JSONResponse(status_code=500, content=create_error_response("DELETE_ERROR", "Failed to delete job"))
|
||||
|
@ -6,6 +6,7 @@ from utils.llm_proxy import LLMProvider, get_llm
|
||||
|
||||
router = APIRouter(prefix="/providers", tags=["providers"])
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
async def list_models(provider: Optional[str] = None):
|
||||
"""List available models for a provider"""
|
||||
@ -17,29 +18,25 @@ async def list_models(provider: Optional[str] = None):
|
||||
try:
|
||||
provider_enum = LLMProvider(provider.lower())
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported provider: {provider}"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported provider: {provider}")
|
||||
|
||||
models = await llm.list_models(provider_enum)
|
||||
return {
|
||||
"provider": provider_enum.value if provider_enum else llm.default_provider.value,
|
||||
"models": models
|
||||
}
|
||||
return {"provider": provider_enum.value if provider_enum else llm.default_provider.value, "models": models}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_providers():
|
||||
"""List all configured providers"""
|
||||
llm = get_llm()
|
||||
return {
|
||||
"providers": [provider.value for provider in llm._initialized_providers],
|
||||
"default": llm.default_provider.value
|
||||
"default": llm.default_provider.value,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{provider}/set-default")
|
||||
async def set_default_provider(provider: str):
|
||||
"""Set the default provider"""
|
||||
@ -51,6 +48,7 @@ async def set_default_provider(provider: str):
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
# Health check endpoint
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
@ -59,5 +57,5 @@ async def health_check():
|
||||
return {
|
||||
"status": "healthy",
|
||||
"providers_configured": len(llm._initialized_providers),
|
||||
"default_provider": llm.default_provider.value
|
||||
"default_provider": llm.default_provider.value,
|
||||
}
|
||||
|
@ -11,26 +11,24 @@ from fastapi.responses import StreamingResponse
|
||||
import backstory_traceback as backstory_traceback
|
||||
from database.manager import RedisDatabase
|
||||
from logger import logger
|
||||
from models import (
|
||||
MOCK_UUID, ChatMessageError, Job, Candidate, Resume, ResumeMessage
|
||||
)
|
||||
from utils.dependencies import (
|
||||
get_database, get_current_user
|
||||
)
|
||||
from models import MOCK_UUID, ChatMessageError, Job, Candidate, Resume, ResumeMessage
|
||||
from utils.dependencies import get_database, get_current_user
|
||||
from utils.responses import create_success_response
|
||||
|
||||
# Create router for authentication endpoints
|
||||
router = APIRouter(prefix="/resumes", tags=["resumes"])
|
||||
|
||||
|
||||
@router.post("/{candidate_id}/{job_id}")
|
||||
async def create_candidate_resume(
|
||||
candidate_id: str = Path(..., description="ID of the candidate"),
|
||||
job_id: str = Path(..., description="ID of the job"),
|
||||
resume_content: str = Body(...),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Create a new resume for a candidate/job combination"""
|
||||
|
||||
async def message_stream_generator():
|
||||
logger.info(f"🔍 Looking up candidate and job details for {candidate_id}/{job_id}")
|
||||
|
||||
@ -39,7 +37,7 @@ async def create_candidate_resume(
|
||||
logger.error(f"❌ Candidate with ID '{candidate_id}' not found")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Candidate with ID '{candidate_id}' not found"
|
||||
content=f"Candidate with ID '{candidate_id}' not found",
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
@ -50,13 +48,15 @@ async def create_candidate_resume(
|
||||
logger.error(f"❌ Job with ID '{job_id}' not found")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content=f"Job with ID '{job_id}' not found"
|
||||
content=f"Job with ID '{job_id}' not found",
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
job = Job.model_validate(job_data)
|
||||
|
||||
logger.info(f"📄 Saving resume for candidate {candidate.first_name} {candidate.last_name} for job '{job.title}'")
|
||||
logger.info(
|
||||
f"📄 Saving resume for candidate {candidate.first_name} {candidate.last_name} for job '{job.title}'"
|
||||
)
|
||||
|
||||
# Job and Candidate are valid. Save the resume
|
||||
resume = Resume(
|
||||
@ -66,16 +66,13 @@ async def create_candidate_resume(
|
||||
)
|
||||
resume_message: ResumeMessage = ResumeMessage(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
resume=resume
|
||||
resume=resume,
|
||||
)
|
||||
|
||||
# Save to database
|
||||
success = await database.set_resume(current_user.id, resume.model_dump())
|
||||
if not success:
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID,
|
||||
content="Failed to save resume to database"
|
||||
)
|
||||
error_message = ChatMessageError(session_id=MOCK_UUID, content="Failed to save resume to database")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
@ -84,10 +81,11 @@ async def create_candidate_resume(
|
||||
return
|
||||
|
||||
try:
|
||||
|
||||
async def to_json(method):
|
||||
try:
|
||||
async for message in method:
|
||||
json_data = message.model_dump(mode='json', by_alias=True)
|
||||
json_data = message.model_dump(mode="json", by_alias=True)
|
||||
json_str = json.dumps(json_data)
|
||||
yield f"data: {json_str}\n\n".encode("utf-8")
|
||||
except Exception as e:
|
||||
@ -111,22 +109,26 @@ async def create_candidate_resume(
|
||||
logger.error(backstory_traceback.format_exc())
|
||||
logger.error(f"❌ Resume creation error: {e}")
|
||||
return StreamingResponse(
|
||||
iter([json.dumps(ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Failed to create resume"
|
||||
).model_dump(mode='json', by_alias=True))]),
|
||||
media_type="text/event-stream"
|
||||
iter(
|
||||
[
|
||||
json.dumps(
|
||||
ChatMessageError(
|
||||
session_id=MOCK_UUID, # No session ID for document uploads
|
||||
content="Failed to create resume",
|
||||
).model_dump(mode="json", by_alias=True)
|
||||
)
|
||||
]
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_user_resumes(
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
async def get_user_resumes(current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)):
|
||||
"""Get all resumes for the current user"""
|
||||
try:
|
||||
resumes_data = await database.get_all_resumes_for_user(current_user.id)
|
||||
resumes : List[Resume] = [Resume.model_validate(data) for data in resumes_data]
|
||||
resumes: List[Resume] = [Resume.model_validate(data) for data in resumes_data]
|
||||
for resume in resumes:
|
||||
job_data = await database.get_job(resume.job_id)
|
||||
if job_data:
|
||||
@ -135,19 +137,17 @@ async def get_user_resumes(
|
||||
if candidate_data:
|
||||
resume.candidate = Candidate.model_validate(candidate_data)
|
||||
resumes.sort(key=lambda x: x.updated_at, reverse=True) # Sort by creation date
|
||||
return create_success_response({
|
||||
"resumes": resumes,
|
||||
"count": len(resumes)
|
||||
})
|
||||
return create_success_response({"resumes": resumes, "count": len(resumes)})
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error retrieving resumes for user {current_user.id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve resumes")
|
||||
|
||||
|
||||
@router.get("/{resume_id}")
|
||||
async def get_resume(
|
||||
resume_id: str = Path(..., description="ID of the resume"),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Get a specific resume by ID"""
|
||||
try:
|
||||
@ -155,21 +155,19 @@ async def get_resume(
|
||||
if not resume:
|
||||
raise HTTPException(status_code=404, detail="Resume not found")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"resume": resume
|
||||
}
|
||||
return {"success": True, "resume": resume}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error retrieving resume {resume_id} for user {current_user.id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve resume")
|
||||
|
||||
|
||||
@router.delete("/{resume_id}")
|
||||
async def delete_resume(
|
||||
resume_id: str = Path(..., description="ID of the resume"),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Delete a specific resume"""
|
||||
try:
|
||||
@ -177,102 +175,82 @@ async def delete_resume(
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Resume not found")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Resume {resume_id} deleted successfully"
|
||||
}
|
||||
return {"success": True, "message": f"Resume {resume_id} deleted successfully"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error deleting resume {resume_id} for user {current_user.id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to delete resume")
|
||||
|
||||
|
||||
@router.get("/candidate/{candidate_id}")
|
||||
async def get_resumes_by_candidate(
|
||||
candidate_id: str = Path(..., description="ID of the candidate"),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Get all resumes for a specific candidate"""
|
||||
try:
|
||||
resumes = await database.get_resumes_by_candidate(current_user.id, candidate_id)
|
||||
return {
|
||||
"success": True,
|
||||
"candidate_id": candidate_id,
|
||||
"resumes": resumes,
|
||||
"count": len(resumes)
|
||||
}
|
||||
return {"success": True, "candidate_id": candidate_id, "resumes": resumes, "count": len(resumes)}
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error retrieving resumes for candidate {candidate_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve candidate resumes")
|
||||
|
||||
|
||||
@router.get("/job/{job_id}")
|
||||
async def get_resumes_by_job(
|
||||
job_id: str = Path(..., description="ID of the job"),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Get all resumes for a specific job"""
|
||||
try:
|
||||
resumes = await database.get_resumes_by_job(current_user.id, job_id)
|
||||
return {
|
||||
"success": True,
|
||||
"job_id": job_id,
|
||||
"resumes": resumes,
|
||||
"count": len(resumes)
|
||||
}
|
||||
return {"success": True, "job_id": job_id, "resumes": resumes, "count": len(resumes)}
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error retrieving resumes for job {job_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve job resumes")
|
||||
|
||||
|
||||
@router.get("/search")
|
||||
async def search_resumes(
|
||||
q: str = Query(..., description="Search query"),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Search resumes by content"""
|
||||
try:
|
||||
resumes = await database.search_resumes_for_user(current_user.id, q)
|
||||
return {
|
||||
"success": True,
|
||||
"query": q,
|
||||
"resumes": resumes,
|
||||
"count": len(resumes)
|
||||
}
|
||||
return {"success": True, "query": q, "resumes": resumes, "count": len(resumes)}
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error searching resumes for user {current_user.id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to search resumes")
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_resume_statistics(
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
"""Get resume statistics for the current user"""
|
||||
try:
|
||||
stats = await database.get_resume_statistics(current_user.id)
|
||||
return {
|
||||
"success": True,
|
||||
"statistics": stats
|
||||
}
|
||||
return {"success": True, "statistics": stats}
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error retrieving resume statistics for user {current_user.id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve resume statistics")
|
||||
|
||||
|
||||
@router.put("/{resume_id}")
|
||||
async def update_resume(
|
||||
resume_id: str = Path(..., description="ID of the resume"),
|
||||
resume: str = Body(..., description="Updated resume content"),
|
||||
current_user = Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
current_user=Depends(get_current_user),
|
||||
database: RedisDatabase = Depends(get_database),
|
||||
):
|
||||
"""Update the content of a specific resume"""
|
||||
try:
|
||||
updates = {
|
||||
"resume": resume,
|
||||
"updated_at": datetime.now(UTC).isoformat()
|
||||
}
|
||||
updates = {"resume": resume, "updated_at": datetime.now(UTC).isoformat()}
|
||||
|
||||
updated_resume_data = await database.update_resume(current_user.id, resume_id, updates)
|
||||
if not updated_resume_data:
|
||||
@ -280,11 +258,9 @@ async def update_resume(
|
||||
raise HTTPException(status_code=404, detail="Resume not found")
|
||||
updated_resume = Resume.model_validate(updated_resume_data) if updated_resume_data else None
|
||||
|
||||
return create_success_response({
|
||||
"success": True,
|
||||
"message": f"Resume {resume_id} updated successfully",
|
||||
"resume": updated_resume
|
||||
})
|
||||
return create_success_response(
|
||||
{"success": True, "message": f"Resume {resume_id} updated successfully", "resume": updated_resume}
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
@ -10,6 +10,7 @@ from utils.responses import create_success_response
|
||||
# Create router for authentication endpoints
|
||||
router = APIRouter(prefix="/system", tags=["system"])
|
||||
|
||||
|
||||
async def get_redis() -> Redis:
|
||||
"""Dependency to get Redis client"""
|
||||
return redis_manager.get_client()
|
||||
@ -19,9 +20,11 @@ async def get_redis() -> Redis:
|
||||
async def get_system_info(request: Request):
|
||||
"""Get system information"""
|
||||
from system_info import system_info # Import system_info function from system_info module
|
||||
|
||||
system = system_info()
|
||||
|
||||
return create_success_response(system.model_dump(mode='json'))
|
||||
return create_success_response(system.model_dump(mode="json"))
|
||||
|
||||
|
||||
@router.get("/redis/stats")
|
||||
async def redis_stats(redis: Redis = Depends(get_redis)):
|
||||
@ -33,7 +36,7 @@ async def redis_stats(redis: Redis = Depends(get_redis)):
|
||||
"total_commands_processed": info.get("total_commands_processed"),
|
||||
"keyspace_hits": info.get("keyspace_hits"),
|
||||
"keyspace_misses": info.get("keyspace_misses"),
|
||||
"uptime_in_seconds": info.get("uptime_in_seconds")
|
||||
"uptime_in_seconds": info.get("uptime_in_seconds"),
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=503, detail=f"Redis stats unavailable: {e}")
|
||||
|
@ -7,23 +7,17 @@ from fastapi.responses import JSONResponse
|
||||
|
||||
from database.manager import RedisDatabase
|
||||
from logger import logger
|
||||
from models import (
|
||||
BaseUserWithType
|
||||
)
|
||||
from utils.dependencies import (
|
||||
get_database
|
||||
)
|
||||
from models import BaseUserWithType
|
||||
from utils.dependencies import get_database
|
||||
from utils.responses import create_success_response, create_error_response
|
||||
|
||||
# Create router for job endpoints
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
|
||||
# reference can be candidateId, username, or email
|
||||
@router.get("/users/{reference}")
|
||||
async def get_user(
|
||||
reference: str = Path(...),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
):
|
||||
async def get_user(reference: str = Path(...), database: RedisDatabase = Depends(get_database)):
|
||||
"""Get a candidate by username"""
|
||||
try:
|
||||
# Normalize reference to lowercase for case-insensitive search
|
||||
@ -31,41 +25,36 @@ async def get_user(
|
||||
|
||||
all_candidate_data = await database.get_all_candidates()
|
||||
if not all_candidate_data:
|
||||
logger.warning(f"⚠️ No users found in database")
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "No users found")
|
||||
)
|
||||
logger.warning("⚠️ No users found in database")
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "No users found"))
|
||||
|
||||
user_data = None
|
||||
for user in all_candidate_data.values():
|
||||
if (user.get("id", "").lower() == query_lower or
|
||||
user.get("username", "").lower() == query_lower or
|
||||
user.get("email", "").lower() == query_lower):
|
||||
if (
|
||||
user.get("id", "").lower() == query_lower
|
||||
or user.get("username", "").lower() == query_lower
|
||||
or user.get("email", "").lower() == query_lower
|
||||
):
|
||||
user_data = user
|
||||
break
|
||||
|
||||
if not user_data:
|
||||
all_guest_data = await database.get_all_guests()
|
||||
if not all_guest_data:
|
||||
logger.warning(f"⚠️ No guests found in database")
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "No users found")
|
||||
)
|
||||
logger.warning("⚠️ No guests found in database")
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "No users found"))
|
||||
for user in all_guest_data.values():
|
||||
if (user.get("id", "").lower() == query_lower or
|
||||
user.get("username", "").lower() == query_lower or
|
||||
user.get("email", "").lower() == query_lower):
|
||||
if (
|
||||
user.get("id", "").lower() == query_lower
|
||||
or user.get("username", "").lower() == query_lower
|
||||
or user.get("email", "").lower() == query_lower
|
||||
):
|
||||
user_data = user
|
||||
break
|
||||
|
||||
if not user_data:
|
||||
logger.warning(f"⚠️ User nor Guest found for reference: {reference}")
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content=create_error_response("NOT_FOUND", "User not found")
|
||||
)
|
||||
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "User not found"))
|
||||
|
||||
user = BaseUserWithType.model_validate(user_data)
|
||||
|
||||
@ -73,8 +62,4 @@ async def get_user(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Get user error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=create_error_response("FETCH_ERROR", str(e))
|
||||
)
|
||||
|
||||
return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e)))
|
||||
|
@ -4,6 +4,7 @@ import subprocess
|
||||
import math
|
||||
from models import SystemInfo
|
||||
|
||||
|
||||
def get_installed_ram():
|
||||
try:
|
||||
with open("/proc/meminfo", "r") as f:
|
||||
@ -19,9 +20,7 @@ def get_graphics_cards():
|
||||
gpus = []
|
||||
try:
|
||||
# Run the ze-monitor utility
|
||||
result = subprocess.run(
|
||||
["ze-monitor"], capture_output=True, text=True, check=True
|
||||
)
|
||||
result = subprocess.run(["ze-monitor"], capture_output=True, text=True, check=True)
|
||||
|
||||
# Clean up the output (remove leading/trailing whitespace and newlines)
|
||||
output = result.stdout.strip()
|
||||
@ -71,6 +70,7 @@ def get_cpu_info():
|
||||
except Exception as e:
|
||||
return f"Error retrieving CPU info: {e}"
|
||||
|
||||
|
||||
def system_info() -> SystemInfo:
|
||||
"""
|
||||
Collects system information including RAM, GPU, CPU, LLM model, embedding model, and context length.
|
||||
|
@ -122,9 +122,7 @@ def get_forecast(grid_endpoint):
|
||||
|
||||
# Process the forecast data into a simpler format
|
||||
forecast = {
|
||||
"location": data["properties"]
|
||||
.get("relativeLocation", {})
|
||||
.get("properties", {}),
|
||||
"location": data["properties"].get("relativeLocation", {}).get("properties", {}),
|
||||
"updated": data["properties"].get("updated", ""),
|
||||
"periods": [],
|
||||
}
|
||||
@ -181,7 +179,7 @@ def get_forecast(grid_endpoint):
|
||||
def TickerValue(ticker_symbols):
|
||||
api_key = os.getenv("TWELVEDATA_API_KEY", "")
|
||||
if not api_key:
|
||||
return {"error": f"Error fetching data: No API key for TwelveData"}
|
||||
return {"error": "Error fetching data: No API key for TwelveData"}
|
||||
|
||||
results = []
|
||||
for ticker_symbol in ticker_symbols.split(","):
|
||||
@ -189,9 +187,7 @@ def TickerValue(ticker_symbols):
|
||||
if ticker_symbol == "":
|
||||
continue
|
||||
|
||||
url = (
|
||||
f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}"
|
||||
)
|
||||
url = f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}"
|
||||
|
||||
response = requests.get(url)
|
||||
data = response.json()
|
||||
@ -244,9 +240,7 @@ def yfTickerValue(ticker_symbols):
|
||||
|
||||
logging.error(f"Error fetching data for {ticker_symbol}: {e}")
|
||||
logging.error(traceback.format_exc())
|
||||
results.append(
|
||||
{"error": f"Error fetching data for {ticker_symbol}: {str(e)}"}
|
||||
)
|
||||
results.append({"error": f"Error fetching data for {ticker_symbol}: {str(e)}"})
|
||||
|
||||
return results[0] if len(results) == 1 else results
|
||||
|
||||
@ -278,8 +272,10 @@ def DateTime(timezone="America/Los_Angeles"):
|
||||
except Exception as e:
|
||||
return {"error": f"Invalid timezone {timezone}: {str(e)}"}
|
||||
|
||||
|
||||
async def GenerateImage(llm, model: str, prompt: str):
|
||||
return { "image_id": "image-a830a83-bd831" }
|
||||
return {"image_id": "image-a830a83-bd831"}
|
||||
|
||||
|
||||
async def AnalyzeSite(llm, model: str, url: str, question: str):
|
||||
"""
|
||||
@ -347,7 +343,6 @@ async def AnalyzeSite(llm, model: str, url: str, question: str):
|
||||
return f"Error processing the website content: {str(e)}"
|
||||
|
||||
|
||||
|
||||
# %%
|
||||
class Function(BaseModel):
|
||||
name: str
|
||||
@ -355,171 +350,181 @@ class Function(BaseModel):
|
||||
parameters: Dict[str, Any]
|
||||
returns: Optional[Dict[str, Any]] = {}
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
type: str
|
||||
function: Function
|
||||
|
||||
tools : List[Tool] = [
|
||||
# Tool.model_validate({
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "name": "GenerateImage",
|
||||
# "description": """\
|
||||
# CRITICAL INSTRUCTIONS FOR IMAGE GENERATION:
|
||||
|
||||
# 1. Call this tool when users request images, drawings, or visual content
|
||||
# 2. This tool returns an image_id (e.g., "img_abc123")
|
||||
# 3. MANDATORY: You must respond with EXACTLY this format: <GenerateImage id={image_id}/>
|
||||
# 4. FORBIDDEN: DO NOT use markdown image syntax 
|
||||
# 5. FORBIDDEN: DO NOT create fake URLs or file paths
|
||||
# 6. FORBIDDEN: DO NOT use any other image embedding format
|
||||
|
||||
# CORRECT EXAMPLE:
|
||||
# User: "Draw a cat"
|
||||
# Tool returns: {"image_id": "img_xyz789"}
|
||||
# Your response: "Here's your cat image: <GenerateImage id=img_xyz789/>"
|
||||
|
||||
# WRONG EXAMPLES (DO NOT DO THIS):
|
||||
# - 
|
||||
# - 
|
||||
# - <img src="...">
|
||||
|
||||
# The <GenerateImage id={image_id}/> format is the ONLY way to display images in this system.
|
||||
# """,
|
||||
# "parameters": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "prompt": {
|
||||
# "type": "string",
|
||||
# "description": "Detailed image description including style, colors, subject, composition"
|
||||
# }
|
||||
# },
|
||||
# "required": ["prompt"]
|
||||
# },
|
||||
# "returns": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "image_id": {
|
||||
# "type": "string",
|
||||
# "description": "Unique identifier for the generated image. Use this EXACTLY in <GenerateImage id={this_value}/>"
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
# }),
|
||||
Tool.model_validate({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "TickerValue",
|
||||
"description": "Get the current stock price of one or more ticker symbols. Returns an array of objects with 'symbol' and 'price' fields. Call this whenever you need to know the latest value of stock ticker symbols, for example when a user asks 'How much is Intel trading at?' or 'What are the prices of AAPL and MSFT?'",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"ticker": {
|
||||
"type": "string",
|
||||
"description": "The company stock ticker symbol. For multiple tickers, provide a comma-separated list (e.g., 'AAPL,MSFT,GOOGL').",
|
||||
tools: List[Tool] = [
|
||||
# Tool.model_validate({
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "name": "GenerateImage",
|
||||
# "description": """\
|
||||
# CRITICAL INSTRUCTIONS FOR IMAGE GENERATION:
|
||||
# 1. Call this tool when users request images, drawings, or visual content
|
||||
# 2. This tool returns an image_id (e.g., "img_abc123")
|
||||
# 3. MANDATORY: You must respond with EXACTLY this format: <GenerateImage id={image_id}/>
|
||||
# 4. FORBIDDEN: DO NOT use markdown image syntax 
|
||||
# 5. FORBIDDEN: DO NOT create fake URLs or file paths
|
||||
# 6. FORBIDDEN: DO NOT use any other image embedding format
|
||||
# CORRECT EXAMPLE:
|
||||
# User: "Draw a cat"
|
||||
# Tool returns: {"image_id": "img_xyz789"}
|
||||
# Your response: "Here's your cat image: <GenerateImage id=img_xyz789/>"
|
||||
# WRONG EXAMPLES (DO NOT DO THIS):
|
||||
# - 
|
||||
# - 
|
||||
# - <img src="...">
|
||||
# The <GenerateImage id={image_id}/> format is the ONLY way to display images in this system.
|
||||
# """,
|
||||
# "parameters": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "prompt": {
|
||||
# "type": "string",
|
||||
# "description": "Detailed image description including style, colors, subject, composition"
|
||||
# }
|
||||
# },
|
||||
# "required": ["prompt"]
|
||||
# },
|
||||
# "returns": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "image_id": {
|
||||
# "type": "string",
|
||||
# "description": "Unique identifier for the generated image. Use this EXACTLY in <GenerateImage id={this_value}/>"
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
# }
|
||||
# }),
|
||||
Tool.model_validate(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "TickerValue",
|
||||
"description": "Get the current stock price of one or more ticker symbols. Returns an array of objects with 'symbol' and 'price' fields. Call this whenever you need to know the latest value of stock ticker symbols, for example when a user asks 'How much is Intel trading at?' or 'What are the prices of AAPL and MSFT?'",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"ticker": {
|
||||
"type": "string",
|
||||
"description": "The company stock ticker symbol. For multiple tickers, provide a comma-separated list (e.g., 'AAPL,MSFT,GOOGL').",
|
||||
},
|
||||
},
|
||||
"required": ["ticker"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"required": ["ticker"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}),
|
||||
Tool.model_validate({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "AnalyzeSite",
|
||||
"description": "Downloads the requested site and asks a second LLM agent to answer the question based on the site content. For example if the user says 'What are the top headlines on cnn.com?' you would use AnalyzeSite to get the answer. Only use this if the user asks about a specific site or company.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The website URL to download and process",
|
||||
},
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "The question to ask the second LLM about the content",
|
||||
}
|
||||
),
|
||||
Tool.model_validate(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "AnalyzeSite",
|
||||
"description": "Downloads the requested site and asks a second LLM agent to answer the question based on the site content. For example if the user says 'What are the top headlines on cnn.com?' you would use AnalyzeSite to get the answer. Only use this if the user asks about a specific site or company.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "The website URL to download and process",
|
||||
},
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "The question to ask the second LLM about the content",
|
||||
},
|
||||
},
|
||||
"required": ["url", "question"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"required": ["url", "question"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"returns": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"source": {
|
||||
"type": "string",
|
||||
"description": "Identifier for the source LLM",
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The complete response from the second LLM",
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"description": "Additional information about the response",
|
||||
"returns": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"source": {
|
||||
"type": "string",
|
||||
"description": "Identifier for the source LLM",
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The complete response from the second LLM",
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"description": "Additional information about the response",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
Tool.model_validate({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "DateTime",
|
||||
"description": "Get the current date and time in a specified timezone. For example if a user asks 'What time is it in Poland?' you would pass the Warsaw timezone to DateTime.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "Timezone name (e.g., 'UTC', 'America/New_York', 'Europe/London', 'America/Los_Angeles'). Default is 'America/Los_Angeles'.",
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
}),
|
||||
Tool.model_validate({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WeatherForecast",
|
||||
"description": "Get the full weather forecast as structured data for a given CITY and STATE location in the United States. For example, if the user asks 'What is the weather in Portland?' or 'What is the forecast for tomorrow?' use the provided data to answer the question.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City to find the weather forecast (e.g., 'Portland', 'Seattle').",
|
||||
"minLength": 2,
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description": "State to find the weather forecast (e.g., 'OR', 'WA').",
|
||||
"minLength": 2,
|
||||
}
|
||||
),
|
||||
Tool.model_validate(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "DateTime",
|
||||
"description": "Get the current date and time in a specified timezone. For example if a user asks 'What time is it in Poland?' you would pass the Warsaw timezone to DateTime.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "Timezone name (e.g., 'UTC', 'America/New_York', 'Europe/London', 'America/Los_Angeles'). Default is 'America/Los_Angeles'.",
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
"required": ["city", "state"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}),
|
||||
}
|
||||
),
|
||||
Tool.model_validate(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WeatherForecast",
|
||||
"description": "Get the full weather forecast as structured data for a given CITY and STATE location in the United States. For example, if the user asks 'What is the weather in Portland?' or 'What is the forecast for tomorrow?' use the provided data to answer the question.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City to find the weather forecast (e.g., 'Portland', 'Seattle').",
|
||||
"minLength": 2,
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description": "State to find the weather forecast (e.g., 'OR', 'WA').",
|
||||
"minLength": 2,
|
||||
},
|
||||
},
|
||||
"required": ["city", "state"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class ToolEntry(BaseModel):
|
||||
enabled: bool = True
|
||||
tool: Tool
|
||||
|
||||
|
||||
def llm_tools(tools: List[ToolEntry]) -> List[Dict[str, Any]]:
|
||||
return [entry.tool.model_dump(mode='json') for entry in tools if entry.enabled == True]
|
||||
return [entry.tool.model_dump(mode="json") for entry in tools if entry.enabled is True]
|
||||
|
||||
|
||||
def all_tools() -> List[ToolEntry]:
|
||||
return [ToolEntry(tool=tool) for tool in tools]
|
||||
|
||||
|
||||
def enabled_tools(tools: List[ToolEntry]) -> List[ToolEntry]:
|
||||
return [ToolEntry(tool=entry.tool) for entry in tools if entry.enabled == True]
|
||||
return [ToolEntry(tool=entry.tool) for entry in tools if entry.enabled is True]
|
||||
|
||||
|
||||
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite", "GenerateImage"]
|
||||
__all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"]
|
||||
|
||||
|
@ -14,6 +14,7 @@ from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PasswordSecurity:
|
||||
"""Handles password hashing and verification using bcrypt"""
|
||||
|
||||
@ -32,9 +33,9 @@ class PasswordSecurity:
|
||||
salt = bcrypt.gensalt()
|
||||
|
||||
# Hash the password
|
||||
password_hash = bcrypt.hashpw(password.encode('utf-8'), salt)
|
||||
password_hash = bcrypt.hashpw(password.encode("utf-8"), salt)
|
||||
|
||||
return password_hash.decode('utf-8'), salt.decode('utf-8')
|
||||
return password_hash.decode("utf-8"), salt.decode("utf-8")
|
||||
|
||||
@staticmethod
|
||||
def verify_password(password: str, password_hash: str) -> bool:
|
||||
@ -49,10 +50,7 @@ class PasswordSecurity:
|
||||
True if password matches, False otherwise
|
||||
"""
|
||||
try:
|
||||
return bcrypt.checkpw(
|
||||
password.encode('utf-8'),
|
||||
password_hash.encode('utf-8')
|
||||
)
|
||||
return bcrypt.checkpw(password.encode("utf-8"), password_hash.encode("utf-8"))
|
||||
except Exception as e:
|
||||
logger.error(f"Password verification error: {e}")
|
||||
return False
|
||||
@ -62,8 +60,10 @@ class PasswordSecurity:
|
||||
"""Generate a cryptographically secure random token"""
|
||||
return secrets.token_urlsafe(length)
|
||||
|
||||
|
||||
class AuthenticationRecord(BaseModel):
|
||||
"""Authentication record for storing user credentials"""
|
||||
|
||||
user_id: str
|
||||
password_hash: str
|
||||
salt: str
|
||||
@ -78,18 +78,19 @@ class AuthenticationRecord(BaseModel):
|
||||
locked_until: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat() if v else None
|
||||
}
|
||||
json_encoders = {datetime: lambda v: v.isoformat() if v else None}
|
||||
|
||||
|
||||
class SecurityConfig:
|
||||
"""Security configuration constants"""
|
||||
|
||||
MAX_LOGIN_ATTEMPTS = 5
|
||||
ACCOUNT_LOCKOUT_DURATION_MINUTES = 15
|
||||
PASSWORD_MIN_LENGTH = 8
|
||||
TOKEN_EXPIRY_HOURS = 24
|
||||
REFRESH_TOKEN_EXPIRY_DAYS = 30
|
||||
|
||||
|
||||
class AuthenticationManager:
|
||||
"""Manages authentication operations with security features"""
|
||||
|
||||
@ -120,7 +121,7 @@ class AuthenticationManager:
|
||||
password_hash=password_hash,
|
||||
salt=salt,
|
||||
last_password_change=datetime.now(timezone.utc),
|
||||
login_attempts=0
|
||||
login_attempts=0,
|
||||
)
|
||||
|
||||
# Store in database
|
||||
@ -129,7 +130,9 @@ class AuthenticationManager:
|
||||
logger.info(f"🔐 Created authentication record for user {user_id}")
|
||||
return auth_record
|
||||
|
||||
async def verify_user_credentials(self, login: str, password: str) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
|
||||
async def verify_user_credentials(
|
||||
self, login: str, password: str
|
||||
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
|
||||
"""
|
||||
Verify user credentials with security checks
|
||||
|
||||
@ -164,7 +167,11 @@ class AuthenticationManager:
|
||||
seconds = int(total_seconds % 60)
|
||||
time_until_unlock_str = f"{minutes}m {seconds}s"
|
||||
logger.warning(f"🔒 Account is locked for user {login} for another {time_until_unlock_str}.")
|
||||
return False, None, f"Account is temporarily locked due to too many failed attempts. Retry after {time_until_unlock_str}"
|
||||
return (
|
||||
False,
|
||||
None,
|
||||
f"Account is temporarily locked due to too many failed attempts. Retry after {time_until_unlock_str}",
|
||||
)
|
||||
|
||||
# Verify password
|
||||
if not self.password_security.verify_password(password, auth_data.password_hash):
|
||||
@ -176,7 +183,9 @@ class AuthenticationManager:
|
||||
auth_data.locked_until = datetime.now(timezone.utc) + timedelta(
|
||||
minutes=SecurityConfig.ACCOUNT_LOCKOUT_DURATION_MINUTES
|
||||
)
|
||||
logger.warning(f"🔒 Account locked for user {login} after {auth_data.login_attempts} failed attempts")
|
||||
logger.warning(
|
||||
f"🔒 Account locked for user {login} after {auth_data.login_attempts} failed attempts"
|
||||
)
|
||||
|
||||
# Update authentication record
|
||||
await self.database.set_authentication(user_data["id"], auth_data.model_dump())
|
||||
@ -238,6 +247,7 @@ class AuthenticationManager:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error updating last login for user {user_id}: {e}")
|
||||
|
||||
|
||||
# Utility functions for common operations
|
||||
def validate_password_strength(password: str) -> Tuple[bool, list]:
|
||||
"""
|
||||
@ -270,6 +280,7 @@ def validate_password_strength(password: str) -> Tuple[bool, list]:
|
||||
|
||||
return len(issues) == 0, issues
|
||||
|
||||
|
||||
def sanitize_login_input(login: str) -> str:
|
||||
"""Sanitize login input (email or username)"""
|
||||
return login.strip().lower() if login else ""
|
@ -19,7 +19,7 @@ from models import BaseUserWithType, Candidate, CandidateAI, Employer, Guest
|
||||
from logger import logger
|
||||
from background_tasks import BackgroundTaskManager
|
||||
|
||||
#from . rate_limiter import RateLimiter
|
||||
# from . rate_limiter import RateLimiter
|
||||
|
||||
# Security
|
||||
security = HTTPBearer()
|
||||
@ -33,11 +33,13 @@ background_task_manager: Optional[BackgroundTaskManager] = None
|
||||
# Global database manager reference
|
||||
db_manager = None
|
||||
|
||||
|
||||
def set_db_manager(manager: DatabaseManager):
|
||||
"""Set the global database manager reference"""
|
||||
global db_manager
|
||||
db_manager = manager
|
||||
|
||||
|
||||
def get_database() -> RedisDatabase:
|
||||
"""
|
||||
Safe database dependency that checks for availability
|
||||
@ -47,26 +49,18 @@ def get_database() -> RedisDatabase:
|
||||
|
||||
if db_manager is None:
|
||||
logger.error("Database manager not initialized")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Database not available - service starting up"
|
||||
)
|
||||
raise HTTPException(status_code=503, detail="Database not available - service starting up")
|
||||
|
||||
if db_manager.is_shutting_down:
|
||||
logger.warning("Database is shutting down")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Service is shutting down"
|
||||
)
|
||||
raise HTTPException(status_code=503, detail="Service is shutting down")
|
||||
|
||||
try:
|
||||
return db_manager.get_database()
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Database not available: {e}")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Database connection not available"
|
||||
)
|
||||
raise HTTPException(status_code=503, detail="Database connection not available")
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
||||
to_encode = data.copy()
|
||||
@ -78,6 +72,7 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
||||
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||
"""Enhanced token verification with guest session recovery"""
|
||||
try:
|
||||
@ -125,9 +120,9 @@ async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials
|
||||
logger.error(f"❌ Token verification error: {e}")
|
||||
raise HTTPException(status_code=401, detail="Token verification failed")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
user_id: str = Depends(verify_token_with_blacklist),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database)
|
||||
) -> BaseUserWithType:
|
||||
"""Get current user from database"""
|
||||
try:
|
||||
@ -135,6 +130,7 @@ async def get_current_user(
|
||||
candidate_data = await database.get_candidate(user_id)
|
||||
if candidate_data:
|
||||
from helpers.model_cast import cast_to_base_user_with_type
|
||||
|
||||
if candidate_data.get("is_AI"):
|
||||
return cast_to_base_user_with_type(CandidateAI.model_validate(candidate_data))
|
||||
else:
|
||||
@ -152,16 +148,20 @@ async def get_current_user(
|
||||
logger.error(f"❌ Error getting current user: {e}")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
|
||||
async def get_current_user_or_guest(
|
||||
user_id: str = Depends(verify_token_with_blacklist),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database)
|
||||
) -> BaseUserWithType:
|
||||
"""Get current user (including guests) from database"""
|
||||
try:
|
||||
# Check candidates first
|
||||
candidate_data = await database.get_candidate(user_id)
|
||||
if candidate_data:
|
||||
return Candidate.model_validate(candidate_data) if not candidate_data.get("is_AI") else CandidateAI.model_validate(candidate_data)
|
||||
return (
|
||||
Candidate.model_validate(candidate_data)
|
||||
if not candidate_data.get("is_AI")
|
||||
else CandidateAI.model_validate(candidate_data)
|
||||
)
|
||||
|
||||
# Check employers
|
||||
employer_data = await database.get_employer(user_id)
|
||||
@ -180,9 +180,9 @@ async def get_current_user_or_guest(
|
||||
logger.error(f"❌ Error getting current user: {e}")
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
|
||||
async def get_current_admin(
|
||||
user_id: str = Depends(verify_token_with_blacklist),
|
||||
database: RedisDatabase = Depends(get_database)
|
||||
user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database)
|
||||
) -> BaseUserWithType:
|
||||
user = await get_current_user(user_id=user_id, database=database)
|
||||
if isinstance(user, Candidate) and user.is_admin:
|
||||
@ -193,6 +193,7 @@ async def get_current_admin(
|
||||
logger.warning(f"⚠️ User {user_id} is not an admin")
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
|
||||
prometheus_collector = CollectorRegistry()
|
||||
|
||||
# Keep the Instrumentator instance alive
|
||||
@ -201,5 +202,5 @@ instrumentator = Instrumentator(
|
||||
should_ignore_untemplated=True,
|
||||
should_group_untemplated=True,
|
||||
excluded_handlers=[f"{defines.api_prefix}/metrics"],
|
||||
registry=prometheus_collector
|
||||
registry=prometheus_collector,
|
||||
)
|
||||
|
@ -1,5 +1,6 @@
|
||||
from typing import List, Dict
|
||||
from models import (Job)
|
||||
from models import Job
|
||||
|
||||
|
||||
def get_requirements_list(job: Job) -> List[Dict[str, str]]:
|
||||
requirements: List[Dict[str, str]] = []
|
||||
@ -7,56 +8,56 @@ def get_requirements_list(job: Job) -> List[Dict[str, str]]:
|
||||
if job.requirements:
|
||||
if job.requirements.technical_skills:
|
||||
if job.requirements.technical_skills.required:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Technical Skills (required)"}
|
||||
for req in job.requirements.technical_skills.required
|
||||
])
|
||||
requirements.extend(
|
||||
[
|
||||
{"requirement": req, "domain": "Technical Skills (required)"}
|
||||
for req in job.requirements.technical_skills.required
|
||||
]
|
||||
)
|
||||
if job.requirements.technical_skills.preferred:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Technical Skills (preferred)"}
|
||||
for req in job.requirements.technical_skills.preferred
|
||||
])
|
||||
requirements.extend(
|
||||
[
|
||||
{"requirement": req, "domain": "Technical Skills (preferred)"}
|
||||
for req in job.requirements.technical_skills.preferred
|
||||
]
|
||||
)
|
||||
|
||||
if job.requirements.experience_requirements:
|
||||
if job.requirements.experience_requirements.required:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Experience (required)"}
|
||||
for req in job.requirements.experience_requirements.required
|
||||
])
|
||||
requirements.extend(
|
||||
[
|
||||
{"requirement": req, "domain": "Experience (required)"}
|
||||
for req in job.requirements.experience_requirements.required
|
||||
]
|
||||
)
|
||||
if job.requirements.experience_requirements.preferred:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Experience (preferred)"}
|
||||
for req in job.requirements.experience_requirements.preferred
|
||||
])
|
||||
requirements.extend(
|
||||
[
|
||||
{"requirement": req, "domain": "Experience (preferred)"}
|
||||
for req in job.requirements.experience_requirements.preferred
|
||||
]
|
||||
)
|
||||
|
||||
if job.requirements.soft_skills:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Soft Skills"}
|
||||
for req in job.requirements.soft_skills
|
||||
])
|
||||
requirements.extend([{"requirement": req, "domain": "Soft Skills"} for req in job.requirements.soft_skills])
|
||||
|
||||
if job.requirements.experience:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Experience"}
|
||||
for req in job.requirements.experience
|
||||
])
|
||||
requirements.extend([{"requirement": req, "domain": "Experience"} for req in job.requirements.experience])
|
||||
|
||||
if job.requirements.education:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Education"}
|
||||
for req in job.requirements.education
|
||||
])
|
||||
requirements.extend([{"requirement": req, "domain": "Education"} for req in job.requirements.education])
|
||||
|
||||
if job.requirements.certifications:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Certifications"}
|
||||
for req in job.requirements.certifications
|
||||
])
|
||||
requirements.extend(
|
||||
[{"requirement": req, "domain": "Certifications"} for req in job.requirements.certifications]
|
||||
)
|
||||
|
||||
if job.requirements.preferred_attributes:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Preferred Attributes"}
|
||||
for req in job.requirements.preferred_attributes
|
||||
])
|
||||
requirements.extend(
|
||||
[
|
||||
{"requirement": req, "domain": "Preferred Attributes"}
|
||||
for req in job.requirements.preferred_attributes
|
||||
]
|
||||
)
|
||||
|
||||
return requirements
|
@ -13,44 +13,13 @@ from fastapi.responses import StreamingResponse
|
||||
import defines
|
||||
from logger import logger
|
||||
from models import DocumentType
|
||||
from models import (
|
||||
LoginRequest, CreateCandidateRequest, CreateEmployerRequest,
|
||||
Candidate, Employer, Guest, AuthResponse,
|
||||
MFARequest, MFAData, MFARequestResponse, MFAVerifyRequest,
|
||||
EmailVerificationRequest, ResendVerificationRequest,
|
||||
# API
|
||||
MOCK_UUID, ApiActivityType, ChatMessageError, ChatMessageResume,
|
||||
ChatMessageSkillAssessment, ChatMessageStatus, ChatMessageStreaming,
|
||||
ChatMessageUser, DocumentMessage, DocumentOptions, Job,
|
||||
JobRequirements, JobRequirementsMessage, LoginRequest,
|
||||
CreateCandidateRequest, CreateEmployerRequest,
|
||||
|
||||
# User models
|
||||
Candidate, Employer, BaseUserWithType, BaseUser, Guest,
|
||||
Authentication, AuthResponse, CandidateAI,
|
||||
|
||||
# Job models
|
||||
JobApplication, ApplicationStatus,
|
||||
|
||||
# Chat models
|
||||
ChatSession, ChatMessage, ChatContext, ChatQuery, ChatSenderType, ApiMessageType, ChatContextType,
|
||||
ChatMessageRagSearch,
|
||||
|
||||
# Document models
|
||||
Document, DocumentType, DocumentListResponse, DocumentUpdateRequest, DocumentContentResponse,
|
||||
|
||||
# Supporting models
|
||||
Location, MFARequest, MFAData, MFARequestResponse, MFAVerifyRequest, RagContentMetadata, RagContentResponse, ResendVerificationRequest, Resume, ResumeMessage, Skill, SkillAssessment, SystemInfo, UserType, WorkExperience, Education,
|
||||
|
||||
# Email
|
||||
EmailVerificationRequest,
|
||||
ApiStatusType
|
||||
)
|
||||
from models import Job, ChatMessage, DocumentType, ApiStatusType
|
||||
|
||||
from typing import List, Dict
|
||||
from models import (Job)
|
||||
from models import Job
|
||||
import utils.llm_proxy as llm_manager
|
||||
|
||||
|
||||
async def get_last_item(generator):
|
||||
"""Get the last item from an async generator"""
|
||||
last_item = None
|
||||
@ -65,7 +34,7 @@ def filter_and_paginate(
|
||||
limit: int = 20,
|
||||
sort_by: Optional[str] = None,
|
||||
sort_order: str = "desc",
|
||||
filters: Optional[Dict] = None
|
||||
filters: Optional[Dict] = None,
|
||||
) -> Tuple[List[Any], int]:
|
||||
"""Filter, sort, and paginate items"""
|
||||
filtered_items = items.copy()
|
||||
@ -76,8 +45,7 @@ def filter_and_paginate(
|
||||
if isinstance(filtered_items[0], dict) and key in filtered_items[0]:
|
||||
filtered_items = [item for item in filtered_items if item.get(key) == value]
|
||||
elif hasattr(filtered_items[0], key) if filtered_items else False:
|
||||
filtered_items = [item for item in filtered_items
|
||||
if getattr(item, key, None) == value]
|
||||
filtered_items = [item for item in filtered_items if getattr(item, key, None) == value]
|
||||
|
||||
# Sort items
|
||||
if sort_by and filtered_items:
|
||||
@ -101,9 +69,9 @@ def filter_and_paginate(
|
||||
|
||||
async def stream_agent_response(chat_agent, user_message, chat_session_data=None, database=None) -> StreamingResponse:
|
||||
"""Stream agent response with proper formatting"""
|
||||
|
||||
async def message_stream_generator():
|
||||
"""Generator to stream messages with persistence"""
|
||||
last_log = None
|
||||
final_message = None
|
||||
|
||||
import utils.llm_proxy as llm_manager
|
||||
@ -127,12 +95,15 @@ async def stream_agent_response(chat_agent, user_message, chat_session_data=None
|
||||
# metadata and other unnecessary fields for streaming
|
||||
if generated_message.status != ApiStatusType.DONE:
|
||||
from models import ChatMessageStreaming, ChatMessageStatus
|
||||
if not isinstance(generated_message, ChatMessageStreaming) and not isinstance(generated_message, ChatMessageStatus):
|
||||
|
||||
if not isinstance(generated_message, ChatMessageStreaming) and not isinstance(
|
||||
generated_message, ChatMessageStatus
|
||||
):
|
||||
raise TypeError(
|
||||
f"Expected ChatMessageStreaming or ChatMessageStatus, got {type(generated_message)}"
|
||||
)
|
||||
|
||||
json_data = generated_message.model_dump(mode='json', by_alias=True)
|
||||
json_data = generated_message.model_dump(mode="json", by_alias=True)
|
||||
json_str = json.dumps(json_data)
|
||||
|
||||
yield f"data: {json_str}\n\n"
|
||||
@ -177,16 +148,16 @@ def get_document_type_from_filename(filename: str) -> DocumentType:
|
||||
extension = pathlib.Path(filename).suffix.lower()
|
||||
|
||||
type_mapping = {
|
||||
'.pdf': DocumentType.PDF,
|
||||
'.docx': DocumentType.DOCX,
|
||||
'.doc': DocumentType.DOCX,
|
||||
'.txt': DocumentType.TXT,
|
||||
'.md': DocumentType.MARKDOWN,
|
||||
'.markdown': DocumentType.MARKDOWN,
|
||||
'.png': DocumentType.IMAGE,
|
||||
'.jpg': DocumentType.IMAGE,
|
||||
'.jpeg': DocumentType.IMAGE,
|
||||
'.gif': DocumentType.IMAGE,
|
||||
".pdf": DocumentType.PDF,
|
||||
".docx": DocumentType.DOCX,
|
||||
".doc": DocumentType.DOCX,
|
||||
".txt": DocumentType.TXT,
|
||||
".md": DocumentType.MARKDOWN,
|
||||
".markdown": DocumentType.MARKDOWN,
|
||||
".png": DocumentType.IMAGE,
|
||||
".jpg": DocumentType.IMAGE,
|
||||
".jpeg": DocumentType.IMAGE,
|
||||
".gif": DocumentType.IMAGE,
|
||||
}
|
||||
|
||||
return type_mapping.get(extension, DocumentType.TXT)
|
||||
@ -206,17 +177,12 @@ async def reformat_as_markdown(database, candidate_entity, content: str):
|
||||
|
||||
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
|
||||
if not chat_agent:
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID,
|
||||
content="No agent found for job requirements chat type"
|
||||
)
|
||||
error_message = ChatMessageError(session_id=MOCK_UUID, content="No agent found for job requirements chat type")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID,
|
||||
content=f"Reformatting job description as markdown...",
|
||||
activity=ApiActivityType.CONVERTING
|
||||
session_id=MOCK_UUID, content="Reformatting job description as markdown...", activity=ApiActivityType.CONVERTING
|
||||
)
|
||||
yield status_message
|
||||
|
||||
@ -229,26 +195,23 @@ async def reformat_as_markdown(database, candidate_entity, content: str):
|
||||
system_prompt="""
|
||||
You are a document editor. Take the provided job description and reformat as legible markdown.
|
||||
Return only the markdown content, no other text. Make sure all content is included.
|
||||
"""
|
||||
""",
|
||||
):
|
||||
pass
|
||||
|
||||
if not message or not isinstance(message, ChatMessage):
|
||||
logger.error("❌ Failed to reformat job description to markdown")
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID,
|
||||
content="Failed to reformat job description"
|
||||
)
|
||||
error_message = ChatMessageError(session_id=MOCK_UUID, content="Failed to reformat job description")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
chat_message: ChatMessage = message
|
||||
try:
|
||||
chat_message.content = chat_agent.extract_markdown_from_text(chat_message.content)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(f"✅ Successfully converted content to markdown")
|
||||
logger.info("✅ Successfully converted content to markdown")
|
||||
yield chat_message
|
||||
return
|
||||
|
||||
@ -256,14 +219,19 @@ Return only the markdown content, no other text. Make sure all content is includ
|
||||
async def create_job_from_content(database, current_user, content: str):
|
||||
"""Create a job from content using AI analysis"""
|
||||
from models import (
|
||||
MOCK_UUID, ApiStatusType, ChatMessageError, ChatMessageStatus,
|
||||
ApiActivityType, ChatContextType, JobRequirementsMessage
|
||||
MOCK_UUID,
|
||||
ApiStatusType,
|
||||
ChatMessageError,
|
||||
ChatMessageStatus,
|
||||
ApiActivityType,
|
||||
ChatContextType,
|
||||
JobRequirementsMessage,
|
||||
)
|
||||
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID,
|
||||
content=f"Initiating connection with {current_user.first_name}'s AI agent...",
|
||||
activity=ApiActivityType.INFO
|
||||
activity=ApiActivityType.INFO,
|
||||
)
|
||||
yield status_message
|
||||
await asyncio.sleep(0) # Let the status message propagate
|
||||
@ -278,10 +246,7 @@ async def create_job_from_content(database, current_user, content: str):
|
||||
yield message
|
||||
|
||||
if not message or not isinstance(message, ChatMessage):
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID,
|
||||
content="Failed to reformat job description"
|
||||
)
|
||||
error_message = ChatMessageError(session_id=MOCK_UUID, content="Failed to reformat job description")
|
||||
yield error_message
|
||||
return
|
||||
|
||||
@ -290,33 +255,28 @@ async def create_job_from_content(database, current_user, content: str):
|
||||
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
|
||||
if not chat_agent:
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID,
|
||||
content="No agent found for job requirements chat type"
|
||||
session_id=MOCK_UUID, content="No agent found for job requirements chat type"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
|
||||
status_message = ChatMessageStatus(
|
||||
session_id=MOCK_UUID,
|
||||
content=f"Analyzing document for company and requirement details...",
|
||||
activity=ApiActivityType.SEARCHING
|
||||
content="Analyzing document for company and requirement details...",
|
||||
activity=ApiActivityType.SEARCHING,
|
||||
)
|
||||
yield status_message
|
||||
|
||||
message = None
|
||||
async for message in chat_agent.generate(
|
||||
llm=llm_manager.get_llm(),
|
||||
model=defines.model,
|
||||
session_id=MOCK_UUID,
|
||||
prompt=markdown_message.content
|
||||
llm=llm_manager.get_llm(), model=defines.model, session_id=MOCK_UUID, prompt=markdown_message.content
|
||||
):
|
||||
if message.status != ApiStatusType.DONE:
|
||||
yield message
|
||||
|
||||
if not message or not isinstance(message, JobRequirementsMessage):
|
||||
error_message = ChatMessageError(
|
||||
session_id=MOCK_UUID,
|
||||
content="Job extraction did not convert successfully"
|
||||
session_id=MOCK_UUID, content="Job extraction did not convert successfully"
|
||||
)
|
||||
yield error_message
|
||||
return
|
||||
@ -326,62 +286,63 @@ async def create_job_from_content(database, current_user, content: str):
|
||||
yield job_requirements
|
||||
return
|
||||
|
||||
|
||||
def get_requirements_list(job: Job) -> List[Dict[str, str]]:
|
||||
requirements: List[Dict[str, str]] = []
|
||||
|
||||
if job.requirements:
|
||||
if job.requirements.technical_skills:
|
||||
if job.requirements.technical_skills.required:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Technical Skills (required)"}
|
||||
for req in job.requirements.technical_skills.required
|
||||
])
|
||||
requirements.extend(
|
||||
[
|
||||
{"requirement": req, "domain": "Technical Skills (required)"}
|
||||
for req in job.requirements.technical_skills.required
|
||||
]
|
||||
)
|
||||
if job.requirements.technical_skills.preferred:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Technical Skills (preferred)"}
|
||||
for req in job.requirements.technical_skills.preferred
|
||||
])
|
||||
requirements.extend(
|
||||
[
|
||||
{"requirement": req, "domain": "Technical Skills (preferred)"}
|
||||
for req in job.requirements.technical_skills.preferred
|
||||
]
|
||||
)
|
||||
|
||||
if job.requirements.experience_requirements:
|
||||
if job.requirements.experience_requirements.required:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Experience (required)"}
|
||||
for req in job.requirements.experience_requirements.required
|
||||
])
|
||||
requirements.extend(
|
||||
[
|
||||
{"requirement": req, "domain": "Experience (required)"}
|
||||
for req in job.requirements.experience_requirements.required
|
||||
]
|
||||
)
|
||||
if job.requirements.experience_requirements.preferred:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Experience (preferred)"}
|
||||
for req in job.requirements.experience_requirements.preferred
|
||||
])
|
||||
requirements.extend(
|
||||
[
|
||||
{"requirement": req, "domain": "Experience (preferred)"}
|
||||
for req in job.requirements.experience_requirements.preferred
|
||||
]
|
||||
)
|
||||
|
||||
if job.requirements.soft_skills:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Soft Skills"}
|
||||
for req in job.requirements.soft_skills
|
||||
])
|
||||
requirements.extend([{"requirement": req, "domain": "Soft Skills"} for req in job.requirements.soft_skills])
|
||||
|
||||
if job.requirements.experience:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Experience"}
|
||||
for req in job.requirements.experience
|
||||
])
|
||||
requirements.extend([{"requirement": req, "domain": "Experience"} for req in job.requirements.experience])
|
||||
|
||||
if job.requirements.education:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Education"}
|
||||
for req in job.requirements.education
|
||||
])
|
||||
requirements.extend([{"requirement": req, "domain": "Education"} for req in job.requirements.education])
|
||||
|
||||
if job.requirements.certifications:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Certifications"}
|
||||
for req in job.requirements.certifications
|
||||
])
|
||||
requirements.extend(
|
||||
[{"requirement": req, "domain": "Certifications"} for req in job.requirements.certifications]
|
||||
)
|
||||
|
||||
if job.requirements.preferred_attributes:
|
||||
requirements.extend([
|
||||
{"requirement": req, "domain": "Preferred Attributes"}
|
||||
for req in job.requirements.preferred_attributes
|
||||
])
|
||||
requirements.extend(
|
||||
[
|
||||
{"requirement": req, "domain": "Preferred Attributes"}
|
||||
for req in job.requirements.preferred_attributes
|
||||
]
|
||||
)
|
||||
|
||||
return requirements
|
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,7 @@
|
||||
from prometheus_client import Counter, Histogram # type: ignore
|
||||
from threading import Lock
|
||||
|
||||
|
||||
def singleton(cls):
|
||||
instance = None
|
||||
lock = Lock()
|
||||
|
@ -6,56 +6,68 @@ from functools import wraps
|
||||
from datetime import datetime, timedelta, UTC
|
||||
from typing import Callable, Dict, Optional, Any
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from pydantic import BaseModel # type: ignore
|
||||
from pydantic import BaseModel # type: ignore
|
||||
from database.manager import RedisDatabase
|
||||
from logger import logger
|
||||
|
||||
from . dependencies import get_current_user_or_guest, get_database
|
||||
from .dependencies import get_current_user_or_guest, get_database
|
||||
|
||||
|
||||
async def get_rate_limiter(database: RedisDatabase = Depends(get_database)) -> RateLimiter:
|
||||
"""Dependency to get rate limiter instance"""
|
||||
return RateLimiter(database)
|
||||
|
||||
|
||||
class RateLimitConfig(BaseModel):
|
||||
"""Rate limit configuration"""
|
||||
|
||||
requests_per_minute: int
|
||||
requests_per_hour: int
|
||||
requests_per_day: int
|
||||
burst_limit: int # Maximum requests in a short burst
|
||||
burst_window_seconds: int = 60 # Window for burst detection
|
||||
|
||||
|
||||
class GuestRateLimitConfig(RateLimitConfig):
|
||||
"""Rate limits for guest users - more restrictive"""
|
||||
|
||||
requests_per_minute: int = 10
|
||||
requests_per_hour: int = 100
|
||||
requests_per_day: int = 500
|
||||
burst_limit: int = 15
|
||||
burst_window_seconds: int = 60
|
||||
|
||||
|
||||
class AuthenticatedUserRateLimitConfig(RateLimitConfig):
|
||||
"""Rate limits for authenticated users - more generous"""
|
||||
|
||||
requests_per_minute: int = 60
|
||||
requests_per_hour: int = 1000
|
||||
requests_per_day: int = 10000
|
||||
burst_limit: int = 100
|
||||
burst_window_seconds: int = 60
|
||||
|
||||
|
||||
class PremiumUserRateLimitConfig(RateLimitConfig):
|
||||
"""Rate limits for premium/admin users - most generous"""
|
||||
|
||||
requests_per_minute: int = 120
|
||||
requests_per_hour: int = 5000
|
||||
requests_per_day: int = 50000
|
||||
burst_limit: int = 200
|
||||
burst_window_seconds: int = 60
|
||||
|
||||
|
||||
class RateLimitResult(BaseModel):
|
||||
"""Result of rate limit check"""
|
||||
|
||||
allowed: bool
|
||||
reason: Optional[str] = None
|
||||
retry_after_seconds: Optional[int] = None
|
||||
remaining_requests: Dict[str, int] = {}
|
||||
reset_times: Dict[str, datetime] = {}
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Rate limiter using Redis for distributed rate limiting"""
|
||||
|
||||
@ -78,11 +90,7 @@ class RateLimiter:
|
||||
return self.user_config
|
||||
|
||||
async def check_rate_limit(
|
||||
self,
|
||||
user_id: str,
|
||||
user_type: str,
|
||||
is_admin: bool = False,
|
||||
endpoint: Optional[str] = None
|
||||
self, user_id: str, user_type: str, is_admin: bool = False, endpoint: Optional[str] = None
|
||||
) -> RateLimitResult:
|
||||
"""
|
||||
Check if user has exceeded rate limits
|
||||
@ -105,7 +113,7 @@ class RateLimiter:
|
||||
"minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}",
|
||||
"hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}",
|
||||
"day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}",
|
||||
"burst": f"{base_key}:burst"
|
||||
"burst": f"{base_key}:burst",
|
||||
}
|
||||
|
||||
# Add endpoint-specific limiting if provided
|
||||
@ -125,7 +133,7 @@ class RateLimiter:
|
||||
"minute": int(results[0] or 0),
|
||||
"hour": int(results[1] or 0),
|
||||
"day": int(results[2] or 0),
|
||||
"burst": int(results[3] or 0)
|
||||
"burst": int(results[3] or 0),
|
||||
}
|
||||
|
||||
# Check limits
|
||||
@ -133,7 +141,7 @@ class RateLimiter:
|
||||
"minute": config.requests_per_minute,
|
||||
"hour": config.requests_per_hour,
|
||||
"day": config.requests_per_day,
|
||||
"burst": config.burst_limit
|
||||
"burst": config.burst_limit,
|
||||
}
|
||||
|
||||
# Check each limit
|
||||
@ -146,18 +154,22 @@ class RateLimiter:
|
||||
elif window == "hour":
|
||||
retry_after = 3600 - (current_time.minute * 60 + current_time.second)
|
||||
elif window == "day":
|
||||
retry_after = 86400 - (current_time.hour * 3600 + current_time.minute * 60 + current_time.second)
|
||||
retry_after = 86400 - (
|
||||
current_time.hour * 3600 + current_time.minute * 60 + current_time.second
|
||||
)
|
||||
else: # burst
|
||||
retry_after = config.burst_window_seconds
|
||||
|
||||
logger.warning(f"🚫 Rate limit exceeded for {user_type} {user_id}: {current_count}/{limit} {window}")
|
||||
logger.warning(
|
||||
f"🚫 Rate limit exceeded for {user_type} {user_id}: {current_count}/{limit} {window}"
|
||||
)
|
||||
|
||||
return RateLimitResult(
|
||||
allowed=False,
|
||||
reason=f"Rate limit exceeded: {current_count}/{limit} requests per {window}",
|
||||
retry_after_seconds=retry_after,
|
||||
remaining_requests={k: max(0, limits[k] - v) for k, v in current_counts.items()},
|
||||
reset_times=self._calculate_reset_times(current_time)
|
||||
reset_times=self._calculate_reset_times(current_time),
|
||||
)
|
||||
|
||||
# If we get here, request is allowed - increment counters
|
||||
@ -182,17 +194,12 @@ class RateLimiter:
|
||||
await pipe.execute()
|
||||
|
||||
# Calculate remaining requests
|
||||
remaining = {
|
||||
k: max(0, limits[k] - (current_counts[k] + 1))
|
||||
for k in current_counts.keys()
|
||||
}
|
||||
remaining = {k: max(0, limits[k] - (current_counts[k] + 1)) for k in current_counts.keys()}
|
||||
|
||||
logger.debug(f"✅ Rate limit check passed for {user_type} {user_id}")
|
||||
|
||||
return RateLimitResult(
|
||||
allowed=True,
|
||||
remaining_requests=remaining,
|
||||
reset_times=self._calculate_reset_times(current_time)
|
||||
allowed=True, remaining_requests=remaining, reset_times=self._calculate_reset_times(current_time)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@ -206,18 +213,9 @@ class RateLimiter:
|
||||
next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
|
||||
next_day = current_time.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
|
||||
|
||||
return {
|
||||
"minute": next_minute,
|
||||
"hour": next_hour,
|
||||
"day": next_day
|
||||
}
|
||||
return {"minute": next_minute, "hour": next_hour, "day": next_day}
|
||||
|
||||
async def get_user_rate_limit_status(
|
||||
self,
|
||||
user_id: str,
|
||||
user_type: str,
|
||||
is_admin: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
async def get_user_rate_limit_status(self, user_id: str, user_type: str, is_admin: bool = False) -> Dict[str, Any]:
|
||||
"""Get current rate limit status for a user"""
|
||||
config = self.get_config_for_user(user_type, is_admin)
|
||||
current_time = datetime.now(UTC)
|
||||
@ -227,7 +225,7 @@ class RateLimiter:
|
||||
"minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}",
|
||||
"hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}",
|
||||
"day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}",
|
||||
"burst": f"{base_key}:burst"
|
||||
"burst": f"{base_key}:burst",
|
||||
}
|
||||
|
||||
try:
|
||||
@ -240,14 +238,14 @@ class RateLimiter:
|
||||
"minute": int(results[0] or 0),
|
||||
"hour": int(results[1] or 0),
|
||||
"day": int(results[2] or 0),
|
||||
"burst": int(results[3] or 0)
|
||||
"burst": int(results[3] or 0),
|
||||
}
|
||||
|
||||
limits = {
|
||||
"minute": config.requests_per_minute,
|
||||
"hour": config.requests_per_hour,
|
||||
"day": config.requests_per_day,
|
||||
"burst": config.burst_limit
|
||||
"burst": config.burst_limit,
|
||||
}
|
||||
|
||||
return {
|
||||
@ -258,7 +256,7 @@ class RateLimiter:
|
||||
"limits": limits,
|
||||
"remaining": {k: max(0, limits[k] - current_counts[k]) for k in limits.keys()},
|
||||
"reset_times": self._calculate_reset_times(current_time),
|
||||
"config": config.model_dump()
|
||||
"config": config.model_dump(),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@ -295,11 +293,9 @@ class RateLimiter:
|
||||
# Rate Limited Decorator
|
||||
# ============================
|
||||
|
||||
|
||||
def rate_limited(
|
||||
guest_per_minute: int = 10,
|
||||
user_per_minute: int = 60,
|
||||
admin_per_minute: int = 120,
|
||||
endpoint_specific: bool = True
|
||||
guest_per_minute: int = 10, user_per_minute: int = 60, admin_per_minute: int = 120, endpoint_specific: bool = True
|
||||
):
|
||||
"""
|
||||
Decorator to easily apply rate limiting to endpoints
|
||||
@ -320,12 +316,14 @@ def rate_limited(
|
||||
):
|
||||
return {"message": "Rate limited endpoint"}
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Extract dependencies from function signature
|
||||
import inspect
|
||||
sig = inspect.signature(func)
|
||||
|
||||
inspect.signature(func)
|
||||
|
||||
# Get request, current_user, and rate_limiter from kwargs or args
|
||||
request = None
|
||||
@ -336,7 +334,7 @@ def rate_limited(
|
||||
for param_name, param_value in kwargs.items():
|
||||
if isinstance(param_value, Request):
|
||||
request = param_value
|
||||
elif hasattr(param_value, 'user_type'): # User-like object
|
||||
elif hasattr(param_value, "user_type"): # User-like object
|
||||
current_user = param_value
|
||||
elif isinstance(param_value, RateLimiter):
|
||||
rate_limiter = param_value
|
||||
@ -350,30 +348,33 @@ def rate_limited(
|
||||
# Apply rate limiting if we have the required components
|
||||
if request and current_user and rate_limiter:
|
||||
await apply_custom_rate_limiting(
|
||||
request, current_user, rate_limiter,
|
||||
guest_per_minute, user_per_minute, admin_per_minute
|
||||
request, current_user, rate_limiter, guest_per_minute, user_per_minute, admin_per_minute
|
||||
)
|
||||
|
||||
# Call the original function
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
async def apply_custom_rate_limiting(
|
||||
request: Request,
|
||||
current_user,
|
||||
rate_limiter: RateLimiter,
|
||||
guest_per_minute: int,
|
||||
user_per_minute: int,
|
||||
admin_per_minute: int
|
||||
admin_per_minute: int,
|
||||
):
|
||||
"""Apply custom rate limiting with specified limits"""
|
||||
try:
|
||||
# Determine user info
|
||||
user_id = current_user.id
|
||||
user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type)
|
||||
is_admin = getattr(current_user, 'is_admin', False)
|
||||
user_type = (
|
||||
current_user.user_type.value if hasattr(current_user.user_type, "value") else str(current_user.user_type)
|
||||
)
|
||||
is_admin = getattr(current_user, "is_admin", False)
|
||||
|
||||
# Determine appropriate limit
|
||||
if is_admin:
|
||||
@ -385,13 +386,17 @@ async def apply_custom_rate_limiting(
|
||||
|
||||
# Create custom rate limit key
|
||||
current_time = datetime.now(UTC)
|
||||
custom_key = f"custom_rate_limit:{request.url.path}:{user_type}:{user_id}:minute:{current_time.strftime('%Y%m%d%H%M')}"
|
||||
custom_key = (
|
||||
f"custom_rate_limit:{request.url.path}:{user_type}:{user_id}:minute:{current_time.strftime('%Y%m%d%H%M')}"
|
||||
)
|
||||
|
||||
# Check current usage
|
||||
current_count = int(await rate_limiter.redis.get(custom_key) or 0)
|
||||
|
||||
if current_count >= requests_per_minute:
|
||||
logger.warning(f"🚫 Custom rate limit exceeded for {user_type} {user_id}: {current_count}/{requests_per_minute}")
|
||||
logger.warning(
|
||||
f"🚫 Custom rate limit exceeded for {user_type} {user_id}: {current_count}/{requests_per_minute}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
@ -399,9 +404,9 @@ async def apply_custom_rate_limiting(
|
||||
"message": f"Custom rate limit exceeded: {current_count}/{requests_per_minute} requests per minute",
|
||||
"retryAfter": 60 - current_time.second,
|
||||
"userType": user_type,
|
||||
"endpoint": request.url.path
|
||||
"endpoint": request.url.path,
|
||||
},
|
||||
headers={"Retry-After": str(60 - current_time.second)}
|
||||
headers={"Retry-After": str(60 - current_time.second)},
|
||||
)
|
||||
|
||||
# Increment counter
|
||||
@ -410,7 +415,9 @@ async def apply_custom_rate_limiting(
|
||||
pipe.expire(custom_key, 120) # 2 minutes TTL
|
||||
await pipe.execute()
|
||||
|
||||
logger.debug(f"✅ Custom rate limit check passed for {user_type} {user_id}: {current_count + 1}/{requests_per_minute}")
|
||||
logger.debug(
|
||||
f"✅ Custom rate limit check passed for {user_type} {user_id}: {current_count + 1}/{requests_per_minute}"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@ -418,15 +425,13 @@ async def apply_custom_rate_limiting(
|
||||
logger.error(f"❌ Custom rate limiting error: {e}")
|
||||
# Fail open
|
||||
|
||||
|
||||
# ============================
|
||||
# Alternative: FastAPI Dependency-Based Rate Limiting
|
||||
# ============================
|
||||
|
||||
def create_rate_limit_dependency(
|
||||
guest_per_minute: int = 10,
|
||||
user_per_minute: int = 60,
|
||||
admin_per_minute: int = 120
|
||||
):
|
||||
|
||||
def create_rate_limit_dependency(guest_per_minute: int = 10, user_per_minute: int = 60, admin_per_minute: int = 120):
|
||||
"""
|
||||
Create a FastAPI dependency for rate limiting
|
||||
|
||||
@ -441,23 +446,25 @@ def create_rate_limit_dependency(
|
||||
):
|
||||
return {"message": "Rate limited endpoint"}
|
||||
"""
|
||||
|
||||
async def rate_limit_dependency(
|
||||
request: Request,
|
||||
current_user = Depends(get_current_user_or_guest),
|
||||
rate_limiter: RateLimiter = Depends(get_rate_limiter)
|
||||
current_user=Depends(get_current_user_or_guest),
|
||||
rate_limiter: RateLimiter = Depends(get_rate_limiter),
|
||||
):
|
||||
await apply_custom_rate_limiting(
|
||||
request, current_user, rate_limiter,
|
||||
guest_per_minute, user_per_minute, admin_per_minute
|
||||
request, current_user, rate_limiter, guest_per_minute, user_per_minute, admin_per_minute
|
||||
)
|
||||
return True
|
||||
|
||||
return rate_limit_dependency
|
||||
|
||||
|
||||
# ============================
|
||||
# Rate Limiting Utilities
|
||||
# ============================
|
||||
|
||||
|
||||
class EndpointRateLimiter:
|
||||
"""Utility class for endpoint-specific rate limiting"""
|
||||
|
||||
@ -477,9 +484,11 @@ class EndpointRateLimiter:
|
||||
return True # No custom limits set
|
||||
|
||||
limits = self.custom_limits[endpoint]
|
||||
user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type)
|
||||
user_type = (
|
||||
current_user.user_type.value if hasattr(current_user.user_type, "value") else str(current_user.user_type)
|
||||
)
|
||||
|
||||
if getattr(current_user, 'is_admin', False):
|
||||
if getattr(current_user, "is_admin", False):
|
||||
user_type = "admin"
|
||||
|
||||
limit = limits.get(user_type, limits.get("default", 60))
|
||||
@ -491,8 +500,7 @@ class EndpointRateLimiter:
|
||||
|
||||
if current_count >= limit:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Endpoint rate limit exceeded: {current_count}/{limit} for {endpoint}"
|
||||
status_code=429, detail=f"Endpoint rate limit exceeded: {current_count}/{limit} for {endpoint}"
|
||||
)
|
||||
|
||||
# Increment counter
|
||||
@ -501,9 +509,11 @@ class EndpointRateLimiter:
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# Global endpoint rate limiter instance
|
||||
endpoint_rate_limiter = None
|
||||
|
||||
|
||||
def get_endpoint_rate_limiter(rate_limiter: RateLimiter = Depends(get_rate_limiter)) -> EndpointRateLimiter:
|
||||
"""Get endpoint rate limiter instance"""
|
||||
global endpoint_rate_limiter
|
||||
@ -511,15 +521,14 @@ def get_endpoint_rate_limiter(rate_limiter: RateLimiter = Depends(get_rate_limit
|
||||
endpoint_rate_limiter = EndpointRateLimiter(rate_limiter)
|
||||
|
||||
# Configure endpoint-specific limits
|
||||
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/chat/sessions/*/messages/stream", {
|
||||
"guest": 5, "candidate": 30, "employer": 30, "admin": 100
|
||||
})
|
||||
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/candidates/documents/upload", {
|
||||
"guest": 2, "candidate": 10, "employer": 10, "admin": 50
|
||||
})
|
||||
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/jobs", {
|
||||
"guest": 1, "candidate": 5, "employer": 20, "admin": 50
|
||||
})
|
||||
endpoint_rate_limiter.set_endpoint_limits(
|
||||
"/api/1.0/chat/sessions/*/messages/stream", {"guest": 5, "candidate": 30, "employer": 30, "admin": 100}
|
||||
)
|
||||
endpoint_rate_limiter.set_endpoint_limits(
|
||||
"/api/1.0/candidates/documents/upload", {"guest": 2, "candidate": 10, "employer": 10, "admin": 50}
|
||||
)
|
||||
endpoint_rate_limiter.set_endpoint_limits(
|
||||
"/api/1.0/jobs", {"guest": 1, "candidate": 5, "employer": 20, "admin": 50}
|
||||
)
|
||||
|
||||
return endpoint_rate_limiter
|
||||
|
||||
|
@ -3,37 +3,17 @@ Response utility functions for consistent API responses
|
||||
"""
|
||||
from typing import Any, Optional, Dict, List
|
||||
|
||||
|
||||
def create_success_response(data: Any, meta: Optional[Dict] = None) -> Dict:
|
||||
return {
|
||||
"success": True,
|
||||
"data": data,
|
||||
"meta": meta
|
||||
}
|
||||
return {"success": True, "data": data, "meta": meta}
|
||||
|
||||
|
||||
def create_error_response(code: str, message: str, details: Any = None) -> Dict:
|
||||
return {
|
||||
"success": False,
|
||||
"error": {
|
||||
"code": code,
|
||||
"message": message,
|
||||
"details": details
|
||||
}
|
||||
}
|
||||
return {"success": False, "error": {"code": code, "message": message, "details": details}}
|
||||
|
||||
def create_paginated_response(
|
||||
data: List[Any],
|
||||
page: int,
|
||||
limit: int,
|
||||
total: int
|
||||
) -> Dict:
|
||||
|
||||
def create_paginated_response(data: List[Any], page: int, limit: int, total: int) -> Dict:
|
||||
total_pages = (total + limit - 1) // limit
|
||||
has_more = page < total_pages
|
||||
|
||||
return {
|
||||
"data": data,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"totalPages": total_pages,
|
||||
"hasMore": has_more
|
||||
}
|
||||
return {"data": data, "total": total, "page": page, "limit": limit, "totalPages": total_pages, "hasMore": has_more}
|
||||
|
Loading…
x
Reference in New Issue
Block a user