225 lines
6.0 KiB
TypeScript
225 lines
6.0 KiB
TypeScript
import React, { useEffect, useState } from 'react';
|
|
import Box from '@mui/material/Box';
|
|
import Plot from 'react-plotly.js';
|
|
import TextField from '@mui/material/TextField';
|
|
import Tooltip from '@mui/material/Tooltip';
|
|
import Button from '@mui/material/Button';
|
|
import SendIcon from '@mui/icons-material/Send';
|
|
|
|
interface Metadata {
|
|
type?: string;
|
|
[key: string]: any;
|
|
}
|
|
|
|
interface ResultData {
|
|
embeddings: number[][] | number[][][];
|
|
documents: string[];
|
|
metadatas: Metadata[];
|
|
}
|
|
|
|
interface PlotData {
|
|
x: number[];
|
|
y: number[];
|
|
z?: number[];
|
|
colors: string[];
|
|
text: string[];
|
|
sizes: number[];
|
|
symbols: string[];
|
|
}
|
|
|
|
interface VectorVisualizerProps {
|
|
result: ResultData;
|
|
connectionBase: string;
|
|
sessionId?: string;
|
|
}
|
|
|
|
interface ChromaResult {
|
|
distances: number[];
|
|
documents: string[];
|
|
ids: string[];
|
|
metadatas: Metadata[];
|
|
query_embedding: number[];
|
|
query?: string;
|
|
vector_embedding?: number[];
|
|
}
|
|
|
|
const VectorVisualizer: React.FC<VectorVisualizerProps> = ({ result, connectionBase, sessionId }) => {
|
|
const [plotData, setPlotData] = useState<PlotData | null>(null);
|
|
const [query, setQuery] = useState<string>('');
|
|
const [queryEmbedding, setQueryEmbedding] = useState<ChromaResult | undefined>(undefined);
|
|
|
|
useEffect(() => {
|
|
if (!result || !result.embeddings) return;
|
|
if (result.embeddings.length === 0) return;
|
|
|
|
const vectors: number[][] = [...result.embeddings as number[][]];
|
|
const documents = [...result.documents || []];
|
|
const metadatas = [...result.metadatas || []];
|
|
|
|
if (queryEmbedding !== undefined && queryEmbedding.vector_embedding !== undefined) {
|
|
metadatas.unshift({ type: 'query' });
|
|
documents.unshift(queryEmbedding.query || '');
|
|
vectors.unshift(queryEmbedding.vector_embedding);
|
|
}
|
|
const is2D = vectors.every((v: number[]) => v.length === 2);
|
|
const is3D = vectors.every((v: number[]) => v.length === 3);
|
|
if (!is2D && !is3D) {
|
|
console.error('Vectors are neither 2D nor 3D');
|
|
return;
|
|
}
|
|
console.log('Vectors:', vectors);
|
|
// Placeholder color assignment
|
|
const colorMap: Record<string, string> = {
|
|
'query': '#00ff00',
|
|
};
|
|
const sizeMap: Record<string, number> = {
|
|
'query': 10,
|
|
};
|
|
const symbolMap: Record<string, string> = {
|
|
'query': 'circle',
|
|
};
|
|
|
|
const doc_types = metadatas.map(m => m.type || 'unknown');
|
|
const sizes = doc_types.map(type => {
|
|
if (!sizeMap[type]) {
|
|
sizeMap[type] = 5;
|
|
}
|
|
return sizeMap[type];
|
|
});
|
|
const symbols = doc_types.map(type => {
|
|
if (!symbolMap[type]) {
|
|
symbolMap[type] = 'circle';
|
|
}
|
|
return symbolMap[type];
|
|
});
|
|
const colors = doc_types.map(type => {
|
|
if (!colorMap[type]) {
|
|
colorMap[type] = '#ff0000';
|
|
}
|
|
return colorMap[type];
|
|
});
|
|
|
|
const x = vectors.map((v: number[]) => v[0]);
|
|
const y = vectors.map((v: number[]) => v[1]);
|
|
const text = documents.map((doc, i) => `Type: ${doc_types[i]}<br>Text: ${doc.slice(0, 100)}...`);
|
|
if (is3D) {
|
|
const z = vectors.map((v: number[]) => v[2]);
|
|
setPlotData({
|
|
x: x,
|
|
y: y,
|
|
z: z,
|
|
colors: colors,
|
|
sizes: sizes,
|
|
symbols: symbols,
|
|
text: text
|
|
});
|
|
} else {
|
|
setPlotData({
|
|
x: x,
|
|
y: y,
|
|
colors: colors,
|
|
sizes: sizes,
|
|
symbols: symbols,
|
|
text: text
|
|
});
|
|
}
|
|
}, [result, queryEmbedding]);
|
|
|
|
const handleKeyPress = (event: any) => {
|
|
if (event.key === 'Enter') {
|
|
sendQuery(query);
|
|
}
|
|
};
|
|
|
|
const sendQuery = async (query: string) => {
|
|
if (!query.trim()) return;
|
|
setQuery('');
|
|
|
|
const response = await fetch(`${connectionBase}/api/similarity/${sessionId}`, {
|
|
method: 'PUT',
|
|
headers: {
|
|
'Content-Type': 'application/json',
|
|
},
|
|
body: JSON.stringify({
|
|
query: query,
|
|
})
|
|
});
|
|
const chroma: ChromaResult = await response.json();
|
|
console.log('Chroma:', chroma);
|
|
setQueryEmbedding(chroma);
|
|
};
|
|
|
|
if (!plotData || sessionId === undefined) return (
|
|
<Box sx={{ display: 'flex', flexGrow: 1, justifyContent: 'center', alignItems: 'center' }}>
|
|
<div>Loading visualization...</div>
|
|
</Box>
|
|
);
|
|
|
|
return (
|
|
<>
|
|
<Box sx={{ display: 'flex', flexGrow: 1, justifyContent: 'center', alignItems: 'center' }}>
|
|
<Plot
|
|
data={[
|
|
{
|
|
x: plotData.x,
|
|
y: plotData.y,
|
|
z: plotData.z,
|
|
mode: 'markers',
|
|
marker: {
|
|
size: plotData.sizes,
|
|
symbol: plotData.symbols,
|
|
color: plotData.colors,
|
|
opacity: 0.8,
|
|
},
|
|
text: plotData.text,
|
|
hoverinfo: 'text',
|
|
type: plotData.z?.length ? 'scatter3d' : 'scatter',
|
|
},
|
|
]}
|
|
useResizeHandler={true}
|
|
config={{ responsive: true }}
|
|
style={{ width: '100%', height: '100%' }}
|
|
layout={{
|
|
autosize: true,
|
|
title: 'Vector Store Visualization',
|
|
xaxis: { title: 'x' },
|
|
yaxis: { title: 'y' },
|
|
zaxis: { title: 'z' },
|
|
margin: { r: 20, b: 10, l: 10, t: 40 },
|
|
}}
|
|
/>
|
|
</Box>
|
|
|
|
{ queryEmbedding !== undefined &&
|
|
<Box sx={{ display: 'flex', flexDirection: 'column', p: 1 }}>
|
|
<Box sx={{ fontSize: '0.8rem', mb: 1 }}>
|
|
Query: {queryEmbedding.query}
|
|
</Box>
|
|
</Box>
|
|
}
|
|
|
|
<Box className="Query" sx={{ display: "flex", flexDirection: "row", p: 1 }}>
|
|
<TextField
|
|
variant="outlined"
|
|
fullWidth
|
|
type="text"
|
|
value={query}
|
|
onChange={(e) => setQuery(e.target.value)}
|
|
onKeyDown={handleKeyPress}
|
|
placeholder="Enter query to find related documents..."
|
|
id="QueryInput"
|
|
/>
|
|
<Tooltip title="Send">
|
|
<Button sx={{ m: 1 }} variant="contained" onClick={() => { sendQuery(query); }}><SendIcon /></Button>
|
|
</Tooltip>
|
|
</Box>
|
|
</>
|
|
);
|
|
};
|
|
|
|
export type { VectorVisualizerProps, ResultData, Metadata };
|
|
|
|
export {
|
|
VectorVisualizer
|
|
};
|