386 lines
15 KiB
Python
386 lines
15 KiB
Python
"""
|
|
Conversation Context Management for Advanced Bot Management.
|
|
|
|
This module manages conversation context, memory, and state for enhanced
|
|
bot interactions with persistent conversation awareness.
|
|
"""
|
|
|
|
import json
|
|
import time
|
|
import os
|
|
from typing import Dict, List, Optional, Any
|
|
from dataclasses import dataclass, field
|
|
from pydantic import BaseModel, Field
|
|
from collections import defaultdict
|
|
|
|
from logger import logger
|
|
|
|
|
|
@dataclass
|
|
class ConversationTurn:
|
|
"""Individual turn in a conversation."""
|
|
turn_id: str
|
|
timestamp: float
|
|
user_message: str
|
|
bot_response: str
|
|
context_used: Dict[str, Any] = field(default_factory=dict)
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to dictionary for serialization."""
|
|
return {
|
|
"turn_id": self.turn_id,
|
|
"timestamp": self.timestamp,
|
|
"user_message": self.user_message,
|
|
"bot_response": self.bot_response,
|
|
"context_used": self.context_used,
|
|
"metadata": self.metadata
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'ConversationTurn':
|
|
"""Create from dictionary."""
|
|
return cls(**data)
|
|
|
|
|
|
class ConversationMemory(BaseModel):
|
|
"""Memory system for conversation context."""
|
|
|
|
# Core conversation data
|
|
session_id: str
|
|
bot_name: str
|
|
user_name: Optional[str] = None
|
|
conversation_id: str
|
|
|
|
# Conversation history
|
|
turns: List[ConversationTurn] = Field(default_factory=list)
|
|
created_at: float = Field(default_factory=time.time)
|
|
last_updated: float = Field(default_factory=time.time)
|
|
|
|
# Memory components
|
|
facts_learned: Dict[str, Any] = Field(default_factory=dict) # Facts about user/context
|
|
preferences: Dict[str, Any] = Field(default_factory=dict) # User preferences
|
|
topics_discussed: List[str] = Field(default_factory=list) # Topics covered
|
|
emotional_context: Dict[str, Any] = Field(default_factory=dict) # Emotional state
|
|
|
|
# Conversation state
|
|
current_topic: Optional[str] = None
|
|
conversation_stage: str = "greeting" # greeting, discussion, conclusion
|
|
user_intent: Optional[str] = None
|
|
bot_goals: List[str] = Field(default_factory=list)
|
|
|
|
# Memory management
|
|
max_turns: int = 100
|
|
max_facts: int = 50
|
|
memory_decay_factor: float = 0.95 # How quickly old memories fade
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
def add_turn(self, turn: ConversationTurn):
|
|
"""Add a conversation turn to memory."""
|
|
self.turns.append(turn)
|
|
self.last_updated = time.time()
|
|
|
|
# Extract facts and context from the turn
|
|
self._extract_context_from_turn(turn)
|
|
|
|
# Trim history if needed
|
|
if len(self.turns) > self.max_turns:
|
|
self.turns = self.turns[-self.max_turns:]
|
|
|
|
def _extract_context_from_turn(self, turn: ConversationTurn):
|
|
"""Extract contextual information from a conversation turn."""
|
|
# Simple keyword-based fact extraction (can be enhanced with NLP)
|
|
user_message = turn.user_message.lower()
|
|
|
|
# Extract preferences
|
|
if "i like" in user_message or "i love" in user_message:
|
|
# Simple preference extraction
|
|
preference_start = max(user_message.find("i like"), user_message.find("i love"))
|
|
preference_text = user_message[preference_start:].split('.')[0]
|
|
self.preferences[f"preference_{len(self.preferences)}"] = preference_text
|
|
|
|
# Extract facts
|
|
if "my name is" in user_message:
|
|
name_start = user_message.find("my name is") + len("my name is")
|
|
name = user_message[name_start:].split()[0].strip()
|
|
if name:
|
|
self.facts_learned["user_name"] = name
|
|
self.user_name = name
|
|
|
|
# Topic tracking
|
|
if turn.metadata.get("detected_topics"):
|
|
for topic in turn.metadata["detected_topics"]:
|
|
if topic not in self.topics_discussed:
|
|
self.topics_discussed.append(topic)
|
|
|
|
# Emotional context (simple sentiment analysis)
|
|
emotional_indicators = {
|
|
"happy": ["happy", "great", "wonderful", "excited", "joy"],
|
|
"sad": ["sad", "unhappy", "disappointed", "depressed"],
|
|
"frustrated": ["frustrated", "annoyed", "angry", "upset"],
|
|
"confused": ["confused", "don't understand", "unclear", "puzzled"],
|
|
"satisfied": ["good", "thanks", "helpful", "satisfied"]
|
|
}
|
|
|
|
for emotion, indicators in emotional_indicators.items():
|
|
if any(indicator in user_message for indicator in indicators):
|
|
self.emotional_context["current_emotion"] = emotion
|
|
self.emotional_context["last_emotion_update"] = time.time()
|
|
break
|
|
|
|
def get_recent_context(self, turns: int = 5) -> List[ConversationTurn]:
|
|
"""Get recent conversation turns for context."""
|
|
return self.turns[-turns:] if self.turns else []
|
|
|
|
def get_relevant_facts(self, query: str) -> Dict[str, Any]:
|
|
"""Get facts relevant to a query."""
|
|
relevant_facts = {}
|
|
query_lower = query.lower()
|
|
|
|
for key, value in self.facts_learned.items():
|
|
if isinstance(value, str) and any(word in value.lower() for word in query_lower.split()):
|
|
relevant_facts[key] = value
|
|
|
|
return relevant_facts
|
|
|
|
def get_conversation_summary(self) -> str:
|
|
"""Generate a summary of the conversation."""
|
|
if not self.turns:
|
|
return "No conversation history."
|
|
|
|
summary_parts = []
|
|
|
|
if self.user_name:
|
|
summary_parts.append(f"User: {self.user_name}")
|
|
|
|
if self.topics_discussed:
|
|
topics_str = ", ".join(self.topics_discussed[:5])
|
|
summary_parts.append(f"Topics discussed: {topics_str}")
|
|
|
|
if self.preferences:
|
|
prefs = list(self.preferences.values())[:3]
|
|
summary_parts.append(f"User preferences: {'; '.join(prefs)}")
|
|
|
|
if self.emotional_context.get("current_emotion"):
|
|
summary_parts.append(f"Current mood: {self.emotional_context['current_emotion']}")
|
|
|
|
summary_parts.append(f"Conversation turns: {len(self.turns)}")
|
|
|
|
return " | ".join(summary_parts)
|
|
|
|
|
|
class ConversationContextManager:
|
|
"""Manager for conversation contexts and memory."""
|
|
|
|
def __init__(self, storage_path: Optional[str] = None):
|
|
self.storage_path = storage_path or "./conversation_contexts"
|
|
self.active_contexts: Dict[str, ConversationMemory] = {}
|
|
self.context_index: Dict[str, List[str]] = defaultdict(list) # bot_name -> conversation_ids
|
|
|
|
# Ensure storage directory exists
|
|
os.makedirs(self.storage_path, exist_ok=True)
|
|
|
|
# Load existing contexts
|
|
self._load_existing_contexts()
|
|
|
|
def _load_existing_contexts(self):
|
|
"""Load existing conversation contexts from storage."""
|
|
try:
|
|
context_files = [f for f in os.listdir(self.storage_path) if f.endswith('.json')]
|
|
for file in context_files:
|
|
try:
|
|
file_path = os.path.join(self.storage_path, file)
|
|
with open(file_path, 'r') as f:
|
|
data = json.load(f)
|
|
|
|
# Convert turn data back to ConversationTurn objects
|
|
turns = [ConversationTurn.from_dict(turn_data) for turn_data in data.get('turns', [])]
|
|
data['turns'] = turns
|
|
|
|
context = ConversationMemory(**data)
|
|
conversation_id = context.conversation_id
|
|
|
|
self.active_contexts[conversation_id] = context
|
|
self.context_index[context.bot_name].append(conversation_id)
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load context from {file}: {e}")
|
|
|
|
logger.info(f"Loaded {len(self.active_contexts)} conversation contexts")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load conversation contexts: {e}")
|
|
|
|
def get_or_create_context(
|
|
self,
|
|
session_id: str,
|
|
bot_name: str,
|
|
conversation_id: Optional[str] = None
|
|
) -> ConversationMemory:
|
|
"""Get existing context or create a new one."""
|
|
|
|
if conversation_id and conversation_id in self.active_contexts:
|
|
return self.active_contexts[conversation_id]
|
|
|
|
# Create new conversation ID if not provided
|
|
if not conversation_id:
|
|
conversation_id = f"{session_id}_{bot_name}_{int(time.time())}"
|
|
|
|
# Create new context
|
|
context = ConversationMemory(
|
|
session_id=session_id,
|
|
bot_name=bot_name,
|
|
conversation_id=conversation_id
|
|
)
|
|
|
|
self.active_contexts[conversation_id] = context
|
|
self.context_index[bot_name].append(conversation_id)
|
|
|
|
logger.info(f"Created new conversation context: {conversation_id}")
|
|
return context
|
|
|
|
def save_context(self, conversation_id: str):
|
|
"""Save a conversation context to storage."""
|
|
if conversation_id not in self.active_contexts:
|
|
logger.warning(f"Context {conversation_id} not found for saving")
|
|
return
|
|
|
|
context = self.active_contexts[conversation_id]
|
|
|
|
try:
|
|
# Convert to dict for serialization
|
|
data = context.model_dump()
|
|
# Convert ConversationTurn objects to dicts
|
|
data['turns'] = [turn.to_dict() for turn in context.turns]
|
|
|
|
file_path = os.path.join(self.storage_path, f"{conversation_id}.json")
|
|
with open(file_path, 'w') as f:
|
|
json.dump(data, f, indent=2)
|
|
|
|
logger.debug(f"Saved context: {conversation_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to save context {conversation_id}: {e}")
|
|
|
|
def add_conversation_turn(
|
|
self,
|
|
conversation_id: str,
|
|
user_message: str,
|
|
bot_response: str,
|
|
context_used: Optional[Dict[str, Any]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
):
|
|
"""Add a conversation turn to the specified context."""
|
|
if conversation_id not in self.active_contexts:
|
|
logger.warning(f"Context {conversation_id} not found")
|
|
return
|
|
|
|
turn = ConversationTurn(
|
|
turn_id=f"{conversation_id}_{len(self.active_contexts[conversation_id].turns)}",
|
|
timestamp=time.time(),
|
|
user_message=user_message,
|
|
bot_response=bot_response,
|
|
context_used=context_used or {},
|
|
metadata=metadata or {}
|
|
)
|
|
|
|
self.active_contexts[conversation_id].add_turn(turn)
|
|
|
|
# Auto-save after each turn
|
|
self.save_context(conversation_id)
|
|
|
|
def get_context_for_response(self, conversation_id: str) -> Optional[str]:
|
|
"""Get formatted context for generating bot responses."""
|
|
if conversation_id not in self.active_contexts:
|
|
return None
|
|
|
|
context = self.active_contexts[conversation_id]
|
|
context_parts = []
|
|
|
|
# Add conversation summary
|
|
summary = context.get_conversation_summary()
|
|
if summary != "No conversation history.":
|
|
context_parts.append(f"Conversation context: {summary}")
|
|
|
|
# Add recent turns for immediate context
|
|
recent_turns = context.get_recent_context(3)
|
|
if recent_turns:
|
|
context_parts.append("Recent conversation:")
|
|
for turn in recent_turns:
|
|
context_parts.append(f"User: {turn.user_message}")
|
|
context_parts.append(f"Bot: {turn.bot_response}")
|
|
|
|
# Add relevant facts
|
|
if context.facts_learned:
|
|
facts_str = "; ".join([f"{k}: {v}" for k, v in list(context.facts_learned.items())[:3]])
|
|
context_parts.append(f"Known facts: {facts_str}")
|
|
|
|
# Add emotional context
|
|
if context.emotional_context.get("current_emotion"):
|
|
context_parts.append(f"User's current mood: {context.emotional_context['current_emotion']}")
|
|
|
|
return "\n".join(context_parts) if context_parts else None
|
|
|
|
def get_contexts_for_bot(self, bot_name: str) -> List[ConversationMemory]:
|
|
"""Get all contexts for a specific bot."""
|
|
conversation_ids = self.context_index.get(bot_name, [])
|
|
return [self.active_contexts[cid] for cid in conversation_ids if cid in self.active_contexts]
|
|
|
|
def cleanup_old_contexts(self, max_age_days: int = 30):
|
|
"""Clean up old conversation contexts."""
|
|
current_time = time.time()
|
|
max_age_seconds = max_age_days * 24 * 60 * 60
|
|
|
|
contexts_to_remove = []
|
|
for conversation_id, context in self.active_contexts.items():
|
|
if current_time - context.last_updated > max_age_seconds:
|
|
contexts_to_remove.append(conversation_id)
|
|
|
|
for conversation_id in contexts_to_remove:
|
|
context = self.active_contexts[conversation_id]
|
|
|
|
# Remove from index
|
|
if context.bot_name in self.context_index:
|
|
if conversation_id in self.context_index[context.bot_name]:
|
|
self.context_index[context.bot_name].remove(conversation_id)
|
|
|
|
# Remove context file
|
|
try:
|
|
file_path = os.path.join(self.storage_path, f"{conversation_id}.json")
|
|
if os.path.exists(file_path):
|
|
os.remove(file_path)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to remove context file {conversation_id}: {e}")
|
|
|
|
# Remove from active contexts
|
|
del self.active_contexts[conversation_id]
|
|
|
|
if contexts_to_remove:
|
|
logger.info(f"Cleaned up {len(contexts_to_remove)} old conversation contexts")
|
|
|
|
def get_statistics(self) -> Dict[str, Any]:
|
|
"""Get statistics about conversation contexts."""
|
|
total_contexts = len(self.active_contexts)
|
|
total_turns = sum(len(context.turns) for context in self.active_contexts.values())
|
|
|
|
bot_stats = {}
|
|
for bot_name, conversation_ids in self.context_index.items():
|
|
active_conversations = [cid for cid in conversation_ids if cid in self.active_contexts]
|
|
bot_stats[bot_name] = {
|
|
"active_conversations": len(active_conversations),
|
|
"total_turns": sum(len(self.active_contexts[cid].turns) for cid in active_conversations)
|
|
}
|
|
|
|
return {
|
|
"total_contexts": total_contexts,
|
|
"total_turns": total_turns,
|
|
"average_turns_per_context": total_turns / total_contexts if total_contexts > 0 else 0,
|
|
"bot_statistics": bot_stats
|
|
}
|
|
|
|
|
|
# Global context manager instance
|
|
context_manager = ConversationContextManager()
|