/*
 * Decompiled with CFR 0.152.
 */
package org.carrot2.text.vsm;

import com.carrotsearch.hppc.BitSet;
import com.carrotsearch.hppc.IntIntHashMap;
import com.carrotsearch.hppc.sorting.IndirectSort;
import java.util.Arrays;
import java.util.function.IntToDoubleFunction;
import org.carrot2.attrs.AttrComposite;
import org.carrot2.attrs.AttrDouble;
import org.carrot2.attrs.AttrInteger;
import org.carrot2.attrs.AttrObject;
import org.carrot2.attrs.AttrStringArray;
import org.carrot2.language.TokenTypeUtils;
import org.carrot2.math.mahout.matrix.DoubleMatrix2D;
import org.carrot2.math.mahout.matrix.impl.DenseDoubleMatrix2D;
import org.carrot2.math.mahout.matrix.impl.SparseDoubleMatrix2D;
import org.carrot2.math.matrix.MatrixUtils;
import org.carrot2.text.preprocessing.PreprocessingContext;
import org.carrot2.text.vsm.LogTfIdfTermWeighting;
import org.carrot2.text.vsm.TermWeighting;
import org.carrot2.text.vsm.VectorSpaceModelContext;

public class TermDocumentMatrixBuilder
extends AttrComposite {
    public final AttrDouble boostedFieldWeight;
    public AttrStringArray boostFields;
    public final AttrInteger maximumMatrixSize;
    public final AttrDouble maxWordDf;
    public TermWeighting termWeighting;

    public TermDocumentMatrixBuilder() {
        this.boostedFieldWeight = this.attributes.register("boostedFieldWeight", AttrDouble.builder().label("Boosted fields weight").min(0.0).max(10.0).defaultValue(2.0));
        this.boostFields = this.attributes.register("boostFields", AttrStringArray.builder().label("Boosted fields").defaultValue(new String[0]));
        this.maximumMatrixSize = this.attributes.register("maximumMatrixSize", AttrInteger.builder().label("Maximum term-document matrix size").min(5000).defaultValue(37500));
        this.maxWordDf = this.attributes.register("maxWordDf", AttrDouble.builder().label("Maximum word document frequency").min(0.0).max(1.0).defaultValue(0.9));
        this.attributes.register("termWeighting", ((AttrObject.Builder)AttrObject.builder(TermWeighting.class).label("Term weighting for term-document matrix")).getset(() -> this.termWeighting, v -> {
            this.termWeighting = v;
        }).defaultValue(LogTfIdfTermWeighting::new));
    }

    public void buildTermDocumentMatrix(VectorSpaceModelContext vsmContext) {
        IntToDoubleFunction fieldIndexToBoost;
        PreprocessingContext preprocessingContext = vsmContext.preprocessingContext;
        int documentCount = preprocessingContext.documentCount;
        int[] stemsTf = preprocessingContext.allStems.tf;
        int[][] stemsTfByDocument = preprocessingContext.allStems.tfByDocument;
        byte[] stemsFieldIndices = preprocessingContext.allStems.fieldIndices;
        if (documentCount == 0) {
            vsmContext.termDocumentMatrix = new DenseDoubleMatrix2D(0, 0);
            vsmContext.stemToRowIndex = new IntIntHashMap();
            return;
        }
        if (((String[])this.boostFields.get()).length == 0) {
            fieldIndexToBoost = fieldIndices -> 1.0;
        } else {
            double[] boosts = new double[256];
            Arrays.fill(boosts, 1.0);
            PreprocessingContext.AllFields allFields = preprocessingContext.allFields;
            for (String fieldName : (String[])this.boostFields.get()) {
                int fieldIndex = allFields.fieldIndex(fieldName);
                if (fieldIndex < 0) continue;
                int mask = 1 << fieldIndex;
                for (int i = 0; i < boosts.length; ++i) {
                    if ((i & mask) == 0) continue;
                    boosts[i] = (Double)this.boostedFieldWeight.get();
                }
            }
            fieldIndexToBoost = fieldIndices -> boosts[fieldIndices];
        }
        int[] stemsToInclude = this.computeRequiredStemIndices(preprocessingContext);
        TermWeighting termWeighting = this.termWeighting;
        double[] stemsWeight = new double[stemsToInclude.length];
        for (int i = 0; i < stemsToInclude.length; ++i) {
            int stemIndex = stemsToInclude[i];
            double weight = termWeighting.calculateTermWeight(stemsTf[stemIndex], stemsTfByDocument[stemIndex].length / 2, documentCount);
            stemsWeight[i] = weight * fieldIndexToBoost.applyAsDouble(stemsFieldIndices[stemIndex]);
        }
        int[] stemWeightOrder = IndirectSort.mergesort((int)0, (int)stemsWeight.length, (a, b) -> Double.compare(stemsWeight[b], stemsWeight[a]));
        int maxRows = (Integer)this.maximumMatrixSize.get() / documentCount;
        DenseDoubleMatrix2D tdMatrix = new DenseDoubleMatrix2D(Math.min(maxRows, stemsToInclude.length), documentCount);
        for (int i = 0; i < stemWeightOrder.length && i < maxRows; ++i) {
            int stemIndex = stemsToInclude[stemWeightOrder[i]];
            int[] tfByDocument = stemsTfByDocument[stemIndex];
            int df = tfByDocument.length / 2;
            byte fieldIndices2 = stemsFieldIndices[stemIndex];
            double fieldWeight = fieldIndexToBoost.applyAsDouble(fieldIndices2);
            for (int j = 0; j < df; ++j) {
                double weight = termWeighting.calculateTermWeight(tfByDocument[j * 2 + 1], df, documentCount);
                tdMatrix.set(i, tfByDocument[j * 2], weight *= fieldWeight);
            }
        }
        IntIntHashMap stemToRowIndex = new IntIntHashMap();
        for (int i = 0; i < stemWeightOrder.length && i < tdMatrix.rows(); ++i) {
            stemToRowIndex.put(stemsToInclude[stemWeightOrder[i]], i);
        }
        vsmContext.termDocumentMatrix = tdMatrix;
        vsmContext.stemToRowIndex = stemToRowIndex;
    }

    public void buildTermPhraseMatrix(VectorSpaceModelContext context) {
        PreprocessingContext preprocessingContext = context.preprocessingContext;
        IntIntHashMap stemToRowIndex = context.stemToRowIndex;
        int[] labelsFeatureIndex = preprocessingContext.allLabels.featureIndex;
        int firstPhraseIndex = preprocessingContext.allLabels.firstPhraseIndex;
        if (firstPhraseIndex >= 0 && stemToRowIndex.size() > 0) {
            int[] phraseFeatureIndices = new int[labelsFeatureIndex.length - firstPhraseIndex];
            for (int featureIndex = 0; featureIndex < phraseFeatureIndices.length; ++featureIndex) {
                phraseFeatureIndices[featureIndex] = labelsFeatureIndex[featureIndex + firstPhraseIndex];
            }
            DoubleMatrix2D phraseMatrix = TermDocumentMatrixBuilder.buildAlignedMatrix(context, phraseFeatureIndices, this.termWeighting);
            MatrixUtils.normalizeColumnL2(phraseMatrix, null);
            context.termPhraseMatrix = phraseMatrix.viewDice();
        }
    }

    private int[] computeRequiredStemIndices(PreprocessingContext context) {
        int[] labelsFeatureIndex = context.allLabels.featureIndex;
        int[] wordsStemIndex = context.allWords.stemIndex;
        short[] wordsTypes = context.allWords.type;
        int[][] phrasesWordIndices = context.allPhrases.wordIndices;
        int wordCount = wordsStemIndex.length;
        int[][] stemsTfByDocument = context.allStems.tfByDocument;
        int documentCount = context.documentCount;
        BitSet requiredStemIndices = new BitSet((long)labelsFeatureIndex.length);
        double maxWordDf = (Double)this.maxWordDf.get();
        for (int i = 0; i < labelsFeatureIndex.length; ++i) {
            int featureIndex = labelsFeatureIndex[i];
            if (featureIndex < wordCount) {
                this.addStemIndex(wordsStemIndex, documentCount, stemsTfByDocument, requiredStemIndices, featureIndex, maxWordDf);
                continue;
            }
            int[] wordIndices = phrasesWordIndices[featureIndex - wordCount];
            for (int j = 0; j < wordIndices.length; ++j) {
                int wordIndex = wordIndices[j];
                if (TokenTypeUtils.isCommon(wordsTypes[wordIndex])) continue;
                this.addStemIndex(wordsStemIndex, documentCount, stemsTfByDocument, requiredStemIndices, wordIndex, maxWordDf);
            }
        }
        return requiredStemIndices.asIntLookupContainer().toArray();
    }

    private void addStemIndex(int[] wordsStemIndex, int documentCount, int[][] stemsTfByDocument, BitSet requiredStemIndices, int featureIndex, double maxWordDf) {
        int stemIndex = wordsStemIndex[featureIndex];
        int df = stemsTfByDocument[stemIndex].length / 2;
        if ((double)df / (double)documentCount <= maxWordDf) {
            requiredStemIndices.set((long)stemIndex);
        }
    }

    static DoubleMatrix2D buildAlignedMatrix(VectorSpaceModelContext vsmContext, int[] featureIndex, TermWeighting termWeighting) {
        IntIntHashMap stemToRowIndex = vsmContext.stemToRowIndex;
        if (featureIndex.length == 0) {
            return new DenseDoubleMatrix2D(stemToRowIndex.size(), 0);
        }
        SparseDoubleMatrix2D phraseMatrix = new SparseDoubleMatrix2D(stemToRowIndex.size(), featureIndex.length);
        PreprocessingContext preprocessingContext = vsmContext.preprocessingContext;
        int[] wordsStemIndex = preprocessingContext.allWords.stemIndex;
        int[] stemsTf = preprocessingContext.allStems.tf;
        int[][] stemsTfByDocument = preprocessingContext.allStems.tfByDocument;
        int[][] phrasesWordIndices = preprocessingContext.allPhrases.wordIndices;
        int documentCount = preprocessingContext.documentCount;
        int wordCount = wordsStemIndex.length;
        for (int i = 0; i < featureIndex.length; ++i) {
            int feature = featureIndex[i];
            int[] wordIndices = feature < wordCount ? new int[]{feature} : phrasesWordIndices[feature - wordCount];
            for (int wordIndex = 0; wordIndex < wordIndices.length; ++wordIndex) {
                int stemIndex = wordsStemIndex[wordIndices[wordIndex]];
                int index = stemToRowIndex.indexOf(stemIndex);
                if (!stemToRowIndex.indexExists(index)) continue;
                int rowIndex = stemToRowIndex.indexGet(index);
                double weight = termWeighting.calculateTermWeight(stemsTf[stemIndex], stemsTfByDocument[stemIndex].length / 2, documentCount);
                ((DoubleMatrix2D)phraseMatrix).setQuick(rowIndex, i, weight);
            }
        }
        return phraseMatrix;
    }
}

