Working with new ui
This commit is contained in:
parent
4e3208190a
commit
705bbf146e
@ -44,6 +44,7 @@ export type {
|
||||
ChatQueryInterface,
|
||||
Query,
|
||||
ChatSubmitQueryInterface,
|
||||
Tunables,
|
||||
};
|
||||
|
||||
export {
|
||||
|
@ -133,7 +133,7 @@ interface BackstoryPageContainerProps {
|
||||
const BackstoryPageContainer = (props : BackstoryPageContainerProps) => {
|
||||
const { children, sx } = props;
|
||||
return (
|
||||
<Container maxWidth="xl" sx={{ mt: 2, mb: 2, ...sx }}>
|
||||
<Container maxWidth="xl" sx={{ mt: 2, mb: 2, height: "calc(1024px - 72px)", ...sx }}>
|
||||
<Paper
|
||||
elevation={2}
|
||||
sx={{
|
||||
|
@ -6,7 +6,7 @@ import MuiMarkdown from 'mui-markdown';
|
||||
|
||||
import { BackstoryPageProps } from '../Components//BackstoryTab';
|
||||
import { Conversation, ConversationHandle } from '../Components/Conversation';
|
||||
import { ChatQuery } from '../Components/ChatQuery';
|
||||
import { ChatQuery, Tunables } from '../Components/ChatQuery';
|
||||
import { MessageList } from '../Components/Message';
|
||||
|
||||
import { connectionBase } from '../Global';
|
||||
@ -17,7 +17,10 @@ type UserData = {
|
||||
last_name: string;
|
||||
full_name: string;
|
||||
contact_info: Record<string, string>;
|
||||
questions: string[];
|
||||
questions: [{
|
||||
question: string;
|
||||
tunables?: Tunables
|
||||
}]
|
||||
};
|
||||
|
||||
const HomePage = forwardRef<ConversationHandle, BackstoryPageProps>((props: BackstoryPageProps, ref) => {
|
||||
@ -52,8 +55,8 @@ What would you like to know about ${user.first_name}?
|
||||
|
||||
setQuestions([
|
||||
<Box sx={{ display: "flex", flexDirection: isMobile ? "column" : "row" }}>
|
||||
{user.questions.map((q: string, i: number) =>
|
||||
<ChatQuery key={i} query={{ prompt: q, tunables: { enable_tools: false } }} submitQuery={submitQuery} />
|
||||
{user.questions.map(({ question, tunables }, i: number) =>
|
||||
<ChatQuery key={i} query={{ prompt: question, tunables: tunables }} submitQuery={submitQuery} />
|
||||
)}
|
||||
</Box>,
|
||||
<Box sx={{ p: 1 }}>
|
||||
|
@ -444,8 +444,8 @@ class WebServer:
|
||||
response["rags"] = [ r.model_dump(mode="json") for r in context.rags ]
|
||||
case "tools":
|
||||
logger.info(f"Resetting {reset_operation}")
|
||||
context.tools = Tools.enabled_tools(Tools.tools)
|
||||
response["tools"] = context.tools
|
||||
context.tools = Tools.all_tools()
|
||||
response["tools"] = Tools.llm_tools(context.tools)
|
||||
case "history":
|
||||
reset_map = {
|
||||
"job_description": (
|
||||
@ -507,8 +507,8 @@ class WebServer:
|
||||
match k:
|
||||
case "tools":
|
||||
from typing import Any
|
||||
# { "tools": [{ "tool": tool?.name, "enabled": tool.enabled }] }
|
||||
tools: list[dict[str, Any]] = data[k]
|
||||
# { "tools": [{ "tool": tool.name, "enabled": tool.enabled }] }
|
||||
tools: List[Dict[str, Any]] = data[k]
|
||||
if not tools:
|
||||
return JSONResponse(
|
||||
{
|
||||
@ -518,20 +518,15 @@ class WebServer:
|
||||
)
|
||||
for tool in tools:
|
||||
for context_tool in context.tools:
|
||||
if context_tool["function"]["name"] == tool["name"]:
|
||||
context_tool["enabled"] = tool["enabled"]
|
||||
if context_tool.tool.function.name == tool["name"]:
|
||||
context_tool.enabled = tool.get("enabled", True)
|
||||
self.save_context(context_id)
|
||||
return JSONResponse(
|
||||
{
|
||||
"tools": [
|
||||
{
|
||||
**t["function"],
|
||||
"enabled": t["enabled"],
|
||||
}
|
||||
for t in context.tools
|
||||
]
|
||||
}
|
||||
)
|
||||
return JSONResponse({
|
||||
"tools": [{
|
||||
**t.function.model_dump(mode='json'),
|
||||
"enabled": t.enabled,
|
||||
} for t in context.tools]
|
||||
})
|
||||
|
||||
case "rags":
|
||||
from typing import Any
|
||||
@ -581,7 +576,7 @@ class WebServer:
|
||||
"last_name": user.last_name,
|
||||
"full_name": user.full_name,
|
||||
"contact_info": user.contact_info,
|
||||
"questions": user.user_questions,
|
||||
"questions": [ q.model_dump(mode='json') for q in user.user_questions],
|
||||
}
|
||||
return JSONResponse(user_data)
|
||||
|
||||
@ -600,13 +595,10 @@ class WebServer:
|
||||
{
|
||||
"system_prompt": agent.system_prompt,
|
||||
"rags": [ r.model_dump(mode="json") for r in context.rags ],
|
||||
"tools": [
|
||||
{
|
||||
**t["function"],
|
||||
"enabled": t["enabled"],
|
||||
}
|
||||
for t in context.tools
|
||||
],
|
||||
"tools": [{
|
||||
**t.function.model_dump(mode='json'),
|
||||
"enabled": t.enabled,
|
||||
} for t in context.tools],
|
||||
}
|
||||
)
|
||||
|
||||
@ -689,19 +681,13 @@ class WebServer:
|
||||
result = json.dumps(result) + "\n"
|
||||
message.network_packets += 1
|
||||
message.network_bytes += len(result)
|
||||
disconnected = await request.is_disconnected()
|
||||
if disconnected:
|
||||
logger.info("Disconnect detected. Continuing generation to store in cache.")
|
||||
disconnected = True
|
||||
|
||||
if not disconnected:
|
||||
yield result
|
||||
if await request.is_disconnected():
|
||||
logger.info("Disconnect detected. Aborting generation.")
|
||||
context.processing = False
|
||||
# Save context on completion or error
|
||||
message.prompt = query.prompt
|
||||
message.status = "error"
|
||||
message.response = (
|
||||
"Client disconnected during generation."
|
||||
)
|
||||
agent.conversation.add(message)
|
||||
self.save_context(context_id)
|
||||
return
|
||||
|
||||
current_time = time.perf_counter()
|
||||
if current_time - start_time > LLM_TIMEOUT:
|
||||
@ -711,6 +697,7 @@ class WebServer:
|
||||
logger.info(message.response + " Ending session")
|
||||
result = message.model_dump(by_alias=True, mode="json")
|
||||
result = json.dumps(result) + "\n"
|
||||
if not disconnected:
|
||||
yield result
|
||||
|
||||
if message.status == "error":
|
||||
@ -801,8 +788,8 @@ class WebServer:
|
||||
modify = data["tool"]
|
||||
enabled = data["enabled"]
|
||||
for tool in context.tools:
|
||||
if modify == tool["function"]["name"]:
|
||||
tool["enabled"] = enabled
|
||||
if modify == tool.function.name:
|
||||
tool.enabled = enabled
|
||||
self.save_context(context_id)
|
||||
return JSONResponse(context.tools)
|
||||
return JSONResponse(
|
||||
@ -971,7 +958,7 @@ class WebServer:
|
||||
id=context_id,
|
||||
user=user,
|
||||
rags=[ rag.model_copy() for rag in user.rags ],
|
||||
tools=Tools.enabled_tools(Tools.tools)
|
||||
tools=Tools.all_tools()
|
||||
)
|
||||
|
||||
return self.contexts[context_id]
|
||||
@ -1005,13 +992,13 @@ class WebServer:
|
||||
id=context_id,
|
||||
user=user,
|
||||
rags=[ rag.model_copy() for rag in user.rags ],
|
||||
tools=Tools.enabled_tools(Tools.tools)
|
||||
tools=Tools.all_tools()
|
||||
)
|
||||
else:
|
||||
context = Context(
|
||||
user=user,
|
||||
rags=[ rag.model_copy() for rag in user.rags ],
|
||||
tools=Tools.enabled_tools(Tools.tools)
|
||||
tools=Tools.all_tools()
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.error(e)
|
||||
|
@ -36,13 +36,11 @@ from ..metrics import Metrics
|
||||
from ..tools import TickerValue, WeatherForecast, AnalyzeSite, DateTime, llm_tools # type: ignore -- dynamically added to __all__
|
||||
from ..conversation import Conversation
|
||||
|
||||
|
||||
class LLMMessage(BaseModel):
|
||||
role: str = Field(default="")
|
||||
content: str = Field(default="")
|
||||
tool_calls: Optional[List[Dict]] = Field(default={}, exclude=True)
|
||||
|
||||
|
||||
class Agent(BaseModel, ABC):
|
||||
"""
|
||||
Base class for all agent types.
|
||||
|
@ -9,8 +9,8 @@ from uuid import uuid4
|
||||
import traceback
|
||||
|
||||
from . rag import RagEntry
|
||||
from . import tools as Tools
|
||||
from .agents import AnyAgent
|
||||
from . tools import ToolEntry
|
||||
from . agents import AnyAgent
|
||||
from . import User
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -28,7 +28,7 @@ class Context(BaseModel):
|
||||
default_factory=lambda: str(uuid4()),
|
||||
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
|
||||
)
|
||||
tools: List[dict]
|
||||
tools: List[ToolEntry]
|
||||
rags: List[RagEntry]
|
||||
username: str = "__invalid__"
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import importlib
|
||||
|
||||
from .basetools import tools, llm_tools, enabled_tools, tool_functions
|
||||
from .basetools import all_tools, ToolEntry, llm_tools, enabled_tools, tool_functions
|
||||
from ..setup_logging import setup_logging
|
||||
from .. import defines
|
||||
|
||||
@ -11,4 +11,4 @@ module = importlib.import_module(".basetools", package=__package__)
|
||||
for name in tool_functions:
|
||||
globals()[name] = getattr(module, name)
|
||||
|
||||
__all__ = ["tools", "llm_tools", "enabled_tools", "tool_functions"]
|
||||
__all__ = ["all_tools", "ToolEntry", "llm_tools", "enabled_tools", "tool_functions"]
|
||||
|
@ -1,4 +1,6 @@
|
||||
import os
|
||||
from pydantic import BaseModel, Field, model_validator # type: ignore
|
||||
from typing import List, Optional, Generator, ClassVar, Any, Dict
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
@ -345,8 +347,18 @@ async def AnalyzeSite(llm, model: str, url: str, question: str):
|
||||
|
||||
|
||||
# %%
|
||||
tools = [
|
||||
{
|
||||
class Function(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
parameters: Dict[str, Any]
|
||||
returns: Optional[Dict[str, Any]] = {}
|
||||
|
||||
class Tool(BaseModel):
|
||||
type: str
|
||||
function: Function
|
||||
|
||||
tools : List[Tool] = [
|
||||
Tool.model_validate({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "TickerValue",
|
||||
@ -363,8 +375,8 @@ tools = [
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
}),
|
||||
Tool.model_validate({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "AnalyzeSite",
|
||||
@ -402,8 +414,8 @@ tools = [
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
}),
|
||||
Tool.model_validate({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "DateTime",
|
||||
@ -419,8 +431,8 @@ tools = [
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
}),
|
||||
Tool.model_validate({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "WeatherForecast",
|
||||
@ -443,18 +455,22 @@ tools = [
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
]
|
||||
|
||||
class ToolEntry(BaseModel):
|
||||
enabled: bool = True
|
||||
tool: Tool
|
||||
|
||||
def llm_tools(tools):
|
||||
return [tool for tool in tools if tool.get("enabled", False) == True]
|
||||
def llm_tools(tools: List[ToolEntry]) -> List[Dict[str, Any]]:
|
||||
return [entry.tool.model_dump(mode='json') for entry in tools if entry.enabled == True]
|
||||
|
||||
def all_tools() -> List[ToolEntry]:
|
||||
return [ToolEntry(tool=tool) for tool in tools]
|
||||
|
||||
def enabled_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
return [{**tool, "enabled": True} for tool in tools]
|
||||
|
||||
def enabled_tools(tools: List[ToolEntry]) -> List[ToolEntry]:
|
||||
return [ToolEntry(tool=entry.tool) for entry in tools if entry.enabled == True]
|
||||
|
||||
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite"]
|
||||
__all__ = ["tools", "llm_tools", "enabled_tools", "tool_functions"]
|
||||
__all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"]
|
||||
# __all__.extend(__tool_functions__) # type: ignore
|
||||
|
@ -24,6 +24,11 @@ logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .rag import RagEntry
|
||||
from .message import Tunables
|
||||
|
||||
class Question(BaseModel):
|
||||
question: str
|
||||
tunables: Tunables = Field(default_factory=Tunables)
|
||||
|
||||
class User(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
|
||||
@ -35,7 +40,7 @@ class User(BaseModel):
|
||||
last_name: str = ""
|
||||
full_name: str = ""
|
||||
contact_info : Dict[str, str] = {}
|
||||
user_questions : List[str] = []
|
||||
user_questions : List[Question] = []
|
||||
|
||||
#context: Optional[List[Context]] = []
|
||||
# file_watcher : ChromaDBFileWatcher = set by initialize
|
||||
@ -193,7 +198,16 @@ class User(BaseModel):
|
||||
self.last_name = info.get("last_name", "")
|
||||
self.full_name = info.get("full_name", f"{self.first_name} {self.last_name}")
|
||||
self.contact_info = info.get("contact_info", {})
|
||||
self.user_questions = info.get("questions", [ f"Tell me about {self.first_name}.", f"What are {self.first_name}'s professional strengths?"])
|
||||
questions = info.get("questions", [ f"Tell me about {self.first_name}.", f"What are {self.first_name}'s professional strengths?"])
|
||||
self.user_questions = []
|
||||
for question in questions:
|
||||
if type(question) == str:
|
||||
self.user_questions.append(Question(question=question, tunables=Tunables(enable_tools=False)))
|
||||
else:
|
||||
try:
|
||||
self.user_questions.append(Question.model_validate(question))
|
||||
except Exception as e:
|
||||
logger.info(f"Unable to initialize all questions from {user_info}")
|
||||
|
||||
os.makedirs(persist_directory, exist_ok=True)
|
||||
os.makedirs(watch_directory, exist_ok=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user