Working with new ui

This commit is contained in:
James Ketr 2025-05-19 10:09:00 -07:00
parent 4e3208190a
commit 705bbf146e
9 changed files with 92 additions and 73 deletions

View File

@ -44,6 +44,7 @@ export type {
ChatQueryInterface,
Query,
ChatSubmitQueryInterface,
Tunables,
};
export {

View File

@ -133,7 +133,7 @@ interface BackstoryPageContainerProps {
const BackstoryPageContainer = (props : BackstoryPageContainerProps) => {
const { children, sx } = props;
return (
<Container maxWidth="xl" sx={{ mt: 2, mb: 2, ...sx }}>
<Container maxWidth="xl" sx={{ mt: 2, mb: 2, height: "calc(1024px - 72px)", ...sx }}>
<Paper
elevation={2}
sx={{

View File

@ -6,7 +6,7 @@ import MuiMarkdown from 'mui-markdown';
import { BackstoryPageProps } from '../Components//BackstoryTab';
import { Conversation, ConversationHandle } from '../Components/Conversation';
import { ChatQuery } from '../Components/ChatQuery';
import { ChatQuery, Tunables } from '../Components/ChatQuery';
import { MessageList } from '../Components/Message';
import { connectionBase } from '../Global';
@ -17,7 +17,10 @@ type UserData = {
last_name: string;
full_name: string;
contact_info: Record<string, string>;
questions: string[];
questions: [{
question: string;
tunables?: Tunables
}]
};
const HomePage = forwardRef<ConversationHandle, BackstoryPageProps>((props: BackstoryPageProps, ref) => {
@ -52,8 +55,8 @@ What would you like to know about ${user.first_name}?
setQuestions([
<Box sx={{ display: "flex", flexDirection: isMobile ? "column" : "row" }}>
{user.questions.map((q: string, i: number) =>
<ChatQuery key={i} query={{ prompt: q, tunables: { enable_tools: false } }} submitQuery={submitQuery} />
{user.questions.map(({ question, tunables }, i: number) =>
<ChatQuery key={i} query={{ prompt: question, tunables: tunables }} submitQuery={submitQuery} />
)}
</Box>,
<Box sx={{ p: 1 }}>

View File

@ -444,8 +444,8 @@ class WebServer:
response["rags"] = [ r.model_dump(mode="json") for r in context.rags ]
case "tools":
logger.info(f"Resetting {reset_operation}")
context.tools = Tools.enabled_tools(Tools.tools)
response["tools"] = context.tools
context.tools = Tools.all_tools()
response["tools"] = Tools.llm_tools(context.tools)
case "history":
reset_map = {
"job_description": (
@ -507,8 +507,8 @@ class WebServer:
match k:
case "tools":
from typing import Any
# { "tools": [{ "tool": tool?.name, "enabled": tool.enabled }] }
tools: list[dict[str, Any]] = data[k]
# { "tools": [{ "tool": tool.name, "enabled": tool.enabled }] }
tools: List[Dict[str, Any]] = data[k]
if not tools:
return JSONResponse(
{
@ -518,20 +518,15 @@ class WebServer:
)
for tool in tools:
for context_tool in context.tools:
if context_tool["function"]["name"] == tool["name"]:
context_tool["enabled"] = tool["enabled"]
if context_tool.tool.function.name == tool["name"]:
context_tool.enabled = tool.get("enabled", True)
self.save_context(context_id)
return JSONResponse(
{
"tools": [
{
**t["function"],
"enabled": t["enabled"],
}
for t in context.tools
]
}
)
return JSONResponse({
"tools": [{
**t.function.model_dump(mode='json'),
"enabled": t.enabled,
} for t in context.tools]
})
case "rags":
from typing import Any
@ -581,7 +576,7 @@ class WebServer:
"last_name": user.last_name,
"full_name": user.full_name,
"contact_info": user.contact_info,
"questions": user.user_questions,
"questions": [ q.model_dump(mode='json') for q in user.user_questions],
}
return JSONResponse(user_data)
@ -600,13 +595,10 @@ class WebServer:
{
"system_prompt": agent.system_prompt,
"rags": [ r.model_dump(mode="json") for r in context.rags ],
"tools": [
{
**t["function"],
"enabled": t["enabled"],
}
for t in context.tools
],
"tools": [{
**t.function.model_dump(mode='json'),
"enabled": t.enabled,
} for t in context.tools],
}
)
@ -689,19 +681,13 @@ class WebServer:
result = json.dumps(result) + "\n"
message.network_packets += 1
message.network_bytes += len(result)
disconnected = await request.is_disconnected()
if disconnected:
logger.info("Disconnect detected. Continuing generation to store in cache.")
disconnected = True
if not disconnected:
yield result
if await request.is_disconnected():
logger.info("Disconnect detected. Aborting generation.")
context.processing = False
# Save context on completion or error
message.prompt = query.prompt
message.status = "error"
message.response = (
"Client disconnected during generation."
)
agent.conversation.add(message)
self.save_context(context_id)
return
current_time = time.perf_counter()
if current_time - start_time > LLM_TIMEOUT:
@ -711,6 +697,7 @@ class WebServer:
logger.info(message.response + " Ending session")
result = message.model_dump(by_alias=True, mode="json")
result = json.dumps(result) + "\n"
if not disconnected:
yield result
if message.status == "error":
@ -801,8 +788,8 @@ class WebServer:
modify = data["tool"]
enabled = data["enabled"]
for tool in context.tools:
if modify == tool["function"]["name"]:
tool["enabled"] = enabled
if modify == tool.function.name:
tool.enabled = enabled
self.save_context(context_id)
return JSONResponse(context.tools)
return JSONResponse(
@ -971,7 +958,7 @@ class WebServer:
id=context_id,
user=user,
rags=[ rag.model_copy() for rag in user.rags ],
tools=Tools.enabled_tools(Tools.tools)
tools=Tools.all_tools()
)
return self.contexts[context_id]
@ -1005,13 +992,13 @@ class WebServer:
id=context_id,
user=user,
rags=[ rag.model_copy() for rag in user.rags ],
tools=Tools.enabled_tools(Tools.tools)
tools=Tools.all_tools()
)
else:
context = Context(
user=user,
rags=[ rag.model_copy() for rag in user.rags ],
tools=Tools.enabled_tools(Tools.tools)
tools=Tools.all_tools()
)
except ValidationError as e:
logger.error(e)

View File

@ -36,13 +36,11 @@ from ..metrics import Metrics
from ..tools import TickerValue, WeatherForecast, AnalyzeSite, DateTime, llm_tools # type: ignore -- dynamically added to __all__
from ..conversation import Conversation
class LLMMessage(BaseModel):
role: str = Field(default="")
content: str = Field(default="")
tool_calls: Optional[List[Dict]] = Field(default={}, exclude=True)
class Agent(BaseModel, ABC):
"""
Base class for all agent types.

View File

@ -9,7 +9,7 @@ from uuid import uuid4
import traceback
from . rag import RagEntry
from . import tools as Tools
from . tools import ToolEntry
from . agents import AnyAgent
from . import User
@ -28,7 +28,7 @@ class Context(BaseModel):
default_factory=lambda: str(uuid4()),
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
)
tools: List[dict]
tools: List[ToolEntry]
rags: List[RagEntry]
username: str = "__invalid__"

View File

@ -1,6 +1,6 @@
import importlib
from .basetools import tools, llm_tools, enabled_tools, tool_functions
from .basetools import all_tools, ToolEntry, llm_tools, enabled_tools, tool_functions
from ..setup_logging import setup_logging
from .. import defines
@ -11,4 +11,4 @@ module = importlib.import_module(".basetools", package=__package__)
for name in tool_functions:
globals()[name] = getattr(module, name)
__all__ = ["tools", "llm_tools", "enabled_tools", "tool_functions"]
__all__ = ["all_tools", "ToolEntry", "llm_tools", "enabled_tools", "tool_functions"]

View File

@ -1,4 +1,6 @@
import os
from pydantic import BaseModel, Field, model_validator # type: ignore
from typing import List, Optional, Generator, ClassVar, Any, Dict
from datetime import datetime
from typing import (
Any,
@ -345,8 +347,18 @@ async def AnalyzeSite(llm, model: str, url: str, question: str):
# %%
tools = [
{
class Function(BaseModel):
name: str
description: str
parameters: Dict[str, Any]
returns: Optional[Dict[str, Any]] = {}
class Tool(BaseModel):
type: str
function: Function
tools : List[Tool] = [
Tool.model_validate({
"type": "function",
"function": {
"name": "TickerValue",
@ -363,8 +375,8 @@ tools = [
"additionalProperties": False,
},
},
},
{
}),
Tool.model_validate({
"type": "function",
"function": {
"name": "AnalyzeSite",
@ -402,8 +414,8 @@ tools = [
},
},
},
},
{
}),
Tool.model_validate({
"type": "function",
"function": {
"name": "DateTime",
@ -419,8 +431,8 @@ tools = [
"required": [],
},
},
},
{
}),
Tool.model_validate({
"type": "function",
"function": {
"name": "WeatherForecast",
@ -443,18 +455,22 @@ tools = [
"additionalProperties": False,
},
},
},
}),
]
class ToolEntry(BaseModel):
enabled: bool = True
tool: Tool
def llm_tools(tools):
return [tool for tool in tools if tool.get("enabled", False) == True]
def llm_tools(tools: List[ToolEntry]) -> List[Dict[str, Any]]:
return [entry.tool.model_dump(mode='json') for entry in tools if entry.enabled == True]
def all_tools() -> List[ToolEntry]:
return [ToolEntry(tool=tool) for tool in tools]
def enabled_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
return [{**tool, "enabled": True} for tool in tools]
def enabled_tools(tools: List[ToolEntry]) -> List[ToolEntry]:
return [ToolEntry(tool=entry.tool) for entry in tools if entry.enabled == True]
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite"]
__all__ = ["tools", "llm_tools", "enabled_tools", "tool_functions"]
__all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"]
# __all__.extend(__tool_functions__) # type: ignore

View File

@ -24,6 +24,11 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from .rag import RagEntry
from .message import Tunables
class Question(BaseModel):
question: str
tunables: Tunables = Field(default_factory=Tunables)
class User(BaseModel):
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
@ -35,7 +40,7 @@ class User(BaseModel):
last_name: str = ""
full_name: str = ""
contact_info : Dict[str, str] = {}
user_questions : List[str] = []
user_questions : List[Question] = []
#context: Optional[List[Context]] = []
# file_watcher : ChromaDBFileWatcher = set by initialize
@ -193,7 +198,16 @@ class User(BaseModel):
self.last_name = info.get("last_name", "")
self.full_name = info.get("full_name", f"{self.first_name} {self.last_name}")
self.contact_info = info.get("contact_info", {})
self.user_questions = info.get("questions", [ f"Tell me about {self.first_name}.", f"What are {self.first_name}'s professional strengths?"])
questions = info.get("questions", [ f"Tell me about {self.first_name}.", f"What are {self.first_name}'s professional strengths?"])
self.user_questions = []
for question in questions:
if type(question) == str:
self.user_questions.append(Question(question=question, tunables=Tunables(enable_tools=False)))
else:
try:
self.user_questions.append(Question.model_validate(question))
except Exception as e:
logger.info(f"Unable to initialize all questions from {user_info}")
os.makedirs(persist_directory, exist_ok=True)
os.makedirs(watch_directory, exist_ok=True)