/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.state.heap;

import java.util.HashMap;
import java.util.Set;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.flink.runtime.state.KeyExtractorFunction;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
import org.apache.flink.runtime.state.PriorityComparator;
import org.apache.flink.runtime.state.heap.HeapPriorityQueue;
import org.apache.flink.runtime.state.heap.HeapPriorityQueueElement;
import org.apache.flink.util.Preconditions;

public class HeapPriorityQueueSet<T extends HeapPriorityQueueElement>
extends HeapPriorityQueue<T>
implements KeyGroupedInternalPriorityQueue<T> {
    private final KeyExtractorFunction<T> keyExtractor;
    private final HashMap<T, T>[] deduplicationMapsByKeyGroup;
    private final KeyGroupRange keyGroupRange;
    private final int totalNumberOfKeyGroups;

    public HeapPriorityQueueSet(@Nonnull PriorityComparator<T> elementPriorityComparator, @Nonnull KeyExtractorFunction<T> keyExtractor, @Nonnegative int minimumCapacity, @Nonnull KeyGroupRange keyGroupRange, @Nonnegative int totalNumberOfKeyGroups) {
        super(elementPriorityComparator, minimumCapacity);
        this.keyExtractor = keyExtractor;
        this.totalNumberOfKeyGroups = totalNumberOfKeyGroups;
        this.keyGroupRange = keyGroupRange;
        int keyGroupsInLocalRange = keyGroupRange.getNumberOfKeyGroups();
        int deduplicationSetSize = 1 + minimumCapacity / keyGroupsInLocalRange;
        this.deduplicationMapsByKeyGroup = new HashMap[keyGroupsInLocalRange];
        for (int i = 0; i < keyGroupsInLocalRange; ++i) {
            this.deduplicationMapsByKeyGroup[i] = new HashMap(deduplicationSetSize);
        }
    }

    @Override
    @Nullable
    public T poll() {
        Object toRemove = super.poll();
        return (T)(toRemove != null ? (HeapPriorityQueueElement)this.getDedupMapForElement(toRemove).remove(toRemove) : null);
    }

    @Override
    public boolean add(@Nonnull T element) {
        return this.getDedupMapForElement(element).putIfAbsent(element, element) == null && super.add(element);
    }

    @Override
    public boolean remove(@Nonnull T toRemove) {
        HeapPriorityQueueElement storedElement = (HeapPriorityQueueElement)this.getDedupMapForElement(toRemove).remove(toRemove);
        return storedElement != null && super.remove(storedElement);
    }

    @Override
    public void clear() {
        super.clear();
        for (HashMap<T, T> elementHashMap : this.deduplicationMapsByKeyGroup) {
            elementHashMap.clear();
        }
    }

    private HashMap<T, T> getDedupMapForKeyGroup(@Nonnegative int keyGroupId) {
        return this.deduplicationMapsByKeyGroup[this.globalKeyGroupToLocalIndex(keyGroupId)];
    }

    private HashMap<T, T> getDedupMapForElement(T element) {
        int keyGroup = KeyGroupRangeAssignment.assignToKeyGroup(this.keyExtractor.extractKeyFromElement(element), this.totalNumberOfKeyGroups);
        return this.getDedupMapForKeyGroup(keyGroup);
    }

    private int globalKeyGroupToLocalIndex(int keyGroup) {
        Preconditions.checkArgument(this.keyGroupRange.contains(keyGroup), "%s does not contain key group %s", this.keyGroupRange, keyGroup);
        return keyGroup - this.keyGroupRange.getStartKeyGroup();
    }

    @Override
    @Nonnull
    public Set<T> getSubsetForKeyGroup(int keyGroupId) {
        return this.getDedupMapForKeyGroup(keyGroupId).keySet();
    }
}

