/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ad.transport;

import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;
import java.io.IOException;
import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.nodes.TransportNodesAction;
import org.opensearch.ad.caching.ADCacheProvider;
import org.opensearch.ad.caching.ADPriorityCache;
import org.opensearch.ad.ml.ADRealTimeInferencer;
import org.opensearch.ad.transport.ADHCImputeAction;
import org.opensearch.ad.transport.ADHCImputeNodeRequest;
import org.opensearch.ad.transport.ADHCImputeNodeResponse;
import org.opensearch.ad.transport.ADHCImputeNodesResponse;
import org.opensearch.ad.transport.ADHCImputeRequest;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.timeseries.AnalysisType;
import org.opensearch.timeseries.NodeStateManager;
import org.opensearch.timeseries.cluster.HashRing;
import org.opensearch.timeseries.ml.ModelState;
import org.opensearch.timeseries.ml.Sample;
import org.opensearch.timeseries.model.Config;
import org.opensearch.timeseries.util.ActionListenerExecutor;
import org.opensearch.transport.TransportService;

public class ADHCImputeTransportAction
extends TransportNodesAction<ADHCImputeRequest, ADHCImputeNodesResponse, ADHCImputeNodeRequest, ADHCImputeNodeResponse> {
    private static final Logger LOG = LogManager.getLogger(ADHCImputeTransportAction.class);
    private ADCacheProvider cache;
    private NodeStateManager nodeStateManager;
    private ADRealTimeInferencer adInferencer;
    private HashRing hashRing;

    @Inject
    public ADHCImputeTransportAction(ThreadPool threadPool, ClusterService clusterService, TransportService transportService, ActionFilters actionFilters, ADCacheProvider priorityCache, NodeStateManager nodeStateManager, ADRealTimeInferencer adInferencer, HashRing hashRing) {
        super(ADHCImputeAction.NAME, threadPool, clusterService, transportService, actionFilters, ADHCImputeRequest::new, ADHCImputeNodeRequest::new, "ad-threadpool", ADHCImputeNodeResponse.class);
        this.cache = priorityCache;
        this.nodeStateManager = nodeStateManager;
        this.adInferencer = adInferencer;
        this.hashRing = hashRing;
    }

    protected ADHCImputeNodeRequest newNodeRequest(ADHCImputeRequest request) {
        return new ADHCImputeNodeRequest(request);
    }

    protected ADHCImputeNodeResponse newNodeResponse(StreamInput response) throws IOException {
        return new ADHCImputeNodeResponse(response);
    }

    protected ADHCImputeNodesResponse newResponse(ADHCImputeRequest request, List<ADHCImputeNodeResponse> responses, List<FailedNodeException> failures) {
        return new ADHCImputeNodesResponse(this.clusterService.getClusterName(), responses, failures);
    }

    protected ADHCImputeNodeResponse nodeOperation(ADHCImputeNodeRequest nodeRequest) {
        String configId = nodeRequest.getRequest().getConfigId();
        this.nodeStateManager.getConfig(configId, AnalysisType.AD, ActionListenerExecutor.wrap(configOptional -> {
            if (configOptional.isEmpty()) {
                LOG.warn(String.format(Locale.ROOT, "cannot find config %s", configId));
                return;
            }
            Config config = (Config)configOptional.get();
            int featureSize = config.getEnabledFeatureIds().size();
            long dataEndMillis = nodeRequest.getRequest().getDataEndMillis();
            long dataStartMillis = nodeRequest.getRequest().getDataStartMillis();
            String taskId = nodeRequest.getRequest().getTaskId();
            for (ModelState<ThresholdedRandomCutForest> modelState : ((ADPriorityCache)this.cache.get()).getAllModels(configId)) {
                if (!this.shouldProcessModelState(modelState, dataEndMillis, this.clusterService, this.hashRing)) continue;
                double[] nanArray = new double[featureSize];
                Arrays.fill(nanArray, Double.NaN);
                this.adInferencer.process(new Sample(nanArray, Instant.ofEpochMilli(dataStartMillis), Instant.ofEpochMilli(dataEndMillis)), modelState, config, taskId);
            }
        }, e -> this.nodeStateManager.setException(configId, (Exception)e), this.threadPool.executor("ad-threadpool")));
        Optional<Exception> previousException = this.nodeStateManager.fetchExceptionAndClear(configId);
        if (previousException.isPresent()) {
            return new ADHCImputeNodeResponse(this.clusterService.localNode(), previousException.get());
        }
        return new ADHCImputeNodeResponse(this.clusterService.localNode(), null);
    }

    private boolean shouldProcessModelState(ModelState<ThresholdedRandomCutForest> modelState, long dataEndTime, ClusterService clusterService, HashRing hashRing) {
        Optional owningNode = modelState.getEntity().isPresent() ? hashRing.getOwningNodeWithSameLocalVersionForRealtime(modelState.getEntity().get().toString()) : Optional.empty();
        return modelState.getLastSeenDataEndTime() != Instant.MIN && dataEndTime > modelState.getLastSeenDataEndTime().toEpochMilli() && modelState.getEntity().isPresent() && owningNode.isPresent() && ((DiscoveryNode)owningNode.get()).getId().equals(clusterService.localNode().getId());
    }
}

