/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite3.internal.sql.engine.rule.logical;

import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.rules.LoptMultiJoin;
import org.apache.calcite.rel.rules.MultiJoin;
import org.apache.calcite.rel.rules.TransformationRule;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexPermuteInputsShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.ignite3.internal.sql.engine.rule.logical.ImmutableIgniteMultiJoinOptimizeBushyRule;
import org.apache.ignite3.internal.util.IgniteUtils;
import org.immutables.value.Value;
import org.jetbrains.annotations.Nullable;

@Value.Enclosing
public class IgniteMultiJoinOptimizeBushyRule
extends RelRule<Config>
implements TransformationRule {
    private static final int MAX_JOIN_SIZE = 20;
    private static final Comparator<Vertex> VERTEX_COMPARATOR = Comparator.comparingInt(v -> v.size).reversed().thenComparingDouble(v -> v.cost);

    private IgniteMultiJoinOptimizeBushyRule(Config config) {
        super((RelRule.Config)config);
    }

    public void onMatch(RelOptRuleCall call) {
        MultiJoin multiJoinRel = (MultiJoin)call.rel(0);
        int numberOfRelations = multiJoinRel.getInputs().size();
        if (numberOfRelations > 20) {
            return;
        }
        if (multiJoinRel.isFullOuterJoin()) {
            return;
        }
        for (JoinRelType joinType : multiJoinRel.getJoinTypes()) {
            if (joinType == JoinRelType.INNER) continue;
            return;
        }
        LoptMultiJoin multiJoin = new LoptMultiJoin(multiJoinRel);
        RexBuilder rexBuilder = multiJoinRel.getCluster().getRexBuilder();
        RelBuilder relBuilder = call.builder();
        RelMetadataQuery mq = call.getMetadataQuery();
        ArrayList<RexNode> unusedConditions = new ArrayList<RexNode>();
        Int2ObjectMap<List<Edge>> edges = IgniteMultiJoinOptimizeBushyRule.collectEdges(multiJoin, unusedConditions);
        Int2ObjectOpenHashMap bestPlan = new Int2ObjectOpenHashMap();
        BitSet connections = new BitSet(1 << numberOfRelations);
        int id = 1;
        int fieldOffset = 0;
        for (RelNode input : multiJoinRel.getInputs()) {
            Mappings.TargetMapping mapping = Mappings.offsetSource((Mappings.TargetMapping)Mappings.createIdentity((int)input.getRowType().getFieldCount()), (int)fieldOffset, (int)multiJoin.getNumTotalFields());
            bestPlan.put(id, (Object)new Vertex(id, mq.getRowCount(input), input, mapping));
            connections.set(id);
            id <<= 1;
            fieldOffset += input.getRowType().getFieldCount();
        }
        Vertex bestSoFar = null;
        for (int s = 3; s < 1 << numberOfRelations; ++s) {
            if (IgniteUtils.isPow2(s)) continue;
            int lhs = Integer.lowestOneBit(s);
            while (lhs < s / 2 + 1) {
                int rhs = s - lhs;
                List<Edge> edges0 = connections.get(lhs) && connections.get(rhs) ? IgniteMultiJoinOptimizeBushyRule.findEdges(lhs, rhs, edges) : List.of();
                if (!edges0.isEmpty()) {
                    connections.set(s);
                    Vertex planLhs = (Vertex)bestPlan.get(lhs);
                    Vertex planRhs = (Vertex)bestPlan.get(rhs);
                    Vertex newPlan = IgniteMultiJoinOptimizeBushyRule.createJoin(planLhs, planRhs, edges0, mq, relBuilder, rexBuilder);
                    Vertex currentBest = (Vertex)bestPlan.get(s);
                    if (currentBest == null || currentBest.cost > newPlan.cost) {
                        bestPlan.put(s, (Object)newPlan);
                        bestSoFar = IgniteMultiJoinOptimizeBushyRule.chooseBest(bestSoFar, newPlan);
                    }
                    IgniteMultiJoinOptimizeBushyRule.aggregateEdges(edges, lhs, rhs);
                }
                lhs = s & lhs - s;
            }
        }
        int allRelationsMask = (1 << numberOfRelations) - 1;
        Vertex best = bestSoFar == null || bestSoFar.id != allRelationsMask ? IgniteMultiJoinOptimizeBushyRule.composeCartesianJoin(allRelationsMask, (Int2ObjectMap<Vertex>)bestPlan, edges, bestSoFar, mq, relBuilder, rexBuilder) : bestSoFar;
        RelNode result = relBuilder.push(best.rel).filter(new RexNode[]{(RexNode)RexUtil.composeConjunction((RexBuilder)rexBuilder, unusedConditions).accept((RexVisitor)new RexPermuteInputsShuttle(best.mapping, new RelNode[]{best.rel}))}).project((Iterable)relBuilder.fields(best.mapping)).build();
        call.transformTo(result);
    }

    private static void aggregateEdges(Int2ObjectMap<List<Edge>> edges, int lhs, int rhs) {
        int id = lhs | rhs;
        if (!edges.containsKey(id)) {
            Set used = Collections.newSetFromMap(new IdentityHashMap());
            ArrayList union = new ArrayList((Collection)edges.getOrDefault(lhs, List.of()));
            used.addAll(union);
            ((List)edges.getOrDefault(rhs, List.of())).forEach(edge -> {
                if (used.add(edge)) {
                    union.add(edge);
                }
            });
            if (!union.isEmpty()) {
                edges.put(id, union);
            }
        }
    }

    private static Vertex composeCartesianJoin(int allRelationsMask, Int2ObjectMap<Vertex> bestPlan, Int2ObjectMap<List<Edge>> edges, @Nullable Vertex bestSoFar, RelMetadataQuery mq, RelBuilder relBuilder, RexBuilder rexBuilder) {
        ArrayList<Vertex> options;
        if (bestSoFar != null) {
            options = new ArrayList<Vertex>();
            for (Vertex option : bestPlan.values()) {
                if ((option.id & bestSoFar.id) != 0) continue;
                options.add(option);
            }
        } else {
            options = new ArrayList(bestPlan.values());
        }
        options.sort(VERTEX_COMPARATOR);
        Iterator it = options.iterator();
        if (bestSoFar == null) {
            bestSoFar = (Vertex)it.next();
        }
        while (it.hasNext() && bestSoFar.id != allRelationsMask) {
            Vertex input = (Vertex)it.next();
            if ((bestSoFar.id & input.id) != 0) continue;
            List<Edge> edges0 = IgniteMultiJoinOptimizeBushyRule.findEdges(bestSoFar.id, input.id, edges);
            IgniteMultiJoinOptimizeBushyRule.aggregateEdges(edges, bestSoFar.id, input.id);
            bestSoFar = IgniteMultiJoinOptimizeBushyRule.createJoin(bestSoFar, input, edges0, mq, relBuilder, rexBuilder);
        }
        assert (bestSoFar.id == allRelationsMask);
        return bestSoFar;
    }

    private static Vertex chooseBest(@Nullable Vertex currentBest, Vertex candidate) {
        if (currentBest == null) {
            return candidate;
        }
        if (VERTEX_COMPARATOR.compare(currentBest, candidate) > 0) {
            return candidate;
        }
        return currentBest;
    }

    private static Int2ObjectMap<List<Edge>> collectEdges(LoptMultiJoin multiJoin, List<RexNode> unusedConditions) {
        Int2ObjectOpenHashMap edges = new Int2ObjectOpenHashMap();
        for (RexNode condition : multiJoin.getJoinFilters()) {
            int[] inputRefs = multiJoin.getFactorsRefByJoinFilter(condition).toArray();
            if (inputRefs.length < 2) {
                unusedConditions.add(condition);
                continue;
            }
            if (condition.isA(SqlKind.OR)) {
                unusedConditions.add(condition);
                continue;
            }
            int connectedInputs = 0;
            for (int i : inputRefs) {
                connectedInputs |= 1 << i;
            }
            Edge edge = new Edge(connectedInputs, condition);
            for (int i : inputRefs) {
                ((List)edges.computeIfAbsent(1 << i, k -> new ArrayList())).add(edge);
            }
        }
        return edges;
    }

    private static Vertex createJoin(Vertex lhs, Vertex rhs, List<Edge> edges, RelMetadataQuery metadataQuery, RelBuilder relBuilder, RexBuilder rexBuilder) {
        Vertex minorFactor;
        Vertex majorFactor;
        double rightSize;
        ArrayList<RexNode> conditions = new ArrayList<RexNode>();
        for (Edge e : edges) {
            conditions.add(e.condition);
        }
        double leftSize = metadataQuery.getRowCount(lhs.rel);
        if (leftSize >= (rightSize = metadataQuery.getRowCount(rhs.rel).doubleValue())) {
            majorFactor = lhs;
            minorFactor = rhs;
        } else {
            majorFactor = rhs;
            minorFactor = lhs;
        }
        Mappings.TargetMapping mapping = Mappings.merge((Mappings.TargetMapping)majorFactor.mapping, (Mappings.TargetMapping)Mappings.offsetTarget((Mappings.TargetMapping)minorFactor.mapping, (int)majorFactor.rel.getRowType().getFieldCount()));
        RexNode condition = (RexNode)RexUtil.composeConjunction((RexBuilder)rexBuilder, conditions).accept((RexVisitor)new RexPermuteInputsShuttle(mapping, new RelNode[]{majorFactor.rel, minorFactor.rel}));
        RelNode join = relBuilder.push(majorFactor.rel).push(minorFactor.rel).join(JoinRelType.INNER, condition).build();
        double selfCost = metadataQuery.getRowCount(join);
        return new Vertex(lhs.id | rhs.id, selfCost + lhs.cost + rhs.cost, join, mapping);
    }

    private static List<Edge> findEdges(int lhs, int rhs, Int2ObjectMap<List<Edge>> edges) {
        ArrayList<Edge> result = new ArrayList<Edge>();
        List fromLeft = (List)edges.getOrDefault(lhs, List.of());
        for (Edge edge : fromLeft) {
            int requiredInputs = edge.connectedInputs & ~lhs;
            if (requiredInputs == 0 || edge.connectedInputs == requiredInputs || (requiredInputs &= ~rhs) != 0) continue;
            result.add(edge);
        }
        return result;
    }

    private static class Vertex {
        private final int id;
        private final byte size;
        private final double cost;
        private final Mappings.TargetMapping mapping;
        private final RelNode rel;

        Vertex(int id, double cost, RelNode rel, Mappings.TargetMapping mapping) {
            this.id = id;
            this.size = (byte)Integer.bitCount(id);
            this.cost = cost;
            this.rel = rel;
            this.mapping = mapping;
        }
    }

    private static class Edge {
        private final int connectedInputs;
        private final RexNode condition;

        Edge(int connectedInputs, RexNode condition) {
            this.connectedInputs = connectedInputs;
            this.condition = condition;
        }
    }

    @Value.Immutable
    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = ImmutableIgniteMultiJoinOptimizeBushyRule.Config.of().withOperandSupplier(b -> b.operand(MultiJoin.class).anyInputs());

        default public IgniteMultiJoinOptimizeBushyRule toRule() {
            return new IgniteMultiJoinOptimizeBushyRule(this);
        }
    }
}

