/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.codegen.opt;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.codegen.opt.InterestingPoint;
import org.apache.sysml.hops.codegen.opt.PlanPartition;
import org.apache.sysml.hops.codegen.opt.PlanSelection;
import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysml.runtime.util.UtilFunctions;

public class ReachabilityGraph {
    private HashMap<Pair<Long, Long>, NodeLink> _matPoints = new HashMap();
    private NodeLink _root = null;
    private InterestingPoint[] _searchSpace;
    private CutSet[] _cutSets;

    public ReachabilityGraph(PlanPartition part, CPlanMemoTable memo) {
        for (InterestingPoint p2 : part.getMatPointsExt()) {
            this._matPoints.put(Pair.of(p2._fromHopID, p2._toHopID), new NodeLink(p2));
        }
        this._root = new NodeLink(null);
        HashSet<PlanSelection.VisitMarkCost> visited = new HashSet<PlanSelection.VisitMarkCost>();
        for (Long hopID : part.getRoots()) {
            Hop rootHop = memo.getHopRefs().get(hopID);
            this.addInputNodeLinks(rootHop, this._root, part, memo, visited);
        }
        List tmpCS = this._matPoints.values().stream().filter(p -> ((NodeLink)p)._inputs.size() > 0 && ((NodeLink)p)._p != null).sorted().collect(Collectors.toList());
        if (tmpCS.isEmpty()) {
            this._cutSets = new CutSet[0];
            this._searchSpace = this.sortBySize(part.getMatPointsExt(), memo, false);
            return;
        }
        ArrayList<ArrayList<NodeLink>> candCS = new ArrayList<ArrayList<NodeLink>>();
        ArrayList<NodeLink> current = new ArrayList<NodeLink>();
        for (NodeLink node : tmpCS) {
            if (current.isEmpty()) {
                current.add(node);
                continue;
            }
            if (((NodeLink)current.get(0)).equals(node)) {
                current.add(node);
                continue;
            }
            candCS.add(current);
            current = new ArrayList();
            current.add(node);
        }
        if (!current.isEmpty()) {
            candCS.add(current);
        }
        ArrayList<ArrayList<NodeLink>> remain = new ArrayList<ArrayList<NodeLink>>();
        ArrayList<Pair<CutSet, Double>> cutSets = this.evaluateCutSets(candCS, remain);
        if (!remain.isEmpty() && remain.size() < 5) {
            ArrayList<ArrayList<NodeLink>> candCS2 = new ArrayList<ArrayList<NodeLink>>();
            for (int i = 0; i < remain.size() - 1; ++i) {
                for (int j = i + 1; j < remain.size(); ++j) {
                    ArrayList tmp = new ArrayList();
                    tmp.addAll(remain.get(i));
                    tmp.addAll(remain.get(j));
                    candCS2.add(tmp);
                }
            }
            ArrayList<Pair<CutSet, Double>> cutSets2 = this.evaluateCutSets(candCS2, remain);
            HashSet testDisjoint = new HashSet();
            for (Pair<CutSet, Double> cs : cutSets2) {
                if (CollectionUtils.containsAny(testDisjoint, Arrays.asList(cs.getLeft().cut))) continue;
                cutSets.add(cs);
                CollectionUtils.addAll(testDisjoint, cs.getLeft().cut);
            }
        }
        this._cutSets = (CutSet[])cutSets.stream().sorted(Comparator.comparing(p -> (Double)p.getRight())).map(p -> (CutSet)p.getLeft()).toArray(CutSet[]::new);
        HashMap<Object, Integer> probe = new HashMap<Object, Integer>();
        ArrayList<Object> lsearchSpace = new ArrayList<Object>();
        for (CutSet cutSet : this._cutSets) {
            CollectionUtils.addAll(lsearchSpace, cutSet.cut);
            for (InterestingPoint p3 : cutSet.cut) {
                probe.put(p3, probe.size());
            }
        }
        for (InterestingPoint interestingPoint : this.sortBySize(part.getMatPointsExt(), memo, false)) {
            if (probe.containsKey(interestingPoint)) continue;
            lsearchSpace.add(interestingPoint);
            probe.put(interestingPoint, probe.size());
        }
        this._searchSpace = lsearchSpace.toArray(new InterestingPoint[0]);
        for (CutSet cutSet : this._cutSets) {
            cutSet.updatePositions(probe);
        }
        if (this._searchSpace.length != part.getMatPointsExt().length) {
            throw new RuntimeException("Corrupt linearized search space: " + this._searchSpace.length + " vs " + part.getMatPointsExt().length);
        }
    }

    public InterestingPoint[] getSortedSearchSpace() {
        return this._searchSpace;
    }

    public boolean isCutSet(boolean[] plan) {
        for (CutSet cs : this._cutSets) {
            if (!this.isCutSet(cs, plan)) continue;
            return true;
        }
        return false;
    }

    public boolean isCutSet(CutSet cs, boolean[] plan) {
        boolean ret = true;
        for (int i = 0; i < cs.posCut.length && ret; ret &= plan[cs.posCut[i]], ++i) {
        }
        return ret;
    }

    public CutSet getCutSet(boolean[] plan) {
        for (CutSet cs : this._cutSets) {
            if (!this.isCutSet(cs, plan)) continue;
            return cs;
        }
        throw new RuntimeException("No valid cut set found.");
    }

    public long getNumSkipPlans(boolean[] plan) {
        for (CutSet cs : this._cutSets) {
            if (!this.isCutSet(cs, plan)) continue;
            int pos = cs.posCut[cs.posCut.length - 1];
            return UtilFunctions.pow(2, plan.length - pos - 1);
        }
        throw new RuntimeException("Failed to compute number of skip plans for plan without cutset.");
    }

    public SubProblem[] getSubproblems(boolean[] plan) {
        CutSet cs = this.getCutSet(plan);
        return new SubProblem[]{new SubProblem(cs.cut.length, cs.posLeft, cs.left), new SubProblem(cs.cut.length, cs.posRight, cs.right)};
    }

    public String toString() {
        return "ReachabilityGraph(" + this._matPoints.size() + "):\n" + this._root.explain(new HashSet());
    }

    private void addInputNodeLinks(Hop current, NodeLink parent, PlanPartition part, CPlanMemoTable memo, HashSet<PlanSelection.VisitMarkCost> visited) {
        if (visited.contains(new PlanSelection.VisitMarkCost(current.getHopID(), parent._ID))) {
            return;
        }
        for (Hop in : current.getInput()) {
            if (InterestingPoint.isMatPoint(part.getMatPointsExt(), current.getHopID(), in.getHopID())) {
                NodeLink tmp = this._matPoints.get(Pair.of(current.getHopID(), in.getHopID()));
                parent.addInput(tmp);
                this.addInputNodeLinks(in, tmp, part, memo, visited);
                continue;
            }
            this.addInputNodeLinks(in, parent, part, memo, visited);
        }
        visited.add(new PlanSelection.VisitMarkCost(current.getHopID(), parent._ID));
    }

    private void rCollectInputs(NodeLink current, HashSet<NodeLink> probe, HashSet<NodeLink> inputs) {
        for (NodeLink c : current._inputs) {
            if (probe.contains(c)) continue;
            this.rCollectInputs(c, probe, inputs);
            inputs.add(c);
        }
    }

    private ArrayList<Pair<CutSet, Double>> evaluateCutSets(ArrayList<ArrayList<NodeLink>> candCS, ArrayList<ArrayList<NodeLink>> remain) {
        ArrayList<Pair<CutSet, Double>> cutSets = new ArrayList<Pair<CutSet, Double>>();
        for (ArrayList<NodeLink> cand : candCS) {
            HashSet<NodeLink> probe = new HashSet<NodeLink>(cand);
            HashSet<NodeLink> part1 = new HashSet<NodeLink>();
            this.rCollectInputs(this._root, probe, part1);
            HashSet<NodeLink> part2 = new HashSet<NodeLink>();
            for (NodeLink rNode : cand) {
                this.rCollectInputs(rNode, probe, part2);
            }
            if (!(CollectionUtils.containsAny(part1, part2) || part1.isEmpty() || part2.isEmpty())) {
                double base = UtilFunctions.pow(2, this._matPoints.size());
                double numComb = UtilFunctions.pow(2, cand.size());
                double score = (numComb - 1.0) / numComb * base + 1.0 / numComb * (double)UtilFunctions.pow(2, part1.size()) + 1.0 / numComb * (double)UtilFunctions.pow(2, part2.size());
                cutSets.add(Pair.of(new CutSet((InterestingPoint[])cand.stream().map(p -> ((NodeLink)p)._p).toArray(InterestingPoint[]::new), (InterestingPoint[])part1.stream().map(p -> ((NodeLink)p)._p).toArray(InterestingPoint[]::new), (InterestingPoint[])part2.stream().map(p -> ((NodeLink)p)._p).toArray(InterestingPoint[]::new)), score));
                continue;
            }
            remain.add(cand);
        }
        return cutSets;
    }

    private InterestingPoint[] sortBySize(InterestingPoint[] points, CPlanMemoTable memo, boolean asc) {
        return (InterestingPoint[])Arrays.stream(points).sorted(Comparator.comparing(p -> (long)(asc ? 1 : -1) * ReachabilityGraph.getSize(memo.getHopRefs().get(p.getToHopID())))).toArray(InterestingPoint[]::new);
    }

    private static long getSize(Hop hop) {
        return Math.max(hop.getDim1(), 1L) * Math.max(hop.getDim2(), 1L);
    }

    private static class NodeLink
    implements Comparable<NodeLink> {
        private static final IDSequence _seqID = new IDSequence();
        private ArrayList<NodeLink> _inputs = new ArrayList();
        private long _ID = _seqID.getNextID();
        private InterestingPoint _p;

        private NodeLink(InterestingPoint p) {
            this._p = p;
        }

        private void addInput(NodeLink in) {
            this._inputs.add(in);
        }

        public int hashCode() {
            return Arrays.hashCode(new int[]{this._inputs.hashCode(), Long.hashCode(this._ID), this._p.hashCode()});
        }

        public boolean equals(Object o) {
            if (!(o instanceof NodeLink)) {
                return false;
            }
            NodeLink that = (NodeLink)o;
            boolean ret = this._inputs.size() == that._inputs.size();
            for (int i = 0; i < this._inputs.size() && ret; ret &= this._inputs.get((int)i)._ID == that._inputs.get((int)i)._ID, ++i) {
            }
            return ret;
        }

        @Override
        public int compareTo(NodeLink that) {
            if (this._inputs.size() > that._inputs.size()) {
                return -1;
            }
            if (this._inputs.size() < that._inputs.size()) {
                return 1;
            }
            for (int i = 0; i < this._inputs.size(); ++i) {
                int comp = Long.compare(this._inputs.get((int)i)._ID, that._inputs.get((int)i)._ID);
                if (comp == 0) continue;
                return comp;
            }
            return 0;
        }

        public String toString() {
            StringBuilder inputs = new StringBuilder();
            for (NodeLink in : this._inputs) {
                if (inputs.length() > 0) {
                    inputs.append(",");
                }
                inputs.append(in._ID);
            }
            return this._ID + " (" + inputs.toString() + ") " + (this._p != null ? this._p : "null");
        }

        private String explain(HashSet<Long> visited) {
            if (visited.contains(this._ID)) {
                return "";
            }
            StringBuilder sb = new StringBuilder();
            StringBuilder inputs = new StringBuilder();
            for (NodeLink in : this._inputs) {
                String tmp = in.explain(visited);
                if (!tmp.isEmpty()) {
                    sb.append(tmp + "\n");
                }
                if (inputs.length() > 0) {
                    inputs.append(",");
                }
                inputs.append(in._ID);
            }
            sb.append(this._ID + " (" + inputs + ") " + (this._p != null ? this._p : "null"));
            visited.add(this._ID);
            return sb.toString();
        }
    }

    private static class CutSet {
        private final InterestingPoint[] cut;
        private final InterestingPoint[] left;
        private final InterestingPoint[] right;
        private int[] posCut;
        private int[] posLeft;
        private int[] posRight;

        private CutSet(InterestingPoint[] cutPoints, InterestingPoint[] l, InterestingPoint[] r) {
            this.cut = cutPoints;
            this.left = (InterestingPoint[])ArrayUtils.addAll(this.cut, l);
            this.right = (InterestingPoint[])ArrayUtils.addAll(this.cut, r);
        }

        private void updatePositions(HashMap<InterestingPoint, Integer> probe) {
            int lenCut = this.cut.length;
            this.posCut = new int[lenCut];
            for (int i = 0; i < lenCut; ++i) {
                this.posCut[i] = probe.get(this.cut[i]);
            }
            int lenLeft = this.left.length - this.cut.length;
            this.posLeft = new int[lenLeft];
            for (int i = 0; i < lenLeft; ++i) {
                this.posLeft[i] = probe.get(this.left[lenCut + i]);
            }
            int lenRight = this.right.length - this.cut.length;
            this.posRight = new int[lenRight];
            for (int i = 0; i < lenRight; ++i) {
                this.posRight[i] = probe.get(this.right[lenCut + i]);
            }
        }

        public String toString() {
            return "Cut : " + Arrays.toString(this.cut);
        }
    }

    public static class SubProblem {
        public int offset;
        public int[] freePos;
        public InterestingPoint[] freeMat;

        public SubProblem(int off, int[] pos, InterestingPoint[] mat) {
            this.offset = off;
            this.freePos = pos;
            this.freeMat = mat;
        }

        public String toString() {
            return "SubProblem: " + Arrays.toString(this.freeMat) + "; " + this.offset + "; " + Arrays.toString(this.freePos);
        }
    }
}

