/*
 * Decompiled with CFR 0.152.
 */
package org.apache.cassandra.cql3.functions;

import com.google.common.collect.ImmutableList;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.cassandra.cql3.CQL3Type;
import org.apache.cassandra.cql3.functions.AggregateFcts;
import org.apache.cassandra.cql3.functions.AggregateFunction;
import org.apache.cassandra.cql3.functions.Arguments;
import org.apache.cassandra.cql3.functions.FunctionArguments;
import org.apache.cassandra.cql3.functions.FunctionFactory;
import org.apache.cassandra.cql3.functions.FunctionParameter;
import org.apache.cassandra.cql3.functions.NativeAggregateFunction;
import org.apache.cassandra.cql3.functions.NativeFunction;
import org.apache.cassandra.cql3.functions.NativeFunctions;
import org.apache.cassandra.cql3.functions.NativeScalarFunction;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.db.marshal.CollectionType;
import org.apache.cassandra.db.marshal.Int32Type;
import org.apache.cassandra.db.marshal.ListType;
import org.apache.cassandra.db.marshal.MapType;
import org.apache.cassandra.db.marshal.SetType;
import org.apache.cassandra.transport.ProtocolVersion;

public class CollectionFcts {
    public static void addFunctionsTo(NativeFunctions functions) {
        functions.add(new FunctionFactory("map_keys", new FunctionParameter[]{FunctionParameter.anyMap()}){

            @Override
            protected NativeFunction doGetOrCreateFunction(List<AbstractType<?>> argTypes, AbstractType<?> receiverType) {
                return CollectionFcts.makeMapKeysFunction(this.name.name, (MapType)argTypes.get(0));
            }
        });
        functions.add(new FunctionFactory("map_values", new FunctionParameter[]{FunctionParameter.anyMap()}){

            @Override
            protected NativeFunction doGetOrCreateFunction(List<AbstractType<?>> argTypes, AbstractType<?> receiverType) {
                return CollectionFcts.makeMapValuesFunction(this.name.name, (MapType)argTypes.get(0));
            }
        });
        functions.add(new FunctionFactory("collection_count", new FunctionParameter[]{FunctionParameter.anyCollection()}){

            @Override
            protected NativeFunction doGetOrCreateFunction(List<AbstractType<?>> argTypes, AbstractType<?> receiverType) {
                return CollectionFcts.makeCollectionCountFunction(this.name.name, (CollectionType)argTypes.get(0));
            }
        });
        functions.add(new FunctionFactory("collection_min", new FunctionParameter[]{FunctionParameter.setOrList()}){

            @Override
            protected NativeFunction doGetOrCreateFunction(List<AbstractType<?>> argTypes, AbstractType<?> receiverType) {
                return CollectionFcts.makeCollectionMinFunction(this.name.name, (CollectionType)argTypes.get(0));
            }
        });
        functions.add(new FunctionFactory("collection_max", new FunctionParameter[]{FunctionParameter.setOrList()}){

            @Override
            protected NativeFunction doGetOrCreateFunction(List<AbstractType<?>> argTypes, AbstractType<?> receiverType) {
                return CollectionFcts.makeCollectionMaxFunction(this.name.name, (CollectionType)argTypes.get(0));
            }
        });
        functions.add(new FunctionFactory("collection_sum", new FunctionParameter[]{FunctionParameter.numericSetOrList()}){

            @Override
            protected NativeFunction doGetOrCreateFunction(List<AbstractType<?>> argTypes, AbstractType<?> receiverType) {
                return CollectionFcts.makeCollectionSumFunction(this.name.name, (CollectionType)argTypes.get(0));
            }
        });
        functions.add(new FunctionFactory("collection_avg", new FunctionParameter[]{FunctionParameter.numericSetOrList()}){

            @Override
            protected NativeFunction doGetOrCreateFunction(List<AbstractType<?>> argTypes, AbstractType<?> receiverType) {
                return CollectionFcts.makeCollectionAvgFunction(this.name.name, (CollectionType)argTypes.get(0));
            }
        });
    }

    private static <K, V> NativeScalarFunction makeMapKeysFunction(String name, MapType<K, V> inputType) {
        final SetType<K> outputType = SetType.getInstance(inputType.getKeysType(), false);
        return new NativeScalarFunction(name, outputType, new AbstractType[]{inputType}){

            @Override
            public ByteBuffer execute(Arguments arguments) {
                if (arguments.containsNulls()) {
                    return null;
                }
                Map map = (Map)arguments.get(0);
                Set keys = map.keySet();
                return outputType.decompose(keys);
            }
        };
    }

    private static <K, V> NativeScalarFunction makeMapValuesFunction(String name, MapType<K, V> inputType) {
        final ListType<V> outputType = ListType.getInstance(inputType.getValuesType(), false);
        return new NativeScalarFunction(name, outputType, new AbstractType[]{inputType}){

            @Override
            public ByteBuffer execute(Arguments arguments) {
                if (arguments.containsNulls()) {
                    return null;
                }
                Map map = (Map)arguments.get(0);
                ImmutableList values = ImmutableList.copyOf(map.values());
                return outputType.decompose(values);
            }
        };
    }

    private static <T> NativeScalarFunction makeCollectionCountFunction(String name, final CollectionType<T> inputType) {
        return new NativeScalarFunction(name, (AbstractType)Int32Type.instance, new AbstractType[]{inputType}){

            @Override
            public Arguments newArguments(ProtocolVersion version) {
                return FunctionArguments.newNoopInstance(version, 1);
            }

            @Override
            public ByteBuffer execute(Arguments arguments) {
                if (arguments.containsNulls()) {
                    return null;
                }
                int size = inputType.size((ByteBuffer)arguments.get(0));
                return Int32Type.instance.decompose(size);
            }
        };
    }

    private static <T> NativeScalarFunction makeCollectionMinFunction(String name, CollectionType<T> inputType) {
        AbstractType<?> elementsType = CollectionFcts.elementsType(inputType);
        NativeAggregateFunction function = elementsType.isCounter() ? AggregateFcts.minFunctionForCounter : AggregateFcts.makeMinFunction(elementsType);
        return new CollectionAggregationFunction(name, inputType, function);
    }

    private static <T> NativeScalarFunction makeCollectionMaxFunction(String name, CollectionType<T> inputType) {
        AbstractType<?> elementsType = CollectionFcts.elementsType(inputType);
        NativeAggregateFunction function = elementsType.isCounter() ? AggregateFcts.maxFunctionForCounter : AggregateFcts.makeMaxFunction(elementsType);
        return new CollectionAggregationFunction(name, inputType, function);
    }

    private static <T> NativeScalarFunction makeCollectionSumFunction(String name, CollectionType<T> inputType) {
        CQL3Type elementsType = CollectionFcts.elementsType(inputType).asCQL3Type();
        NativeAggregateFunction function = CollectionFcts.getSumFunction((CQL3Type.Native)elementsType);
        return new CollectionAggregationFunction(name, inputType, function);
    }

    private static NativeAggregateFunction getSumFunction(CQL3Type.Native type) {
        switch (type) {
            case TINYINT: {
                return AggregateFcts.sumFunctionForByte;
            }
            case SMALLINT: {
                return AggregateFcts.sumFunctionForShort;
            }
            case INT: {
                return AggregateFcts.sumFunctionForInt32;
            }
            case BIGINT: {
                return AggregateFcts.sumFunctionForLong;
            }
            case FLOAT: {
                return AggregateFcts.sumFunctionForFloat;
            }
            case DOUBLE: {
                return AggregateFcts.sumFunctionForDouble;
            }
            case VARINT: {
                return AggregateFcts.sumFunctionForVarint;
            }
            case DECIMAL: {
                return AggregateFcts.sumFunctionForDecimal;
            }
        }
        throw new AssertionError((Object)("Expected numeric collection but found " + type));
    }

    private static <T> NativeScalarFunction makeCollectionAvgFunction(String name, CollectionType<T> inputType) {
        CQL3Type elementsType = CollectionFcts.elementsType(inputType).asCQL3Type();
        NativeAggregateFunction function = CollectionFcts.getAvgFunction((CQL3Type.Native)elementsType);
        return new CollectionAggregationFunction(name, inputType, function);
    }

    private static NativeAggregateFunction getAvgFunction(CQL3Type.Native type) {
        switch (type) {
            case TINYINT: {
                return AggregateFcts.avgFunctionForByte;
            }
            case SMALLINT: {
                return AggregateFcts.avgFunctionForShort;
            }
            case INT: {
                return AggregateFcts.avgFunctionForInt32;
            }
            case BIGINT: {
                return AggregateFcts.avgFunctionForLong;
            }
            case FLOAT: {
                return AggregateFcts.avgFunctionForFloat;
            }
            case DOUBLE: {
                return AggregateFcts.avgFunctionForDouble;
            }
            case VARINT: {
                return AggregateFcts.avgFunctionForVarint;
            }
            case DECIMAL: {
                return AggregateFcts.avgFunctionForDecimal;
            }
        }
        throw new AssertionError((Object)("Expected numeric collection but found " + type));
    }

    private static AbstractType<?> elementsType(CollectionType<?> type) {
        if (type.kind == CollectionType.Kind.LIST) {
            return ((ListType)type).getElementsType();
        }
        if (type.kind == CollectionType.Kind.SET) {
            return ((SetType)type).getElementsType();
        }
        throw new AssertionError((Object)("Cannot get the element type of: " + type));
    }

    private static class CollectionAggregationFunction
    extends NativeScalarFunction {
        private final CollectionType<?> inputType;
        private final NativeAggregateFunction aggregateFunction;

        public CollectionAggregationFunction(String name, CollectionType<?> inputType, NativeAggregateFunction aggregateFunction) {
            super(name, aggregateFunction.returnType, inputType);
            this.inputType = inputType;
            this.aggregateFunction = aggregateFunction;
        }

        @Override
        public Arguments newArguments(ProtocolVersion version) {
            return FunctionArguments.newNoopInstance(version, 1);
        }

        @Override
        public ByteBuffer execute(Arguments arguments) {
            if (arguments.containsNulls()) {
                return null;
            }
            Arguments args = this.aggregateFunction.newArguments(arguments.getProtocolVersion());
            AggregateFunction.Aggregate aggregate = this.aggregateFunction.newAggregate();
            this.inputType.forEach((ByteBuffer)arguments.get(0), element -> {
                args.set(0, (ByteBuffer)element);
                aggregate.addInput(args);
            });
            return aggregate.compute(arguments.getProtocolVersion());
        }
    }
}

