import React, { useContext, useEffect } from 'react';
import { ProjectedEmbeddings } from 'types/backend/response/ProjectedEmbeddings';
import { Annotations, Data, PlotMouseEvent, Point } from 'plotly.js';
import FigureContext from 'App/FigureContext';
import TokenContext from 'App/TokenContext';
import { DEFAULT_HOVER_TEXT, HoverText } from 'types/HoverText';
import { useTracesRest } from 'hooks/useTracesRest';
import { arrayWords } from 'types/Words';
import { useTracesTokens } from 'hooks/useTracesTokens';
import { ModelType } from 'types/ModelType';
import { TokenTable } from 'types/backend/response/TokenTable';
import ParameterContext from 'App/ParameterContext';
import BackendQueryEngine from 'backend/BackendQueryEngine';
import { Grid } from '@material-ui/core';
import StablePlot from '../../StablePlot/StablePlot';
import { ProjectionMethod } from '../../../types/ProjectionMethod';
import { tableau20 } from '../../../tools/helpers';

interface Props {
    model: ModelType;
    revision?: number;
    projectionMethod: ProjectionMethod;
    annotations?: Annotations[];
    eye?: Partial<Point>;
}

const useWordTokenRows = (model: ModelType) => {
    const [wordTokens, setWordTokens] = React.useState<TokenTable[]>([]);

    // For each word (token), get the token rows in which it occurs
    useEffect(() => {
        let isRequestStillValid = true;

        let baseModel: ModelType;
        switch (model) {
            case 'GPT2':
            case 'GPT2LEFT':
            case 'GPT2RIGHT':
                baseModel = 'GPT2';
                break;
            default:
                baseModel = 'BERT';
                break;
        }

        const responsePromises = arrayWords.map((word) =>
            BackendQueryEngine.getTokens(baseModel, { token: model === 'BERT' ? word : 'Ġ' + word })
        );

        Promise.all(responsePromises).then((tokens) => {
            isRequestStillValid && setWordTokens(tokens);
        });

        return () => {
            isRequestStillValid = false;
        };
    }, [model]);

    return wordTokens;
};

const useWordEmbeddings = (
    model: ModelType,
    layer: number,
    wordTokens: TokenTable[],
    projectionMethod: ProjectionMethod
) => {
    const [wordEmbeddings, setWordEmbeddings] = React.useState<ProjectedEmbeddings[]>([]);

    // Get the embeddings for each of the token occurrences
    useEffect(() => {
        let isRequestStillValid = true;

        const indexPerWord = wordTokens.map((t) => t.row_id);
        const responsePromises = indexPerWord.map((index) =>
            BackendQueryEngine.getProjectedEmbeddings(model, layer, projectionMethod, { row_ids: index })
        );

        Promise.all(responsePromises).then((embeddings) => {
            isRequestStillValid && setWordEmbeddings(embeddings);
        });

        return () => {
            isRequestStillValid = false;
        };
    }, [layer, wordTokens]); // Model explicitly not in deps (should be done via update of wordTokens by previous effect)

    return wordEmbeddings;
};

const useModelSpecificVars = (model: ModelType, projectionMethod2: ProjectionMethod): [number, ProjectedEmbeddings] => {
    const {
        gpt2RightProjections,
        gpt2LeftProjections,
        gpt2Projections,
        bertProjections,
        gpt2UmapProjections,
        bertUmapProjections,
    } = useContext(FigureContext);
    const { gpt2LayerID, bertLayerID } = useContext(ParameterContext);

    let projection: ProjectedEmbeddings, layer: number;

    if (projectionMethod2 === 'UMAP') {
        switch (model) {
            case 'GPT2':
                layer = gpt2LayerID;
                projection = gpt2UmapProjections;
                break;
            default:
                layer = bertLayerID;
                projection = bertUmapProjections;
                break;
        }
    } else {
        switch (model) {
            case 'GPT2':
                layer = gpt2LayerID;
                projection = gpt2Projections;
                break;
            case 'GPT2LEFT':
                layer = gpt2LayerID;
                projection = gpt2LeftProjections;
                break;
            case 'GPT2RIGHT':
                layer = gpt2LayerID;
                projection = gpt2RightProjections;
                break;
            default:
                layer = bertLayerID;
                projection = bertProjections;
                break;
        }
    }

    return [layer, projection];
};

const Section4SubFigure: React.FunctionComponent<Props> = ({
    model,
    revision,
    projectionMethod,
    annotations,
    eye = {
        x: 1,
        y: 1,
        z: 1,
    },
}: Props) => {
    const { numPoints } = useContext(ParameterContext);
    const { sentences } = useContext(TokenContext);

    const [layer, projection] = useModelSpecificVars(model, projectionMethod);

    const [sentenceTextCurrentHover, setSentenceTextCurrentHover] = React.useState<HoverText>(DEFAULT_HOVER_TEXT);
    const [sentenceTextHoverList] = React.useState<HoverText[][]>([]);

    const plotlyHover = (event: PlotMouseEvent) => {
        const datapoint = event.points[0];
        const hoverObject = sentenceTextHoverList[datapoint.curveNumber][datapoint.pointNumber] as HoverText;
        setSentenceTextCurrentHover(hoverObject);
    };

    const wordTokens = useWordTokenRows(model);
    const wordEmbeddings = useWordEmbeddings(model, layer, wordTokens, projectionMethod);

    // Create the html code for the word traces
    const data = useTracesTokens(
        wordTokens,
        wordEmbeddings,
        arrayWords,
        sentences,
        layer,
        numPoints,
        sentenceTextHoverList
    );

    const dataShadow = useTracesRest(projection);

    if (projection && data && arrayWords) {
        let combinedGPT2: Data[];
        if (data) {
            combinedGPT2 = [...data, dataShadow];
        } else {
            combinedGPT2 = [dataShadow];
        }

        return (
            <>
                <Grid container spacing={2} direction={'column'}>
                    <Grid item xs={12}>
                        <StablePlot
                            revision={revision}
                            data={combinedGPT2}
                            layout={{
                                colorway: tableau20, // Change default categorical colors
                                // width: width,
                                // height: height,
                                scene: {
                                    annotations: annotations,
                                    aspectmode: 'manual',
                                    aspectratio: { x: 1, y: 1, z: 1 },
                                    camera: {
                                        eye: eye,
                                    },
                                    xaxis: {
                                        title: '',
                                        autorange: true,
                                        showgrid: true,
                                        zeroline: false,
                                        showline: false,
                                        ticks: '',
                                        showticklabels: false,
                                    },
                                    yaxis: {
                                        title: '',
                                        autorange: true,
                                        showgrid: true,
                                        zeroline: false,
                                        showline: false,
                                        ticks: '',
                                        showticklabels: false,
                                    },
                                    zaxis: {
                                        title: '',
                                        autorange: true,
                                        showgrid: true,
                                        zeroline: false,
                                        showline: false,
                                        ticks: '',
                                        showticklabels: false,
                                    },
                                },
                                margin: {
                                    l: 10,
                                    r: 10,
                                    b: 0,
                                    t: 0,
                                    pad: 0,
                                },
                                autosize: true,
                            }}
                            config={{ responsive: true }}
                            onHover={plotlyHover}
                        />
                    </Grid>
                    <Grid item xs={12}>
                        <div className="sentenceBox">
                            <p>Sentence:</p>
                            {sentenceTextCurrentHover.left}
                            <b>{sentenceTextCurrentHover.word}</b>
                            {sentenceTextCurrentHover.right}
                        </div>
                    </Grid>
                </Grid>
            </>
        );
    }
    return null;
};

export default Section4SubFigure;
