/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor.normalization;

import com.google.common.primitives.Floats;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import lombok.Generated;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;
import org.opensearch.neuralsearch.processor.NormalizeScoresDTO;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplanationUtils;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationUtil;

public class MinMaxScoreNormalizationTechnique
implements ScoreNormalizationTechnique,
ExplainableTechnique {
    public static final String TECHNIQUE_NAME = "min_max";
    private static final float MIN_SCORE = 0.001f;
    private static final float SINGLE_RESULT_SCORE = 1.0f;

    @Override
    public void normalize(NormalizeScoresDTO normalizeScoresDTO) {
        List<CompoundTopDocs> queryTopDocs = normalizeScoresDTO.getQueryTopDocs();
        MinMaxScores minMaxScores = this.getMinMaxScoresResult(queryTopDocs);
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            for (int j = 0; j < topDocsPerSubQuery.size(); ++j) {
                TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j);
                for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
                    scoreDoc.score = this.normalizeSingleScore(scoreDoc.score, minMaxScores.getMinScoresPerSubquery()[j], minMaxScores.getMaxScoresPerSubquery()[j]);
                }
            }
        }
    }

    private MinMaxScores getMinMaxScoresResult(List<CompoundTopDocs> queryTopDocs) {
        int numOfSubqueries = this.getNumOfSubqueries(queryTopDocs);
        float[] minScoresPerSubquery = this.getMinScores(queryTopDocs, numOfSubqueries);
        float[] maxScoresPerSubquery = this.getMaxScores(queryTopDocs, numOfSubqueries);
        return new MinMaxScores(minScoresPerSubquery, maxScoresPerSubquery);
    }

    @Override
    public String describe() {
        return String.format(Locale.ROOT, "%s", TECHNIQUE_NAME);
    }

    @Override
    public Map<DocIdAtSearchShard, ExplanationDetails> explain(List<CompoundTopDocs> queryTopDocs) {
        MinMaxScores minMaxScores = this.getMinMaxScoresResult(queryTopDocs);
        HashMap<DocIdAtSearchShard, List<Float>> normalizedScores = new HashMap<DocIdAtSearchShard, List<Float>>();
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            int numberOfSubQueries = topDocsPerSubQuery.size();
            for (int subQueryIndex = 0; subQueryIndex < numberOfSubQueries; ++subQueryIndex) {
                TopDocs subQueryTopDoc = topDocsPerSubQuery.get(subQueryIndex);
                for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
                    DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(scoreDoc.doc, compoundQueryTopDocs.getSearchShard());
                    float normalizedScore = this.normalizeSingleScore(scoreDoc.score, minMaxScores.getMinScoresPerSubquery()[subQueryIndex], minMaxScores.getMaxScoresPerSubquery()[subQueryIndex]);
                    ScoreNormalizationUtil.setNormalizedScore(normalizedScores, docIdAtSearchShard, subQueryIndex, numberOfSubQueries, normalizedScore);
                    scoreDoc.score = normalizedScore;
                }
            }
        }
        return ExplanationUtils.getDocIdAtQueryForNormalization(normalizedScores, this);
    }

    private int getNumOfSubqueries(List<CompoundTopDocs> queryTopDocs) {
        return queryTopDocs.stream().filter(Objects::nonNull).filter(topDocs -> !topDocs.getTopDocs().isEmpty()).findAny().get().getTopDocs().size();
    }

    private float[] getMaxScores(List<CompoundTopDocs> queryTopDocs, int numOfSubqueries) {
        float[] maxScores = new float[numOfSubqueries];
        Arrays.fill(maxScores, Float.MIN_VALUE);
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            for (int j = 0; j < topDocsPerSubQuery.size(); ++j) {
                maxScores[j] = Math.max(maxScores[j], Arrays.stream(topDocsPerSubQuery.get((int)j).scoreDocs).map(scoreDoc -> Float.valueOf(scoreDoc.score)).max(Float::compare).orElse(Float.valueOf(Float.MIN_VALUE)).floatValue());
            }
        }
        return maxScores;
    }

    private float[] getMinScores(List<CompoundTopDocs> queryTopDocs, int numOfScores) {
        float[] minScores = new float[numOfScores];
        Arrays.fill(minScores, Float.MAX_VALUE);
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            for (int j = 0; j < topDocsPerSubQuery.size(); ++j) {
                minScores[j] = Math.min(minScores[j], Arrays.stream(topDocsPerSubQuery.get((int)j).scoreDocs).map(scoreDoc -> Float.valueOf(scoreDoc.score)).min(Float::compare).orElse(Float.valueOf(Float.MAX_VALUE)).floatValue());
            }
        }
        return minScores;
    }

    private float normalizeSingleScore(float score, float minScore, float maxScore) {
        if (Floats.compare((float)maxScore, (float)minScore) == 0 && Floats.compare((float)maxScore, (float)score) == 0) {
            return 1.0f;
        }
        float normalizedScore = (score - minScore) / (maxScore - minScore);
        return normalizedScore == 0.0f ? 0.001f : normalizedScore;
    }

    @Generated
    public String toString() {
        return "MinMaxScoreNormalizationTechnique(TECHNIQUE_NAME=min_max)";
    }

    private class MinMaxScores {
        float[] minScoresPerSubquery;
        float[] maxScoresPerSubquery;

        @Generated
        public MinMaxScores(float[] minScoresPerSubquery, float[] maxScoresPerSubquery) {
            this.minScoresPerSubquery = minScoresPerSubquery;
            this.maxScoresPerSubquery = maxScoresPerSubquery;
        }

        @Generated
        public float[] getMinScoresPerSubquery() {
            return this.minScoresPerSubquery;
        }

        @Generated
        public float[] getMaxScoresPerSubquery() {
            return this.maxScoresPerSubquery;
        }
    }
}

