backstory/frontend/src/components/VectorVisualizer.tsx

622 lines
21 KiB
TypeScript

import React, { useEffect, useState, useRef } from 'react';
import Box from '@mui/material/Box';
import Paper from '@mui/material/Paper';
import Plot from 'react-plotly.js';
import TextField from '@mui/material/TextField';
import Tooltip from '@mui/material/Tooltip';
import Button from '@mui/material/Button';
import SendIcon from '@mui/icons-material/Send';
import FormControlLabel from '@mui/material/FormControlLabel';
import Switch from '@mui/material/Switch';
import useMediaQuery from '@mui/material/useMediaQuery';
import { SxProps, useTheme } from '@mui/material/styles';
import Table from '@mui/material/Table';
import TableBody from '@mui/material/TableBody';
import TableCell from '@mui/material/TableCell';
import TableContainer from '@mui/material/TableContainer';
import TableRow from '@mui/material/TableRow';
import { Scrollable } from './Scrollable';
import './VectorVisualizer.css';
import { BackstoryPageProps } from './BackstoryTab';
import { useAuth } from 'hooks/AuthContext';
import * as Types from 'types/types';
import { useAppState, useSelectedCandidate } from 'hooks/GlobalContext';
import { useNavigate } from 'react-router-dom';
interface VectorVisualizerProps extends BackstoryPageProps {
inline?: boolean;
rag?: any;
};
interface Metadata {
id: string;
docType: string;
content: string;
distance?: number;
}
const emptyQuerySet = {
ids: [],
documents: [],
metadatas: [],
embeddings: [],
};
interface PlotData {
x: number[];
y: number[];
z?: number[];
colors: string[];
text: string[];
sizes: number[];
customdata: Metadata[];
}
const config: Partial<Plotly.Config> = {
responsive: true,
autosizable: true,
displaylogo: false,
showSendToCloud: false,
staticPlot: false,
frameMargins: 0,
scrollZoom: false,
doubleClick: false,
// | "lasso2d"
// | "select2d"
// | "sendDataToCloud"
// | "zoom2d"
// | "pan2d"
// | "zoomIn2d"
// | "zoomOut2d"
// | "autoScale2d"
// | "resetScale2d"
// | "hoverClosestCartesian"
// | "hoverCompareCartesian"
// | "zoom3d"
// | "pan3d"
// | "orbitRotation"
// | "tableRotation"
// | "handleDrag3d"
// | "resetCameraDefault3d"
// | "resetCameraLastSave3d"
// | "hoverClosest3d"
// | "zoomInGeo"
// | "zoomOutGeo"
// | "resetGeo"
// | "hoverClosestGeo"
// | "hoverClosestGl2d"
// | "hoverClosestPie"
// | "toggleHover"
// | "toImage"
// | "resetViews"
// | "toggleSpikelines"
// | "zoomInMapbox"
// | "zoomOutMapbox"
// | "resetViewMapbox"
// | "togglespikelines"
// | "togglehover"
// | "hovercompare"
// | "hoverclosest"
// | "v1hovermode";
modeBarButtonsToRemove: [
'lasso2d', 'select2d',
]
};
const layout: Partial<Plotly.Layout> = {
autosize: false,
clickmode: 'event+select',
paper_bgcolor: '#FFFFFF', // white
plot_bgcolor: '#FFFFFF', // white plot background
font: {
family: 'Roboto, sans-serif',
color: '#2E2E2E', // charcoal black
},
hovermode: 'closest',
scene: {
bgcolor: '#FFFFFF', // 3D plot background
zaxis: { title: 'Z', gridcolor: '#cccccc', zerolinecolor: '#aaaaaa' },
},
xaxis: { title: 'X', gridcolor: '#cccccc', zerolinecolor: '#aaaaaa' },
yaxis: { title: 'Y', gridcolor: '#cccccc', zerolinecolor: '#aaaaaa' },
margin: { r: 0, b: 0, l: 0, t: 0 },
legend: {
x: 0.8, // Horizontal position (0 to 1, 0 is left, 1 is right)
y: 0, // Vertical position (0 to 1, 0 is bottom, 1 is top)
xanchor: 'left',
yanchor: 'top',
orientation: 'h' // 'v' for horizontal legend
},
showlegend: true // Show the legend
};
const normalizeDimension = (arr: number[]): number[] => {
const min = Math.min(...arr);
const max = Math.max(...arr);
const range = max - min;
if (range === 0) return arr.map(() => 0.5); // flat dimension
return arr.map(v => (v - min) / range);
};
const emojiMap: Record<string, string> = {
query: '🔍',
resume: '📄',
projects: '📁',
jobs: '📁',
'performance-reviews': '📄',
news: '📰',
};
const colorMap: Record<string, string> = {
query: '#D4A017', // Golden Ochre — strong highlight
resume: '#4A7A7D', // Dusty Teal — secondary theme color
projects: '#1A2536', // Midnight Blue — rich and deep
news: '#D3CDBF', // Warm Gray — soft and neutral
'performance-reviews': '#8FD0D0', // Light red
'jobs': '#F3aD8F', // Warm Gray — soft and neutral
};
const DEFAULT_SIZE = 6.;
const DEFAULT_UNFOCUS_SIZE = 2.;
type Node = {
id: string,
content: string, // Portion of content that was used for embedding
fullContent: string | undefined, // Portion of content plus/minus buffer
emoji: string,
docType: string,
source_file: string,
distance: number | undefined,
path: string,
chunkBegin: number,
lineBegin: number,
chunkEnd: number,
lineEnd: number,
sx: SxProps,
};
const VectorVisualizer: React.FC<VectorVisualizerProps> = (props: VectorVisualizerProps) => {
const { user, apiClient } = useAuth();
const { rag, inline, sx } = props;
const { setSnack } = useAppState();
const [plotData, setPlotData] = useState<PlotData | null>(null);
const [newQuery, setNewQuery] = useState<string>('');
const [querySet, setQuerySet] = useState<Types.ChromaDBGetResponse>(rag || emptyQuerySet);
const [result, setResult] = useState<Types.ChromaDBGetResponse | null>(null);
const [view2D, setView2D] = useState<boolean>(true);
const plotlyRef = useRef(null);
const boxRef = useRef<HTMLElement>(null);
const [node, setNode] = useState<Node | null>(null);
const theme = useTheme();
const isMobile = useMediaQuery(theme.breakpoints.down('md'));
const [plotDimensions, setPlotDimensions] = useState({ width: 0, height: 0 });
const navigate = useNavigate();
const candidate: Types.Candidate | null = user?.userType === 'candidate' ? user as Types.Candidate : null;
/* Force resize of Plotly as it tends to not be the correct size if it is initially rendered
* off screen (eg., the VectorVisualizer is not on the tab the app loads to) */
useEffect(() => {
if (!boxRef.current) {
return;
}
const resize = () => {
requestAnimationFrame(() => {
const plotContainer = document.querySelector('.plot-container') as HTMLElement;
const svgContainer = document?.querySelector('.svg-container') as HTMLElement;
if (plotContainer && svgContainer) {
const plotContainerRect = plotContainer.getBoundingClientRect();
svgContainer.style.width = `${plotContainerRect.width}px`;
svgContainer.style.height = `${plotContainerRect.height}px`;
if (plotDimensions.width !== plotContainerRect.width || plotDimensions.height !== plotContainerRect.height) {
setPlotDimensions({ width: plotContainerRect.width, height: plotContainerRect.height });
}
}
});
}
resize();
});
// Get the collection to visualize
useEffect(() => {
if (result) {
return;
}
const fetchCollection = async () => {
if (!candidate) {
return;
}
try {
const result = await apiClient.getCandidateVectors(view2D ? 2 : 3);
setResult(result);
} catch (error) {
console.error('Error obtaining collection information:', error);
setSnack("Unable to obtain collection information.", "error");
};
};
fetchCollection();
}, [result, setSnack, view2D])
useEffect(() => {
if (!result || !result.embeddings) return;
if (result.embeddings.length === 0) return;
const full: Types.ChromaDBGetResponse = {
...result,
ids: [...result.ids || []],
documents: [...result.documents || []],
embeddings: [...result.embeddings],
metadatas: [...result.metadatas || []],
};
let is2D = full.embeddings.every((v: number[]) => v.length === 2);
let is3D = full.embeddings.every((v: number[]) => v.length === 3);
if ((view2D && !is2D) || (!view2D && !is3D)) {
return;
}
if (!is2D && !is3D) {
console.warn('Modified vectors are neither 2D nor 3D');
return;
}
let query: Types.ChromaDBGetResponse = {
ids: [],
documents: [],
embeddings: [],
metadatas: [],
distances: [],
query: '',
size: 0,
dimensions: 2,
name: ''
};
let filtered: Types.ChromaDBGetResponse = {
ids: [],
documents: [],
embeddings: [],
metadatas: [],
distances: [],
query: '',
size: 0,
dimensions: 2,
name: ''
};
/* Loop through all items and divide into two groups:
* filtered is for any item not in the querySet
* query is for any item that is in the querySet
*/
full.ids.forEach((id, index) => {
const foundIndex = querySet.ids.indexOf(id);
/* Update metadata to hold the doc content and id */
full.metadatas[index].id = id;
full.metadatas[index].content = full.documents[index];
if (foundIndex !== -1) {
/* The query set will contain the distance to the query */
full.metadatas[index].distance = querySet.distances ? querySet.distances[foundIndex] : undefined;
query.ids.push(id);
query.documents.push(full.documents[index]);
query.embeddings.push(full.embeddings[index]);
query.metadatas.push(full.metadatas[index]);
} else {
/* THe filtered set does not have a distance */
full.metadatas[index].distance = undefined;
filtered.ids.push(id);
filtered.documents.push(full.documents[index]);
filtered.embeddings.push(full.embeddings[index]);
filtered.metadatas.push(full.metadatas[index]);
}
});
if (view2D && querySet.umapEmbedding2D && querySet.umapEmbedding2D.length) {
query.ids.unshift('query');
query.metadatas.unshift({ id: 'query', docType: 'query', content: querySet.query || '', distance: 0 });
query.embeddings.unshift(querySet.umapEmbedding2D);
}
if (!view2D && querySet.umapEmbedding3D && querySet.umapEmbedding3D.length) {
query.ids.unshift('query');
query.metadatas.unshift({ id: 'query', docType: 'query', content: querySet.query || '', distance: 0 });
query.embeddings.unshift(querySet.umapEmbedding3D);
}
const filtered_docTypes = filtered.metadatas.map(m => m.docType || 'unknown')
const query_docTypes = query.metadatas.map(m => m.docType || 'unknown')
const has_query = query.metadatas.length > 0;
const filtered_sizes = filtered.metadatas.map(m => has_query ? DEFAULT_UNFOCUS_SIZE : DEFAULT_SIZE);
const filtered_colors = filtered_docTypes.map(type => colorMap[type] || '#4d4d4d');
const filtered_x = normalizeDimension(filtered.embeddings.map((v: number[]) => v[0]));
const filtered_y = normalizeDimension(filtered.embeddings.map((v: number[]) => v[1]));
const filtered_z = is3D ? normalizeDimension(filtered.embeddings.map((v: number[]) => v[2])) : undefined;
const query_sizes = query.metadatas.map(m => DEFAULT_SIZE + 2. * DEFAULT_SIZE * Math.pow((1. - (m.distance || 1.)), 3));
const query_colors = query_docTypes.map(type => colorMap[type] || '#4d4d4d');
const query_x = normalizeDimension(query.embeddings.map((v: number[]) => v[0]));
const query_y = normalizeDimension(query.embeddings.map((v: number[]) => v[1]));
const query_z = is3D ? normalizeDimension(query.embeddings.map((v: number[]) => v[2])) : undefined;
const data: any = [{
name: 'All data',
x: filtered_x,
y: filtered_y,
mode: 'markers',
marker: {
size: filtered_sizes,
symbol: 'circle',
color: filtered_colors,
opacity: 1
},
text: filtered.ids,
customdata: filtered.metadatas,
type: is3D ? 'scatter3d' : 'scatter',
hovertemplate: '&nbsp;',
}, {
name: 'Query',
x: query_x,
y: query_y,
mode: 'markers',
marker: {
size: query_sizes,
symbol: 'circle',
color: query_colors,
opacity: 1
},
text: query.ids,
customdata: query.metadatas,
type: is3D ? 'scatter3d' : 'scatter',
hovertemplate: '%{text}',
}];
if (is3D) {
data[0].z = filtered_z;
data[1].z = query_z;
}
setPlotData(data);
}, [result, querySet, view2D]);
const handleKeyPress = (event: any) => {
if (event.key === 'Enter') {
sendQuery(newQuery);
}
};
const sendQuery = async (query: string) => {
if (!query.trim()) return;
setNewQuery('');
try {
const result = await apiClient.getCandidateSimilarContent(query);
console.log(result);
setQuerySet(result);
} catch (error) {
const msg = `Error obtaining similar content to ${query}.`
setSnack(msg, "error");
};
};
if (!result) return (
<Box sx={{ display: 'flex', flexGrow: 1, justifyContent: 'center', alignItems: 'center' }}>
<div>Loading visualization...</div>
</Box>
);
if (!candidate) return (
<Box sx={{ display: 'flex', flexGrow: 1, justifyContent: 'center', alignItems: 'center' }}>
<div>No candidate selected. Please <Button onClick={() => navigate('/find-a-candidate')}>select a candidate</Button> first.</div>
</Box>
);
const fetchRAGMeta = async (node: Node) => {
try {
const result = await apiClient.getCandidateRAGContent(node.id);
const update: Node = {
...node,
fullContent: result.content
}
setNode(update);
} catch (error) {
const msg = `Error obtaining content for ${node.id}.`
console.error(msg, error);
setSnack(msg, "error");
};
};
const onNodeSelected = (metadata: any) => {
let node: Node;
console.log(metadata);
if (metadata.docType === 'query') {
node = {
...metadata,
content: `Similarity results for the query **${querySet.query || ''}**
The scatter graph shows the query in N-dimensional space, mapped to ${view2D ? '2' : '3'}-dimensional space. Larger dots represent relative similarity in N-dimensional space.
`,
emoji: emojiMap[metadata.docType],
sx: {
m: 0.5,
p: 2,
width: '3rem',
display: "flex",
alignContent: "center",
justifyContent: "center",
flexGrow: 0,
flexWrap: "wrap",
backgroundColor: colorMap[metadata.docType] || '#ff8080',
}
}
setNode(node);
return;
}
node = {
content: `Loading...`,
...metadata,
emoji: emojiMap[metadata.docType] || '❓',
}
setNode(node);
fetchRAGMeta(node);
};
return (
<Box className="VectorVisualizer"
ref={boxRef}
sx={{
...sx
}}>
<Box sx={{ p: 0, m: 0, gap: 0 }}>
<Paper sx={{
p: 0.5, m: 0,
display: "flex",
flexGrow: 0,
height: isMobile ? "auto" : "auto", //"320px",
minHeight: isMobile ? "auto" : "auto", //"320px",
maxHeight: isMobile ? "auto" : "auto", //"320px",
position: "relative",
flexDirection: "column"
}}>
<FormControlLabel
sx={{
display: "flex",
position: "relative",
width: "fit-content",
ml: 1,
mb: '-2.5rem',
zIndex: 100,
flexBasis: 0,
flexGrow: 0
}}
control={<Switch checked={!view2D} />} onChange={() => { setView2D(!view2D); setResult(null); }} label="3D" />
<Plot
ref={plotlyRef}
onClick={(event: any) => { onNodeSelected(event.points[0].customdata); }}
data={plotData}
useResizeHandler={true}
config={config}
style={{
display: "flex",
flexGrow: 1,
minHeight: '240px',
padding: 0,
margin: 0,
width: "100%",
height: "100%",
overflow: "hidden",
}}
layout={{...layout, width: plotDimensions.width, height: plotDimensions.height }}
/>
</Paper>
<Paper sx={{ display: "flex", flexDirection: isMobile ? "column" : "row", mt: 0.5, p: 0.5, flexGrow: 1, minHeight: "fit-content" }}>
{node !== null &&
<Box sx={{ display: "flex", fontSize: "0.75rem", flexDirection: "column", flexGrow: 1, maxWidth: "100%", flexBasis: 1, maxHeight: "min-content" }}>
<TableContainer component={Paper} sx={{ mb: isMobile ? 1 : 0, mr: isMobile ? 0 : 1 }}>
<Table size="small" sx={{ tableLayout: 'fixed' }}>
<TableBody sx={{ '& td': { verticalAlign: "top", fontSize: "0.75rem", }, '& td:first-of-type': { whiteSpace: "nowrap", width: "1rem" } }}>
<TableRow>
<TableCell>Type</TableCell>
<TableCell>{node.emoji} {node.docType}</TableCell>
</TableRow>
{node.source_file !== undefined && <TableRow>
<TableCell>File</TableCell>
<TableCell>{node.source_file.replace(/^.*\//, '')}</TableCell>
</TableRow>}
{node.path !== undefined && <TableRow>
<TableCell>Section</TableCell>
<TableCell>{node.path}</TableCell>
</TableRow>}
{node.distance !== undefined && <TableRow>
<TableCell>Distance</TableCell>
<TableCell>{node.distance}</TableCell>
</TableRow>}
</TableBody>
</Table>
</TableContainer>
{node.content !== "" && node.content !== undefined &&
<Paper elevation={6} sx={{ display: "flex", flexDirection: "column", border: "1px solid #808080", minHeight: "fit-content", mt: 1 }}>
<Box sx={{ display: "flex", background: "#404040", p: 1, color: "white" }}>Vector Embedded Content</Box>
<Box sx={{ display: "flex", p: 1, flexGrow: 1 }}>{node.content}</Box>
</Paper>
}
</Box>
}
<Box sx={{ display: "flex", flexDirection: "column", flexGrow: 2, flexBasis: 0, flexShrink: 1 }}>
{node === null &&
<Paper sx={{ m: 0.5, p: 2, flexGrow: 1 }}>
Click a point in the scatter-graph to see information about that node.
</Paper>
}
{node !== null && node.fullContent &&
<Scrollable
autoscroll={false}
sx={{
display: 'flex',
flexDirection: 'column',
m: 0,
p: 0.5,
pl: 1,
flexShrink: 1,
position: "relative",
maxWidth: "100%",
}}
>
{
node.fullContent.split('\n').map((line, index) => {
index += 1 + node.chunkBegin;
const bgColor = (index > node.lineBegin && index <= node.lineEnd) ? '#f0f0f0' : 'auto';
return <Box key={index} sx={{ display: "flex", flexDirection: "row", borderBottom: '1px solid #d0d0d0', ':first-of-type': { borderTop: '1px solid #d0d0d0' }, backgroundColor: bgColor }}>
<Box sx={{ fontFamily: 'courier', fontSize: "0.8rem", minWidth: "2rem", pt: "0.1rem", align: "left", verticalAlign: "top" }}>{index}</Box>
<pre style={{ margin: 0, padding: 0, border: "none", minHeight: "1rem", overflow: "hidden" }} >{line || " "}</pre>
</Box>;
})
}
{!node.lineBegin && <pre style={{ margin: 0, padding: 0, border: "none" }}>{node.content}</pre>}
</Scrollable>
}
</Box>
</Paper>
{!inline && querySet.query !== undefined && querySet.query !== '' &&
<Paper sx={{ display: 'flex', flexDirection: 'column', justifyContent: 'center', flexGrow: 0, minHeight: '2.5rem', maxHeight: '2.5rem', height: '2.5rem', alignItems: 'center', mt: 1, pb: 0 }}>
{querySet.query !== undefined && querySet.query !== '' && `Query: ${querySet.query}`}
{querySet.ids.length === 0 && "Enter query below to perform a similarity search."}
</Paper>
}
{
!inline &&
<Box className="Query" sx={{ display: "flex", flexDirection: "row", p: 1 }}>
<TextField
variant="outlined"
fullWidth
type="text"
value={newQuery}
onChange={(e) => setNewQuery(e.target.value)}
onKeyDown={handleKeyPress}
placeholder="Enter query to find related documents..."
id="QueryInput"
/>
<Tooltip title="Send">
<Button sx={{ m: 1 }} variant="contained" onClick={() => { sendQuery(newQuery); }}><SendIcon /></Button>
</Tooltip>
</Box>
}
</Box>
</Box>
);
};
export type { VectorVisualizerProps };
export {
VectorVisualizer,
};