package org.elasticsearch.xpack.ml.action;

import java.io.IOException;
import java.time.Instant;
import java.util.List;
import java.util.Objects;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.master.MasterNodeRequest;
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.license.License;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;

/* loaded from: input_file:org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.class */
public class TransportPutTrainedModelAction extends TransportMasterNodeAction<PutTrainedModelAction.Request, PutTrainedModelAction.Response> {
    private final TrainedModelProvider trainedModelProvider;
    private final XPackLicenseState licenseState;
    private final NamedXContentRegistry xContentRegistry;
    private final Client client;

    @Inject
    public TransportPutTrainedModelAction(TransportService transportService, ClusterService clusterService, ThreadPool threadPool, XPackLicenseState xPackLicenseState, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, Client client, TrainedModelProvider trainedModelProvider, NamedXContentRegistry namedXContentRegistry) {
        super("cluster:admin/xpack/ml/inference/put", transportService, clusterService, threadPool, actionFilters, PutTrainedModelAction.Request::new, indexNameExpressionResolver, PutTrainedModelAction.Response::new, "same");
        this.licenseState = xPackLicenseState;
        this.trainedModelProvider = trainedModelProvider;
        this.xContentRegistry = namedXContentRegistry;
        this.client = client;
    }

    protected void masterOperation(PutTrainedModelAction.Request request, ClusterState clusterState, ActionListener<PutTrainedModelAction.Response> actionListener) {
        TrainedModelConfig trainedModelConfig = request.getTrainedModelConfig();
        if (clusterState.nodes().getMinNodeVersion().before(Version.V_7_8_0)) {
            actionListener.onFailure(ExceptionsHelper.badRequestException("Creating a new model requires that all nodes are at least version [{}]", new Object[]{request.getTrainedModelConfig().getModelId(), Version.V_7_8_0.toString()}));
            return;
        }
        try {
            if (!request.isDeferDefinitionDecompression()) {
                trainedModelConfig.ensureParsedDefinition(this.xContentRegistry);
                trainedModelConfig.getModelDefinition().getTrainedModel().validate();
            }
            boolean z = trainedModelConfig.getModelDefinition() != null;
            if (z) {
                if (!trainedModelConfig.getInferenceConfig().isTargetTypeSupported(trainedModelConfig.getModelDefinition().getTrainedModel().targetType())) {
                    actionListener.onFailure(ExceptionsHelper.badRequestException("Model [{}] inference config type [{}] does not support definition target type [{}]", new Object[]{trainedModelConfig.getModelId(), trainedModelConfig.getInferenceConfig().getName(), trainedModelConfig.getModelDefinition().getTrainedModel().targetType()}));
                    return;
                }
                Version minimalCompatibilityVersion = trainedModelConfig.getModelDefinition().getTrainedModel().getMinimalCompatibilityVersion();
                if (clusterState.nodes().getMinNodeVersion().before(minimalCompatibilityVersion)) {
                    actionListener.onFailure(ExceptionsHelper.badRequestException("Definition for [{}] requires that all nodes are at least version [{}]", new Object[]{request.getTrainedModelConfig().getModelId(), minimalCompatibilityVersion.toString()}));
                    return;
                }
            }
            TrainedModelConfig.Builder licenseLevel = new TrainedModelConfig.Builder(trainedModelConfig).setVersion(Version.CURRENT).setCreateTime(Instant.now()).setCreatedBy("api_user").setLicenseLevel(License.OperationMode.PLATINUM.description());
            if (z) {
                licenseLevel.setModelSize(request.getTrainedModelConfig().getModelDefinition().ramBytesUsed()).setEstimatedOperations(request.getTrainedModelConfig().getModelDefinition().getTrainedModel().estimatedNumOperations());
            }
            TrainedModelConfig build = licenseLevel.build();
            if (ModelAliasMetadata.fromState(clusterState).getModelId(build.getModelId()) != null) {
                actionListener.onFailure(ExceptionsHelper.badRequestException("requested model_id [{}] is the same as an existing model_alias. Model model_aliases and ids must be unique", new Object[]{request.getTrainedModelConfig().getModelId()}));
                return;
            }
            CheckedConsumer checkedConsumer = r9 -> {
                TrainedModelProvider trainedModelProvider = this.trainedModelProvider;
                CheckedConsumer checkedConsumer2 = bool -> {
                    actionListener.onResponse(new PutTrainedModelAction.Response(new TrainedModelConfig.Builder(build).clearDefinition().build()));
                };
                Objects.requireNonNull(actionListener);
                trainedModelProvider.storeTrainedModel(build, ActionListener.wrap(checkedConsumer2, actionListener::onFailure));
            };
            Objects.requireNonNull(actionListener);
            ActionListener wrap = ActionListener.wrap(checkedConsumer, actionListener::onFailure);
            CheckedConsumer checkedConsumer2 = r7 -> {
                checkTagsAgainstModelIds(request.getTrainedModelConfig().getTags(), wrap);
            };
            Objects.requireNonNull(actionListener);
            checkModelIdAgainstTags(request.getTrainedModelConfig().getModelId(), ActionListener.wrap(checkedConsumer2, actionListener::onFailure));
        } catch (IOException e) {
            actionListener.onFailure(ExceptionsHelper.badRequestException("Failed to parse definition for [{}]", e, new Object[]{request.getTrainedModelConfig().getModelId()}));
        } catch (ElasticsearchException e2) {
            actionListener.onFailure(ExceptionsHelper.badRequestException("Definition for [{}] has validation failures.", e2, new Object[]{request.getTrainedModelConfig().getModelId()}));
        }
    }

    private void checkModelIdAgainstTags(String str, ActionListener<Void> actionListener) {
        SearchRequest source = new SearchRequest(new String[]{".ml-inference-*"}).source(new SearchSourceBuilder().query(QueryBuilders.constantScoreQuery(QueryBuilders.boolQuery().filter(QueryBuilders.termQuery(TrainedModelConfig.TAGS.getPreferredName(), str)))).size(0).trackTotalHitsUpTo(1));
        ThreadContext threadContext = this.client.threadPool().getThreadContext();
        CheckedConsumer checkedConsumer = searchResponse -> {
            if (searchResponse.getHits().getTotalHits().value > 0) {
                actionListener.onFailure(ExceptionsHelper.badRequestException(Messages.getMessage("The provided model_id {0} must not match existing tags.", new Object[]{str}), new Object[0]));
            } else {
                actionListener.onResponse((Object) null);
            }
        };
        Objects.requireNonNull(actionListener);
        ActionListener wrap = ActionListener.wrap(checkedConsumer, actionListener::onFailure);
        Client client = this.client;
        Objects.requireNonNull(client);
        ClientHelper.executeAsyncWithOrigin(threadContext, "ml", source, wrap, client::search);
    }

    private void checkTagsAgainstModelIds(List<String> list, ActionListener<Void> actionListener) {
        if (list.isEmpty()) {
            actionListener.onResponse((Object) null);
            return;
        }
        SearchRequest source = new SearchRequest(new String[]{".ml-inference-*"}).source(new SearchSourceBuilder().query(QueryBuilders.constantScoreQuery(QueryBuilders.boolQuery().filter(QueryBuilders.termsQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), list)))).size(0).trackTotalHitsUpTo(1));
        ThreadContext threadContext = this.client.threadPool().getThreadContext();
        CheckedConsumer checkedConsumer = searchResponse -> {
            if (searchResponse.getHits().getTotalHits().value > 0) {
                actionListener.onFailure(ExceptionsHelper.badRequestException(Messages.getMessage("The provided tags {0} must not match existing model_ids.", new Object[]{list}), new Object[0]));
            } else {
                actionListener.onResponse((Object) null);
            }
        };
        Objects.requireNonNull(actionListener);
        ActionListener wrap = ActionListener.wrap(checkedConsumer, actionListener::onFailure);
        Client client = this.client;
        Objects.requireNonNull(client);
        ClientHelper.executeAsyncWithOrigin(threadContext, "ml", source, wrap, client::search);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ClusterBlockException checkBlock(PutTrainedModelAction.Request request, ClusterState clusterState) {
        return clusterState.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
    }

    protected void doExecute(Task task, PutTrainedModelAction.Request request, ActionListener<PutTrainedModelAction.Response> actionListener) {
        if (MachineLearningField.ML_API_FEATURE.check(this.licenseState)) {
            super.doExecute(task, request, actionListener);
        } else {
            actionListener.onFailure(LicenseUtils.newComplianceException("ml"));
        }
    }

    protected /* bridge */ /* synthetic */ void doExecute(Task task, MasterNodeRequest masterNodeRequest, ActionListener actionListener) {
        doExecute(task, (PutTrainedModelAction.Request) masterNodeRequest, (ActionListener<PutTrainedModelAction.Response>) actionListener);
    }

    protected /* bridge */ /* synthetic */ void masterOperation(MasterNodeRequest masterNodeRequest, ClusterState clusterState, ActionListener actionListener) throws Exception {
        masterOperation((PutTrainedModelAction.Request) masterNodeRequest, clusterState, (ActionListener<PutTrainedModelAction.Response>) actionListener);
    }

    protected /* bridge */ /* synthetic */ void doExecute(Task task, ActionRequest actionRequest, ActionListener actionListener) {
        doExecute(task, (PutTrainedModelAction.Request) actionRequest, (ActionListener<PutTrainedModelAction.Response>) actionListener);
    }
}
