Working with new ui
This commit is contained in:
parent
4e3208190a
commit
705bbf146e
@ -44,6 +44,7 @@ export type {
|
|||||||
ChatQueryInterface,
|
ChatQueryInterface,
|
||||||
Query,
|
Query,
|
||||||
ChatSubmitQueryInterface,
|
ChatSubmitQueryInterface,
|
||||||
|
Tunables,
|
||||||
};
|
};
|
||||||
|
|
||||||
export {
|
export {
|
||||||
|
@ -133,7 +133,7 @@ interface BackstoryPageContainerProps {
|
|||||||
const BackstoryPageContainer = (props : BackstoryPageContainerProps) => {
|
const BackstoryPageContainer = (props : BackstoryPageContainerProps) => {
|
||||||
const { children, sx } = props;
|
const { children, sx } = props;
|
||||||
return (
|
return (
|
||||||
<Container maxWidth="xl" sx={{ mt: 2, mb: 2, ...sx }}>
|
<Container maxWidth="xl" sx={{ mt: 2, mb: 2, height: "calc(1024px - 72px)", ...sx }}>
|
||||||
<Paper
|
<Paper
|
||||||
elevation={2}
|
elevation={2}
|
||||||
sx={{
|
sx={{
|
||||||
|
@ -6,7 +6,7 @@ import MuiMarkdown from 'mui-markdown';
|
|||||||
|
|
||||||
import { BackstoryPageProps } from '../Components//BackstoryTab';
|
import { BackstoryPageProps } from '../Components//BackstoryTab';
|
||||||
import { Conversation, ConversationHandle } from '../Components/Conversation';
|
import { Conversation, ConversationHandle } from '../Components/Conversation';
|
||||||
import { ChatQuery } from '../Components/ChatQuery';
|
import { ChatQuery, Tunables } from '../Components/ChatQuery';
|
||||||
import { MessageList } from '../Components/Message';
|
import { MessageList } from '../Components/Message';
|
||||||
|
|
||||||
import { connectionBase } from '../Global';
|
import { connectionBase } from '../Global';
|
||||||
@ -17,7 +17,10 @@ type UserData = {
|
|||||||
last_name: string;
|
last_name: string;
|
||||||
full_name: string;
|
full_name: string;
|
||||||
contact_info: Record<string, string>;
|
contact_info: Record<string, string>;
|
||||||
questions: string[];
|
questions: [{
|
||||||
|
question: string;
|
||||||
|
tunables?: Tunables
|
||||||
|
}]
|
||||||
};
|
};
|
||||||
|
|
||||||
const HomePage = forwardRef<ConversationHandle, BackstoryPageProps>((props: BackstoryPageProps, ref) => {
|
const HomePage = forwardRef<ConversationHandle, BackstoryPageProps>((props: BackstoryPageProps, ref) => {
|
||||||
@ -52,8 +55,8 @@ What would you like to know about ${user.first_name}?
|
|||||||
|
|
||||||
setQuestions([
|
setQuestions([
|
||||||
<Box sx={{ display: "flex", flexDirection: isMobile ? "column" : "row" }}>
|
<Box sx={{ display: "flex", flexDirection: isMobile ? "column" : "row" }}>
|
||||||
{user.questions.map((q: string, i: number) =>
|
{user.questions.map(({ question, tunables }, i: number) =>
|
||||||
<ChatQuery key={i} query={{ prompt: q, tunables: { enable_tools: false } }} submitQuery={submitQuery} />
|
<ChatQuery key={i} query={{ prompt: question, tunables: tunables }} submitQuery={submitQuery} />
|
||||||
)}
|
)}
|
||||||
</Box>,
|
</Box>,
|
||||||
<Box sx={{ p: 1 }}>
|
<Box sx={{ p: 1 }}>
|
||||||
|
@ -444,8 +444,8 @@ class WebServer:
|
|||||||
response["rags"] = [ r.model_dump(mode="json") for r in context.rags ]
|
response["rags"] = [ r.model_dump(mode="json") for r in context.rags ]
|
||||||
case "tools":
|
case "tools":
|
||||||
logger.info(f"Resetting {reset_operation}")
|
logger.info(f"Resetting {reset_operation}")
|
||||||
context.tools = Tools.enabled_tools(Tools.tools)
|
context.tools = Tools.all_tools()
|
||||||
response["tools"] = context.tools
|
response["tools"] = Tools.llm_tools(context.tools)
|
||||||
case "history":
|
case "history":
|
||||||
reset_map = {
|
reset_map = {
|
||||||
"job_description": (
|
"job_description": (
|
||||||
@ -507,8 +507,8 @@ class WebServer:
|
|||||||
match k:
|
match k:
|
||||||
case "tools":
|
case "tools":
|
||||||
from typing import Any
|
from typing import Any
|
||||||
# { "tools": [{ "tool": tool?.name, "enabled": tool.enabled }] }
|
# { "tools": [{ "tool": tool.name, "enabled": tool.enabled }] }
|
||||||
tools: list[dict[str, Any]] = data[k]
|
tools: List[Dict[str, Any]] = data[k]
|
||||||
if not tools:
|
if not tools:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
{
|
{
|
||||||
@ -518,20 +518,15 @@ class WebServer:
|
|||||||
)
|
)
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
for context_tool in context.tools:
|
for context_tool in context.tools:
|
||||||
if context_tool["function"]["name"] == tool["name"]:
|
if context_tool.tool.function.name == tool["name"]:
|
||||||
context_tool["enabled"] = tool["enabled"]
|
context_tool.enabled = tool.get("enabled", True)
|
||||||
self.save_context(context_id)
|
self.save_context(context_id)
|
||||||
return JSONResponse(
|
return JSONResponse({
|
||||||
{
|
"tools": [{
|
||||||
"tools": [
|
**t.function.model_dump(mode='json'),
|
||||||
{
|
"enabled": t.enabled,
|
||||||
**t["function"],
|
} for t in context.tools]
|
||||||
"enabled": t["enabled"],
|
})
|
||||||
}
|
|
||||||
for t in context.tools
|
|
||||||
]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
case "rags":
|
case "rags":
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -581,7 +576,7 @@ class WebServer:
|
|||||||
"last_name": user.last_name,
|
"last_name": user.last_name,
|
||||||
"full_name": user.full_name,
|
"full_name": user.full_name,
|
||||||
"contact_info": user.contact_info,
|
"contact_info": user.contact_info,
|
||||||
"questions": user.user_questions,
|
"questions": [ q.model_dump(mode='json') for q in user.user_questions],
|
||||||
}
|
}
|
||||||
return JSONResponse(user_data)
|
return JSONResponse(user_data)
|
||||||
|
|
||||||
@ -600,13 +595,10 @@ class WebServer:
|
|||||||
{
|
{
|
||||||
"system_prompt": agent.system_prompt,
|
"system_prompt": agent.system_prompt,
|
||||||
"rags": [ r.model_dump(mode="json") for r in context.rags ],
|
"rags": [ r.model_dump(mode="json") for r in context.rags ],
|
||||||
"tools": [
|
"tools": [{
|
||||||
{
|
**t.function.model_dump(mode='json'),
|
||||||
**t["function"],
|
"enabled": t.enabled,
|
||||||
"enabled": t["enabled"],
|
} for t in context.tools],
|
||||||
}
|
|
||||||
for t in context.tools
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -689,19 +681,13 @@ class WebServer:
|
|||||||
result = json.dumps(result) + "\n"
|
result = json.dumps(result) + "\n"
|
||||||
message.network_packets += 1
|
message.network_packets += 1
|
||||||
message.network_bytes += len(result)
|
message.network_bytes += len(result)
|
||||||
|
disconnected = await request.is_disconnected()
|
||||||
|
if disconnected:
|
||||||
|
logger.info("Disconnect detected. Continuing generation to store in cache.")
|
||||||
|
disconnected = True
|
||||||
|
|
||||||
|
if not disconnected:
|
||||||
yield result
|
yield result
|
||||||
if await request.is_disconnected():
|
|
||||||
logger.info("Disconnect detected. Aborting generation.")
|
|
||||||
context.processing = False
|
|
||||||
# Save context on completion or error
|
|
||||||
message.prompt = query.prompt
|
|
||||||
message.status = "error"
|
|
||||||
message.response = (
|
|
||||||
"Client disconnected during generation."
|
|
||||||
)
|
|
||||||
agent.conversation.add(message)
|
|
||||||
self.save_context(context_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
current_time = time.perf_counter()
|
current_time = time.perf_counter()
|
||||||
if current_time - start_time > LLM_TIMEOUT:
|
if current_time - start_time > LLM_TIMEOUT:
|
||||||
@ -711,6 +697,7 @@ class WebServer:
|
|||||||
logger.info(message.response + " Ending session")
|
logger.info(message.response + " Ending session")
|
||||||
result = message.model_dump(by_alias=True, mode="json")
|
result = message.model_dump(by_alias=True, mode="json")
|
||||||
result = json.dumps(result) + "\n"
|
result = json.dumps(result) + "\n"
|
||||||
|
if not disconnected:
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
if message.status == "error":
|
if message.status == "error":
|
||||||
@ -801,8 +788,8 @@ class WebServer:
|
|||||||
modify = data["tool"]
|
modify = data["tool"]
|
||||||
enabled = data["enabled"]
|
enabled = data["enabled"]
|
||||||
for tool in context.tools:
|
for tool in context.tools:
|
||||||
if modify == tool["function"]["name"]:
|
if modify == tool.function.name:
|
||||||
tool["enabled"] = enabled
|
tool.enabled = enabled
|
||||||
self.save_context(context_id)
|
self.save_context(context_id)
|
||||||
return JSONResponse(context.tools)
|
return JSONResponse(context.tools)
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
@ -971,7 +958,7 @@ class WebServer:
|
|||||||
id=context_id,
|
id=context_id,
|
||||||
user=user,
|
user=user,
|
||||||
rags=[ rag.model_copy() for rag in user.rags ],
|
rags=[ rag.model_copy() for rag in user.rags ],
|
||||||
tools=Tools.enabled_tools(Tools.tools)
|
tools=Tools.all_tools()
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.contexts[context_id]
|
return self.contexts[context_id]
|
||||||
@ -1005,13 +992,13 @@ class WebServer:
|
|||||||
id=context_id,
|
id=context_id,
|
||||||
user=user,
|
user=user,
|
||||||
rags=[ rag.model_copy() for rag in user.rags ],
|
rags=[ rag.model_copy() for rag in user.rags ],
|
||||||
tools=Tools.enabled_tools(Tools.tools)
|
tools=Tools.all_tools()
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
context = Context(
|
context = Context(
|
||||||
user=user,
|
user=user,
|
||||||
rags=[ rag.model_copy() for rag in user.rags ],
|
rags=[ rag.model_copy() for rag in user.rags ],
|
||||||
tools=Tools.enabled_tools(Tools.tools)
|
tools=Tools.all_tools()
|
||||||
)
|
)
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
|
@ -36,13 +36,11 @@ from ..metrics import Metrics
|
|||||||
from ..tools import TickerValue, WeatherForecast, AnalyzeSite, DateTime, llm_tools # type: ignore -- dynamically added to __all__
|
from ..tools import TickerValue, WeatherForecast, AnalyzeSite, DateTime, llm_tools # type: ignore -- dynamically added to __all__
|
||||||
from ..conversation import Conversation
|
from ..conversation import Conversation
|
||||||
|
|
||||||
|
|
||||||
class LLMMessage(BaseModel):
|
class LLMMessage(BaseModel):
|
||||||
role: str = Field(default="")
|
role: str = Field(default="")
|
||||||
content: str = Field(default="")
|
content: str = Field(default="")
|
||||||
tool_calls: Optional[List[Dict]] = Field(default={}, exclude=True)
|
tool_calls: Optional[List[Dict]] = Field(default={}, exclude=True)
|
||||||
|
|
||||||
|
|
||||||
class Agent(BaseModel, ABC):
|
class Agent(BaseModel, ABC):
|
||||||
"""
|
"""
|
||||||
Base class for all agent types.
|
Base class for all agent types.
|
||||||
|
@ -9,8 +9,8 @@ from uuid import uuid4
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from . rag import RagEntry
|
from . rag import RagEntry
|
||||||
from . import tools as Tools
|
from . tools import ToolEntry
|
||||||
from .agents import AnyAgent
|
from . agents import AnyAgent
|
||||||
from . import User
|
from . import User
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -28,7 +28,7 @@ class Context(BaseModel):
|
|||||||
default_factory=lambda: str(uuid4()),
|
default_factory=lambda: str(uuid4()),
|
||||||
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
|
pattern=r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$",
|
||||||
)
|
)
|
||||||
tools: List[dict]
|
tools: List[ToolEntry]
|
||||||
rags: List[RagEntry]
|
rags: List[RagEntry]
|
||||||
username: str = "__invalid__"
|
username: str = "__invalid__"
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
from .basetools import tools, llm_tools, enabled_tools, tool_functions
|
from .basetools import all_tools, ToolEntry, llm_tools, enabled_tools, tool_functions
|
||||||
from ..setup_logging import setup_logging
|
from ..setup_logging import setup_logging
|
||||||
from .. import defines
|
from .. import defines
|
||||||
|
|
||||||
@ -11,4 +11,4 @@ module = importlib.import_module(".basetools", package=__package__)
|
|||||||
for name in tool_functions:
|
for name in tool_functions:
|
||||||
globals()[name] = getattr(module, name)
|
globals()[name] = getattr(module, name)
|
||||||
|
|
||||||
__all__ = ["tools", "llm_tools", "enabled_tools", "tool_functions"]
|
__all__ = ["all_tools", "ToolEntry", "llm_tools", "enabled_tools", "tool_functions"]
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
from pydantic import BaseModel, Field, model_validator # type: ignore
|
||||||
|
from typing import List, Optional, Generator, ClassVar, Any, Dict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@ -345,8 +347,18 @@ async def AnalyzeSite(llm, model: str, url: str, question: str):
|
|||||||
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
tools = [
|
class Function(BaseModel):
|
||||||
{
|
name: str
|
||||||
|
description: str
|
||||||
|
parameters: Dict[str, Any]
|
||||||
|
returns: Optional[Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
class Tool(BaseModel):
|
||||||
|
type: str
|
||||||
|
function: Function
|
||||||
|
|
||||||
|
tools : List[Tool] = [
|
||||||
|
Tool.model_validate({
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "TickerValue",
|
"name": "TickerValue",
|
||||||
@ -363,8 +375,8 @@ tools = [
|
|||||||
"additionalProperties": False,
|
"additionalProperties": False,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
{
|
Tool.model_validate({
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "AnalyzeSite",
|
"name": "AnalyzeSite",
|
||||||
@ -402,8 +414,8 @@ tools = [
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
{
|
Tool.model_validate({
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "DateTime",
|
"name": "DateTime",
|
||||||
@ -419,8 +431,8 @@ tools = [
|
|||||||
"required": [],
|
"required": [],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
{
|
Tool.model_validate({
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "WeatherForecast",
|
"name": "WeatherForecast",
|
||||||
@ -443,18 +455,22 @@ tools = [
|
|||||||
"additionalProperties": False,
|
"additionalProperties": False,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
class ToolEntry(BaseModel):
|
||||||
|
enabled: bool = True
|
||||||
|
tool: Tool
|
||||||
|
|
||||||
def llm_tools(tools):
|
def llm_tools(tools: List[ToolEntry]) -> List[Dict[str, Any]]:
|
||||||
return [tool for tool in tools if tool.get("enabled", False) == True]
|
return [entry.tool.model_dump(mode='json') for entry in tools if entry.enabled == True]
|
||||||
|
|
||||||
|
def all_tools() -> List[ToolEntry]:
|
||||||
|
return [ToolEntry(tool=tool) for tool in tools]
|
||||||
|
|
||||||
def enabled_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
def enabled_tools(tools: List[ToolEntry]) -> List[ToolEntry]:
|
||||||
return [{**tool, "enabled": True} for tool in tools]
|
return [ToolEntry(tool=entry.tool) for entry in tools if entry.enabled == True]
|
||||||
|
|
||||||
|
|
||||||
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite"]
|
tool_functions = ["DateTime", "WeatherForecast", "TickerValue", "AnalyzeSite"]
|
||||||
__all__ = ["tools", "llm_tools", "enabled_tools", "tool_functions"]
|
__all__ = ["ToolEntry", "all_tools", "llm_tools", "enabled_tools", "tool_functions"]
|
||||||
# __all__.extend(__tool_functions__) # type: ignore
|
# __all__.extend(__tool_functions__) # type: ignore
|
||||||
|
@ -24,6 +24,11 @@ logging.basicConfig(level=logging.INFO)
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from .rag import RagEntry
|
from .rag import RagEntry
|
||||||
|
from .message import Tunables
|
||||||
|
|
||||||
|
class Question(BaseModel):
|
||||||
|
question: str
|
||||||
|
tunables: Tunables = Field(default_factory=Tunables)
|
||||||
|
|
||||||
class User(BaseModel):
|
class User(BaseModel):
|
||||||
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
|
model_config = {"arbitrary_types_allowed": True} # Allow ChromaDBFileWatcher, etc
|
||||||
@ -35,7 +40,7 @@ class User(BaseModel):
|
|||||||
last_name: str = ""
|
last_name: str = ""
|
||||||
full_name: str = ""
|
full_name: str = ""
|
||||||
contact_info : Dict[str, str] = {}
|
contact_info : Dict[str, str] = {}
|
||||||
user_questions : List[str] = []
|
user_questions : List[Question] = []
|
||||||
|
|
||||||
#context: Optional[List[Context]] = []
|
#context: Optional[List[Context]] = []
|
||||||
# file_watcher : ChromaDBFileWatcher = set by initialize
|
# file_watcher : ChromaDBFileWatcher = set by initialize
|
||||||
@ -193,7 +198,16 @@ class User(BaseModel):
|
|||||||
self.last_name = info.get("last_name", "")
|
self.last_name = info.get("last_name", "")
|
||||||
self.full_name = info.get("full_name", f"{self.first_name} {self.last_name}")
|
self.full_name = info.get("full_name", f"{self.first_name} {self.last_name}")
|
||||||
self.contact_info = info.get("contact_info", {})
|
self.contact_info = info.get("contact_info", {})
|
||||||
self.user_questions = info.get("questions", [ f"Tell me about {self.first_name}.", f"What are {self.first_name}'s professional strengths?"])
|
questions = info.get("questions", [ f"Tell me about {self.first_name}.", f"What are {self.first_name}'s professional strengths?"])
|
||||||
|
self.user_questions = []
|
||||||
|
for question in questions:
|
||||||
|
if type(question) == str:
|
||||||
|
self.user_questions.append(Question(question=question, tunables=Tunables(enable_tools=False)))
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
self.user_questions.append(Question.model_validate(question))
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f"Unable to initialize all questions from {user_info}")
|
||||||
|
|
||||||
os.makedirs(persist_directory, exist_ok=True)
|
os.makedirs(persist_directory, exist_ok=True)
|
||||||
os.makedirs(watch_directory, exist_ok=True)
|
os.makedirs(watch_directory, exist_ok=True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user