import React, { useEffect, useState, useRef } from 'react'; import Box from '@mui/material/Box'; import Paper from '@mui/material/Paper'; import Plot from 'react-plotly.js'; import TextField from '@mui/material/TextField'; import Tooltip from '@mui/material/Tooltip'; import Button from '@mui/material/Button'; import SendIcon from '@mui/icons-material/Send'; import FormControlLabel from '@mui/material/FormControlLabel'; import Switch from '@mui/material/Switch'; import useMediaQuery from '@mui/material/useMediaQuery'; import { SxProps, useTheme } from '@mui/material/styles'; import Table from '@mui/material/Table'; import TableBody from '@mui/material/TableBody'; import TableCell from '@mui/material/TableCell'; import TableContainer from '@mui/material/TableContainer'; import TableRow from '@mui/material/TableRow'; import { Scrollable } from './Scrollable'; import './VectorVisualizer.css'; import { BackstoryPageProps } from './BackstoryTab'; import { useAuth } from 'hooks/AuthContext'; import * as Types from 'types/types'; import { useAppState } from 'hooks/GlobalContext'; import { useNavigate } from 'react-router-dom'; interface VectorVisualizerProps extends BackstoryPageProps { inline?: boolean; rag?: Types.ChromaDBGetResponse; } // interface Metadata { // id: string; // docType: string; // content: string; // distance?: number; // } type Metadata = Record; const emptyQuerySet: Types.ChromaDBGetResponse = { ids: [], documents: [], metadatas: [], embeddings: [], distances: [], name: 'Empty', size: 0, dimensions: 2, query: '', }; interface PlotData { name: string; mode: string; type: string; x: number[]; y: number[]; z?: number[]; text: string[]; marker: { color: string[]; size: number[]; symbol: string; opacity: number; }; customdata: Metadata[]; hovertemplate: string; } const config: Partial = { responsive: true, autosizable: true, displaylogo: false, showSendToCloud: false, staticPlot: false, frameMargins: 0, scrollZoom: false, doubleClick: false, // | "lasso2d" // | "select2d" // | "sendDataToCloud" // | "zoom2d" // | "pan2d" // | "zoomIn2d" // | "zoomOut2d" // | "autoScale2d" // | "resetScale2d" // | "hoverClosestCartesian" // | "hoverCompareCartesian" // | "zoom3d" // | "pan3d" // | "orbitRotation" // | "tableRotation" // | "handleDrag3d" // | "resetCameraDefault3d" // | "resetCameraLastSave3d" // | "hoverClosest3d" // | "zoomInGeo" // | "zoomOutGeo" // | "resetGeo" // | "hoverClosestGeo" // | "hoverClosestGl2d" // | "hoverClosestPie" // | "toggleHover" // | "toImage" // | "resetViews" // | "toggleSpikelines" // | "zoomInMapbox" // | "zoomOutMapbox" // | "resetViewMapbox" // | "togglespikelines" // | "togglehover" // | "hovercompare" // | "hoverclosest" // | "v1hovermode"; modeBarButtonsToRemove: ['lasso2d', 'select2d'], }; const layout: Partial = { autosize: false, clickmode: 'event+select', paper_bgcolor: '#FFFFFF', // white plot_bgcolor: '#FFFFFF', // white plot background font: { family: 'Roboto, sans-serif', color: '#2E2E2E', // charcoal black }, hovermode: 'closest', scene: { bgcolor: '#FFFFFF', // 3D plot background zaxis: { title: 'Z', gridcolor: '#cccccc', zerolinecolor: '#aaaaaa' }, }, xaxis: { title: 'X', gridcolor: '#cccccc', zerolinecolor: '#aaaaaa' }, yaxis: { title: 'Y', gridcolor: '#cccccc', zerolinecolor: '#aaaaaa' }, margin: { r: 0, b: 0, l: 0, t: 0 }, legend: { x: 0.8, // Horizontal position (0 to 1, 0 is left, 1 is right) y: 0, // Vertical position (0 to 1, 0 is bottom, 1 is top) xanchor: 'left', yanchor: 'top', orientation: 'h', // 'v' for horizontal legend }, showlegend: true, // Show the legend }; const normalizeDimension = (arr: number[]): number[] => { const min = Math.min(...arr); const max = Math.max(...arr); const range = max - min; if (range === 0) return arr.map(() => 0.5); // flat dimension return arr.map(v => (v - min) / range); }; const emojiMap: Record = { query: '🔍', resume: '📄', projects: '📁', jobs: '📁', 'performance-reviews': '📄', news: '📰', }; const colorMap: Record = { query: '#D4A017', // Golden Ochre — strong highlight resume: '#4A7A7D', // Dusty Teal — secondary theme color projects: '#1A2536', // Midnight Blue — rich and deep news: '#D3CDBF', // Warm Gray — soft and neutral 'performance-reviews': '#8FD0D0', // Light red jobs: '#F3aD8F', // Warm Gray — soft and neutral }; const DEFAULT_SIZE = 6; const DEFAULT_UNFOCUS_SIZE = 2; type Node = { id: string; content: string; // Portion of content that was used for embedding fullContent: string | undefined; // Portion of content plus/minus buffer emoji: string; docType: string; source_file: string; distance: number | undefined; path: string; chunkBegin: number; lineBegin: number; chunkEnd: number; lineEnd: number; sx: SxProps; }; const VectorVisualizer: React.FC = (props: VectorVisualizerProps) => { const { user, apiClient } = useAuth(); const { rag, inline, sx } = props; const { setSnack } = useAppState(); const [plotData, setPlotData] = useState(null); const [newQuery, setNewQuery] = useState(''); const [querySet, setQuerySet] = useState(rag || emptyQuerySet); const [result, setResult] = useState(null); const [view2D, setView2D] = useState(true); const plotlyRef = useRef(null); const boxRef = useRef(null); const [node, setNode] = useState(null); const theme = useTheme(); const isMobile = useMediaQuery(theme.breakpoints.down('md')); const [plotDimensions, setPlotDimensions] = useState({ width: 0, height: 0 }); const navigate = useNavigate(); const candidate: Types.Candidate | null = user?.userType === 'candidate' ? (user as Types.Candidate) : null; /* Force resize of Plotly as it tends to not be the correct size if it is initially rendered * off screen (eg., the VectorVisualizer is not on the tab the app loads to) */ useEffect(() => { if (!boxRef.current) { return; } const resize = (): void => { requestAnimationFrame(() => { const plotContainer = document.querySelector('.plot-container') as HTMLElement; const svgContainer = document?.querySelector('.svg-container') as HTMLElement; if (plotContainer && svgContainer) { const plotContainerRect = plotContainer.getBoundingClientRect(); svgContainer.style.width = `${plotContainerRect.width}px`; svgContainer.style.height = `${plotContainerRect.height}px`; if ( plotDimensions.width !== plotContainerRect.width || plotDimensions.height !== plotContainerRect.height ) { setPlotDimensions({ width: plotContainerRect.width, height: plotContainerRect.height, }); } } }); }; resize(); }); // Get the collection to visualize useEffect(() => { if (result) { return; } const fetchCollection = async (): Promise => { if (!candidate) { return; } try { const result = await apiClient.getCandidateVectors(view2D ? 2 : 3); setResult(result); } catch (error) { console.error('Error obtaining collection information:', error); setSnack('Unable to obtain collection information.', 'error'); } }; fetchCollection(); }, [result, setSnack, view2D, apiClient, candidate]); useEffect(() => { if (!result || !result.embeddings) return; if (result.embeddings.length === 0) return; const full: Types.ChromaDBGetResponse = { ...result, ids: [...(result.ids || [])], documents: [...(result.documents || [])], embeddings: [...result.embeddings], metadatas: [...(result.metadatas || [])], }; const is2D = full.embeddings.every((v: number[]) => v.length === 2); const is3D = full.embeddings.every((v: number[]) => v.length === 3); if ((view2D && !is2D) || (!view2D && !is3D)) { return; } if (!is2D && !is3D) { console.warn('Modified vectors are neither 2D nor 3D'); return; } const query: Types.ChromaDBGetResponse = { ids: [], documents: [], embeddings: [], metadatas: [], distances: [], query: '', size: 0, dimensions: 2, name: '', }; const filtered: Types.ChromaDBGetResponse = { ids: [], documents: [], embeddings: [], metadatas: [], distances: [], query: '', size: 0, dimensions: 2, name: '', }; /* Loop through all items and divide into two groups: * filtered is for any item not in the querySet * query is for any item that is in the querySet */ full.ids.forEach((id, index) => { const foundIndex = querySet.ids.indexOf(id); /* Update metadata to hold the doc content and id */ full.metadatas[index].id = id; full.metadatas[index].content = full.documents[index]; if (foundIndex !== -1) { /* The query set will contain the distance to the query */ full.metadatas[index].distance = querySet.distances ? querySet.distances[foundIndex] : undefined; query.ids.push(id); query.documents.push(full.documents[index]); query.embeddings.push(full.embeddings[index]); query.metadatas.push(full.metadatas[index]); } else { /* THe filtered set does not have a distance */ full.metadatas[index].distance = undefined; filtered.ids.push(id); filtered.documents.push(full.documents[index]); filtered.embeddings.push(full.embeddings[index]); filtered.metadatas.push(full.metadatas[index]); } }); if (view2D && querySet.umapEmbedding2D && querySet.umapEmbedding2D.length) { query.ids.unshift('query'); query.metadatas.unshift({ id: 'query', docType: 'query', content: querySet.query || '', distance: 0, }); query.embeddings.unshift(querySet.umapEmbedding2D); } if (!view2D && querySet.umapEmbedding3D && querySet.umapEmbedding3D.length) { query.ids.unshift('query'); query.metadatas.unshift({ id: 'query', docType: 'query', content: querySet.query || '', distance: 0, }); query.embeddings.unshift(querySet.umapEmbedding3D); } const filtered_docTypes = filtered.metadatas.map(m => m.docType || 'unknown'); const query_docTypes = query.metadatas.map(m => m.docType || 'unknown'); const has_query = query.metadatas.length > 0; const filtered_sizes = filtered.metadatas.map(_m => has_query ? DEFAULT_UNFOCUS_SIZE : DEFAULT_SIZE ); const filtered_colors = filtered_docTypes.map(type => colorMap[type] || '#4d4d4d'); const filtered_x = normalizeDimension(filtered.embeddings.map((v: number[]) => v[0])); const filtered_y = normalizeDimension(filtered.embeddings.map((v: number[]) => v[1])); const filtered_z = is3D ? normalizeDimension(filtered.embeddings.map((v: number[]) => v[2])) : undefined; const query_sizes = query.metadatas.map( m => DEFAULT_SIZE + 2 * DEFAULT_SIZE * Math.pow(1 - (m.distance || 1), 3) ); const query_colors = query_docTypes.map(type => colorMap[type] || '#4d4d4d'); const query_x = normalizeDimension(query.embeddings.map((v: number[]) => v[0])); const query_y = normalizeDimension(query.embeddings.map((v: number[]) => v[1])); const query_z = is3D ? normalizeDimension(query.embeddings.map((v: number[]) => v[2])) : undefined; const data: PlotData[] = [ { name: 'All data', x: filtered_x, y: filtered_y, mode: 'markers', marker: { size: filtered_sizes, symbol: 'circle', color: filtered_colors, opacity: 1, }, text: filtered.ids, customdata: filtered.metadatas, type: is3D ? 'scatter3d' : 'scatter', hovertemplate: ' ', }, { name: 'Query', x: query_x, y: query_y, mode: 'markers', marker: { size: query_sizes, symbol: 'circle', color: query_colors, opacity: 1, }, text: query.ids, customdata: query.metadatas, type: is3D ? 'scatter3d' : 'scatter', hovertemplate: '%{text}', }, ]; if (is3D) { data[0].z = filtered_z; data[1].z = query_z; } setPlotData(data); }, [result, querySet, view2D]); const handleKeyPress = (event: React.KeyboardEvent): void => { if (event.key === 'Enter') { sendQuery(newQuery); } }; const sendQuery = async (query: string): Promise => { if (!query.trim()) return; setNewQuery(''); try { const result = await apiClient.getCandidateSimilarContent(query); console.log(result); setQuerySet(result); } catch (error) { const msg = `Error obtaining similar content to ${query}.`; setSnack(msg, 'error'); } }; if (!result) return (
Loading visualization...
); if (!candidate) return (
No candidate selected. Please{' '} {' '} first.
); const fetchRAGMeta = async (node: Node): Promise => { try { const result = await apiClient.getCandidateRAGContent(node.id); const update: Node = { ...node, fullContent: result.content, }; setNode(update); } catch (error) { const msg = `Error obtaining content for ${node.id}.`; console.error(msg, error); setSnack(msg, 'error'); } }; const onNodeSelected = (metadata: Metadata): void => { let node: Partial; console.log(metadata); if (metadata.docType === 'query') { node = { ...metadata, content: `Similarity results for the query **${querySet.query || ''}** The scatter graph shows the query in N-dimensional space, mapped to ${ view2D ? '2' : '3' }-dimensional space. Larger dots represent relative similarity in N-dimensional space. `, emoji: emojiMap[metadata.docType], sx: { m: 0.5, p: 2, width: '3rem', display: 'flex', alignContent: 'center', justifyContent: 'center', flexGrow: 0, flexWrap: 'wrap', backgroundColor: colorMap[metadata.docType] || '#ff8080', }, }; setNode(node as Node); return; } node = { content: `Loading...`, ...metadata, emoji: emojiMap[metadata.docType] || '❓', }; setNode(node as Node); fetchRAGMeta(node as Node); }; return ( } onChange={(): void => { setView2D(!view2D); setResult(null); }} label="3D" /> { onNodeSelected(event.points[0].customdata); }} data={plotData} useResizeHandler={true} config={config} style={{ display: 'flex', flexGrow: 1, minHeight: '240px', padding: 0, margin: 0, width: '100%', height: '100%', overflow: 'hidden', }} layout={{ ...layout, width: plotDimensions.width, height: plotDimensions.height, }} /> {node !== null && ( Type {node.emoji} {node.docType} {node.source_file !== undefined && ( File {node.source_file.replace(/^.*\//, '')} )} {node.path !== undefined && ( Section {node.path} )} {node.distance !== undefined && ( Distance {node.distance} )}
{node.content !== '' && node.content !== undefined && ( Vector Embedded Content {node.content} )}
)} {node === null && ( Click a point in the scatter-graph to see information about that node. )} {node !== null && node.fullContent && ( {node.fullContent.split('\n').map((line, index) => { index += 1 + node.chunkBegin; const bgColor = index > node.lineBegin && index <= node.lineEnd ? '#f0f0f0' : 'auto'; return ( {index}
                        {line || ' '}
                      
); })} {!node.lineBegin && (
{node.content}
)}
)}
{!inline && querySet.query !== undefined && querySet.query !== '' && ( {querySet.query !== undefined && querySet.query !== '' && `Query: ${querySet.query}`} {querySet.ids.length === 0 && 'Enter query below to perform a similarity search.'} )} {!inline && ( { setNewQuery(e.target.value); }} onKeyDown={handleKeyPress} placeholder="Enter query to find related documents..." id="QueryInput" /> )}
); }; export type { VectorVisualizerProps }; export { VectorVisualizer };