backstory/frontend/src/VectorVisualizer.tsx

225 lines
6.0 KiB
TypeScript

import React, { useEffect, useState } from 'react';
import Box from '@mui/material/Box';
import Plot from 'react-plotly.js';
import TextField from '@mui/material/TextField';
import Tooltip from '@mui/material/Tooltip';
import Button from '@mui/material/Button';
import SendIcon from '@mui/icons-material/Send';
interface Metadata {
type?: string;
[key: string]: any;
}
interface ResultData {
embeddings: number[][] | number[][][];
documents: string[];
metadatas: Metadata[];
}
interface PlotData {
x: number[];
y: number[];
z?: number[];
colors: string[];
text: string[];
sizes: number[];
symbols: string[];
}
interface VectorVisualizerProps {
result: ResultData;
connectionBase: string;
sessionId?: string;
}
interface ChromaResult {
distances: number[];
documents: string[];
ids: string[];
metadatas: Metadata[];
query_embedding: number[];
query?: string;
vector_embedding?: number[];
}
const VectorVisualizer: React.FC<VectorVisualizerProps> = ({ result, connectionBase, sessionId }) => {
const [plotData, setPlotData] = useState<PlotData | null>(null);
const [query, setQuery] = useState<string>('');
const [queryEmbedding, setQueryEmbedding] = useState<ChromaResult | undefined>(undefined);
useEffect(() => {
if (!result || !result.embeddings) return;
if (result.embeddings.length === 0) return;
const vectors: number[][] = [...result.embeddings as number[][]];
const documents = [...result.documents || []];
const metadatas = [...result.metadatas || []];
if (queryEmbedding !== undefined && queryEmbedding.vector_embedding !== undefined) {
metadatas.unshift({ type: 'query' });
documents.unshift(queryEmbedding.query || '');
vectors.unshift(queryEmbedding.vector_embedding);
}
const is2D = vectors.every((v: number[]) => v.length === 2);
const is3D = vectors.every((v: number[]) => v.length === 3);
if (!is2D && !is3D) {
console.error('Vectors are neither 2D nor 3D');
return;
}
console.log('Vectors:', vectors);
// Placeholder color assignment
const colorMap: Record<string, string> = {
'query': '#00ff00',
};
const sizeMap: Record<string, number> = {
'query': 10,
};
const symbolMap: Record<string, string> = {
'query': 'circle',
};
const doc_types = metadatas.map(m => m.type || 'unknown');
const sizes = doc_types.map(type => {
if (!sizeMap[type]) {
sizeMap[type] = 5;
}
return sizeMap[type];
});
const symbols = doc_types.map(type => {
if (!symbolMap[type]) {
symbolMap[type] = 'circle';
}
return symbolMap[type];
});
const colors = doc_types.map(type => {
if (!colorMap[type]) {
colorMap[type] = '#ff0000';
}
return colorMap[type];
});
const x = vectors.map((v: number[]) => v[0]);
const y = vectors.map((v: number[]) => v[1]);
const text = documents.map((doc, i) => `Type: ${doc_types[i]}<br>Text: ${doc.slice(0, 100)}...`);
if (is3D) {
const z = vectors.map((v: number[]) => v[2]);
setPlotData({
x: x,
y: y,
z: z,
colors: colors,
sizes: sizes,
symbols: symbols,
text: text
});
} else {
setPlotData({
x: x,
y: y,
colors: colors,
sizes: sizes,
symbols: symbols,
text: text
});
}
}, [result, queryEmbedding]);
const handleKeyPress = (event: any) => {
if (event.key === 'Enter') {
sendQuery(query);
}
};
const sendQuery = async (query: string) => {
if (!query.trim()) return;
setQuery('');
const response = await fetch(`${connectionBase}/api/similarity/${sessionId}`, {
method: 'PUT',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
query: query,
})
});
const chroma: ChromaResult = await response.json();
console.log('Chroma:', chroma);
setQueryEmbedding(chroma);
};
if (!plotData || sessionId === undefined) return (
<Box sx={{ display: 'flex', flexGrow: 1, justifyContent: 'center', alignItems: 'center' }}>
<div>Loading visualization...</div>
</Box>
);
return (
<>
<Box sx={{ display: 'flex', flexGrow: 1, justifyContent: 'center', alignItems: 'center' }}>
<Plot
data={[
{
x: plotData.x,
y: plotData.y,
z: plotData.z,
mode: 'markers',
marker: {
size: plotData.sizes,
symbol: plotData.symbols,
color: plotData.colors,
opacity: 0.8,
},
text: plotData.text,
hoverinfo: 'text',
type: plotData.z?.length ? 'scatter3d' : 'scatter',
},
]}
useResizeHandler={true}
config={{ responsive: true }}
style={{ width: '100%', height: '100%' }}
layout={{
autosize: true,
title: 'Vector Store Visualization',
xaxis: { title: 'x' },
yaxis: { title: 'y' },
zaxis: { title: 'z' },
margin: { r: 20, b: 10, l: 10, t: 40 },
}}
/>
</Box>
{ queryEmbedding !== undefined &&
<Box sx={{ display: 'flex', flexDirection: 'column', p: 1 }}>
<Box sx={{ fontSize: '0.8rem', mb: 1 }}>
Query: {queryEmbedding.query}
</Box>
</Box>
}
<Box className="Query" sx={{ display: "flex", flexDirection: "row", p: 1 }}>
<TextField
variant="outlined"
fullWidth
type="text"
value={query}
onChange={(e) => setQuery(e.target.value)}
onKeyDown={handleKeyPress}
placeholder="Enter query to find related documents..."
id="QueryInput"
/>
<Tooltip title="Send">
<Button sx={{ m: 1 }} variant="contained" onClick={() => { sendQuery(query); }}><SendIcon /></Button>
</Tooltip>
</Box>
</>
);
};
export type { VectorVisualizerProps, ResultData, Metadata };
export {
VectorVisualizer
};