diff --git a/frontend/src/Components/ChatQuery.tsx b/frontend/src/Components/ChatQuery.tsx index 949defd..07d4d46 100644 --- a/frontend/src/Components/ChatQuery.tsx +++ b/frontend/src/Components/ChatQuery.tsx @@ -44,6 +44,7 @@ export type { ChatQueryInterface, Query, ChatSubmitQueryInterface, + Tunables, }; export { diff --git a/frontend/src/NewApp/BackstoryApp.tsx b/frontend/src/NewApp/BackstoryApp.tsx index 88b2a21..52bbd5a 100644 --- a/frontend/src/NewApp/BackstoryApp.tsx +++ b/frontend/src/NewApp/BackstoryApp.tsx @@ -133,7 +133,7 @@ interface BackstoryPageContainerProps { const BackstoryPageContainer = (props : BackstoryPageContainerProps) => { const { children, sx } = props; return ( - + ; - questions: string[]; + questions: [{ + question: string; + tunables?: Tunables + }] }; const HomePage = forwardRef((props: BackstoryPageProps, ref) => { @@ -52,8 +55,8 @@ What would you like to know about ${user.first_name}? setQuestions([ - {user.questions.map((q: string, i: number) => - + {user.questions.map(({ question, tunables }, i: number) => + )} , diff --git a/src/server.py b/src/server.py index 4f0dcd5..b9b0415 100644 --- a/src/server.py +++ b/src/server.py @@ -444,8 +444,8 @@ class WebServer: response["rags"] = [ r.model_dump(mode="json") for r in context.rags ] case "tools": logger.info(f"Resetting {reset_operation}") - context.tools = Tools.enabled_tools(Tools.tools) - response["tools"] = context.tools + context.tools = Tools.all_tools() + response["tools"] = Tools.llm_tools(context.tools) case "history": reset_map = { "job_description": ( @@ -507,8 +507,8 @@ class WebServer: match k: case "tools": from typing import Any - # { "tools": [{ "tool": tool?.name, "enabled": tool.enabled }] } - tools: list[dict[str, Any]] = data[k] + # { "tools": [{ "tool": tool.name, "enabled": tool.enabled }] } + tools: List[Dict[str, Any]] = data[k] if not tools: return JSONResponse( { @@ -518,20 +518,15 @@ class WebServer: ) for tool in tools: for context_tool in context.tools: - if context_tool["function"]["name"] == tool["name"]: - context_tool["enabled"] = tool["enabled"] + if context_tool.tool.function.name == tool["name"]: + context_tool.enabled = tool.get("enabled", True) self.save_context(context_id) - return JSONResponse( - { - "tools": [ - { - **t["function"], - "enabled": t["enabled"], - } - for t in context.tools - ] - } - ) + return JSONResponse({ + "tools": [{ + **t.function.model_dump(mode='json'), + "enabled": t.enabled, + } for t in context.tools] + }) case "rags": from typing import Any @@ -581,7 +576,7 @@ class WebServer: "last_name": user.last_name, "full_name": user.full_name, "contact_info": user.contact_info, - "questions": user.user_questions, + "questions": [ q.model_dump(mode='json') for q in user.user_questions], } return JSONResponse(user_data) @@ -600,13 +595,10 @@ class WebServer: { "system_prompt": agent.system_prompt, "rags": [ r.model_dump(mode="json") for r in context.rags ], - "tools": [ - { - **t["function"], - "enabled": t["enabled"], - } - for t in context.tools - ], + "tools": [{ + **t.function.model_dump(mode='json'), + "enabled": t.enabled, + } for t in context.tools], } ) @@ -689,19 +681,13 @@ class WebServer: result = json.dumps(result) + "\n" message.network_packets += 1 message.network_bytes += len(result) - yield result - if await request.is_disconnected(): - logger.info("Disconnect detected. Aborting generation.") - context.processing = False - # Save context on completion or error - message.prompt = query.prompt - message.status = "error" - message.response = ( - "Client disconnected during generation." - ) - agent.conversation.add(message) - self.save_context(context_id) - return + disconnected = await request.is_disconnected() + if disconnected: + logger.info("Disconnect detected. Continuing generation to store in cache.") + disconnected = True + + if not disconnected: + yield result current_time = time.perf_counter() if current_time - start_time > LLM_TIMEOUT: @@ -711,7 +697,8 @@ class WebServer: logger.info(message.response + " Ending session") result = message.model_dump(by_alias=True, mode="json") result = json.dumps(result) + "\n" - yield result + if not disconnected: + yield result if message.status == "error": context.processing = False @@ -801,8 +788,8 @@ class WebServer: modify = data["tool"] enabled = data["enabled"] for tool in context.tools: - if modify == tool["function"]["name"]: - tool["enabled"] = enabled + if modify == tool.function.name: + tool.enabled = enabled self.save_context(context_id) return JSONResponse(context.tools) return JSONResponse( @@ -971,7 +958,7 @@ class WebServer: id=context_id, user=user, rags=[ rag.model_copy() for rag in user.rags ], - tools=Tools.enabled_tools(Tools.tools) + tools=Tools.all_tools() ) return self.contexts[context_id] @@ -1005,13 +992,13 @@ class WebServer: id=context_id, user=user, rags=[ rag.model_copy() for rag in user.rags ], - tools=Tools.enabled_tools(Tools.tools) + tools=Tools.all_tools() ) else: context = Context( user=user, rags=[ rag.model_copy() for rag in user.rags ], - tools=Tools.enabled_tools(Tools.tools) + tools=Tools.all_tools() ) except ValidationError as e: logger.error(e) diff --git a/src/utils/agents/base.py b/src/utils/agents/base.py index 127e121..2e6045b 100644 --- a/src/utils/agents/base.py +++ b/src/utils/agents/base.py @@ -36,13 +36,11 @@ from ..metrics import Metrics from ..tools import TickerValue, WeatherForecast, AnalyzeSite, DateTime, llm_tools # type: ignore -- dynamically added to __all__ from ..conversation import Conversation - class LLMMessage(BaseModel): role: str = Field(default="") content: str = Field(default="") tool_calls: Optional[List[Dict]] = Field(default={}, exclude=True) - class Agent(BaseModel, ABC): """ Base class for all agent types. diff --git a/src/utils/context.py b/src/utils/context.py index e98f469..eb43139 100644 --- a/src/utils/context.py +++ b/src/utils/context.py @@ -9,8 +9,8 @@ from uuid import uuid4 import traceback from . rag import RagEntry -from . import tools as Tools -from .agents import AnyAgent +from . tools import ToolEntry +from . agents import AnyAgent from . import User if TYPE_CHECKING: @@ -28,7 +28,7 @@ class Context(BaseModel): default_factory=lambda: str(uuid4()), pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", ) - tools: List[dict] + tools: List[ToolEntry] rags: List[RagEntry] username: str = "__invalid__" diff --git a/src/utils/tools/__init__.py b/src/utils/tools/__init__.py index c03439c..1f8c20a 100644 --- a/src/utils/tools/__init__.py +++ b/src/utils/tools/__init__.py @@ -1,6 +1,6 @@ import importlib -from .basetools import tools, llm_tools, enabled_tools, tool_functions +from .basetools import all_tools, ToolEntry, llm_tools, enabled_tools, tool_functions from ..setup_logging import setup_logging from .. import defines @@ -11,4 +11,4 @@ module = importlib.import_module(".basetools", package=__package__) for name in tool_functions: globals()[name] = getattr(module, name) -__all__ = ["tools", "llm_tools", "enabled_tools", "tool_functions"] +__all__ = ["all_tools", "ToolEntry", "llm_tools", "enabled_tools", "tool_functions"] diff --git a/src/utils/tools/basetools.py b/src/utils/tools/basetools.py index 7044a75..c43cbea 100644 --- a/src/utils/tools/basetools.py +++ b/src/utils/tools/basetools.py @@ -1,4 +1,6 @@ import os +from pydantic import BaseModel, Field, model_validator # type: ignore +from typing import List, Optional, Generator, ClassVar, Any, Dict from datetime import datetime from typing import ( Any, @@ -345,8 +347,18 @@ async def AnalyzeSite(llm, model: str, url: str, question: str): # %% -tools = [ - { +class Function(BaseModel): + name: str + description: str + parameters: Dict[str, Any] + returns: Optional[Dict[str, Any]] = {} + +class Tool(BaseModel): + type: str + function: Function + +tools : List[Tool] = [ + Tool.model_validate({ "type": "function", "function": { "name": "TickerValue", @@ -363,8 +375,8 @@ tools = [ "additionalProperties": False, }, }, - }, - { + }), + Tool.model_validate({ "type": "function", "function": { "name": "AnalyzeSite", @@ -402,8 +414,8 @@ tools = [ }, }, }, - }, - { + }), + Tool.model_validate({ "type": "function", "function": { "name": "DateTime", @@ -419,8 +431,8 @@ tools = [ "required": [], }, }, - }, - { + }), + Tool.model_validate({ "type": "function", "function": { "name": "WeatherForecast", @@ -443,18 +455,22 @@ tools = [ "additionalProperties": False, }, }, - }, + }), ] +class ToolEntry(BaseModel): + enabled: bool = True + tool: Tool -def llm_tools(tools): - return [tool for tool in tools if tool.get("enabled", False) == True] +def llm_tools(tools: List[ToolEntry]) -> List[Dict[str, Any]]: + return [entry.tool.model_dump(mode='json') for entry in tools if entry.enabled == True] +def all_tools() -> List[ToolEntry]: + return [ToolEntry(tool=tool) for tool in tools] -def enabled_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]: - return [{**tool, "enabled": True} for tool in tools] - +def enabled_tools(tools: List[ToolEntry]) -> List[ToolEntry]: + return [ToolEntry(tool=entry.tool) for entry in tools if entry.enabled == True] tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite"] -__all__ = ["tools", "llm_tools", "enabled_tools", "tool_functions"] +__all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"] # __all__.extend(__tool_functions__) # type: ignore diff --git a/src/utils/user.py b/src/utils/user.py index dd731af..627f284 100644 --- a/src/utils/user.py +++ b/src/utils/user.py @@ -24,6 +24,11 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) from .rag import RagEntry +from .message import Tunables + +class Question(BaseModel): + question: str + tunables: Tunables = Field(default_factory=Tunables) class User(BaseModel): model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc @@ -35,7 +40,7 @@ class User(BaseModel): last_name: str = "" full_name: str = "" contact_info : Dict[str, str] = {} - user_questions : List[str] = [] + user_questions : List[Question] = [] #context: Optional[List[Context]] = [] # file_watcher : ChromaDBFileWatcher = set by initialize @@ -193,7 +198,16 @@ class User(BaseModel): self.last_name = info.get("last_name", "") self.full_name = info.get("full_name", f"{self.first_name} {self.last_name}") self.contact_info = info.get("contact_info", {}) - self.user_questions = info.get("questions", [ f"Tell me about {self.first_name}.", f"What are {self.first_name}'s professional strengths?"]) + questions = info.get("questions", [ f"Tell me about {self.first_name}.", f"What are {self.first_name}'s professional strengths?"]) + self.user_questions = [] + for question in questions: + if type(question) == str: + self.user_questions.append(Question(question=question, tunables=Tunables(enable_tools=False))) + else: + try: + self.user_questions.append(Question.model_validate(question)) + except Exception as e: + logger.info(f"Unable to initialize all questions from {user_info}") os.makedirs(persist_directory, exist_ok=True) os.makedirs(watch_directory, exist_ok=True)