/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kylin.query.relnode;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.Generated;
import org.apache.calcite.adapter.enumerable.EnumerableHashJoin;
import org.apache.calcite.adapter.enumerable.EnumerableRel;
import org.apache.calcite.adapter.enumerable.EnumerableRelImplementor;
import org.apache.calcite.adapter.enumerable.JavaRowFormat;
import org.apache.calcite.adapter.enumerable.PhysType;
import org.apache.calcite.adapter.enumerable.PhysTypeImpl;
import org.apache.calcite.adapter.java.JavaTypeFactory;
import org.apache.calcite.linq4j.tree.Blocks;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.linq4j.tree.MethodCallExpression;
import org.apache.calcite.linq4j.tree.Node;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.plan.RelTrait;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.BiRel;
import org.apache.calcite.rel.InvalidRelException;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rel.type.RelDataTypeFieldImpl;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.kylin.guava30.shaded.common.base.Preconditions;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.guava30.shaded.common.collect.Sets;
import org.apache.kylin.metadata.model.JoinDesc;
import org.apache.kylin.metadata.model.TblColRef;
import org.apache.kylin.query.relnode.ColumnRowType;
import org.apache.kylin.query.relnode.ContextUtil;
import org.apache.kylin.query.relnode.OlapContext;
import org.apache.kylin.query.relnode.OlapProjectRel;
import org.apache.kylin.query.relnode.OlapRel;
import org.apache.kylin.query.relnode.OlapTableScan;
import org.apache.kylin.query.schema.OlapTable;
import org.apache.kylin.query.util.ICutContextStrategy;
import org.apache.kylin.query.util.RexUtils;

public class OlapJoinRel
extends EnumerableHashJoin
implements OlapRel {
    static final double LARGE_JOIN_FACTOR = 100.0;
    static final String[] COLUMN_ARRAY_MARKER = new String[0];
    private OlapContext context;
    private Set<OlapContext> subContexts = Sets.newHashSet();
    private ColumnRowType columnRowType;
    private int columnRowTypePivot;
    private boolean isPreCalJoin = true;
    private boolean aboveTopPreCalcJoin = false;
    private boolean joinCondEqualNullSafe = false;

    public OlapJoinRel(RelOptCluster cluster, RelTraitSet traits, RelNode left, RelNode right, RexNode condition, Set<CorrelationId> variablesSet, JoinRelType joinType) throws InvalidRelException {
        super(cluster, traits, left, right, condition, variablesSet, joinType);
        Preconditions.checkArgument((this.getConvention() == CONVENTION ? 1 : 0) != 0);
        this.rowType = this.getRowType();
    }

    public RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
        return this.joinType == JoinRelType.RIGHT || this.condition.isAlwaysTrue() ? super.computeSelfCost(planner, mq).multiplyBy(100.0) : super.computeSelfCost(planner, mq).multiplyBy(0.05);
    }

    public double estimateRowCount(RelMetadataQuery mq) {
        return super.estimateRowCount(mq) * 0.1;
    }

    protected boolean isParentMerelyPermutation(OlapRel.OlapImpl olapImpl) {
        if (olapImpl.getParentNode() instanceof OlapProjectRel) {
            return ((OlapProjectRel)olapImpl.getParentNode()).isMerelyPermutation();
        }
        return false;
    }

    protected ColumnRowType buildColumnRowType() {
        ArrayList<Set<TblColRef>> sourceColumns = new ArrayList<Set<TblColRef>>();
        OlapRel leftRel = (OlapRel)this.left;
        OlapRel rightRel = (OlapRel)this.right;
        ColumnRowType leftColumnRowType = leftRel.getColumnRowType();
        ColumnRowType rightColumnRowType = rightRel.getColumnRowType();
        this.columnRowTypePivot = leftColumnRowType.getAllColumns().size();
        ArrayList<TblColRef> columns = new ArrayList<TblColRef>(leftColumnRowType.getAllColumns());
        columns.addAll(rightColumnRowType.getAllColumns());
        if (columns.size() != this.rowType.getFieldCount()) {
            throw new IllegalStateException("RowType=" + this.rowType.getFieldCount() + ", ColumnRowType=" + columns.size());
        }
        columns.forEach(col -> sourceColumns.add(col.getSourceColumns()));
        return new ColumnRowType(columns, sourceColumns);
    }

    protected JoinDesc buildJoin(RexCall condition) {
        Map<TblColRef, Set<TblColRef>> joinColumns = this.translateJoinColumn((RexNode)condition);
        ArrayList<String> pks = new ArrayList<String>();
        ArrayList<String> fks = new ArrayList<String>();
        ArrayList<TblColRef> pkCols = new ArrayList<TblColRef>();
        ArrayList<TblColRef> fkCols = new ArrayList<TblColRef>();
        for (Map.Entry<TblColRef, Set<TblColRef>> columnPair : joinColumns.entrySet()) {
            TblColRef fromCol = columnPair.getKey();
            Set<TblColRef> toCols = columnPair.getValue();
            for (TblColRef toCol : toCols) {
                fks.add(fromCol.getName());
                pks.add(toCol.getName());
                fkCols.add(fromCol);
                pkCols.add(toCol);
            }
        }
        JoinDesc join = new JoinDesc();
        join.setForeignKey(fks.toArray(COLUMN_ARRAY_MARKER));
        join.setPrimaryKey(pks.toArray(COLUMN_ARRAY_MARKER));
        join.setForeignKeyColumns(fkCols.toArray(new TblColRef[0]));
        join.setPrimaryKeyColumns(pkCols.toArray(new TblColRef[0]));
        join.sortByFK();
        return join;
    }

    protected Map<TblColRef, Set<TblColRef>> translateJoinColumn(RexNode condition) {
        HashMap<TblColRef, Set<TblColRef>> joinColumns = new HashMap<TblColRef, Set<TblColRef>>();
        if (condition instanceof RexCall) {
            this.translateJoinColumn((RexCall)condition, joinColumns);
        }
        return joinColumns;
    }

    void translateJoinColumn(RexCall condition, Map<TblColRef, Set<TblColRef>> joinColumns) {
        SqlKind kind = condition.getOperator().getKind();
        if (kind == SqlKind.AND) {
            for (RexNode operand : condition.getOperands()) {
                RexCall subCond = (RexCall)operand;
                this.translateJoinColumn(subCond, joinColumns);
            }
        } else if (kind == SqlKind.EQUALS) {
            List operands = condition.getOperands();
            RexInputRef op0 = (RexInputRef)operands.get(0);
            TblColRef col0 = this.columnRowType.getColumnByIndex(op0.getIndex());
            RexInputRef op1 = (RexInputRef)operands.get(1);
            TblColRef col1 = this.columnRowType.getColumnByIndex(op1.getIndex());
            if (op0.getIndex() < this.columnRowTypePivot) {
                joinColumns.computeIfAbsent(col0, key -> new HashSet()).add(col1);
            } else {
                joinColumns.computeIfAbsent(col1, key -> new HashSet()).add(col0);
            }
        }
    }

    public EnumerableRel.Result implement(EnumerableRelImplementor implementor, EnumerableRel.Prefer pref) {
        String execFunc = this.context.genExecFunc(this);
        PhysType physType = PhysTypeImpl.of((JavaTypeFactory)implementor.getTypeFactory(), (RelDataType)this.getRowType(), (JavaRowFormat)pref.preferArray());
        RelOptTable factTable = this.context.getFirstTableScan().getTable();
        MethodCallExpression exprCall = Expressions.call((Expression)Objects.requireNonNull(factTable.getExpression(OlapTable.class)), (String)execFunc, (Expression[])new Expression[]{implementor.getRootExpression(), Expressions.constant((Object)this.context.getId())});
        return implementor.result(physType, Blocks.toBlock((Node)exprCall));
    }

    @Override
    public boolean hasSubQuery() {
        return false;
    }

    @Override
    public RelTraitSet replaceTraitSet(RelTrait trait) {
        RelTraitSet oldTraitSet = this.traitSet;
        this.traitSet = this.traitSet.replace(trait);
        return oldTraitSet;
    }

    public RelWriter explainTerms(RelWriter pw) {
        return super.explainTerms(pw).item("ctx", (Object)this.displayCtxId(this.context));
    }

    public EnumerableHashJoin copy(RelTraitSet traitSet, RexNode conditionExpr, RelNode left, RelNode right, JoinRelType joinType, boolean semiJoinDone) {
        try {
            return new OlapJoinRel(this.getCluster(), traitSet, left, right, conditionExpr, (Set<CorrelationId>)this.variablesSet, joinType);
        }
        catch (InvalidRelException e) {
            throw new AssertionError((Object)e);
        }
    }

    public boolean isRuntimeJoin() {
        if (this.context != null) {
            this.context.setReturnTupleInfo(this.rowType, this.columnRowType);
        }
        return this.context == null || ((OlapRel)this.left).getContext() != ((OlapRel)this.right).getContext();
    }

    @Override
    public void implementContext(OlapRel.ContextImpl contextImpl, OlapRel.ContextVisitorState state) {
        OlapRel.ContextVisitorState leftState = OlapRel.ContextVisitorState.init();
        contextImpl.fixSharedOlapTableScanOnTheLeft((BiRel)this);
        contextImpl.visitChild(this.getInput(0), this, leftState);
        OlapRel.ContextVisitorState rightState = OlapRel.ContextVisitorState.init();
        contextImpl.fixSharedOlapTableScanOnTheRight((BiRel)this);
        contextImpl.visitChild(this.getInput(1), this, rightState);
        if (leftState.hasModelView() || rightState.hasModelView()) {
            if (leftState.hasFreeTable()) {
                contextImpl.allocateContext((OlapRel)this.getInput(0), this);
                leftState.setHasFreeTable(false);
            }
            if (rightState.hasFreeTable()) {
                contextImpl.allocateContext((OlapRel)this.getInput(1), this);
                rightState.setHasFreeTable(false);
            }
        }
        if (this.getJoinType() == JoinRelType.LEFT && rightState.hasFilter() && rightState.hasFreeTable()) {
            contextImpl.allocateContext((OlapRel)this.getInput(1), this);
            rightState.setHasFreeTable(false);
        }
        if (this.getJoinType() == JoinRelType.INNER || this.getJoinType() == JoinRelType.LEFT) {
            if (!leftState.hasFreeTable() && rightState.hasFreeTable()) {
                contextImpl.allocateContext((OlapRel)this.right, this);
                rightState.setHasFreeTable(false);
            } else if (leftState.hasFreeTable() && !rightState.hasFreeTable()) {
                contextImpl.allocateContext((OlapRel)this.left, this);
                leftState.setHasFreeTable(false);
            } else if (leftState.hasFreeTable() && rightState.hasFreeTable() && (this.isCrossJoin() || this.hasSameFirstTable(leftState, rightState) || this.isRightSideIncrementalTable(rightState) || RexUtils.joinMoreThanOneTable((Join)this) || !RexUtils.isMerelyTableColumnReference(this, this.condition) || this.joinCondEqualNullSafe)) {
                contextImpl.allocateContext((OlapRel)this.left, this);
                contextImpl.allocateContext((OlapRel)this.right, this);
                leftState.setHasFreeTable(false);
                rightState.setHasFreeTable(false);
            }
            state.merge(leftState).merge(rightState);
            this.subContexts.addAll(ContextUtil.collectSubContext(this.left));
            this.subContexts.addAll(ContextUtil.collectSubContext(this.right));
            return;
        }
        if (leftState.hasFreeTable()) {
            contextImpl.allocateContext((OlapRel)this.left, this);
            leftState.setHasFreeTable(false);
        }
        if (rightState.hasFreeTable()) {
            contextImpl.allocateContext((OlapRel)this.right, this);
            rightState.setHasFreeTable(false);
        }
        state.merge(leftState).merge(rightState);
        this.subContexts.addAll(ContextUtil.collectSubContext(this.left));
        this.subContexts.addAll(ContextUtil.collectSubContext(this.right));
    }

    private boolean isRightSideIncrementalTable(OlapRel.ContextVisitorState rightState) {
        return rightState.hasIncrementalTable();
    }

    private boolean hasSameFirstTable(OlapRel.ContextVisitorState leftState, OlapRel.ContextVisitorState rightState) {
        return !leftState.hasIncrementalTable() && !rightState.hasIncrementalTable() && leftState.hasFirstTable() && rightState.hasFirstTable();
    }

    private boolean isCrossJoin() {
        return this.joinInfo.leftKeys.isEmpty() || this.joinInfo.rightKeys.isEmpty();
    }

    public ImmutableIntList getLeftKeys() {
        return this.joinInfo.leftKeys;
    }

    public ImmutableIntList getRightKeys() {
        return this.joinInfo.rightKeys;
    }

    @Override
    public void implementCutContext(ICutContextStrategy.ContextCutImpl contextCutImpl) {
        if (!this.isPreCalJoin) {
            RelNode input = this.context == ((OlapRel)this.left).getContext() ? this.left : this.right;
            contextCutImpl.visitChild(input);
            this.context = null;
            this.columnRowType = null;
        } else {
            this.context = null;
            this.columnRowType = null;
            contextCutImpl.allocateContext((OlapRel)this.getInput(0), this);
            contextCutImpl.allocateContext((OlapRel)this.getInput(1), this);
        }
    }

    @Override
    public void setContext(OlapContext context) {
        this.context = context;
        for (RelNode input : this.getInputs()) {
            ((OlapRel)input).setContext(context);
            this.subContexts.addAll(ContextUtil.collectSubContext(input));
        }
    }

    @Override
    public boolean pushRelInfoToContext(OlapContext context) {
        if (this.context != null) {
            return false;
        }
        if (this == context.getParentOfTopNode() || ((OlapRel)this.getLeft()).pushRelInfoToContext(context) || ((OlapRel)this.getRight()).pushRelInfoToContext(context)) {
            this.context = context;
            this.isPreCalJoin = false;
            return true;
        }
        return false;
    }

    @Override
    public void implementOlap(OlapRel.OlapImpl implementor) {
        if (this.context != null) {
            this.context.getAllOlapJoins().add(this);
            this.aboveTopPreCalcJoin = !this.isPreCalJoin || !this.context.isHasPreCalcJoin();
            this.context.setHasJoin(true);
            this.context.setHasPreCalcJoin(this.context.isHasPreCalcJoin() || this.isPreCalJoin);
        }
        implementor.visitChild(this.left, this);
        implementor.visitChild(this.right, this);
        this.columnRowType = this.buildColumnRowType();
        if (this.context != null) {
            this.collectCtxOlapInfoIfExist();
        } else {
            Map<TblColRef, Set<TblColRef>> joinColumns = this.translateJoinColumn(this.getCondition());
            this.pushDownJoinColsToSubContexts(joinColumns.entrySet().stream().flatMap(e -> Stream.concat(Stream.of(e.getKey()), ((Set)e.getValue()).stream())).collect(Collectors.toSet()));
        }
    }

    private void collectCtxOlapInfoIfExist() {
        if (this.isPreCalJoin || this.context.getParentOfTopNode() instanceof OlapRel && ((OlapRel)this.context.getParentOfTopNode()).getContext() != this.context) {
            JoinDesc join = this.buildJoin((RexCall)this.getCondition());
            String joinType = this.getJoinType() == JoinRelType.INNER || this.getJoinType() == JoinRelType.LEFT ? this.getJoinType().name() : null;
            join.setType(joinType);
            this.context.getJoins().add(join);
        } else {
            Map<TblColRef, Set<TblColRef>> joinColumnsMap = this.translateJoinColumn(this.getCondition());
            Collection joinCols = joinColumnsMap.entrySet().stream().flatMap(e -> Stream.concat(Stream.of(e.getKey()), ((Set)e.getValue()).stream())).collect(Collectors.toSet());
            joinCols.stream().flatMap(e -> e.getSourceColumns().stream()).filter(this.context::belongToContextTables).forEach(colRef -> {
                this.context.getSubqueryJoinParticipants().add((TblColRef)colRef);
                this.context.getAllColumns().add((TblColRef)colRef);
            });
            this.pushDownJoinColsToSubContexts(joinCols);
        }
        if (this == this.context.getTopNode() && !this.context.isHasAgg()) {
            ContextUtil.amendAllColsIfNoAgg(this);
        }
    }

    @Override
    public void implementRewrite(OlapRel.RewriteImpl rewriteImpl) {
        rewriteImpl.visitChild(this, this.left);
        rewriteImpl.visitChild(this, this.right);
        if (this.context != null) {
            this.rowType = this.deriveRowType();
            if (this.context.hasPrecalculatedFields() && this.aboveTopPreCalcJoin && OlapRel.RewriteImpl.needRewrite(this.context)) {
                int paramIndex = this.rowType.getFieldList().size();
                LinkedList newFieldList = Lists.newLinkedList();
                for (Map.Entry<String, RelDataType> rewriteField : this.context.getRewriteFields().entrySet()) {
                    String fieldName = rewriteField.getKey();
                    if (this.rowType.getField(fieldName, true, false) != null) continue;
                    RelDataType fieldType = rewriteField.getValue();
                    RelDataTypeFieldImpl newField = new RelDataTypeFieldImpl(fieldName, paramIndex++, fieldType);
                    newFieldList.add(newField);
                }
                List fieldList = Stream.of(this.rowType.getFieldList(), newFieldList).flatMap(Collection::stream).collect(Collectors.toList());
                this.rowType = this.getCluster().getTypeFactory().createStructType(fieldList);
                this.columnRowType = this.rebuildColumnRowType(newFieldList, this.context);
            }
        }
    }

    @Override
    public EnumerableRel implementEnumerable(List<EnumerableRel> inputs) {
        if (this.isRuntimeJoin()) {
            try {
                return EnumerableHashJoin.create((RelNode)((RelNode)inputs.get(0)), (RelNode)((RelNode)inputs.get(1)), (RexNode)this.condition, (Set)this.variablesSet, (JoinRelType)this.joinType);
            }
            catch (Exception e) {
                throw new IllegalStateException("Can't create EnumerableHashJoin!", e);
            }
        }
        return this;
    }

    private void pushDownJoinColsToSubContexts(Collection<TblColRef> joinColumns) {
        for (OlapContext subContext : this.subContexts) {
            this.collectJoinColsToContext(joinColumns, subContext);
        }
    }

    private void collectJoinColsToContext(Collection<TblColRef> joinColumns, OlapContext context) {
        Set sourceJoinKeyCols = joinColumns.stream().flatMap(col -> col.getSourceColumns().stream()).filter(context::belongToContextTables).collect(Collectors.toSet());
        context.getAllColumns().addAll(sourceJoinKeyCols);
        if (context.getOuterJoinParticipants().isEmpty() && this.isDirectOuterJoin(this, context)) {
            context.getOuterJoinParticipants().addAll(sourceJoinKeyCols);
        }
    }

    private boolean isDirectOuterJoin(RelNode currentNode, OlapContext context) {
        if (currentNode == this) {
            for (RelNode input : currentNode.getInputs()) {
                if (!this.isDirectOuterJoin(input, context)) continue;
                return true;
            }
            return false;
        }
        if (((OlapRel)currentNode).getContext() == context) {
            return true;
        }
        if (currentNode instanceof Project || currentNode instanceof Filter) {
            return this.isDirectOuterJoin(currentNode.getInput(0), context);
        }
        return false;
    }

    @Override
    public void setSubContexts(Set<OlapContext> contexts) {
        this.subContexts = contexts;
    }

    private ColumnRowType rebuildColumnRowType(List<RelDataTypeField> missingFields, OlapContext context) {
        ArrayList columns = Lists.newArrayList();
        OlapRel olapLeft = (OlapRel)this.left;
        OlapRel olapRight = (OlapRel)this.right;
        columns.addAll(olapLeft.getColumnRowType().getAllColumns());
        columns.addAll(olapRight.getColumnRowType().getAllColumns());
        for (RelDataTypeField dataTypeField : missingFields) {
            OlapTableScan tableScan;
            String fieldName = dataTypeField.getName();
            TblColRef aggOutCol = null;
            Iterator<OlapTableScan> iterator = context.getAllTableScans().iterator();
            while (iterator.hasNext() && (aggOutCol = (tableScan = iterator.next()).getColumnRowType().getColumnByName(fieldName)) == null) {
            }
            if (aggOutCol == null) {
                aggOutCol = TblColRef.newInnerColumn((String)fieldName, (TblColRef.InnerDataTypeEnum)TblColRef.InnerDataTypeEnum.LITERAL);
            }
            aggOutCol.getColumnDesc().setId(String.valueOf(dataTypeField.getIndex()));
            columns.add(aggOutCol);
        }
        if (columns.size() != this.rowType.getFieldCount()) {
            throw new IllegalStateException("RowType=" + this.rowType.getFieldCount() + ", ColumnRowType=" + columns.size());
        }
        return new ColumnRowType(columns);
    }

    @Override
    @Generated
    public OlapContext getContext() {
        return this.context;
    }

    @Override
    @Generated
    public Set<OlapContext> getSubContexts() {
        return this.subContexts;
    }

    @Override
    @Generated
    public ColumnRowType getColumnRowType() {
        return this.columnRowType;
    }

    @Generated
    public boolean isJoinCondEqualNullSafe() {
        return this.joinCondEqualNullSafe;
    }

    @Generated
    private int getColumnRowTypePivot() {
        return this.columnRowTypePivot;
    }

    @Generated
    private boolean isPreCalJoin() {
        return this.isPreCalJoin;
    }

    @Generated
    private boolean isAboveTopPreCalcJoin() {
        return this.aboveTopPreCalcJoin;
    }

    @Generated
    public void setJoinCondEqualNullSafe(boolean joinCondEqualNullSafe) {
        this.joinCondEqualNullSafe = joinCondEqualNullSafe;
    }
}

