import React, { useEffect, useState } from 'react'; import Box from '@mui/material/Box'; import Plot from 'react-plotly.js'; import TextField from '@mui/material/TextField'; import Tooltip from '@mui/material/Tooltip'; import Button from '@mui/material/Button'; import SendIcon from '@mui/icons-material/Send'; interface Metadata { type?: string; [key: string]: any; } interface ResultData { embeddings: number[][] | number[][][]; documents: string[]; metadatas: Metadata[]; } interface PlotData { x: number[]; y: number[]; z?: number[]; colors: string[]; text: string[]; sizes: number[]; symbols: string[]; } interface VectorVisualizerProps { result: ResultData; connectionBase: string; sessionId?: string; } interface ChromaResult { distances: number[]; documents: string[]; ids: string[]; metadatas: Metadata[]; query_embedding: number[]; query?: string; vector_embedding?: number[]; } const VectorVisualizer: React.FC = ({ result, connectionBase, sessionId }) => { const [plotData, setPlotData] = useState(null); const [query, setQuery] = useState(''); const [queryEmbedding, setQueryEmbedding] = useState(undefined); useEffect(() => { if (!result || !result.embeddings) return; if (result.embeddings.length === 0) return; const vectors: number[][] = [...result.embeddings as number[][]]; const documents = [...result.documents || []]; const metadatas = [...result.metadatas || []]; if (queryEmbedding !== undefined && queryEmbedding.vector_embedding !== undefined) { metadatas.unshift({ type: 'query' }); documents.unshift(queryEmbedding.query || ''); vectors.unshift(queryEmbedding.vector_embedding); } const is2D = vectors.every((v: number[]) => v.length === 2); const is3D = vectors.every((v: number[]) => v.length === 3); if (!is2D && !is3D) { console.error('Vectors are neither 2D nor 3D'); return; } console.log('Vectors:', vectors); // Placeholder color assignment const colorMap: Record = { 'query': '#00ff00', }; const sizeMap: Record = { 'query': 10, }; const symbolMap: Record = { 'query': 'circle', }; const doc_types = metadatas.map(m => m.type || 'unknown'); const sizes = doc_types.map(type => { if (!sizeMap[type]) { sizeMap[type] = 5; } return sizeMap[type]; }); const symbols = doc_types.map(type => { if (!symbolMap[type]) { symbolMap[type] = 'circle'; } return symbolMap[type]; }); const colors = doc_types.map(type => { if (!colorMap[type]) { colorMap[type] = '#ff0000'; } return colorMap[type]; }); const x = vectors.map((v: number[]) => v[0]); const y = vectors.map((v: number[]) => v[1]); const text = documents.map((doc, i) => `Type: ${doc_types[i]}
Text: ${doc.slice(0, 100)}...`); if (is3D) { const z = vectors.map((v: number[]) => v[2]); setPlotData({ x: x, y: y, z: z, colors: colors, sizes: sizes, symbols: symbols, text: text }); } else { setPlotData({ x: x, y: y, colors: colors, sizes: sizes, symbols: symbols, text: text }); } }, [result, queryEmbedding]); const handleKeyPress = (event: any) => { if (event.key === 'Enter') { sendQuery(query); } }; const sendQuery = async (query: string) => { if (!query.trim()) return; setQuery(''); const response = await fetch(`${connectionBase}/api/similarity/${sessionId}`, { method: 'PUT', headers: { 'Content-Type': 'application/json', }, body: JSON.stringify({ query: query, }) }); const chroma: ChromaResult = await response.json(); console.log('Chroma:', chroma); setQueryEmbedding(chroma); }; if (!plotData || sessionId === undefined) return (
Loading visualization...
); return ( <> { queryEmbedding !== undefined && Query: {queryEmbedding.query} } setQuery(e.target.value)} onKeyDown={handleKeyPress} placeholder="Enter query to find related documents..." id="QueryInput" /> ); }; export type { VectorVisualizerProps, ResultData, Metadata }; export { VectorVisualizer };