diff --git a/src/ketr-chat/src/App.css b/src/ketr-chat/src/App.css index d35b652..ae8fed1 100644 --- a/src/ketr-chat/src/App.css +++ b/src/ketr-chat/src/App.css @@ -9,14 +9,13 @@ div { height: 100%; flex-direction: column; } - + .ChatBox { display: flex; flex-direction: column; flex-grow: 1; - border: 1px solid red; - max-width: 800px; - margin: 0 auto; + max-width: 800px; + margin: 0 auto; } .Controls { @@ -30,10 +29,16 @@ div { box-sizing: border-box; overflow-x: visible; min-width: 10rem; + width: 100%; flex-grow: 1; - border: 1px solid magenta; } +@media (min-width: 768px) { + .Controls { + width: 600px; /* or whatever you prefer for a desktop */ + max-width: 80vw; /* Optional: Prevent it from taking up too much space */ + } +} .Conversation { display: flex; @@ -54,7 +59,7 @@ div { margin-left: 1rem; border-radius: 0.25rem; min-width: 80%; - max-width: 80%; + max-width: 80%; justify-self: right; display: flex; white-space: pre-wrap; @@ -62,7 +67,7 @@ div { word-break: break-word; flex-direction: column; align-items: self-end; - align-self: end; + align-self: end; } .assistant-message { @@ -95,6 +100,7 @@ div { font-size: 0.75rem; padding: 0.125rem; } + /* Reduce general whitespace in markdown content */ .markdown-content p { margin-top: 0.5rem; diff --git a/src/ketr-chat/src/App.tsx b/src/ketr-chat/src/App.tsx index d8a3994..bece700 100644 --- a/src/ketr-chat/src/App.tsx +++ b/src/ketr-chat/src/App.tsx @@ -1,7 +1,8 @@ -import React, { useState, useEffect, useRef, useCallback } from 'react'; +import React, { useState, useEffect, useRef } from 'react'; import FormGroup from '@mui/material/FormGroup'; import FormControlLabel from '@mui/material/FormControlLabel'; import Switch from '@mui/material/Switch'; +import Divider from '@mui/material/Divider'; import Snackbar, { SnackbarCloseReason } from '@mui/material/Snackbar'; import Alert from '@mui/material/Alert'; import TextField from '@mui/material/TextField'; @@ -46,7 +47,15 @@ const getConnectionBase = (loc: any): string => { } type Tool = { - label: string, + type: string, + function?: { + name: string, + description: string, + parameters?: any, + returns?: any + }, + name?: string, + description?: string, enabled: boolean }; @@ -56,11 +65,14 @@ interface ControlsParams { sessionId: string, connectionBase: string, setSnack: (snackMessage: string, snackSeverity?: SeverityType) => void + onClearHistory: () => void }; -const Controls = ({ sessionId, connectionBase, setSnack }: ControlsParams) => { +const Controls = ({ sessionId, connectionBase, setSnack, onClearHistory }: ControlsParams) => { const [systemPrompt, setSystemPrompt] = useState(""); const [editSystemPrompt, setEditSystemPrompt] = useState(systemPrompt); + const [tools, setTools] = useState([]); + const [rags, setRags] = useState([]); useEffect(() => { if (systemPrompt !== "") { @@ -76,34 +88,99 @@ const Controls = ({ sessionId, connectionBase, setSnack }: ControlsParams) => { }, }); const data = await response.json(); - setSystemPrompt(data["system-prompt"]); - setEditSystemPrompt(data["system-prompt"]); + const systemPrompt = data["system-prompt"].trim(); + setSystemPrompt(systemPrompt); + setEditSystemPrompt(systemPrompt); } fetchSystemPrompt(); }, [sessionId, systemPrompt, setSystemPrompt, setEditSystemPrompt, connectionBase]); - const tools: Tool[] = [{ - label: "get_stock_price", - enabled: true, - }, { - label: "get_weather", - enabled: true - }, { - label: "site_summary", - enabled: true - }]; + useEffect(() => { + if (tools.length) { + return; + } + const fetchTools = async () => { + try { + // Make the fetch request with proper headers + const response = await fetch(connectionBase + `/api/tools/${sessionId}`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + 'Accept': 'application/json', + }, + }); + if (!response.ok) { + throw Error(); + } + const tools = await response.json(); + setTools(tools); + } catch (error: any) { + setSnack("Unable to fetch tools", "error"); + console.error(error); + } + } - const rags: Tool[] = [{ - label: "RAG JPK", - enabled: false - }, { - label: "RAG LKML", - enabled: false - }]; + fetchTools(); + }, [sessionId, tools, setTools, setSnack, connectionBase]); - const toggleTool = (event: any) => { - console.log(`${event.target.value} clicked`) + useEffect(() => { + if (rags.length) { + return; + } + const fetchRags = async () => { + try { + // Make the fetch request with proper headers + const response = await fetch(connectionBase + `/api/rags/${sessionId}`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + 'Accept': 'application/json', + }, + }); + if (!response.ok) { + throw Error(); + } + const rags = await response.json(); + setRags(rags); + } catch (error: any) { + setSnack("Unable to fetch RAGs", "error"); + console.error(error); + } + } + + fetchRags(); + }, [sessionId, rags, setRags, setSnack, connectionBase]); + + const toggleTool = async (type: string, index: number) => { + switch (type) { + case "rag": + setSnack("RAG backend not yet implemented", "warning"); + // rags[index].enabled = !rags[index].enabled + // setRags([...rags]) + break; + case "tool": + const tool = tools[index]; + tool.enabled = !tool.enabled + try { + const response = await fetch(connectionBase + `/api/tools/${sessionId}`, { + method: 'PUT', + headers: { + 'Content-Type': 'application/json', + 'Accept': 'application/json', + }, + body: JSON.stringify({ "tool": tool?.function?.name, "enabled": tool.enabled }), + }); + + const tools = await response.json(); + setTools([...tools]) + setSnack(`${tool?.function?.name} ${tool.enabled ? "enabled" : "disabled"}`); + } catch (error) { + console.error('Fetch error:', error); + setSnack(`${tool?.function?.name} ${tool.enabled ? "enabling" : "disabling"} failed.`, "error"); + tool.enabled = !tool.enabled + } + }; }; const sendSystemPrompt = async () => { @@ -123,7 +200,6 @@ const Controls = ({ sessionId, connectionBase, setSnack }: ControlsParams) => { }); const data = await response.json(); - console.log(data); if (data["system-prompt"] !== systemPrompt) { setSystemPrompt(data["system-prompt"].trim()); setSnack("System prompt updated", "success"); @@ -134,7 +210,7 @@ const Controls = ({ sessionId, connectionBase, setSnack }: ControlsParams) => { } }; - const resetSystemPrompt = async () => { + const onResetSystemPrompt = async () => { try { const response = await fetch(connectionBase + `/api/system-prompt/${sessionId}`, { method: 'PUT', @@ -155,6 +231,54 @@ const Controls = ({ sessionId, connectionBase, setSnack }: ControlsParams) => { } }; + const onResetToDefaults = async (event: any) => { + try { + const response = await fetch(connectionBase + `/api/reset/${sessionId}`, { + method: 'PUT', + headers: { + 'Content-Type': 'application/json', + 'Accept': 'application/json', + }, + body: JSON.stringify({ "reset": ["rag", "tools"] }), + }); + + if (response.ok) { + await response.json(); + setSnack("Defaults restored", "success"); + } else { + throw Error(`${{ status: response.status, message: response.statusText }}`); + } + } catch (error) { + console.error('Fetch error:', error); + setSnack("Unable to restore defaults", "error"); + } + }; + + const onClearContext = async (event: any) => { + try { + const response = await fetch(connectionBase + `/api/reset/${sessionId}`, { + method: 'PUT', + headers: { + 'Content-Type': 'application/json', + 'Accept': 'application/json', + }, + body: JSON.stringify({ "reset": ["history"] }), + }); + + if (response.ok) { + await response.json(); + setSnack("Chat history cleared", "success"); + onClearHistory(); + } else { + throw Error(`${{ status: response.status, message: response.statusText }}`); + } + } catch (error) { + console.error('Fetch error:', error); + setSnack("Unable to clear chat history", "error"); + } + }; + + const handleKeyPress = (event: any) => { if (event.key === 'Enter' && event.ctrlKey) { switch (event.target.id) { @@ -166,42 +290,10 @@ const Controls = ({ sessionId, connectionBase, setSnack }: ControlsParams) => { }; return (
- - } - > - Tools - - - These tools can be made available to the LLM for obtaining real-time information from the Internet. - - - - { - tools.map((tool, index) => { - return (} onChange={toggleTool} label={tool.label} />); - }) - } - - - - } - > - RAG - - - These RAG databases can be enabled / disabled for adding additional context based on the chat request. - - - - { - rags.map((rag, index) => { - return (} onChange={toggleTool} label={rag.label} />); - }) - } - - + + You can change the information available to the LLM by adjusting the following settings: + + } @@ -222,11 +314,57 @@ const Controls = ({ sessionId, connectionBase, setSnack }: ControlsParams) => { id="SystemPromptInput" />
- - + +
+ + } + > + Tools + + + These tools can be made available to the LLM for obtaining real-time information from the Internet. The description provided to the LLM is provided for reference. + + + + { + tools.map((tool, index) => + + + } onChange={() => toggleTool("tool", index)} label={tool?.function?.name} /> + {tool?.function?.description} + + ) + } + + + + } + > + RAG + + + These RAG databases can be enabled / disabled for adding additional context based on the chat request. + + + + { + rags.map((rag, index) => + + + } onChange={() => toggleTool("rag", index)} label={rag?.name} /> + {rag?.description} + + ) + } + + + +
); } @@ -243,14 +381,6 @@ const App = () => { const [snackMessage, setSnackMessage] = useState(""); const [snackSeverity, setSnackSeverity] = useState("success"); - useEffect(() => { - if (snackMessage === "") { - setSnackOpen(false); - } else { - setSnackOpen(true); - } - }, [snackMessage, setSnackOpen]); - // Scroll to bottom of conversation when conversation updates useEffect(() => { if (conversationRef.current) { @@ -295,6 +425,7 @@ const App = () => { const setSnack = (message: string, severity: SeverityType = "success") => { setSnackMessage(message); setSnackSeverity(severity); + setSnackOpen(true); } const handleDrawerClose = () => { @@ -312,9 +443,13 @@ const App = () => { } }; + const onClearHistory = () => { + setConversation([welcomeMessage]); + }; + const drawer = ( <> - {sessionId !== undefined && } + {sessionId !== undefined && } ); @@ -399,36 +534,22 @@ const App = () => { const decoder = new TextDecoder(); let buffer = ''; - // Debug message to console - console.log('Starting to process stream'); - while (true) { const { done, value } = await reader.read(); - if (done) { - console.log('Stream complete'); break; } - // Convert chunk to text and debug const chunk = decoder.decode(value, { stream: true }); - console.log('Received chunk:', chunk); - - // Add to buffer and process lines - buffer += chunk; - - // Process complete lines - let lines = buffer.split('\n'); - buffer = lines.pop() || ''; // Keep incomplete line in buffer - - console.log(`Processing ${lines.length} complete lines`); // Process each complete line immediately + buffer += chunk; + let lines = buffer.split('\n'); + buffer = lines.pop() || ''; // Keep incomplete line in buffer for (const line of lines) { if (!line.trim()) continue; try { - console.log('Processing line:', line); const update = JSON.parse(line); // Force an immediate state update based on the message type @@ -457,6 +578,7 @@ const App = () => { ]); } } catch (e) { + setSnack("Error processing query", "error") console.error('Error parsing JSON:', e, line); } } @@ -465,7 +587,6 @@ const App = () => { // Process any remaining buffer content if (buffer.trim()) { try { - console.log('Processing final buffer:', buffer); const update = JSON.parse(buffer); if (update.status === 'done') { @@ -475,7 +596,7 @@ const App = () => { ]); } } catch (e) { - console.error('Error parsing final buffer:', e); + setSnack("Error processing query", "error") } } @@ -495,8 +616,6 @@ const App = () => { event: React.SyntheticEvent | Event, reason?: SnackbarCloseReason, ) => { - setSnackMessage(""); - if (reason === 'clickaway') { return; } diff --git a/src/server.py b/src/server.py index 139598f..101d5ae 100644 --- a/src/server.py +++ b/src/server.py @@ -58,6 +58,11 @@ from tools import ( tools ) +rags = [ + { "name": "JPK", "enabled": False, "description": "Expert data about James Ketrenos, including work history, personal hobbies, and projects." }, + { "name": "LKML", "enabled": False, "description": "Full associative data for entire LKML mailing list archive." }, +] + # %% # Defaults OLLAMA_API_URL = "http://ollama:11434" # Default Ollama local endpoint @@ -328,6 +333,13 @@ def is_valid_uuid(value): except (ValueError, TypeError): return False + +def default_tools(tools): + return [{**tool, "enabled": True} for tool in tools] + +def llm_tools(tools): + return [tool for tool in tools if tool.get("enabled", False) == True] + # %% class WebServer: def __init__(self, logging, client, model=MODEL_NAME): @@ -356,6 +368,35 @@ class WebServer: return RedirectResponse(url=f"/{context['id']}", status_code=307) #return JSONResponse({"redirect": f"/{context['id']}"}) + @self.app.put('/api/reset/{context_id}') + async def put_reset(context_id: str, request: Request): + if not is_valid_uuid(context_id): + logging.warning(f"Invalid context_id: {context_id}") + return JSONResponse({"error": "Invalid context_id"}, status_code=400) + context = self.upsert_context(context_id) + data = await request.json() + try: + for reset in data["reset"]: + match reset: + case "system-prompt": + context["system"] = [{"role": "system", "content": system_message}] + return JSONResponse(context["system"]) + case "rag": + context["rag"] = rags.copy() + return JSONResponse(context["rag"]) + case "tools": + context["tools"] = default_tools(tools) + return JSONResponse(context["tools"]) + case "history": + context["history"] = [] + return JSONResponse(context["history"]) + + return JSONResponse({ "error": "Usage: { reset: rag|tools|history|system-prompt}"}), 405 + + except: + return JSONResponse({ "error": "Usage: { reset: rag|tools|history|system-prompt}"}), 405 + + @self.app.put('/api/system-prompt/{context_id}') async def put_system_prompt(context_id: str, request: Request): if not is_valid_uuid(context_id): @@ -365,7 +406,7 @@ class WebServer: data = await request.json() system_prompt = data["system-prompt"].strip() if not system_prompt: - system_prompt = system_message + return JSONResponse({ "status": "error", "message": "System prompt can not be empty." }), 405 context["system"] = [{"role": "system", "content": system_prompt}] return JSONResponse({ "system-prompt": system_prompt }) @@ -398,6 +439,7 @@ class WebServer: "X-Accel-Buffering": "no" # Prevents Nginx buffering if you're using it } ) + @self.app.post('/api/context') async def create_context(): context = self.create_context() @@ -414,6 +456,29 @@ class WebServer: context = self.upsert_context(context_id) return JSONResponse(context["tools"]) + @self.app.put('/api/tools/{context_id}') + async def put_tools(context_id: str, request: Request): + if not is_valid_uuid(context_id): + logging.warning(f"Invalid context_id: {context_id}") + return JSONResponse({"error": "Invalid context_id"}, status_code=400) + context = self.upsert_context(context_id) + try: + data = await request.json() + modify = data["tool"] + enabled = data["enabled"] + for tool in context["tools"]: + if modify == tool["function"]["name"]: + tool["enabled"] = enabled + return JSONResponse(context["tools"]) + return JSONResponse({ "status": f"{modify} not found in tools." }), 404 + except: + return JSONResponse({ "status": "error" }), 405 + + @self.app.get('/api/rags/{context_id}') + async def get_rags(context_id: str): + context = self.upsert_context(context_id) + return JSONResponse(context["rags"]) + @self.app.get('/api/health') async def health_check(): return JSONResponse({"status": "healthy"}) @@ -434,7 +499,8 @@ class WebServer: "id": context_id, "system": [{"role": "system", "content": system_message}], "history": [], - "tools": [] + "tools": default_tools(tools), + "rags": rags.copy() } logging.info(f"{context_id} created and added to sessions.") self.contexts[context_id] = context @@ -472,7 +538,7 @@ class WebServer: yield {"status": "processing", "message": "Processing request..."} # Use the async generator in an async for loop - response = self.client.chat(model=self.model, messages=messages, tools=tools) + response = self.client.chat(model=self.model, messages=messages, tools=llm_tools(context["tools"])) tools_used = [] yield {"status": "processing", "message": "Initial response received"} diff --git a/src/tools.py b/src/tools.py index 95ed853..e224a0c 100644 --- a/src/tools.py +++ b/src/tools.py @@ -352,11 +352,13 @@ tools = [ { "properties": { "city": { "type": "string", - "description": "City to find the weather forecast (e.g., 'Portland', 'Seattle')." + "description": "City to find the weather forecast (e.g., 'Portland', 'Seattle').", + "minLength": 2 }, "state": { "type": "string", - "description": "State to find the weather forecast (e.g., 'OR', 'WA')." + "description": "State to find the weather forecast (e.g., 'OR', 'WA').", + "minLength": 2 } }, "required": [ "city", "state" ],