import React, { useEffect, useState, useRef } from 'react'; import Box from '@mui/material/Box'; import Paper from '@mui/material/Paper'; import Plot from 'react-plotly.js'; import TextField from '@mui/material/TextField'; import Tooltip from '@mui/material/Tooltip'; import Button from '@mui/material/Button'; import SendIcon from '@mui/icons-material/Send'; import FormControlLabel from '@mui/material/FormControlLabel'; import Switch from '@mui/material/Switch'; import useMediaQuery from '@mui/material/useMediaQuery'; import { SxProps, useTheme } from '@mui/material/styles'; import Table from '@mui/material/Table'; import TableBody from '@mui/material/TableBody'; import TableCell from '@mui/material/TableCell'; import TableContainer from '@mui/material/TableContainer'; import TableRow from '@mui/material/TableRow'; import { Scrollable } from './Scrollable'; import './VectorVisualizer.css'; import { BackstoryPageProps } from './BackstoryTab'; import { useAuth } from 'hooks/AuthContext'; import * as Types from 'types/types'; import { useAppState, useSelectedCandidate } from 'hooks/GlobalContext'; import { useNavigate } from 'react-router-dom'; interface VectorVisualizerProps extends BackstoryPageProps { inline?: boolean; rag?: any; }; interface Metadata { id: string; docType: string; content: string; distance?: number; } const emptyQuerySet = { ids: [], documents: [], metadatas: [], embeddings: [], }; interface PlotData { x: number[]; y: number[]; z?: number[]; colors: string[]; text: string[]; sizes: number[]; customdata: Metadata[]; } const config: Partial = { responsive: true, autosizable: true, displaylogo: false, showSendToCloud: false, staticPlot: false, frameMargins: 0, scrollZoom: false, doubleClick: false, // | "lasso2d" // | "select2d" // | "sendDataToCloud" // | "zoom2d" // | "pan2d" // | "zoomIn2d" // | "zoomOut2d" // | "autoScale2d" // | "resetScale2d" // | "hoverClosestCartesian" // | "hoverCompareCartesian" // | "zoom3d" // | "pan3d" // | "orbitRotation" // | "tableRotation" // | "handleDrag3d" // | "resetCameraDefault3d" // | "resetCameraLastSave3d" // | "hoverClosest3d" // | "zoomInGeo" // | "zoomOutGeo" // | "resetGeo" // | "hoverClosestGeo" // | "hoverClosestGl2d" // | "hoverClosestPie" // | "toggleHover" // | "toImage" // | "resetViews" // | "toggleSpikelines" // | "zoomInMapbox" // | "zoomOutMapbox" // | "resetViewMapbox" // | "togglespikelines" // | "togglehover" // | "hovercompare" // | "hoverclosest" // | "v1hovermode"; modeBarButtonsToRemove: [ 'lasso2d', 'select2d', ] }; const layout: Partial = { autosize: false, clickmode: 'event+select', paper_bgcolor: '#FFFFFF', // white plot_bgcolor: '#FFFFFF', // white plot background font: { family: 'Roboto, sans-serif', color: '#2E2E2E', // charcoal black }, hovermode: 'closest', scene: { bgcolor: '#FFFFFF', // 3D plot background zaxis: { title: 'Z', gridcolor: '#cccccc', zerolinecolor: '#aaaaaa' }, }, xaxis: { title: 'X', gridcolor: '#cccccc', zerolinecolor: '#aaaaaa' }, yaxis: { title: 'Y', gridcolor: '#cccccc', zerolinecolor: '#aaaaaa' }, margin: { r: 0, b: 0, l: 0, t: 0 }, legend: { x: 0.8, // Horizontal position (0 to 1, 0 is left, 1 is right) y: 0, // Vertical position (0 to 1, 0 is bottom, 1 is top) xanchor: 'left', yanchor: 'top', orientation: 'h' // 'v' for horizontal legend }, showlegend: true // Show the legend }; const normalizeDimension = (arr: number[]): number[] => { const min = Math.min(...arr); const max = Math.max(...arr); const range = max - min; if (range === 0) return arr.map(() => 0.5); // flat dimension return arr.map(v => (v - min) / range); }; const emojiMap: Record = { query: '🔍', resume: '📄', projects: '📁', jobs: '📁', 'performance-reviews': '📄', news: '📰', }; const colorMap: Record = { query: '#D4A017', // Golden Ochre — strong highlight resume: '#4A7A7D', // Dusty Teal — secondary theme color projects: '#1A2536', // Midnight Blue — rich and deep news: '#D3CDBF', // Warm Gray — soft and neutral 'performance-reviews': '#8FD0D0', // Light red 'jobs': '#F3aD8F', // Warm Gray — soft and neutral }; const DEFAULT_SIZE = 6.; const DEFAULT_UNFOCUS_SIZE = 2.; type Node = { id: string, content: string, // Portion of content that was used for embedding fullContent: string | undefined, // Portion of content plus/minus buffer emoji: string, docType: string, source_file: string, distance: number | undefined, path: string, chunkBegin: number, lineBegin: number, chunkEnd: number, lineEnd: number, sx: SxProps, }; const VectorVisualizer: React.FC = (props: VectorVisualizerProps) => { const { user, apiClient } = useAuth(); const { rag, inline, sx } = props; const { setSnack } = useAppState(); const [plotData, setPlotData] = useState(null); const [newQuery, setNewQuery] = useState(''); const [querySet, setQuerySet] = useState(rag || emptyQuerySet); const [result, setResult] = useState(null); const [view2D, setView2D] = useState(true); const plotlyRef = useRef(null); const boxRef = useRef(null); const [node, setNode] = useState(null); const theme = useTheme(); const isMobile = useMediaQuery(theme.breakpoints.down('md')); const [plotDimensions, setPlotDimensions] = useState({ width: 0, height: 0 }); const navigate = useNavigate(); const candidate: Types.Candidate | null = user?.userType === 'candidate' ? user as Types.Candidate : null; /* Force resize of Plotly as it tends to not be the correct size if it is initially rendered * off screen (eg., the VectorVisualizer is not on the tab the app loads to) */ useEffect(() => { if (!boxRef.current) { return; } const resize = () => { requestAnimationFrame(() => { const plotContainer = document.querySelector('.plot-container') as HTMLElement; const svgContainer = document?.querySelector('.svg-container') as HTMLElement; if (plotContainer && svgContainer) { const plotContainerRect = plotContainer.getBoundingClientRect(); svgContainer.style.width = `${plotContainerRect.width}px`; svgContainer.style.height = `${plotContainerRect.height}px`; if (plotDimensions.width !== plotContainerRect.width || plotDimensions.height !== plotContainerRect.height) { setPlotDimensions({ width: plotContainerRect.width, height: plotContainerRect.height }); } } }); } resize(); }); // Get the collection to visualize useEffect(() => { if (result) { return; } const fetchCollection = async () => { if (!candidate) { return; } try { const result = await apiClient.getCandidateVectors(view2D ? 2 : 3); setResult(result); } catch (error) { console.error('Error obtaining collection information:', error); setSnack("Unable to obtain collection information.", "error"); }; }; fetchCollection(); }, [result, setSnack, view2D]) useEffect(() => { if (!result || !result.embeddings) return; if (result.embeddings.length === 0) return; const full: Types.ChromaDBGetResponse = { ...result, ids: [...result.ids || []], documents: [...result.documents || []], embeddings: [...result.embeddings], metadatas: [...result.metadatas || []], }; let is2D = full.embeddings.every((v: number[]) => v.length === 2); let is3D = full.embeddings.every((v: number[]) => v.length === 3); if ((view2D && !is2D) || (!view2D && !is3D)) { return; } if (!is2D && !is3D) { console.warn('Modified vectors are neither 2D nor 3D'); return; } let query: Types.ChromaDBGetResponse = { ids: [], documents: [], embeddings: [], metadatas: [], distances: [], query: '', size: 0, dimensions: 2, name: '' }; let filtered: Types.ChromaDBGetResponse = { ids: [], documents: [], embeddings: [], metadatas: [], distances: [], query: '', size: 0, dimensions: 2, name: '' }; /* Loop through all items and divide into two groups: * filtered is for any item not in the querySet * query is for any item that is in the querySet */ full.ids.forEach((id, index) => { const foundIndex = querySet.ids.indexOf(id); /* Update metadata to hold the doc content and id */ full.metadatas[index].id = id; full.metadatas[index].content = full.documents[index]; if (foundIndex !== -1) { /* The query set will contain the distance to the query */ full.metadatas[index].distance = querySet.distances ? querySet.distances[foundIndex] : undefined; query.ids.push(id); query.documents.push(full.documents[index]); query.embeddings.push(full.embeddings[index]); query.metadatas.push(full.metadatas[index]); } else { /* THe filtered set does not have a distance */ full.metadatas[index].distance = undefined; filtered.ids.push(id); filtered.documents.push(full.documents[index]); filtered.embeddings.push(full.embeddings[index]); filtered.metadatas.push(full.metadatas[index]); } }); if (view2D && querySet.umapEmbedding2D && querySet.umapEmbedding2D.length) { query.ids.unshift('query'); query.metadatas.unshift({ id: 'query', docType: 'query', content: querySet.query || '', distance: 0 }); query.embeddings.unshift(querySet.umapEmbedding2D); } if (!view2D && querySet.umapEmbedding3D && querySet.umapEmbedding3D.length) { query.ids.unshift('query'); query.metadatas.unshift({ id: 'query', docType: 'query', content: querySet.query || '', distance: 0 }); query.embeddings.unshift(querySet.umapEmbedding3D); } const filtered_docTypes = filtered.metadatas.map(m => m.docType || 'unknown') const query_docTypes = query.metadatas.map(m => m.docType || 'unknown') const has_query = query.metadatas.length > 0; const filtered_sizes = filtered.metadatas.map(m => has_query ? DEFAULT_UNFOCUS_SIZE : DEFAULT_SIZE); const filtered_colors = filtered_docTypes.map(type => colorMap[type] || '#4d4d4d'); const filtered_x = normalizeDimension(filtered.embeddings.map((v: number[]) => v[0])); const filtered_y = normalizeDimension(filtered.embeddings.map((v: number[]) => v[1])); const filtered_z = is3D ? normalizeDimension(filtered.embeddings.map((v: number[]) => v[2])) : undefined; const query_sizes = query.metadatas.map(m => DEFAULT_SIZE + 2. * DEFAULT_SIZE * Math.pow((1. - (m.distance || 1.)), 3)); const query_colors = query_docTypes.map(type => colorMap[type] || '#4d4d4d'); const query_x = normalizeDimension(query.embeddings.map((v: number[]) => v[0])); const query_y = normalizeDimension(query.embeddings.map((v: number[]) => v[1])); const query_z = is3D ? normalizeDimension(query.embeddings.map((v: number[]) => v[2])) : undefined; const data: any = [{ name: 'All data', x: filtered_x, y: filtered_y, mode: 'markers', marker: { size: filtered_sizes, symbol: 'circle', color: filtered_colors, opacity: 1 }, text: filtered.ids, customdata: filtered.metadatas, type: is3D ? 'scatter3d' : 'scatter', hovertemplate: ' ', }, { name: 'Query', x: query_x, y: query_y, mode: 'markers', marker: { size: query_sizes, symbol: 'circle', color: query_colors, opacity: 1 }, text: query.ids, customdata: query.metadatas, type: is3D ? 'scatter3d' : 'scatter', hovertemplate: '%{text}', }]; if (is3D) { data[0].z = filtered_z; data[1].z = query_z; } setPlotData(data); }, [result, querySet, view2D]); const handleKeyPress = (event: any) => { if (event.key === 'Enter') { sendQuery(newQuery); } }; const sendQuery = async (query: string) => { if (!query.trim()) return; setNewQuery(''); try { const result = await apiClient.getCandidateSimilarContent(query); console.log(result); setQuerySet(result); } catch (error) { const msg = `Error obtaining similar content to ${query}.` setSnack(msg, "error"); }; }; if (!result) return (
Loading visualization...
); if (!candidate) return (
No candidate selected. Please first.
); const fetchRAGMeta = async (node: Node) => { try { const result = await apiClient.getCandidateRAGContent(node.id); const update: Node = { ...node, fullContent: result.content } setNode(update); } catch (error) { const msg = `Error obtaining content for ${node.id}.` console.error(msg, error); setSnack(msg, "error"); }; }; const onNodeSelected = (metadata: any) => { let node: Node; console.log(metadata); if (metadata.docType === 'query') { node = { ...metadata, content: `Similarity results for the query **${querySet.query || ''}** The scatter graph shows the query in N-dimensional space, mapped to ${view2D ? '2' : '3'}-dimensional space. Larger dots represent relative similarity in N-dimensional space. `, emoji: emojiMap[metadata.docType], sx: { m: 0.5, p: 2, width: '3rem', display: "flex", alignContent: "center", justifyContent: "center", flexGrow: 0, flexWrap: "wrap", backgroundColor: colorMap[metadata.docType] || '#ff8080', } } setNode(node); return; } node = { content: `Loading...`, ...metadata, emoji: emojiMap[metadata.docType] || '❓', } setNode(node); fetchRAGMeta(node); }; return ( } onChange={() => { setView2D(!view2D); setResult(null); }} label="3D" /> { onNodeSelected(event.points[0].customdata); }} data={plotData} useResizeHandler={true} config={config} style={{ display: "flex", flexGrow: 1, minHeight: '240px', padding: 0, margin: 0, width: "100%", height: "100%", overflow: "hidden", }} layout={{...layout, width: plotDimensions.width, height: plotDimensions.height }} /> {node !== null && Type {node.emoji} {node.docType} {node.source_file !== undefined && File {node.source_file.replace(/^.*\//, '')} } {node.path !== undefined && Section {node.path} } {node.distance !== undefined && Distance {node.distance} }
{node.content !== "" && node.content !== undefined && Vector Embedded Content {node.content} }
} {node === null && Click a point in the scatter-graph to see information about that node. } {node !== null && node.fullContent && { node.fullContent.split('\n').map((line, index) => { index += 1 + node.chunkBegin; const bgColor = (index > node.lineBegin && index <= node.lineEnd) ? '#f0f0f0' : 'auto'; return {index}
{line || " "}
; }) } {!node.lineBegin &&
{node.content}
}
}
{!inline && querySet.query !== undefined && querySet.query !== '' && {querySet.query !== undefined && querySet.query !== '' && `Query: ${querySet.query}`} {querySet.ids.length === 0 && "Enter query below to perform a similarity search."} } { !inline && setNewQuery(e.target.value)} onKeyDown={handleKeyPress} placeholder="Enter query to find related documents..." id="QueryInput" /> }
); }; export type { VectorVisualizerProps }; export { VectorVisualizer, };