import { ModelType } from 'types/ModelType';
import { DependencyList, useEffect, useState } from 'react';
import { ProjectedEmbeddings } from 'types/backend/response/ProjectedEmbeddings';
import BackendQueryEngine from 'backend/BackendQueryEngine';
import { ProjectionMethod } from '../types/ProjectionMethod';
import { useLoadingCursor } from 'hooks/useLoadingCursor';

interface SampledEmbeddingsPayload {
    maxSamples: number;
    model: ModelType;
    layer: number;
    projectionMethod: ProjectionMethod;
}

export const useSampledEmbeddings = (payloads: SampledEmbeddingsPayload[], deps: DependencyList) => {
    const [projectedEmbeddings, setProjectedEmbeddings] = useState<ProjectedEmbeddings[]>();

    const [setLoadingCursor, unsetLoadingCursor] = useLoadingCursor();

    // When token table is available, query projected embeddings
    useEffect(() => {
        let isRequestStillValid = true;

        setLoadingCursor();

        const responsePromises = payloads.map(({ maxSamples, model, layer, projectionMethod }) =>
            BackendQueryEngine.getProjectedEmbeddings(model, layer, projectionMethod, { max_samples: maxSamples })
        );

        Promise.all(responsePromises).then((embeddings) => {
            isRequestStillValid && setProjectedEmbeddings(embeddings);
            unsetLoadingCursor();
        });

        return () => {
            isRequestStillValid = false;
            unsetLoadingCursor();
        };
    }, deps);

    return projectedEmbeddings;
};
