diff --git a/frontend/src/VectorVisualizer.tsx b/frontend/src/VectorVisualizer.tsx index 5d1b473..0cae3ed 100644 --- a/frontend/src/VectorVisualizer.tsx +++ b/frontend/src/VectorVisualizer.tsx @@ -10,7 +10,7 @@ import SendIcon from '@mui/icons-material/Send'; import FormControlLabel from '@mui/material/FormControlLabel'; import Switch from '@mui/material/Switch'; -import { SeverityType } from './Snack'; +import { SetSnackType } from './Snack'; interface Metadata { doc_type?: string; @@ -18,7 +18,7 @@ interface Metadata { } interface ResultData { - embeddings: number[][] | number[][][]; + embeddings: (number[])[]; documents: string[]; metadatas: Metadata[]; ids: string[]; @@ -42,7 +42,7 @@ interface PlotData { interface VectorVisualizerProps { connectionBase: string; sessionId?: string; - setSnack: (message: string, severity: SeverityType) => void; + setSnack: SetSnackType; inline?: boolean; rag?: any; } @@ -141,11 +141,19 @@ const VectorVisualizer: React.FC = ({ setSnack, rag, inli if (!result || !result.embeddings) return; if (result.embeddings.length === 0) return; - const vectors: number[][] = [...result.embeddings as number[][]]; + const vectors: (number[])[] = [...result.embeddings]; const documents = [...result.documents || []]; const metadatas = [...result.metadatas || []]; const ids = [...result.ids || []]; + let is2D = vectors.every((v: number[]) => v.length === 2); + let is3D = vectors.every((v: number[]) => v.length === 3); + + console.log(`Embeddings are ${is2D ? '2D' : is3D ? '3D' : 'invaalid'} and view2D is ${view2D}`); + if ((view2D && !is2D) || (!view2D && !is3D)) { + return; + } + if (view2D && rag && rag.umap_embedding_2d) { metadatas.unshift({ doc_type: 'query' }); documents.unshift('Query'); @@ -169,10 +177,11 @@ const VectorVisualizer: React.FC = ({ setSnack, rag, inli } } - const is2D = vectors.every((v: number[]) => v.length === 2); - const is3D = vectors.every((v: number[]) => v.length === 3); + is2D = vectors.every((v: number[]) => v.length === 2); + is3D = vectors.every((v: number[]) => v.length === 3); + if (!is2D && !is3D) { - console.error('Vectors are neither 2D nor 3D'); + console.warn('Modified vectors are neither 2D nor 3D'); return; }