/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.udf.generic;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.time.Instant;
import java.time.LocalDate;
import java.time.ZoneId;
import java.time.ZoneOffset;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collections;
import java.util.List;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantStringObjectInspector;
import org.apache.hadoop.hive.serde2.variant.Variant;
import org.apache.hadoop.hive.serde2.variant.VariantUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Description(name="variant_get", value="_FUNC_(variant, path[, type]) - Extracts a sub-variant from variant according to path, and casts it to type", extended="Example:\n> SELECT _FUNC_(parse_json('{\"a\": 1}'), '$.a', 'int');\n1\n> SELECT _FUNC_(parse_json('{\"a\": 1}'), '$.b', 'int');\nNULL\n> SELECT _FUNC_(parse_json('[1, \"2\"]'), '$[1]', 'string');\n2\n> SELECT _FUNC_(parse_json('[1, \"hello\"]'), '$[1]');\n\"hello\"")
public class GenericUDFVariantGet
extends GenericUDF {
    private static final Logger LOG = LoggerFactory.getLogger(GenericUDFVariantGet.class);
    private static final ObjectMapper MAPPER = new ObjectMapper();
    private StructObjectInspector variantOI;
    private PrimitiveObjectInspector pathOI;
    private String targetType;

    @Override
    public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
        boolean hasTypeArgument;
        if (arguments.length < 2 || arguments.length > 3) {
            throw new UDFArgumentException("variant_get requires 2 or 3 arguments");
        }
        if (!(arguments[0] instanceof StructObjectInspector)) {
            throw new UDFArgumentException("First argument must be VARIANT");
        }
        this.variantOI = (StructObjectInspector)arguments[0];
        if (!(arguments[1] instanceof PrimitiveObjectInspector)) {
            throw new UDFArgumentException("Second argument must be string path");
        }
        this.pathOI = (PrimitiveObjectInspector)arguments[1];
        boolean bl = hasTypeArgument = arguments.length == 3;
        if (hasTypeArgument) {
            ObjectInspector objectInspector = arguments[2];
            if (!(objectInspector instanceof WritableConstantStringObjectInspector)) {
                throw new UDFArgumentException("Third argument must be string type name");
            }
            WritableConstantStringObjectInspector typeOI = (WritableConstantStringObjectInspector)objectInspector;
            this.targetType = typeOI.getWritableConstantValue().toString();
            return PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector((PrimitiveObjectInspector.PrimitiveCategory)PrimitiveObjectInspectorUtils.getTypeEntryFromTypeName((String)this.targetType).primitiveCategory);
        }
        return PrimitiveObjectInspectorFactory.javaStringObjectInspector;
    }

    @Override
    public Object evaluate(GenericUDF.DeferredObject[] arguments) throws HiveException {
        try {
            Object variantObj = arguments[0].get();
            if (variantObj == null) {
                return null;
            }
            Variant variant = Variant.from((List)this.variantOI.getStructFieldsDataAsList(variantObj));
            Object pathObj = arguments[1].get();
            if (pathObj == null) {
                return null;
            }
            String path = this.pathOI.getPrimitiveJavaObject(pathObj).toString();
            Variant result = GenericUDFVariantGet.extractValueByPath(variant, path);
            return GenericUDFVariantGet.castValue(result, this.targetType);
        }
        catch (Exception e) {
            throw new HiveException("Failed to extract variant: " + e.getMessage(), (Throwable)e);
        }
    }

    @Override
    public String getDisplayString(String[] children) {
        return "variant_get(" + String.join((CharSequence)", ", children) + ")";
    }

    private static Variant extractValueByPath(Variant variant, String path) {
        if (variant == null || path == null) {
            return null;
        }
        try {
            List<VariantToken> tokens = VariantPathParser.parse(path);
            Variant current = variant;
            for (VariantToken token : tokens) {
                if (current == null) {
                    return null;
                }
                current = token.get(current);
            }
            return current;
        }
        catch (IllegalArgumentException e) {
            LOG.warn("Invalid path syntax provided: {}", (Object)e.getMessage());
            return null;
        }
    }

    private static Object castValue(Variant value, String targetType) {
        if (value == null || value.getType() == VariantUtil.Type.NULL) {
            return null;
        }
        if (targetType == null) {
            return GenericUDFVariantGet.unescapeJson(value.toJson((ZoneId)ZoneOffset.UTC));
        }
        try {
            return switch (targetType.toLowerCase()) {
                case "boolean", "bool" -> GenericUDFVariantGet.toBoolean(value);
                case "int", "integer" -> GenericUDFVariantGet.toInteger(value);
                case "long" -> GenericUDFVariantGet.toLong(value);
                case "double" -> GenericUDFVariantGet.toDouble(value);
                case "string" -> GenericUDFVariantGet.toString(value);
                default -> throw new IllegalArgumentException("Unsupported target type: " + targetType);
            };
        }
        catch (NumberFormatException e) {
            LOG.warn("Invalid target type syntax provided: {}", (Object)e.getMessage());
            return null;
        }
    }

    private static Integer toInteger(Variant value) {
        return switch (value.getType()) {
            case VariantUtil.Type.LONG -> (int)value.getLong();
            case VariantUtil.Type.DOUBLE -> (int)value.getDouble();
            case VariantUtil.Type.FLOAT -> (int)value.getFloat();
            case VariantUtil.Type.DECIMAL -> value.getDecimal().intValue();
            case VariantUtil.Type.STRING -> Integer.parseInt(value.getString());
            case VariantUtil.Type.BOOLEAN -> value.getBoolean() ? 1 : 0;
            default -> null;
        };
    }

    private static Long toLong(Variant value) {
        return switch (value.getType()) {
            case VariantUtil.Type.LONG -> value.getLong();
            case VariantUtil.Type.DOUBLE -> (long)value.getDouble();
            case VariantUtil.Type.FLOAT -> (long)value.getFloat();
            case VariantUtil.Type.DECIMAL -> value.getDecimal().longValue();
            case VariantUtil.Type.STRING -> Long.parseLong(value.getString());
            case VariantUtil.Type.BOOLEAN -> value.getBoolean() ? 1L : 0L;
            case VariantUtil.Type.DATE -> value.getLong();
            case VariantUtil.Type.TIMESTAMP, VariantUtil.Type.TIMESTAMP_NTZ -> value.getLong();
            default -> null;
        };
    }

    private static Double toDouble(Variant value) {
        return switch (value.getType()) {
            case VariantUtil.Type.LONG -> value.getLong();
            case VariantUtil.Type.DOUBLE -> value.getDouble();
            case VariantUtil.Type.FLOAT -> value.getFloat();
            case VariantUtil.Type.DECIMAL -> value.getDecimal().doubleValue();
            case VariantUtil.Type.STRING -> Double.parseDouble(value.getString());
            case VariantUtil.Type.BOOLEAN -> value.getBoolean() ? 1.0 : 0.0;
            default -> null;
        };
    }

    private static Boolean toBoolean(Variant value) {
        return switch (value.getType()) {
            case VariantUtil.Type.BOOLEAN -> value.getBoolean();
            case VariantUtil.Type.LONG -> value.getLong() != 0L;
            case VariantUtil.Type.DOUBLE -> value.getDouble() != 0.0;
            case VariantUtil.Type.FLOAT -> value.getFloat() != 0.0f;
            case VariantUtil.Type.STRING -> Boolean.parseBoolean(value.getString());
            default -> null;
        };
    }

    private static String toString(Variant value) {
        return switch (value.getType()) {
            case VariantUtil.Type.BOOLEAN -> String.valueOf(value.getBoolean());
            case VariantUtil.Type.LONG -> String.valueOf(value.getLong());
            case VariantUtil.Type.DOUBLE -> String.valueOf(value.getDouble());
            case VariantUtil.Type.FLOAT -> String.valueOf(value.getFloat());
            case VariantUtil.Type.DECIMAL -> value.getDecimal().toPlainString();
            case VariantUtil.Type.STRING -> value.getString();
            case VariantUtil.Type.BINARY -> Base64.getEncoder().encodeToString(value.getBinary());
            case VariantUtil.Type.DATE -> LocalDate.ofEpochDay(value.getLong()).toString();
            case VariantUtil.Type.TIMESTAMP, VariantUtil.Type.TIMESTAMP_NTZ -> {
                Instant instant = Instant.EPOCH.plus(value.getLong(), ChronoUnit.MICROS);
                yield instant.toString();
            }
            case VariantUtil.Type.UUID -> value.getUuid().toString();
            case VariantUtil.Type.OBJECT, VariantUtil.Type.ARRAY -> value.toJson((ZoneId)ZoneOffset.UTC);
            default -> null;
        };
    }

    private static String unescapeJson(String str) {
        if (str == null) {
            return null;
        }
        if (str.startsWith("[") || str.startsWith("{")) {
            return str;
        }
        try {
            return (String)MAPPER.readValue(str, String.class);
        }
        catch (JsonProcessingException e) {
            return null;
        }
    }

    private static final class VariantPathParser {
        private VariantPathParser() {
        }

        public static List<VariantToken> parse(String path) {
            if (path == null || !path.startsWith("$")) {
                throw new IllegalArgumentException("Invalid path: must start with '$'.");
            }
            if (path.length() == 1) {
                return Collections.emptyList();
            }
            ArrayList<VariantToken> tokens = new ArrayList<VariantToken>();
            int i = 1;
            while (i < path.length()) {
                int start;
                char c = path.charAt(i);
                if (c == '.') {
                    start = ++i;
                    while (i < path.length() && path.charAt(i) != '.' && path.charAt(i) != '[') {
                        ++i;
                    }
                    String key = path.substring(start, i);
                    if (key.isEmpty()) {
                        throw new IllegalArgumentException("Invalid path: empty field name at position " + start);
                    }
                    tokens.add(new FieldToken(key));
                    continue;
                }
                if (c == '[') {
                    int end;
                    if ((end = path.indexOf(93, start = ++i)) == -1) {
                        throw new IllegalArgumentException("Invalid path: unclosed array index at position " + start);
                    }
                    String indexStr = path.substring(start, end).trim();
                    try {
                        int index = Integer.parseInt(indexStr);
                        tokens.add(new IndexToken(index));
                    }
                    catch (NumberFormatException e) {
                        throw new IllegalArgumentException("Invalid path: non-integer array index '" + indexStr + "'");
                    }
                    i = end + 1;
                    continue;
                }
                throw new IllegalArgumentException("Invalid path: unexpected character '" + c + "' at position " + i);
            }
            return tokens;
        }
    }

    private static interface VariantToken {
        public Variant get(Variant var1);
    }

    private record IndexToken(int index) implements VariantToken
    {
        @Override
        public Variant get(Variant target) {
            if (target != null && target.getType() == VariantUtil.Type.ARRAY) {
                return target.getElementAtIndex(this.index);
            }
            return null;
        }
    }

    private record FieldToken(String key) implements VariantToken
    {
        @Override
        public Variant get(Variant target) {
            if (target != null && target.getType() == VariantUtil.Type.OBJECT) {
                return target.getFieldByKey(this.key);
            }
            return null;
        }
    }
}

