backstory/frontend/src/Controls.tsx

444 lines
15 KiB
TypeScript

import React, { useState, useEffect, ReactElement } from 'react';
import FormGroup from '@mui/material/FormGroup';
import FormControlLabel from '@mui/material/FormControlLabel';
import Switch from '@mui/material/Switch';
import Divider from '@mui/material/Divider';
import TextField from '@mui/material/TextField';
import Accordion from '@mui/material/Accordion';
import AccordionActions from '@mui/material/AccordionActions';
import AccordionSummary from '@mui/material/AccordionSummary';
import AccordionDetails from '@mui/material/AccordionDetails';
import Typography from '@mui/material/Typography';
import Button from '@mui/material/Button';
import Box from '@mui/material/Box';
import ResetIcon from '@mui/icons-material/History';
import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
import { SetSnackType } from './Snack';
interface ServerTunables {
system_prompt: string,
message_history_length: number,
tools: Tool[],
rags: Tool[]
};
type Tool = {
type: string,
enabled: boolean
name: string,
description: string,
parameters?: any,
returns?: any
};
interface ControlsParams {
connectionBase: string,
sessionId: string | undefined,
setSnack: SetSnackType,
};
type GPUInfo = {
name: string,
memory: number,
discrete: boolean
}
type SystemInfo = {
"Installed RAM": string,
"Graphics Card": GPUInfo[],
"CPU": string
};
const SystemInfoComponent: React.FC<{ systemInfo: SystemInfo | undefined }> = ({ systemInfo }) => {
const [systemElements, setSystemElements] = useState<ReactElement[]>([]);
const convertToSymbols = (text: string) => {
return text
.replace(/\(R\)/g, '®') // Replace (R) with the ® symbol
.replace(/\(C\)/g, '©') // Replace (C) with the © symbol
.replace(/\(TM\)/g, '™'); // Replace (TM) with the ™ symbol
};
useEffect(() => {
if (systemInfo === undefined) {
return;
}
const elements = Object.entries(systemInfo).flatMap(([k, v]) => {
// If v is an array, repeat for each card
if (Array.isArray(v)) {
return v.map((card, index) => (
<div key={index} className="SystemInfoItem">
<div>{convertToSymbols(k)} {index}</div>
<div>{convertToSymbols(card.name)} {card.discrete ? `w/ ${Math.round(card.memory / (1024 * 1024 * 1024))}GB RAM` : "(integrated)"}</div>
</div>
));
}
// If it's not an array, handle normally
return (
<div key={k} className="SystemInfoItem">
<div>{convertToSymbols(k)}</div>
<div>{convertToSymbols(String(v))}</div>
</div>
);
});
setSystemElements(elements);
}, [systemInfo]);
return <div className="SystemInfo">{systemElements}</div>;
};
const Controls = ({ sessionId, setSnack, connectionBase }: ControlsParams) => {
const [editSystemPrompt, setEditSystemPrompt] = useState<string>("");
const [systemInfo, setSystemInfo] = useState<SystemInfo | undefined>(undefined);
const [tools, setTools] = useState<Tool[]>([]);
const [rags, setRags] = useState<Tool[]>([]);
const [systemPrompt, setSystemPrompt] = useState<string>("");
const [messageHistoryLength, setMessageHistoryLength] = useState<number>(5);
const [serverTunables, setServerTunables] = useState<ServerTunables | undefined>(undefined);
useEffect(() => {
if (serverTunables === undefined || systemPrompt === serverTunables.system_prompt || !systemPrompt.trim() || sessionId === undefined) {
return;
}
const sendSystemPrompt = async (prompt: string) => {
try {
const response = await fetch(connectionBase + `/api/tunables/${sessionId}`, {
method: 'PUT',
headers: {
'Content-Type': 'application/json',
'Accept': 'application/json',
},
body: JSON.stringify({ "system_prompt": prompt }),
});
const tunables = await response.json();
serverTunables.system_prompt = tunables.system_prompt;
setSystemPrompt(tunables.system_prompt)
setSnack("System prompt updated", "success");
} catch (error) {
console.error('Fetch error:', error);
setSnack("System prompt update failed", "error");
}
};
sendSystemPrompt(systemPrompt);
}, [systemPrompt, connectionBase, sessionId, setSnack, serverTunables]);
useEffect(() => {
if (serverTunables === undefined || messageHistoryLength === serverTunables.message_history_length || !messageHistoryLength || sessionId === undefined) {
return;
}
const sendMessageHistoryLength = async (length: number) => {
try {
const response = await fetch(connectionBase + `/api/tunables/${sessionId}`, {
method: 'PUT',
headers: {
'Content-Type': 'application/json',
'Accept': 'application/json',
},
body: JSON.stringify({ "message_history_length": length }),
});
const data = await response.json();
const newLength = data["message_history_length"];
if (newLength !== messageHistoryLength) {
setMessageHistoryLength(newLength);
setSnack("Message history length updated", "success");
}
} catch (error) {
console.error('Fetch error:', error);
setSnack("Message history length update failed", "error");
}
};
sendMessageHistoryLength(messageHistoryLength);
}, [messageHistoryLength, setMessageHistoryLength, connectionBase, sessionId, setSnack, serverTunables]);
const reset = async (types: ("rags" | "tools" | "history" | "system_prompt" | "message_history_length")[], message: string = "Update successful.") => {
try {
const response = await fetch(connectionBase + `/api/reset/${sessionId}`, {
method: 'PUT',
headers: {
'Content-Type': 'application/json',
'Accept': 'application/json',
},
body: JSON.stringify({ "reset": types }),
});
if (response.ok) {
const data = await response.json();
if (data.error) {
throw Error()
}
for (const [key, value] of Object.entries(data)) {
switch (key) {
case "rags":
setRags(value as Tool[]);
break;
case "tools":
setTools(value as Tool[]);
break;
case "system_prompt":
setSystemPrompt((value as ServerTunables)["system_prompt"].trim());
break;
case "history":
console.log('TODO: handle history reset');
break;
}
}
setSnack(message, "success");
} else {
throw Error(`${{ status: response.status, message: response.statusText }}`);
}
} catch (error) {
console.error('Fetch error:', error);
setSnack("Unable to restore defaults", "error");
}
};
// Get the system information
useEffect(() => {
if (systemInfo !== undefined || sessionId === undefined) {
return;
}
fetch(connectionBase + `/api/system-info/${sessionId}`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
},
})
.then(response => response.json())
.then(data => {
setSystemInfo(data);
})
.catch(error => {
console.error('Error obtaining system information:', error);
setSnack("Unable to obtain system information.", "error");
});
}, [systemInfo, setSystemInfo, connectionBase, setSnack, sessionId])
useEffect(() => {
setEditSystemPrompt(systemPrompt);
}, [systemPrompt, setEditSystemPrompt]);
const toggleRag = async (tool: Tool) => {
tool.enabled = !tool.enabled
try {
const response = await fetch(connectionBase + `/api/tunables/${sessionId}`, {
method: 'PUT',
headers: {
'Content-Type': 'application/json',
'Accept': 'application/json',
},
body: JSON.stringify({ "rags": [{ "name": tool?.name, "enabled": tool.enabled }] }),
});
const tunables: ServerTunables = await response.json();
setRags(tunables.rags)
setSnack(`${tool?.name} ${tool.enabled ? "enabled" : "disabled"}`);
} catch (error) {
console.error('Fetch error:', error);
setSnack(`${tool?.name} ${tool.enabled ? "enabling" : "disabling"} failed.`, "error");
tool.enabled = !tool.enabled
}
};
const toggleTool = async (tool: Tool) => {
tool.enabled = !tool.enabled
try {
const response = await fetch(connectionBase + `/api/tunables/${sessionId}`, {
method: 'PUT',
headers: {
'Content-Type': 'application/json',
'Accept': 'application/json',
},
body: JSON.stringify({ "tools": [{ "name": tool.name, "enabled": tool.enabled }] }),
});
const tunables: ServerTunables = await response.json();
setTools(tunables.tools)
setSnack(`${tool.name} ${tool.enabled ? "enabled" : "disabled"}`);
} catch (error) {
console.error('Fetch error:', error);
setSnack(`${tool.name} ${tool.enabled ? "enabling" : "disabling"} failed.`, "error");
tool.enabled = !tool.enabled
}
};
// If the systemPrompt has not been set, fetch it from the server
useEffect(() => {
if (serverTunables !== undefined || sessionId === undefined) {
return;
}
const fetchTunables = async () => {
// Make the fetch request with proper headers
const response = await fetch(connectionBase + `/api/tunables/${sessionId}`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
'Accept': 'application/json',
},
});
const data = await response.json();
console.log("Server tunables: ", data);
setServerTunables(data);
setSystemPrompt(data["system_prompt"]);
setMessageHistoryLength(data["message_history_length"]);
setTools(data["tools"]);
setRags(data["rags"]);
}
fetchTunables();
}, [sessionId, connectionBase, setServerTunables, setSystemPrompt, setMessageHistoryLength, serverTunables, setTools, setRags]);
const toggle = async (type: string, index: number) => {
switch (type) {
case "rag":
if (rags === undefined) {
return;
}
toggleRag(rags[index])
break;
case "tool":
if (tools === undefined) {
return;
}
toggleTool(tools[index]);
}
};
const handleKeyPress = (event: any) => {
if (event.key === 'Enter' && event.ctrlKey) {
switch (event.target.id) {
case 'SystemPromptInput':
setSystemPrompt(editSystemPrompt);
break;
}
}
};
return (<div className="Controls">
<Typography component="span" sx={{ mb: 1 }}>
You can change the information available to the LLM by adjusting the following settings:
</Typography>
<Accordion>
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
<Typography component="span">System Prompt</Typography>
</AccordionSummary>
<AccordionActions style={{ flexDirection: "column" }}>
<TextField
variant="outlined"
fullWidth
multiline
type="text"
value={editSystemPrompt}
onChange={(e) => setEditSystemPrompt(e.target.value)}
onKeyDown={handleKeyPress}
placeholder="Enter the new system prompt.."
id="SystemPromptInput"
/>
<div style={{ display: "flex", flexDirection: "row", gap: "8px", paddingTop: "8px" }}>
<Button variant="contained" disabled={editSystemPrompt === systemPrompt} onClick={() => { setSystemPrompt(editSystemPrompt); }}>Set</Button>
<Button variant="outlined" onClick={() => { reset(["system_prompt"], "System prompt reset."); }} color="error">Reset</Button>
</div>
</AccordionActions>
</Accordion>
<Accordion>
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
<Typography component="span">Tunables</Typography>
</AccordionSummary>
<AccordionActions style={{ flexDirection: "column" }}>
<TextField
id="outlined-number"
label="Message history"
type="number"
helperText="Only use this many messages as context. 0 = All. Keeping this low will reduce context growth and improve performance."
value={messageHistoryLength}
onChange={(e: any) => setMessageHistoryLength(e.target.value)}
slotProps={{
htmlInput: {
min: 0
},
inputLabel: {
shrink: true,
},
}}
/>
</AccordionActions>
</Accordion>
<Accordion>
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
<Typography component="span">Tools</Typography>
</AccordionSummary>
<AccordionDetails>
These tools can be made available to the LLM for obtaining real-time information from the Internet. The description provided to the LLM is provided for reference.
</AccordionDetails>
<AccordionActions>
<FormGroup sx={{ p: 1 }}>
{
(tools || []).map((tool, index) =>
<Box key={index}>
<Divider />
<FormControlLabel control={<Switch checked={tool.enabled} />} onChange={() => toggle("tool", index)} label={tool.name} />
<Typography sx={{ fontSize: "0.8rem", mb: 1 }}>{tool.description}</Typography>
</Box>
)
}</FormGroup>
</AccordionActions>
</Accordion>
<Accordion>
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
<Typography component="span">RAG</Typography>
</AccordionSummary>
<AccordionDetails>
These RAG databases can be enabled / disabled for adding additional context based on the chat request.
</AccordionDetails>
<AccordionActions>
<FormGroup sx={{ p: 1, flexGrow: 1, justifyContent: "flex-start" }}>
{
(rags || []).map((rag, index) =>
<Box key={index} sx={{ display: "flex", flexGrow: 1, flexDirection: "column" }}>
<Divider />
<FormControlLabel
control={<Switch checked={rag.enabled} />}
onChange={() => toggle("rag", index)} label={rag.name}
/>
<Typography>{rag.description}</Typography>
</Box>
)
}</FormGroup>
</AccordionActions>
</Accordion>
<Accordion>
<AccordionSummary expandIcon={<ExpandMoreIcon />}>
<Typography component="span">System Information</Typography>
</AccordionSummary>
<AccordionDetails>
The server is running on the following hardware:
</AccordionDetails>
<AccordionActions>
<SystemInfoComponent systemInfo={systemInfo} />
</AccordionActions>
</Accordion>
<Button startIcon={<ResetIcon />} onClick={() => { reset(["history"], "History cleared."); }}>Delete Backstory History</Button>
<Button onClick={() => { reset(["rags", "tools", "system_prompt", "message_history_length"], "Default settings restored.") }}>Reset system prompt, tunables, and RAG to defaults</Button>
</div>);
}
export type {
ControlsParams
};
export {
Controls
};