package org.elasticsearch.xpack.ml.action;

import java.util.Objects;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;

/* loaded from: input_file:org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.class */
public class TransportInternalInferModelAction extends HandledTransportAction<InternalInferModelAction.Request, InternalInferModelAction.Response> {
    private final ModelLoadingService modelLoadingService;
    private final Client client;
    private final XPackLicenseState licenseState;
    private final TrainedModelProvider trainedModelProvider;

    @Inject
    public TransportInternalInferModelAction(TransportService transportService, ActionFilters actionFilters, ModelLoadingService modelLoadingService, Client client, XPackLicenseState xPackLicenseState, TrainedModelProvider trainedModelProvider) {
        super("cluster:internal/xpack/ml/inference/infer", transportService, actionFilters, InternalInferModelAction.Request::new);
        this.modelLoadingService = modelLoadingService;
        this.client = client;
        this.licenseState = xPackLicenseState;
        this.trainedModelProvider = trainedModelProvider;
    }

    protected void doExecute(Task task, InternalInferModelAction.Request request, ActionListener<InternalInferModelAction.Response> actionListener) {
        InternalInferModelAction.Response.Builder builder = InternalInferModelAction.Response.builder();
        CheckedConsumer checkedConsumer = localModel -> {
            TypedChainTaskExecutor typedChainTaskExecutor = new TypedChainTaskExecutor(this.client.threadPool().executor("same"), inferenceResults -> {
                return true;
            }, exc -> {
                return true;
            });
            request.getObjectsToInfer().forEach(map -> {
                typedChainTaskExecutor.add(actionListener2 -> {
                    localModel.infer(map, request.getUpdate(), actionListener2);
                });
            });
            typedChainTaskExecutor.execute(ActionListener.wrap(list -> {
                localModel.release();
                actionListener.onResponse(builder.setInferenceResults(list).setModelId(localModel.getModelId()).build());
            }, exc2 -> {
                localModel.release();
                actionListener.onFailure(exc2);
            }));
        };
        Objects.requireNonNull(actionListener);
        ActionListener<LocalModel> wrap = ActionListener.wrap(checkedConsumer, actionListener::onFailure);
        if (MachineLearningField.ML_API_FEATURE.check(this.licenseState)) {
            builder.setLicensed(true);
            this.modelLoadingService.getModelForPipeline(request.getModelId(), wrap);
            return;
        }
        TrainedModelProvider trainedModelProvider = this.trainedModelProvider;
        String modelId = request.getModelId();
        GetTrainedModelsAction.Includes empty = GetTrainedModelsAction.Includes.empty();
        CheckedConsumer checkedConsumer2 = trainedModelConfig -> {
            builder.setLicensed(this.licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()));
            if (this.licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()) || request.isPreviouslyLicensed()) {
                this.modelLoadingService.getModelForPipeline(request.getModelId(), wrap);
            } else {
                actionListener.onFailure(LicenseUtils.newComplianceException("ml"));
            }
        };
        Objects.requireNonNull(actionListener);
        trainedModelProvider.getTrainedModel(modelId, empty, ActionListener.wrap(checkedConsumer2, actionListener::onFailure));
    }

    protected /* bridge */ /* synthetic */ void doExecute(Task task, ActionRequest actionRequest, ActionListener actionListener) {
        doExecute(task, (InternalInferModelAction.Request) actionRequest, (ActionListener<InternalInferModelAction.Response>) actionListener);
    }
}
