backstory/frontend/src/components/VectorVisualizer.tsx

822 lines
24 KiB
TypeScript

import React, { useEffect, useState, useRef } from 'react';
import Box from '@mui/material/Box';
import Paper from '@mui/material/Paper';
import Plot from 'react-plotly.js';
import TextField from '@mui/material/TextField';
import Tooltip from '@mui/material/Tooltip';
import Button from '@mui/material/Button';
import SendIcon from '@mui/icons-material/Send';
import FormControlLabel from '@mui/material/FormControlLabel';
import Switch from '@mui/material/Switch';
import useMediaQuery from '@mui/material/useMediaQuery';
import { SxProps, useTheme } from '@mui/material/styles';
import Table from '@mui/material/Table';
import TableBody from '@mui/material/TableBody';
import TableCell from '@mui/material/TableCell';
import TableContainer from '@mui/material/TableContainer';
import TableRow from '@mui/material/TableRow';
import { Scrollable } from './Scrollable';
import './VectorVisualizer.css';
import { BackstoryPageProps } from './BackstoryTab';
import { useAuth } from 'hooks/AuthContext';
import * as Types from 'types/types';
import { useAppState } from 'hooks/GlobalContext';
import { useNavigate } from 'react-router-dom';
interface VectorVisualizerProps extends BackstoryPageProps {
inline?: boolean;
rag?: Types.ChromaDBGetResponse;
}
// interface Metadata {
// id: string;
// docType: string;
// content: string;
// distance?: number;
// }
type Metadata = Record<string, string | number>;
const emptyQuerySet: Types.ChromaDBGetResponse = {
ids: [],
documents: [],
metadatas: [],
embeddings: [],
distances: [],
name: 'Empty',
size: 0,
dimensions: 2,
query: '',
};
interface PlotData {
name: string;
mode: string;
type: string;
x: number[];
y: number[];
z?: number[];
text: string[];
marker: {
color: string[];
size: number[];
symbol: string;
opacity: number;
};
customdata: Metadata[];
hovertemplate: string;
}
const config: Partial<Plotly.Config> = {
responsive: true,
autosizable: true,
displaylogo: false,
showSendToCloud: false,
staticPlot: false,
frameMargins: 0,
scrollZoom: false,
doubleClick: false,
// | "lasso2d"
// | "select2d"
// | "sendDataToCloud"
// | "zoom2d"
// | "pan2d"
// | "zoomIn2d"
// | "zoomOut2d"
// | "autoScale2d"
// | "resetScale2d"
// | "hoverClosestCartesian"
// | "hoverCompareCartesian"
// | "zoom3d"
// | "pan3d"
// | "orbitRotation"
// | "tableRotation"
// | "handleDrag3d"
// | "resetCameraDefault3d"
// | "resetCameraLastSave3d"
// | "hoverClosest3d"
// | "zoomInGeo"
// | "zoomOutGeo"
// | "resetGeo"
// | "hoverClosestGeo"
// | "hoverClosestGl2d"
// | "hoverClosestPie"
// | "toggleHover"
// | "toImage"
// | "resetViews"
// | "toggleSpikelines"
// | "zoomInMapbox"
// | "zoomOutMapbox"
// | "resetViewMapbox"
// | "togglespikelines"
// | "togglehover"
// | "hovercompare"
// | "hoverclosest"
// | "v1hovermode";
modeBarButtonsToRemove: ['lasso2d', 'select2d'],
};
const layout: Partial<Plotly.Layout> = {
autosize: false,
clickmode: 'event+select',
paper_bgcolor: '#FFFFFF', // white
plot_bgcolor: '#FFFFFF', // white plot background
font: {
family: 'Roboto, sans-serif',
color: '#2E2E2E', // charcoal black
},
hovermode: 'closest',
scene: {
bgcolor: '#FFFFFF', // 3D plot background
zaxis: { title: 'Z', gridcolor: '#cccccc', zerolinecolor: '#aaaaaa' },
},
xaxis: { title: 'X', gridcolor: '#cccccc', zerolinecolor: '#aaaaaa' },
yaxis: { title: 'Y', gridcolor: '#cccccc', zerolinecolor: '#aaaaaa' },
margin: { r: 0, b: 0, l: 0, t: 0 },
legend: {
x: 0.8, // Horizontal position (0 to 1, 0 is left, 1 is right)
y: 0, // Vertical position (0 to 1, 0 is bottom, 1 is top)
xanchor: 'left',
yanchor: 'top',
orientation: 'h', // 'v' for horizontal legend
},
showlegend: true, // Show the legend
};
const normalizeDimension = (arr: number[]): number[] => {
const min = Math.min(...arr);
const max = Math.max(...arr);
const range = max - min;
if (range === 0) return arr.map(() => 0.5); // flat dimension
return arr.map(v => (v - min) / range);
};
const emojiMap: Record<string, string> = {
query: '🔍',
resume: '📄',
projects: '📁',
jobs: '📁',
'performance-reviews': '📄',
news: '📰',
};
const colorMap: Record<string, string> = {
query: '#D4A017', // Golden Ochre — strong highlight
resume: '#4A7A7D', // Dusty Teal — secondary theme color
projects: '#1A2536', // Midnight Blue — rich and deep
news: '#D3CDBF', // Warm Gray — soft and neutral
'performance-reviews': '#8FD0D0', // Light red
jobs: '#F3aD8F', // Warm Gray — soft and neutral
};
const DEFAULT_SIZE = 6;
const DEFAULT_UNFOCUS_SIZE = 2;
type Node = {
id: string;
content: string; // Portion of content that was used for embedding
fullContent: string | undefined; // Portion of content plus/minus buffer
emoji: string;
docType: string;
source_file: string;
distance: number | undefined;
path: string;
chunkBegin: number;
lineBegin: number;
chunkEnd: number;
lineEnd: number;
sx: SxProps;
};
const VectorVisualizer: React.FC<VectorVisualizerProps> = (props: VectorVisualizerProps) => {
const { user, apiClient } = useAuth();
const { rag, inline, sx } = props;
const { setSnack } = useAppState();
const [plotData, setPlotData] = useState<PlotData[] | null>(null);
const [newQuery, setNewQuery] = useState<string>('');
const [querySet, setQuerySet] = useState<Types.ChromaDBGetResponse>(rag || emptyQuerySet);
const [result, setResult] = useState<Types.ChromaDBGetResponse | null>(null);
const [view2D, setView2D] = useState<boolean>(true);
const plotlyRef = useRef(null);
const boxRef = useRef<HTMLElement>(null);
const [node, setNode] = useState<Node | null>(null);
const theme = useTheme();
const isMobile = useMediaQuery(theme.breakpoints.down('md'));
const [plotDimensions, setPlotDimensions] = useState({ width: 0, height: 0 });
const navigate = useNavigate();
const candidate: Types.Candidate | null =
user?.userType === 'candidate' ? (user as Types.Candidate) : null;
/* Force resize of Plotly as it tends to not be the correct size if it is initially rendered
* off screen (eg., the VectorVisualizer is not on the tab the app loads to) */
useEffect(() => {
if (!boxRef.current) {
return;
}
const resize = (): void => {
requestAnimationFrame(() => {
const plotContainer = document.querySelector('.plot-container') as HTMLElement;
const svgContainer = document?.querySelector('.svg-container') as HTMLElement;
if (plotContainer && svgContainer) {
const plotContainerRect = plotContainer.getBoundingClientRect();
svgContainer.style.width = `${plotContainerRect.width}px`;
svgContainer.style.height = `${plotContainerRect.height}px`;
if (
plotDimensions.width !== plotContainerRect.width ||
plotDimensions.height !== plotContainerRect.height
) {
setPlotDimensions({
width: plotContainerRect.width,
height: plotContainerRect.height,
});
}
}
});
};
resize();
});
// Get the collection to visualize
useEffect(() => {
if (result) {
return;
}
const fetchCollection = async (): Promise<void> => {
if (!candidate) {
return;
}
try {
const result = await apiClient.getCandidateVectors(view2D ? 2 : 3);
setResult(result);
} catch (error) {
console.error('Error obtaining collection information:', error);
setSnack('Unable to obtain collection information.', 'error');
}
};
fetchCollection();
}, [result, setSnack, view2D, apiClient, candidate]);
useEffect(() => {
if (!result || !result.embeddings) return;
if (result.embeddings.length === 0) return;
const full: Types.ChromaDBGetResponse = {
...result,
ids: [...(result.ids || [])],
documents: [...(result.documents || [])],
embeddings: [...result.embeddings],
metadatas: [...(result.metadatas || [])],
};
const is2D = full.embeddings.every((v: number[]) => v.length === 2);
const is3D = full.embeddings.every((v: number[]) => v.length === 3);
if ((view2D && !is2D) || (!view2D && !is3D)) {
return;
}
if (!is2D && !is3D) {
console.warn('Modified vectors are neither 2D nor 3D');
return;
}
const query: Types.ChromaDBGetResponse = {
ids: [],
documents: [],
embeddings: [],
metadatas: [],
distances: [],
query: '',
size: 0,
dimensions: 2,
name: '',
};
const filtered: Types.ChromaDBGetResponse = {
ids: [],
documents: [],
embeddings: [],
metadatas: [],
distances: [],
query: '',
size: 0,
dimensions: 2,
name: '',
};
/* Loop through all items and divide into two groups:
* filtered is for any item not in the querySet
* query is for any item that is in the querySet
*/
full.ids.forEach((id, index) => {
const foundIndex = querySet.ids.indexOf(id);
/* Update metadata to hold the doc content and id */
full.metadatas[index].id = id;
full.metadatas[index].content = full.documents[index];
if (foundIndex !== -1) {
/* The query set will contain the distance to the query */
full.metadatas[index].distance = querySet.distances
? querySet.distances[foundIndex]
: undefined;
query.ids.push(id);
query.documents.push(full.documents[index]);
query.embeddings.push(full.embeddings[index]);
query.metadatas.push(full.metadatas[index]);
} else {
/* THe filtered set does not have a distance */
full.metadatas[index].distance = undefined;
filtered.ids.push(id);
filtered.documents.push(full.documents[index]);
filtered.embeddings.push(full.embeddings[index]);
filtered.metadatas.push(full.metadatas[index]);
}
});
if (view2D && querySet.umapEmbedding2D && querySet.umapEmbedding2D.length) {
query.ids.unshift('query');
query.metadatas.unshift({
id: 'query',
docType: 'query',
content: querySet.query || '',
distance: 0,
});
query.embeddings.unshift(querySet.umapEmbedding2D);
}
if (!view2D && querySet.umapEmbedding3D && querySet.umapEmbedding3D.length) {
query.ids.unshift('query');
query.metadatas.unshift({
id: 'query',
docType: 'query',
content: querySet.query || '',
distance: 0,
});
query.embeddings.unshift(querySet.umapEmbedding3D);
}
const filtered_docTypes = filtered.metadatas.map(m => m.docType || 'unknown');
const query_docTypes = query.metadatas.map(m => m.docType || 'unknown');
const has_query = query.metadatas.length > 0;
const filtered_sizes = filtered.metadatas.map(_m =>
has_query ? DEFAULT_UNFOCUS_SIZE : DEFAULT_SIZE
);
const filtered_colors = filtered_docTypes.map(type => colorMap[type] || '#4d4d4d');
const filtered_x = normalizeDimension(filtered.embeddings.map((v: number[]) => v[0]));
const filtered_y = normalizeDimension(filtered.embeddings.map((v: number[]) => v[1]));
const filtered_z = is3D
? normalizeDimension(filtered.embeddings.map((v: number[]) => v[2]))
: undefined;
const query_sizes = query.metadatas.map(
m => DEFAULT_SIZE + 2 * DEFAULT_SIZE * Math.pow(1 - (m.distance || 1), 3)
);
const query_colors = query_docTypes.map(type => colorMap[type] || '#4d4d4d');
const query_x = normalizeDimension(query.embeddings.map((v: number[]) => v[0]));
const query_y = normalizeDimension(query.embeddings.map((v: number[]) => v[1]));
const query_z = is3D
? normalizeDimension(query.embeddings.map((v: number[]) => v[2]))
: undefined;
const data: PlotData[] = [
{
name: 'All data',
x: filtered_x,
y: filtered_y,
mode: 'markers',
marker: {
size: filtered_sizes,
symbol: 'circle',
color: filtered_colors,
opacity: 1,
},
text: filtered.ids,
customdata: filtered.metadatas,
type: is3D ? 'scatter3d' : 'scatter',
hovertemplate: '&nbsp;',
},
{
name: 'Query',
x: query_x,
y: query_y,
mode: 'markers',
marker: {
size: query_sizes,
symbol: 'circle',
color: query_colors,
opacity: 1,
},
text: query.ids,
customdata: query.metadatas,
type: is3D ? 'scatter3d' : 'scatter',
hovertemplate: '%{text}',
},
];
if (is3D) {
data[0].z = filtered_z;
data[1].z = query_z;
}
setPlotData(data);
}, [result, querySet, view2D]);
const handleKeyPress = (event: React.KeyboardEvent<HTMLDivElement>): void => {
if (event.key === 'Enter') {
sendQuery(newQuery);
}
};
const sendQuery = async (query: string): Promise<void> => {
if (!query.trim()) return;
setNewQuery('');
try {
const result = await apiClient.getCandidateSimilarContent(query);
console.log(result);
setQuerySet(result);
} catch (error) {
const msg = `Error obtaining similar content to ${query}.`;
setSnack(msg, 'error');
}
};
if (!result)
return (
<Box
sx={{
display: 'flex',
flexGrow: 1,
justifyContent: 'center',
alignItems: 'center',
}}
>
<div>Loading visualization...</div>
</Box>
);
if (!candidate)
return (
<Box
sx={{
display: 'flex',
flexGrow: 1,
justifyContent: 'center',
alignItems: 'center',
}}
>
<div>
No candidate selected. Please{' '}
<Button
onClick={(): void => {
navigate('/find-a-candidate');
}}
>
select a candidate
</Button>{' '}
first.
</div>
</Box>
);
const fetchRAGMeta = async (node: Node): Promise<void> => {
try {
const result = await apiClient.getCandidateRAGContent(node.id);
const update: Node = {
...node,
fullContent: result.content,
};
setNode(update);
} catch (error) {
const msg = `Error obtaining content for ${node.id}.`;
console.error(msg, error);
setSnack(msg, 'error');
}
};
const onNodeSelected = (metadata: Metadata): void => {
let node: Partial<Node>;
console.log(metadata);
if (metadata.docType === 'query') {
node = {
...metadata,
content: `Similarity results for the query **${querySet.query || ''}**
The scatter graph shows the query in N-dimensional space, mapped to ${
view2D ? '2' : '3'
}-dimensional space. Larger dots represent relative similarity in N-dimensional space.
`,
emoji: emojiMap[metadata.docType],
sx: {
m: 0.5,
p: 2,
width: '3rem',
display: 'flex',
alignContent: 'center',
justifyContent: 'center',
flexGrow: 0,
flexWrap: 'wrap',
backgroundColor: colorMap[metadata.docType] || '#ff8080',
},
};
setNode(node as Node);
return;
}
node = {
content: `Loading...`,
...metadata,
emoji: emojiMap[metadata.docType] || '❓',
};
setNode(node as Node);
fetchRAGMeta(node as Node);
};
return (
<Box
className="VectorVisualizer"
ref={boxRef}
sx={{
...sx,
}}
>
<Box sx={{ p: 0, m: 0, gap: 0 }}>
<Paper
sx={{
p: 0.5,
m: 0,
display: 'flex',
flexGrow: 0,
height: isMobile ? 'auto' : 'auto', //"320px",
minHeight: isMobile ? 'auto' : 'auto', //"320px",
maxHeight: isMobile ? 'auto' : 'auto', //"320px",
position: 'relative',
flexDirection: 'column',
}}
>
<FormControlLabel
sx={{
display: 'flex',
position: 'relative',
width: 'fit-content',
ml: 1,
mb: '-2.5rem',
zIndex: 100,
flexBasis: 0,
flexGrow: 0,
}}
control={<Switch checked={!view2D} />}
onChange={(): void => {
setView2D(!view2D);
setResult(null);
}}
label="3D"
/>
<Plot
ref={plotlyRef}
onClick={(event: { points: { customdata: Metadata }[] }): void => {
onNodeSelected(event.points[0].customdata);
}}
data={plotData}
useResizeHandler={true}
config={config}
style={{
display: 'flex',
flexGrow: 1,
minHeight: '240px',
padding: 0,
margin: 0,
width: '100%',
height: '100%',
overflow: 'hidden',
}}
layout={{
...layout,
width: plotDimensions.width,
height: plotDimensions.height,
}}
/>
</Paper>
<Paper
sx={{
display: 'flex',
flexDirection: isMobile ? 'column' : 'row',
mt: 0.5,
p: 0.5,
flexGrow: 1,
minHeight: 'fit-content',
}}
>
{node !== null && (
<Box
sx={{
display: 'flex',
fontSize: '0.75rem',
flexDirection: 'column',
flexGrow: 1,
maxWidth: '100%',
flexBasis: 1,
maxHeight: 'min-content',
}}
>
<TableContainer component={Paper} sx={{ mb: isMobile ? 1 : 0, mr: isMobile ? 0 : 1 }}>
<Table size="small" sx={{ tableLayout: 'fixed' }}>
<TableBody
sx={{
'& td': { verticalAlign: 'top', fontSize: '0.75rem' },
'& td:first-of-type': {
whiteSpace: 'nowrap',
width: '1rem',
},
}}
>
<TableRow>
<TableCell>Type</TableCell>
<TableCell>
{node.emoji} {node.docType}
</TableCell>
</TableRow>
{node.source_file !== undefined && (
<TableRow>
<TableCell>File</TableCell>
<TableCell>{node.source_file.replace(/^.*\//, '')}</TableCell>
</TableRow>
)}
{node.path !== undefined && (
<TableRow>
<TableCell>Section</TableCell>
<TableCell>{node.path}</TableCell>
</TableRow>
)}
{node.distance !== undefined && (
<TableRow>
<TableCell>Distance</TableCell>
<TableCell>{node.distance}</TableCell>
</TableRow>
)}
</TableBody>
</Table>
</TableContainer>
{node.content !== '' && node.content !== undefined && (
<Paper
elevation={6}
sx={{
display: 'flex',
flexDirection: 'column',
border: '1px solid #808080',
minHeight: 'fit-content',
mt: 1,
}}
>
<Box
sx={{
display: 'flex',
background: '#404040',
p: 1,
color: 'white',
}}
>
Vector Embedded Content
</Box>
<Box sx={{ display: 'flex', p: 1, flexGrow: 1 }}>{node.content}</Box>
</Paper>
)}
</Box>
)}
<Box
sx={{
display: 'flex',
flexDirection: 'column',
flexGrow: 2,
flexBasis: 0,
flexShrink: 1,
}}
>
{node === null && (
<Paper sx={{ m: 0.5, p: 2, flexGrow: 1 }}>
Click a point in the scatter-graph to see information about that node.
</Paper>
)}
{node !== null && node.fullContent && (
<Scrollable
autoscroll={false}
sx={{
display: 'flex',
flexDirection: 'column',
m: 0,
p: 0.5,
pl: 1,
flexShrink: 1,
position: 'relative',
maxWidth: '100%',
}}
>
{node.fullContent.split('\n').map((line, index) => {
index += 1 + node.chunkBegin;
const bgColor =
index > node.lineBegin && index <= node.lineEnd ? '#f0f0f0' : 'auto';
return (
<Box
key={index}
sx={{
display: 'flex',
flexDirection: 'row',
borderBottom: '1px solid #d0d0d0',
':first-of-type': { borderTop: '1px solid #d0d0d0' },
backgroundColor: bgColor,
}}
>
<Box
sx={{
fontFamily: 'courier',
fontSize: '0.8rem',
minWidth: '2rem',
pt: '0.1rem',
align: 'left',
verticalAlign: 'top',
}}
>
{index}
</Box>
<pre
style={{
margin: 0,
padding: 0,
border: 'none',
minHeight: '1rem',
overflow: 'hidden',
}}
>
{line || ' '}
</pre>
</Box>
);
})}
{!node.lineBegin && (
<pre style={{ margin: 0, padding: 0, border: 'none' }}>{node.content}</pre>
)}
</Scrollable>
)}
</Box>
</Paper>
{!inline && querySet.query !== undefined && querySet.query !== '' && (
<Paper
sx={{
display: 'flex',
flexDirection: 'column',
justifyContent: 'center',
flexGrow: 0,
minHeight: '2.5rem',
maxHeight: '2.5rem',
height: '2.5rem',
alignItems: 'center',
mt: 1,
pb: 0,
}}
>
{querySet.query !== undefined && querySet.query !== '' && `Query: ${querySet.query}`}
{querySet.ids.length === 0 && 'Enter query below to perform a similarity search.'}
</Paper>
)}
{!inline && (
<Box className="Query" sx={{ display: 'flex', flexDirection: 'row', p: 1 }}>
<TextField
variant="outlined"
fullWidth
type="text"
value={newQuery}
onChange={(e): void => {
setNewQuery(e.target.value);
}}
onKeyDown={handleKeyPress}
placeholder="Enter query to find related documents..."
id="QueryInput"
/>
<Tooltip title="Send">
<Button
sx={{ m: 1 }}
variant="contained"
onClick={(): void => {
sendQuery(newQuery);
}}
>
<SendIcon />
</Button>
</Tooltip>
</Box>
)}
</Box>
</Box>
);
};
export type { VectorVisualizerProps };
export { VectorVisualizer };