Streaming working
This commit is contained in:
parent
3c7584eadb
commit
7a06bfc102
@ -380,16 +380,32 @@ class LocalProvider(AIProvider):
|
|||||||
"num_predict": self.config.max_tokens
|
"num_predict": self.config.max_tokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
logger.info(f"LocalProvider generate payload: {payload}")
|
||||||
try:
|
try:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"{self.base_url}/api/chat",
|
f"{self.base_url}/chat/completions",
|
||||||
json=payload,
|
json=payload,
|
||||||
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
|
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
|
||||||
) as resp:
|
) as resp:
|
||||||
if resp.status == 200:
|
if resp.status == 200:
|
||||||
result = await resp.json()
|
result = await resp.json()
|
||||||
|
|
||||||
|
# Handle OpenAI-compatible format
|
||||||
|
if "choices" in result and len(result["choices"]) > 0:
|
||||||
|
choice = result["choices"][0]
|
||||||
|
if "message" in choice and "content" in choice["message"]:
|
||||||
|
response_text = choice["message"]["content"]
|
||||||
|
elif "text" in choice:
|
||||||
|
# Some APIs use "text" instead of "message.content"
|
||||||
|
response_text = choice["text"]
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unexpected response format: {result}")
|
||||||
|
# Fallback to original format for backward compatibility
|
||||||
|
elif "message" in result and "content" in result["message"]:
|
||||||
response_text = result["message"]["content"]
|
response_text = result["message"]["content"]
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unexpected response format: {result}")
|
||||||
|
|
||||||
context.add_message(MessageRole.ASSISTANT, response_text)
|
context.add_message(MessageRole.ASSISTANT, response_text)
|
||||||
return response_text
|
return response_text
|
||||||
else:
|
else:
|
||||||
@ -411,25 +427,51 @@ class LocalProvider(AIProvider):
|
|||||||
"messages": context.get_context_messages(),
|
"messages": context.get_context_messages(),
|
||||||
"stream": True
|
"stream": True
|
||||||
}
|
}
|
||||||
|
logger.info(f"LocalProvider stream payload: {self.base_url} {payload}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"{self.base_url}/api/chat",
|
f"{self.base_url}/chat/completions",
|
||||||
json=payload,
|
json=payload,
|
||||||
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
|
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
|
||||||
) as resp:
|
) as resp:
|
||||||
if resp.status == 200:
|
if resp.status == 200:
|
||||||
full_response = ""
|
full_response = ""
|
||||||
|
import json
|
||||||
async for line in resp.content:
|
async for line in resp.content:
|
||||||
if line:
|
if line:
|
||||||
import json
|
|
||||||
try:
|
try:
|
||||||
data = json.loads(line.decode())
|
# Decode the line
|
||||||
if "message" in data and "content" in data["message"]:
|
line_str = line.decode('utf-8').strip()
|
||||||
content = data["message"]["content"]
|
|
||||||
|
# Skip empty lines
|
||||||
|
if not line_str:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle Server-Sent Events format
|
||||||
|
if line_str.startswith('data: '):
|
||||||
|
line_str = line_str[6:] # Remove 'data: ' prefix
|
||||||
|
|
||||||
|
# Skip end-of-stream marker
|
||||||
|
if line_str == '[DONE]':
|
||||||
|
break
|
||||||
|
|
||||||
|
data = json.loads(line_str)
|
||||||
|
|
||||||
|
# Handle OpenAI-compatible format
|
||||||
|
if "choices" in data and len(data["choices"]) > 0:
|
||||||
|
choice = data["choices"][0]
|
||||||
|
if "delta" in choice and "content" in choice["delta"]:
|
||||||
|
content = choice["delta"]["content"]
|
||||||
|
if content: # Only yield non-empty content
|
||||||
full_response += content
|
full_response += content
|
||||||
yield content
|
yield content
|
||||||
except json.JSONDecodeError:
|
|
||||||
|
except (UnicodeDecodeError, json.JSONDecodeError) as e:
|
||||||
|
logger.debug(f"Skipping invalid line: {line} (error: {e})")
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Unexpected error processing line: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
context.add_message(MessageRole.ASSISTANT, full_response)
|
context.add_message(MessageRole.ASSISTANT, full_response)
|
||||||
@ -437,7 +479,7 @@ class LocalProvider(AIProvider):
|
|||||||
raise Exception(f"Local API returned status {resp.status}")
|
raise Exception(f"Local API returned status {resp.status}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Local provider streaming failed: {e}")
|
logger.error(f"Local provider streaming failed {self.base_url}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def health_check(self) -> bool:
|
async def health_check(self) -> bool:
|
||||||
@ -479,8 +521,8 @@ class AIProviderManager:
|
|||||||
),
|
),
|
||||||
AIProviderType.LOCAL: AIProviderConfig(
|
AIProviderType.LOCAL: AIProviderConfig(
|
||||||
provider_type=AIProviderType.LOCAL,
|
provider_type=AIProviderType.LOCAL,
|
||||||
base_url=os.getenv("LOCAL_MODEL_URL", "http://localhost:11434"),
|
base_url=os.getenv("OPENAI_BASE_URL", "http://localhost:11434"),
|
||||||
model=os.getenv("LOCAL_MODEL_NAME", "llama2"),
|
model=os.getenv("OPENAI_MODEL", "llama2"),
|
||||||
max_tokens=int(os.getenv("LOCAL_MAX_TOKENS", "1000")),
|
max_tokens=int(os.getenv("LOCAL_MAX_TOKENS", "1000")),
|
||||||
temperature=float(os.getenv("LOCAL_TEMPERATURE", "0.7"))
|
temperature=float(os.getenv("LOCAL_TEMPERATURE", "0.7"))
|
||||||
)
|
)
|
||||||
|
@ -2,6 +2,5 @@
|
|||||||
|
|
||||||
from . import synthetic_media
|
from . import synthetic_media
|
||||||
from . import whisper
|
from . import whisper
|
||||||
from . import chatbot
|
|
||||||
|
|
||||||
__all__ = ["synthetic_media", "whisper", "chatbot"]
|
__all__ = ["synthetic_media", "whisper"]
|
||||||
|
@ -10,7 +10,8 @@ This bot demonstrates the advanced capabilities including:
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, Optional, Callable, Awaitable, Any, Union
|
import secrets
|
||||||
|
from typing import Dict, Optional, Callable, Awaitable, Any, Union, AsyncGenerator
|
||||||
from aiortc import MediaStreamTrack
|
from aiortc import MediaStreamTrack
|
||||||
|
|
||||||
# Import system modules
|
# Import system modules
|
||||||
@ -85,6 +86,15 @@ class EnhancedAIChatbot:
|
|||||||
self.conversation_context = None
|
self.conversation_context = None
|
||||||
self.initialized = False
|
self.initialized = False
|
||||||
|
|
||||||
|
# Instance configuration variables
|
||||||
|
self.bot_personality = BOT_PERSONALITY
|
||||||
|
self.bot_ai_provider = BOT_AI_PROVIDER
|
||||||
|
self.bot_streaming = BOT_STREAMING
|
||||||
|
self.bot_memory_enabled = BOT_MEMORY_ENABLED
|
||||||
|
|
||||||
|
# Per-lobby configurations
|
||||||
|
self.lobby_configs: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
# Initialize advanced features if available
|
# Initialize advanced features if available
|
||||||
if AI_PROVIDERS_AVAILABLE:
|
if AI_PROVIDERS_AVAILABLE:
|
||||||
self._initialize_ai_features()
|
self._initialize_ai_features()
|
||||||
@ -95,18 +105,18 @@ class EnhancedAIChatbot:
|
|||||||
"""Initialize AI provider, personality, and context management."""
|
"""Initialize AI provider, personality, and context management."""
|
||||||
try:
|
try:
|
||||||
# Initialize personality
|
# Initialize personality
|
||||||
self.personality = personality_manager.create_personality_from_template(BOT_PERSONALITY)
|
self.personality = personality_manager.create_personality_from_template(self.bot_personality)
|
||||||
if not self.personality:
|
if not self.personality:
|
||||||
logger.warning(f"Personality template '{BOT_PERSONALITY}' not found, using default")
|
logger.warning(f"Personality template '{self.bot_personality}' not found, using default")
|
||||||
self.personality = personality_manager.create_personality_from_template("helpful_assistant")
|
self.personality = personality_manager.create_personality_from_template("helpful_assistant")
|
||||||
|
|
||||||
# Initialize AI provider
|
# Initialize AI provider
|
||||||
provider_type = AIProviderType(BOT_AI_PROVIDER)
|
provider_type = AIProviderType(self.bot_ai_provider)
|
||||||
self.ai_provider = ai_provider_manager.create_provider(provider_type)
|
self.ai_provider = ai_provider_manager.create_provider(provider_type)
|
||||||
ai_provider_manager.register_provider(f"{AGENT_NAME}_{self.session_id}", self.ai_provider)
|
ai_provider_manager.register_provider(f"{AGENT_NAME}_{self.session_id}", self.ai_provider)
|
||||||
|
|
||||||
# Initialize conversation context if memory is enabled
|
# Initialize conversation context if memory is enabled
|
||||||
if BOT_MEMORY_ENABLED:
|
if self.bot_memory_enabled:
|
||||||
self.conversation_context = context_manager.get_or_create_context(
|
self.conversation_context = context_manager.get_or_create_context(
|
||||||
session_id=self.session_id,
|
session_id=self.session_id,
|
||||||
bot_name=AGENT_NAME,
|
bot_name=AGENT_NAME,
|
||||||
@ -114,7 +124,7 @@ class EnhancedAIChatbot:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.initialized = True
|
self.initialized = True
|
||||||
logger.info(f"Enhanced AI chatbot initialized: provider={BOT_AI_PROVIDER}, personality={BOT_PERSONALITY}, memory={BOT_MEMORY_ENABLED}")
|
logger.info(f"Enhanced AI chatbot initialized: provider={self.bot_ai_provider}, personality={self.bot_personality}, memory={self.bot_memory_enabled}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize AI features: {e}")
|
logger.error(f"Failed to initialize AI features: {e}")
|
||||||
@ -153,7 +163,7 @@ class EnhancedAIChatbot:
|
|||||||
ai_context.add_message(MessageRole.SYSTEM, self.personality.generate_system_prompt())
|
ai_context.add_message(MessageRole.SYSTEM, self.personality.generate_system_prompt())
|
||||||
|
|
||||||
# Generate response
|
# Generate response
|
||||||
if BOT_STREAMING:
|
if self.bot_streaming:
|
||||||
# For streaming, collect the full response
|
# For streaming, collect the full response
|
||||||
response_parts = []
|
response_parts = []
|
||||||
async for chunk in self.ai_provider.stream_response(ai_context, message):
|
async for chunk in self.ai_provider.stream_response(ai_context, message):
|
||||||
@ -168,8 +178,8 @@ class EnhancedAIChatbot:
|
|||||||
conversation_id=self.conversation_context.conversation_id,
|
conversation_id=self.conversation_context.conversation_id,
|
||||||
user_message=message,
|
user_message=message,
|
||||||
bot_response=response,
|
bot_response=response,
|
||||||
context_used={"ai_provider": BOT_AI_PROVIDER, "personality": BOT_PERSONALITY},
|
context_used={"ai_provider": self.bot_ai_provider, "personality": self.bot_personality},
|
||||||
metadata={"timestamp": time.time(), "streaming": BOT_STREAMING}
|
metadata={"timestamp": time.time(), "streaming": self.bot_streaming}
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
@ -227,6 +237,63 @@ class EnhancedAIChatbot:
|
|||||||
|
|
||||||
return health
|
return health
|
||||||
|
|
||||||
|
async def generate_streaming_response(self, message: str) -> AsyncGenerator[str, None]:
|
||||||
|
"""Generate a streaming response, yielding partial responses as chunks arrive."""
|
||||||
|
if not self.initialized or not self.ai_provider:
|
||||||
|
yield self._get_fallback_response(message)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prepare conversation context (same as generate_response)
|
||||||
|
if self.conversation_context:
|
||||||
|
# Create a new AI conversation context with personality
|
||||||
|
ai_context = ConversationContext(
|
||||||
|
session_id=self.session_id,
|
||||||
|
bot_name=AGENT_NAME,
|
||||||
|
personality_prompt=self.personality.generate_system_prompt() if self.personality else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add personality system message
|
||||||
|
if self.personality:
|
||||||
|
ai_context.add_message(MessageRole.SYSTEM, self.personality.generate_system_prompt())
|
||||||
|
|
||||||
|
# Add conversation history context
|
||||||
|
context_summary = context_manager.get_context_for_response(self.conversation_context.conversation_id)
|
||||||
|
if context_summary:
|
||||||
|
ai_context.add_message(MessageRole.SYSTEM, f"Conversation context: {context_summary}")
|
||||||
|
else:
|
||||||
|
# Simple context without memory
|
||||||
|
ai_context = ConversationContext(
|
||||||
|
session_id=self.session_id,
|
||||||
|
bot_name=AGENT_NAME
|
||||||
|
)
|
||||||
|
if self.personality:
|
||||||
|
ai_context.add_message(MessageRole.SYSTEM, self.personality.generate_system_prompt())
|
||||||
|
|
||||||
|
# Stream the response
|
||||||
|
accumulated_response = ""
|
||||||
|
chunk_count = 0
|
||||||
|
async for chunk in self.ai_provider.stream_response(ai_context, message):
|
||||||
|
accumulated_response += chunk
|
||||||
|
chunk_count += 1
|
||||||
|
logger.info(f"AI provider yielded chunk {chunk_count}: '{chunk}' (accumulated: {len(accumulated_response)} chars)")
|
||||||
|
yield accumulated_response
|
||||||
|
|
||||||
|
# Store conversation turn in context manager after streaming is complete
|
||||||
|
if self.conversation_context:
|
||||||
|
context_manager.add_conversation_turn(
|
||||||
|
conversation_id=self.conversation_context.conversation_id,
|
||||||
|
user_message=message,
|
||||||
|
bot_response=accumulated_response,
|
||||||
|
context_used={"ai_provider": self.bot_ai_provider, "personality": self.bot_personality},
|
||||||
|
metadata={"timestamp": time.time(), "streaming": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"AI streaming response generation failed: {e}")
|
||||||
|
yield self._get_fallback_response(message, error=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Global bot instance
|
# Global bot instance
|
||||||
_bot_instance: Optional[EnhancedAIChatbot] = None
|
_bot_instance: Optional[EnhancedAIChatbot] = None
|
||||||
@ -268,7 +335,13 @@ async def handle_chat_message(
|
|||||||
_bot_instance = EnhancedAIChatbot(chat_message.sender_name)
|
_bot_instance = EnhancedAIChatbot(chat_message.sender_name)
|
||||||
logger.info(f"Initialized enhanced AI chatbot for session: {chat_message.sender_name}")
|
logger.info(f"Initialized enhanced AI chatbot for session: {chat_message.sender_name}")
|
||||||
|
|
||||||
# Generate response
|
if _bot_instance.bot_streaming:
|
||||||
|
# Handle streaming response
|
||||||
|
logger.info(f"Using streaming response path, bot_streaming={_bot_instance.bot_streaming}")
|
||||||
|
return await _handle_streaming_response(chat_message, send_message_func)
|
||||||
|
else:
|
||||||
|
# Generate non-streaming response
|
||||||
|
logger.info(f"Using non-streaming response path, bot_streaming={_bot_instance.bot_streaming}")
|
||||||
response = await _bot_instance.generate_response(chat_message.message)
|
response = await _bot_instance.generate_response(chat_message.message)
|
||||||
|
|
||||||
# Send response
|
# Send response
|
||||||
@ -285,6 +358,94 @@ async def handle_chat_message(
|
|||||||
return error_response
|
return error_response
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_streaming_response(
|
||||||
|
chat_message: ChatMessageModel,
|
||||||
|
send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]]
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Handle streaming response by sending updates as chunks arrive."""
|
||||||
|
global _bot_instance
|
||||||
|
|
||||||
|
logger.info("Starting _handle_streaming_response")
|
||||||
|
message_id = None
|
||||||
|
try:
|
||||||
|
# Generate a unique message ID for this streaming response
|
||||||
|
message_id = secrets.token_hex(8)
|
||||||
|
|
||||||
|
# Get the client's session_id from the bound method
|
||||||
|
client = getattr(send_message_func, '__self__', None)
|
||||||
|
client_session_id = client.session_id if client else chat_message.sender_session_id
|
||||||
|
|
||||||
|
# Send initial empty message to establish the message in the chat
|
||||||
|
initial_message = ChatMessageModel(
|
||||||
|
id=str(message_id),
|
||||||
|
message="",
|
||||||
|
sender_name=_bot_instance.session_name if _bot_instance else "AI Chatbot",
|
||||||
|
sender_session_id=client_session_id,
|
||||||
|
timestamp=time.time(),
|
||||||
|
lobby_id=chat_message.lobby_id,
|
||||||
|
)
|
||||||
|
await send_message_func(initial_message)
|
||||||
|
logger.info(f"Started streaming response with message ID: {message_id}")
|
||||||
|
|
||||||
|
# Check if bot instance exists
|
||||||
|
if not _bot_instance:
|
||||||
|
error_msg = "Bot instance not available for streaming"
|
||||||
|
update_message = ChatMessageModel(
|
||||||
|
id=str(message_id),
|
||||||
|
message=error_msg,
|
||||||
|
sender_name="AI Chatbot",
|
||||||
|
sender_session_id=client_session_id,
|
||||||
|
timestamp=time.time(),
|
||||||
|
lobby_id=chat_message.lobby_id,
|
||||||
|
)
|
||||||
|
await send_message_func(update_message)
|
||||||
|
return error_msg
|
||||||
|
|
||||||
|
# Stream the response
|
||||||
|
final_response = ""
|
||||||
|
chunk_count = 0
|
||||||
|
async for partial_response in _bot_instance.generate_streaming_response(chat_message.message):
|
||||||
|
final_response = partial_response
|
||||||
|
chunk_count += 1
|
||||||
|
logger.info(f"Sending streaming chunk {chunk_count}: {partial_response[:50]}...")
|
||||||
|
|
||||||
|
update_message = ChatMessageModel(
|
||||||
|
id=str(message_id),
|
||||||
|
message=partial_response,
|
||||||
|
sender_name=_bot_instance.session_name,
|
||||||
|
sender_session_id=client_session_id,
|
||||||
|
timestamp=time.time(),
|
||||||
|
lobby_id=chat_message.lobby_id,
|
||||||
|
)
|
||||||
|
await send_message_func(update_message)
|
||||||
|
|
||||||
|
logger.info(f"Completed streaming response to {chat_message.sender_name}: {final_response[:100]}...")
|
||||||
|
return final_response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in streaming response: {e}")
|
||||||
|
error_response = "I apologize, but I encountered an error. Please try again."
|
||||||
|
|
||||||
|
# Try to update the existing message with the error, or send a new one
|
||||||
|
try:
|
||||||
|
client = getattr(send_message_func, '__self__', None)
|
||||||
|
client_session_id = client.session_id if client else chat_message.sender_session_id
|
||||||
|
error_message = ChatMessageModel(
|
||||||
|
id=str(message_id),
|
||||||
|
message=error_response,
|
||||||
|
sender_name=_bot_instance.session_name if _bot_instance else "AI Chatbot",
|
||||||
|
sender_session_id=client_session_id,
|
||||||
|
timestamp=time.time(),
|
||||||
|
lobby_id=chat_message.lobby_id,
|
||||||
|
)
|
||||||
|
await send_message_func(error_message)
|
||||||
|
except (NameError, TypeError, AttributeError):
|
||||||
|
# If message_id is not defined or other issues, send as string
|
||||||
|
await send_message_func(error_response)
|
||||||
|
|
||||||
|
return error_response
|
||||||
|
|
||||||
|
|
||||||
async def get_bot_status() -> Dict[str, Any]:
|
async def get_bot_status() -> Dict[str, Any]:
|
||||||
"""Get detailed bot status and health information."""
|
"""Get detailed bot status and health information."""
|
||||||
global _bot_instance
|
global _bot_instance
|
||||||
@ -323,17 +484,17 @@ async def get_bot_status() -> Dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
# Additional helper functions for advanced features
|
# Additional helper functions for advanced features
|
||||||
async def switch_personality(personality_id: str) -> bool:
|
async def switch_personality(bot_instance: EnhancedAIChatbot, personality_id: str) -> bool:
|
||||||
"""Switch bot personality at runtime."""
|
"""Switch bot personality at runtime."""
|
||||||
global _bot_instance
|
|
||||||
|
|
||||||
if not AI_PROVIDERS_AVAILABLE or not _bot_instance:
|
if not AI_PROVIDERS_AVAILABLE or not bot_instance:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
new_personality = personality_manager.create_personality_from_template(personality_id)
|
new_personality = personality_manager.create_personality_from_template(personality_id)
|
||||||
if new_personality:
|
if new_personality:
|
||||||
_bot_instance.personality = new_personality
|
bot_instance.personality = new_personality
|
||||||
|
bot_instance.bot_personality = personality_id
|
||||||
logger.info(f"Switched to personality: {personality_id}")
|
logger.info(f"Switched to personality: {personality_id}")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -342,9 +503,8 @@ async def switch_personality(personality_id: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def switch_ai_provider(provider_type: str) -> bool:
|
async def switch_ai_provider(bot_instance: EnhancedAIChatbot, provider_type: str) -> bool:
|
||||||
"""Switch AI provider at runtime."""
|
"""Switch AI provider at runtime."""
|
||||||
global _bot_instance, BOT_AI_PROVIDER
|
|
||||||
|
|
||||||
logger.info(f"Switching AI provider to: {provider_type}")
|
logger.info(f"Switching AI provider to: {provider_type}")
|
||||||
|
|
||||||
@ -353,17 +513,13 @@ async def switch_ai_provider(provider_type: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Always update the global default first
|
|
||||||
old_provider = BOT_AI_PROVIDER
|
|
||||||
BOT_AI_PROVIDER = provider_type
|
|
||||||
logger.info(f"Updated global BOT_AI_PROVIDER from {old_provider} to {provider_type}")
|
|
||||||
|
|
||||||
# If instance exists, switch its provider
|
# If instance exists, switch its provider
|
||||||
if _bot_instance:
|
if bot_instance:
|
||||||
logger.info("Switching existing bot instance provider")
|
logger.info("Switching existing bot instance provider")
|
||||||
provider_enum = AIProviderType(provider_type)
|
provider_enum = AIProviderType(provider_type)
|
||||||
new_provider = ai_provider_manager.create_provider(provider_enum)
|
new_provider = ai_provider_manager.create_provider(provider_enum)
|
||||||
_bot_instance.ai_provider = new_provider
|
bot_instance.ai_provider = new_provider
|
||||||
|
bot_instance.bot_ai_provider = provider_type
|
||||||
logger.info(f"Switched existing instance to AI provider: {provider_type}")
|
logger.info(f"Switched existing instance to AI provider: {provider_type}")
|
||||||
else:
|
else:
|
||||||
logger.info("No existing bot instance to switch")
|
logger.info("No existing bot instance to switch")
|
||||||
@ -371,8 +527,6 @@ async def switch_ai_provider(provider_type: str) -> bool:
|
|||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to switch AI provider to {provider_type}: {e}")
|
logger.error(f"Failed to switch AI provider to {provider_type}: {e}")
|
||||||
# Revert the global change on failure
|
|
||||||
BOT_AI_PROVIDER = old_provider
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@ -481,17 +635,21 @@ async def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) ->
|
|||||||
try:
|
try:
|
||||||
logger.info(f"Updating config for lobby {lobby_id}: {config_values}")
|
logger.info(f"Updating config for lobby {lobby_id}: {config_values}")
|
||||||
|
|
||||||
|
# Get the bot instance (create if doesn't exist)
|
||||||
|
if _bot_instance is None:
|
||||||
|
_bot_instance = EnhancedAIChatbot("AI Chatbot")
|
||||||
|
|
||||||
# Apply configuration changes
|
# Apply configuration changes
|
||||||
config_applied = False
|
config_applied = False
|
||||||
|
|
||||||
if "personality" in config_values:
|
if "personality" in config_values:
|
||||||
success = await switch_personality(config_values["personality"])
|
success = await switch_personality(_bot_instance, config_values["personality"])
|
||||||
if success:
|
if success:
|
||||||
config_applied = True
|
config_applied = True
|
||||||
logger.info(f"Applied personality: {config_values['personality']}")
|
logger.info(f"Applied personality: {config_values['personality']}")
|
||||||
|
|
||||||
if "ai_provider" in config_values:
|
if "ai_provider" in config_values:
|
||||||
success = await switch_ai_provider(config_values["ai_provider"])
|
success = await switch_ai_provider(_bot_instance, config_values["ai_provider"])
|
||||||
if success:
|
if success:
|
||||||
config_applied = True
|
config_applied = True
|
||||||
logger.info(f"Applied AI provider: {config_values['ai_provider']}")
|
logger.info(f"Applied AI provider: {config_values['ai_provider']}")
|
||||||
@ -499,19 +657,16 @@ async def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) ->
|
|||||||
logger.warning(f"Failed to apply AI provider: {config_values['ai_provider']}")
|
logger.warning(f"Failed to apply AI provider: {config_values['ai_provider']}")
|
||||||
|
|
||||||
if "streaming" in config_values:
|
if "streaming" in config_values:
|
||||||
global BOT_STREAMING
|
_bot_instance.bot_streaming = bool(config_values["streaming"])
|
||||||
BOT_STREAMING = bool(config_values["streaming"])
|
|
||||||
config_applied = True
|
config_applied = True
|
||||||
logger.info(f"Applied streaming: {BOT_STREAMING}")
|
logger.info(f"Applied streaming: {_bot_instance.bot_streaming}")
|
||||||
|
|
||||||
if "memory_enabled" in config_values:
|
if "memory_enabled" in config_values:
|
||||||
global BOT_MEMORY_ENABLED
|
_bot_instance.bot_memory_enabled = bool(config_values["memory_enabled"])
|
||||||
BOT_MEMORY_ENABLED = bool(config_values["memory_enabled"])
|
|
||||||
config_applied = True
|
config_applied = True
|
||||||
logger.info(f"Applied memory: {BOT_MEMORY_ENABLED}")
|
logger.info(f"Applied memory: {_bot_instance.bot_memory_enabled}")
|
||||||
|
|
||||||
# Store other configuration values for use in response generation
|
# Store other configuration values for use in response generation
|
||||||
if _bot_instance:
|
|
||||||
if not hasattr(_bot_instance, 'lobby_configs'):
|
if not hasattr(_bot_instance, 'lobby_configs'):
|
||||||
_bot_instance.lobby_configs = {}
|
_bot_instance.lobby_configs = {}
|
||||||
|
|
||||||
|
@ -1,89 +0,0 @@
|
|||||||
"""Simple chatbot agent that demonstrates chat message handling.
|
|
||||||
|
|
||||||
This bot shows how to create an agent that primarily uses chat functionality
|
|
||||||
rather than media streams.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, Optional, Callable, Awaitable, Union
|
|
||||||
import time
|
|
||||||
import random
|
|
||||||
from shared.logger import logger
|
|
||||||
from aiortc import MediaStreamTrack
|
|
||||||
|
|
||||||
# Import shared models for chat functionality
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
||||||
from shared.models import ChatMessageModel
|
|
||||||
|
|
||||||
|
|
||||||
AGENT_NAME = "chatbot"
|
|
||||||
AGENT_DESCRIPTION = "Simple chatbot that responds to chat messages"
|
|
||||||
|
|
||||||
# Simple response database
|
|
||||||
RESPONSES = {
|
|
||||||
"hello": ["Hello!", "Hi there!", "Hey!", "Greetings!"],
|
|
||||||
"how are you": ["I'm doing well, thank you!", "Great, thanks for asking!", "I'm fine!"],
|
|
||||||
"goodbye": ["Goodbye!", "See you later!", "Bye!", "Take care!"],
|
|
||||||
"help": ["I can respond to simple greetings and questions. Try saying hello, asking how I am, or say goodbye!"],
|
|
||||||
"time": ["Let me check... it's currently {time}"],
|
|
||||||
"joke": [
|
|
||||||
"Why don't scientists trust atoms? Because they make up everything!",
|
|
||||||
"I told my wife she was drawing her eyebrows too high. She seemed surprised.",
|
|
||||||
"What do you call a fish wearing a crown? A king fish!",
|
|
||||||
"Why don't eggs tell jokes? They'd crack each other up!"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def agent_info() -> Dict[str, str]:
|
|
||||||
return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION, "has_media": "false"}
|
|
||||||
|
|
||||||
|
|
||||||
def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]:
|
|
||||||
"""Chatbot doesn't provide media tracks - it's chat-only."""
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_chat_message(chat_message: ChatMessageModel, send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]]) -> Optional[str]:
|
|
||||||
"""Handle incoming chat messages and provide responses.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_message: The received chat message
|
|
||||||
send_message_func: Function to send messages back to the lobby
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional response message to send back to the lobby
|
|
||||||
"""
|
|
||||||
message_lower = chat_message.message.lower().strip()
|
|
||||||
sender = chat_message.sender_name
|
|
||||||
|
|
||||||
logger.info(f"Chatbot received message from {sender}: {chat_message.message}")
|
|
||||||
|
|
||||||
# Skip messages from ourselves
|
|
||||||
if sender.lower() == AGENT_NAME.lower():
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Look for keywords in the message
|
|
||||||
for keyword, responses in RESPONSES.items():
|
|
||||||
if keyword in message_lower:
|
|
||||||
response = random.choice(responses)
|
|
||||||
# Handle special formatting
|
|
||||||
if "{time}" in response:
|
|
||||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
response = response.format(time=current_time)
|
|
||||||
|
|
||||||
logger.info(f"Chatbot responding with: {response}")
|
|
||||||
return response
|
|
||||||
|
|
||||||
# If we get a direct mention or question, provide a generic response
|
|
||||||
if any(word in message_lower for word in ["bot", "chatbot", "?"]):
|
|
||||||
responses = [
|
|
||||||
f"Hi {sender}! I'm a simple chatbot. Say 'help' to see what I can do!",
|
|
||||||
f"Hello {sender}! I heard you mention me. How can I help?",
|
|
||||||
"I'm here and listening! Try asking me about the time or tell me a greeting!"
|
|
||||||
]
|
|
||||||
return random.choice(responses)
|
|
||||||
|
|
||||||
# Default: don't respond to unrecognized messages
|
|
||||||
return None
|
|
Loading…
x
Reference in New Issue
Block a user