Compare commits

...

2 Commits

Author SHA1 Message Date
2cf3fa7b04 ruff reformat 2025-06-18 13:53:07 -07:00
f1c2e16389 Reformatting with ruff 2025-06-18 13:30:54 -07:00
69 changed files with 5012 additions and 5145 deletions

View File

@ -17,11 +17,9 @@ from logger import logger
AnyAgent: TypeAlias = Agent # BaseModel covers Agent and subclasses
# Maps class_name to (module_name, class_name)
class_registry: Dict[str, Tuple[str, str]] = (
{}
)
class_registry: Dict[str, Tuple[str, str]] = {}
__all__ = ['get_or_create_agent']
__all__ = ["get_or_create_agent"]
package_dir = pathlib.Path(__file__).parent
package_name = __name__
@ -38,16 +36,11 @@ for path in package_dir.glob("*.py"):
# Find all Agent subclasses in the module
for name, obj in inspect.getmembers(module, inspect.isclass):
if (
issubclass(obj, AnyAgent)
and obj is not AnyAgent
and obj is not Agent
and name not in class_registry
):
if issubclass(obj, AnyAgent) and obj is not AnyAgent and obj is not Agent and name not in class_registry:
class_registry[name] = (full_module_name, name)
globals()[name] = obj
logger.info(f"Adding agent: {name}")
__all__.append(name) # type: ignore
__all__.append(name) # type: ignore
except ImportError as e:
logger.error(traceback.format_exc())
logger.error(f"Error importing {full_module_name}: {e}")

View File

@ -1,5 +1,5 @@
from __future__ import annotations
from pydantic import BaseModel, Field, model_validator # type: ignore
from pydantic import BaseModel, Field # type: ignore
from typing import (
Literal,
get_args,
@ -13,10 +13,10 @@ import time
import re
from abc import ABC
from datetime import datetime, UTC
from prometheus_client import Counter, Summary, CollectorRegistry # type: ignore
from prometheus_client import CollectorRegistry # type: ignore
import numpy as np # type: ignore
import json_extractor as json_extractor
from pydantic import BaseModel, Field, model_validator # type: ignore
from pydantic import BaseModel, Field # type: ignore
from uuid import uuid4
from typing import List, Optional, ClassVar, Any, Literal
@ -24,7 +24,7 @@ from datetime import datetime, UTC
import numpy as np # type: ignore
from uuid import uuid4
from prometheus_client import CollectorRegistry, Counter # type: ignore
from prometheus_client import CollectorRegistry # type: ignore
import os
import re
from pathlib import Path
@ -32,19 +32,45 @@ from pathlib import Path
from rag import start_file_watcher, ChromaDBFileWatcher
import defines
from logger import logger
from models import (Tunables, ChatMessageUser, ChatMessage, RagEntry, ChatMessageMetaData, ApiStatusType, Candidate, ChatContextType)
from models import (
Tunables,
ChatMessageUser,
ChatMessage,
RagEntry,
ChatMessageMetaData,
ApiStatusType,
Candidate,
ChatContextType,
)
import utils.llm_proxy as llm_manager
from database.manager import RedisDatabase
from models import ChromaDBGetResponse
from utils.metrics import Metrics
from models import ( ApiActivityType, ApiMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageStatus, ChatMessageStreaming, LLMMessage, ChatMessage, ChatOptions, ChatMessageUser, Tunables, ApiStatusType, ChatMessageMetaData, Candidate)
from models import (
ApiActivityType,
ApiMessage,
ChatMessageError,
ChatMessageRagSearch,
ChatMessageStatus,
ChatMessageStreaming,
LLMMessage,
ChatMessage,
ChatOptions,
ChatMessageUser,
Tunables,
ApiStatusType,
ChatMessageMetaData,
Candidate,
)
from logger import logger
import defines
from .registry import agent_registry
from models import ( ChromaDBGetResponse )
from models import ChromaDBGetResponse
class CandidateEntity(Candidate):
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
@ -59,13 +85,10 @@ class CandidateEntity(Candidate):
CandidateEntity__agents: List[Agent] = []
CandidateEntity__observer: Optional[Any] = Field(default=None, exclude=True)
CandidateEntity__file_watcher: Optional[ChromaDBFileWatcher] = Field(default=None, exclude=True)
CandidateEntity__prometheus_collector: Optional[CollectorRegistry] = Field(
default=None, exclude=True
)
CandidateEntity__prometheus_collector: Optional[CollectorRegistry] = Field(default=None, exclude=True)
CandidateEntity__metrics: Optional[Metrics] = Field(
default=None,
description="Metrics collector for this agent, used to track performance and usage."
default=None, description="Metrics collector for this agent, used to track performance and usage."
)
def __init__(self, candidate=None):
@ -78,7 +101,7 @@ class CandidateEntity(Candidate):
@classmethod
def exists(cls, username: str):
# Validate username format (only allow safe characters)
if not re.match(r'^[a-zA-Z0-9_-]+$', username):
if not re.match(r"^[a-zA-Z0-9_-]+$", username):
return False # Invalid username characters
# Check for minimum and maximum length
@ -117,11 +140,7 @@ class CandidateEntity(Candidate):
if agent.agent_type == agent_type:
return agent
return get_or_create_agent(
agent_type=agent_type,
user=self,
prometheus_collector=self.prometheus_collector
)
return get_or_create_agent(agent_type=agent_type, user=self, prometheus_collector=self.prometheus_collector)
# Wrapper properties that map into file_watcher
@property
@ -132,6 +151,7 @@ class CandidateEntity(Candidate):
# Fields managed by initialize()
CandidateEntity__initialized: bool = Field(default=False, exclude=True)
@property
def metrics(self) -> Metrics:
if not self.CandidateEntity__metrics:
@ -160,15 +180,10 @@ class CandidateEntity(Candidate):
if not self.metrics:
logger.warning("No metrics collector set for this agent.")
return
self.metrics.tokens_prompt.labels(agent=agent.agent_type).inc(
response.usage.prompt_eval_count
)
self.metrics.tokens_prompt.labels(agent=agent.agent_type).inc(response.usage.prompt_eval_count)
self.metrics.tokens_eval.labels(agent=agent.agent_type).inc(response.usage.eval_count)
async def initialize(
self,
prometheus_collector: CollectorRegistry,
database: RedisDatabase):
async def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
if self.CandidateEntity__initialized:
# Initialization can only be attempted once; if there are multiple attempts, it means
# a subsystem is failing or there is a logic bug in the code.
@ -188,8 +203,8 @@ class CandidateEntity(Candidate):
self.CandidateEntity__metrics = Metrics(prometheus_collector=self.prometheus_collector)
user_dir = os.path.join(defines.user_dir, self.username)
vector_db_dir=os.path.join(user_dir, defines.persist_directory)
rag_content_dir=os.path.join(user_dir, defines.rag_content_dir)
vector_db_dir = os.path.join(user_dir, defines.persist_directory)
rag_content_dir = os.path.join(user_dir, defines.rag_content_dir)
os.makedirs(vector_db_dir, exist_ok=True)
os.makedirs(rag_content_dir, exist_ok=True)
@ -205,17 +220,21 @@ class CandidateEntity(Candidate):
)
has_username_rag = any(item.name == self.username for item in self.rags)
if not has_username_rag:
self.rags.append(RagEntry(
name=self.username,
description=f"Expert data about {self.full_name}.",
))
self.rags.append(
RagEntry(
name=self.username,
description=f"Expert data about {self.full_name}.",
)
)
self.rag_content_size = self.file_watcher.collection.count()
class Agent(BaseModel, ABC):
"""
Base class for all agent types.
This class defines the common attributes and methods for all agent types.
"""
class Config:
arbitrary_types_allowed = True # Allow arbitrary types like RedisDatabase
@ -237,7 +256,7 @@ class Agent(BaseModel, ABC):
conversation: List[ChatMessageUser] = Field(
default_factory=list,
description="Conversation history for this agent, used to maintain context across messages."
description="Conversation history for this agent, used to maintain context across messages.",
)
@property
@ -254,9 +273,7 @@ class Agent(BaseModel, ABC):
last_item = item
return last_item
def set_optimal_context_size(
self, llm: Any, model: str, prompt: str, ctx_buffer=2048
) -> int:
def set_optimal_context_size(self, llm: Any, model: str, prompt: str, ctx_buffer=2048) -> int:
# Most models average 1.3-1.5 tokens per word
word_count = len(prompt.split())
tokens = int(word_count * 1.4)
@ -265,9 +282,7 @@ class Agent(BaseModel, ABC):
total_ctx = tokens + ctx_buffer
if total_ctx > self.context_size:
logger.info(
f"Increasing context size from {self.context_size} to {total_ctx}"
)
logger.info(f"Increasing context size from {self.context_size} to {total_ctx}")
# Grow the context size if necessary
self.context_size = max(self.context_size, total_ctx)
@ -472,24 +487,24 @@ class Agent(BaseModel, ABC):
context = []
for chroma_results in rag_message.content:
for index, metadata in enumerate(chroma_results.metadatas):
content = "\n".join([
line.strip()
for line in chroma_results.documents[index].split("\n")
if line
]).strip()
context.append(f"""
content = "\n".join(
[line.strip() for line in chroma_results.documents[index].split("\n") if line]
).strip()
context.append(
f"""
Source: {metadata.get("doc_type", "unknown")}: {metadata.get("path", "")}
Document reference: {chroma_results.ids[index]}
Content: {content}
""")
"""
)
return "\n".join(context)
async def generate_rag_results(
self,
session_id: str,
prompt: str,
top_k: int=defines.default_rag_top_k,
threshold: float=defines.default_rag_threshold,
top_k: int = defines.default_rag_top_k,
threshold: float = defines.default_rag_threshold,
) -> AsyncGenerator[ApiMessage, None]:
"""
Generate RAG results for the given query.
@ -501,15 +516,11 @@ Content: {content}
A list of dictionaries containing the RAG results.
"""
if not self.user:
error_message = ChatMessageError(
session_id=session_id,
content="No user set for RAG generation."
)
error_message = ChatMessageError(session_id=session_id, content="No user set for RAG generation.")
yield error_message
return
results : List[ChromaDBGetResponse] = []
entries: int = 0
results: List[ChromaDBGetResponse] = []
user: CandidateEntity = self.user
for rag in user.rags:
if not rag.enabled:
@ -518,20 +529,18 @@ Content: {content}
status_message = ChatMessageStatus(
session_id=session_id,
activity=ApiActivityType.SEARCHING,
content = f"Searching RAG context {rag.name}..."
content=f"Searching RAG context {rag.name}...",
)
yield status_message
try:
chroma_results = await user.file_watcher.find_similar(
query=prompt, top_k=top_k, threshold=threshold
)
chroma_results = await user.file_watcher.find_similar(query=prompt, top_k=top_k, threshold=threshold)
if not chroma_results:
continue
query_embedding = np.array(chroma_results["query_embedding"]).flatten() # type: ignore
query_embedding = np.array(chroma_results["query_embedding"]).flatten() # type: ignore
umap_2d = user.file_watcher.umap_model_2d.transform([query_embedding])[0] # type: ignore
umap_3d = user.file_watcher.umap_model_3d.transform([query_embedding])[0] # type: ignore
umap_2d = user.file_watcher.umap_model_2d.transform([query_embedding])[0] # type: ignore
umap_3d = user.file_watcher.umap_model_3d.transform([query_embedding])[0] # type: ignore
rag_metadata = ChromaDBGetResponse(
name=rag.name,
@ -549,7 +558,7 @@ Content: {content}
continue_message = ChatMessageStatus(
session_id=session_id,
activity=ApiActivityType.SEARCHING,
content=f"Error searching RAG context {rag.name}: {str(e)}"
content=f"Error searching RAG context {rag.name}: {str(e)}",
)
yield continue_message
@ -562,23 +571,21 @@ Content: {content}
return
async def llm_one_shot(
self,
llm: Any, model: str,
session_id: str, prompt: str, system_prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7) -> AsyncGenerator[ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessage, None]:
self,
llm: Any,
model: str,
session_id: str,
prompt: str,
system_prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7,
) -> AsyncGenerator[ChatMessageStatus | ChatMessageError | ChatMessageStreaming | ChatMessage, None]:
if not self.user:
error_message = ChatMessageError(
session_id=session_id,
content="No user set for chat generation."
)
error_message = ChatMessageError(session_id=session_id, content="No user set for chat generation.")
yield error_message
return
self.set_optimal_context_size(
llm=llm, model=model, prompt=prompt+system_prompt
)
self.set_optimal_context_size(llm=llm, model=model, prompt=prompt + system_prompt)
options = ChatOptions(
seed=8911,
@ -592,9 +599,7 @@ Content: {content}
]
status_message = ChatMessageStatus(
session_id=session_id,
activity=ApiActivityType.GENERATING,
content=f"Generating response..."
session_id=session_id, activity=ApiActivityType.GENERATING, content="Generating response..."
)
yield status_message
@ -610,10 +615,7 @@ Content: {content}
stream=True,
):
if not response:
error_message = ChatMessageError(
session_id=session_id,
content="No response from LLM."
)
error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
yield error_message
return
@ -628,46 +630,34 @@ Content: {content}
yield streaming_message
if not response:
error_message = ChatMessageError(
session_id=session_id,
content="No response from LLM."
)
error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
yield error_message
return
self.user.collect_metrics(agent=self, response=response)
self.context_tokens = (
response.usage.prompt_eval_count + response.usage.eval_count
)
self.context_tokens = response.usage.prompt_eval_count + response.usage.eval_count
chat_message = ChatMessage(
session_id=session_id,
tunables=tunables,
status=ApiStatusType.DONE,
content=content,
metadata = ChatMessageMetaData(
metadata=ChatMessageMetaData(
options=options,
eval_count=response.usage.eval_count,
eval_duration=response.usage.eval_duration,
prompt_eval_count=response.usage.prompt_eval_count,
prompt_eval_duration=response.usage.prompt_eval_duration,
)
),
)
yield chat_message
return
async def generate(
self, llm: Any, model: str,
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
) -> AsyncGenerator[ApiMessage, None]:
if not self.user:
error_message = ChatMessageError(
session_id=session_id,
content="No user set for chat generation."
)
error_message = ChatMessageError(session_id=session_id, content="No user set for chat generation.")
yield error_message
return
@ -675,12 +665,11 @@ Content: {content}
session_id=session_id,
content=prompt,
)
user = self.user
self.user.metrics.generate_count.labels(agent=self.agent_type).inc()
with self.user.metrics.generate_duration.labels(agent=self.agent_type).time():
context = None
rag_message : ChatMessageRagSearch | None = None
rag_message: ChatMessageRagSearch | None = None
if self.user:
message = None
async for message in self.generate_rag_results(session_id=session_id, prompt=prompt):
@ -692,38 +681,32 @@ Content: {content}
yield message
if not isinstance(message, ChatMessageRagSearch):
raise ValueError(
f"Expected ChatMessageRagSearch, got {type(rag_message)}"
)
raise ValueError(f"Expected ChatMessageRagSearch, got {type(rag_message)}")
rag_message = message
context = self.get_rag_context(rag_message)
# Create a pruned down message list based purely on the prompt and responses,
# discarding the full preamble generated by prepare_message
messages: List[LLMMessage] = [
LLMMessage(role="system", content=self.system_prompt)
]
messages: List[LLMMessage] = [LLMMessage(role="system", content=self.system_prompt)]
# Add the conversation history to the messages
messages.extend([
LLMMessage(role="user" if isinstance(m, ChatMessageUser) else "assistant", content=m.content)
for m in self.conversation
])
messages.extend(
[
LLMMessage(role="user" if isinstance(m, ChatMessageUser) else "assistant", content=m.content)
for m in self.conversation
]
)
# Add the RAG context to the messages if available
if context:
messages.append(
LLMMessage(
role="user",
content=f"<|context|>\nThe following is context information about {self.user.full_name}:\n{context}\n</|context|>\n\nPrompt to respond to:\n{prompt}\n"
content=f"<|context|>\nThe following is context information about {self.user.full_name}:\n{context}\n</|context|>\n\nPrompt to respond to:\n{prompt}\n",
)
)
else:
# Only the actual user query is provided with the full context message
messages.append(
LLMMessage(role="user", content=prompt)
)
llm_history = messages
messages.append(LLMMessage(role="user", content=prompt))
# use_tools = message.tunables.enable_tools and len(self.context.tools) > 0
# message.metadata.tools = {
@ -827,16 +810,12 @@ Content: {content}
# not use_tools
status_message = ChatMessageStatus(
session_id=session_id,
activity=ApiActivityType.GENERATING,
content=f"Generating response..."
session_id=session_id, activity=ApiActivityType.GENERATING, content="Generating response..."
)
yield status_message
# Set the response for streaming
self.set_optimal_context_size(
llm, model, prompt=prompt
)
self.set_optimal_context_size(llm, model, prompt=prompt)
options = ChatOptions(
seed=8911,
@ -856,10 +835,7 @@ Content: {content}
stream=True,
):
if not response:
error_message = ChatMessageError(
session_id=session_id,
content="No response from LLM."
)
error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
yield error_message
return
@ -873,17 +849,12 @@ Content: {content}
yield streaming_message
if not response:
error_message = ChatMessageError(
session_id=session_id,
content="No response from LLM."
)
error_message = ChatMessageError(session_id=session_id, content="No response from LLM.")
yield error_message
return
self.user.collect_metrics(agent=self, response=response)
self.context_tokens = (
response.usage.prompt_eval_count + response.usage.eval_count
)
self.context_tokens = response.usage.prompt_eval_count + response.usage.eval_count
end_time = time.perf_counter()
chat_message = ChatMessage(
@ -891,7 +862,7 @@ Content: {content}
tunables=tunables,
status=ApiStatusType.DONE,
content=content,
metadata = ChatMessageMetaData(
metadata=ChatMessageMetaData(
options=options,
eval_count=response.usage.eval_count,
eval_duration=response.usage.eval_duration,
@ -902,10 +873,9 @@ Content: {content}
"llm_streamed": end_time - start_time,
"llm_with_tools": 0, # Placeholder for tool processing time
},
)
),
)
# Add the user and chat messages to the conversation
self.conversation.append(user_message)
self.conversation.append(chat_message)
@ -999,12 +969,13 @@ Content: {content}
raise ValueError("No Markdown found in the response")
_agents: List[Agent] = []
def get_or_create_agent(
agent_type: str,
prometheus_collector: CollectorRegistry,
user: Optional[CandidateEntity]=None) -> Agent:
agent_type: str, prometheus_collector: CollectorRegistry, user: Optional[CandidateEntity] = None
) -> Agent:
"""
Get or create and append a new agent of the specified type, ensuring only one agent per type exists.
@ -1028,14 +999,16 @@ def get_or_create_agent(
for agent_cls in Agent.__subclasses__():
if agent_cls.model_fields["agent_type"].default == agent_type:
# Create the agent instance with provided kwargs
agent = agent_cls(agent_type=agent_type, # type: ignore[call-arg]
user=user)
agent = agent_cls(
agent_type=agent_type, # type: ignore[call-arg]
user=user,
)
_agents.append(agent)
return agent
raise ValueError(f"No agent class found for agent_type: {agent_type}")
# Register the base agent
agent_registry.register(Agent._agent_type, Agent)
CandidateEntity.model_rebuild()

View File

@ -5,10 +5,10 @@ from .base import Agent, agent_registry
from logger import logger
from .registry import agent_registry
from models import ( ApiMessage, Tunables, ApiStatusType)
from models import ApiMessage, Tunables, ApiStatusType
system_message = f"""
system_message = """
When answering queries, follow these steps:
- When any content from <|context|> is relevant, synthesize information from all sources to provide the most complete answer.
@ -21,21 +21,19 @@ Always <|context|> when possible. Be concise, and never make up information. If
Before answering, ensure you have spelled the candidate's name correctly.
"""
class CandidateChat(Agent):
"""
CandidateChat Agent
"""
agent_type: Literal["candidate_chat"] = "candidate_chat" # type: ignore
agent_type: Literal["candidate_chat"] = "candidate_chat" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
system_prompt: str = system_message
async def generate(
self, llm: Any, model: str,
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
) -> AsyncGenerator[ApiMessage, None]:
user = self.user
if not user:
@ -54,12 +52,14 @@ Use that spelling instead of any spelling you may find in the <|context|>.
{system_message}
"""
async for message in super().generate(llm=llm, model=model, session_id=session_id, prompt=prompt, temperature=temperature, tunables=tunables):
async for message in super().generate(
llm=llm, model=model, session_id=session_id, prompt=prompt, temperature=temperature, tunables=tunables
):
if message.status == ApiStatusType.ERROR:
yield message
return
yield message
# Register the base agent
agent_registry.register(CandidateChat._agent_type, CandidateChat)

View File

@ -49,11 +49,13 @@ class Chat(Agent):
"""
Chat Agent
"""
agent_type: Literal["general"] = "general" # type: ignore
agent_type: Literal["general"] = "general" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
system_prompt: str = system_message
# async def prepare_message(self, message: Message) -> AsyncGenerator[Message, None]:
# logger.info(f"{self.agent_type} - {inspect.stack()[0].function}")
# if not self.context:

View File

@ -1,14 +1,11 @@
from __future__ import annotations
from typing import (
Dict,
Literal,
ClassVar,
cast,
Any,
AsyncGenerator,
List,
Optional
# override
Optional,
# override
) # NOTE: You must import Optional for late binding to work
import random
import time
@ -16,7 +13,15 @@ import time
import os
from .base import Agent, agent_registry
from models import ApiActivityType, ChatMessage, ChatMessageError, ChatMessageStatus, ChatMessageStreaming, ApiStatusType, Tunables
from models import (
ApiActivityType,
ChatMessage,
ChatMessageError,
ChatMessageStatus,
ChatMessageStreaming,
ApiStatusType,
Tunables,
)
from logger import logger
import defines
import backstory_traceback as traceback
@ -26,18 +31,16 @@ from image_generator.profile_image import generate_image, ImageRequest
seed = int(time.time())
random.seed(seed)
class ImageGenerator(Agent):
agent_type: Literal["generate_image"] = "generate_image" # type: ignore
agent_type: Literal["generate_image"] = "generate_image" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
agent_persist: bool = False
system_prompt: str = "" # No system prompt is used
system_prompt: str = "" # No system prompt is used
async def generate(
self, llm: Any, model: str,
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
if not self.user:
logger.error("User is not set for ImageGenerator agent.")
@ -57,11 +60,17 @@ class ImageGenerator(Agent):
yield status_message
logger.info(f"Image generation: {file_path} <- {prompt}")
request = ImageRequest(filepath=file_path, session_id=session_id, prompt=prompt, iterations=4, height=256, width=256, guidance_scale=7.5)
request = ImageRequest(
filepath=file_path,
session_id=session_id,
prompt=prompt,
iterations=4,
height=256,
width=256,
guidance_scale=7.5,
)
generated_message = None
async for generated_message in generate_image(
request=request
):
async for generated_message in generate_image(request=request):
if generated_message.status == ApiStatusType.ERROR:
yield generated_message
return
@ -71,8 +80,7 @@ class ImageGenerator(Agent):
if generated_message is None:
error_message = ChatMessageError(
session_id=session_id,
content="Image generation failed to produce a valid response."
session_id=session_id, content="Image generation failed to produce a valid response."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
@ -86,21 +94,19 @@ class ImageGenerator(Agent):
generated_image = ChatMessage(
session_id=session_id,
status=ApiStatusType.DONE,
content = f"{defines.api_prefix}/profile/{user.username}",
metadata=generated_message.metadata
content=f"{defines.api_prefix}/profile/{user.username}",
metadata=generated_message.metadata,
)
yield generated_image
return
except Exception as e:
error_message = ChatMessageError(
session_id=session_id,
content=f"Error generating image: {str(e)}"
)
error_message = ChatMessageError(session_id=session_id, content=f"Error generating image: {str(e)}")
logger.error(traceback.format_exc())
logger.error(f"⚠️ {error_message.content}")
yield error_message
return
# Register the base agent
agent_registry.register(ImageGenerator._agent_type, ImageGenerator)

View File

@ -1,16 +1,15 @@
from __future__ import annotations
from pydantic import model_validator, Field, BaseModel # type: ignore
from pydantic import Field # type: ignore
from typing import (
Dict,
Literal,
ClassVar,
cast,
Any,
Tuple,
AsyncGenerator,
List,
Optional
# override
Optional,
# override
) # NOTE: You must import Optional for late binding to work
import random
import re
@ -19,10 +18,18 @@ import time
import time
import os
import random
from names_dataset import NameDataset, NameWrapper # type: ignore
from .base import Agent, agent_registry
from models import ApiActivityType, ChatMessage, ChatMessageError, ApiMessageType, ChatMessageStatus, ChatMessageStreaming, ApiStatusType, Tunables
from models import (
ApiActivityType,
ChatMessage,
ChatMessageError,
ApiMessageType,
ChatMessageStatus,
ChatMessageStreaming,
ApiStatusType,
Tunables,
)
from logger import logger
import defines
import backstory_traceback as traceback
@ -45,6 +52,7 @@ emptyUser = {
"questions": [],
}
def generate_persona_system_prompt(persona: Dict[str, Any]) -> str:
return f"""\
You are a casting director for a movie. Your job is to provide information on ficticious personas for use in a screen play.
@ -86,6 +94,7 @@ DO NOT infer, imply, abbreviate, or state the ethnicity, gender, or age in the u
You are providing those only for use later by the system when casting individuals for the role.
"""
generate_resume_system_prompt = """
You are a creative writing casting director. As part of the casting, you are building backstories about individuals. The first part
of that is to create an in-depth resume for the person. You will be provided with the following information:
@ -117,10 +126,12 @@ import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EthnicNameGenerator:
def __init__(self):
try:
from names_dataset import NameDataset # type: ignore
from names_dataset import NameDataset # type: ignore
self.nd = NameDataset()
except ImportError:
logger.error("NameDataset not available. Please install: pip install names-dataset")
@ -131,24 +142,24 @@ class EthnicNameGenerator:
# US Census 2020 approximate ethnic distribution
self.ethnic_weights = {
'White': 0.576,
'Hispanic': 0.186,
'Black': 0.134,
'Asian': 0.062,
'Native American': 0.013,
'Pacific Islander': 0.003,
'Mixed/Other': 0.026
"White": 0.576,
"Hispanic": 0.186,
"Black": 0.134,
"Asian": 0.062,
"Native American": 0.013,
"Pacific Islander": 0.003,
"Mixed/Other": 0.026,
}
# Map ethnicities to countries (using alpha-2 codes that NameDataset uses)
self.ethnic_country_mapping = {
'White': ['US', 'GB', 'DE', 'IE', 'IT', 'PL', 'FR', 'CA', 'AU'],
'Hispanic': ['MX', 'ES', 'CO', 'PE', 'AR', 'CU', 'VE', 'CL'],
'Black': ['US'], # African American names
'Asian': ['CN', 'IN', 'PH', 'VN', 'KR', 'JP', 'TH', 'MY'],
'Native American': ['US'],
'Pacific Islander': ['US'],
'Mixed/Other': ['US']
"White": ["US", "GB", "DE", "IE", "IT", "PL", "FR", "CA", "AU"],
"Hispanic": ["MX", "ES", "CO", "PE", "AR", "CU", "VE", "CL"],
"Black": ["US"], # African American names
"Asian": ["CN", "IN", "PH", "VN", "KR", "JP", "TH", "MY"],
"Native American": ["US"],
"Pacific Islander": ["US"],
"Mixed/Other": ["US"],
}
def get_weighted_ethnicity(self) -> str:
@ -157,8 +168,9 @@ class EthnicNameGenerator:
weights = list(self.ethnic_weights.values())
return random.choices(ethnicities, weights=weights)[0]
def get_names_by_criteria(self, countries: List[str], gender: Optional[str] = None,
n: int = 50, use_first_names: bool = True) -> List[str]:
def get_names_by_criteria(
self, countries: List[str], gender: Optional[str] = None, n: int = 50, use_first_names: bool = True
) -> List[str]:
"""Get names matching criteria using NameDataset's get_top_names method"""
if not self.nd:
return []
@ -168,16 +180,13 @@ class EthnicNameGenerator:
try:
# Get top names for this country
top_names = self.nd.get_top_names(
n=n,
use_first_names=use_first_names,
country_alpha2=country_code,
gender=gender
n=n, use_first_names=use_first_names, country_alpha2=country_code, gender=gender
)
if country_code in top_names:
if use_first_names and gender:
# For first names with gender specified
gender_key = 'M' if gender.upper() in ['M', 'MALE'] else 'F'
gender_key = "M" if gender.upper() in ["M", "MALE"] else "F"
if gender_key in top_names[country_code]:
all_names.extend(top_names[country_code][gender_key])
elif use_first_names:
@ -194,25 +203,18 @@ class EthnicNameGenerator:
return list(set(all_names)) # Remove duplicates
def get_name_by_ethnicity(self, ethnicity: str, gender: str = 'random') -> Tuple[str, str, str, str]:
def get_name_by_ethnicity(self, ethnicity: str, gender: str = "random") -> Tuple[str, str, str, str]:
"""Generate a name based on ethnicity using the correct NameDataset API"""
if gender == 'random':
gender = random.choice(['Male', 'Female'])
if gender == "random":
gender = random.choice(["Male", "Female"])
countries = self.ethnic_country_mapping.get(ethnicity, ['US'])
countries = self.ethnic_country_mapping.get(ethnicity, ["US"])
# Get first names
first_names = self.get_names_by_criteria(
countries=countries,
gender=gender,
use_first_names=True
)
first_names = self.get_names_by_criteria(countries=countries, gender=gender, use_first_names=True)
# Get last names
last_names = self.get_names_by_criteria(
countries=countries,
use_first_names=False
)
last_names = self.get_names_by_criteria(countries=countries, use_first_names=False)
# Select names or use fallbacks
if first_names:
@ -232,57 +234,60 @@ class EthnicNameGenerator:
def _get_fallback_first_name(self, gender: str, ethnicity: str) -> str:
"""Provide culturally appropriate fallback first names"""
fallback_names = {
'White': {
'Male': ['James', 'Robert', 'John', 'Michael', 'William', 'David', 'Richard', 'Joseph'],
'Female': ['Mary', 'Patricia', 'Jennifer', 'Linda', 'Elizabeth', 'Barbara', 'Susan', 'Jessica']
"White": {
"Male": ["James", "Robert", "John", "Michael", "William", "David", "Richard", "Joseph"],
"Female": ["Mary", "Patricia", "Jennifer", "Linda", "Elizabeth", "Barbara", "Susan", "Jessica"],
},
'Hispanic': {
'Male': ['José', 'Luis', 'Miguel', 'Juan', 'Francisco', 'Alejandro', 'Antonio', 'Carlos'],
'Female': ['María', 'Guadalupe', 'Juana', 'Margarita', 'Francisca', 'Teresa', 'Rosa', 'Ana']
"Hispanic": {
"Male": ["José", "Luis", "Miguel", "Juan", "Francisco", "Alejandro", "Antonio", "Carlos"],
"Female": ["María", "Guadalupe", "Juana", "Margarita", "Francisca", "Teresa", "Rosa", "Ana"],
},
'Black': {
'Male': ['James', 'Robert', 'John', 'Michael', 'William', 'David', 'Richard', 'Charles'],
'Female': ['Mary', 'Patricia', 'Linda', 'Elizabeth', 'Barbara', 'Susan', 'Jessica', 'Sarah']
"Black": {
"Male": ["James", "Robert", "John", "Michael", "William", "David", "Richard", "Charles"],
"Female": ["Mary", "Patricia", "Linda", "Elizabeth", "Barbara", "Susan", "Jessica", "Sarah"],
},
"Asian": {
"Male": ["Wei", "Ming", "Chen", "Li", "Kumar", "Raj", "Hiroshi", "Takeshi"],
"Female": ["Mei", "Lin", "Ling", "Priya", "Yuki", "Soo", "Hana", "Anh"],
},
'Asian': {
'Male': ['Wei', 'Ming', 'Chen', 'Li', 'Kumar', 'Raj', 'Hiroshi', 'Takeshi'],
'Female': ['Mei', 'Lin', 'Ling', 'Priya', 'Yuki', 'Soo', 'Hana', 'Anh']
}
}
ethnicity_names = fallback_names.get(ethnicity, fallback_names['White'])
return random.choice(ethnicity_names.get(gender, ethnicity_names['Male']))
ethnicity_names = fallback_names.get(ethnicity, fallback_names["White"])
return random.choice(ethnicity_names.get(gender, ethnicity_names["Male"]))
def _get_fallback_last_name(self, ethnicity: str) -> str:
"""Provide culturally appropriate fallback last names"""
fallback_surnames = {
'White': ['Smith', 'Johnson', 'Williams', 'Brown', 'Jones', 'Miller', 'Wilson', 'Moore'],
'Hispanic': ['García', 'Rodríguez', 'Martínez', 'López', 'González', 'Pérez', 'Sánchez', 'Ramírez'],
'Black': ['Johnson', 'Williams', 'Brown', 'Jones', 'Davis', 'Miller', 'Wilson', 'Moore'],
'Asian': ['Li', 'Wang', 'Zhang', 'Liu', 'Chen', 'Yang', 'Huang', 'Zhao']
"White": ["Smith", "Johnson", "Williams", "Brown", "Jones", "Miller", "Wilson", "Moore"],
"Hispanic": ["García", "Rodríguez", "Martínez", "López", "González", "Pérez", "Sánchez", "Ramírez"],
"Black": ["Johnson", "Williams", "Brown", "Jones", "Davis", "Miller", "Wilson", "Moore"],
"Asian": ["Li", "Wang", "Zhang", "Liu", "Chen", "Yang", "Huang", "Zhao"],
}
return random.choice(fallback_surnames.get(ethnicity, fallback_surnames['White']))
return random.choice(fallback_surnames.get(ethnicity, fallback_surnames["White"]))
def generate_random_name(self, gender: str = 'random') -> Tuple[str, str, str, str]:
def generate_random_name(self, gender: str = "random") -> Tuple[str, str, str, str]:
"""Generate a random name with ethnicity based on US demographics"""
ethnicity = self.get_weighted_ethnicity()
return self.get_name_by_ethnicity(ethnicity, gender)
def generate_multiple_names(self, count: int = 10, gender: str = 'random') -> List[Dict]:
def generate_multiple_names(self, count: int = 10, gender: str = "random") -> List[Dict]:
"""Generate multiple random names"""
names = []
for _ in range(count):
first, last, ethnicity, actual_gender = self.generate_random_name(gender)
names.append({
'full_name': f"{first} {last}",
'first_name': first,
'last_name': last,
'ethnicity': ethnicity,
'gender': actual_gender
})
names.append(
{
"full_name": f"{first} {last}",
"first_name": first,
"last_name": last,
"ethnicity": ethnicity,
"gender": actual_gender,
}
)
return names
class GeneratePersona(Agent):
agent_type: Literal["generate_persona"] = "generate_persona" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
@ -307,23 +312,20 @@ class GeneratePersona(Agent):
self.full_name = f"{self.first_name} {self.last_name}"
async def generate(
self, llm: Any, model: str,
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError | ChatMessageStreaming, None]:
self.randomize()
original_prompt = prompt.strip()
persona = {
"age": self.age,
"gender": self.gender,
"ethnicity": self.ethnicity,
"full_name": self.full_name,
"first_name": self.first_name,
"last_name": self.last_name,
}
"age": self.age,
"gender": self.gender,
"ethnicity": self.ethnicity,
"full_name": self.full_name,
"first_name": self.first_name,
"last_name": self.last_name,
}
prompt = f"""\
```json
@ -339,10 +341,11 @@ Incorporate the following into the job description: {original_prompt}
#
# Generate the persona
#
logger.info(f"🤖 Generating persona...")
logger.info("🤖 Generating persona...")
generating_message = None
async for generating_message in self.llm_one_shot(
llm=llm, model=model,
llm=llm,
model=model,
session_id=session_id,
prompt=prompt,
system_prompt=generate_persona_system_prompt(persona=persona),
@ -356,8 +359,7 @@ Incorporate the following into the job description: {original_prompt}
if not generating_message:
error_message = ChatMessageError(
session_id=session_id,
content="Persona generation failed to generate a response."
session_id=session_id, content="Persona generation failed to generate a response."
)
yield error_message
return
@ -375,7 +377,7 @@ Incorporate the following into the job description: {original_prompt}
self.username = persona.get("username", None)
if not self.username:
raise ValueError("LLM did not generate a username")
self.username = re.sub(r'\s+', '.', self.username)
self.username = re.sub(r"\s+", ".", self.username)
user_dir = os.path.join(defines.user_dir, persona["username"])
while os.path.exists(user_dir):
match = re.match(r"^(.*?)(\d*)$", persona["username"])
@ -398,19 +400,14 @@ Incorporate the following into the job description: {original_prompt}
location_parts = persona["location"].split(",")
if len(location_parts) == 3:
city, state, country = [part.strip() for part in location_parts]
persona["location"] = {
"city": city,
"state": state,
"country": country
}
persona["location"] = {"city": city, "state": state, "country": country}
else:
logger.error(f"Invalid location format: {persona['location']}")
persona["location"] = None
persona["is_ai"] = True
except Exception as e:
error_message = ChatMessageError(
session_id=session_id,
content=f"Error parsing LLM response: {str(e)}\n\n{json_str}"
session_id=session_id, content=f"Error parsing LLM response: {str(e)}\n\n{json_str}"
)
logger.error(f"❌ Error parsing LLM response: {error_message.content}")
logger.error(traceback.format_exc())
@ -422,10 +419,7 @@ Incorporate the following into the job description: {original_prompt}
# Persona generated
persona_message = ChatMessage(
session_id=session_id,
status=ApiStatusType.DONE,
type=ApiMessageType.JSON,
content = json.dumps(persona)
session_id=session_id, status=ApiStatusType.DONE, type=ApiMessageType.JSON, content=json.dumps(persona)
)
yield persona_message
@ -434,8 +428,8 @@ Incorporate the following into the job description: {original_prompt}
#
status_message = ChatMessageStatus(
session_id=session_id,
activity = ApiActivityType.THINKING,
content = f"Generating resume for {persona['full_name']}..."
activity=ApiActivityType.THINKING,
content=f"Generating resume for {persona['full_name']}...",
)
logger.info(f"🤖 {status_message.content}")
yield status_message
@ -458,7 +452,8 @@ Incorporate the following into the job description: {original_prompt}
Make sure at least one of the candidate's job descriptions take into account the following: {original_prompt}."""
async for generating_message in self.llm_one_shot(
llm=llm, model=model,
llm=llm,
model=model,
session_id=session_id,
prompt=content,
system_prompt=generate_resume_system_prompt,
@ -472,8 +467,7 @@ Make sure at least one of the candidate's job descriptions take into account the
if not generating_message:
error_message = ChatMessageError(
session_id=session_id,
content="Resume generation failed to generate a response."
session_id=session_id, content="Resume generation failed to generate a response."
)
logger.error(f"{error_message.content}")
yield error_message
@ -481,10 +475,7 @@ Make sure at least one of the candidate's job descriptions take into account the
resume = self.extract_markdown_from_text(generating_message.content)
resume_message = ChatMessage(
session_id=session_id,
status=ApiStatusType.DONE,
type=ApiMessageType.TEXT,
content=resume
session_id=session_id, status=ApiStatusType.DONE, type=ApiMessageType.TEXT, content=resume
)
yield resume_message
return
@ -504,5 +495,6 @@ Make sure at least one of the candidate's job descriptions take into account the
raise ValueError("No JSON found in the response")
# Register the base agent
agent_registry.register(GeneratePersona._agent_type, GeneratePersona)

View File

@ -1,30 +1,34 @@
from __future__ import annotations
from pydantic import model_validator, Field # type: ignore
from typing import (
Dict,
Literal,
ClassVar,
Any,
AsyncGenerator,
List,
Optional
# override
# override
) # NOTE: You must import Optional for late binding to work
import json
import numpy as np # type: ignore
from logger import logger
from .base import Agent, agent_registry
from models import (ApiActivityType, ApiMessage, ApiStatusType, ChatMessage, ChatMessageError, ChatMessageResume, ChatMessageStatus, SkillAssessment, SkillStrength)
from models import (
ApiActivityType,
ApiMessage,
ApiStatusType,
ChatMessage,
ChatMessageError,
ChatMessageResume,
ChatMessageStatus,
SkillAssessment,
SkillStrength,
)
class GenerateResume(Agent):
agent_type: Literal["generate_resume"] = "generate_resume" # type: ignore
agent_type: Literal["generate_resume"] = "generate_resume" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
def generate_resume_prompt(
self,
skills: List[SkillAssessment]
):
def generate_resume_prompt(self, skills: List[SkillAssessment]):
"""
Generate a professional resume based on skill assessment results
@ -45,7 +49,7 @@ class GenerateResume(Agent):
SkillStrength.STRONG: [],
SkillStrength.MODERATE: [],
SkillStrength.WEAK: [],
SkillStrength.NONE: []
SkillStrength.NONE: [],
}
experience_evidence = {}
@ -67,11 +71,7 @@ class GenerateResume(Agent):
experience_evidence[source] = []
experience_evidence[source].append(
{
"skill": skill,
"quote": evidence.quote,
"context": evidence.context
}
{"skill": skill, "quote": evidence.quote, "context": evidence.context}
)
# Build the system prompt
@ -171,16 +171,16 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
) -> AsyncGenerator[ApiMessage, None]:
# Stage 1A: Analyze job requirements
status_message = ChatMessageStatus(
session_id=session_id,
content = f"Analyzing job requirements",
activity=ApiActivityType.THINKING
session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING
)
yield status_message
system_prompt, prompt = self.generate_resume_prompt(skills=skills)
generated_message = None
async for generated_message in self.llm_one_shot(llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt):
async for generated_message in self.llm_one_shot(
llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt
):
if generated_message.status == ApiStatusType.ERROR:
yield generated_message
return
@ -189,8 +189,7 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
if not generated_message:
error_message = ChatMessageError(
session_id=session_id,
content="Job requirements analysis failed to generate a response."
session_id=session_id, content="Job requirements analysis failed to generate a response."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
@ -198,8 +197,7 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
if not isinstance(generated_message, ChatMessage):
error_message = ChatMessageError(
session_id=session_id,
content="Job requirements analysis did not return a valid message."
session_id=session_id, content="Job requirements analysis did not return a valid message."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
@ -215,8 +213,9 @@ Format it in clean, ATS-friendly markdown. Provide ONLY the resume with no comme
system_prompt=system_prompt,
)
yield resume_message
logger.info(f"✅ Resume generation completed successfully.")
logger.info("✅ Resume generation completed successfully.")
return
# Register the base agent
agent_registry.register(GenerateResume._agent_type, GenerateResume)

View File

@ -1,24 +1,34 @@
from __future__ import annotations
from pydantic import model_validator, Field # type: ignore
from typing import (
Dict,
Literal,
ClassVar,
Any,
AsyncGenerator,
List,
Optional
# override
Optional,
# override
) # NOTE: You must import Optional for late binding to work
import inspect
import json
import numpy as np # type: ignore
from .base import Agent, agent_registry
from models import ApiActivityType, ApiMessage, ChatMessage, ChatMessageError, ChatMessageStatus, ChatMessageStreaming, ApiStatusType, Job, JobRequirements, JobRequirementsMessage, Tunables
from models import (
ApiActivityType,
ApiMessage,
ChatMessage,
ChatMessageError,
ChatMessageStatus,
ChatMessageStreaming,
ApiStatusType,
Job,
JobRequirements,
JobRequirementsMessage,
Tunables,
)
from logger import logger
import backstory_traceback as traceback
class JobRequirementsAgent(Agent):
agent_type: Literal["job_requirements"] = "job_requirements" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
@ -94,14 +104,14 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
"""Analyze job requirements from job description."""
system_prompt, prompt = self.create_job_analysis_prompt(prompt)
status_message = ChatMessageStatus(
session_id=session_id,
content="Analyzing job requirements",
activity=ApiActivityType.THINKING
session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING
)
yield status_message
logger.info(f"🔍 {status_message.content}")
generated_message = None
async for generated_message in self.llm_one_shot(llm, model, session_id=session_id, prompt=prompt, system_prompt=system_prompt):
async for generated_message in self.llm_one_shot(
llm, model, session_id=session_id, prompt=prompt, system_prompt=system_prompt
):
if generated_message.status == ApiStatusType.ERROR:
yield generated_message
return
@ -110,8 +120,8 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
if not generated_message:
error_message = ChatMessageError(
session_id=session_id,
content="Job requirements analysis failed to generate a response.")
session_id=session_id, content="Job requirements analysis failed to generate a response."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
return
@ -132,18 +142,18 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
display = {
"technical_skills": {
"required": reqs.technical_skills.required,
"preferred": reqs.technical_skills.preferred
"preferred": reqs.technical_skills.preferred,
},
"experience_requirements": {
"required": reqs.experience_requirements.required,
"preferred": reqs.experience_requirements.preferred
"preferred": reqs.experience_requirements.preferred,
},
"soft_skills": reqs.soft_skills,
"experience": reqs.experience,
"education": reqs.education,
"certifications": reqs.certifications,
"preferred_attributes": reqs.preferred_attributes,
"company_values": reqs.company_values
"company_values": reqs.company_values,
}
return display
@ -152,19 +162,14 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
) -> AsyncGenerator[ApiMessage, None]:
if not self.user:
error_message = ChatMessageError(
session_id=session_id,
content="User is not set for this agent."
)
error_message = ChatMessageError(session_id=session_id, content="User is not set for this agent.")
logger.error(f"⚠️ {error_message.content}")
yield error_message
return
# Stage 1A: Analyze job requirements
status_message = ChatMessageStatus(
session_id=session_id,
content = f"Analyzing job requirements",
activity=ApiActivityType.THINKING
session_id=session_id, content="Analyzing job requirements", activity=ApiActivityType.THINKING
)
yield status_message
@ -178,8 +183,7 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
if not generated_message:
error_message = ChatMessageError(
session_id=session_id,
content="Job requirements analysis failed to generate a response."
session_id=session_id, content="Job requirements analysis failed to generate a response."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
@ -214,7 +218,9 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
return
except Exception as e:
status_message.status = ApiStatusType.ERROR
status_message.content = f"Unexpected error processing job requirements: {str(e)}\n\n{job_requirements_data}"
status_message.content = (
f"Unexpected error processing job requirements: {str(e)}\n\n{job_requirements_data}"
)
logger.error(traceback.format_exc())
logger.error(f"⚠️ {status_message.content}")
yield status_message
@ -238,8 +244,9 @@ Avoid vague categorizations and be precise about whether skills are explicitly r
job=job,
)
yield job_requirements_message
logger.info(f"✅ Job requirements analysis completed successfully.")
logger.info("✅ Job requirements analysis completed successfully.")
return
# Register the base agent
agent_registry.register(JobRequirementsAgent._agent_type, JobRequirementsAgent)

View File

@ -5,20 +5,19 @@ from .base import Agent, agent_registry
from logger import logger
from .registry import agent_registry
from models import ( ApiMessage, ApiStatusType, ChatMessageError, ChatMessageRagSearch, ApiStatusType, Tunables )
from models import ApiMessage, ApiStatusType, ChatMessageError, ChatMessageRagSearch, ApiStatusType, Tunables
class Chat(Agent):
"""
Chat Agent
"""
agent_type: Literal["rag_search"] = "rag_search" # type: ignore
agent_type: Literal["rag_search"] = "rag_search" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
async def generate(
self, llm: Any, model: str,
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
) -> AsyncGenerator[ApiMessage, None]:
"""
Generate a response based on the user message and the provided LLM.
@ -44,8 +43,7 @@ class Chat(Agent):
if not isinstance(rag_message, ChatMessageRagSearch):
logger.error(f"Expected ChatMessageRagSearch, got {type(rag_message)}")
error_message = ChatMessageError(
session_id=session_id,
content="RAG search did not return a valid response."
session_id=session_id, content="RAG search did not return a valid response."
)
yield error_message
return
@ -53,5 +51,6 @@ class Chat(Agent):
rag_message.status = ApiStatusType.DONE
yield rag_message
# Register the base agent
agent_registry.register(Chat._agent_type, Chat)

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from typing import List, Dict, Optional, Type
# We'll use a registry pattern rather than hardcoded strings
class AgentRegistry:
"""Registry for agent types and classes"""

View File

@ -1,26 +1,32 @@
from __future__ import annotations
from pydantic import model_validator, Field # type: ignore
from typing import (
Dict,
Literal,
ClassVar,
Any,
AsyncGenerator,
List,
Optional
# override
Optional,
# override
) # NOTE: You must import Optional for late binding to work
import json
import numpy as np # type: ignore
from .base import Agent, agent_registry
from models import (ApiMessage, ChatMessage, ChatMessageError, ChatMessageRagSearch, ChatMessageSkillAssessment, ApiStatusType, EvidenceDetail,
SkillAssessment, Tunables)
from models import (
ApiMessage,
ChatMessage,
ChatMessageError,
ChatMessageRagSearch,
ChatMessageSkillAssessment,
ApiStatusType,
EvidenceDetail,
SkillAssessment,
Tunables,
)
from logger import logger
import backstory_traceback as traceback
class SkillMatchAgent(Agent):
agent_type: Literal["skill_match"] = "skill_match" # type: ignore
agent_type: Literal["skill_match"] = "skill_match" # type: ignore
_agent_type: ClassVar[str] = agent_type # Add this for registration
def generate_skill_assessment_prompt(self, skill, rag_context):
@ -100,15 +106,12 @@ JSON RESPONSE:"""
return system_prompt, prompt
async def generate(
self, llm: Any, model: str,
session_id: str, prompt: str,
tunables: Optional[Tunables] = None,
temperature=0.7
self, llm: Any, model: str, session_id: str, prompt: str, tunables: Optional[Tunables] = None, temperature=0.7
) -> AsyncGenerator[ApiMessage, None]:
if not self.user:
error_message = ChatMessageError(
session_id=session_id,
content="Agent not attached to user. Attach the agent to a user before generating responses."
content="Agent not attached to user. Attach the agent to a user before generating responses.",
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
@ -116,10 +119,7 @@ JSON RESPONSE:"""
skill = prompt.strip()
if not skill:
error_message = ChatMessageError(
session_id=session_id,
content="Skill cannot be empty."
)
error_message = ChatMessageError(session_id=session_id, content="Skill cannot be empty.")
logger.error(f"⚠️ {error_message.content}")
yield error_message
return
@ -134,8 +134,7 @@ JSON RESPONSE:"""
if generated_message is None:
error_message = ChatMessageError(
session_id=session_id,
content="RAG search did not return a valid response."
session_id=session_id, content="RAG search did not return a valid response."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
@ -144,19 +143,20 @@ JSON RESPONSE:"""
if not isinstance(generated_message, ChatMessageRagSearch):
logger.error(f"Expected ChatMessageRagSearch, got {type(generated_message)}")
error_message = ChatMessageError(
session_id=session_id,
content="RAG search did not return a valid response."
session_id=session_id, content="RAG search did not return a valid response."
)
yield error_message
return
rag_message : ChatMessageRagSearch = generated_message
rag_message: ChatMessageRagSearch = generated_message
rag_context = self.get_rag_context(rag_message)
logger.info(f"🔍 RAG content retrieved {len(rag_context)} bytes of context")
system_prompt, prompt = self.generate_skill_assessment_prompt(skill=skill, rag_context=rag_context)
generated_message = None
async for generated_message in self.llm_one_shot(llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt, temperature=0.7):
async for generated_message in self.llm_one_shot(
llm=llm, model=model, session_id=session_id, prompt=prompt, system_prompt=system_prompt, temperature=0.7
):
if generated_message.status == ApiStatusType.ERROR:
logger.error(f"⚠️ {generated_message.content}")
yield generated_message
@ -166,8 +166,7 @@ JSON RESPONSE:"""
if generated_message is None:
error_message = ChatMessageError(
session_id=session_id,
content="Skill assessment failed to generate a response."
session_id=session_id, content="Skill assessment failed to generate a response."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
@ -175,8 +174,7 @@ JSON RESPONSE:"""
if not isinstance(generated_message, ChatMessage):
error_message = ChatMessageError(
session_id=session_id,
content="Skill assessment did not return a valid message."
session_id=session_id, content="Skill assessment did not return a valid message."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
@ -199,14 +197,15 @@ JSON RESPONSE:"""
EvidenceDetail(
source=evidence.get("source", ""),
quote=evidence.get("quote", ""),
context=evidence.get("context", "")
) for evidence in skill_assessment_data.get("evidence_details", [])
]
context=evidence.get("context", ""),
)
for evidence in skill_assessment_data.get("evidence_details", [])
],
)
except Exception as e:
error_message = ChatMessageError(
session_id=session_id,
content=f"Failed to parse Skill assessment JSON: {str(e)}\n\n{generated_message.content}\n\nJSON:\n{json_str}\n\n"
content=f"Failed to parse Skill assessment JSON: {str(e)}\n\n{generated_message.content}\n\nJSON:\n{json_str}\n\n",
)
logger.error(traceback.format_exc())
logger.error(f"⚠️ {error_message.content}")
@ -233,8 +232,9 @@ JSON RESPONSE:"""
skill_assessment=skill_assessment,
)
yield skill_assessment_message
logger.info(f"✅ Skill assessment completed successfully.")
logger.info("✅ Skill assessment completed successfully.")
return
# Register the base agent
agent_registry.register(SkillMatchAgent._agent_type, SkillMatchAgent)

View File

@ -9,6 +9,7 @@ from typing import Optional, List, Dict, Any, Callable
from logger import logger
from database.manager import DatabaseManager
class BackgroundTaskManager:
"""Manages background tasks for the application using asyncio instead of threading"""
@ -65,10 +66,12 @@ class BackgroundTaskManager:
stats = await database.get_guest_statistics()
# Log interesting statistics
if stats.get('total_guests', 0) > 0:
logger.info(f"📊 Guest stats: {stats['total_guests']} total, "
f"{stats['active_last_hour']} active in last hour, "
f"{stats['converted_guests']} converted")
if stats.get("total_guests", 0) > 0:
logger.info(
f"📊 Guest stats: {stats['total_guests']} total, "
f"{stats['active_last_hour']} active in last hour, "
f"{stats['converted_guests']} converted"
)
return stats
except Exception as e:
@ -84,6 +87,7 @@ class BackgroundTaskManager:
# Get Redis client safely (using the event loop safe method)
from database.manager import redis_manager
redis = await redis_manager.get_client()
# Clean up rate limit keys older than specified days
@ -101,7 +105,7 @@ class BackgroundTaskManager:
try:
ttl = await redis.ttl(key)
if ttl == -1: # No expiration set, check creation time
creation_time = await redis.hget(key, "created_at") # type: ignore
creation_time = await redis.hget(key, "created_at") # type: ignore
if creation_time:
creation_time = datetime.fromisoformat(creation_time).replace(tzinfo=UTC)
if creation_time < cutoff_time:
@ -192,10 +196,7 @@ class BackgroundTaskManager:
# Create asyncio tasks for each periodic task
for name, func, interval, *args in periodic_tasks:
task = asyncio.create_task(
self._run_periodic_task(name, func, interval, *args),
name=f"background_{name}"
)
task = asyncio.create_task(self._run_periodic_task(name, func, interval, *args), name=f"background_{name}")
self.tasks.append(task)
logger.info(f"📅 Scheduled background task: {name}")
@ -238,10 +239,7 @@ class BackgroundTaskManager:
# Wait for all tasks to complete with timeout
if self.tasks:
try:
await asyncio.wait_for(
asyncio.gather(*self.tasks, return_exceptions=True),
timeout=30.0
)
await asyncio.wait_for(asyncio.gather(*self.tasks, return_exceptions=True), timeout=30.0)
logger.info("✅ All background tasks stopped gracefully")
except asyncio.TimeoutError:
logger.warning("⚠️ Some background tasks did not stop within timeout")
@ -258,7 +256,7 @@ class BackgroundTaskManager:
"main_loop_id": id(self.main_loop) if self.main_loop else None,
"current_loop_id": None,
"task_count": len(self.tasks),
"tasks": []
"tasks": [],
}
try:
@ -317,6 +315,7 @@ async def setup_background_tasks(database_manager: DatabaseManager) -> Backgroun
await task_manager.start()
return task_manager
# For integration with your existing app startup
async def initialize_with_background_tasks(database_manager: DatabaseManager):
"""Initialize database and background tasks together"""

View File

@ -3,6 +3,7 @@ import os
import sys
import defines
def filter_traceback(tb, app_path=None, module_name=None):
"""
Filter traceback to include only frames from the specified application path or module.
@ -36,7 +37,8 @@ def filter_traceback(tb, app_path=None, module_name=None):
formatted_exc = traceback.format_exception_only(exc_type, exc_value)
# Combine the filtered stack trace with the exception message
return ''.join(formatted_stack + formatted_exc)
return "".join(formatted_stack + formatted_exc)
def format_exc(app_path=defines.app_path, module_name=None):
"""

View File

@ -1,4 +1,4 @@
from .core import RedisDatabase
from .manager import DatabaseManager, redis_manager
__all__ = ['RedisDatabase', 'DatabaseManager', 'redis_manager']
__all__ = ["RedisDatabase", "DatabaseManager", "redis_manager"]

View File

@ -1,15 +1,15 @@
KEY_PREFIXES = {
'viewers': 'viewer:',
'candidates': 'candidate:',
'employers': 'employer:',
'jobs': 'job:',
'job_applications': 'job_application:',
'chat_sessions': 'chat_session:',
'chat_messages': 'chat_messages:',
'ai_parameters': 'ai_parameters:',
'users': 'user:',
'candidate_documents': 'candidate_documents:',
'job_requirements': 'job_requirements:',
'resumes': 'resume:',
'user_resumes': 'user_resumes:',
"viewers": "viewer:",
"candidates": "candidate:",
"employers": "employer:",
"jobs": "job:",
"job_applications": "job_application:",
"chat_sessions": "chat_session:",
"chat_messages": "chat_messages:",
"ai_parameters": "ai_parameters:",
"users": "user:",
"candidate_documents": "candidate_documents:",
"job_requirements": "job_requirements:",
"resumes": "resume:",
"user_resumes": "user_resumes:",
}

View File

@ -10,6 +10,7 @@ from .mixins.job import JobMixin
from .mixins.skill import SkillMixin
from .mixins.ai import AIMixin
# RedisDatabase is the main class that combines all mixins for a
# comprehensive Redis database interface.
class RedisDatabase(

View File

@ -1,18 +1,15 @@
from redis.asyncio import (Redis, ConnectionPool)
from redis.asyncio import Redis, ConnectionPool
from typing import Optional, Optional
import json
import logging
import os
from datetime import datetime, UTC
import asyncio
from models import (
# User models
Candidate, Employer, BaseUser, EvidenceDetail, Guest, Authentication, AuthResponse, SkillAssessment,
)
from .core import RedisDatabase
logger = logging.getLogger(__name__)
# _RedisManager is a singleton class that manages the Redis connection and
# provides methods for connecting, disconnecting, and performing health checks.
#
@ -46,12 +43,10 @@ class _RedisManager:
retry_on_timeout=True,
socket_keepalive=True,
socket_keepalive_options={},
health_check_interval=30
health_check_interval=30,
)
self.redis = Redis(
connection_pool=self._connection_pool
)
self.redis = Redis(connection_pool=self._connection_pool)
if not self.redis:
raise RuntimeError("Redis client not initialized")
@ -135,7 +130,7 @@ class _RedisManager:
"uptime_seconds": info.get("uptime_in_seconds", 0),
"connected_clients": info.get("connected_clients", 0),
"used_memory_human": info.get("used_memory_human", "unknown"),
"total_commands_processed": info.get("total_commands_processed", 0)
"total_commands_processed": info.get("total_commands_processed", 0),
}
except Exception as e:
logger.error(f"Redis health check failed: {e}")
@ -177,9 +172,11 @@ class _RedisManager:
logger.error(f"Failed to get Redis info: {e}")
return None
# Global Redis manager instance
redis_manager = _RedisManager()
# DatabaseManager is an enhanced database manager that provides graceful shutdown capabilities
# It manages the Redis connection, tracks active requests, and allows for data backup before shutdown.
class DatabaseManager:
@ -231,7 +228,7 @@ class DatabaseManager:
backup_filename = f"backup_{datetime.now(UTC).strftime('%Y%m%d_%H%M%S')}.json"
# Save to local file (you might want to save to cloud storage instead)
with open(backup_filename, 'w') as f:
with open(backup_filename, "w") as f:
json.dump(backup_data, f, indent=2, default=str)
logger.info(f"Backup created: {backup_filename}")
@ -314,5 +311,3 @@ class DatabaseManager:
if self._shutdown_initiated:
raise RuntimeError("Application is shutting down")
return self.db

View File

@ -8,8 +8,10 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__)
class AIMixin(DatabaseProtocol):
"""Mixin for AI operations"""
async def get_ai_parameters(self, param_id: str) -> Optional[Dict]:
"""Get AI parameters by ID"""
key = f"{KEY_PREFIXES['ai_parameters']}{param_id}"
@ -36,7 +38,7 @@ class AIMixin(DatabaseProtocol):
result = {}
for key, value in zip(keys, values):
param_id = key.replace(KEY_PREFIXES['ai_parameters'], '')
param_id = key.replace(KEY_PREFIXES["ai_parameters"], "")
result[param_id] = self._deserialize(value)
return result
@ -45,4 +47,3 @@ class AIMixin(DatabaseProtocol):
"""Delete AI parameters"""
key = f"{KEY_PREFIXES['ai_parameters']}{param_id}"
await self.redis.delete(key)

View File

@ -7,6 +7,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class AnalyticsMixin:
"""Mixin for analytics-related database operations"""

View File

@ -8,6 +8,7 @@ from .protocols import DatabaseProtocol
logger = logging.getLogger(__name__)
class AuthMixin(DatabaseProtocol):
"""Mixin for auth-related database operations"""
@ -25,8 +26,9 @@ class AuthMixin(DatabaseProtocol):
token_data = await self.redis.get(key)
if token_data:
verification_info = json.loads(token_data)
if (verification_info.get("email", "").lower() == email_lower and
not verification_info.get("verified", False)):
if verification_info.get("email", "").lower() == email_lower and not verification_info.get(
"verified", False
):
# Extract token from key
token = key.replace("email_verification:", "")
verification_info["token"] = token
@ -115,10 +117,7 @@ class AuthMixin(DatabaseProtocol):
window_start = current_time - timedelta(hours=24)
# Filter out old attempts
recent_attempts = [
attempt for attempt in attempts_data
if datetime.fromisoformat(attempt) > window_start
]
recent_attempts = [attempt for attempt in attempts_data if datetime.fromisoformat(attempt) > window_start]
return len(recent_attempts)
@ -141,16 +140,13 @@ class AuthMixin(DatabaseProtocol):
# Keep only last 24 hours of attempts
window_start = current_time - timedelta(hours=24)
recent_attempts = [
attempt for attempt in attempts_data
if datetime.fromisoformat(attempt) > window_start
]
recent_attempts = [attempt for attempt in attempts_data if datetime.fromisoformat(attempt) > window_start]
# Store with 24 hour expiration
await self.redis.setex(
key,
24 * 60 * 60, # 24 hours
json.dumps(recent_attempts)
json.dumps(recent_attempts),
)
return True
@ -169,14 +165,14 @@ class AuthMixin(DatabaseProtocol):
"user_data": user_data,
"expires_at": (datetime.now(timezone.utc) + timedelta(hours=24)).isoformat(),
"created_at": datetime.now(timezone.utc).isoformat(),
"verified": False
"verified": False,
}
# Store with 24 hour expiration
await self.redis.setex(
key,
24 * 60 * 60, # 24 hours in seconds
json.dumps(verification_data, default=str)
json.dumps(verification_data, default=str),
)
logger.info(f"📧 Stored email verification token for {email}")
@ -208,7 +204,7 @@ class AuthMixin(DatabaseProtocol):
await self.redis.setex(
key,
24 * 60 * 60, # Keep for remaining TTL
json.dumps(token_data, default=str)
json.dumps(token_data, default=str),
)
return True
return False
@ -219,7 +215,7 @@ class AuthMixin(DatabaseProtocol):
async def store_mfa_code(self, email: str, code: str, device_id: str) -> bool:
"""Store MFA code for verification"""
try:
logger.info("🔐 Storing MFA code for email: %s", email )
logger.info("🔐 Storing MFA code for email: %s", email)
key = f"mfa_code:{email.lower()}:{device_id}"
mfa_data = {
"code": code,
@ -228,14 +224,14 @@ class AuthMixin(DatabaseProtocol):
"expires_at": (datetime.now(timezone.utc) + timedelta(minutes=10)).isoformat(),
"created_at": datetime.now(timezone.utc).isoformat(),
"attempts": 0,
"verified": False
"verified": False,
}
# Store with 10 minute expiration
await self.redis.setex(
key,
10 * 60, # 10 minutes in seconds
json.dumps(mfa_data, default=str)
json.dumps(mfa_data, default=str),
)
logger.info(f"🔐 Stored MFA code for {email}")
@ -266,7 +262,7 @@ class AuthMixin(DatabaseProtocol):
await self.redis.setex(
key,
10 * 60, # Keep original TTL
json.dumps(mfa_data, default=str)
json.dumps(mfa_data, default=str),
)
return mfa_data["attempts"]
return 0
@ -285,7 +281,7 @@ class AuthMixin(DatabaseProtocol):
await self.redis.setex(
key,
10 * 60, # Keep for remaining TTL
json.dumps(mfa_data, default=str)
json.dumps(mfa_data, default=str),
)
return True
return False
@ -327,7 +323,9 @@ class AuthMixin(DatabaseProtocol):
logger.error(f"❌ Error deleting authentication record for {user_id}: {e}")
return False
async def store_refresh_token(self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]) -> bool:
async def store_refresh_token(
self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]
) -> bool:
"""Store refresh token for a user"""
try:
key = f"refresh_token:{token}"
@ -337,7 +335,7 @@ class AuthMixin(DatabaseProtocol):
"device": device_info.get("device", "unknown"),
"ip_address": device_info.get("ip_address", "unknown"),
"is_revoked": False,
"created_at": datetime.now(timezone.utc).isoformat()
"created_at": datetime.now(timezone.utc).isoformat(),
}
# Store with expiration
@ -374,7 +372,7 @@ class AuthMixin(DatabaseProtocol):
token_data["is_revoked"] = True
token_data["revoked_at"] = datetime.now(timezone.utc).isoformat()
await self.redis.set(key, json.dumps(token_data, default=str))
logger.info(f"🔐 Revoked refresh token")
logger.info("🔐 Revoked refresh token")
return True
return False
except Exception as e:
@ -420,7 +418,7 @@ class AuthMixin(DatabaseProtocol):
"email": email.lower(),
"expires_at": expires_at.isoformat(),
"used": False,
"created_at": datetime.now(timezone.utc).isoformat()
"created_at": datetime.now(timezone.utc).isoformat(),
}
# Store with expiration
@ -457,7 +455,7 @@ class AuthMixin(DatabaseProtocol):
token_data["used"] = True
token_data["used_at"] = datetime.now(timezone.utc).isoformat()
await self.redis.set(key, json.dumps(token_data, default=str))
logger.info(f"🔐 Marked password reset token as used")
logger.info("🔐 Marked password reset token as used")
return True
return False
except Exception as e:
@ -473,14 +471,14 @@ class AuthMixin(DatabaseProtocol):
"timestamp": datetime.now(timezone.utc).isoformat(),
"user_id": user_id,
"event_type": event_type,
"details": details
"details": details,
}
# Add to list (latest events first)
await self.redis.lpush(key, json.dumps(event_data, default=str))# type: ignore
await self.redis.lpush(key, json.dumps(event_data, default=str)) # type: ignore
# Keep only last 100 events per day
await self.redis.ltrim(key, 0, 99)# type: ignore
await self.redis.ltrim(key, 0, 99) # type: ignore
# Set expiration for 30 days
await self.redis.expire(key, 30 * 24 * 60 * 60)
@ -496,10 +494,10 @@ class AuthMixin(DatabaseProtocol):
try:
events = []
for i in range(days):
date = (datetime.now(timezone.utc) - timedelta(days=i)).strftime('%Y-%m-%d')
date = (datetime.now(timezone.utc) - timedelta(days=i)).strftime("%Y-%m-%d")
key = f"security_log:{user_id}:{date}"
daily_events = await self.redis.lrange(key, 0, -1)# type: ignore
daily_events = await self.redis.lrange(key, 0, -1) # type: ignore
for event_json in daily_events:
events.append(json.loads(event_json))
@ -509,4 +507,3 @@ class AuthMixin(DatabaseProtocol):
except Exception as e:
logger.error(f"❌ Error retrieving security log for {user_id}: {e}")
return []

View File

@ -5,11 +5,13 @@ from typing import Any, Dict, TYPE_CHECKING
from .protocols import DatabaseProtocol
from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
pass
class BaseMixin(DatabaseProtocol):
"""Base mixin with core Redis operations and utilities"""
@ -45,6 +47,3 @@ class BaseMixin(DatabaseProtocol):
keys = await self.redis.keys(pattern)
if keys:
await self.redis.delete(*keys)

View File

@ -8,6 +8,7 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__)
class ChatMixin(DatabaseProtocol):
"""Mixin for chat-related database operations"""
@ -22,7 +23,7 @@ class ChatMixin(DatabaseProtocol):
"total_sessions": 0,
"total_messages": 0,
"first_chat": None,
"last_chat": None
"last_chat": None,
}
total_messages = 0
@ -41,7 +42,7 @@ class ChatMixin(DatabaseProtocol):
"total_messages": total_messages,
"first_chat": sessions_by_date[0].get("createdAt") if sessions_by_date else None,
"last_chat": sessions_by_date[-1].get("lastActivity") if sessions_by_date else None,
"recent_sessions": sessions[:5] # Last 5 sessions
"recent_sessions": sessions[:5], # Last 5 sessions
}
# Chat Sessions operations
@ -71,13 +72,13 @@ class ChatMixin(DatabaseProtocol):
result = {}
for key, value in zip(keys, values):
session_id = key.replace(KEY_PREFIXES['chat_sessions'], '')
session_id = key.replace(KEY_PREFIXES["chat_sessions"], "")
result[session_id] = self._deserialize(value)
return result
async def delete_chat_session(self, session_id: str) -> bool:
'''Delete a chat session from Redis'''
"""Delete a chat session from Redis"""
try:
result = await self.redis.delete(f"chat_session:{session_id}")
return result > 0
@ -86,11 +87,11 @@ class ChatMixin(DatabaseProtocol):
raise
async def delete_chat_message(self, session_id: str, message_id: str) -> bool:
'''Delete a specific chat message from Redis'''
"""Delete a specific chat message from Redis"""
try:
# Remove from the session's message list
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
await self.redis.lrem(key, 0, message_id)# type: ignore
await self.redis.lrem(key, 0, message_id) # type: ignore
# Delete the message data itself
result = await self.redis.delete(f"chat_message:{message_id}")
return result > 0
@ -102,13 +103,13 @@ class ChatMixin(DatabaseProtocol):
async def get_chat_messages(self, session_id: str) -> List[Dict]:
"""Get chat messages for a session"""
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
messages = await self.redis.lrange(key, 0, -1)# type: ignore
return [self._deserialize(msg) for msg in messages if msg] # type: ignore
messages = await self.redis.lrange(key, 0, -1) # type: ignore
return [self._deserialize(msg) for msg in messages if msg] # type: ignore
async def add_chat_message(self, session_id: str, message_data: Dict):
"""Add a chat message to a session"""
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
await self.redis.rpush(key, self._serialize(message_data))# type: ignore
await self.redis.rpush(key, self._serialize(message_data)) # type: ignore
async def set_chat_messages(self, session_id: str, messages: List[Dict]):
"""Set all chat messages for a session (replaces existing)"""
@ -120,7 +121,7 @@ class ChatMixin(DatabaseProtocol):
# Add new messages
if messages:
serialized_messages = [self._serialize(msg) for msg in messages]
await self.redis.rpush(key, *serialized_messages)# type: ignore
await self.redis.rpush(key, *serialized_messages) # type: ignore
async def get_all_chat_messages(self) -> Dict[str, List[Dict]]:
"""Get all chat messages grouped by session"""
@ -132,8 +133,8 @@ class ChatMixin(DatabaseProtocol):
result = {}
for key in keys:
session_id = key.replace(KEY_PREFIXES['chat_messages'], '')
messages = await self.redis.lrange(key, 0, -1)# type: ignore
session_id = key.replace(KEY_PREFIXES["chat_messages"], "")
messages = await self.redis.lrange(key, 0, -1) # type: ignore
result[session_id] = [self._deserialize(msg) for msg in messages if msg]
return result
@ -164,8 +165,7 @@ class ChatMixin(DatabaseProtocol):
for session_data in all_sessions.values():
context = session_data.get("context", {})
if (context.get("relatedEntityType") == "candidate" and
context.get("relatedEntityId") == candidate_id):
if context.get("relatedEntityType") == "candidate" and context.get("relatedEntityId") == candidate_id:
candidate_sessions.append(session_data)
# Sort by last activity (most recent first)
@ -188,7 +188,7 @@ class ChatMixin(DatabaseProtocol):
async def get_chat_message_count(self, session_id: str) -> int:
"""Get the total number of messages in a chat session"""
key = f"{KEY_PREFIXES['chat_messages']}{session_id}"
return await self.redis.llen(key)# type: ignore
return await self.redis.llen(key) # type: ignore
async def search_chat_messages(self, session_id: str, query: str) -> List[Dict]:
"""Search for messages containing specific text in a session"""
@ -236,7 +236,6 @@ class ChatMixin(DatabaseProtocol):
return archived_count
# Analytics and Reporting
async def get_chat_statistics(self) -> Dict[str, Any]:
"""Get comprehensive chat statistics"""
@ -250,7 +249,7 @@ class ChatMixin(DatabaseProtocol):
"archived_sessions": 0,
"sessions_by_type": {},
"sessions_with_candidates": 0,
"average_messages_per_session": 0
"average_messages_per_session": 0,
}
# Analyze sessions

View File

@ -7,6 +7,7 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__)
class DocumentMixin(DatabaseProtocol):
"""Mixin for document-related database operations"""
@ -31,7 +32,7 @@ class DocumentMixin(DatabaseProtocol):
try:
# Get all document IDs for this candidate
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
document_ids = await self.redis.lrange(key, 0, -1)# type: ignore
document_ids = await self.redis.lrange(key, 0, -1) # type: ignore
if not document_ids:
logger.info(f"No documents found for candidate {candidate_id}")
@ -64,7 +65,7 @@ class DocumentMixin(DatabaseProtocol):
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]:
"""Get all documents for a specific candidate"""
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
document_ids = await self.redis.lrange(key, 0, -1) # type: ignore
document_ids = await self.redis.lrange(key, 0, -1) # type: ignore
if not document_ids:
return []
@ -83,7 +84,7 @@ class DocumentMixin(DatabaseProtocol):
documents.append(doc_data)
else:
# Clean up orphaned document ID
await self.redis.lrem(key, 0, doc_id)# type: ignore
await self.redis.lrem(key, 0, doc_id) # type: ignore
logger.warning(f"Removed orphaned document ID {doc_id} for candidate {candidate_id}")
return documents
@ -91,12 +92,12 @@ class DocumentMixin(DatabaseProtocol):
async def add_document_to_candidate(self, candidate_id: str, document_id: str):
"""Add a document ID to a candidate's document list"""
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
await self.redis.rpush(key, document_id)# type: ignore
await self.redis.rpush(key, document_id) # type: ignore
async def remove_document_from_candidate(self, candidate_id: str, document_id: str):
"""Remove a document ID from a candidate's document list"""
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
await self.redis.lrem(key, 0, document_id)# type: ignore
await self.redis.lrem(key, 0, document_id) # type: ignore
async def update_document(self, document_id: str, updates: Dict) -> Dict[Any, Any] | None:
"""Update document metadata"""
@ -128,7 +129,7 @@ class DocumentMixin(DatabaseProtocol):
async def get_document_count_for_candidate(self, candidate_id: str) -> int:
"""Get total number of documents for a candidate"""
key = f"{KEY_PREFIXES['candidate_documents']}{candidate_id}"
return await self.redis.llen(key)# type: ignore
return await self.redis.llen(key) # type: ignore
async def search_candidate_documents(self, candidate_id: str, query: str) -> List[Dict]:
"""Search documents by filename for a candidate"""
@ -136,8 +137,7 @@ class DocumentMixin(DatabaseProtocol):
query_lower = query.lower()
return [
doc for doc in all_documents
if (query_lower in doc.get("filename", "").lower() or
query_lower in doc.get("originalName", "").lower())
doc
for doc in all_documents
if (query_lower in doc.get("filename", "").lower() or query_lower in doc.get("originalName", "").lower())
]

View File

@ -8,8 +8,10 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__)
class JobMixin(DatabaseProtocol):
"""Mixin for job-related database operations"""
async def get_job(self, job_id: str) -> Optional[Dict]:
"""Get job by ID"""
key = f"{KEY_PREFIXES['jobs']}{job_id}"
@ -36,7 +38,7 @@ class JobMixin(DatabaseProtocol):
result = {}
for key, value in zip(keys, values):
job_id = key.replace(KEY_PREFIXES['jobs'], '')
job_id = key.replace(KEY_PREFIXES["jobs"], "")
result[job_id] = self._deserialize(value)
return result
@ -124,7 +126,7 @@ class JobMixin(DatabaseProtocol):
result = {}
for key, value in zip(keys, values):
app_id = key.replace(KEY_PREFIXES['job_applications'], '')
app_id = key.replace(KEY_PREFIXES["job_applications"], "")
result[app_id] = self._deserialize(value)
return result
@ -189,7 +191,7 @@ class JobMixin(DatabaseProtocol):
requirements_with_meta = {
**requirements,
"cached_at": datetime.now(UTC).isoformat(),
"document_id": document_id
"document_id": document_id,
}
await self.redis.set(key, self._serialize(requirements_with_meta))
@ -232,7 +234,7 @@ class JobMixin(DatabaseProtocol):
result = {}
for key, value in zip(keys, values):
document_id = key.replace(KEY_PREFIXES['job_requirements'], '')
document_id = key.replace(KEY_PREFIXES["job_requirements"], "")
if value:
result[document_id] = self._deserialize(value)
@ -247,11 +249,7 @@ class JobMixin(DatabaseProtocol):
pattern = f"{KEY_PREFIXES['job_requirements']}*"
keys = await self.redis.keys(pattern)
stats = {
"total_cached_requirements": len(keys),
"cache_dates": {},
"documents_with_requirements": []
}
stats = {"total_cached_requirements": len(keys), "cache_dates": {}, "documents_with_requirements": []}
if keys:
# Get cache dates for analysis
@ -264,7 +262,7 @@ class JobMixin(DatabaseProtocol):
if value:
requirements_data = self._deserialize(value)
if requirements_data:
document_id = key.replace(KEY_PREFIXES['job_requirements'], '')
document_id = key.replace(KEY_PREFIXES["job_requirements"], "")
stats["documents_with_requirements"].append(document_id)
# Track cache dates
@ -277,4 +275,3 @@ class JobMixin(DatabaseProtocol):
except Exception as e:
logger.error(f"❌ Error getting job requirements stats: {e}")
return {"total_cached_requirements": 0, "cache_dates": {}, "documents_with_requirements": []}

View File

@ -7,153 +7,412 @@ if TYPE_CHECKING:
from models import SkillAssessment
class DatabaseProtocol(Protocol):
# Base mixin
redis: Redis
def _serialize(self, data) -> str: ...
def _deserialize(self, data: str): ...
def _serialize(self, data) -> str:
...
def _deserialize(self, data: str):
...
# Chat mixin
async def add_chat_message(self, session_id: str, message_data: Dict): ...
async def archive_chat_session(self, session_id: str): ...
async def bulk_update_chat_sessions(self, session_updates: Dict[str, Dict]): ...
async def delete_chat_message(self, session_id: str, message_id: str) -> bool: ...
async def delete_chat_messages(self, session_id: str): ...
async def delete_chat_session_completely(self, session_id: str): ...
async def delete_chat_session(self, session_id: str) -> bool: ...
async def add_chat_message(self, session_id: str, message_data: Dict):
...
async def archive_chat_session(self, session_id: str):
...
async def bulk_update_chat_sessions(self, session_updates: Dict[str, Dict]):
...
async def delete_chat_message(self, session_id: str, message_id: str) -> bool:
...
async def delete_chat_messages(self, session_id: str):
...
async def delete_chat_session_completely(self, session_id: str):
...
async def delete_chat_session(self, session_id: str) -> bool:
...
# Document mixin
async def add_document_to_candidate(self, candidate_id: str, document_id: str): ...
async def bulk_update_document_rag_status(self, candidate_id: str, document_ids: List[str], include_in_rag: bool): ...
async def add_document_to_candidate(self, candidate_id: str, document_id: str):
...
async def bulk_update_document_rag_status(self, candidate_id: str, document_ids: List[str], include_in_rag: bool):
...
# Job mixin
async def bulk_delete_job_requirements(self, document_ids: List[str]) -> int: ...
async def cache_skill_match(self, cache_key: str, assessment: SkillAssessment) -> None: ...
async def bulk_delete_job_requirements(self, document_ids: List[str]) -> int:
...
async def cache_skill_match(self, cache_key: str, assessment: SkillAssessment) -> None:
...
# User mixin
async def delete_candidate_batch(self, candidate_ids: List[str]) -> Dict[str, Dict[str, int]]: ...
async def delete_candidate(self, candidate_id: str) -> Dict[str, int]: ...
async def delete_employer(self, employer_id: str): ...
async def delete_guest(self, guest_id: str) -> bool: ...
async def delete_user(self, email: str): ...
async def find_candidate_by_username(self, username: str) -> Optional[Dict]: ...
async def get_all_users(self) -> Dict[str, Any]: ...
async def get_all_viewers(self) -> Dict[str, Any]: ...
async def get_candidate_chat_summary(self, candidate_id: str) -> Dict[str, Any]: ...
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]: ...
async def get_candidate(self, candidate_id: str) -> Optional[Dict]: ...
async def get_employer(self, employer_id: str) -> Optional[Dict]: ...
async def get_guest_by_session_id(self, session_id: str) -> Optional[Dict[str, Any]]: ...
async def get_guest(self, guest_id: str) -> Optional[Dict[str, Any]]: ...
async def get_guest_statistics(self) -> Dict[str, Any]: ...
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: ...
async def get_user_by_username(self, username: str) -> Optional[Dict]: ...
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]: ...
async def get_user_security_log(self, user_id: str, days: int = 7) -> List[Dict[str, Any]]: ...
async def get_user(self, login: str) -> Optional[Dict[str, Any]]: ...
async def invalidate_candidate_skill_cache(self, candidate_id: str) -> int: ...
async def invalidate_user_skill_cache(self, user_id: str) -> int: ...
async def set_candidate(self, candidate_id: str, candidate_data: Dict): ...
async def set_employer(self, employer_id: str, employer_data: Dict): ...
async def set_guest(self, guest_id: str, guest_data: Dict[str, Any]) -> None: ...
async def set_user_by_id(self, user_id: str, user_data: Dict[str, Any]) -> bool: ...
async def set_user(self, login: str, user_data: Dict[str, Any]) -> bool: ...
async def update_user_rag_timestamp(self, user_id: str) -> bool: ...
async def user_exists_by_email(self, email: str) -> bool: ...
async def user_exists_by_username(self, username: str) -> bool: ...
async def delete_candidate_batch(self, candidate_ids: List[str]) -> Dict[str, Dict[str, int]]:
...
async def delete_candidate(self, candidate_id: str) -> Dict[str, int]:
...
async def delete_employer(self, employer_id: str):
...
async def delete_guest(self, guest_id: str) -> bool:
...
async def delete_user(self, email: str):
...
async def find_candidate_by_username(self, username: str) -> Optional[Dict]:
...
async def get_all_users(self) -> Dict[str, Any]:
...
async def get_all_viewers(self) -> Dict[str, Any]:
...
async def get_candidate_chat_summary(self, candidate_id: str) -> Dict[str, Any]:
...
async def get_candidate_documents(self, candidate_id: str) -> List[Dict]:
...
async def get_candidate(self, candidate_id: str) -> Optional[Dict]:
...
async def get_employer(self, employer_id: str) -> Optional[Dict]:
...
async def get_guest_by_session_id(self, session_id: str) -> Optional[Dict[str, Any]]:
...
async def get_guest(self, guest_id: str) -> Optional[Dict[str, Any]]:
...
async def get_guest_statistics(self) -> Dict[str, Any]:
...
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
...
async def get_user_by_username(self, username: str) -> Optional[Dict]:
...
async def get_user_rag_update_time(self, user_id: str) -> Optional[datetime]:
...
async def get_user_security_log(self, user_id: str, days: int = 7) -> List[Dict[str, Any]]:
...
async def get_user(self, login: str) -> Optional[Dict[str, Any]]:
...
async def invalidate_candidate_skill_cache(self, candidate_id: str) -> int:
...
async def invalidate_user_skill_cache(self, user_id: str) -> int:
...
async def set_candidate(self, candidate_id: str, candidate_data: Dict):
...
async def set_employer(self, employer_id: str, employer_data: Dict):
...
async def set_guest(self, guest_id: str, guest_data: Dict[str, Any]) -> None:
...
async def set_user_by_id(self, user_id: str, user_data: Dict[str, Any]) -> bool:
...
async def set_user(self, login: str, user_data: Dict[str, Any]) -> bool:
...
async def update_user_rag_timestamp(self, user_id: str) -> bool:
...
async def user_exists_by_email(self, email: str) -> bool:
...
async def user_exists_by_username(self, username: str) -> bool:
...
# Auth mixin
async def cleanup_expired_verification_tokens(self) -> int: ...
async def cleanup_inactive_guests(self, inactive_hours: int = 24) -> int: ...
async def cleanup_old_chat_sessions(self, days_old: int = 90) -> int: ...
async def cleanup_orphaned_job_requirements(self) -> int: ...
async def clear_all_data(self: "DatabaseProtocol"): ...
async def clear_all_skill_match_cache(self) -> int: ...
async def cleanup_expired_verification_tokens(self) -> int:
...
async def cleanup_inactive_guests(self, inactive_hours: int = 24) -> int:
...
async def cleanup_old_chat_sessions(self, days_old: int = 90) -> int:
...
async def cleanup_orphaned_job_requirements(self) -> int:
...
async def clear_all_data(self: "DatabaseProtocol"):
...
async def clear_all_skill_match_cache(self) -> int:
...
# Resume mixin
async def delete_ai_parameters(self, param_id: str): ...
async def delete_all_candidate_documents(self, candidate_id: str) -> int: ...
async def delete_all_resumes_for_user(self, user_id: str) -> int: ...
async def delete_authentication(self, user_id: str) -> bool: ...
async def delete_document(self, document_id: str): ...
async def delete_ai_parameters(self, param_id: str):
...
async def delete_job_application(self, application_id: str): ...
async def delete_job_requirements(self, document_id: str) -> bool: ...
async def delete_job(self, job_id: str): ...
async def delete_resume(self, user_id: str, resume_id: str) -> bool: ...
async def delete_viewer(self, viewer_id: str): ...
async def find_verification_token_by_email(self, email: str) -> Optional[Dict[str, Any]]: ...
async def get_ai_parameters(self, param_id: str) -> Optional[Dict]: ...
async def get_all_ai_parameters(self) -> Dict[str, Any]: ...
async def get_all_candidates(self) -> Dict[str, Any]: ...
async def get_all_chat_messages(self) -> Dict[str, List[Dict]]: ...
async def get_all_chat_sessions(self) -> Dict[str, Any]: ...
async def get_all_employers(self) -> Dict[str, Any]: ...
async def get_all_guests(self) -> Dict[str, Dict[str, Any]]: ...
async def get_all_job_applications(self) -> Dict[str, Any]: ...
async def get_all_job_requirements(self) -> Dict[str, Any]: ...
async def get_all_jobs(self) -> Dict[str, Any]: ...
async def get_all_resumes_for_user(self, user_id: str) -> List[Dict]: ...
async def get_all_resumes(self) -> Dict[str, List[Dict]]: ...
async def get_authentication(self, user_id: str) -> Optional[Dict[str, Any]]: ...
async def get_cached_skill_match(self, cache_key: str) -> Optional[SkillAssessment]: ...
async def get_chat_message_count(self, session_id: str) -> int: ...
async def get_chat_messages(self, session_id: str) -> List[Dict]: ...
async def get_chat_sessions_by_candidate(self, candidate_id: str) -> List[Dict]: ...
async def get_chat_sessions_by_user(self, user_id: str) -> List[Dict]: ...
async def get_chat_session(self, session_id: str) -> Optional[Dict]: ...
async def get_chat_statistics(self) -> Dict[str, Any]: ...
async def get_document_count_for_candidate(self, candidate_id: str) -> int: ...
async def get_documents_by_rag_status(self, candidate_id: str, include_in_rag: bool = True) -> List[Dict]: ...
async def get_document(self, document_id: str) -> Optional[Dict]: ...
async def get_email_verification_token(self, token: str) -> Optional[Dict[str, Any]]: ...
async def get_job_application(self, application_id: str) -> Optional[Dict]: ...
async def get_job_requirements_by_candidate(self, candidate_id: str) -> List[Dict]: ...
async def get_job_requirements(self, document_id: str) -> Optional[Dict]: ...
async def get_job_requirements_stats(self) -> Dict[str, Any]: ...
async def get_job(self, job_id: str) -> Optional[Dict]: ...
async def get_mfa_code(self, email: str, device_id: str) -> Optional[Dict[str, Any]]: ...
async def get_multiple_candidates_by_usernames(self, usernames: List[str]) -> Dict[str, Dict]: ...
async def get_password_reset_token(self, token: str) -> Optional[Dict[str, Any]]: ...
async def get_pending_verifications_count(self) -> int: ...
async def get_recent_chat_messages(self, session_id: str, limit: int = 10) -> List[Dict]: ...
async def get_refresh_token(self, token: str) -> Optional[Dict[str, Any]]: ...
async def get_resumes_by_candidate(self, user_id: str, candidate_id: str) -> List[Dict]: ...
async def get_resumes_by_job(self, user_id: str, job_id: str) -> List[Dict]: ...
async def get_resume(self, user_id: str, resume_id: str) -> Optional[Dict]: ...
async def get_resume_statistics(self, user_id: str) -> Dict[str, Any]: ...
async def get_stats(self) -> Dict[str, int]: ...
async def get_verification_attempts_count(self, email: str) -> int: ...
async def get_viewer(self, viewer_id: str) -> Optional[Dict]: ...
async def increment_mfa_attempts(self, email: str, device_id: str) -> int: ...
async def invalidate_job_requirements_cache(self, document_id: str) -> bool: ...
async def log_security_event(self, user_id: str, event_type: str, details: Dict[str, Any]) -> bool: ...
async def mark_email_verified(self, token: str) -> bool: ...
async def mark_mfa_verified(self, email: str, device_id: str) -> bool: ...
async def mark_password_reset_token_used(self, token: str) -> bool: ...
async def record_verification_attempt(self, email: str) -> bool: ...
async def remove_document_from_candidate(self, candidate_id: str, document_id: str): ...
async def revoke_all_user_tokens(self, user_id: str) -> bool: ...
async def revoke_refresh_token(self, token: str) -> bool: ...
async def save_job_requirements(self, document_id: str, requirements: Dict) -> bool: ...
async def search_candidate_documents(self, candidate_id: str, query: str) -> List[Dict]: ...
async def search_chat_messages(self, session_id: str, query: str) -> List[Dict]: ...
async def search_resumes_for_user(self, user_id: str, query: str) -> List[Dict]: ...
async def set_ai_parameters(self, param_id: str, param_data: Dict): ...
async def set_authentication(self, user_id: str, auth_data: Dict[str, Any]) -> bool: ...
async def set_chat_messages(self, session_id: str, messages: List[Dict]): ...
async def set_chat_session(self, session_id: str, session_data: Dict): ...
async def set_document(self, document_id: str, document_data: Dict): ...
async def set_job_application(self, application_id: str, application_data: Dict): ...
async def set_job(self, job_id: str, job_data: Dict): ...
async def set_resume(self, user_id: str, resume_data: Dict) -> bool: ...
async def set_viewer(self, viewer_id: str, viewer_data: Dict): ...
async def store_email_verification_token(self, email: str, token: str, user_type: str, user_data: dict) -> bool: ...
async def store_mfa_code(self, email: str, code: str, device_id: str) -> bool: ...
async def store_password_reset_token(self, email: str, token: str, expires_at: datetime) -> bool: ...
async def store_refresh_token(self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]) -> bool: ...
async def update_chat_session_activity(self, session_id: str): ...
async def update_document(self, document_id: str, updates: Dict)-> Dict[Any, Any] | None: ...
async def update_resume(self, user_id: str, resume_id: str, updates: Dict) -> Optional[Dict]: ...
async def delete_all_candidate_documents(self, candidate_id: str) -> int:
...
async def delete_all_resumes_for_user(self, user_id: str) -> int:
...
async def delete_authentication(self, user_id: str) -> bool:
...
async def delete_document(self, document_id: str):
...
async def delete_job_application(self, application_id: str):
...
async def delete_job_requirements(self, document_id: str) -> bool:
...
async def delete_job(self, job_id: str):
...
async def delete_resume(self, user_id: str, resume_id: str) -> bool:
...
async def delete_viewer(self, viewer_id: str):
...
async def find_verification_token_by_email(self, email: str) -> Optional[Dict[str, Any]]:
...
async def get_ai_parameters(self, param_id: str) -> Optional[Dict]:
...
async def get_all_ai_parameters(self) -> Dict[str, Any]:
...
async def get_all_candidates(self) -> Dict[str, Any]:
...
async def get_all_chat_messages(self) -> Dict[str, List[Dict]]:
...
async def get_all_chat_sessions(self) -> Dict[str, Any]:
...
async def get_all_employers(self) -> Dict[str, Any]:
...
async def get_all_guests(self) -> Dict[str, Dict[str, Any]]:
...
async def get_all_job_applications(self) -> Dict[str, Any]:
...
async def get_all_job_requirements(self) -> Dict[str, Any]:
...
async def get_all_jobs(self) -> Dict[str, Any]:
...
async def get_all_resumes_for_user(self, user_id: str) -> List[Dict]:
...
async def get_all_resumes(self) -> Dict[str, List[Dict]]:
...
async def get_authentication(self, user_id: str) -> Optional[Dict[str, Any]]:
...
async def get_cached_skill_match(self, cache_key: str) -> Optional[SkillAssessment]:
...
async def get_chat_message_count(self, session_id: str) -> int:
...
async def get_chat_messages(self, session_id: str) -> List[Dict]:
...
async def get_chat_sessions_by_candidate(self, candidate_id: str) -> List[Dict]:
...
async def get_chat_sessions_by_user(self, user_id: str) -> List[Dict]:
...
async def get_chat_session(self, session_id: str) -> Optional[Dict]:
...
async def get_chat_statistics(self) -> Dict[str, Any]:
...
async def get_document_count_for_candidate(self, candidate_id: str) -> int:
...
async def get_documents_by_rag_status(self, candidate_id: str, include_in_rag: bool = True) -> List[Dict]:
...
async def get_document(self, document_id: str) -> Optional[Dict]:
...
async def get_email_verification_token(self, token: str) -> Optional[Dict[str, Any]]:
...
async def get_job_application(self, application_id: str) -> Optional[Dict]:
...
async def get_job_requirements_by_candidate(self, candidate_id: str) -> List[Dict]:
...
async def get_job_requirements(self, document_id: str) -> Optional[Dict]:
...
async def get_job_requirements_stats(self) -> Dict[str, Any]:
...
async def get_job(self, job_id: str) -> Optional[Dict]:
...
async def get_mfa_code(self, email: str, device_id: str) -> Optional[Dict[str, Any]]:
...
async def get_multiple_candidates_by_usernames(self, usernames: List[str]) -> Dict[str, Dict]:
...
async def get_password_reset_token(self, token: str) -> Optional[Dict[str, Any]]:
...
async def get_pending_verifications_count(self) -> int:
...
async def get_recent_chat_messages(self, session_id: str, limit: int = 10) -> List[Dict]:
...
async def get_refresh_token(self, token: str) -> Optional[Dict[str, Any]]:
...
async def get_resumes_by_candidate(self, user_id: str, candidate_id: str) -> List[Dict]:
...
async def get_resumes_by_job(self, user_id: str, job_id: str) -> List[Dict]:
...
async def get_resume(self, user_id: str, resume_id: str) -> Optional[Dict]:
...
async def get_resume_statistics(self, user_id: str) -> Dict[str, Any]:
...
async def get_stats(self) -> Dict[str, int]:
...
async def get_verification_attempts_count(self, email: str) -> int:
...
async def get_viewer(self, viewer_id: str) -> Optional[Dict]:
...
async def increment_mfa_attempts(self, email: str, device_id: str) -> int:
...
async def invalidate_job_requirements_cache(self, document_id: str) -> bool:
...
async def log_security_event(self, user_id: str, event_type: str, details: Dict[str, Any]) -> bool:
...
async def mark_email_verified(self, token: str) -> bool:
...
async def mark_mfa_verified(self, email: str, device_id: str) -> bool:
...
async def mark_password_reset_token_used(self, token: str) -> bool:
...
async def record_verification_attempt(self, email: str) -> bool:
...
async def remove_document_from_candidate(self, candidate_id: str, document_id: str):
...
async def revoke_all_user_tokens(self, user_id: str) -> bool:
...
async def revoke_refresh_token(self, token: str) -> bool:
...
async def save_job_requirements(self, document_id: str, requirements: Dict) -> bool:
...
async def search_candidate_documents(self, candidate_id: str, query: str) -> List[Dict]:
...
async def search_chat_messages(self, session_id: str, query: str) -> List[Dict]:
...
async def search_resumes_for_user(self, user_id: str, query: str) -> List[Dict]:
...
async def set_ai_parameters(self, param_id: str, param_data: Dict):
...
async def set_authentication(self, user_id: str, auth_data: Dict[str, Any]) -> bool:
...
async def set_chat_messages(self, session_id: str, messages: List[Dict]):
...
async def set_chat_session(self, session_id: str, session_data: Dict):
...
async def set_document(self, document_id: str, document_data: Dict):
...
async def set_job_application(self, application_id: str, application_data: Dict):
...
async def set_job(self, job_id: str, job_data: Dict):
...
async def set_resume(self, user_id: str, resume_data: Dict) -> bool:
...
async def set_viewer(self, viewer_id: str, viewer_data: Dict):
...
async def store_email_verification_token(self, email: str, token: str, user_type: str, user_data: dict) -> bool:
...
async def store_mfa_code(self, email: str, code: str, device_id: str) -> bool:
...
async def store_password_reset_token(self, email: str, token: str, expires_at: datetime) -> bool:
...
async def store_refresh_token(
self, user_id: str, token: str, expires_at: datetime, device_info: Dict[str, str]
) -> bool:
...
async def update_chat_session_activity(self, session_id: str):
...
async def update_document(self, document_id: str, updates: Dict) -> Dict[Any, Any] | None:
...
async def update_resume(self, user_id: str, resume_id: str, updates: Dict) -> Optional[Dict]:
...

View File

@ -7,6 +7,7 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__)
class ResumeMixin(DatabaseProtocol):
"""Mixin for resume-related database operations"""
@ -14,10 +15,10 @@ class ResumeMixin(DatabaseProtocol):
"""Save a resume for a user"""
try:
# Generate resume_id if not present
if 'id' not in resume_data:
if "id" not in resume_data:
raise ValueError("Resume data must include an 'id' field")
resume_id = resume_data['id']
resume_id = resume_data["id"]
# Store the resume data
key = f"{KEY_PREFIXES['resumes']}{user_id}:{resume_id}"
@ -25,7 +26,7 @@ class ResumeMixin(DatabaseProtocol):
# Add resume_id to user's resume list
user_resumes_key = f"{KEY_PREFIXES['user_resumes']}{user_id}"
await self.redis.rpush(user_resumes_key, resume_id) # type: ignore
await self.redis.rpush(user_resumes_key, resume_id) # type: ignore
logger.info(f"📄 Saved resume {resume_id} for user {user_id}")
return True
@ -53,7 +54,7 @@ class ResumeMixin(DatabaseProtocol):
try:
# Get all resume IDs for this user
user_resumes_key = f"{KEY_PREFIXES['user_resumes']}{user_id}"
resume_ids = await self.redis.lrange(user_resumes_key, 0, -1)# type: ignore
resume_ids = await self.redis.lrange(user_resumes_key, 0, -1) # type: ignore
if not resume_ids:
logger.info(f"📄 No resumes found for user {user_id}")
@ -73,7 +74,7 @@ class ResumeMixin(DatabaseProtocol):
resumes.append(resume_data)
else:
# Clean up orphaned resume ID
await self.redis.lrem(user_resumes_key, 0, resume_id)# type: ignore
await self.redis.lrem(user_resumes_key, 0, resume_id) # type: ignore
logger.warning(f"Removed orphaned resume ID {resume_id} for user {user_id}")
# Sort by created_at timestamp (most recent first)
@ -94,7 +95,7 @@ class ResumeMixin(DatabaseProtocol):
# Remove from user's resume list
user_resumes_key = f"{KEY_PREFIXES['user_resumes']}{user_id}"
await self.redis.lrem(user_resumes_key, 0, resume_id)# type: ignore
await self.redis.lrem(user_resumes_key, 0, resume_id) # type: ignore
if result > 0:
logger.info(f"🗑️ Deleted resume {resume_id} for user {user_id}")
@ -111,7 +112,7 @@ class ResumeMixin(DatabaseProtocol):
try:
# Get all resume IDs for this user
user_resumes_key = f"{KEY_PREFIXES['user_resumes']}{user_id}"
resume_ids = await self.redis.lrange(user_resumes_key, 0, -1)# type: ignore
resume_ids = await self.redis.lrange(user_resumes_key, 0, -1) # type: ignore
if not resume_ids:
logger.info(f"📄 No resumes found for user {user_id}")
@ -159,7 +160,7 @@ class ResumeMixin(DatabaseProtocol):
for key, value in zip(keys, values):
if value:
# Extract user_id from key format: resume:{user_id}:{resume_id}
key_parts = key.replace(KEY_PREFIXES['resumes'], '').split(':', 1)
key_parts = key.replace(KEY_PREFIXES["resumes"], "").split(":", 1)
if len(key_parts) >= 1:
user_id = key_parts[0]
resume_data = self._deserialize(value)
@ -186,12 +187,14 @@ class ResumeMixin(DatabaseProtocol):
matching_resumes = []
for resume in all_resumes:
# Search in resume content, job_id, candidate_id, etc.
searchable_text = " ".join([
resume.get("resume", ""),
resume.get("job_id", ""),
resume.get("candidate_id", ""),
str(resume.get("created_at", ""))
]).lower()
searchable_text = " ".join(
[
resume.get("resume", ""),
resume.get("job_id", ""),
resume.get("candidate_id", ""),
str(resume.get("created_at", "")),
]
).lower()
if query_lower in searchable_text:
matching_resumes.append(resume)
@ -206,10 +209,7 @@ class ResumeMixin(DatabaseProtocol):
"""Get all resumes for a specific candidate created by a user"""
try:
all_resumes = await self.get_all_resumes_for_user(user_id)
candidate_resumes = [
resume for resume in all_resumes
if resume.get("candidate_id") == candidate_id
]
candidate_resumes = [resume for resume in all_resumes if resume.get("candidate_id") == candidate_id]
logger.info(f"📄 Found {len(candidate_resumes)} resumes for candidate {candidate_id} by user {user_id}")
return candidate_resumes
@ -221,10 +221,7 @@ class ResumeMixin(DatabaseProtocol):
"""Get all resumes for a specific job created by a user"""
try:
all_resumes = await self.get_all_resumes_for_user(user_id)
job_resumes = [
resume for resume in all_resumes
if resume.get("job_id") == job_id
]
job_resumes = [resume for resume in all_resumes if resume.get("job_id") == job_id]
logger.info(f"📄 Found {len(job_resumes)} resumes for job {job_id} by user {user_id}")
return job_resumes
@ -242,7 +239,7 @@ class ResumeMixin(DatabaseProtocol):
"resumes_by_candidate": {},
"resumes_by_job": {},
"creation_timeline": {},
"recent_resumes": []
"recent_resumes": [],
}
for resume in all_resumes:
@ -269,7 +266,13 @@ class ResumeMixin(DatabaseProtocol):
return stats
except Exception as e:
logger.error(f"❌ Error getting resume statistics for user {user_id}: {e}")
return {"total_resumes": 0, "resumes_by_candidate": {}, "resumes_by_job": {}, "creation_timeline": {}, "recent_resumes": []}
return {
"total_resumes": 0,
"resumes_by_candidate": {},
"resumes_by_job": {},
"creation_timeline": {},
"recent_resumes": [],
}
async def update_resume(self, user_id: str, resume_id: str, updates: Dict) -> Optional[Dict]:
"""Update specific fields of a resume"""

View File

@ -8,6 +8,7 @@ from .protocols import DatabaseProtocol
logger = logging.getLogger(__name__)
class SkillMixin(DatabaseProtocol):
"""Mixin for Skill-related database operations"""
@ -74,7 +75,9 @@ class SkillMixin(DatabaseProtocol):
# Cache for 1 hour by default
await self.redis.set(
cache_key,
json.dumps(assessment.model_dump(mode='json', by_alias=True), default=str) # Serialize with datetime handling
json.dumps(
assessment.model_dump(mode="json", by_alias=True), default=str
), # Serialize with datetime handling
)
logger.info(f"💾 Skill match cached: {cache_key}")
except Exception as e:

View File

@ -10,6 +10,7 @@ from ..constants import KEY_PREFIXES
logger = logging.getLogger(__name__)
class UserMixin(DatabaseProtocol):
"""Mixin for user operations"""
@ -23,13 +24,13 @@ class UserMixin(DatabaseProtocol):
guest_data["last_activity"] = datetime.now(UTC).isoformat()
# Store in Redis with both hash and individual key for redundancy
await self.redis.hset("guests", guest_id, json.dumps(guest_data))# type: ignore
await self.redis.hset("guests", guest_id, json.dumps(guest_data)) # type: ignore
# Also store with a longer TTL as backup
await self.redis.setex(
f"guest_backup:{guest_id}",
86400 * 7, # 7 days TTL
json.dumps(guest_data)
json.dumps(guest_data),
)
logger.info(f"💾 Guest stored with backup: {guest_id}")
@ -41,7 +42,7 @@ class UserMixin(DatabaseProtocol):
"""Get guest data with fallback to backup"""
try:
# Try primary storage first
data = await self.redis.hget("guests", guest_id)# type: ignore
data = await self.redis.hget("guests", guest_id) # type: ignore
if data:
guest_data = json.loads(data)
# Update last activity when accessed
@ -82,11 +83,8 @@ class UserMixin(DatabaseProtocol):
async def get_all_guests(self) -> Dict[str, Dict[str, Any]]:
"""Get all guests"""
try:
data = await self.redis.hgetall("guests")# type: ignore
return {
guest_id: json.loads(guest_json)
for guest_id, guest_json in data.items()
}
data = await self.redis.hgetall("guests") # type: ignore
return {guest_id: json.loads(guest_json) for guest_id, guest_json in data.items()}
except Exception as e:
logger.error(f"❌ Error getting all guests: {e}")
return {}
@ -94,7 +92,7 @@ class UserMixin(DatabaseProtocol):
async def delete_guest(self, guest_id: str) -> bool:
"""Delete a guest"""
try:
result = await self.redis.hdel("guests", guest_id)# type: ignore
result = await self.redis.hdel("guests", guest_id) # type: ignore
if result:
logger.info(f"🗑️ Guest deleted: {guest_id}")
return True
@ -120,7 +118,7 @@ class UserMixin(DatabaseProtocol):
# Skip cleanup if guest is very new (less than 1 hour old)
if created_at_str:
created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00"))
if current_time - created_at < timedelta(hours=1):
preserved_count += 1
logger.info(f"🛡️ Preserving new guest: {guest_id}")
@ -130,7 +128,7 @@ class UserMixin(DatabaseProtocol):
should_delete = False
if last_activity_str:
try:
last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00'))
last_activity = datetime.fromisoformat(last_activity_str.replace("Z", "+00:00"))
if last_activity < cutoff_time:
should_delete = True
except ValueError:
@ -172,7 +170,7 @@ class UserMixin(DatabaseProtocol):
"active_last_day": 0,
"converted_guests": 0,
"by_ip": {},
"creation_timeline": {}
"creation_timeline": {},
}
hour_ago = current_time - timedelta(hours=1)
@ -183,7 +181,7 @@ class UserMixin(DatabaseProtocol):
last_activity_str = guest_data.get("last_activity")
if last_activity_str:
try:
last_activity = datetime.fromisoformat(last_activity_str.replace('Z', '+00:00'))
last_activity = datetime.fromisoformat(last_activity_str.replace("Z", "+00:00"))
if last_activity > hour_ago:
stats["active_last_hour"] += 1
if last_activity > day_ago:
@ -203,8 +201,8 @@ class UserMixin(DatabaseProtocol):
created_at_str = guest_data.get("created_at")
if created_at_str:
try:
created_at = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
date_key = created_at.strftime('%Y-%m-%d')
created_at = datetime.fromisoformat(created_at_str.replace("Z", "+00:00"))
date_key = created_at.strftime("%Y-%m-%d")
stats["creation_timeline"][date_key] = stats["creation_timeline"].get(date_key, 0) + 1
except ValueError:
pass
@ -278,7 +276,7 @@ class UserMixin(DatabaseProtocol):
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Get user lookup data by user ID"""
try:
data = await self.redis.hget("user_lookup_by_id", user_id)# type: ignore
data = await self.redis.hget("user_lookup_by_id", user_id) # type: ignore
if data:
return json.loads(data)
return None
@ -321,7 +319,7 @@ class UserMixin(DatabaseProtocol):
result = {}
for key, value in zip(keys, values):
email = key.replace(KEY_PREFIXES['users'], '')
email = key.replace(KEY_PREFIXES["users"], "")
logger.info(f"🔍 Found user key: {key}, type: {type(value)}")
if type(value) == str:
result[email] = value
@ -364,7 +362,6 @@ class UserMixin(DatabaseProtocol):
logger.error(f"❌ Error storing user {login}: {e}")
return False
# ================
# Employers
# ================
@ -394,7 +391,7 @@ class UserMixin(DatabaseProtocol):
result = {}
for key, value in zip(keys, values):
employer_id = key.replace(KEY_PREFIXES['employers'], '')
employer_id = key.replace(KEY_PREFIXES["employers"], "")
result[employer_id] = self._deserialize(value)
return result
@ -404,7 +401,6 @@ class UserMixin(DatabaseProtocol):
key = f"{KEY_PREFIXES['employers']}{employer_id}"
await self.redis.delete(key)
# ================
# Candidates
# ================
@ -435,7 +431,7 @@ class UserMixin(DatabaseProtocol):
result = {}
for key, value in zip(keys, values):
candidate_id = key.replace(KEY_PREFIXES['candidates'], '')
candidate_id = key.replace(KEY_PREFIXES["candidates"], "")
result[candidate_id] = self._deserialize(value)
return result
@ -456,7 +452,7 @@ class UserMixin(DatabaseProtocol):
"security_logs": 0,
"ai_parameters": 0,
"candidate_record": 0,
"resumes": 0
"resumes": 0,
}
logger.info(f"🗑️ Starting cascading delete for candidate {candidate_id}")
@ -495,7 +491,9 @@ class UserMixin(DatabaseProtocol):
deletion_stats["chat_sessions"] = len(candidate_sessions)
deletion_stats["chat_messages"] = messages_deleted
logger.info(f"🗑️ Deleted {len(candidate_sessions)} chat sessions and {messages_deleted} messages for candidate {candidate_id}")
logger.info(
f"🗑️ Deleted {len(candidate_sessions)} chat sessions and {messages_deleted} messages for candidate {candidate_id}"
)
except Exception as e:
logger.error(f"❌ Error deleting chat sessions: {e}")
@ -528,9 +526,11 @@ class UserMixin(DatabaseProtocol):
logger.info(f"🗑️ Deleted user record by email: {candidate_email}")
# Delete by username (if different from email)
if (candidate_username and
candidate_username != candidate_email and
await self.user_exists_by_username(candidate_username)):
if (
candidate_username
and candidate_username != candidate_email
and await self.user_exists_by_username(candidate_username)
):
await self.delete_user(candidate_username)
user_records_deleted += 1
logger.info(f"🗑️ Deleted user record by username: {candidate_username}")
@ -593,8 +593,7 @@ class UserMixin(DatabaseProtocol):
candidate_ai_params = []
for param_id, param_data in all_ai_params.items():
if (param_data.get("candidateId") == candidate_id or
param_data.get("userId") == candidate_id):
if param_data.get("candidateId") == candidate_id or param_data.get("userId") == candidate_id:
candidate_ai_params.append(param_id)
# Delete each AI parameter set
@ -630,7 +629,9 @@ class UserMixin(DatabaseProtocol):
break
if tokens_deleted > 0:
logger.info(f"🗑️ Deleted {tokens_deleted} email verification tokens for candidate {candidate_id}")
logger.info(
f"🗑️ Deleted {tokens_deleted} email verification tokens for candidate {candidate_id}"
)
except Exception as e:
logger.error(f"❌ Error deleting email verification tokens: {e}")
@ -717,8 +718,10 @@ class UserMixin(DatabaseProtocol):
# 15. Log the deletion as a security event (if we have admin/system user context)
try:
total_items_deleted = sum(deletion_stats.values())
logger.info(f"✅ Completed cascading delete for candidate {candidate_id}. "
f"Total items deleted: {total_items_deleted}")
logger.info(
f"✅ Completed cascading delete for candidate {candidate_id}. "
f"Total items deleted: {total_items_deleted}"
)
logger.info(f"📊 Deletion breakdown: {deletion_stats}")
except Exception as e:
logger.error(f"❌ Error logging deletion summary: {e}")
@ -774,8 +777,8 @@ class UserMixin(DatabaseProtocol):
"total_candidates_processed": len(candidate_ids),
"successful_deletions": len([r for r in batch_results.values() if "error" not in r]),
"failed_deletions": len([r for r in batch_results.values() if "error" in r]),
"total_items_deleted": sum(total_stats.values())
}
"total_items_deleted": sum(total_stats.values()),
},
}
except Exception as e:
@ -816,7 +819,7 @@ class UserMixin(DatabaseProtocol):
"total_sessions": 0,
"total_messages": 0,
"first_chat": None,
"last_chat": None
"last_chat": None,
}
total_messages = 0
@ -835,7 +838,7 @@ class UserMixin(DatabaseProtocol):
"total_messages": total_messages,
"first_chat": sessions_by_date[0].get("createdAt") if sessions_by_date else None,
"last_chat": sessions_by_date[-1].get("lastActivity") if sessions_by_date else None,
"recent_sessions": sessions[:5] # Last 5 sessions
"recent_sessions": sessions[:5], # Last 5 sessions
}
# ================
@ -868,7 +871,7 @@ class UserMixin(DatabaseProtocol):
result = {}
for key, value in zip(keys, values):
viewer_id = key.replace(KEY_PREFIXES['viewers'], '')
viewer_id = key.replace(KEY_PREFIXES["viewers"], "")
result[viewer_id] = self._deserialize(value)
return result
@ -877,4 +880,3 @@ class UserMixin(DatabaseProtocol):
"""Delete viewer"""
key = f"{KEY_PREFIXES['viewers']}{viewer_id}"
await self.redis.delete(key)

View File

@ -3,20 +3,20 @@ import os
ollama_api_url = "http://ollama:11434" # Default Ollama local endpoint
user_dir = "/opt/backstory/users"
user_info_file = "info.json" # Relative to "{user_dir}/{user}"
user_info_file = "info.json" # Relative to "{user_dir}/{user}"
default_username = "jketreno"
rag_content_dir = "rag-content" # Relative to "{user_dir}/{user}"
rag_content_dir = "rag-content" # Relative to "{user_dir}/{user}"
# Path to candidate full resume
resume_doc_dir = f"{rag_content_dir}/resume" # Relative to "{user_dir}/{user}
resume_doc_dir = f"{rag_content_dir}/resume" # Relative to "{user_dir}/{user}
resume_doc = "resume.md"
persist_directory = "db" # Relative to "{user_dir}/{user}"
persist_directory = "db" # Relative to "{user_dir}/{user}"
# Model name License Notes
# model = "deepseek-r1:7b" # MIT Tool calls don"t work
# model = "gemma3:4b" # Gemma Requires newer ollama https://ai.google.dev/gemma/terms
# model = "llama3.2" # Llama Good results; qwen seems slightly better https://huggingface.co/meta-llama/Llama-3.2-1B/blob/main/LICENSE.txt
# model = "mistral:7b" # Apache 2.0 Tool calls don"t work
model = "qwen2.5:7b" # Apache 2.0 Good results
model = "qwen2.5:7b" # Apache 2.0 Good results
# model = "qwen3:8b" # Apache 2.0 Requires newer ollama
model = os.getenv("MODEL_NAME", model)
@ -40,7 +40,7 @@ logging_level = os.getenv("LOGGING_LEVEL", "INFO").upper()
# RAG and Vector DB settings
## Where to read RAG content
chunk_buffer = 5 # Number of lines before and after chunk beyond the portion used in embedding (to return to callers)
chunk_buffer = 5 # Number of lines before and after chunk beyond the portion used in embedding (to return to callers)
# Maximum number of entries for ChromaDB to find
default_rag_top_k = 50
@ -60,7 +60,7 @@ cert_path = "/opt/backstory/keys/cert.pem"
host = os.getenv("BACKSTORY_HOST", "0.0.0.0")
port = int(os.getenv("BACKSTORY_PORT", "8911"))
api_prefix = "/api/1.0"
debug=os.getenv("BACKSTORY_DEBUG", "false").lower() in ("true", "1", "yes")
debug = os.getenv("BACKSTORY_DEBUG", "false").lower() in ("true", "1", "yes")
# Used for filtering tracebacks
app_path="/opt/backstory/src/backend"
app_path = "/opt/backstory/src/backend"

View File

@ -6,6 +6,7 @@ from datetime import datetime, timezone
from user_agents import parse
import json
class DeviceManager:
def __init__(self, database: RedisDatabase):
self.database = database
@ -13,7 +14,6 @@ class DeviceManager:
def generate_device_fingerprint(self, request: Request) -> str:
"""Generate device fingerprint from request"""
user_agent = request.headers.get("user-agent", "")
ip_address = request.client.host if request.client else "unknown"
accept_language = request.headers.get("accept-language", "")
# Create fingerprint
@ -35,7 +35,7 @@ class DeviceManager:
"os": user_agent.os.family,
"os_version": user_agent.os.version_string,
"ip_address": request.client.host if request.client else "unknown",
"user_agent": user_agent_string
"user_agent": user_agent_string,
}
async def is_trusted_device(self, user_id: str, device_id: str) -> bool:
@ -55,14 +55,14 @@ class DeviceManager:
device_data = {
**device_info,
"added_at": datetime.now(timezone.utc).isoformat(),
"last_used": datetime.now(timezone.utc).isoformat()
"last_used": datetime.now(timezone.utc).isoformat(),
}
# Store for 90 days
await self.database.redis.setex(
key,
90 * 24 * 60 * 60, # 90 days in seconds
json.dumps(device_data, default=str)
json.dumps(device_data, default=str),
)
logger.info(f"🔒 Added trusted device {device_id} for user {user_id}")
@ -80,7 +80,7 @@ class DeviceManager:
await self.database.redis.setex(
key,
90 * 24 * 60 * 60, # Reset 90 day expiry
json.dumps(device_info, default=str)
json.dumps(device_info, default=str),
)
except Exception as e:
logger.error(f"Error updating device last used: {e}")

View File

@ -10,6 +10,7 @@ from datetime import datetime, timezone, timedelta
import json
from database.manager import RedisDatabase
class EmailService:
def __init__(self):
# Configure these in your .env file
@ -30,36 +31,25 @@ class EmailService:
def _format_template(self, template: str, **kwargs) -> str:
"""Format template with provided variables"""
return template.format(
app_name=self.app_name,
from_name=self.from_name,
frontend_url=self.frontend_url,
**kwargs
app_name=self.app_name, from_name=self.from_name, frontend_url=self.frontend_url, **kwargs
)
async def send_verification_email(
self,
to_email: str,
verification_token: str,
user_name: str,
user_type: str = "user"
self, to_email: str, verification_token: str, user_name: str, user_type: str = "user"
):
"""Send email verification email using template"""
try:
template = self._get_template("verification")
verification_link = f"{self.frontend_url}/login/verify-email?token={verification_token}"
subject = self._format_template(
template["subject"],
user_name=user_name,
to_email=to_email
)
subject = self._format_template(template["subject"], user_name=user_name, to_email=to_email)
html_content = self._format_template(
template["html"],
user_name=user_name,
user_type=user_type,
to_email=to_email,
verification_link=verification_link
verification_link=verification_link,
)
await self._send_email(to_email, subject, html_content)
@ -70,12 +60,7 @@ class EmailService:
raise
async def send_mfa_email(
self,
to_email: str,
mfa_code: str,
device_name: str,
user_name: str,
ip_address: str = "Unknown"
self, to_email: str, mfa_code: str, device_name: str, user_name: str, ip_address: str = "Unknown"
):
"""Send MFA code email using template"""
try:
@ -91,7 +76,7 @@ class EmailService:
ip_address=ip_address,
login_time=login_time,
mfa_code=mfa_code,
to_email=to_email
to_email=to_email,
)
await self._send_email(to_email, subject, html_content)
@ -101,12 +86,7 @@ class EmailService:
logger.error(f"❌ Failed to send MFA email to {to_email}: {e}")
raise
async def send_password_reset_email(
self,
to_email: str,
reset_token: str,
user_name: str
):
async def send_password_reset_email(self, to_email: str, reset_token: str, user_name: str):
"""Send password reset email using template"""
try:
template = self._get_template("password_reset")
@ -115,10 +95,7 @@ class EmailService:
subject = self._format_template(template["subject"])
html_content = self._format_template(
template["html"],
user_name=user_name,
reset_link=reset_link,
to_email=to_email
template["html"], user_name=user_name, reset_link=reset_link, to_email=to_email
)
await self._send_email(to_email, subject, html_content)
@ -134,14 +111,14 @@ class EmailService:
if not self.email_user:
raise ValueError("Email user is not configured")
# Create message
msg = MIMEMultipart('alternative')
msg['From'] = f"{self.from_name} <{self.email_user}>"
msg['To'] = to_email
msg['Subject'] = subject
msg['Reply-To'] = self.email_user
msg = MIMEMultipart("alternative")
msg["From"] = f"{self.from_name} <{self.email_user}>"
msg["To"] = to_email
msg["Subject"] = subject
msg["Reply-To"] = self.email_user
# Add HTML content
html_part = MIMEText(html_content, 'html', 'utf-8')
html_part = MIMEText(html_content, "html", "utf-8")
msg.attach(html_part)
# Send email with connection pooling and retry logic
@ -157,7 +134,6 @@ class EmailService:
text = msg.as_string()
server.sendmail(self.email_user, to_email, text)
break # Success, exit retry loop
except smtplib.SMTPException as e:
if attempt == max_retries - 1: # Last attempt
raise
@ -170,6 +146,7 @@ class EmailService:
logger.error(f"❌ SMTP error sending to {to_email}: {e}")
raise
class EmailRateLimiter:
def __init__(self, database: RedisDatabase):
self.database = database
@ -191,10 +168,7 @@ class EmailRateLimiter:
email_records = json.loads(count_data)
# Filter out old records
recent_records = [
record for record in email_records
if datetime.fromisoformat(record) > window_start
]
recent_records = [record for record in email_records if datetime.fromisoformat(record) > window_start]
if len(recent_records) >= limit:
logger.warning(f"⚠️ Email rate limit exceeded for {email} ({email_type})")
@ -202,11 +176,7 @@ class EmailRateLimiter:
# Add current email to records
recent_records.append(current_time.isoformat())
await self.database.redis.setex(
key,
window_minutes * 60,
json.dumps(recent_records)
)
await self.database.redis.setex(key, window_minutes * 60, json.dumps(recent_records))
return True
@ -217,11 +187,8 @@ class EmailRateLimiter:
async def _record_email_sent(self, key: str, timestamp: datetime, ttl_minutes: int):
"""Record that an email was sent"""
await self.database.redis.setex(
key,
ttl_minutes * 60,
json.dumps([timestamp.isoformat()])
)
await self.database.redis.setex(key, ttl_minutes * 60, json.dumps([timestamp.isoformat()]))
class VerificationEmailRateLimiter:
def __init__(self, database: RedisDatabase):
@ -242,7 +209,10 @@ class VerificationEmailRateLimiter:
# Check daily limit
daily_count = await self.database.get_verification_attempts_count(email)
if daily_count >= self.max_attempts_per_day:
return False, f"Daily limit reached. You can request up to {self.max_attempts_per_day} verification emails per day."
return (
False,
f"Daily limit reached. You can request up to {self.max_attempts_per_day} verification emails per day.",
)
# Check hourly limit
hour_ago = current_time - timedelta(hours=1)
@ -251,13 +221,13 @@ class VerificationEmailRateLimiter:
if data:
attempts_data = json.loads(data)
recent_attempts = [
attempt for attempt in attempts_data
if datetime.fromisoformat(attempt) > hour_ago
]
recent_attempts = [attempt for attempt in attempts_data if datetime.fromisoformat(attempt) > hour_ago]
if len(recent_attempts) >= self.max_attempts_per_hour:
return False, f"Hourly limit reached. You can request up to {self.max_attempts_per_hour} verification emails per hour."
return (
False,
f"Hourly limit reached. You can request up to {self.max_attempts_per_hour} verification emails per hour.",
)
# Check cooldown period
if recent_attempts:
@ -280,7 +250,4 @@ class VerificationEmailRateLimiter:
await self.database.record_verification_attempt(email)
email_service = EmailService()

View File

@ -129,9 +129,8 @@ EMAIL_TEMPLATES = {
</div>
</body>
</html>
"""
""",
},
"mfa": {
"subject": "Security Code for Backstory",
"html": """
@ -274,9 +273,8 @@ EMAIL_TEMPLATES = {
</div>
</body>
</html>
"""
""",
},
"password_reset": {
"subject": "Reset your Backstory password",
"html": """
@ -386,6 +384,6 @@ EMAIL_TEMPLATES = {
</div>
</body>
</html>
"""
}
""",
},
}

View File

@ -3,12 +3,13 @@ import weakref
from datetime import datetime, timedelta
from typing import Dict, Optional
from contextlib import asynccontextmanager
from pydantic import BaseModel, Field # type: ignore
from pydantic import BaseModel # type: ignore
from models import Candidate
from agents.base import CandidateEntity
from database.manager import RedisDatabase
from prometheus_client import CollectorRegistry # type: ignore
from prometheus_client import CollectorRegistry # type: ignore
class EntityManager(BaseModel):
"""Manages lifecycle of CandidateEntity instances"""
@ -36,10 +37,7 @@ class EntityManager(BaseModel):
pass
self._cleanup_task = None
def initialize(
self,
prometheus_collector: CollectorRegistry,
database: RedisDatabase):
def initialize(self, prometheus_collector: CollectorRegistry, database: RedisDatabase):
"""Initialize the EntityManager with Prometheus collector"""
self._prometheus_collector = prometheus_collector
self._database = database
@ -58,9 +56,7 @@ class EntityManager(BaseModel):
raise ValueError("EntityManager has not been initialized with required components.")
entity = CandidateEntity(candidate=candidate)
await entity.initialize(
prometheus_collector=self._prometheus_collector,
database=self._database)
await entity.initialize(prometheus_collector=self._prometheus_collector, database=self._database)
# Store with reference tracking
self._entities[candidate.id] = entity
@ -105,10 +101,12 @@ class EntityManager(BaseModel):
def _on_entity_deleted(self, user_id: str):
"""Callback when entity is garbage collected"""
def cleanup_callback(weak_ref):
self._entities.pop(user_id, None)
self._weak_refs.pop(user_id, None)
print(f"Entity {user_id} garbage collected")
return cleanup_callback
async def release_entity(self, user_id: str):
@ -138,8 +136,7 @@ class EntityManager(BaseModel):
time_since_access = current_time - entity.last_accessed
# Remove if TTL exceeded and no active references
if (time_since_access > timedelta(minutes=self._ttl_minutes)
and entity.reference_count == 0):
if time_since_access > timedelta(minutes=self._ttl_minutes) and entity.reference_count == 0:
expired_entities.append(user_id)
for user_id in expired_entities:
@ -153,6 +150,7 @@ class EntityManager(BaseModel):
# Global entity manager instance
entity_manager = EntityManager(default_ttl_minutes=30)
@asynccontextmanager
async def get_candidate_entity(candidate: Candidate):
"""Context manager for safe entity access with automatic reference management"""
@ -164,4 +162,5 @@ async def get_candidate_entity(candidate: Candidate):
finally:
await entity_manager.release_entity(candidate.id)
EntityManager.model_rebuild()

View File

@ -6,10 +6,7 @@ without getting caught up in serialization format complexities
import sys
from datetime import datetime
from models import (
UserStatus, UserType, SkillLevel, EmploymentType,
Candidate, Employer, Location, Skill
)
from models import UserStatus, UserType, SkillLevel, EmploymentType, Candidate, Employer, Location, Skill
def test_model_creation():
@ -23,37 +20,38 @@ def test_model_creation():
# Create candidate
candidate = Candidate(
email="test@example.com",
user_type=UserType.CANDIDATE,
username="test_candidate",
createdAt=datetime.now(),
updatedAt=datetime.now(),
created_at=datetime.now(),
updated_at=datetime.now(),
status=UserStatus.ACTIVE,
firstName="John",
lastName="Doe",
fullName="John Doe",
first_name="John",
last_name="Doe",
full_name="John Doe",
skills=[skill],
experience=[],
education=[],
preferredJobTypes=[EmploymentType.FULL_TIME],
preferred_job_types=[EmploymentType.FULL_TIME],
location=location,
languages=[],
certifications=[]
certifications=[],
)
# Create employer
employer = Employer(
firstName="Mary",
lastName="Smith",
fullName="Mary Smith",
user_type=UserType.EMPLOYER,
first_name="Mary",
last_name="Smith",
full_name="Mary Smith",
email="hr@company.com",
username="test_employer",
createdAt=datetime.now(),
updatedAt=datetime.now(),
created_at=datetime.now(),
updated_at=datetime.now(),
status=UserStatus.ACTIVE,
companyName="Test Company",
company_name="Test Company",
industry="Technology",
companySize="50-200",
companyDescription="A test company",
location=location
company_size="50-200",
company_description="A test company",
location=location,
)
print(f"✅ Candidate: {candidate.first_name} {candidate.last_name}")
@ -62,6 +60,7 @@ def test_model_creation():
return candidate, employer
def test_json_api_format():
"""Test JSON serialization in API format (the most important use case)"""
print("\n📡 Testing JSON API format...")
@ -84,11 +83,12 @@ def test_json_api_format():
assert candidate_back.first_name == candidate.first_name
assert employer_back.company_name == employer.company_name
print(f"✅ JSON round-trip successful")
print(f"✅ Data integrity verified")
print("✅ JSON round-trip successful")
print("✅ Data integrity verified")
return True
def test_api_dict_format():
"""Test dictionary format with aliases (for API requests/responses)"""
print("\n📊 Testing API dictionary format...")
@ -105,8 +105,8 @@ def test_api_dict_format():
assert "createdAt" in candidate_dict
assert "companyName" in employer_dict
print(f"✅ API format dictionaries created")
print(f"✅ CamelCase aliases verified")
print("✅ API format dictionaries created")
print("✅ CamelCase aliases verified")
# Test deserializing from API format
candidate_back = Candidate.model_validate(candidate_dict)
@ -115,25 +115,27 @@ def test_api_dict_format():
assert candidate_back.email == candidate.email
assert employer_back.company_name == employer.company_name
print(f"✅ API format round-trip successful")
print("✅ API format round-trip successful")
return True
def test_validation_constraints():
"""Test that validation constraints work"""
print("\n🔒 Testing validation constraints...")
try:
# Create a candidate with invalid email
invalid_candidate = Candidate(
Candidate(
user_type=UserType.CANDIDATE,
email="invalid-email",
username="test_invalid",
createdAt=datetime.now(),
updatedAt=datetime.now(),
created_at=datetime.now(),
updated_at=datetime.now(),
status=UserStatus.ACTIVE,
firstName="Jane",
lastName="Doe",
fullName="Jane Doe"
first_name="Jane",
last_name="Doe",
full_name="Jane Doe",
)
print("❌ Validation should have failed but didn't")
return False
@ -141,6 +143,7 @@ def test_validation_constraints():
print(f"✅ Validation error caught: {e}")
return True
def test_enum_values():
"""Test that enum values work correctly"""
print("\n📋 Testing enum values...")
@ -155,11 +158,12 @@ def test_enum_values():
assert candidate_dict["userType"] == "candidate"
assert employer.user_type == UserType.EMPLOYER
print(f"✅ Enum values correctly serialized")
print("✅ Enum values correctly serialized")
print(f"✅ User types: candidate={candidate.user_type}, employer={employer.user_type}")
return True
def main():
"""Run all focused tests"""
print("🎯 Focused Pydantic Model Tests")
@ -172,7 +176,7 @@ def main():
test_validation_constraints()
test_enum_values()
print(f"\n🎉 All focused tests passed!")
print("\n🎉 All focused tests passed!")
print("=" * 40)
print("✅ Models work correctly")
print("✅ JSON API format works")
@ -185,10 +189,12 @@ def main():
except Exception as e:
print(f"\n❌ Test failed: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
print(f"\n{traceback.format_exc()}")
return False
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)

View File

@ -2,6 +2,7 @@ from pydantic import BaseModel
import json
from typing import Any, List, Set
def check_serializable(obj: Any, path: str = "", errors: List[str] = [], visited: Set[int] = set()) -> List[str]:
"""
Recursively check all fields in an object for non-JSON-serializable types, avoiding infinite recursion.

View File

@ -10,16 +10,19 @@ assert issubclass(CandidateAI, BaseUserWithType), "CandidateAI must inherit from
assert issubclass(Employer, BaseUserWithType), "Employer must inherit from BaseUserWithType"
assert issubclass(Guest, BaseUserWithType), "Guest must inherit from BaseUserWithType"
T = TypeVar('T', bound=BaseModel)
T = TypeVar("T", bound=BaseModel)
def cast_to_model(model_cls: Type[T], source: BaseModel) -> T:
data = {field: getattr(source, field) for field in model_cls.__fields__}
return model_cls(**data)
def cast_to_model_safe(model_cls: Type[T], source: BaseModel) -> T:
data = {field: copy.deepcopy(getattr(source, field)) for field in model_cls.__fields__}
return model_cls(**data)
def cast_to_base_user_with_type(user) -> BaseUserWithType:
"""
Casts a Candidate, CandidateAI, Employer, or Guest to BaseUserWithType.

View File

@ -7,7 +7,8 @@ from typing import Any
import torch
from diffusers import StableDiffusionPipeline, FluxPipeline
class ImageModelCache: # Stay loaded for 3 hours
class ImageModelCache: # Stay loaded for 3 hours
def __init__(self, timeout_seconds: float = 3 * 60 * 60):
self._pipe = None
self._model_name = None
@ -36,11 +37,11 @@ class ImageModelCache: # Stay loaded for 3 hours
cached_model_type = self._get_model_type(self._model_name) if self._model_name else None
if (
self._pipe is not None and
self._model_name == model and
self._device == device and
current_model_type == cached_model_type and
current_time - self._last_access_time < self._timeout_seconds
self._pipe is not None
and self._model_name == model
and self._device == device
and current_model_type == cached_model_type
and current_time - self._last_access_time < self._timeout_seconds
):
self._last_access_time = current_time
return self._pipe
@ -52,8 +53,10 @@ class ImageModelCache: # Stay loaded for 3 hours
model,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
)
def dummy_safety_checker(images, clip_input):
return images, [False] * len(images)
pipe.safety_checker = dummy_safety_checker
else:
pipe = FluxPipeline.from_pretrained(
@ -61,7 +64,7 @@ class ImageModelCache: # Stay loaded for 3 hours
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
)
try:
pipe.load_lora_weights('enhanceaiteam/Flux-uncensored', weight_name='lora.safetensors')
pipe.load_lora_weights("enhanceaiteam/Flux-uncensored", weight_name="lora.safetensors")
except Exception as e:
raise Exception(f"Failed to load LoRA weights: {str(e)}")
@ -89,10 +92,7 @@ class ImageModelCache: # Stay loaded for 3 hours
async def cleanup_if_expired(self):
async with self._lock:
if (
self._pipe is not None and
time.time() - self._last_access_time >= self._timeout_seconds
):
if self._pipe is not None and time.time() - self._last_access_time >= self._timeout_seconds:
await self._unload_model()
async def _periodic_cleanup(self):

View File

@ -29,6 +29,7 @@ TIME_ESTIMATES = {
}
}
class ImageRequest(BaseModel):
session_id: str
filepath: str
@ -39,18 +40,22 @@ class ImageRequest(BaseModel):
width: int = 256
guidance_scale: float = 7.5
# Global model cache instance
model_cache = ImageModelCache()
def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task_id: str):
"""Background worker for Flux image generation"""
try:
# Flux: Run generation in the background and yield progress updates
status_queue.put(ChatMessageStatus(
session_id=params.session_id,
content=f"Initializing image generation.",
activity=ApiActivityType.GENERATING_IMAGE,
))
status_queue.put(
ChatMessageStatus(
session_id=params.session_id,
content="Initializing image generation.",
activity=ApiActivityType.GENERATING_IMAGE,
)
)
# Start the generation task
start_gen_time = time.time()
@ -58,13 +63,15 @@ def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task
# Simulate your pipe call with progress updates
def status_callback(pipeline, step, timestep, callback_kwargs):
# Send progress updates
progress = int((step+1) / params.iterations * 100)
progress = int((step + 1) / params.iterations * 100)
status_queue.put(ChatMessageStatus(
session_id=params.session_id,
content=f"Processing step {step+1}/{params.iterations} ({progress}%)",
activity=ApiActivityType.GENERATING_IMAGE,
))
status_queue.put(
ChatMessageStatus(
session_id=params.session_id,
content=f"Processing step {step+1}/{params.iterations} ({progress}%)",
activity=ApiActivityType.GENERATING_IMAGE,
)
)
return callback_kwargs
# Replace this block with your actual Flux pipe call:
@ -84,22 +91,28 @@ def flux_worker(pipe: Any, params: ImageRequest, status_queue: queue.Queue, task
image.save(params.filepath)
# Final completion status
status_queue.put(ChatMessage(
session_id=params.session_id,
status=ApiStatusType.DONE,
content=f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.",
))
status_queue.put(
ChatMessage(
session_id=params.session_id,
status=ApiStatusType.DONE,
content=f"Image generated in {gen_time:.1f} seconds, {per_step_time:.1f} per iteration.",
)
)
except Exception as e:
logger.error(traceback.format_exc())
logger.error(e)
status_queue.put(ChatMessageError(
session_id=params.session_id,
content=f"Error during image generation: {str(e)}",
))
status_queue.put(
ChatMessageError(
session_id=params.session_id,
content=f"Error during image generation: {str(e)}",
)
)
async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError, None]:
async def async_generate_image(
pipe: Any, params: ImageRequest
) -> AsyncGenerator[ChatMessage | ChatMessageStatus | ChatMessageError, None]:
"""
Single async function that handles background Flux generation with status streaming
"""
@ -109,18 +122,14 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato
try:
# Start background worker thread
worker_thread = Thread(
target=flux_worker,
args=(pipe, params, status_queue, task_id),
daemon=True
)
worker_thread = Thread(target=flux_worker, args=(pipe, params, status_queue, task_id), daemon=True)
worker_thread.start()
# Initial status
status_message = ChatMessageStatus(
session_id=params.session_id,
content=f"Starting image generation with task ID {task_id}",
activity=ApiActivityType.THINKING
activity=ApiActivityType.THINKING,
)
yield status_message
@ -177,18 +186,16 @@ async def async_generate_image(pipe: Any, params: ImageRequest) -> AsyncGenerato
yield final_status
except Exception as e:
error_status = ChatMessageError(
session_id=params.session_id,
content=f'Server error: {str(e)}'
)
error_status = ChatMessageError(session_id=params.session_id, content=f"Server error: {str(e)}")
logger.error(error_status)
yield error_status
finally:
# Cleanup: ensure thread completion
if worker_thread and 'worker_thread' in locals() and worker_thread.is_alive():
if worker_thread and "worker_thread" in locals() and worker_thread.is_alive():
worker_thread.join(timeout=1.0) # Wait up to 1 second for cleanup
def status(session_id: str, status: str) -> ChatMessageStatus:
"""Update chat message status and return it."""
chat_message = ChatMessageStatus(
@ -198,6 +205,7 @@ def status(session_id: str, status: str) -> ChatMessageStatus:
)
return chat_message
async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, None]:
"""Generate an image with specified dimensions and yield status updates with time estimates."""
session_id = request.session_id
@ -205,10 +213,7 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
try:
# Validate prompt
if not prompt:
error_message = ChatMessageError(
session_id=session_id,
content="Prompt cannot be empty."
)
error_message = ChatMessageError(session_id=session_id, content="Prompt cannot be empty.")
logger.error(error_message.content)
yield error_message
return
@ -216,15 +221,14 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
# Validate dimensions
if request.height <= 0 or request.width <= 0:
error_message = ChatMessageError(
session_id=session_id,
content="Height and width must be positive integers."
session_id=session_id, content="Height and width must be positive integers."
)
logger.error(error_message.content)
yield error_message
return
filedir = os.path.dirname(request.filepath)
filename = os.path.basename(request.filepath)
os.path.basename(request.filepath)
os.makedirs(filedir, exist_ok=True)
model_type = "flux"
@ -233,14 +237,17 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
# Get initial time estimate, scaled by resolution
estimates = TIME_ESTIMATES[model_type][device]
resolution_scale = (request.height * request.width) / (512 * 512)
estimated_total = estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale
estimates["load"] + estimates["per_step"] * request.iterations * resolution_scale
# Initialize or get cached pipeline
start_time = time.time()
yield status(session_id, f"Loading generative image model...")
yield status(session_id, "Loading generative image model...")
pipe = await model_cache.get_pipeline(request.model, device)
load_time = time.time() - start_time
yield status(session_id, f"Model loaded in {load_time:.1f} seconds.",)
yield status(
session_id,
f"Model loaded in {load_time:.1f} seconds.",
)
progress = None
async for progress in async_generate_image(pipe, request):
@ -252,8 +259,7 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
if not progress:
error_message = ChatMessageError(
session_id=session_id,
content="Image generation failed to produce a valid response."
session_id=session_id, content="Image generation failed to produce a valid response."
)
logger.error(f"⚠️ {error_message.content}")
yield error_message
@ -269,10 +275,7 @@ async def generate_image(request: ImageRequest) -> AsyncGenerator[ChatMessage, N
yield chat_message
except Exception as e:
error_message = ChatMessageError(
session_id=session_id,
content=f"Error during image generation: {str(e)}"
)
error_message = ChatMessageError(session_id=session_id, content=f"Error during image generation: {str(e)}")
logger.error(traceback.format_exc())
logger.error(error_message.content)
yield error_message

View File

@ -2,6 +2,7 @@ import json
import re
from typing import List, Union
def extract_json_blocks(text: str, allow_multiple: bool = False) -> List[dict]:
"""
Extract JSON blocks from text, even if surrounded by markdown or noisy text.
@ -31,13 +32,14 @@ def extract_json_blocks(text: str, allow_multiple: bool = False) -> List[dict]:
return found
def _extract_standalone_json(text: str, allow_multiple: bool = False) -> List[Union[dict, list]]:
"""Extract standalone JSON objects or arrays from text using proper brace counting."""
found = []
i = 0
while i < len(text):
if text[i] in '{[':
if text[i] in "{[":
# Found potential JSON start
json_str = _extract_complete_json_at_position(text, i)
if json_str:
@ -55,16 +57,17 @@ def _extract_standalone_json(text: str, allow_multiple: bool = False) -> List[Un
return found
def _extract_complete_json_at_position(text: str, start_pos: int) -> str:
"""
Extract a complete JSON object or array starting at the given position.
Uses proper brace/bracket counting and string escape handling.
"""
if start_pos >= len(text) or text[start_pos] not in '{[':
if start_pos >= len(text) or text[start_pos] not in "{[":
return ""
start_char = text[start_pos]
end_char = '}' if start_char == '{' else ']'
end_char = "}" if start_char == "{" else "]"
count = 1
i = start_pos + 1
@ -76,7 +79,7 @@ def _extract_complete_json_at_position(text: str, start_pos: int) -> str:
if escape_next:
escape_next = False
elif char == '\\' and in_string:
elif char == "\\" and in_string:
escape_next = True
elif char == '"' and not escape_next:
in_string = not in_string
@ -92,6 +95,7 @@ def _extract_complete_json_at_position(text: str, start_pos: int) -> str:
return text[start_pos:i]
return ""
def extract_json_from_text(text: str) -> str:
"""Extract JSON string from text that may contain other content."""
return json.dumps(extract_json_blocks(text, allow_multiple=False)[0])

View File

@ -2,11 +2,11 @@ import os
import warnings
import logging
import defines
def _setup_logging(level=defines.logging_level) -> logging.Logger:
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
warnings.filterwarnings(
"ignore", message="Overriding a previously registered kernel"
)
warnings.filterwarnings("ignore", message="Overriding a previously registered kernel")
warnings.filterwarnings("ignore", message="Warning only once for all operators")
warnings.filterwarnings("ignore", message=".*Couldn't find ffmpeg or avconv.*")
warnings.filterwarnings("ignore", message="'force_all_finite' was renamed to")
@ -19,8 +19,7 @@ def _setup_logging(level=defines.logging_level) -> logging.Logger:
# Create a custom formatter
formatter = logging.Formatter(
fmt="%(levelname)s - %(filename)s:%(lineno)d - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"
fmt="%(levelname)s - %(filename)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
# Create a handler (e.g., StreamHandler for console output)
@ -50,5 +49,6 @@ def _setup_logging(level=defines.logging_level) -> logging.Logger:
logger = logging.getLogger(__name__)
return logger
logger = _setup_logging(level=defines.logging_level)
logger.debug(f"Logging initialized with level: {defines.logging_level}")

View File

@ -58,6 +58,7 @@ background_task_manager = None
prev_int = signal.getsignal(signal.SIGINT)
prev_term = signal.getsignal(signal.SIGTERM)
def signal_handler(signum, frame):
logger.info(f"⚠️ Received signal {signum!r}, shutting down…")
# now call the old handler (it might raise KeyboardInterrupt or exit)
@ -66,9 +67,11 @@ def signal_handler(signum, frame):
elif signum == signal.SIGTERM and callable(prev_term):
prev_term(signum, frame)
# Global background task manager
background_task_manager: Optional[BackgroundTaskManager] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
@ -116,6 +119,7 @@ async def lifespan(app: FastAPI):
if db_manager:
await db_manager.graceful_shutdown()
app = FastAPI(
lifespan=lifespan,
title="Backstory API",
@ -129,11 +133,9 @@ app = FastAPI(
ssl_enabled = os.getenv("SSL_ENABLED", "true").lower() == "true"
if ssl_enabled:
allow_origins = ["https://battle-linux.ketrenos.com:3000",
"https://backstory-beta.ketrenos.com"]
allow_origins = ["https://battle-linux.ketrenos.com:3000", "https://backstory-beta.ketrenos.com"]
else:
allow_origins = ["http://battle-linux.ketrenos.com:3000",
"http://backstory-beta.ketrenos.com"]
allow_origins = ["http://battle-linux.ketrenos.com:3000", "http://backstory-beta.ketrenos.com"]
# Add CORS middleware
app.add_middleware(
@ -144,12 +146,14 @@ app.add_middleware(
allow_headers=["*"],
)
# ============================
# Debug data type failures
# ============================
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
import traceback
logger.error(traceback.format_exc())
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Validation error {request.method} {request.url.path}: {str(exc)}")
@ -158,6 +162,7 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
content=json.dumps({"detail": str(exc)}),
)
# ============================
# Create API router with prefix
# ============================
@ -181,9 +186,10 @@ api_router.include_router(users.router)
# Health Check and Info Endpoints
# ============================
@app.get("/health")
async def health_check(
database = Depends(get_database),
database=Depends(get_database),
):
"""Health check endpoint"""
try:
@ -202,15 +208,12 @@ async def health_check(
return {
"status": "healthy",
"timestamp": datetime.utcnow().isoformat(),
"database": {
"status": "connected",
"stats": stats
},
"database": {"status": "connected", "stats": stats},
"redis": {
"version": redis_info.get("redis_version", "unknown"),
"uptime": redis_info.get("uptime_in_seconds", 0),
"memory_used": redis_info.get("used_memory_human", "unknown")
}
"memory_used": redis_info.get("used_memory_human", "unknown"),
},
}
except RuntimeError as e:
@ -219,6 +222,7 @@ async def health_check(
logger.error(f"❌ Health check failed: {e}")
return {"status": "error", "message": str(e)}
# ============================
# Include Router in App
# ============================
@ -231,11 +235,14 @@ app.include_router(api_router)
# ============================
logger.info(f"Debug mode is {'enabled' if defines.debug else 'disabled'}")
@app.middleware("http")
async def log_requests(request: Request, call_next):
try:
if defines.debug and not re.match(rf"{defines.api_prefix}/metrics", request.url.path):
logger.info(f"📝 Request {request.method}: {request.url.path}, Remote: {request.client.host if request.client else ''}")
logger.info(
f"📝 Request {request.method}: {request.url.path}, Remote: {request.client.host if request.client else ''}"
)
response = await call_next(request)
if defines.debug and not re.match(rf"{defines.api_prefix}/metrics", request.url.path):
if response.status_code < 200 or response.status_code >= 300:
@ -243,11 +250,13 @@ async def log_requests(request: Request, call_next):
return response
except Exception as e:
import traceback
logger.error(traceback.format_exc())
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Error processing request: {str(e)}, Path: {request.url.path}, Method: {request.method}")
return JSONResponse(status_code=400, content={"detail": "Invalid HTTP request"})
# ============================
# Request tracking middleware
# ============================
@ -266,6 +275,7 @@ async def track_requests(request, call_next):
finally:
db_manager.decrement_requests()
# ============================
# FastAPI Metrics
# ============================
@ -277,7 +287,7 @@ instrumentator = Instrumentator(
should_ignore_untemplated=True,
should_group_untemplated=True,
excluded_handlers=[f"{defines.api_prefix}/metrics"],
registry=prometheus_collector
registry=prometheus_collector,
)
# Instrument the FastAPI app
@ -291,6 +301,7 @@ instrumentator.expose(app, endpoint=f"{defines.api_prefix}/metrics")
# Static File Serving
# ============================
@app.get("/{path:path}")
async def serve_static(path: str, request: Request):
full_path = os.path.join(defines.static_content, path)
@ -300,6 +311,7 @@ async def serve_static(path: str, request: Request):
return FileResponse(os.path.join(defines.static_content, "index.html"))
# Root endpoint when no static files
@app.get("/", include_in_schema=False)
async def root():
@ -309,9 +321,10 @@ async def root():
"version": "1.0.0",
"api_prefix": defines.api_prefix,
"documentation": f"{defines.api_prefix}/docs",
"health": f"{defines.api_prefix}/health"
"health": f"{defines.api_prefix}/health",
}
async def periodic_verification_cleanup():
"""Background task to periodically clean up expired verification tokens"""
try:
@ -324,6 +337,7 @@ async def periodic_verification_cleanup():
except Exception as e:
logger.error(f"❌ Error in periodic verification cleanup: {e}")
if __name__ == "__main__":
host = defines.host
port = defines.port

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,23 @@
[tool.ruff]
# Set line length (default is 88, Black-compatible)
line-length = 120
exclude = [
".git",
"__pycache__",
"migrations",
"venv",
"dist",
"build",
]
target-version = "py312"
[tool.ruff.format]
# Use Ruff's formatter (Black-compatible by default)
quote-style = "double" # Enforce double quotes
indent-style = "space" # Use spaces for indentation
skip-magic-trailing-comma = false # Add trailing commas where applicable
[tool.ruff.per-file-ignores]
# Ignore specific rules for certain files (e.g., tests or scripts)
"tests/**" = ["D100", "S101"] # Ignore missing docstrings and assert warnings in tests
"scripts/**" = ["E402"] # Ignore import order in scripts

View File

@ -1,7 +1,3 @@
from .rag import ChromaDBFileWatcher, start_file_watcher, RagEntry
__all__ = [
"ChromaDBFileWatcher",
"start_file_watcher",
"RagEntry"
]
__all__ = ["ChromaDBFileWatcher", "start_file_watcher", "RagEntry"]

View File

@ -7,10 +7,12 @@ import logging
import defines
class Chunk(TypedDict):
text: str
metadata: Dict[str, Any]
def clear_chunk(chunk: Chunk):
chunk["text"] = ""
chunk["metadata"] = {
@ -22,6 +24,7 @@ def clear_chunk(chunk: Chunk):
}
return chunk
class MarkdownChunker:
def __init__(self):
# Initialize the Markdown parser
@ -76,10 +79,7 @@ class MarkdownChunker:
# Initialize a chunk structure
chunk: Chunk = {
"text": "",
"metadata": {
"source_file": file_path,
"lines": total_lines
},
"metadata": {"source_file": file_path, "lines": total_lines},
}
clear_chunk(chunk)
@ -89,9 +89,7 @@ class MarkdownChunker:
return chunks
def _sanitize_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
return {
k: ("" if v is None else v) for k, v in metadata.items() if v is not None
}
return {k: ("" if v is None else v) for k, v in metadata.items() if v is not None}
def _extract_text_from_children(self, node: SyntaxTreeNode) -> str:
lines = []
@ -114,7 +112,7 @@ class MarkdownChunker:
chunks: List[Chunk],
chunk: Chunk,
level: int,
buffer: int = defines.chunk_buffer
buffer: int = defines.chunk_buffer,
) -> int:
is_list = False
# Handle heading nodes
@ -141,7 +139,7 @@ class MarkdownChunker:
if node.nester_tokens:
opening, closing = node.nester_tokens
if opening and opening.map:
( begin, end ) = opening.map
(begin, end) = opening.map
metadata = chunk["metadata"]
metadata["chunk_begin"] = max(0, begin - buffer)
metadata["chunk_end"] = min(metadata["lines"], end + buffer)
@ -186,7 +184,7 @@ class MarkdownChunker:
if node.nester_tokens:
opening, closing = node.nester_tokens
if opening and opening.map:
( begin, end ) = opening.map
(begin, end) = opening.map
metadata = chunk["metadata"]
metadata["chunk_begin"] = max(0, begin - buffer)
metadata["chunk_end"] = min(metadata["lines"], end + buffer)
@ -198,9 +196,7 @@ class MarkdownChunker:
# Recursively process children
if not is_list:
for child in node.children:
level = self._process_node(
child, current_headings, chunks, chunk, level=level
)
level = self._process_node(child, current_headings, chunks, chunk, level=level)
# After root-level recursion, finalize any remaining chunk
if node.type == "document":
@ -211,7 +207,7 @@ class MarkdownChunker:
if node.nester_tokens:
opening, closing = node.nester_tokens
if opening and opening.map:
( begin, end ) = opening.map
(begin, end) = opening.map
metadata = chunk["metadata"]
metadata["chunk_begin"] = max(0, begin - buffer)
metadata["chunk_end"] = min(metadata["lines"], end + buffer)

View File

@ -1,5 +1,5 @@
from __future__ import annotations
from pydantic import BaseModel, field_serializer, field_validator, model_validator, Field # type: ignore
from pydantic import BaseModel # type: ignore
from typing import List, Optional, Dict, Any
import os
import glob
@ -11,10 +11,10 @@ import json
import numpy as np # type: ignore
import traceback
import chromadb # type: ignore
import chromadb # type: ignore
from watchdog.observers import Observer # type: ignore
from watchdog.events import FileSystemEventHandler # type: ignore
import umap # type: ignore
import umap # type: ignore
from markitdown import MarkItDown # type: ignore
from chromadb.api.models.Collection import Collection # type: ignore
@ -33,11 +33,13 @@ __all__ = ["ChromaDBFileWatcher", "start_file_watcher"]
DEFAULT_CHUNK_SIZE = 750
DEFAULT_CHUNK_OVERLAP = 100
class RagEntry(BaseModel):
name: str
description: str = ""
enabled: bool = True
class ChromaDBFileWatcher(FileSystemEventHandler):
def __init__(
self,
@ -72,9 +74,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Path for storing file hash state
self.hash_state_path = os.path.join(
self.persist_directory, f"{collection_name}_hash_state.json"
)
self.hash_state_path = os.path.join(self.persist_directory, f"{collection_name}_hash_state.json")
# Flag to track if this is a new collection
self.is_new_collection = False
@ -158,9 +158,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
process_all: If True, process all files regardless of hash status
"""
# Check for new or modified files
file_paths = glob.glob(
os.path.join(self.watch_directory, "**/*"), recursive=True
)
file_paths = glob.glob(os.path.join(self.watch_directory, "**/*"), recursive=True)
files_checked = 0
files_processed = 0
files_to_process = []
@ -180,20 +178,12 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
continue
# If file is new, changed, or we're processing all files
if (
process_all
or file_path not in self.file_hashes
or self.file_hashes[file_path] != current_hash
):
if process_all or file_path not in self.file_hashes or self.file_hashes[file_path] != current_hash:
self.file_hashes[file_path] = current_hash
files_to_process.append(file_path)
logging.info(
f"File {'found' if process_all else 'changed'}: {file_path}"
)
logging.info(f"File {'found' if process_all else 'changed'}: {file_path}")
logging.info(
f"Found {len(files_to_process)} files to process after scanning {files_checked} files"
)
logging.info(f"Found {len(files_to_process)} files to process after scanning {files_checked} files")
# Check for deleted files
deleted_files = []
@ -201,9 +191,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
if not os.path.exists(file_path):
deleted_files.append(file_path)
# Schedule removal
asyncio.run_coroutine_threadsafe(
self.remove_file_from_collection(file_path), self.loop
)
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
# Don't block on result, just let it run
logging.info(f"File deleted: {file_path}")
@ -253,10 +241,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
if not current_hash: # File might have been deleted or is inaccessible
return
if (
file_path in self.file_hashes
and self.file_hashes[file_path] == current_hash
):
if file_path in self.file_hashes and self.file_hashes[file_path] == current_hash:
# File hasn't actually changed in content
logging.info(f"Hash has not changed for {file_path}")
return
@ -289,9 +274,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
if results and "ids" in results and results["ids"]:
self.collection.delete(ids=results["ids"])
await self.database.update_user_rag_timestamp(self.user_id)
logging.info(
f"Removed {len(results['ids'])} chunks for deleted file: {file_path}"
)
logging.info(f"Removed {len(results['ids'])} chunks for deleted file: {file_path}")
# Remove from hash dictionary
if file_path in self.file_hashes:
@ -304,17 +287,15 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
def _update_umaps(self):
# Update the UMAP embeddings
self._umap_collection = ChromaDBGetResponse.model_validate(self._collection.get(
include=["embeddings", "documents", "metadatas"]
))
self._umap_collection = ChromaDBGetResponse.model_validate(
self._collection.get(include=["embeddings", "documents", "metadatas"])
)
if not self._umap_collection or not len(self._umap_collection.embeddings):
logging.warning("⚠️ No embeddings found in the collection.")
return
# During initialization
logging.info(
f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors"
)
logging.info(f"Updating 2D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors")
vectors = np.array(self._umap_collection.embeddings)
self._umap_model_2d = umap.UMAP(
n_components=2,
@ -323,14 +304,12 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
n_neighbors=30,
min_dist=0.1,
)
self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors) # type: ignore
self._umap_embedding_2d = self._umap_model_2d.fit_transform(vectors) # type: ignore
# logging.info(
# f"2D UMAP model n_components: {self._umap_model_2d.n_components}"
# ) # Should be 2
logging.info(
f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors"
)
logging.info(f"Updating 3D {self.collection_name} UMAP for {len(self._umap_collection.embeddings)} vectors")
self._umap_model_3d = umap.UMAP(
n_components=3,
random_state=8911,
@ -338,7 +317,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
n_neighbors=30,
min_dist=0.01,
)
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors)# type: ignore
self._umap_embedding_3d = self._umap_model_3d.fit_transform(vectors) # type: ignore
# logging.info(
# f"3D UMAP model n_components: {self._umap_model_3d.n_components}"
# ) # Should be 3
@ -373,9 +352,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
self.is_new_collection = True
logging.info(f"Recreating collection: {self.collection_name}")
return chroma_client.get_or_create_collection(
name=self.collection_name, metadata={"hnsw:space": "cosine"}
)
return chroma_client.get_or_create_collection(name=self.collection_name, metadata={"hnsw:space": "cosine"})
async def get_embedding(self, text: str) -> np.ndarray:
"""Generate and normalize an embedding for the given text."""
@ -419,9 +396,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# Generate a more unique ID based on content and metadata
path_hash = ""
if "path" in metadata:
path_hash = hashlib.md5(metadata["source_file"].encode()).hexdigest()[
:8
]
path_hash = hashlib.md5(metadata["source_file"].encode()).hexdigest()[:8]
content_hash = hashlib.md5(text.encode()).hexdigest()[:8]
chunk_id = f"{path_hash}_{i}_{content_hash}"
@ -438,7 +413,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
logging.error(traceback.format_exc())
logging.error(chunk)
def prepare_metadata(self, meta: Dict[str, Any], buffer=defines.chunk_buffer)-> str | None:
def prepare_metadata(self, meta: Dict[str, Any], buffer=defines.chunk_buffer) -> str | None:
source_file = meta.get("source_file")
try:
source_file = meta["source_file"]
@ -541,9 +516,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
return
file_path = event.src_path
asyncio.run_coroutine_threadsafe(
self.remove_file_from_collection(file_path), self.loop
)
asyncio.run_coroutine_threadsafe(self.remove_file_from_collection(file_path), self.loop)
logging.info(f"File deleted: {file_path}")
def on_moved(self, event):
@ -571,11 +544,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
try:
# Remove existing entries for this file
existing_results = self.collection.get(where={"path": file_path})
if (
existing_results
and "ids" in existing_results
and existing_results["ids"]
):
if existing_results and "ids" in existing_results and existing_results["ids"]:
self.collection.delete(ids=existing_results["ids"])
await self.database.update_user_rag_timestamp(self.user_id)
@ -584,15 +553,11 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
p = Path(file_path)
p_as_md = p.with_suffix(".md")
if p_as_md.exists():
logging.info(
f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}"
)
logging.info(f"newer: {p.stat().st_mtime > p_as_md.stat().st_mtime}")
# If file_path.md doesn't exist or file_path is newer than file_path.md,
# fire off markitdown
if (not p_as_md.exists()) or (
p.stat().st_mtime > p_as_md.stat().st_mtime
):
if (not p_as_md.exists()) or (p.stat().st_mtime > p_as_md.stat().st_mtime):
self._markitdown(file_path, p_as_md)
return
@ -626,9 +591,7 @@ class ChromaDBFileWatcher(FileSystemEventHandler):
# Process all files regardless of hash state
num_processed = await self.scan_directory(process_all=True)
logging.info(
f"Vectorstore initialized with {self.collection.count()} documents"
)
logging.info(f"Vectorstore initialized with {self.collection.count()} documents")
self._update_umaps()
@ -676,7 +639,7 @@ def start_file_watcher(
persist_directory=persist_directory,
collection_name=collection_name,
recreate=recreate,
database=database
database=database,
)
# Process all files if:

View File

@ -13,14 +13,4 @@ from . import employers
from . import admin
from . import system
__all__ = [
"auth",
"candidates",
"resumes",
"jobs",
"chat",
"users",
"employers",
"admin",
"system"
]
__all__ = ["auth", "candidates", "resumes", "jobs", "chat", "users", "employers", "admin", "system"]

View File

@ -5,8 +5,18 @@ import json
from datetime import datetime, timezone, UTC
from fastapi import (
APIRouter, HTTPException, Depends, Body, Request, HTTPException,
Depends, Query, Path, Body, APIRouter, Request
APIRouter,
HTTPException,
Depends,
Body,
Request,
HTTPException,
Depends,
Query,
Path,
Body,
APIRouter,
Request,
)
from fastapi.responses import JSONResponse
@ -14,52 +24,51 @@ from fastapi.responses import JSONResponse
from utils.rate_limiter import RateLimiter, get_rate_limiter
from database.manager import RedisDatabase
from logger import logger
from utils.dependencies import (
get_current_admin, get_current_user_or_guest, get_database, background_task_manager
)
from utils.dependencies import get_current_admin, get_current_user_or_guest, get_database, background_task_manager
from utils.responses import (
create_paginated_response, create_success_response, create_error_response,
create_paginated_response,
create_success_response,
create_error_response,
)
# Create router for authentication endpoints
router = APIRouter(prefix="/admin", tags=["admin"])
@router.post("/tasks/cleanup-guests")
async def manual_guest_cleanup(
inactive_hours: int = Body(24, embed=True),
current_user = Depends(get_current_admin),
admin_user = Depends(get_current_admin)
current_user=Depends(get_current_admin),
admin_user=Depends(get_current_admin),
):
"""Manually trigger guest cleanup (admin only)"""
try:
if not background_task_manager:
return JSONResponse(
status_code=500,
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available")
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available"),
)
cleaned_count = await background_task_manager.cleanup_inactive_guests(inactive_hours)
logger.info(f"🧹 Manual guest cleanup triggered by admin {admin_user.id}: {cleaned_count} guests cleaned")
return create_success_response({
"message": f"Guest cleanup completed. Removed {cleaned_count} inactive sessions.",
"cleaned_count": cleaned_count,
"triggered_by": admin_user.id
})
return create_success_response(
{
"message": f"Guest cleanup completed. Removed {cleaned_count} inactive sessions.",
"cleaned_count": cleaned_count,
"triggered_by": admin_user.id,
}
)
except Exception as e:
logger.error(f"❌ Manual guest cleanup error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("CLEANUP_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
@router.post("/tasks/cleanup-tokens")
async def manual_token_cleanup(
admin_user = Depends(get_current_admin)
):
async def manual_token_cleanup(admin_user=Depends(get_current_admin)):
"""Manually trigger verification token cleanup (admin only)"""
try:
global background_task_manager
@ -67,31 +76,28 @@ async def manual_token_cleanup(
if not background_task_manager:
return JSONResponse(
status_code=500,
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available")
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available"),
)
cleaned_count = await background_task_manager.cleanup_expired_verification_tokens()
logger.info(f"🧹 Manual token cleanup triggered by admin {admin_user.id}: {cleaned_count} tokens cleaned")
return create_success_response({
"message": f"Token cleanup completed. Removed {cleaned_count} expired tokens.",
"cleaned_count": cleaned_count,
"triggered_by": admin_user.id
})
return create_success_response(
{
"message": f"Token cleanup completed. Removed {cleaned_count} expired tokens.",
"cleaned_count": cleaned_count,
"triggered_by": admin_user.id,
}
)
except Exception as e:
logger.error(f"❌ Manual token cleanup error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("CLEANUP_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
@router.post("/tasks/cleanup-rate-limits")
async def manual_rate_limit_cleanup(
days_old: int = Body(7, embed=True),
admin_user = Depends(get_current_admin)
):
async def manual_rate_limit_cleanup(days_old: int = Body(7, embed=True), admin_user=Depends(get_current_admin)):
"""Manually trigger rate limit data cleanup (admin only)"""
try:
global background_task_manager
@ -99,60 +105,55 @@ async def manual_rate_limit_cleanup(
if not background_task_manager:
return JSONResponse(
status_code=500,
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available")
content=create_error_response("TASK_MANAGER_NOT_AVAILABLE", "Background task manager not available"),
)
cleaned_count = await background_task_manager.cleanup_old_rate_limit_data(days_old)
logger.info(f"🧹 Manual rate limit cleanup triggered by admin {admin_user.id}: {cleaned_count} keys cleaned")
return create_success_response({
"message": f"Rate limit cleanup completed. Removed {cleaned_count} old keys.",
"cleaned_count": cleaned_count,
"triggered_by": admin_user.id
})
return create_success_response(
{
"message": f"Rate limit cleanup completed. Removed {cleaned_count} old keys.",
"cleaned_count": cleaned_count,
"triggered_by": admin_user.id,
}
)
except Exception as e:
logger.error(f"❌ Manual rate limit cleanup error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("CLEANUP_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
# ========================================
# System Health and Maintenance Endpoints
# ========================================
@router.get("/system/health")
async def get_system_health(
request: Request,
admin_user = Depends(get_current_admin)
):
async def get_system_health(request: Request, admin_user=Depends(get_current_admin)):
"""Get comprehensive system health status (admin only)"""
try:
# Database health
database_manager = getattr(request.app.state, 'database_manager', None)
database_manager = getattr(request.app.state, "database_manager", None)
db_health = {"status": "unavailable", "healthy": False}
if database_manager:
try:
database = database_manager.get_database()
database_manager.get_database()
from database.manager import redis_manager
redis_health = await redis_manager.health_check()
db_health = {
"status": redis_health.get("status", "unknown"),
"healthy": redis_health.get("status") == "healthy",
"details": redis_health
"details": redis_health,
}
except Exception as e:
db_health = {
"status": "error",
"healthy": False,
"error": str(e)
}
db_health = {"status": "error", "healthy": False, "error": str(e)}
# Background task health
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
background_task_manager = getattr(request.app.state, "background_task_manager", None)
task_health = {"status": "unavailable", "healthy": False}
if background_task_manager:
@ -166,39 +167,30 @@ async def get_system_health(
"healthy": task_status["running"] and failed_tasks == 0,
"running_tasks": running_tasks,
"failed_tasks": failed_tasks,
"total_tasks": task_status["task_count"]
"total_tasks": task_status["task_count"],
}
except Exception as e:
task_health = {
"status": "error",
"healthy": False,
"error": str(e)
}
task_health = {"status": "error", "healthy": False, "error": str(e)}
# Overall health
overall_healthy = db_health["healthy"] and task_health["healthy"]
return create_success_response({
"timestamp": datetime.now(UTC).isoformat(),
"overall_healthy": overall_healthy,
"components": {
"database": db_health,
"background_tasks": task_health
return create_success_response(
{
"timestamp": datetime.now(UTC).isoformat(),
"overall_healthy": overall_healthy,
"components": {"database": db_health, "background_tasks": task_health},
}
})
)
except Exception as e:
logger.error(f"❌ Error getting system health: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("HEALTH_CHECK_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("HEALTH_CHECK_ERROR", str(e)))
@router.post("/maintenance/cleanup")
async def run_maintenance_cleanup(
request: Request,
admin_user = Depends(get_current_admin),
database: RedisDatabase = Depends(get_database)
request: Request, admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
):
"""Run comprehensive maintenance cleanup (admin only)"""
try:
@ -217,50 +209,40 @@ async def run_maintenance_cleanup(
cleanup_results[operation_name] = {
"success": True,
"cleaned_count": result,
"message": f"Cleaned {result} items"
"message": f"Cleaned {result} items",
}
except Exception as e:
cleanup_results[operation_name] = {
"success": False,
"error": str(e),
"message": f"Failed: {str(e)}"
}
cleanup_results[operation_name] = {"success": False, "error": str(e), "message": f"Failed: {str(e)}"}
# Calculate totals
total_cleaned = sum(
result.get("cleaned_count", 0)
for result in cleanup_results.values()
if result.get("success", False)
result.get("cleaned_count", 0) for result in cleanup_results.values() if result.get("success", False)
)
successful_operations = len([
r for r in cleanup_results.values()
if r.get("success", False)
])
successful_operations = len([r for r in cleanup_results.values() if r.get("success", False)])
return create_success_response({
"message": f"Maintenance cleanup completed. {total_cleaned} items cleaned across {successful_operations} operations.",
"total_cleaned": total_cleaned,
"successful_operations": successful_operations,
"details": cleanup_results
})
return create_success_response(
{
"message": f"Maintenance cleanup completed. {total_cleaned} items cleaned across {successful_operations} operations.",
"total_cleaned": total_cleaned,
"successful_operations": successful_operations,
"details": cleanup_results,
}
)
except Exception as e:
logger.error(f"❌ Error in maintenance cleanup: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("CLEANUP_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
# ========================================
# Background Task Statistics
# ========================================
@router.get("/tasks/stats")
async def get_task_statistics(
request: Request,
admin_user = Depends(get_current_admin),
database: RedisDatabase = Depends(get_database)
request: Request, admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
):
"""Get background task execution statistics (admin only)"""
try:
@ -268,7 +250,7 @@ async def get_task_statistics(
guest_stats = await database.get_guest_statistics()
# Get background task manager status
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
background_task_manager = getattr(request.app.state, "background_task_manager", None)
task_manager_stats = {}
if background_task_manager:
@ -276,7 +258,7 @@ async def get_task_statistics(
task_manager_stats = {
"running": task_status["running"],
"task_count": task_status["task_count"],
"task_breakdown": {}
"task_breakdown": {},
}
# Count tasks by status
@ -284,40 +266,35 @@ async def get_task_statistics(
status = task["status"]
task_manager_stats["task_breakdown"][status] = task_manager_stats["task_breakdown"].get(status, 0) + 1
return create_success_response({
"guest_statistics": guest_stats,
"task_manager": task_manager_stats,
"timestamp": datetime.now(UTC).isoformat()
})
return create_success_response(
{
"guest_statistics": guest_stats,
"task_manager": task_manager_stats,
"timestamp": datetime.now(UTC).isoformat(),
}
)
except Exception as e:
logger.error(f"❌ Error getting task statistics: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("STATS_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("STATS_ERROR", str(e)))
# ========================================
# Background Task Status Endpoints
# ========================================
@router.get("/tasks/status")
async def get_background_task_status(
request: Request,
admin_user = Depends(get_current_admin)
):
async def get_background_task_status(request: Request, admin_user=Depends(get_current_admin)):
"""Get background task manager status (admin only)"""
try:
# Get background task manager from app state
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
background_task_manager = getattr(request.app.state, "background_task_manager", None)
if not background_task_manager:
return create_success_response({
"running": False,
"message": "Background task manager not initialized",
"tasks": [],
"task_count": 0
})
return create_success_response(
{"running": False, "message": "Background task manager not initialized", "tasks": [], "task_count": 0}
)
# Get comprehensive task status using the new method
task_status = await background_task_manager.get_task_status()
@ -325,91 +302,68 @@ async def get_background_task_status(
# Add additional system info
system_info = {
"uptime_seconds": None, # Could calculate from start time if stored
"last_cleanup": None, # Could track last cleanup time
"last_cleanup": None, # Could track last cleanup time
}
# Format the response
return create_success_response({
"running": task_status["running"],
"task_count": task_status["task_count"],
"loop_status": {
"main_loop_id": task_status["main_loop_id"],
"current_loop_id": task_status["current_loop_id"],
"loop_matches": task_status.get("loop_matches", False)
},
"tasks": task_status["tasks"],
"system_info": system_info
})
return create_success_response(
{
"running": task_status["running"],
"task_count": task_status["task_count"],
"loop_status": {
"main_loop_id": task_status["main_loop_id"],
"current_loop_id": task_status["current_loop_id"],
"loop_matches": task_status.get("loop_matches", False),
},
"tasks": task_status["tasks"],
"system_info": system_info,
}
)
except Exception as e:
logger.error(f"❌ Get task status error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("STATUS_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("STATUS_ERROR", str(e)))
@router.post("/tasks/run/{task_name}")
async def run_background_task(
task_name: str,
request: Request,
admin_user = Depends(get_current_admin)
):
async def run_background_task(task_name: str, request: Request, admin_user=Depends(get_current_admin)):
"""Manually trigger a specific background task (admin only)"""
try:
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
background_task_manager = getattr(request.app.state, "background_task_manager", None)
if not background_task_manager:
return JSONResponse(
status_code=503,
content=create_error_response(
"MANAGER_UNAVAILABLE",
"Background task manager not initialized"
)
content=create_error_response("MANAGER_UNAVAILABLE", "Background task manager not initialized"),
)
# List of available tasks
available_tasks = [
"guest_cleanup",
"token_cleanup",
"guest_stats",
"rate_limit_cleanup",
"orphaned_cleanup"
]
available_tasks = ["guest_cleanup", "token_cleanup", "guest_stats", "rate_limit_cleanup", "orphaned_cleanup"]
if task_name not in available_tasks:
return JSONResponse(
status_code=400,
content=create_error_response(
"INVALID_TASK",
f"Unknown task: {task_name}. Available: {available_tasks}"
)
"INVALID_TASK", f"Unknown task: {task_name}. Available: {available_tasks}"
),
)
# Run the task
result = await background_task_manager.force_run_task(task_name)
return create_success_response({
"task_name": task_name,
"result": result,
"message": f"Task {task_name} completed successfully"
})
return create_success_response(
{"task_name": task_name, "result": result, "message": f"Task {task_name} completed successfully"}
)
except ValueError as e:
return JSONResponse(
status_code=400,
content=create_error_response("INVALID_TASK", str(e))
)
return JSONResponse(status_code=400, content=create_error_response("INVALID_TASK", str(e)))
except Exception as e:
logger.error(f"❌ Error running task {task_name}: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("TASK_EXECUTION_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("TASK_EXECUTION_ERROR", str(e)))
@router.get("/tasks/list")
async def list_available_tasks(
admin_user = Depends(get_current_admin)
):
async def list_available_tasks(admin_user=Depends(get_current_admin)):
"""List all available background tasks (admin only)"""
try:
tasks = [
@ -417,63 +371,51 @@ async def list_available_tasks(
"name": "guest_cleanup",
"description": "Clean up inactive guest sessions",
"interval": "6 hours",
"parameters": ["inactive_hours (default: 48)"]
"parameters": ["inactive_hours (default: 48)"],
},
{
"name": "token_cleanup",
"description": "Clean up expired email verification tokens",
"interval": "12 hours",
"parameters": []
"parameters": [],
},
{
"name": "guest_stats",
"description": "Update guest usage statistics",
"interval": "1 hour",
"parameters": []
"parameters": [],
},
{
"name": "rate_limit_cleanup",
"description": "Clean up old rate limiting data",
"interval": "24 hours",
"parameters": ["days_old (default: 7)"]
"parameters": ["days_old (default: 7)"],
},
{
"name": "orphaned_cleanup",
"description": "Clean up orphaned database records",
"interval": "6 hours",
"parameters": []
}
"parameters": [],
},
]
return create_success_response({
"total_tasks": len(tasks),
"tasks": tasks
})
return create_success_response({"total_tasks": len(tasks), "tasks": tasks})
except Exception as e:
logger.error(f"❌ Error listing tasks: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("LIST_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("LIST_ERROR", str(e)))
@router.post("/tasks/restart")
async def restart_background_tasks(
request: Request,
admin_user = Depends(get_current_admin)
):
async def restart_background_tasks(request: Request, admin_user=Depends(get_current_admin)):
"""Restart the background task manager (admin only)"""
try:
database_manager = getattr(request.app.state, 'database_manager', None)
background_task_manager = getattr(request.app.state, 'background_task_manager', None)
database_manager = getattr(request.app.state, "database_manager", None)
background_task_manager = getattr(request.app.state, "background_task_manager", None)
if not database_manager:
return JSONResponse(
status_code=503,
content=create_error_response(
"DATABASE_UNAVAILABLE",
"Database manager not available"
)
status_code=503, content=create_error_response("DATABASE_UNAVAILABLE", "Database manager not available")
)
# Stop existing background tasks
@ -483,6 +425,7 @@ async def restart_background_tasks(
# Create and start new background task manager
from background_tasks import BackgroundTaskManager
new_background_task_manager = BackgroundTaskManager(database_manager)
await new_background_task_manager.start()
@ -492,22 +435,20 @@ async def restart_background_tasks(
# Get status of new manager
status = await new_background_task_manager.get_task_status()
return create_success_response({
"message": "Background task manager restarted successfully",
"new_status": status
})
return create_success_response(
{"message": "Background task manager restarted successfully", "new_status": status}
)
except Exception as e:
logger.error(f"❌ Error restarting background tasks: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("RESTART_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("RESTART_ERROR", str(e)))
# ============================
# Task Monitoring and Metrics
# ============================
class TaskMetrics:
"""Collect metrics for background tasks"""
@ -544,36 +485,32 @@ class TaskMetrics:
metrics[task_name] = {
"total_runs": self.task_runs[task_name],
"total_errors": self.task_errors[task_name],
"success_rate": (self.task_runs[task_name] - self.task_errors[task_name]) / self.task_runs[task_name] if self.task_runs[task_name] > 0 else 0,
"success_rate": (self.task_runs[task_name] - self.task_errors[task_name]) / self.task_runs[task_name]
if self.task_runs[task_name] > 0
else 0,
"average_duration": avg_duration,
"last_runs": durations[-10:] if durations else []
"last_runs": durations[-10:] if durations else [],
}
return metrics
# Global task metrics
task_metrics = TaskMetrics()
@router.get("/tasks/metrics")
async def get_task_metrics(
admin_user = Depends(get_current_admin)
):
async def get_task_metrics(admin_user=Depends(get_current_admin)):
"""Get background task metrics (admin only)"""
try:
global task_metrics
metrics = task_metrics.get_metrics()
return create_success_response({
"metrics": metrics,
"timestamp": datetime.now(UTC).isoformat()
})
return create_success_response({"metrics": metrics, "timestamp": datetime.now(UTC).isoformat()})
except Exception as e:
logger.error(f"❌ Get task metrics error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("METRICS_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("METRICS_ERROR", str(e)))
# ============================
@ -581,8 +518,7 @@ async def get_task_metrics(
# ============================
# @router.get("/verification-stats")
async def get_verification_statistics(
current_user = Depends(get_current_admin),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
):
"""Get verification statistics (admin only)"""
try:
@ -591,22 +527,19 @@ async def get_verification_statistics(
stats = {
"pending_verifications": await database.get_pending_verifications_count(),
"expired_tokens_cleaned": await database.cleanup_expired_verification_tokens()
"expired_tokens_cleaned": await database.cleanup_expired_verification_tokens(),
}
return create_success_response(stats)
except Exception as e:
logger.error(f"❌ Error getting verification stats: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("STATS_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("STATS_ERROR", str(e)))
@router.post("/cleanup-verifications")
async def cleanup_verification_tokens(
current_user = Depends(get_current_admin),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
):
"""Manually trigger cleanup of expired verification tokens (admin only)"""
try:
@ -617,24 +550,24 @@ async def cleanup_verification_tokens(
logger.info(f"🧹 Manual cleanup completed by admin {current_user.id}: {cleaned_count} tokens cleaned")
return create_success_response({
"message": f"Cleanup completed. Removed {cleaned_count} expired verification tokens.",
"cleaned_count": cleaned_count
})
return create_success_response(
{
"message": f"Cleanup completed. Removed {cleaned_count} expired verification tokens.",
"cleaned_count": cleaned_count,
}
)
except Exception as e:
logger.error(f"❌ Error in manual cleanup: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("CLEANUP_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("CLEANUP_ERROR", str(e)))
@router.get("/pending-verifications")
async def get_pending_verifications(
current_user = Depends(get_current_admin),
current_user=Depends(get_current_admin),
page: int = Query(1, ge=1),
limit: int = Query(20, ge=1, le=100),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Get list of pending email verifications (admin only)"""
try:
@ -656,14 +589,16 @@ async def get_pending_verifications(
if not verification_info.get("verified", False):
expires_at = datetime.fromisoformat(verification_info.get("expires_at", ""))
pending_verifications.append({
"email": verification_info.get("email"),
"user_type": verification_info.get("user_type"),
"created_at": verification_info.get("created_at"),
"expires_at": verification_info.get("expires_at"),
"is_expired": current_time > expires_at,
"resend_count": verification_info.get("resend_count", 0)
})
pending_verifications.append(
{
"email": verification_info.get("email"),
"user_type": verification_info.get("user_type"),
"created_at": verification_info.get("created_at"),
"expires_at": verification_info.get("expires_at"),
"is_expired": current_time > expires_at,
"resend_count": verification_info.get("resend_count", 0),
}
)
if cursor == 0:
break
@ -677,35 +612,27 @@ async def get_pending_verifications(
end = start + limit
paginated_verifications = pending_verifications[start:end]
paginated_response = create_paginated_response(
paginated_verifications,
page, limit, total
)
paginated_response = create_paginated_response(paginated_verifications, page, limit, total)
return create_success_response(paginated_response)
except Exception as e:
logger.error(f"❌ Error getting pending verifications: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("FETCH_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e)))
@router.get("/rate-limits/info")
async def get_user_rate_limit_status(
current_user = Depends(get_current_user_or_guest),
current_user=Depends(get_current_user_or_guest),
rate_limiter: RateLimiter = Depends(get_rate_limiter),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Get rate limit status for a user (admin only)"""
try:
# Get user to determine type
user_data = await database.get_user_by_id(current_user.id)
if not user_data:
return JSONResponse(
status_code=404,
content=create_error_response("USER_NOT_FOUND", "User not found")
)
return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
user_type = user_data.get("type", "unknown")
is_admin = False
@ -725,27 +652,22 @@ async def get_user_rate_limit_status(
except Exception as e:
logger.error(f"❌ Get rate limit status error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("STATUS_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("STATUS_ERROR", str(e)))
@router.get("/rate-limits/{user_id}")
async def get_anyone_rate_limit_status(
user_id: str = Path(...),
admin_user = Depends(get_current_admin),
admin_user=Depends(get_current_admin),
rate_limiter: RateLimiter = Depends(get_rate_limiter),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Get rate limit status for a user (admin only)"""
try:
# Get user to determine type
user_data = await database.get_user_by_id(user_id)
if not user_data:
return JSONResponse(
status_code=404,
content=create_error_response("USER_NOT_FOUND", "User not found")
)
return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
user_type = user_data.get("type", "unknown")
is_admin = False
@ -765,47 +687,36 @@ async def get_anyone_rate_limit_status(
except Exception as e:
logger.error(f"❌ Get rate limit status error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("STATUS_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("STATUS_ERROR", str(e)))
@router.post("/rate-limits/{user_id}/reset")
async def reset_user_rate_limits(
user_id: str = Path(...),
admin_user = Depends(get_current_admin),
admin_user=Depends(get_current_admin),
rate_limiter: RateLimiter = Depends(get_rate_limiter),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Reset rate limits for a user (admin only)"""
try:
# Get user to determine type
user_data = await database.get_user_by_id(user_id)
if not user_data:
return JSONResponse(
status_code=404,
content=create_error_response("USER_NOT_FOUND", "User not found")
)
return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
user_type = user_data.get("type", "unknown")
success = await rate_limiter.reset_user_rate_limits(user_id, user_type)
if success:
logger.info(f"🔄 Rate limits reset for {user_type} {user_id} by admin {admin_user.id}")
return create_success_response({
"message": f"Rate limits reset for {user_type} {user_id}",
"resetBy": admin_user.id
})
return create_success_response(
{"message": f"Rate limits reset for {user_type} {user_id}", "resetBy": admin_user.id}
)
else:
return JSONResponse(
status_code=500,
content=create_error_response("RESET_FAILED", "Failed to reset rate limits")
status_code=500, content=create_error_response("RESET_FAILED", "Failed to reset rate limits")
)
except Exception as e:
logger.error(f"❌ Reset rate limits error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("RESET_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("RESET_ERROR", str(e)))

View File

@ -20,18 +20,25 @@ from device_manager import DeviceManager
from email_service import VerificationEmailRateLimiter, email_service
from logger import logger
from models import (
LoginRequest, CreateCandidateRequest, Candidate,
Employer, Guest, AuthResponse, MFARequest,
MFAData, MFAVerifyRequest, ResendVerificationRequest,
MFARequestResponse, MFARequestResponse
)
from utils.dependencies import (
get_current_admin, get_database, get_current_user, create_access_token
LoginRequest,
CreateCandidateRequest,
Candidate,
Employer,
Guest,
AuthResponse,
MFARequest,
MFAData,
MFAVerifyRequest,
ResendVerificationRequest,
MFARequestResponse,
MFARequestResponse,
)
from utils.dependencies import get_current_admin, get_database, get_current_user, create_access_token
from utils.responses import create_success_response, create_error_response
from utils.rate_limiter import get_rate_limiter
from utils.auth_utils import (
AuthenticationManager, SecurityConfig,
AuthenticationManager,
SecurityConfig,
validate_password_strength,
)
@ -43,28 +50,31 @@ if JWT_SECRET_KEY == "":
raise ValueError("JWT_SECRET_KEY environment variable is not set")
ALGORITHM = "HS256"
# ============================
# Password Reset Endpoints
# ============================
class PasswordResetRequest(BaseModel):
email: EmailStr
class PasswordResetConfirm(BaseModel):
token: str
new_password: str
@field_validator('new_password')
@field_validator("new_password")
def validate_password_strength(cls, v):
is_valid, issues = validate_password_strength(v)
if not is_valid:
raise ValueError('; '.join(issues))
raise ValueError("; ".join(issues))
return v
@router.post("/guest")
async def create_guest_session_enhanced(
request: Request,
database: RedisDatabase = Depends(get_database),
rate_limiter: RateLimiter = Depends(get_rate_limiter)
rate_limiter: RateLimiter = Depends(get_rate_limiter),
):
"""Create a guest session with enhanced validation and persistence"""
try:
@ -73,21 +83,15 @@ async def create_guest_session_enhanced(
# Check rate limits for guest session creation
rate_result = await rate_limiter.check_rate_limit(
user_id=ip_address,
user_type="guest_creation",
is_admin=False,
endpoint="/guest"
user_id=ip_address, user_type="guest_creation", is_admin=False, endpoint="/guest"
)
if not rate_result.allowed:
logger.warning(f"🚫 Guest creation rate limit exceeded for IP {ip_address}")
return JSONResponse(
status_code=429,
content=create_error_response(
"RATE_LIMITED",
rate_result.reason or "Too many guest sessions created"
),
headers={"Retry-After": str(rate_result.retry_after_seconds or 300)}
content=create_error_response("RATE_LIMITED", rate_result.reason or "Too many guest sessions created"),
headers={"Retry-After": str(rate_result.retry_after_seconds or 300)},
)
# Generate unique guest identifier with timestamp for uniqueness
@ -126,7 +130,7 @@ async def create_guest_session_enhanced(
"user_agent": request.headers.get("user-agent", "Unknown"),
"converted_to_user_id": None,
"browser_session": True, # Mark as browser session
"persistent": True, # Mark as persistent
"persistent": True, # Mark as persistent
}
# Store guest with enhanced persistence
@ -139,7 +143,7 @@ async def create_guest_session_enhanced(
"email": guest_data["email"],
"username": guest_username,
"session_id": session_id,
"created_at": current_time.isoformat()
"created_at": current_time.isoformat(),
}
await database.set_user(guest_data["email"], user_auth_data)
@ -149,11 +153,11 @@ async def create_guest_session_enhanced(
# Create authentication tokens with longer expiry for guests
access_token = create_access_token(
data={"sub": guest_id, "type": "guest"},
expires_delta=timedelta(hours=48) # Longer expiry for guests
expires_delta=timedelta(hours=48), # Longer expiry for guests
)
refresh_token = create_access_token(
data={"sub": guest_id, "type": "refresh_guest"},
expires_delta=timedelta(days=14) # 2 weeks refresh for guests
expires_delta=timedelta(days=14), # 2 weeks refresh for guests
)
# Verify guest was stored correctly
@ -161,8 +165,7 @@ async def create_guest_session_enhanced(
if not verification:
logger.error(f"❌ Failed to verify guest storage: {guest_id}")
return JSONResponse(
status_code=500,
content=create_error_response("STORAGE_ERROR", "Failed to create guest session")
status_code=500, content=create_error_response("STORAGE_ERROR", "Failed to create guest session")
)
# Create guest object for response
@ -178,7 +181,7 @@ async def create_guest_session_enhanced(
"user": guest.model_dump(by_alias=True),
"expiresAt": int((current_time + timedelta(hours=48)).timestamp()),
"userType": "guest",
"isGuest": True
"isGuest": True,
}
return create_success_response(auth_response)
@ -186,25 +189,25 @@ async def create_guest_session_enhanced(
except Exception as e:
logger.error(f"❌ Guest session creation error: {e}")
import traceback
logger.error(traceback.format_exc())
return JSONResponse(
status_code=500,
content=create_error_response("GUEST_CREATION_FAILED", "Failed to create guest session")
status_code=500, content=create_error_response("GUEST_CREATION_FAILED", "Failed to create guest session")
)
@router.post("/guest/convert")
async def convert_guest_to_user(
registration_data: Dict[str, Any] = Body(...),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Convert a guest session to a permanent user account"""
try:
# Verify current user is a guest
if current_user.user_type != "guest":
return JSONResponse(
status_code=400,
content=create_error_response("NOT_GUEST", "Only guest users can be converted")
status_code=400, content=create_error_response("NOT_GUEST", "Only guest users can be converted")
)
guest: Guest = current_user
@ -215,25 +218,18 @@ async def convert_guest_to_user(
try:
candidate_request = CreateCandidateRequest.model_validate(registration_data)
except ValidationError as e:
return JSONResponse(
status_code=400,
content=create_error_response("VALIDATION_ERROR", str(e))
)
return JSONResponse(status_code=400, content=create_error_response("VALIDATION_ERROR", str(e)))
# Check if email/username already exists
auth_manager = AuthenticationManager(database)
user_exists, conflict_field = await auth_manager.check_user_exists(
candidate_request.email,
candidate_request.username
candidate_request.email, candidate_request.username
)
if user_exists:
return JSONResponse(
status_code=409,
content=create_error_response(
"USER_EXISTS",
f"A user with this {conflict_field} already exists"
)
content=create_error_response("USER_EXISTS", f"A user with this {conflict_field} already exists"),
)
# Create candidate
@ -253,7 +249,7 @@ async def convert_guest_to_user(
"updated_at": current_time.isoformat(),
"status": "active",
"is_admin": False,
"converted_from_guest": guest.id
"converted_from_guest": guest.id,
}
candidate = Candidate.model_validate(candidate_data)
@ -269,7 +265,7 @@ async def convert_guest_to_user(
"id": candidate_id,
"type": "candidate",
"email": candidate.email,
"username": candidate.username
"username": candidate.username,
}
await database.set_user(candidate.email, user_auth_data)
@ -286,43 +282,45 @@ async def convert_guest_to_user(
access_token = create_access_token(data={"sub": candidate_id})
refresh_token = create_access_token(
data={"sub": candidate_id, "type": "refresh"},
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS)
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS),
)
auth_response = AuthResponse(
access_token=access_token,
refresh_token=refresh_token,
user=candidate,
expires_at=int((current_time + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp())
expires_at=int((current_time + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()),
)
logger.info(f"✅ Guest {guest.session_id} converted to candidate {candidate.username}")
return create_success_response({
"message": "Guest account successfully converted to candidate",
"auth": auth_response.model_dump(by_alias=True),
"conversionType": "candidate"
})
return create_success_response(
{
"message": "Guest account successfully converted to candidate",
"auth": auth_response.model_dump(by_alias=True),
"conversionType": "candidate",
}
)
else:
return JSONResponse(
status_code=400,
content=create_error_response("INVALID_TYPE", "Only candidate conversion is currently supported")
content=create_error_response("INVALID_TYPE", "Only candidate conversion is currently supported"),
)
except Exception as e:
logger.error(f"❌ Guest conversion error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("CONVERSION_FAILED", "Failed to convert guest account")
status_code=500, content=create_error_response("CONVERSION_FAILED", "Failed to convert guest account")
)
@router.post("/logout")
async def logout(
access_token: str = Body(..., alias="accessToken"),
refresh_token: str = Body(..., alias="refreshToken"),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Logout endpoint - revokes both access and refresh tokens"""
logger.info(f"🔑 User {current_user.id} is logging out")
@ -336,21 +334,18 @@ async def logout(
if not user_id or token_type != "refresh":
return JSONResponse(
status_code=401,
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
)
except jwt.PyJWTError as e:
logger.warning(f"⚠️ Invalid refresh token during logout: {e}")
return JSONResponse(
status_code=401,
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
)
# Verify that the refresh token belongs to the current user
if user_id != current_user.id:
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Token does not belong to current user")
status_code=403, content=create_error_response("FORBIDDEN", "Token does not belong to current user")
)
# Get Redis client
@ -362,12 +357,14 @@ async def logout(
await redis.setex(
f"blacklisted_token:{refresh_token}",
refresh_ttl,
json.dumps({
"user_id": user_id,
"token_type": "refresh",
"revoked_at": datetime.now(UTC).isoformat(),
"reason": "user_logout"
})
json.dumps(
{
"user_id": user_id,
"token_type": "refresh",
"revoked_at": datetime.now(UTC).isoformat(),
"reason": "user_logout",
}
),
)
logger.info(f"🔒 Blacklisted refresh token for user {user_id}")
@ -385,12 +382,14 @@ async def logout(
await redis.setex(
f"blacklisted_token:{access_token}",
access_ttl,
json.dumps({
"user_id": user_id,
"token_type": "access",
"revoked_at": datetime.now(UTC).isoformat(),
"reason": "user_logout"
})
json.dumps(
{
"user_id": user_id,
"token_type": "access",
"revoked_at": datetime.now(UTC).isoformat(),
"reason": "user_logout",
}
),
)
logger.info(f"🔒 Blacklisted access token for user {user_id}")
else:
@ -409,26 +408,20 @@ async def logout(
# )
logger.info(f"🔑 User {user_id} logged out successfully")
return create_success_response({
"message": "Logged out successfully",
"tokensRevoked": {
"refreshToken": True,
"accessToken": bool(access_token)
return create_success_response(
{
"message": "Logged out successfully",
"tokensRevoked": {"refreshToken": True, "accessToken": bool(access_token)},
}
})
)
except Exception as e:
logger.error(f"❌ Logout error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("LOGOUT_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("LOGOUT_ERROR", str(e)))
@router.post("/logout-all")
async def logout_all_devices(
current_user = Depends(get_current_admin),
database: RedisDatabase = Depends(get_database)
):
async def logout_all_devices(current_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)):
"""Logout from all devices by revoking all tokens for the user"""
try:
redis = redis_manager.get_client()
@ -437,25 +430,20 @@ async def logout_all_devices(
await redis.setex(
f"user_tokens_revoked:{current_user.id}",
int(timedelta(days=30).total_seconds()), # Max refresh token lifetime
datetime.now(UTC).isoformat()
datetime.now(UTC).isoformat(),
)
logger.info(f"🔒 All tokens revoked for user {current_user.id}")
return create_success_response({
"message": "Logged out from all devices successfully"
})
return create_success_response({"message": "Logged out from all devices successfully"})
except Exception as e:
logger.error(f"❌ Logout all devices error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("LOGOUT_ALL_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("LOGOUT_ALL_ERROR", str(e)))
@router.post("/refresh")
async def refresh_token_endpoint(
refresh_token: str = Body(..., alias="refreshToken"),
database: RedisDatabase = Depends(get_database)
refresh_token: str = Body(..., alias="refreshToken"), database: RedisDatabase = Depends(get_database)
):
"""Refresh token endpoint"""
try:
@ -466,8 +454,7 @@ async def refresh_token_endpoint(
if not user_id or token_type != "refresh":
return JSONResponse(
status_code=401,
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
)
# Create new access token
@ -484,37 +471,29 @@ async def refresh_token_endpoint(
user = Employer.model_validate(employer_data)
if not user:
return JSONResponse(
status_code=404,
content=create_error_response("USER_NOT_FOUND", "User not found")
)
return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
auth_response = AuthResponse(
access_token=access_token,
refresh_token=refresh_token, # Keep same refresh token
user=user,
expires_at=int((datetime.now(UTC) + timedelta(hours=24)).timestamp())
expires_at=int((datetime.now(UTC) + timedelta(hours=24)).timestamp()),
)
return create_success_response(auth_response.model_dump(by_alias=True))
except jwt.PyJWTError:
return JSONResponse(
status_code=401,
content=create_error_response("INVALID_TOKEN", "Invalid refresh token")
)
return JSONResponse(status_code=401, content=create_error_response("INVALID_TOKEN", "Invalid refresh token"))
except Exception as e:
logger.error(f"❌ Token refresh error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("REFRESH_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("REFRESH_ERROR", str(e)))
@router.post("/resend-verification")
async def resend_verification_email(
request: ResendVerificationRequest,
background_tasks: BackgroundTasks,
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Resend verification email with comprehensive rate limiting and validation"""
try:
@ -527,10 +506,7 @@ async def resend_verification_email(
can_send, reason = await rate_limiter.can_send_verification_email(email_lower)
if not can_send:
logger.warning(f"⚠️ Verification email rate limit exceeded for {email_lower}: {reason}")
return JSONResponse(
status_code=429,
content=create_error_response("RATE_LIMITED", reason)
)
return JSONResponse(status_code=429, content=create_error_response("RATE_LIMITED", reason))
# Clean up expired tokens first
await database.cleanup_expired_verification_tokens()
@ -541,9 +517,11 @@ async def resend_verification_email(
# User exists and is verified - don't reveal this for security
logger.info(f"🔍 Resend verification requested for already verified user: {email_lower}")
await rate_limiter.record_email_sent(email_lower) # Record attempt to prevent abuse
return create_success_response({
"message": "If your email is in our system and pending verification, a new verification email has been sent."
})
return create_success_response(
{
"message": "If your email is in our system and pending verification, a new verification email has been sent."
}
)
# Look for pending verification token
verification_data = await database.find_verification_token_by_email(email_lower)
@ -552,9 +530,11 @@ async def resend_verification_email(
# No pending verification found - don't reveal this for security
logger.info(f"🔍 Resend verification requested for non-existent pending verification: {email_lower}")
await rate_limiter.record_email_sent(email_lower) # Record attempt to prevent abuse
return create_success_response({
"message": "If your email is in our system and pending verification, a new verification email has been sent."
})
return create_success_response(
{
"message": "If your email is in our system and pending verification, a new verification email has been sent."
}
)
# Check if verification token has expired
expires_at = datetime.fromisoformat(verification_data["expires_at"])
@ -568,8 +548,8 @@ async def resend_verification_email(
status_code=400,
content=create_error_response(
"TOKEN_EXPIRED",
"Your verification link has expired. Please register again to create a new account."
)
"Your verification link has expired. Please register again to create a new account.",
),
)
# Generate new verification token (invalidate old one)
@ -577,20 +557,19 @@ async def resend_verification_email(
new_token = secrets.token_urlsafe(32)
# Update verification data with new token and reset attempts
verification_data.update({
"token": new_token,
"expires_at": (current_time + timedelta(hours=24)).isoformat(),
"resent_at": current_time.isoformat(),
"resend_count": verification_data.get("resend_count", 0) + 1
})
verification_data.update(
{
"token": new_token,
"expires_at": (current_time + timedelta(hours=24)).isoformat(),
"resent_at": current_time.isoformat(),
"resend_count": verification_data.get("resend_count", 0) + 1,
}
)
# Store new token and remove old one
await database.redis.delete(f"email_verification:{old_token}")
await database.store_email_verification_token(
email_lower,
new_token,
verification_data["user_type"],
verification_data["user_data"]
email_lower, new_token, verification_data["user_type"], verification_data["user_data"]
)
# Get user name for email
@ -610,73 +589,66 @@ async def resend_verification_email(
await rate_limiter.record_email_sent(email_lower)
# Send new verification email in background
background_tasks.add_task(
email_service.send_verification_email,
email_lower,
new_token,
user_name,
user_type
)
background_tasks.add_task(email_service.send_verification_email, email_lower, new_token, user_name, user_type)
# Log security event
await database.log_security_event(
verification_data["user_data"].get("candidate_data", {}).get("id") or
verification_data["user_data"].get("employer_data", {}).get("id") or "unknown",
verification_data["user_data"].get("candidate_data", {}).get("id")
or verification_data["user_data"].get("employer_data", {}).get("id")
or "unknown",
"verification_resend",
{
"email": email_lower,
"user_type": user_type,
"resend_count": verification_data.get("resend_count", 1),
"old_token_invalidated": old_token[:8] + "...", # Log partial token for debugging
"ip_address": "unknown" # You can extract this from request if needed
"ip_address": "unknown", # You can extract this from request if needed
},
)
logger.info(
f"✅ Verification email resent to {email_lower} (attempt #{verification_data.get('resend_count', 1)})"
)
return create_success_response(
{
"message": "A new verification email has been sent to your email address. Please check your inbox and spam folder.",
"resendCount": verification_data.get("resend_count", 1),
}
)
logger.info(f"✅ Verification email resent to {email_lower} (attempt #{verification_data.get('resend_count', 1)})")
return create_success_response({
"message": "A new verification email has been sent to your email address. Please check your inbox and spam folder.",
"resendCount": verification_data.get("resend_count", 1)
})
except ValueError as ve:
logger.warning(f"⚠️ Invalid resend verification request: {ve}")
return JSONResponse(
status_code=400,
content=create_error_response("VALIDATION_ERROR", str(ve))
)
return JSONResponse(status_code=400, content=create_error_response("VALIDATION_ERROR", str(ve)))
except Exception as e:
logger.error(f"❌ Resend verification email error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("RESEND_FAILED", "An error occurred while processing your request. Please try again later.")
content=create_error_response(
"RESEND_FAILED", "An error occurred while processing your request. Please try again later."
),
)
@router.post("/mfa/request")
async def request_mfa(
request: MFARequest,
background_tasks: BackgroundTasks,
http_request: Request,
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Request MFA for login from new device"""
try:
# Verify credentials first
auth_manager = AuthenticationManager(database)
is_valid, user_data, error_message = await auth_manager.verify_user_credentials(
request.email,
request.password
)
is_valid, user_data, error_message = await auth_manager.verify_user_credentials(request.email, request.password)
if not is_valid or not user_data:
return JSONResponse(
status_code=401,
content=create_error_response("AUTH_FAILED", "Invalid credentials")
)
return JSONResponse(status_code=401, content=create_error_response("AUTH_FAILED", "Invalid credentials"))
# Check if device is trusted
device_manager = DeviceManager(database)
device_info = device_manager.parse_device_info(http_request)
device_manager.parse_device_info(http_request)
is_trusted = await device_manager.is_trusted_device(user_data["id"], request.device_id)
@ -684,10 +656,7 @@ async def request_mfa(
# Device is trusted, proceed with normal login
await device_manager.update_device_last_used(user_data["id"], request.device_id)
return create_success_response({
"mfa_required": False,
"message": "Device is trusted, proceed with login"
})
return create_success_response({"mfa_required": False, "message": "Device is trusted, proceed with login"})
# Generate MFA code
mfa_code = f"{secrets.randbelow(1000000):06d}" # 6-digit code
@ -709,8 +678,7 @@ async def request_mfa(
if not email:
return JSONResponse(
status_code=400,
content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA")
status_code=400, content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA")
)
# Store MFA code
@ -718,13 +686,7 @@ async def request_mfa(
logger.info(f"🔐 MFA code generated for {email} on device {request.device_id}")
# Send MFA code via email
background_tasks.add_task(
email_service.send_mfa_email,
email,
mfa_code,
request.device_name,
user_name
)
background_tasks.add_task(email_service.send_mfa_email, email, mfa_code, request.device_name, user_name)
logger.info(f"🔐 MFA requested for {request.email} from new device {request.device_name}")
@ -735,25 +697,22 @@ async def request_mfa(
device_id=request.device_id,
device_name=request.device_name,
)
mfa_response = MFARequestResponse(
mfa_required=True,
mfa_data=mfa_data
)
mfa_response = MFARequestResponse(mfa_required=True, mfa_data=mfa_data)
return create_success_response(mfa_response)
except Exception as e:
logger.error(f"❌ MFA request error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("MFA_REQUEST_FAILED", "Failed to process MFA request")
status_code=500, content=create_error_response("MFA_REQUEST_FAILED", "Failed to process MFA request")
)
@router.post("/login")
async def login(
request: LoginRequest,
http_request: Request,
background_tasks: BackgroundTasks,
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""login with automatic MFA email sending for new devices"""
try:
@ -766,16 +725,12 @@ async def login(
device_id = device_info["device_id"]
# Verify credentials first
is_valid, user_data, error_message = await auth_manager.verify_user_credentials(
request.login,
request.password
)
is_valid, user_data, error_message = await auth_manager.verify_user_credentials(request.login, request.password)
if not is_valid or not user_data:
logger.warning(f"⚠️ Failed login attempt for: {request.login}")
return JSONResponse(
status_code=401,
content=create_error_response("AUTH_FAILED", error_message or "Invalid credentials")
status_code=401, content=create_error_response("AUTH_FAILED", error_message or "Invalid credentials")
)
# Check if device is trusted
@ -804,8 +759,7 @@ async def login(
if not email:
return JSONResponse(
status_code=400,
content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA")
status_code=400, content=create_error_response("EMAIL_NOT_FOUND", "User email not found for MFA")
)
# Store MFA code
@ -817,12 +771,7 @@ async def login(
# Send MFA code via email in background
background_tasks.add_task(
email_service.send_mfa_email,
email,
mfa_code,
device_info["device_name"],
user_name,
ip_address
email_service.send_mfa_email, email, mfa_code, device_info["device_name"], user_name, ip_address
)
# Log security event
@ -834,8 +783,8 @@ async def login(
"device_name": device_info["device_name"],
"ip_address": ip_address,
"user_agent": device_info.get("user_agent", ""),
"auto_sent": True
}
"auto_sent": True,
},
)
logger.info(f"🔐 MFA code automatically sent to {request.login} for device {device_info['device_name']}")
@ -847,8 +796,8 @@ async def login(
email=email,
device_id=device_id,
device_name=device_info["device_name"],
code_sent=mfa_code
)
code_sent=mfa_code,
),
)
return create_success_response(mfa_response.model_dump(by_alias=True))
@ -860,7 +809,7 @@ async def login(
access_token = create_access_token(data={"sub": user_data["id"]})
refresh_token = create_access_token(
data={"sub": user_data["id"], "type": "refresh"},
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS)
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS),
)
# Get user object
@ -876,8 +825,7 @@ async def login(
if not user:
return JSONResponse(
status_code=404,
content=create_error_response("USER_NOT_FOUND", "User profile not found")
status_code=404, content=create_error_response("USER_NOT_FOUND", "User profile not found")
)
# Log successful login from trusted device
@ -888,8 +836,8 @@ async def login(
"device_id": device_id,
"device_name": device_info["device_name"],
"ip_address": http_request.client.host if http_request.client else "Unknown",
"trusted_device": True
}
"trusted_device": True,
},
)
# Create response
@ -897,7 +845,9 @@ async def login(
access_token=access_token,
refresh_token=refresh_token,
user=user,
expires_at=int((datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp())
expires_at=int(
(datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()
),
)
logger.info(f"🔑 User {request.login} logged in successfully from trusted device")
@ -908,17 +858,12 @@ async def login(
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Login error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("LOGIN_ERROR", "An error occurred during login")
status_code=500, content=create_error_response("LOGIN_ERROR", "An error occurred during login")
)
@router.post("/mfa/verify")
async def verify_mfa(
request: MFAVerifyRequest,
http_request: Request,
database: RedisDatabase = Depends(get_database)
):
async def verify_mfa(request: MFAVerifyRequest, http_request: Request, database: RedisDatabase = Depends(get_database)):
"""Verify MFA code and complete login with error handling"""
try:
# Get MFA data
@ -928,13 +873,17 @@ async def verify_mfa(
logger.warning(f"⚠️ No MFA session found for {request.email} on device {request.device_id}")
return JSONResponse(
status_code=404,
content=create_error_response("NO_MFA_SESSION", "No active MFA session found. Please try logging in again.")
content=create_error_response(
"NO_MFA_SESSION", "No active MFA session found. Please try logging in again."
),
)
if mfa_data.get("verified"):
return JSONResponse(
status_code=400,
content=create_error_response("ALREADY_VERIFIED", "This MFA code has already been used. Please login again.")
content=create_error_response(
"ALREADY_VERIFIED", "This MFA code has already been used. Please login again."
),
)
# Check expiration
@ -944,7 +893,7 @@ async def verify_mfa(
await database.redis.delete(f"mfa_code:{request.email.lower()}:{request.device_id}")
return JSONResponse(
status_code=400,
content=create_error_response("MFA_EXPIRED", "MFA code has expired. Please try logging in again.")
content=create_error_response("MFA_EXPIRED", "MFA code has expired. Please try logging in again."),
)
# Check attempts
@ -954,7 +903,9 @@ async def verify_mfa(
await database.redis.delete(f"mfa_code:{request.email.lower()}:{request.device_id}")
return JSONResponse(
status_code=429,
content=create_error_response("TOO_MANY_ATTEMPTS", "Too many incorrect attempts. Please try logging in again.")
content=create_error_response(
"TOO_MANY_ATTEMPTS", "Too many incorrect attempts. Please try logging in again."
),
)
# Verify code
@ -965,9 +916,8 @@ async def verify_mfa(
return JSONResponse(
status_code=400,
content=create_error_response(
"INVALID_CODE",
f"Invalid MFA code. {remaining_attempts} attempts remaining."
)
"INVALID_CODE", f"Invalid MFA code. {remaining_attempts} attempts remaining."
),
)
# Mark as verified
@ -976,20 +926,13 @@ async def verify_mfa(
# Get user data
user_data = await database.get_user(request.email)
if not user_data:
return JSONResponse(
status_code=404,
content=create_error_response("USER_NOT_FOUND", "User not found")
)
return JSONResponse(status_code=404, content=create_error_response("USER_NOT_FOUND", "User not found"))
# Add device to trusted devices if requested
if request.remember_device:
device_manager = DeviceManager(database)
device_info = device_manager.parse_device_info(http_request)
await device_manager.add_trusted_device(
user_data["id"],
request.device_id,
device_info
)
await device_manager.add_trusted_device(user_data["id"], request.device_id, device_info)
logger.info(f"🔒 Device {request.device_id} added to trusted devices for user {user_data['id']}")
# Update last login
@ -1000,7 +943,7 @@ async def verify_mfa(
access_token = create_access_token(data={"sub": user_data["id"]})
refresh_token = create_access_token(
data={"sub": user_data["id"], "type": "refresh"},
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS)
expires_delta=timedelta(days=SecurityConfig.REFRESH_TOKEN_EXPIRY_DAYS),
)
# Get user object
@ -1016,8 +959,7 @@ async def verify_mfa(
if not user:
return JSONResponse(
status_code=404,
content=create_error_response("USER_NOT_FOUND", "User profile not found")
status_code=404, content=create_error_response("USER_NOT_FOUND", "User profile not found")
)
# Log successful MFA verification and login
@ -1028,8 +970,8 @@ async def verify_mfa(
"device_id": request.device_id,
"ip_address": http_request.client.host if http_request.client else "Unknown",
"device_remembered": request.remember_device,
"attempts_used": current_attempts + 1
}
"attempts_used": current_attempts + 1,
},
)
await database.log_security_event(
@ -1039,8 +981,8 @@ async def verify_mfa(
"device_id": request.device_id,
"ip_address": http_request.client.host if http_request.client else "Unknown",
"mfa_verified": True,
"new_device": True
}
"new_device": True,
},
)
# Clean up MFA session
@ -1051,7 +993,9 @@ async def verify_mfa(
access_token=access_token,
refresh_token=refresh_token,
user=user,
expires_at=int((datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp())
expires_at=int(
(datetime.now(timezone.utc) + timedelta(hours=SecurityConfig.TOKEN_EXPIRY_HOURS)).timestamp()
),
)
logger.info(f"✅ MFA verified and login completed for {request.email}")
@ -1062,15 +1006,12 @@ async def verify_mfa(
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ MFA verification error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("MFA_VERIFICATION_FAILED", "Failed to verify MFA")
status_code=500, content=create_error_response("MFA_VERIFICATION_FAILED", "Failed to verify MFA")
)
@router.post("/password-reset/request")
async def request_password_reset(
request: PasswordResetRequest,
database: RedisDatabase = Depends(get_database)
):
async def request_password_reset(request: PasswordResetRequest, database: RedisDatabase = Depends(get_database)):
"""Request password reset"""
try:
# Check if user exists
@ -1100,15 +1041,12 @@ async def request_password_reset(
except Exception as e:
logger.error(f"❌ Password reset request error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("RESET_ERROR", "An error occurred processing the request")
status_code=500, content=create_error_response("RESET_ERROR", "An error occurred processing the request")
)
@router.post("/password-reset/confirm")
async def confirm_password_reset(
request: PasswordResetConfirm,
database: RedisDatabase = Depends(get_database)
):
async def confirm_password_reset(request: PasswordResetConfirm, database: RedisDatabase = Depends(get_database)):
"""Confirm password reset with token"""
try:
# Find user by reset token
@ -1122,8 +1060,5 @@ async def confirm_password_reset(
except Exception as e:
logger.error(f"❌ Password reset confirm error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("RESET_ERROR", "An error occurred resetting the password")
status_code=500, content=create_error_response("RESET_ERROR", "An error occurred resetting the password")
)

File diff suppressed because it is too large Load Diff

View File

@ -4,125 +4,69 @@ Chat routes
import json
import uuid
from datetime import datetime, UTC
from typing import (Dict, Any)
from typing import Dict, Any
from fastapi import (
APIRouter, Depends, Body, Depends, Query, Path,
Body, APIRouter
)
from fastapi import APIRouter, Depends, Body, Depends, Query, Path, Body, APIRouter
from fastapi.responses import JSONResponse
from database.manager import RedisDatabase
from logger import logger
from utils.dependencies import (
get_database, get_current_user, get_current_user_or_guest
)
from utils.responses import (
create_success_response, create_error_response, create_paginated_response
)
from utils.helpers import (
stream_agent_response
)
from utils.dependencies import get_database, get_current_user, get_current_user_or_guest
from utils.responses import create_success_response, create_error_response, create_paginated_response
from utils.helpers import stream_agent_response
import backstory_traceback
import entities.entity_manager as entities
from models import (
LoginRequest, CreateCandidateRequest, CreateEmployerRequest,
Candidate, Employer, Guest, AuthResponse,
MFARequest, MFAData, MFARequestResponse, MFAVerifyRequest,
EmailVerificationRequest, ResendVerificationRequest,
# API
MOCK_UUID, ApiActivityType, ChatMessageError, ChatMessageResume,
ChatMessageSkillAssessment, ChatMessageStatus, ChatMessageStreaming,
ChatMessageUser, DocumentMessage, DocumentOptions, Job,
JobRequirements, JobRequirementsMessage, LoginRequest,
CreateCandidateRequest, CreateEmployerRequest,
# User models
Candidate, Employer, BaseUserWithType, BaseUser, Guest,
Authentication, AuthResponse, CandidateAI,
# Job models
JobApplication, ApplicationStatus,
# Chat models
ChatSession, ChatMessage, ChatContext, ChatQuery, ApiStatusType, ChatSenderType, ApiMessageType, ChatContextType,
ChatMessageRagSearch,
# Document models
Document, DocumentType, DocumentListResponse, DocumentUpdateRequest, DocumentContentResponse,
# Supporting models
Location, MFARequest, MFAData, MFARequestResponse, MFAVerifyRequest, RagContentMetadata, RagContentResponse, ResendVerificationRequest, Resume, ResumeMessage, Skill, SkillAssessment, SystemInfo, UserType, WorkExperience, Education,
# Email
EmailVerificationRequest
)
from models import Candidate, ChatMessageUser, Candidate, BaseUserWithType, ChatSession, ChatMessage
# Create router for authentication endpoints
router = APIRouter(prefix="/chat", tags=["chat"])
@router.post("/sessions/{session_id}/archive")
async def archive_chat_session(
session_id: str = Path(...),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
session_id: str = Path(...), current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)
):
"""Archive a chat session"""
try:
session_data = await database.get_chat_session(session_id)
if not session_data:
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
# Check if user owns this session or is admin
if session_data.get("userId") != current_user.id:
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Cannot archive another user's session")
status_code=403, content=create_error_response("FORBIDDEN", "Cannot archive another user's session")
)
await database.archive_chat_session(session_id)
return create_success_response({
"message": "Chat session archived successfully",
"sessionId": session_id
})
return create_success_response({"message": "Chat session archived successfully", "sessionId": session_id})
except Exception as e:
logger.error(f"❌ Archive chat session error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("ARCHIVE_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("ARCHIVE_ERROR", str(e)))
@router.get("/statistics")
async def get_chat_statistics(
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
):
async def get_chat_statistics(current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)):
"""Get chat statistics (admin/analytics endpoint)"""
try:
stats = await database.get_chat_statistics()
return create_success_response(stats)
except Exception as e:
logger.error(f"❌ Get chat statistics error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("STATS_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("STATS_ERROR", str(e)))
@router.post("/sessions")
async def create_chat_session(
session_data: Dict[str, Any] = Body(...),
current_user: BaseUserWithType = Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Create a new chat session with optional candidate username association"""
try:
@ -140,15 +84,14 @@ async def create_chat_session(
candidates_list = [Candidate.model_validate(data) for data in all_candidates_data.values()]
# Find candidate by username (case-insensitive)
matching_candidates = [
c for c in candidates_list
if c.username.lower() == username.lower()
]
matching_candidates = [c for c in candidates_list if c.username.lower() == username.lower()]
if not matching_candidates:
return JSONResponse(
status_code=404,
content=create_error_response("CANDIDATE_NOT_FOUND", f"Candidate with username '{username}' not found")
content=create_error_response(
"CANDIDATE_NOT_FOUND", f"Candidate with username '{username}' not found"
),
)
candidate_data = matching_candidates[0]
@ -177,7 +120,7 @@ async def create_chat_session(
"username": candidate_data.username,
"skills": [skill.name for skill in candidate_data.skills] if candidate_data.skills else [],
"experience": len(candidate_data.experience) if candidate_data.experience else 0,
"location": candidate_data.location.city if candidate_data.location else "Unknown"
"location": candidate_data.location.city if candidate_data.location else "Unknown",
}
context["additionalContext"] = additional_context
@ -191,8 +134,10 @@ async def create_chat_session(
chat_session = ChatSession.model_validate(session_data)
await database.set_chat_session(chat_session.id, chat_session.model_dump())
logger.info(f"✅ Chat session created: {chat_session.id} for user {current_user.id}" +
(f" about candidate {candidate_data.full_name}" if candidate_data else ""))
logger.info(
f"✅ Chat session created: {chat_session.id} for user {current_user.id}"
+ (f" about candidate {candidate_data.full_name}" if candidate_data else "")
)
return create_success_response(chat_session.model_dump(by_alias=True))
@ -200,49 +145,58 @@ async def create_chat_session(
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Chat session creation error: {e}")
logger.info(json.dumps(session_data, indent=2))
return JSONResponse(
status_code=400,
content=create_error_response("CREATION_FAILED", str(e))
)
return JSONResponse(status_code=400, content=create_error_response("CREATION_FAILED", str(e)))
@router.post("/sessions/messages/stream")
async def post_chat_session_message_stream(
user_message: ChatMessageUser = Body(...),
current_user = Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database),
):
"""Post a message to a chat session and stream the response with persistence"""
try:
chat_session_data = await database.get_chat_session(user_message.session_id)
if not chat_session_data:
logger.info("🔗 Chat session not found for session ID: " + user_message.session_id)
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
chat_session = ChatSession.model_validate(chat_session_data)
chat_type = chat_session.context.type
candidate_info = chat_session.context.additional_context.get("candidateInfo", {}) if chat_session.context and chat_session.context.additional_context else None
candidate_info = (
chat_session.context.additional_context.get("candidateInfo", {})
if chat_session.context and chat_session.context.additional_context
else None
)
# Get candidate info if this chat is about a specific candidate
if candidate_info:
logger.info(f"🔗 Chat session {user_message.session_id} about candidate {candidate_info['name']} accessed by user {current_user.id}")
logger.info(
f"🔗 Chat session {user_message.session_id} about candidate {candidate_info['name']} accessed by user {current_user.id}"
)
else:
logger.info(f"🔗 Chat session {user_message.session_id} type {chat_type} accessed by user {current_user.id}")
logger.info(
f"🔗 Chat session {user_message.session_id} type {chat_type} accessed by user {current_user.id}"
)
return JSONResponse(
status_code=400,
content=create_error_response("CANDIDATE_REQUIRED", "This chat session requires a candidate association")
content=create_error_response(
"CANDIDATE_REQUIRED", "This chat session requires a candidate association"
),
)
candidate_data = await database.get_candidate(candidate_info["id"]) if candidate_info else None
candidate : Candidate | None = Candidate.model_validate(candidate_data) if candidate_data else None
candidate: Candidate | None = Candidate.model_validate(candidate_data) if candidate_data else None
if not candidate:
logger.info(f"🔗 Candidate not found for chat session {user_message.session_id} with ID {candidate_info['id']}")
logger.info(
f"🔗 Candidate not found for chat session {user_message.session_id} with ID {candidate_info['id']}"
)
return JSONResponse(
status_code=404,
content=create_error_response("CANDIDATE_NOT_FOUND", "Candidate not found for this chat session")
content=create_error_response("CANDIDATE_NOT_FOUND", "Candidate not found for this chat session"),
)
logger.info(f"🔗 User {current_user.id} posting message to chat session {user_message.session_id} with query length: {len(user_message.content)}")
logger.info(
f"🔗 User {current_user.id} posting message to chat session {user_message.session_id} with query length: {len(user_message.content)}"
)
async with entities.get_candidate_entity(candidate=candidate) as candidate_entity:
# Entity automatically released when done
@ -251,7 +205,7 @@ async def post_chat_session_message_stream(
logger.info(f"🔗 No chat agent found for session {user_message.session_id} with type {chat_type}")
return JSONResponse(
status_code=400,
content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type")
content=create_error_response("AGENT_NOT_FOUND", "No agent found for this chat type"),
)
# Persist user message to database
@ -269,30 +223,25 @@ async def post_chat_session_message_stream(
chat_session_data=chat_session_data,
)
except Exception as e:
except Exception:
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Chat message streaming error")
return JSONResponse(
status_code=500,
content=create_error_response("STREAMING_ERROR", "")
)
logger.error("❌ Chat message streaming error")
return JSONResponse(status_code=500, content=create_error_response("STREAMING_ERROR", ""))
@router.get("/sessions/{session_id}/messages")
async def get_chat_session_messages(
session_id: str = Path(...),
current_user = Depends(get_current_user_or_guest),
current_user=Depends(get_current_user_or_guest),
page: int = Query(1, ge=1),
limit: int = Query(50, ge=1, le=100), # Increased default for chat messages
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Get persisted chat messages for a session"""
try:
chat_session_data = await database.get_chat_session(session_id)
if not chat_session_data:
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
# Get messages from database
chat_messages = await database.get_chat_messages(session_id)
@ -317,43 +266,36 @@ async def get_chat_session_messages(
paginated_messages = messages_list[start:end]
paginated_response = create_paginated_response(
[m.model_dump(by_alias=True) for m in paginated_messages],
page, limit, total
[m.model_dump(by_alias=True) for m in paginated_messages], page, limit, total
)
return create_success_response(paginated_response)
except Exception as e:
logger.error(f"❌ Get chat messages error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("FETCH_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e)))
@router.patch("/sessions/{session_id}")
async def update_chat_session(
session_id: str = Path(...),
updates: Dict[str, Any] = Body(...),
current_user = Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database),
):
"""Update a chat session's properties"""
try:
# Get the existing session
session_data = await database.get_chat_session(session_id)
if not session_data:
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
session = ChatSession.model_validate(session_data)
# Check authorization - user can only update their own sessions
if session.user_id != current_user.id:
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Cannot update another user's chat session")
status_code=403, content=create_error_response("FORBIDDEN", "Cannot update another user's chat session")
)
# Validate and apply updates
@ -362,8 +304,7 @@ async def update_chat_session(
if not filtered_updates:
return JSONResponse(
status_code=400,
content=create_error_response("INVALID_UPDATES", "No valid fields provided for update")
status_code=400, content=create_error_response("INVALID_UPDATES", "No valid fields provided for update")
)
# Apply updates to session data
@ -417,40 +358,31 @@ async def update_chat_session(
except ValueError as ve:
logger.warning(f"⚠️ Validation error updating chat session: {ve}")
return JSONResponse(
status_code=400,
content=create_error_response("VALIDATION_ERROR", str(ve))
)
return JSONResponse(status_code=400, content=create_error_response("VALIDATION_ERROR", str(ve)))
except Exception as e:
logger.error(f"❌ Update chat session error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("UPDATE_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("UPDATE_ERROR", str(e)))
@router.delete("/sessions/{session_id}")
async def delete_chat_session(
session_id: str = Path(...),
current_user = Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database),
):
"""Delete a chat session and all its messages"""
try:
# Get the session to verify it exists and check ownership
session_data = await database.get_chat_session(session_id)
if not session_data:
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
session = ChatSession.model_validate(session_data)
# Check authorization - user can only delete their own sessions
if session.user_id != current_user.id:
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Cannot delete another user's chat session")
status_code=403, content=create_error_response("FORBIDDEN", "Cannot delete another user's chat session")
)
# Delete all messages associated with this session
@ -469,42 +401,34 @@ async def delete_chat_session(
logger.info(f"🗑️ Chat session {session_id} deleted by user {current_user.id}")
return create_success_response({
"success": True,
"message": "Chat session deleted successfully",
"sessionId": session_id
})
return create_success_response(
{"success": True, "message": "Chat session deleted successfully", "sessionId": session_id}
)
except Exception as e:
logger.error(f"❌ Delete chat session error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("DELETE_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("DELETE_ERROR", str(e)))
@router.patch("/sessions/{session_id}/reset")
async def reset_chat_session(
session_id: str = Path(...),
current_user = Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user_or_guest),
database: RedisDatabase = Depends(get_database),
):
"""Delete a chat session and all its messages"""
try:
# Get the session to verify it exists and check ownership
session_data = await database.get_chat_session(session_id)
if not session_data:
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Chat session not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Chat session not found"))
session = ChatSession.model_validate(session_data)
# Check authorization - user can only delete their own sessions
if session.user_id != current_user.id:
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Cannot reset another user's chat session")
status_code=403, content=create_error_response("FORBIDDEN", "Cannot reset another user's chat session")
)
# Delete all messages associated with this session
@ -518,20 +442,12 @@ async def reset_chat_session(
logger.warning(f"⚠️ Error deleting messages for session {session_id}: {e}")
# Continue with session deletion even if message deletion fails
logger.info(f"🗑️ Chat session {session_id} reset by user {current_user.id}")
return create_success_response({
"success": True,
"message": "Chat session reset successfully",
"sessionId": session_id
})
return create_success_response(
{"success": True, "message": "Chat session reset successfully", "sessionId": session_id}
)
except Exception as e:
logger.error(f"❌ Reset chat session error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("RESET_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("RESET_ERROR", str(e)))

View File

@ -9,24 +9,21 @@ from fastapi.responses import JSONResponse
from database.manager import RedisDatabase
from logger import logger
from utils.dependencies import (
get_current_admin, get_database
)
from utils.dependencies import get_current_admin, get_database
from utils.responses import create_success_response, create_error_response
# Create router for authentication endpoints
router = APIRouter(prefix="/auth", tags=["authentication"])
@router.get("/guest/{guest_id}")
async def debug_guest_session(
guest_id: str = Path(...),
admin_user = Depends(get_current_admin),
database: RedisDatabase = Depends(get_database)
guest_id: str = Path(...), admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
):
"""Debug guest session issues (admin only)"""
try:
# Check primary storage
primary_data = await database.redis.hget("guests", guest_id) # type: ignore
primary_data = await database.redis.hget("guests", guest_id) # type: ignore
primary_exists = primary_data is not None
# Check backup storage
@ -37,7 +34,7 @@ async def debug_guest_session(
user_lookup = await database.get_user_by_id(guest_id)
# Get TTL info
primary_ttl = await database.redis.ttl(f"guests")
primary_ttl = await database.redis.ttl("guests")
backup_ttl = await database.redis.ttl(f"guest_backup:{guest_id}")
debug_info = {
@ -45,22 +42,19 @@ async def debug_guest_session(
"primary_storage": {
"exists": primary_exists,
"data": json.loads(primary_data) if primary_data else None,
"ttl": primary_ttl
"ttl": primary_ttl,
},
"backup_storage": {
"exists": backup_exists,
"data": json.loads(backup_data) if backup_data else None,
"ttl": backup_ttl
"ttl": backup_ttl,
},
"user_lookup": user_lookup,
"timestamp": datetime.now(UTC).isoformat()
"timestamp": datetime.now(UTC).isoformat(),
}
return create_success_response(debug_info)
except Exception as e:
logger.error(f"❌ Debug guest session error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("DEBUG_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("DEBUG_ERROR", str(e)))

View File

@ -10,12 +10,8 @@ from fastapi.responses import JSONResponse
from database.manager import RedisDatabase
from logger import logger
from models import (
CreateEmployerRequest
)
from utils.dependencies import (
get_database
)
from models import CreateEmployerRequest
from utils.dependencies import get_database
from utils.responses import create_success_response, create_error_response
from email_service import email_service
from utils.auth_utils import AuthenticationManager
@ -23,29 +19,22 @@ from utils.auth_utils import AuthenticationManager
# Create router for job endpoints
router = APIRouter(prefix="/employers", tags=["employers"])
@router.post("")
async def create_employer_with_verification(
request: CreateEmployerRequest,
background_tasks: BackgroundTasks,
database: RedisDatabase = Depends(get_database)
request: CreateEmployerRequest, background_tasks: BackgroundTasks, database: RedisDatabase = Depends(get_database)
):
"""Create a new employer with email verification"""
try:
# Similar to candidate creation but for employer
auth_manager = AuthenticationManager(database)
user_exists, conflict_field = await auth_manager.check_user_exists(
request.email,
request.username
)
user_exists, conflict_field = await auth_manager.check_user_exists(request.email, request.username)
if user_exists and conflict_field:
return JSONResponse(
status_code=409,
content=create_error_response(
"USER_EXISTS",
f"A user with this {conflict_field} already exists"
)
content=create_error_response("USER_EXISTS", f"A user with this {conflict_field} already exists"),
)
employer_id = str(uuid.uuid4())
@ -64,12 +53,8 @@ async def create_employer_with_verification(
"updatedAt": current_time.isoformat(),
"status": "pending", # Not active until verified
"userType": "employer",
"location": {
"city": "",
"country": "",
"remote": False
},
"socialLinks": []
"location": {"city": "", "country": "", "remote": False},
"socialLinks": [],
}
verification_token = secrets.token_urlsafe(32)
@ -78,32 +63,25 @@ async def create_employer_with_verification(
request.email,
verification_token,
"employer",
{
"employer_data": employer_data,
"password": request.password,
"username": request.username
}
{"employer_data": employer_data, "password": request.password, "username": request.username},
)
background_tasks.add_task(
email_service.send_verification_email,
request.email,
verification_token,
request.company_name
email_service.send_verification_email, request.email, verification_token, request.company_name
)
logger.info(f"✅ Employer registration initiated for: {request.email}")
return create_success_response({
"message": "Registration successful! Please check your email to verify your account.",
"email": request.email,
"verificationRequired": True
})
return create_success_response(
{
"message": "Registration successful! Please check your email to verify your account.",
"email": request.email,
"verificationRequired": True,
}
)
except Exception as e:
logger.error(f"❌ Employer creation error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("CREATION_FAILED", "Failed to create employer account")
status_code=500, content=create_error_response("CREATION_FAILED", "Failed to create employer account")
)

View File

@ -21,14 +21,24 @@ from utils.helpers import create_job_from_content, filter_and_paginate, get_docu
from database.manager import RedisDatabase
from logger import logger
from models import (
MOCK_UUID, ApiActivityType, ApiStatusType, ChatContextType, ChatMessage, ChatMessageError, ChatMessageStatus, DocumentType, Job, JobRequirementsMessage, Candidate, Employer
)
from utils.dependencies import (
get_current_admin, get_database, get_current_user
MOCK_UUID,
ApiActivityType,
ApiStatusType,
ChatContextType,
ChatMessage,
ChatMessageError,
ChatMessageStatus,
DocumentType,
Job,
JobRequirementsMessage,
Candidate,
Employer,
)
from utils.dependencies import get_current_admin, get_database, get_current_user
from utils.responses import create_paginated_response, create_success_response, create_error_response
import utils.llm_proxy as llm_manager
import entities.entity_manager as entities
# Create router for job endpoints
router = APIRouter(prefix="/jobs", tags=["jobs"])
@ -38,14 +48,14 @@ async def reformat_as_markdown(database: RedisDatabase, candidate_entity: Candid
if not chat_agent:
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="No agent found for job requirements chat type"
content="No agent found for job requirements chat type",
)
yield error_message
return
status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Reformatting job description as markdown...",
activity=ApiActivityType.CONVERTING
content="Reformatting job description as markdown...",
activity=ApiActivityType.CONVERTING,
)
yield status_message
@ -58,7 +68,7 @@ async def reformat_as_markdown(database: RedisDatabase, candidate_entity: Candid
system_prompt="""
You are a document editor. Take the provided job description and reformat as legible markdown.
Return only the markdown content, no other text. Make sure all content is included.
"""
""",
):
pass
@ -66,16 +76,16 @@ Return only the markdown content, no other text. Make sure all content is includ
logger.error("❌ Failed to reformat job description to markdown")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to reformat job description"
content="Failed to reformat job description",
)
yield error_message
return
chat_message : ChatMessage = message
chat_message: ChatMessage = message
try:
chat_message.content = chat_agent.extract_markdown_from_text(chat_message.content)
except Exception as e:
except Exception:
pass
logger.info(f"✅ Successfully converted content to markdown")
logger.info("✅ Successfully converted content to markdown")
yield chat_message
return
@ -84,10 +94,10 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Initiating connection with {current_user.first_name}'s AI agent...",
activity=ApiActivityType.INFO
activity=ApiActivityType.INFO,
)
yield status_message
await asyncio.sleep(0) # Let the status message propagate
await asyncio.sleep(0) # Let the status message propagate
async with entities.get_candidate_entity(candidate=current_user) as candidate_entity:
message = None
@ -98,7 +108,7 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
if not message or not isinstance(message, ChatMessage):
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to reformat job description"
content="Failed to reformat job description",
)
yield error_message
return
@ -108,23 +118,20 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
if not chat_agent:
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="No agent found for job requirements chat type"
content="No agent found for job requirements chat type",
)
yield error_message
return
status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Analyzing document for company and requirement details...",
activity=ApiActivityType.SEARCHING
content="Analyzing document for company and requirement details...",
activity=ApiActivityType.SEARCHING,
)
yield status_message
message = None
async for message in chat_agent.generate(
llm=llm_manager.get_llm(),
model=defines.model,
session_id=MOCK_UUID,
prompt=markdown_message.content
llm=llm_manager.get_llm(), model=defines.model, session_id=MOCK_UUID, prompt=markdown_message.content
):
if message.status != ApiStatusType.DONE:
yield message
@ -132,23 +139,22 @@ async def create_job_from_content(database: RedisDatabase, current_user: Candida
if not message or not isinstance(message, JobRequirementsMessage):
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Job extraction did not convert successfully"
content="Job extraction did not convert successfully",
)
yield error_message
return
job_requirements : JobRequirementsMessage = message
job_requirements: JobRequirementsMessage = message
logger.info(f"✅ Successfully generated job requirements for job {job_requirements.id}")
yield job_requirements
return
@router.post("")
async def create_job(
job_data: Dict[str, Any] = Body(...),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Create a new job"""
try:
@ -165,20 +171,17 @@ async def create_job(
except Exception as e:
logger.error(f"❌ Job creation error: {e}")
return JSONResponse(
status_code=400,
content=create_error_response("CREATION_FAILED", str(e))
)
return JSONResponse(status_code=400, content=create_error_response("CREATION_FAILED", str(e)))
@router.post("")
async def create_candidate_job(
job_data: Dict[str, Any] = Body(...),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Create a new job"""
is_employer = isinstance(current_user, Employer)
isinstance(current_user, Employer)
try:
job = Job.model_validate(job_data)
@ -194,28 +197,22 @@ async def create_candidate_job(
except Exception as e:
logger.error(f"❌ Job creation error: {e}")
return JSONResponse(
status_code=400,
content=create_error_response("CREATION_FAILED", str(e))
)
return JSONResponse(status_code=400, content=create_error_response("CREATION_FAILED", str(e)))
@router.patch("/{job_id}")
async def update_job(
job_id: str = Path(...),
updates: Dict[str, Any] = Body(...),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Update a candidate"""
try:
job_data = await database.get_job(job_id)
if not job_data:
logger.warning(f"⚠️ Job not found for update: {job_data}")
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Job not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Job not found"))
job = Job.model_validate(job_data)
@ -223,8 +220,7 @@ async def update_job(
if current_user.is_admin is False and job.owner_id != current_user.id:
logger.warning(f"⚠️ Unauthorized update attempt by user {current_user.id} on job {job_id}")
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Cannot update another user's job")
status_code=403, content=create_error_response("FORBIDDEN", "Cannot update another user's job")
)
# Apply updates
@ -239,25 +235,22 @@ async def update_job(
except Exception as e:
logger.error(f"❌ Update job error: {e}")
return JSONResponse(
status_code=400,
content=create_error_response("UPDATE_FAILED", str(e))
)
return JSONResponse(status_code=400, content=create_error_response("UPDATE_FAILED", str(e)))
@router.post("/from-content")
async def create_job_from_description(
content: str = Body(...),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
content: str = Body(...), current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)
):
"""Upload a document for the current candidate"""
async def content_stream_generator(content):
# Verify user is a candidate
if current_user.user_type != "candidate":
logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Only candidates can upload documents"
content="Only candidates can upload documents",
)
yield error_message
return
@ -277,10 +270,11 @@ async def create_job_from_description(
return
try:
async def to_json(method):
try:
async for message in method:
json_data = message.model_dump(mode='json', by_alias=True)
json_data = message.model_dump(mode="json", by_alias=True)
json_str = json.dumps(json_data)
yield f"data: {json_str}\n\n".encode("utf-8")
except Exception as e:
@ -304,18 +298,25 @@ async def create_job_from_description(
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Document upload error: {e}")
return StreamingResponse(
iter([json.dumps(ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to upload document"
).model_dump(by_alias=True)).encode("utf-8")]),
media_type="text/event-stream"
iter(
[
json.dumps(
ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to upload document",
).model_dump(by_alias=True)
).encode("utf-8")
]
),
media_type="text/event-stream",
)
@router.post("/upload")
async def create_job_from_file(
file: UploadFile = File(...),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Upload a job document for the current candidate and create a Job"""
# Check file size (limit to 10MB)
@ -324,55 +325,70 @@ async def create_job_from_file(
if len(file_content) > max_size:
logger.info(f"⚠️ File too large: {file.filename} ({len(file_content)} bytes)")
return StreamingResponse(
iter([json.dumps(ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="File size exceeds 10MB limit"
).model_dump(by_alias=True)).encode("utf-8")]),
media_type="text/event-stream"
iter(
[
json.dumps(
ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="File size exceeds 10MB limit",
).model_dump(by_alias=True)
).encode("utf-8")
]
),
media_type="text/event-stream",
)
if len(file_content) == 0:
logger.info(f"⚠️ File is empty: {file.filename}")
return StreamingResponse(
iter([json.dumps(ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="File is empty"
).model_dump(by_alias=True)).encode("utf-8")]),
media_type="text/event-stream"
iter(
[
json.dumps(
ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="File is empty",
).model_dump(by_alias=True)
).encode("utf-8")
]
),
media_type="text/event-stream",
)
"""Upload a document for the current candidate"""
async def upload_stream_generator(file_content):
# Verify user is a candidate
if current_user.user_type != "candidate":
logger.warning(f"⚠️ Unauthorized upload attempt by user type: {current_user.user_type}")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Only candidates can upload documents"
content="Only candidates can upload documents",
)
yield error_message
return
file.filename = re.sub(r'^.*/', '', file.filename) if file.filename else '' # Sanitize filename
file.filename = re.sub(r"^.*/", "", file.filename) if file.filename else "" # Sanitize filename
if not file.filename or file.filename.strip() == "":
logger.warning("⚠️ File upload attempt with missing filename")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="File must have a valid filename"
content="File must have a valid filename",
)
yield error_message
return
logger.info(f"📁 Received file upload: filename='{file.filename}', content_type='{file.content_type}', size='{len(file_content)} bytes'")
logger.info(
f"📁 Received file upload: filename='{file.filename}', content_type='{file.content_type}', size='{len(file_content)} bytes'"
)
# Validate file type
allowed_types = ['.txt', '.md', '.docx', '.pdf', '.png', '.jpg', '.jpeg', '.gif']
allowed_types = [".txt", ".md", ".docx", ".pdf", ".png", ".jpg", ".jpeg", ".gif"]
file_extension = pathlib.Path(file.filename).suffix.lower() if file.filename else ""
if file_extension not in allowed_types:
logger.warning(f"⚠️ Invalid file type: {file_extension} for file {file.filename}")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}"
content=f"File type {file_extension} not supported. Allowed types: {', '.join(allowed_types)}",
)
yield error_message
return
@ -383,7 +399,7 @@ async def create_job_from_file(
status_message = ChatMessageStatus(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Converting content from {document_type}...",
activity=ApiActivityType.CONVERTING
activity=ApiActivityType.CONVERTING,
)
yield status_message
try:
@ -391,7 +407,7 @@ async def create_job_from_file(
stream = io.BytesIO(file_content)
stream_info = StreamInfo(
extension=file_extension, # e.g., ".pdf"
url=file.filename # optional, helps with logging and guessing
url=file.filename, # optional, helps with logging and guessing
)
result = md.convert_stream(stream, stream_info=stream_info, output_format="markdown")
file_content = result.text_content
@ -405,15 +421,18 @@ async def create_job_from_file(
logger.error(f"❌ Error converting {file.filename} to Markdown: {e}")
return
async for message in create_job_from_content(database=database, current_user=current_user, content=file_content):
async for message in create_job_from_content(
database=database, current_user=current_user, content=file_content
):
yield message
return
try:
async def to_json(method):
try:
async for message in method:
json_data = message.model_dump(mode='json', by_alias=True)
json_data = message.model_dump(mode="json", by_alias=True)
json_str = json.dumps(json_data)
yield f"data: {json_str}\n\n".encode("utf-8")
except Exception as e:
@ -437,26 +456,27 @@ async def create_job_from_file(
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Document upload error: {e}")
return StreamingResponse(
iter([json.dumps(ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to upload document"
).model_dump(mode='json', by_alias=True)).encode("utf-8")]),
media_type="text/event-stream"
iter(
[
json.dumps(
ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to upload document",
).model_dump(mode="json", by_alias=True)
).encode("utf-8")
]
),
media_type="text/event-stream",
)
@router.get("/{job_id}")
async def get_job(
job_id: str = Path(...),
database: RedisDatabase = Depends(get_database)
):
async def get_job(job_id: str = Path(...), database: RedisDatabase = Depends(get_database)):
"""Get a job by ID"""
try:
job_data = await database.get_job(job_id)
if not job_data:
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Job not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Job not found"))
# Increment view count
job_data["views"] = job_data.get("views", 0) + 1
@ -467,10 +487,8 @@ async def get_job(
except Exception as e:
logger.error(f"❌ Get job error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("FETCH_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e)))
@router.get("")
async def get_jobs(
@ -479,7 +497,7 @@ async def get_jobs(
sortBy: Optional[str] = Query(None, alias="sortBy"),
sortOrder: str = Query("desc", pattern="^(asc|desc)$", alias="sortOrder"),
filters: Optional[str] = Query(None),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Get paginated list of jobs"""
try:
@ -493,23 +511,18 @@ async def get_jobs(
for job in all_jobs_data.values():
jobs_list.append(Job.model_validate(job))
paginated_jobs, total = filter_and_paginate(
jobs_list, page, limit, sortBy, sortOrder, filter_dict
)
paginated_jobs, total = filter_and_paginate(jobs_list, page, limit, sortBy, sortOrder, filter_dict)
paginated_response = create_paginated_response(
[j.model_dump(by_alias=True) for j in paginated_jobs],
page, limit, total
[j.model_dump(by_alias=True) for j in paginated_jobs], page, limit, total
)
return create_success_response(paginated_response)
except Exception as e:
logger.error(f"❌ Get jobs error: {e}")
return JSONResponse(
status_code=400,
content=create_error_response("FETCH_FAILED", str(e))
)
return JSONResponse(status_code=400, content=create_error_response("FETCH_FAILED", str(e)))
@router.get("/search")
async def search_jobs(
@ -517,7 +530,7 @@ async def search_jobs(
filters: Optional[str] = Query(None),
page: int = Query(1, ge=1),
limit: int = Query(20, ge=1, le=100),
database: RedisDatabase = Depends(get_database)
database: RedisDatabase = Depends(get_database),
):
"""Search jobs"""
try:
@ -532,69 +545,52 @@ async def search_jobs(
if query:
query_lower = query.lower()
jobs_list = [
j for j in jobs_list
if ((j.title and query_lower in j.title.lower()) or
(j.description and query_lower in j.description.lower()) or
any(query_lower in skill.lower() for skill in getattr(j, "skills", []) or []))
j
for j in jobs_list
if (
(j.title and query_lower in j.title.lower())
or (j.description and query_lower in j.description.lower())
or any(query_lower in skill.lower() for skill in getattr(j, "skills", []) or [])
)
]
paginated_jobs, total = filter_and_paginate(
jobs_list, page, limit, filters=filter_dict
)
paginated_jobs, total = filter_and_paginate(jobs_list, page, limit, filters=filter_dict)
paginated_response = create_paginated_response(
[j.model_dump(by_alias=True) for j in paginated_jobs],
page, limit, total
[j.model_dump(by_alias=True) for j in paginated_jobs], page, limit, total
)
return create_success_response(paginated_response)
except Exception as e:
logger.error(f"❌ Search jobs error: {e}")
return JSONResponse(
status_code=400,
content=create_error_response("SEARCH_FAILED", str(e))
)
return JSONResponse(status_code=400, content=create_error_response("SEARCH_FAILED", str(e)))
@router.delete("/{job_id}")
async def delete_job(
job_id: str = Path(...),
admin_user = Depends(get_current_admin),
database: RedisDatabase = Depends(get_database)
job_id: str = Path(...), admin_user=Depends(get_current_admin), database: RedisDatabase = Depends(get_database)
):
"""Delete a Job"""
try:
# Check if admin user
if not admin_user.is_admin:
logger.warning(f"⚠️ Unauthorized delete attempt by user {admin_user.id}")
return JSONResponse(
status_code=403,
content=create_error_response("FORBIDDEN", "Only admins can delete")
)
return JSONResponse(status_code=403, content=create_error_response("FORBIDDEN", "Only admins can delete"))
# Get candidate data
job_data = await database.get_job(job_id)
if not job_data:
logger.warning(f"⚠️ Candidate not found for deletion: {job_id}")
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "Job not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "Job not found"))
# Delete job from database
await database.delete_job(job_id)
logger.info(f"🗑️ Job deleted: {job_id} by admin {admin_user.id}")
return create_success_response({
"message": "Job deleted successfully",
"jobId": job_id
})
return create_success_response({"message": "Job deleted successfully", "jobId": job_id})
except Exception as e:
logger.error(f"❌ Delete job error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("DELETE_ERROR", "Failed to delete job")
)
return JSONResponse(status_code=500, content=create_error_response("DELETE_ERROR", "Failed to delete job"))

View File

@ -6,6 +6,7 @@ from utils.llm_proxy import LLMProvider, get_llm
router = APIRouter(prefix="/providers", tags=["providers"])
@router.get("/models")
async def list_models(provider: Optional[str] = None):
"""List available models for a provider"""
@ -17,29 +18,25 @@ async def list_models(provider: Optional[str] = None):
try:
provider_enum = LLMProvider(provider.lower())
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Unsupported provider: {provider}"
)
raise HTTPException(status_code=400, detail=f"Unsupported provider: {provider}")
models = await llm.list_models(provider_enum)
return {
"provider": provider_enum.value if provider_enum else llm.default_provider.value,
"models": models
}
return {"provider": provider_enum.value if provider_enum else llm.default_provider.value, "models": models}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("")
async def list_providers():
"""List all configured providers"""
llm = get_llm()
return {
"providers": [provider.value for provider in llm._initialized_providers],
"default": llm.default_provider.value
"default": llm.default_provider.value,
}
@router.post("/{provider}/set-default")
async def set_default_provider(provider: str):
"""Set the default provider"""
@ -51,6 +48,7 @@ async def set_default_provider(provider: str):
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# Health check endpoint
@router.get("/health")
async def health_check():
@ -59,5 +57,5 @@ async def health_check():
return {
"status": "healthy",
"providers_configured": len(llm._initialized_providers),
"default_provider": llm.default_provider.value
"default_provider": llm.default_provider.value,
}

View File

@ -11,26 +11,24 @@ from fastapi.responses import StreamingResponse
import backstory_traceback as backstory_traceback
from database.manager import RedisDatabase
from logger import logger
from models import (
MOCK_UUID, ChatMessageError, Job, Candidate, Resume, ResumeMessage
)
from utils.dependencies import (
get_database, get_current_user
)
from models import MOCK_UUID, ChatMessageError, Job, Candidate, Resume, ResumeMessage
from utils.dependencies import get_database, get_current_user
from utils.responses import create_success_response
# Create router for authentication endpoints
router = APIRouter(prefix="/resumes", tags=["resumes"])
@router.post("/{candidate_id}/{job_id}")
async def create_candidate_resume(
candidate_id: str = Path(..., description="ID of the candidate"),
job_id: str = Path(..., description="ID of the job"),
resume_content: str = Body(...),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Create a new resume for a candidate/job combination"""
async def message_stream_generator():
logger.info(f"🔍 Looking up candidate and job details for {candidate_id}/{job_id}")
@ -39,7 +37,7 @@ async def create_candidate_resume(
logger.error(f"❌ Candidate with ID '{candidate_id}' not found")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Candidate with ID '{candidate_id}' not found"
content=f"Candidate with ID '{candidate_id}' not found",
)
yield error_message
return
@ -50,13 +48,15 @@ async def create_candidate_resume(
logger.error(f"❌ Job with ID '{job_id}' not found")
error_message = ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content=f"Job with ID '{job_id}' not found"
content=f"Job with ID '{job_id}' not found",
)
yield error_message
return
job = Job.model_validate(job_data)
logger.info(f"📄 Saving resume for candidate {candidate.first_name} {candidate.last_name} for job '{job.title}'")
logger.info(
f"📄 Saving resume for candidate {candidate.first_name} {candidate.last_name} for job '{job.title}'"
)
# Job and Candidate are valid. Save the resume
resume = Resume(
@ -66,16 +66,13 @@ async def create_candidate_resume(
)
resume_message: ResumeMessage = ResumeMessage(
session_id=MOCK_UUID, # No session ID for document uploads
resume=resume
resume=resume,
)
# Save to database
success = await database.set_resume(current_user.id, resume.model_dump())
if not success:
error_message = ChatMessageError(
session_id=MOCK_UUID,
content="Failed to save resume to database"
)
error_message = ChatMessageError(session_id=MOCK_UUID, content="Failed to save resume to database")
yield error_message
return
@ -84,10 +81,11 @@ async def create_candidate_resume(
return
try:
async def to_json(method):
try:
async for message in method:
json_data = message.model_dump(mode='json', by_alias=True)
json_data = message.model_dump(mode="json", by_alias=True)
json_str = json.dumps(json_data)
yield f"data: {json_str}\n\n".encode("utf-8")
except Exception as e:
@ -111,22 +109,26 @@ async def create_candidate_resume(
logger.error(backstory_traceback.format_exc())
logger.error(f"❌ Resume creation error: {e}")
return StreamingResponse(
iter([json.dumps(ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to create resume"
).model_dump(mode='json', by_alias=True))]),
media_type="text/event-stream"
iter(
[
json.dumps(
ChatMessageError(
session_id=MOCK_UUID, # No session ID for document uploads
content="Failed to create resume",
).model_dump(mode="json", by_alias=True)
)
]
),
media_type="text/event-stream",
)
@router.get("")
async def get_user_resumes(
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
):
async def get_user_resumes(current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)):
"""Get all resumes for the current user"""
try:
resumes_data = await database.get_all_resumes_for_user(current_user.id)
resumes : List[Resume] = [Resume.model_validate(data) for data in resumes_data]
resumes: List[Resume] = [Resume.model_validate(data) for data in resumes_data]
for resume in resumes:
job_data = await database.get_job(resume.job_id)
if job_data:
@ -135,19 +137,17 @@ async def get_user_resumes(
if candidate_data:
resume.candidate = Candidate.model_validate(candidate_data)
resumes.sort(key=lambda x: x.updated_at, reverse=True) # Sort by creation date
return create_success_response({
"resumes": resumes,
"count": len(resumes)
})
return create_success_response({"resumes": resumes, "count": len(resumes)})
except Exception as e:
logger.error(f"❌ Error retrieving resumes for user {current_user.id}: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve resumes")
@router.get("/{resume_id}")
async def get_resume(
resume_id: str = Path(..., description="ID of the resume"),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Get a specific resume by ID"""
try:
@ -155,21 +155,19 @@ async def get_resume(
if not resume:
raise HTTPException(status_code=404, detail="Resume not found")
return {
"success": True,
"resume": resume
}
return {"success": True, "resume": resume}
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ Error retrieving resume {resume_id} for user {current_user.id}: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve resume")
@router.delete("/{resume_id}")
async def delete_resume(
resume_id: str = Path(..., description="ID of the resume"),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Delete a specific resume"""
try:
@ -177,102 +175,82 @@ async def delete_resume(
if not success:
raise HTTPException(status_code=404, detail="Resume not found")
return {
"success": True,
"message": f"Resume {resume_id} deleted successfully"
}
return {"success": True, "message": f"Resume {resume_id} deleted successfully"}
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ Error deleting resume {resume_id} for user {current_user.id}: {e}")
raise HTTPException(status_code=500, detail="Failed to delete resume")
@router.get("/candidate/{candidate_id}")
async def get_resumes_by_candidate(
candidate_id: str = Path(..., description="ID of the candidate"),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Get all resumes for a specific candidate"""
try:
resumes = await database.get_resumes_by_candidate(current_user.id, candidate_id)
return {
"success": True,
"candidate_id": candidate_id,
"resumes": resumes,
"count": len(resumes)
}
return {"success": True, "candidate_id": candidate_id, "resumes": resumes, "count": len(resumes)}
except Exception as e:
logger.error(f"❌ Error retrieving resumes for candidate {candidate_id}: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve candidate resumes")
@router.get("/job/{job_id}")
async def get_resumes_by_job(
job_id: str = Path(..., description="ID of the job"),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Get all resumes for a specific job"""
try:
resumes = await database.get_resumes_by_job(current_user.id, job_id)
return {
"success": True,
"job_id": job_id,
"resumes": resumes,
"count": len(resumes)
}
return {"success": True, "job_id": job_id, "resumes": resumes, "count": len(resumes)}
except Exception as e:
logger.error(f"❌ Error retrieving resumes for job {job_id}: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve job resumes")
@router.get("/search")
async def search_resumes(
q: str = Query(..., description="Search query"),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Search resumes by content"""
try:
resumes = await database.search_resumes_for_user(current_user.id, q)
return {
"success": True,
"query": q,
"resumes": resumes,
"count": len(resumes)
}
return {"success": True, "query": q, "resumes": resumes, "count": len(resumes)}
except Exception as e:
logger.error(f"❌ Error searching resumes for user {current_user.id}: {e}")
raise HTTPException(status_code=500, detail="Failed to search resumes")
@router.get("/stats")
async def get_resume_statistics(
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user), database: RedisDatabase = Depends(get_database)
):
"""Get resume statistics for the current user"""
try:
stats = await database.get_resume_statistics(current_user.id)
return {
"success": True,
"statistics": stats
}
return {"success": True, "statistics": stats}
except Exception as e:
logger.error(f"❌ Error retrieving resume statistics for user {current_user.id}: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve resume statistics")
@router.put("/{resume_id}")
async def update_resume(
resume_id: str = Path(..., description="ID of the resume"),
resume: str = Body(..., description="Updated resume content"),
current_user = Depends(get_current_user),
database: RedisDatabase = Depends(get_database)
current_user=Depends(get_current_user),
database: RedisDatabase = Depends(get_database),
):
"""Update the content of a specific resume"""
try:
updates = {
"resume": resume,
"updated_at": datetime.now(UTC).isoformat()
}
updates = {"resume": resume, "updated_at": datetime.now(UTC).isoformat()}
updated_resume_data = await database.update_resume(current_user.id, resume_id, updates)
if not updated_resume_data:
@ -280,11 +258,9 @@ async def update_resume(
raise HTTPException(status_code=404, detail="Resume not found")
updated_resume = Resume.model_validate(updated_resume_data) if updated_resume_data else None
return create_success_response({
"success": True,
"message": f"Resume {resume_id} updated successfully",
"resume": updated_resume
})
return create_success_response(
{"success": True, "message": f"Resume {resume_id} updated successfully", "resume": updated_resume}
)
except HTTPException:
raise
except Exception as e:

View File

@ -10,6 +10,7 @@ from utils.responses import create_success_response
# Create router for authentication endpoints
router = APIRouter(prefix="/system", tags=["system"])
async def get_redis() -> Redis:
"""Dependency to get Redis client"""
return redis_manager.get_client()
@ -19,9 +20,11 @@ async def get_redis() -> Redis:
async def get_system_info(request: Request):
"""Get system information"""
from system_info import system_info # Import system_info function from system_info module
system = system_info()
return create_success_response(system.model_dump(mode='json'))
return create_success_response(system.model_dump(mode="json"))
@router.get("/redis/stats")
async def redis_stats(redis: Redis = Depends(get_redis)):
@ -33,7 +36,7 @@ async def redis_stats(redis: Redis = Depends(get_redis)):
"total_commands_processed": info.get("total_commands_processed"),
"keyspace_hits": info.get("keyspace_hits"),
"keyspace_misses": info.get("keyspace_misses"),
"uptime_in_seconds": info.get("uptime_in_seconds")
"uptime_in_seconds": info.get("uptime_in_seconds"),
}
except Exception as e:
raise HTTPException(status_code=503, detail=f"Redis stats unavailable: {e}")

View File

@ -7,23 +7,17 @@ from fastapi.responses import JSONResponse
from database.manager import RedisDatabase
from logger import logger
from models import (
BaseUserWithType
)
from utils.dependencies import (
get_database
)
from models import BaseUserWithType
from utils.dependencies import get_database
from utils.responses import create_success_response, create_error_response
# Create router for job endpoints
router = APIRouter(prefix="/users", tags=["users"])
# reference can be candidateId, username, or email
@router.get("/users/{reference}")
async def get_user(
reference: str = Path(...),
database: RedisDatabase = Depends(get_database)
):
async def get_user(reference: str = Path(...), database: RedisDatabase = Depends(get_database)):
"""Get a candidate by username"""
try:
# Normalize reference to lowercase for case-insensitive search
@ -31,41 +25,36 @@ async def get_user(
all_candidate_data = await database.get_all_candidates()
if not all_candidate_data:
logger.warning(f"⚠️ No users found in database")
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "No users found")
)
logger.warning("⚠️ No users found in database")
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "No users found"))
user_data = None
for user in all_candidate_data.values():
if (user.get("id", "").lower() == query_lower or
user.get("username", "").lower() == query_lower or
user.get("email", "").lower() == query_lower):
if (
user.get("id", "").lower() == query_lower
or user.get("username", "").lower() == query_lower
or user.get("email", "").lower() == query_lower
):
user_data = user
break
if not user_data:
all_guest_data = await database.get_all_guests()
if not all_guest_data:
logger.warning(f"⚠️ No guests found in database")
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "No users found")
)
logger.warning("⚠️ No guests found in database")
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "No users found"))
for user in all_guest_data.values():
if (user.get("id", "").lower() == query_lower or
user.get("username", "").lower() == query_lower or
user.get("email", "").lower() == query_lower):
if (
user.get("id", "").lower() == query_lower
or user.get("username", "").lower() == query_lower
or user.get("email", "").lower() == query_lower
):
user_data = user
break
if not user_data:
logger.warning(f"⚠️ User nor Guest found for reference: {reference}")
return JSONResponse(
status_code=404,
content=create_error_response("NOT_FOUND", "User not found")
)
return JSONResponse(status_code=404, content=create_error_response("NOT_FOUND", "User not found"))
user = BaseUserWithType.model_validate(user_data)
@ -73,8 +62,4 @@ async def get_user(
except Exception as e:
logger.error(f"❌ Get user error: {e}")
return JSONResponse(
status_code=500,
content=create_error_response("FETCH_ERROR", str(e))
)
return JSONResponse(status_code=500, content=create_error_response("FETCH_ERROR", str(e)))

View File

@ -4,6 +4,7 @@ import subprocess
import math
from models import SystemInfo
def get_installed_ram():
try:
with open("/proc/meminfo", "r") as f:
@ -19,9 +20,7 @@ def get_graphics_cards():
gpus = []
try:
# Run the ze-monitor utility
result = subprocess.run(
["ze-monitor"], capture_output=True, text=True, check=True
)
result = subprocess.run(["ze-monitor"], capture_output=True, text=True, check=True)
# Clean up the output (remove leading/trailing whitespace and newlines)
output = result.stdout.strip()
@ -71,6 +70,7 @@ def get_cpu_info():
except Exception as e:
return f"Error retrieving CPU info: {e}"
def system_info() -> SystemInfo:
"""
Collects system information including RAM, GPU, CPU, LLM model, embedding model, and context length.

View File

@ -122,9 +122,7 @@ def get_forecast(grid_endpoint):
# Process the forecast data into a simpler format
forecast = {
"location": data["properties"]
.get("relativeLocation", {})
.get("properties", {}),
"location": data["properties"].get("relativeLocation", {}).get("properties", {}),
"updated": data["properties"].get("updated", ""),
"periods": [],
}
@ -181,7 +179,7 @@ def get_forecast(grid_endpoint):
def TickerValue(ticker_symbols):
api_key = os.getenv("TWELVEDATA_API_KEY", "")
if not api_key:
return {"error": f"Error fetching data: No API key for TwelveData"}
return {"error": "Error fetching data: No API key for TwelveData"}
results = []
for ticker_symbol in ticker_symbols.split(","):
@ -189,9 +187,7 @@ def TickerValue(ticker_symbols):
if ticker_symbol == "":
continue
url = (
f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}"
)
url = f"https://api.twelvedata.com/price?symbol={ticker_symbol}&apikey={api_key}"
response = requests.get(url)
data = response.json()
@ -244,9 +240,7 @@ def yfTickerValue(ticker_symbols):
logging.error(f"Error fetching data for {ticker_symbol}: {e}")
logging.error(traceback.format_exc())
results.append(
{"error": f"Error fetching data for {ticker_symbol}: {str(e)}"}
)
results.append({"error": f"Error fetching data for {ticker_symbol}: {str(e)}"})
return results[0] if len(results) == 1 else results
@ -278,8 +272,10 @@ def DateTime(timezone="America/Los_Angeles"):
except Exception as e:
return {"error": f"Invalid timezone {timezone}: {str(e)}"}
async def GenerateImage(llm, model: str, prompt: str):
return { "image_id": "image-a830a83-bd831" }
return {"image_id": "image-a830a83-bd831"}
async def AnalyzeSite(llm, model: str, url: str, question: str):
"""
@ -347,7 +343,6 @@ async def AnalyzeSite(llm, model: str, url: str, question: str):
return f"Error processing the website content: {str(e)}"
# %%
class Function(BaseModel):
name: str
@ -355,171 +350,181 @@ class Function(BaseModel):
parameters: Dict[str, Any]
returns: Optional[Dict[str, Any]] = {}
class Tool(BaseModel):
type: str
function: Function
tools : List[Tool] = [
# Tool.model_validate({
# "type": "function",
# "function": {
# "name": "GenerateImage",
# "description": """\
# CRITICAL INSTRUCTIONS FOR IMAGE GENERATION:
# 1. Call this tool when users request images, drawings, or visual content
# 2. This tool returns an image_id (e.g., "img_abc123")
# 3. MANDATORY: You must respond with EXACTLY this format: <GenerateImage id={image_id}/>
# 4. FORBIDDEN: DO NOT use markdown image syntax ![](url)
# 5. FORBIDDEN: DO NOT create fake URLs or file paths
# 6. FORBIDDEN: DO NOT use any other image embedding format
# CORRECT EXAMPLE:
# User: "Draw a cat"
# Tool returns: {"image_id": "img_xyz789"}
# Your response: "Here's your cat image: <GenerateImage id=img_xyz789/>"
# WRONG EXAMPLES (DO NOT DO THIS):
# - ![](https://example.com/...)
# - ![Cat image](any_url)
# - <img src="...">
# The <GenerateImage id={image_id}/> format is the ONLY way to display images in this system.
# """,
# "parameters": {
# "type": "object",
# "properties": {
# "prompt": {
# "type": "string",
# "description": "Detailed image description including style, colors, subject, composition"
# }
# },
# "required": ["prompt"]
# },
# "returns": {
# "type": "object",
# "properties": {
# "image_id": {
# "type": "string",
# "description": "Unique identifier for the generated image. Use this EXACTLY in <GenerateImage id={this_value}/>"
# }
# }
# }
# }
# }),
Tool.model_validate({
"type": "function",
"function": {
"name": "TickerValue",
"description": "Get the current stock price of one or more ticker symbols. Returns an array of objects with 'symbol' and 'price' fields. Call this whenever you need to know the latest value of stock ticker symbols, for example when a user asks 'How much is Intel trading at?' or 'What are the prices of AAPL and MSFT?'",
"parameters": {
"type": "object",
"properties": {
"ticker": {
"type": "string",
"description": "The company stock ticker symbol. For multiple tickers, provide a comma-separated list (e.g., 'AAPL,MSFT,GOOGL').",
tools: List[Tool] = [
# Tool.model_validate({
# "type": "function",
# "function": {
# "name": "GenerateImage",
# "description": """\
# CRITICAL INSTRUCTIONS FOR IMAGE GENERATION:
# 1. Call this tool when users request images, drawings, or visual content
# 2. This tool returns an image_id (e.g., "img_abc123")
# 3. MANDATORY: You must respond with EXACTLY this format: <GenerateImage id={image_id}/>
# 4. FORBIDDEN: DO NOT use markdown image syntax ![](url)
# 5. FORBIDDEN: DO NOT create fake URLs or file paths
# 6. FORBIDDEN: DO NOT use any other image embedding format
# CORRECT EXAMPLE:
# User: "Draw a cat"
# Tool returns: {"image_id": "img_xyz789"}
# Your response: "Here's your cat image: <GenerateImage id=img_xyz789/>"
# WRONG EXAMPLES (DO NOT DO THIS):
# - ![](https://example.com/...)
# - ![Cat image](any_url)
# - <img src="...">
# The <GenerateImage id={image_id}/> format is the ONLY way to display images in this system.
# """,
# "parameters": {
# "type": "object",
# "properties": {
# "prompt": {
# "type": "string",
# "description": "Detailed image description including style, colors, subject, composition"
# }
# },
# "required": ["prompt"]
# },
# "returns": {
# "type": "object",
# "properties": {
# "image_id": {
# "type": "string",
# "description": "Unique identifier for the generated image. Use this EXACTLY in <GenerateImage id={this_value}/>"
# }
# }
# }
# }
# }),
Tool.model_validate(
{
"type": "function",
"function": {
"name": "TickerValue",
"description": "Get the current stock price of one or more ticker symbols. Returns an array of objects with 'symbol' and 'price' fields. Call this whenever you need to know the latest value of stock ticker symbols, for example when a user asks 'How much is Intel trading at?' or 'What are the prices of AAPL and MSFT?'",
"parameters": {
"type": "object",
"properties": {
"ticker": {
"type": "string",
"description": "The company stock ticker symbol. For multiple tickers, provide a comma-separated list (e.g., 'AAPL,MSFT,GOOGL').",
},
},
"required": ["ticker"],
"additionalProperties": False,
},
"required": ["ticker"],
"additionalProperties": False,
},
},
}),
Tool.model_validate({
"type": "function",
"function": {
"name": "AnalyzeSite",
"description": "Downloads the requested site and asks a second LLM agent to answer the question based on the site content. For example if the user says 'What are the top headlines on cnn.com?' you would use AnalyzeSite to get the answer. Only use this if the user asks about a specific site or company.",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The website URL to download and process",
},
"question": {
"type": "string",
"description": "The question to ask the second LLM about the content",
}
),
Tool.model_validate(
{
"type": "function",
"function": {
"name": "AnalyzeSite",
"description": "Downloads the requested site and asks a second LLM agent to answer the question based on the site content. For example if the user says 'What are the top headlines on cnn.com?' you would use AnalyzeSite to get the answer. Only use this if the user asks about a specific site or company.",
"parameters": {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The website URL to download and process",
},
"question": {
"type": "string",
"description": "The question to ask the second LLM about the content",
},
},
"required": ["url", "question"],
"additionalProperties": False,
},
"required": ["url", "question"],
"additionalProperties": False,
},
"returns": {
"type": "object",
"properties": {
"source": {
"type": "string",
"description": "Identifier for the source LLM",
},
"content": {
"type": "string",
"description": "The complete response from the second LLM",
},
"metadata": {
"type": "object",
"description": "Additional information about the response",
"returns": {
"type": "object",
"properties": {
"source": {
"type": "string",
"description": "Identifier for the source LLM",
},
"content": {
"type": "string",
"description": "The complete response from the second LLM",
},
"metadata": {
"type": "object",
"description": "Additional information about the response",
},
},
},
},
},
}),
Tool.model_validate({
"type": "function",
"function": {
"name": "DateTime",
"description": "Get the current date and time in a specified timezone. For example if a user asks 'What time is it in Poland?' you would pass the Warsaw timezone to DateTime.",
"parameters": {
"type": "object",
"properties": {
"timezone": {
"type": "string",
"description": "Timezone name (e.g., 'UTC', 'America/New_York', 'Europe/London', 'America/Los_Angeles'). Default is 'America/Los_Angeles'.",
}
},
"required": [],
},
},
}),
Tool.model_validate({
"type": "function",
"function": {
"name": "WeatherForecast",
"description": "Get the full weather forecast as structured data for a given CITY and STATE location in the United States. For example, if the user asks 'What is the weather in Portland?' or 'What is the forecast for tomorrow?' use the provided data to answer the question.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City to find the weather forecast (e.g., 'Portland', 'Seattle').",
"minLength": 2,
},
"state": {
"type": "string",
"description": "State to find the weather forecast (e.g., 'OR', 'WA').",
"minLength": 2,
}
),
Tool.model_validate(
{
"type": "function",
"function": {
"name": "DateTime",
"description": "Get the current date and time in a specified timezone. For example if a user asks 'What time is it in Poland?' you would pass the Warsaw timezone to DateTime.",
"parameters": {
"type": "object",
"properties": {
"timezone": {
"type": "string",
"description": "Timezone name (e.g., 'UTC', 'America/New_York', 'Europe/London', 'America/Los_Angeles'). Default is 'America/Los_Angeles'.",
}
},
"required": [],
},
"required": ["city", "state"],
"additionalProperties": False,
},
},
}),
}
),
Tool.model_validate(
{
"type": "function",
"function": {
"name": "WeatherForecast",
"description": "Get the full weather forecast as structured data for a given CITY and STATE location in the United States. For example, if the user asks 'What is the weather in Portland?' or 'What is the forecast for tomorrow?' use the provided data to answer the question.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City to find the weather forecast (e.g., 'Portland', 'Seattle').",
"minLength": 2,
},
"state": {
"type": "string",
"description": "State to find the weather forecast (e.g., 'OR', 'WA').",
"minLength": 2,
},
},
"required": ["city", "state"],
"additionalProperties": False,
},
},
}
),
]
class ToolEntry(BaseModel):
enabled: bool = True
tool: Tool
def llm_tools(tools: List[ToolEntry]) -> List[Dict[str, Any]]:
return [entry.tool.model_dump(mode='json') for entry in tools if entry.enabled == True]
return [entry.tool.model_dump(mode="json") for entry in tools if entry.enabled is True]
def all_tools() -> List[ToolEntry]:
return [ToolEntry(tool=tool) for tool in tools]
def enabled_tools(tools: List[ToolEntry]) -> List[ToolEntry]:
return [ToolEntry(tool=entry.tool) for entry in tools if entry.enabled == True]
return [ToolEntry(tool=entry.tool) for entry in tools if entry.enabled is True]
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite", "GenerateImage"]
__all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"]

View File

@ -14,6 +14,7 @@ from pydantic import BaseModel
logger = logging.getLogger(__name__)
class PasswordSecurity:
"""Handles password hashing and verification using bcrypt"""
@ -32,9 +33,9 @@ class PasswordSecurity:
salt = bcrypt.gensalt()
# Hash the password
password_hash = bcrypt.hashpw(password.encode('utf-8'), salt)
password_hash = bcrypt.hashpw(password.encode("utf-8"), salt)
return password_hash.decode('utf-8'), salt.decode('utf-8')
return password_hash.decode("utf-8"), salt.decode("utf-8")
@staticmethod
def verify_password(password: str, password_hash: str) -> bool:
@ -49,10 +50,7 @@ class PasswordSecurity:
True if password matches, False otherwise
"""
try:
return bcrypt.checkpw(
password.encode('utf-8'),
password_hash.encode('utf-8')
)
return bcrypt.checkpw(password.encode("utf-8"), password_hash.encode("utf-8"))
except Exception as e:
logger.error(f"Password verification error: {e}")
return False
@ -62,8 +60,10 @@ class PasswordSecurity:
"""Generate a cryptographically secure random token"""
return secrets.token_urlsafe(length)
class AuthenticationRecord(BaseModel):
"""Authentication record for storing user credentials"""
user_id: str
password_hash: str
salt: str
@ -78,18 +78,19 @@ class AuthenticationRecord(BaseModel):
locked_until: Optional[datetime] = None
class Config:
json_encoders = {
datetime: lambda v: v.isoformat() if v else None
}
json_encoders = {datetime: lambda v: v.isoformat() if v else None}
class SecurityConfig:
"""Security configuration constants"""
MAX_LOGIN_ATTEMPTS = 5
ACCOUNT_LOCKOUT_DURATION_MINUTES = 15
PASSWORD_MIN_LENGTH = 8
TOKEN_EXPIRY_HOURS = 24
REFRESH_TOKEN_EXPIRY_DAYS = 30
class AuthenticationManager:
"""Manages authentication operations with security features"""
@ -120,7 +121,7 @@ class AuthenticationManager:
password_hash=password_hash,
salt=salt,
last_password_change=datetime.now(timezone.utc),
login_attempts=0
login_attempts=0,
)
# Store in database
@ -129,7 +130,9 @@ class AuthenticationManager:
logger.info(f"🔐 Created authentication record for user {user_id}")
return auth_record
async def verify_user_credentials(self, login: str, password: str) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
async def verify_user_credentials(
self, login: str, password: str
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
"""
Verify user credentials with security checks
@ -164,7 +167,11 @@ class AuthenticationManager:
seconds = int(total_seconds % 60)
time_until_unlock_str = f"{minutes}m {seconds}s"
logger.warning(f"🔒 Account is locked for user {login} for another {time_until_unlock_str}.")
return False, None, f"Account is temporarily locked due to too many failed attempts. Retry after {time_until_unlock_str}"
return (
False,
None,
f"Account is temporarily locked due to too many failed attempts. Retry after {time_until_unlock_str}",
)
# Verify password
if not self.password_security.verify_password(password, auth_data.password_hash):
@ -176,7 +183,9 @@ class AuthenticationManager:
auth_data.locked_until = datetime.now(timezone.utc) + timedelta(
minutes=SecurityConfig.ACCOUNT_LOCKOUT_DURATION_MINUTES
)
logger.warning(f"🔒 Account locked for user {login} after {auth_data.login_attempts} failed attempts")
logger.warning(
f"🔒 Account locked for user {login} after {auth_data.login_attempts} failed attempts"
)
# Update authentication record
await self.database.set_authentication(user_data["id"], auth_data.model_dump())
@ -238,6 +247,7 @@ class AuthenticationManager:
except Exception as e:
logger.error(f"❌ Error updating last login for user {user_id}: {e}")
# Utility functions for common operations
def validate_password_strength(password: str) -> Tuple[bool, list]:
"""
@ -270,6 +280,7 @@ def validate_password_strength(password: str) -> Tuple[bool, list]:
return len(issues) == 0, issues
def sanitize_login_input(login: str) -> str:
"""Sanitize login input (email or username)"""
return login.strip().lower() if login else ""

View File

@ -19,7 +19,7 @@ from models import BaseUserWithType, Candidate, CandidateAI, Employer, Guest
from logger import logger
from background_tasks import BackgroundTaskManager
#from . rate_limiter import RateLimiter
# from . rate_limiter import RateLimiter
# Security
security = HTTPBearer()
@ -33,11 +33,13 @@ background_task_manager: Optional[BackgroundTaskManager] = None
# Global database manager reference
db_manager = None
def set_db_manager(manager: DatabaseManager):
"""Set the global database manager reference"""
global db_manager
db_manager = manager
def get_database() -> RedisDatabase:
"""
Safe database dependency that checks for availability
@ -47,26 +49,18 @@ def get_database() -> RedisDatabase:
if db_manager is None:
logger.error("Database manager not initialized")
raise HTTPException(
status_code=503,
detail="Database not available - service starting up"
)
raise HTTPException(status_code=503, detail="Database not available - service starting up")
if db_manager.is_shutting_down:
logger.warning("Database is shutting down")
raise HTTPException(
status_code=503,
detail="Service is shutting down"
)
raise HTTPException(status_code=503, detail="Service is shutting down")
try:
return db_manager.get_database()
except RuntimeError as e:
logger.error(f"Database not available: {e}")
raise HTTPException(
status_code=503,
detail="Database connection not available"
)
raise HTTPException(status_code=503, detail="Database connection not available")
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
@ -78,6 +72,7 @@ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""Enhanced token verification with guest session recovery"""
try:
@ -125,9 +120,9 @@ async def verify_token_with_blacklist(credentials: HTTPAuthorizationCredentials
logger.error(f"❌ Token verification error: {e}")
raise HTTPException(status_code=401, detail="Token verification failed")
async def get_current_user(
user_id: str = Depends(verify_token_with_blacklist),
database: RedisDatabase = Depends(get_database)
user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database)
) -> BaseUserWithType:
"""Get current user from database"""
try:
@ -135,6 +130,7 @@ async def get_current_user(
candidate_data = await database.get_candidate(user_id)
if candidate_data:
from helpers.model_cast import cast_to_base_user_with_type
if candidate_data.get("is_AI"):
return cast_to_base_user_with_type(CandidateAI.model_validate(candidate_data))
else:
@ -152,16 +148,20 @@ async def get_current_user(
logger.error(f"❌ Error getting current user: {e}")
raise HTTPException(status_code=404, detail="User not found")
async def get_current_user_or_guest(
user_id: str = Depends(verify_token_with_blacklist),
database: RedisDatabase = Depends(get_database)
user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database)
) -> BaseUserWithType:
"""Get current user (including guests) from database"""
try:
# Check candidates first
candidate_data = await database.get_candidate(user_id)
if candidate_data:
return Candidate.model_validate(candidate_data) if not candidate_data.get("is_AI") else CandidateAI.model_validate(candidate_data)
return (
Candidate.model_validate(candidate_data)
if not candidate_data.get("is_AI")
else CandidateAI.model_validate(candidate_data)
)
# Check employers
employer_data = await database.get_employer(user_id)
@ -180,9 +180,9 @@ async def get_current_user_or_guest(
logger.error(f"❌ Error getting current user: {e}")
raise HTTPException(status_code=404, detail="User not found")
async def get_current_admin(
user_id: str = Depends(verify_token_with_blacklist),
database: RedisDatabase = Depends(get_database)
user_id: str = Depends(verify_token_with_blacklist), database: RedisDatabase = Depends(get_database)
) -> BaseUserWithType:
user = await get_current_user(user_id=user_id, database=database)
if isinstance(user, Candidate) and user.is_admin:
@ -193,6 +193,7 @@ async def get_current_admin(
logger.warning(f"⚠️ User {user_id} is not an admin")
raise HTTPException(status_code=403, detail="Admin access required")
prometheus_collector = CollectorRegistry()
# Keep the Instrumentator instance alive
@ -201,5 +202,5 @@ instrumentator = Instrumentator(
should_ignore_untemplated=True,
should_group_untemplated=True,
excluded_handlers=[f"{defines.api_prefix}/metrics"],
registry=prometheus_collector
registry=prometheus_collector,
)

View File

@ -1,5 +1,6 @@
from typing import List, Dict
from models import (Job)
from models import Job
def get_requirements_list(job: Job) -> List[Dict[str, str]]:
requirements: List[Dict[str, str]] = []
@ -7,56 +8,56 @@ def get_requirements_list(job: Job) -> List[Dict[str, str]]:
if job.requirements:
if job.requirements.technical_skills:
if job.requirements.technical_skills.required:
requirements.extend([
{"requirement": req, "domain": "Technical Skills (required)"}
for req in job.requirements.technical_skills.required
])
requirements.extend(
[
{"requirement": req, "domain": "Technical Skills (required)"}
for req in job.requirements.technical_skills.required
]
)
if job.requirements.technical_skills.preferred:
requirements.extend([
{"requirement": req, "domain": "Technical Skills (preferred)"}
for req in job.requirements.technical_skills.preferred
])
requirements.extend(
[
{"requirement": req, "domain": "Technical Skills (preferred)"}
for req in job.requirements.technical_skills.preferred
]
)
if job.requirements.experience_requirements:
if job.requirements.experience_requirements.required:
requirements.extend([
{"requirement": req, "domain": "Experience (required)"}
for req in job.requirements.experience_requirements.required
])
requirements.extend(
[
{"requirement": req, "domain": "Experience (required)"}
for req in job.requirements.experience_requirements.required
]
)
if job.requirements.experience_requirements.preferred:
requirements.extend([
{"requirement": req, "domain": "Experience (preferred)"}
for req in job.requirements.experience_requirements.preferred
])
requirements.extend(
[
{"requirement": req, "domain": "Experience (preferred)"}
for req in job.requirements.experience_requirements.preferred
]
)
if job.requirements.soft_skills:
requirements.extend([
{"requirement": req, "domain": "Soft Skills"}
for req in job.requirements.soft_skills
])
requirements.extend([{"requirement": req, "domain": "Soft Skills"} for req in job.requirements.soft_skills])
if job.requirements.experience:
requirements.extend([
{"requirement": req, "domain": "Experience"}
for req in job.requirements.experience
])
requirements.extend([{"requirement": req, "domain": "Experience"} for req in job.requirements.experience])
if job.requirements.education:
requirements.extend([
{"requirement": req, "domain": "Education"}
for req in job.requirements.education
])
requirements.extend([{"requirement": req, "domain": "Education"} for req in job.requirements.education])
if job.requirements.certifications:
requirements.extend([
{"requirement": req, "domain": "Certifications"}
for req in job.requirements.certifications
])
requirements.extend(
[{"requirement": req, "domain": "Certifications"} for req in job.requirements.certifications]
)
if job.requirements.preferred_attributes:
requirements.extend([
{"requirement": req, "domain": "Preferred Attributes"}
for req in job.requirements.preferred_attributes
])
requirements.extend(
[
{"requirement": req, "domain": "Preferred Attributes"}
for req in job.requirements.preferred_attributes
]
)
return requirements

View File

@ -13,44 +13,13 @@ from fastapi.responses import StreamingResponse
import defines
from logger import logger
from models import DocumentType
from models import (
LoginRequest, CreateCandidateRequest, CreateEmployerRequest,
Candidate, Employer, Guest, AuthResponse,
MFARequest, MFAData, MFARequestResponse, MFAVerifyRequest,
EmailVerificationRequest, ResendVerificationRequest,
# API
MOCK_UUID, ApiActivityType, ChatMessageError, ChatMessageResume,
ChatMessageSkillAssessment, ChatMessageStatus, ChatMessageStreaming,
ChatMessageUser, DocumentMessage, DocumentOptions, Job,
JobRequirements, JobRequirementsMessage, LoginRequest,
CreateCandidateRequest, CreateEmployerRequest,
# User models
Candidate, Employer, BaseUserWithType, BaseUser, Guest,
Authentication, AuthResponse, CandidateAI,
# Job models
JobApplication, ApplicationStatus,
# Chat models
ChatSession, ChatMessage, ChatContext, ChatQuery, ChatSenderType, ApiMessageType, ChatContextType,
ChatMessageRagSearch,
# Document models
Document, DocumentType, DocumentListResponse, DocumentUpdateRequest, DocumentContentResponse,
# Supporting models
Location, MFARequest, MFAData, MFARequestResponse, MFAVerifyRequest, RagContentMetadata, RagContentResponse, ResendVerificationRequest, Resume, ResumeMessage, Skill, SkillAssessment, SystemInfo, UserType, WorkExperience, Education,
# Email
EmailVerificationRequest,
ApiStatusType
)
from models import Job, ChatMessage, DocumentType, ApiStatusType
from typing import List, Dict
from models import (Job)
from models import Job
import utils.llm_proxy as llm_manager
async def get_last_item(generator):
"""Get the last item from an async generator"""
last_item = None
@ -65,7 +34,7 @@ def filter_and_paginate(
limit: int = 20,
sort_by: Optional[str] = None,
sort_order: str = "desc",
filters: Optional[Dict] = None
filters: Optional[Dict] = None,
) -> Tuple[List[Any], int]:
"""Filter, sort, and paginate items"""
filtered_items = items.copy()
@ -76,8 +45,7 @@ def filter_and_paginate(
if isinstance(filtered_items[0], dict) and key in filtered_items[0]:
filtered_items = [item for item in filtered_items if item.get(key) == value]
elif hasattr(filtered_items[0], key) if filtered_items else False:
filtered_items = [item for item in filtered_items
if getattr(item, key, None) == value]
filtered_items = [item for item in filtered_items if getattr(item, key, None) == value]
# Sort items
if sort_by and filtered_items:
@ -101,9 +69,9 @@ def filter_and_paginate(
async def stream_agent_response(chat_agent, user_message, chat_session_data=None, database=None) -> StreamingResponse:
"""Stream agent response with proper formatting"""
async def message_stream_generator():
"""Generator to stream messages with persistence"""
last_log = None
final_message = None
import utils.llm_proxy as llm_manager
@ -127,12 +95,15 @@ async def stream_agent_response(chat_agent, user_message, chat_session_data=None
# metadata and other unnecessary fields for streaming
if generated_message.status != ApiStatusType.DONE:
from models import ChatMessageStreaming, ChatMessageStatus
if not isinstance(generated_message, ChatMessageStreaming) and not isinstance(generated_message, ChatMessageStatus):
if not isinstance(generated_message, ChatMessageStreaming) and not isinstance(
generated_message, ChatMessageStatus
):
raise TypeError(
f"Expected ChatMessageStreaming or ChatMessageStatus, got {type(generated_message)}"
)
json_data = generated_message.model_dump(mode='json', by_alias=True)
json_data = generated_message.model_dump(mode="json", by_alias=True)
json_str = json.dumps(json_data)
yield f"data: {json_str}\n\n"
@ -177,16 +148,16 @@ def get_document_type_from_filename(filename: str) -> DocumentType:
extension = pathlib.Path(filename).suffix.lower()
type_mapping = {
'.pdf': DocumentType.PDF,
'.docx': DocumentType.DOCX,
'.doc': DocumentType.DOCX,
'.txt': DocumentType.TXT,
'.md': DocumentType.MARKDOWN,
'.markdown': DocumentType.MARKDOWN,
'.png': DocumentType.IMAGE,
'.jpg': DocumentType.IMAGE,
'.jpeg': DocumentType.IMAGE,
'.gif': DocumentType.IMAGE,
".pdf": DocumentType.PDF,
".docx": DocumentType.DOCX,
".doc": DocumentType.DOCX,
".txt": DocumentType.TXT,
".md": DocumentType.MARKDOWN,
".markdown": DocumentType.MARKDOWN,
".png": DocumentType.IMAGE,
".jpg": DocumentType.IMAGE,
".jpeg": DocumentType.IMAGE,
".gif": DocumentType.IMAGE,
}
return type_mapping.get(extension, DocumentType.TXT)
@ -206,17 +177,12 @@ async def reformat_as_markdown(database, candidate_entity, content: str):
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
if not chat_agent:
error_message = ChatMessageError(
session_id=MOCK_UUID,
content="No agent found for job requirements chat type"
)
error_message = ChatMessageError(session_id=MOCK_UUID, content="No agent found for job requirements chat type")
yield error_message
return
status_message = ChatMessageStatus(
session_id=MOCK_UUID,
content=f"Reformatting job description as markdown...",
activity=ApiActivityType.CONVERTING
session_id=MOCK_UUID, content="Reformatting job description as markdown...", activity=ApiActivityType.CONVERTING
)
yield status_message
@ -229,26 +195,23 @@ async def reformat_as_markdown(database, candidate_entity, content: str):
system_prompt="""
You are a document editor. Take the provided job description and reformat as legible markdown.
Return only the markdown content, no other text. Make sure all content is included.
"""
""",
):
pass
if not message or not isinstance(message, ChatMessage):
logger.error("❌ Failed to reformat job description to markdown")
error_message = ChatMessageError(
session_id=MOCK_UUID,
content="Failed to reformat job description"
)
error_message = ChatMessageError(session_id=MOCK_UUID, content="Failed to reformat job description")
yield error_message
return
chat_message: ChatMessage = message
try:
chat_message.content = chat_agent.extract_markdown_from_text(chat_message.content)
except Exception as e:
except Exception:
pass
logger.info(f"✅ Successfully converted content to markdown")
logger.info("✅ Successfully converted content to markdown")
yield chat_message
return
@ -256,14 +219,19 @@ Return only the markdown content, no other text. Make sure all content is includ
async def create_job_from_content(database, current_user, content: str):
"""Create a job from content using AI analysis"""
from models import (
MOCK_UUID, ApiStatusType, ChatMessageError, ChatMessageStatus,
ApiActivityType, ChatContextType, JobRequirementsMessage
MOCK_UUID,
ApiStatusType,
ChatMessageError,
ChatMessageStatus,
ApiActivityType,
ChatContextType,
JobRequirementsMessage,
)
status_message = ChatMessageStatus(
session_id=MOCK_UUID,
content=f"Initiating connection with {current_user.first_name}'s AI agent...",
activity=ApiActivityType.INFO
activity=ApiActivityType.INFO,
)
yield status_message
await asyncio.sleep(0) # Let the status message propagate
@ -278,10 +246,7 @@ async def create_job_from_content(database, current_user, content: str):
yield message
if not message or not isinstance(message, ChatMessage):
error_message = ChatMessageError(
session_id=MOCK_UUID,
content="Failed to reformat job description"
)
error_message = ChatMessageError(session_id=MOCK_UUID, content="Failed to reformat job description")
yield error_message
return
@ -290,33 +255,28 @@ async def create_job_from_content(database, current_user, content: str):
chat_agent = candidate_entity.get_or_create_agent(agent_type=ChatContextType.JOB_REQUIREMENTS)
if not chat_agent:
error_message = ChatMessageError(
session_id=MOCK_UUID,
content="No agent found for job requirements chat type"
session_id=MOCK_UUID, content="No agent found for job requirements chat type"
)
yield error_message
return
status_message = ChatMessageStatus(
session_id=MOCK_UUID,
content=f"Analyzing document for company and requirement details...",
activity=ApiActivityType.SEARCHING
content="Analyzing document for company and requirement details...",
activity=ApiActivityType.SEARCHING,
)
yield status_message
message = None
async for message in chat_agent.generate(
llm=llm_manager.get_llm(),
model=defines.model,
session_id=MOCK_UUID,
prompt=markdown_message.content
llm=llm_manager.get_llm(), model=defines.model, session_id=MOCK_UUID, prompt=markdown_message.content
):
if message.status != ApiStatusType.DONE:
yield message
if not message or not isinstance(message, JobRequirementsMessage):
error_message = ChatMessageError(
session_id=MOCK_UUID,
content="Job extraction did not convert successfully"
session_id=MOCK_UUID, content="Job extraction did not convert successfully"
)
yield error_message
return
@ -326,62 +286,63 @@ async def create_job_from_content(database, current_user, content: str):
yield job_requirements
return
def get_requirements_list(job: Job) -> List[Dict[str, str]]:
requirements: List[Dict[str, str]] = []
if job.requirements:
if job.requirements.technical_skills:
if job.requirements.technical_skills.required:
requirements.extend([
{"requirement": req, "domain": "Technical Skills (required)"}
for req in job.requirements.technical_skills.required
])
requirements.extend(
[
{"requirement": req, "domain": "Technical Skills (required)"}
for req in job.requirements.technical_skills.required
]
)
if job.requirements.technical_skills.preferred:
requirements.extend([
{"requirement": req, "domain": "Technical Skills (preferred)"}
for req in job.requirements.technical_skills.preferred
])
requirements.extend(
[
{"requirement": req, "domain": "Technical Skills (preferred)"}
for req in job.requirements.technical_skills.preferred
]
)
if job.requirements.experience_requirements:
if job.requirements.experience_requirements.required:
requirements.extend([
{"requirement": req, "domain": "Experience (required)"}
for req in job.requirements.experience_requirements.required
])
requirements.extend(
[
{"requirement": req, "domain": "Experience (required)"}
for req in job.requirements.experience_requirements.required
]
)
if job.requirements.experience_requirements.preferred:
requirements.extend([
{"requirement": req, "domain": "Experience (preferred)"}
for req in job.requirements.experience_requirements.preferred
])
requirements.extend(
[
{"requirement": req, "domain": "Experience (preferred)"}
for req in job.requirements.experience_requirements.preferred
]
)
if job.requirements.soft_skills:
requirements.extend([
{"requirement": req, "domain": "Soft Skills"}
for req in job.requirements.soft_skills
])
requirements.extend([{"requirement": req, "domain": "Soft Skills"} for req in job.requirements.soft_skills])
if job.requirements.experience:
requirements.extend([
{"requirement": req, "domain": "Experience"}
for req in job.requirements.experience
])
requirements.extend([{"requirement": req, "domain": "Experience"} for req in job.requirements.experience])
if job.requirements.education:
requirements.extend([
{"requirement": req, "domain": "Education"}
for req in job.requirements.education
])
requirements.extend([{"requirement": req, "domain": "Education"} for req in job.requirements.education])
if job.requirements.certifications:
requirements.extend([
{"requirement": req, "domain": "Certifications"}
for req in job.requirements.certifications
])
requirements.extend(
[{"requirement": req, "domain": "Certifications"} for req in job.requirements.certifications]
)
if job.requirements.preferred_attributes:
requirements.extend([
{"requirement": req, "domain": "Preferred Attributes"}
for req in job.requirements.preferred_attributes
])
requirements.extend(
[
{"requirement": req, "domain": "Preferred Attributes"}
for req in job.requirements.preferred_attributes
]
)
return requirements

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,7 @@
from prometheus_client import Counter, Histogram # type: ignore
from threading import Lock
def singleton(cls):
instance = None
lock = Lock()

View File

@ -6,56 +6,68 @@ from functools import wraps
from datetime import datetime, timedelta, UTC
from typing import Callable, Dict, Optional, Any
from fastapi import Depends, HTTPException, Request
from pydantic import BaseModel # type: ignore
from pydantic import BaseModel # type: ignore
from database.manager import RedisDatabase
from logger import logger
from . dependencies import get_current_user_or_guest, get_database
from .dependencies import get_current_user_or_guest, get_database
async def get_rate_limiter(database: RedisDatabase = Depends(get_database)) -> RateLimiter:
"""Dependency to get rate limiter instance"""
return RateLimiter(database)
class RateLimitConfig(BaseModel):
"""Rate limit configuration"""
requests_per_minute: int
requests_per_hour: int
requests_per_day: int
burst_limit: int # Maximum requests in a short burst
burst_window_seconds: int = 60 # Window for burst detection
class GuestRateLimitConfig(RateLimitConfig):
"""Rate limits for guest users - more restrictive"""
requests_per_minute: int = 10
requests_per_hour: int = 100
requests_per_day: int = 500
burst_limit: int = 15
burst_window_seconds: int = 60
class AuthenticatedUserRateLimitConfig(RateLimitConfig):
"""Rate limits for authenticated users - more generous"""
requests_per_minute: int = 60
requests_per_hour: int = 1000
requests_per_day: int = 10000
burst_limit: int = 100
burst_window_seconds: int = 60
class PremiumUserRateLimitConfig(RateLimitConfig):
"""Rate limits for premium/admin users - most generous"""
requests_per_minute: int = 120
requests_per_hour: int = 5000
requests_per_day: int = 50000
burst_limit: int = 200
burst_window_seconds: int = 60
class RateLimitResult(BaseModel):
"""Result of rate limit check"""
allowed: bool
reason: Optional[str] = None
retry_after_seconds: Optional[int] = None
remaining_requests: Dict[str, int] = {}
reset_times: Dict[str, datetime] = {}
class RateLimiter:
"""Rate limiter using Redis for distributed rate limiting"""
@ -78,11 +90,7 @@ class RateLimiter:
return self.user_config
async def check_rate_limit(
self,
user_id: str,
user_type: str,
is_admin: bool = False,
endpoint: Optional[str] = None
self, user_id: str, user_type: str, is_admin: bool = False, endpoint: Optional[str] = None
) -> RateLimitResult:
"""
Check if user has exceeded rate limits
@ -105,7 +113,7 @@ class RateLimiter:
"minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}",
"hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}",
"day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}",
"burst": f"{base_key}:burst"
"burst": f"{base_key}:burst",
}
# Add endpoint-specific limiting if provided
@ -125,7 +133,7 @@ class RateLimiter:
"minute": int(results[0] or 0),
"hour": int(results[1] or 0),
"day": int(results[2] or 0),
"burst": int(results[3] or 0)
"burst": int(results[3] or 0),
}
# Check limits
@ -133,7 +141,7 @@ class RateLimiter:
"minute": config.requests_per_minute,
"hour": config.requests_per_hour,
"day": config.requests_per_day,
"burst": config.burst_limit
"burst": config.burst_limit,
}
# Check each limit
@ -146,18 +154,22 @@ class RateLimiter:
elif window == "hour":
retry_after = 3600 - (current_time.minute * 60 + current_time.second)
elif window == "day":
retry_after = 86400 - (current_time.hour * 3600 + current_time.minute * 60 + current_time.second)
retry_after = 86400 - (
current_time.hour * 3600 + current_time.minute * 60 + current_time.second
)
else: # burst
retry_after = config.burst_window_seconds
logger.warning(f"🚫 Rate limit exceeded for {user_type} {user_id}: {current_count}/{limit} {window}")
logger.warning(
f"🚫 Rate limit exceeded for {user_type} {user_id}: {current_count}/{limit} {window}"
)
return RateLimitResult(
allowed=False,
reason=f"Rate limit exceeded: {current_count}/{limit} requests per {window}",
retry_after_seconds=retry_after,
remaining_requests={k: max(0, limits[k] - v) for k, v in current_counts.items()},
reset_times=self._calculate_reset_times(current_time)
reset_times=self._calculate_reset_times(current_time),
)
# If we get here, request is allowed - increment counters
@ -182,17 +194,12 @@ class RateLimiter:
await pipe.execute()
# Calculate remaining requests
remaining = {
k: max(0, limits[k] - (current_counts[k] + 1))
for k in current_counts.keys()
}
remaining = {k: max(0, limits[k] - (current_counts[k] + 1)) for k in current_counts.keys()}
logger.debug(f"✅ Rate limit check passed for {user_type} {user_id}")
return RateLimitResult(
allowed=True,
remaining_requests=remaining,
reset_times=self._calculate_reset_times(current_time)
allowed=True, remaining_requests=remaining, reset_times=self._calculate_reset_times(current_time)
)
except Exception as e:
@ -206,18 +213,9 @@ class RateLimiter:
next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
next_day = current_time.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1)
return {
"minute": next_minute,
"hour": next_hour,
"day": next_day
}
return {"minute": next_minute, "hour": next_hour, "day": next_day}
async def get_user_rate_limit_status(
self,
user_id: str,
user_type: str,
is_admin: bool = False
) -> Dict[str, Any]:
async def get_user_rate_limit_status(self, user_id: str, user_type: str, is_admin: bool = False) -> Dict[str, Any]:
"""Get current rate limit status for a user"""
config = self.get_config_for_user(user_type, is_admin)
current_time = datetime.now(UTC)
@ -227,7 +225,7 @@ class RateLimiter:
"minute": f"{base_key}:minute:{current_time.strftime('%Y%m%d%H%M')}",
"hour": f"{base_key}:hour:{current_time.strftime('%Y%m%d%H')}",
"day": f"{base_key}:day:{current_time.strftime('%Y%m%d')}",
"burst": f"{base_key}:burst"
"burst": f"{base_key}:burst",
}
try:
@ -240,14 +238,14 @@ class RateLimiter:
"minute": int(results[0] or 0),
"hour": int(results[1] or 0),
"day": int(results[2] or 0),
"burst": int(results[3] or 0)
"burst": int(results[3] or 0),
}
limits = {
"minute": config.requests_per_minute,
"hour": config.requests_per_hour,
"day": config.requests_per_day,
"burst": config.burst_limit
"burst": config.burst_limit,
}
return {
@ -258,7 +256,7 @@ class RateLimiter:
"limits": limits,
"remaining": {k: max(0, limits[k] - current_counts[k]) for k in limits.keys()},
"reset_times": self._calculate_reset_times(current_time),
"config": config.model_dump()
"config": config.model_dump(),
}
except Exception as e:
@ -295,11 +293,9 @@ class RateLimiter:
# Rate Limited Decorator
# ============================
def rate_limited(
guest_per_minute: int = 10,
user_per_minute: int = 60,
admin_per_minute: int = 120,
endpoint_specific: bool = True
guest_per_minute: int = 10, user_per_minute: int = 60, admin_per_minute: int = 120, endpoint_specific: bool = True
):
"""
Decorator to easily apply rate limiting to endpoints
@ -320,12 +316,14 @@ def rate_limited(
):
return {"message": "Rate limited endpoint"}
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs):
# Extract dependencies from function signature
import inspect
sig = inspect.signature(func)
inspect.signature(func)
# Get request, current_user, and rate_limiter from kwargs or args
request = None
@ -336,7 +334,7 @@ def rate_limited(
for param_name, param_value in kwargs.items():
if isinstance(param_value, Request):
request = param_value
elif hasattr(param_value, 'user_type'): # User-like object
elif hasattr(param_value, "user_type"): # User-like object
current_user = param_value
elif isinstance(param_value, RateLimiter):
rate_limiter = param_value
@ -350,30 +348,33 @@ def rate_limited(
# Apply rate limiting if we have the required components
if request and current_user and rate_limiter:
await apply_custom_rate_limiting(
request, current_user, rate_limiter,
guest_per_minute, user_per_minute, admin_per_minute
request, current_user, rate_limiter, guest_per_minute, user_per_minute, admin_per_minute
)
# Call the original function
return await func(*args, **kwargs)
return wrapper
return decorator
async def apply_custom_rate_limiting(
request: Request,
current_user,
rate_limiter: RateLimiter,
guest_per_minute: int,
user_per_minute: int,
admin_per_minute: int
admin_per_minute: int,
):
"""Apply custom rate limiting with specified limits"""
try:
# Determine user info
user_id = current_user.id
user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type)
is_admin = getattr(current_user, 'is_admin', False)
user_type = (
current_user.user_type.value if hasattr(current_user.user_type, "value") else str(current_user.user_type)
)
is_admin = getattr(current_user, "is_admin", False)
# Determine appropriate limit
if is_admin:
@ -385,13 +386,17 @@ async def apply_custom_rate_limiting(
# Create custom rate limit key
current_time = datetime.now(UTC)
custom_key = f"custom_rate_limit:{request.url.path}:{user_type}:{user_id}:minute:{current_time.strftime('%Y%m%d%H%M')}"
custom_key = (
f"custom_rate_limit:{request.url.path}:{user_type}:{user_id}:minute:{current_time.strftime('%Y%m%d%H%M')}"
)
# Check current usage
current_count = int(await rate_limiter.redis.get(custom_key) or 0)
if current_count >= requests_per_minute:
logger.warning(f"🚫 Custom rate limit exceeded for {user_type} {user_id}: {current_count}/{requests_per_minute}")
logger.warning(
f"🚫 Custom rate limit exceeded for {user_type} {user_id}: {current_count}/{requests_per_minute}"
)
raise HTTPException(
status_code=429,
detail={
@ -399,9 +404,9 @@ async def apply_custom_rate_limiting(
"message": f"Custom rate limit exceeded: {current_count}/{requests_per_minute} requests per minute",
"retryAfter": 60 - current_time.second,
"userType": user_type,
"endpoint": request.url.path
"endpoint": request.url.path,
},
headers={"Retry-After": str(60 - current_time.second)}
headers={"Retry-After": str(60 - current_time.second)},
)
# Increment counter
@ -410,7 +415,9 @@ async def apply_custom_rate_limiting(
pipe.expire(custom_key, 120) # 2 minutes TTL
await pipe.execute()
logger.debug(f"✅ Custom rate limit check passed for {user_type} {user_id}: {current_count + 1}/{requests_per_minute}")
logger.debug(
f"✅ Custom rate limit check passed for {user_type} {user_id}: {current_count + 1}/{requests_per_minute}"
)
except HTTPException:
raise
@ -418,15 +425,13 @@ async def apply_custom_rate_limiting(
logger.error(f"❌ Custom rate limiting error: {e}")
# Fail open
# ============================
# Alternative: FastAPI Dependency-Based Rate Limiting
# ============================
def create_rate_limit_dependency(
guest_per_minute: int = 10,
user_per_minute: int = 60,
admin_per_minute: int = 120
):
def create_rate_limit_dependency(guest_per_minute: int = 10, user_per_minute: int = 60, admin_per_minute: int = 120):
"""
Create a FastAPI dependency for rate limiting
@ -441,23 +446,25 @@ def create_rate_limit_dependency(
):
return {"message": "Rate limited endpoint"}
"""
async def rate_limit_dependency(
request: Request,
current_user = Depends(get_current_user_or_guest),
rate_limiter: RateLimiter = Depends(get_rate_limiter)
current_user=Depends(get_current_user_or_guest),
rate_limiter: RateLimiter = Depends(get_rate_limiter),
):
await apply_custom_rate_limiting(
request, current_user, rate_limiter,
guest_per_minute, user_per_minute, admin_per_minute
request, current_user, rate_limiter, guest_per_minute, user_per_minute, admin_per_minute
)
return True
return rate_limit_dependency
# ============================
# Rate Limiting Utilities
# ============================
class EndpointRateLimiter:
"""Utility class for endpoint-specific rate limiting"""
@ -477,9 +484,11 @@ class EndpointRateLimiter:
return True # No custom limits set
limits = self.custom_limits[endpoint]
user_type = current_user.user_type.value if hasattr(current_user.user_type, 'value') else str(current_user.user_type)
user_type = (
current_user.user_type.value if hasattr(current_user.user_type, "value") else str(current_user.user_type)
)
if getattr(current_user, 'is_admin', False):
if getattr(current_user, "is_admin", False):
user_type = "admin"
limit = limits.get(user_type, limits.get("default", 60))
@ -491,8 +500,7 @@ class EndpointRateLimiter:
if current_count >= limit:
raise HTTPException(
status_code=429,
detail=f"Endpoint rate limit exceeded: {current_count}/{limit} for {endpoint}"
status_code=429, detail=f"Endpoint rate limit exceeded: {current_count}/{limit} for {endpoint}"
)
# Increment counter
@ -501,9 +509,11 @@ class EndpointRateLimiter:
return True
# Global endpoint rate limiter instance
endpoint_rate_limiter = None
def get_endpoint_rate_limiter(rate_limiter: RateLimiter = Depends(get_rate_limiter)) -> EndpointRateLimiter:
"""Get endpoint rate limiter instance"""
global endpoint_rate_limiter
@ -511,15 +521,14 @@ def get_endpoint_rate_limiter(rate_limiter: RateLimiter = Depends(get_rate_limit
endpoint_rate_limiter = EndpointRateLimiter(rate_limiter)
# Configure endpoint-specific limits
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/chat/sessions/*/messages/stream", {
"guest": 5, "candidate": 30, "employer": 30, "admin": 100
})
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/candidates/documents/upload", {
"guest": 2, "candidate": 10, "employer": 10, "admin": 50
})
endpoint_rate_limiter.set_endpoint_limits("/api/1.0/jobs", {
"guest": 1, "candidate": 5, "employer": 20, "admin": 50
})
endpoint_rate_limiter.set_endpoint_limits(
"/api/1.0/chat/sessions/*/messages/stream", {"guest": 5, "candidate": 30, "employer": 30, "admin": 100}
)
endpoint_rate_limiter.set_endpoint_limits(
"/api/1.0/candidates/documents/upload", {"guest": 2, "candidate": 10, "employer": 10, "admin": 50}
)
endpoint_rate_limiter.set_endpoint_limits(
"/api/1.0/jobs", {"guest": 1, "candidate": 5, "employer": 20, "admin": 50}
)
return endpoint_rate_limiter

View File

@ -3,37 +3,17 @@ Response utility functions for consistent API responses
"""
from typing import Any, Optional, Dict, List
def create_success_response(data: Any, meta: Optional[Dict] = None) -> Dict:
return {
"success": True,
"data": data,
"meta": meta
}
return {"success": True, "data": data, "meta": meta}
def create_error_response(code: str, message: str, details: Any = None) -> Dict:
return {
"success": False,
"error": {
"code": code,
"message": message,
"details": details
}
}
return {"success": False, "error": {"code": code, "message": message, "details": details}}
def create_paginated_response(
data: List[Any],
page: int,
limit: int,
total: int
) -> Dict:
def create_paginated_response(data: List[Any], page: int, limit: int, total: int) -> Dict:
total_pages = (total + limit - 1) // limit
has_more = page < total_pages
return {
"data": data,
"total": total,
"page": page,
"limit": limit,
"totalPages": total_pages,
"hasMore": has_more
}
return {"data": data, "total": total, "page": page, "limit": limit, "totalPages": total_pages, "hasMore": has_more}