/*
 * Decompiled with CFR 0.152.
 */
package org.apache.cassandra.spark.bulkwriter;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Range;
import java.io.IOException;
import java.math.BigInteger;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.cassandra.spark.bulkwriter.BulkWriteValidator;
import org.apache.cassandra.spark.bulkwriter.BulkWriterConfig;
import org.apache.cassandra.spark.bulkwriter.BulkWriterContext;
import org.apache.cassandra.spark.bulkwriter.DecoratedKey;
import org.apache.cassandra.spark.bulkwriter.JobInfo;
import org.apache.cassandra.spark.bulkwriter.RingInstance;
import org.apache.cassandra.spark.bulkwriter.SortedSSTableWriter;
import org.apache.cassandra.spark.bulkwriter.StreamResult;
import org.apache.cassandra.spark.bulkwriter.StreamSession;
import org.apache.cassandra.spark.bulkwriter.WriteResult;
import org.apache.cassandra.spark.bulkwriter.token.MultiClusterReplicaAwareFailureHandler;
import org.apache.cassandra.spark.bulkwriter.token.ReplicaAwareFailureHandler;
import org.apache.cassandra.spark.bulkwriter.token.TokenRangeMapping;
import org.apache.cassandra.spark.bulkwriter.util.TaskContextUtils;
import org.apache.cassandra.spark.data.BridgeUdtValue;
import org.apache.cassandra.spark.data.CqlField;
import org.apache.cassandra.spark.data.CqlTable;
import org.apache.cassandra.spark.data.ReplicationFactor;
import org.apache.cassandra.spark.utils.DigestAlgorithm;
import org.apache.cassandra.util.ThreadUtil;
import org.apache.spark.TaskContext;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

public class RecordWriter {
    public static final ReplicationFactor IGNORED_REPLICATION_FACTOR = new ReplicationFactor(ReplicationFactor.ReplicationStrategy.SimpleStrategy, (Map)ImmutableMap.of((Object)"replication_factor", (Object)1));
    private static final Logger LOGGER = LoggerFactory.getLogger(RecordWriter.class);
    private final BulkWriterContext writerContext;
    private final String[] columnNames;
    private final SSTableWriterFactory tableWriterFactory;
    private final DigestAlgorithm digestAlgorithm;
    private final BulkWriteValidator writeValidator;
    private final ReplicaAwareFailureHandler<RingInstance> failureHandler;
    private final Supplier<TaskContext> taskContextSupplier;
    private final ConcurrentHashMap<String, CqlField.CqlUdt> udtCache = new ConcurrentHashMap();
    private final Map<String, Future<StreamResult>> streamFutures;
    private final ExecutorService executorService;
    private final Path baseDir;
    private final CqlTable cqlTable;
    private StreamSession<?> streamSession = null;

    public RecordWriter(BulkWriterConfig config, String[] columnNames) {
        this(BulkWriterContext.from(config), columnNames, TaskContext::get, SortedSSTableWriter::new);
    }

    @VisibleForTesting
    RecordWriter(BulkWriterContext writerContext, String[] columnNames, Supplier<TaskContext> taskContextSupplier, SSTableWriterFactory tableWriterFactory) {
        this.writerContext = writerContext;
        this.columnNames = columnNames;
        this.taskContextSupplier = taskContextSupplier;
        this.tableWriterFactory = tableWriterFactory;
        this.failureHandler = new MultiClusterReplicaAwareFailureHandler<RingInstance>(writerContext.cluster().getPartitioner());
        this.writeValidator = new BulkWriteValidator(writerContext, this.failureHandler);
        this.digestAlgorithm = (DigestAlgorithm)this.writerContext.job().digestAlgorithmSupplier().get();
        this.streamFutures = new HashMap<String, Future<StreamResult>>();
        this.executorService = Executors.newSingleThreadExecutor(ThreadUtil.threadFactory((String)"RecordWriter-worker"));
        this.baseDir = TaskContextUtils.getPartitionUniquePath(System.getProperty("java.io.tmpdir"), writerContext.job().getId(), taskContextSupplier.get());
        writerContext.cluster().startupValidate();
        this.cqlTable = writerContext.bridge().buildSchema(writerContext.schema().getTableSchema().createStatement, writerContext.job().qualifiedTableName().keyspace(), IGNORED_REPLICATION_FACTOR, writerContext.cluster().getPartitioner(), writerContext.schema().getUserDefinedTypeStatements());
    }

    public WriteResult write(Iterator<Tuple2<DecoratedKey, Object[]>> sourceIterator) {
        TaskContext taskContext = this.taskContextSupplier.get();
        LOGGER.info("[{}]: Processing bulk writer partition", (Object)taskContext.partitionId());
        Range<BigInteger> taskTokenRange = this.getTokenRange(taskContext);
        Preconditions.checkState((!taskTokenRange.isEmpty() ? 1 : 0) != 0, (String)"Token range for the partition %s is empty", (Object[])new Object[]{taskTokenRange});
        TokenRangeMapping<RingInstance> initialTokenRangeMapping = this.writerContext.cluster().getTokenRangeMapping(false);
        boolean isClusterBeingResized = !initialTokenRangeMapping.pendingInstances().isEmpty();
        LOGGER.info("[{}]: Fetched token range mapping for keyspace: {} with write instances: {} containing pending instances: {}", new Object[]{taskContext.partitionId(), this.writerContext.job().qualifiedTableName().keyspace(), initialTokenRangeMapping.allInstances().size(), initialTokenRangeMapping.pendingInstances().size()});
        this.writeValidator.setPhase("Environment Validation");
        this.writeValidator.validateClOrFail(initialTokenRangeMapping);
        this.writeValidator.setPhase("UploadAndCommit");
        this.writerContext.cluster().validateTimeSkew(taskTokenRange);
        JavaInterruptibleIterator<Tuple2<DecoratedKey, Object[]>> dataIterator = new JavaInterruptibleIterator<Tuple2<DecoratedKey, Object[]>>(taskContext, sourceIterator);
        int partitionId = taskContext.partitionId();
        JobInfo job = this.writerContext.job();
        HashMap<String, Object> valueMap = new HashMap<String, Object>();
        try {
            LinkedHashSet<Range<BigInteger>> newRanges = new LinkedHashSet<Range<BigInteger>>(initialTokenRangeMapping.getRangeMap().asMapOfRanges().keySet());
            Range<BigInteger> tokenRange = this.getTokenRange(taskContext);
            List<Range<BigInteger>> subRanges = newRanges.contains(tokenRange) ? Collections.singletonList(tokenRange) : this.getIntersectingSubRanges(newRanges, tokenRange);
            int currentRangeIndex = 0;
            Range<BigInteger> currentRange = subRanges.get(currentRangeIndex);
            while (dataIterator.hasNext()) {
                if (this.streamSession != null) {
                    this.streamSession.throwIfLastStreamFailed();
                }
                Tuple2 rowData = (Tuple2)dataIterator.next();
                BigInteger token = ((DecoratedKey)rowData._1()).getToken();
                while (!currentRange.contains((Comparable)token)) {
                    if (++currentRangeIndex >= subRanges.size()) {
                        String errMsg = String.format("Received Token %s outside the expected ranges %s", token, subRanges);
                        throw new IllegalStateException(errMsg);
                    }
                    currentRange = subRanges.get(currentRangeIndex);
                }
                this.maybeSwitchToNewStreamSession(taskContext, currentRange);
                this.writeRow((Tuple2<DecoratedKey, Object[]>)rowData, valueMap, partitionId, this.streamSession.getTokenRange());
            }
            if (this.streamSession != null) {
                this.flushAsync(partitionId);
            }
            List<StreamResult> results = this.waitForStreamCompletionAndValidate(partitionId, initialTokenRangeMapping, taskTokenRange);
            return new WriteResult(results, isClusterBeingResized);
        }
        catch (Exception exception) {
            LOGGER.error("[{}] Failed to write job={}, taskStageAttemptNumber={}, taskAttemptNumber={}", new Object[]{partitionId, job.getId(), taskContext.stageAttemptNumber(), taskContext.attemptNumber(), exception});
            if (exception instanceof InterruptedException) {
                Thread.currentThread().interrupt();
            }
            throw new RuntimeException(exception);
        }
    }

    @NotNull
    private List<StreamResult> waitForStreamCompletionAndValidate(int partitionId, TokenRangeMapping<RingInstance> initialTokenRangeMapping, Range<BigInteger> taskTokenRange) {
        List<StreamResult> results = this.streamFutures.values().stream().map(f -> {
            try {
                return (StreamResult)f.get();
            }
            catch (Exception e) {
                if (e instanceof InterruptedException) {
                    Thread.currentThread().interrupt();
                }
                throw new RuntimeException(e);
            }
        }).collect(Collectors.toList());
        LOGGER.info("[{}] Done with all writers and waiting for stream to complete", (Object)partitionId);
        this.validateTaskTokenRangeMappings(partitionId, initialTokenRangeMapping, taskTokenRange);
        return results;
    }

    private Map<Range<BigInteger>, List<RingInstance>> taskTokenRangeMapping(TokenRangeMapping<RingInstance> tokenRange, Range<BigInteger> taskTokenRange) {
        return tokenRange.getSubRanges(taskTokenRange).asMapOfRanges();
    }

    private Set<RingInstance> instancesFromMapping(Map<Range<BigInteger>, List<RingInstance>> mapping) {
        return mapping.values().stream().flatMap(Collection::stream).collect(Collectors.toSet());
    }

    private void maybeSwitchToNewStreamSession(TaskContext taskContext, Range<BigInteger> currentRange) throws IOException {
        if (this.streamSession != null && this.streamSession.getTokenRange().equals(currentRange)) {
            return;
        }
        if (this.streamSession != null) {
            this.flushAsync(taskContext.partitionId());
        }
        this.streamSession = this.createStreamSession(taskContext, currentRange);
    }

    private StreamSession<?> createStreamSession(TaskContext taskContext, Range<BigInteger> range) throws IOException {
        LOGGER.info("[{}] Creating new stream session. range={}", (Object)taskContext.partitionId(), range);
        String sessionId = TaskContextUtils.createStreamSessionId(taskContext);
        Path perSessionDirectory = this.baseDir.resolve(sessionId);
        Files.createDirectories(perSessionDirectory, new FileAttribute[0]);
        SortedSSTableWriter sstableWriter = this.tableWriterFactory.create(this.writerContext, perSessionDirectory, this.digestAlgorithm, taskContext.partitionId());
        LOGGER.info("[{}][{}] Created new SSTable writer with directory={}", new Object[]{taskContext.partitionId(), sessionId, perSessionDirectory});
        return this.writerContext.transportContext().createStreamSession(this.writerContext, sessionId, sstableWriter, range, this.failureHandler, this.executorService);
    }

    private List<Range<BigInteger>> getIntersectingSubRanges(Set<Range<BigInteger>> ranges, Range<BigInteger> tokenRange) {
        return ranges.stream().filter(r -> r.isConnected(tokenRange) && !r.intersection(tokenRange).isEmpty()).collect(Collectors.toList());
    }

    private void validateTaskTokenRangeMappings(int partitionId, TokenRangeMapping<RingInstance> startTaskMapping, Range<BigInteger> taskTokenRange) {
        boolean haveMappingsChanged;
        TokenRangeMapping<RingInstance> endTaskMapping = this.writerContext.cluster().getTokenRangeMapping(false);
        Map<Range<BigInteger>, List<RingInstance>> startMapping = this.taskTokenRangeMapping(startTaskMapping, taskTokenRange);
        Map<Range<BigInteger>, List<RingInstance>> endMapping = this.taskTokenRangeMapping(endTaskMapping, taskTokenRange);
        Set<RingInstance> initialInstances = this.instancesFromMapping(startMapping);
        Set<RingInstance> endInstances = this.instancesFromMapping(endMapping);
        boolean bl = haveMappingsChanged = !startMapping.keySet().equals(endMapping.keySet()) || !initialInstances.equals(endInstances);
        if (haveMappingsChanged) {
            Set<Range<BigInteger>> rangeDelta = RecordWriter.symmetricDifference(startMapping.keySet(), endMapping.keySet());
            Set instanceDelta = RecordWriter.symmetricDifference(initialInstances, endInstances).stream().map(RingInstance::ipAddressWithPort).collect(Collectors.toSet());
            String message = String.format("[%s] Token range mappings have changed since the task started with non-overlapping instances: %s and ranges: %s", partitionId, instanceDelta, rangeDelta);
            LOGGER.error(message);
            throw new RuntimeException(message);
        }
    }

    static <T> Set<T> symmetricDifference(Set<T> set1, Set<T> set2) {
        return Stream.concat(set1.stream().filter(element -> !set2.contains(element)), set2.stream().filter(element -> !set1.contains(element))).collect(Collectors.toSet());
    }

    private Range<BigInteger> getTokenRange(TaskContext taskContext) {
        return this.writerContext.job().getTokenPartitioner().getTokenRange(taskContext.partitionId());
    }

    private void writeRow(Tuple2<DecoratedKey, Object[]> keyAndRowData, Map<String, Object> valueMap, int partitionId, Range<BigInteger> range) throws IOException {
        DecoratedKey key = (DecoratedKey)keyAndRowData._1();
        BigInteger token = key.getToken();
        Preconditions.checkState((boolean)range.contains((Comparable)token), (Object)String.format("Received Token %s outside of expected range %s", token, range));
        try {
            this.streamSession.addRow(token, this.getBindValuesForColumns(valueMap, this.columnNames, (Object[])keyAndRowData._2()));
        }
        catch (RuntimeException exception) {
            String message = String.format("[%s]: Failed to write data to SSTable: SBW DecoratedKey was %s", partitionId, key);
            LOGGER.error(message, (Throwable)exception);
            throw exception;
        }
    }

    private Map<String, Object> getBindValuesForColumns(Map<String, Object> map, String[] columnNames, Object[] values) {
        Preconditions.checkArgument((values.length == columnNames.length ? 1 : 0) != 0, (Object)("Number of values does not match the number of columns " + values.length + ", " + columnNames.length));
        for (int i = 0; i < columnNames.length; ++i) {
            if (this.cqlTable.containsUdt(columnNames[i])) {
                map.put(columnNames[i], this.maybeConvertUdt(values[i]));
                continue;
            }
            map.put(columnNames[i], values[i]);
        }
        return map;
    }

    private Object maybeConvertUdt(Object value) {
        if (value instanceof List && !((List)value).isEmpty()) {
            ArrayList<Object> resultList = new ArrayList<Object>();
            for (Object entry : (List)value) {
                resultList.add(this.maybeConvertUdt(entry));
            }
            return resultList;
        }
        if (value instanceof Set && !((Set)value).isEmpty()) {
            HashSet<Object> resultSet = new HashSet<Object>();
            for (Object entry : (Set)value) {
                resultSet.add(this.maybeConvertUdt(entry));
            }
            return resultSet;
        }
        if (value instanceof Map && !((Map)value).isEmpty()) {
            HashMap<Object, Object> resultMap = new HashMap<Object, Object>();
            for (Map.Entry entry : ((Map)value).entrySet()) {
                resultMap.put(this.maybeConvertUdt(entry.getKey()), this.maybeConvertUdt(entry.getValue()));
            }
            return resultMap;
        }
        if (value instanceof BridgeUdtValue) {
            BridgeUdtValue udtValue = (BridgeUdtValue)value;
            for (Map.Entry entry : udtValue.udtMap.entrySet()) {
                udtValue.udtMap.put((String)entry.getKey(), this.maybeConvertUdt(entry.getValue()));
            }
            return this.getUdt(udtValue.name).convertForCqlWriter((Object)udtValue.udtMap, this.writerContext.bridge().getVersion(), false);
        }
        return value;
    }

    private synchronized CqlField.CqlType getUdt(String udtName) {
        return (CqlField.CqlType)this.udtCache.computeIfAbsent(udtName, name -> {
            for (CqlField.CqlUdt udt1 : this.cqlTable.udts()) {
                if (!udt1.cqlName().equals(name)) continue;
                return udt1;
            }
            throw new IllegalArgumentException("Could not find udt with name " + name);
        });
    }

    private void flushAsync(int partitionId) throws IOException {
        Preconditions.checkState((this.streamSession != null ? 1 : 0) != 0);
        LOGGER.info("[{}][{}] Closing writer and scheduling SStable stream with {} rows", new Object[]{partitionId, this.streamSession.sessionID, this.streamSession.rowCount()});
        Future<StreamResult> future = this.streamSession.finalizeStreamAsync();
        this.streamFutures.put(this.streamSession.sessionID, future);
        this.streamSession = null;
    }

    private static class JavaInterruptibleIterator<T>
    implements Iterator<T> {
        private final TaskContext taskContext;
        private final Iterator<T> delegate;

        JavaInterruptibleIterator(TaskContext taskContext, Iterator<T> delegate) {
            this.taskContext = taskContext;
            this.delegate = delegate;
        }

        @Override
        public boolean hasNext() {
            this.taskContext.killTaskIfInterrupted();
            return this.delegate.hasNext();
        }

        @Override
        public T next() {
            return this.delegate.next();
        }
    }

    public static interface SSTableWriterFactory {
        public SortedSSTableWriter create(BulkWriterContext var1, Path var2, DigestAlgorithm var3, int var4);
    }
}

