package org.elasticsearch.xpack.ml.inference.loadingservice;

import java.io.Closeable;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.atomic.LongAdder;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.license.License;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;

/* loaded from: input_file:org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.class */
public class LocalModel implements Closeable {
    private final InferenceDefinition trainedModelDefinition;
    private final String modelId;
    private final Set<String> fieldNames;
    private final Map<String, String> defaultFieldMap;
    private final InferenceStats.Accumulator statsAccumulator;
    private final TrainedModelStatsService trainedModelStatsService;
    private volatile long persistenceQuotient = 100;
    private final LongAdder currentInferenceCount;
    private final InferenceConfig inferenceConfig;
    private final License.OperationMode licenseLevel;
    private final CircuitBreaker trainedModelCircuitBreaker;
    private final AtomicLong referenceCount;
    private final long cachedRamBytesUsed;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    public LocalModel(String str, String str2, InferenceDefinition inferenceDefinition, TrainedModelInput trainedModelInput, Map<String, String> map, InferenceConfig inferenceConfig, License.OperationMode operationMode, TrainedModelStatsService trainedModelStatsService, CircuitBreaker circuitBreaker) {
        this.trainedModelDefinition = inferenceDefinition;
        this.cachedRamBytesUsed = inferenceDefinition.ramBytesUsed();
        this.modelId = str;
        this.fieldNames = new HashSet(trainedModelInput.getFieldNames());
        this.statsAccumulator = new InferenceStats.Accumulator(str, str2, 1L);
        this.trainedModelStatsService = trainedModelStatsService;
        this.defaultFieldMap = map == null ? null : new HashMap(map);
        this.currentInferenceCount = new LongAdder();
        this.inferenceConfig = inferenceConfig;
        this.licenseLevel = operationMode;
        this.trainedModelCircuitBreaker = circuitBreaker;
        this.referenceCount = new AtomicLong(1L);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public long ramBytesUsed() {
        return this.cachedRamBytesUsed;
    }

    public String getModelId() {
        return this.modelId;
    }

    public License.OperationMode getLicenseLevel() {
        return this.licenseLevel;
    }

    public InferenceStats getLatestStatsAndReset() {
        return this.statsAccumulator.currentStatsAndReset();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void persistStats(boolean z) {
        this.trainedModelStatsService.queueStats(getLatestStatsAndReset(), z);
        if (this.persistenceQuotient < 1000 && this.currentInferenceCount.sum() > 1000) {
            this.persistenceQuotient = 1000L;
        }
        if (this.persistenceQuotient >= 10000 || this.currentInferenceCount.sum() <= 10000) {
            return;
        }
        this.persistenceQuotient = 10000L;
    }

    public InferenceResults inferNoStats(Map<String, Object> map) {
        mapFieldsIfNecessary(map, this.defaultFieldMap);
        Map dotCollapse = MapHelper.dotCollapse(map, this.fieldNames);
        if (dotCollapse.isEmpty()) {
            new WarningInferenceResults(Messages.getMessage("Model [{0}] could not be inferred as all fields were missing", new Object[]{this.modelId}));
        }
        return this.trainedModelDefinition.infer(dotCollapse, this.inferenceConfig);
    }

    public void infer(Map<String, Object> map, InferenceConfigUpdate inferenceConfigUpdate, ActionListener<InferenceResults> actionListener) {
        if (!inferenceConfigUpdate.isSupported(this.inferenceConfig)) {
            actionListener.onFailure(ExceptionsHelper.badRequestException("Model [{}] has inference config of type [{}] which is not supported by inference request of type [{}]", new Object[]{this.modelId, this.inferenceConfig.getName(), inferenceConfigUpdate.getName()}));
            return;
        }
        try {
            this.statsAccumulator.incInference();
            this.currentInferenceCount.increment();
            mapFieldsIfNecessary(map, this.defaultFieldMap);
            Map dotCollapse = MapHelper.dotCollapse(map, this.fieldNames);
            boolean z = (this.currentInferenceCount.sum() + 1) % this.persistenceQuotient == 0;
            if (dotCollapse.isEmpty()) {
                this.statsAccumulator.incMissingFields();
                if (z) {
                    persistStats(false);
                }
                actionListener.onResponse(new WarningInferenceResults(Messages.getMessage("Model [{0}] could not be inferred as all fields were missing", new Object[]{this.modelId})));
                return;
            }
            InferenceResults infer = this.trainedModelDefinition.infer(dotCollapse, inferenceConfigUpdate.apply(this.inferenceConfig));
            if (z) {
                persistStats(false);
            }
            actionListener.onResponse(infer);
        } catch (Exception e) {
            this.statsAccumulator.incFailure();
            actionListener.onFailure(e);
        }
    }

    public InferenceResults infer(Map<String, Object> map, InferenceConfigUpdate inferenceConfigUpdate) throws Exception {
        AtomicReference atomicReference = new AtomicReference();
        AtomicReference atomicReference2 = new AtomicReference();
        Objects.requireNonNull(atomicReference);
        CheckedConsumer checkedConsumer = (v1) -> {
            r0.set(v1);
        };
        Objects.requireNonNull(atomicReference2);
        infer(map, inferenceConfigUpdate, ActionListener.wrap(checkedConsumer, (v1) -> {
            r1.set(v1);
        }));
        if (atomicReference2.get() != null) {
            throw ((Exception) atomicReference2.get());
        }
        return (InferenceResults) atomicReference.get();
    }

    public static void mapFieldsIfNecessary(Map<String, Object> map, Map<String, String> map2) {
        if (map2 != null) {
            map2.forEach((str, str2) -> {
                Object dig = MapHelper.dig(str, map);
                if (dig != null) {
                    map.putIfAbsent(str2, dig);
                }
            });
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public long acquire() {
        long incrementAndGet = this.referenceCount.incrementAndGet();
        if (incrementAndGet == 1) {
            this.trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(this.trainedModelDefinition.ramBytesUsed(), this.modelId);
        }
        return incrementAndGet;
    }

    public long getReferenceCount() {
        return this.referenceCount.get();
    }

    public long release() {
        long decrementAndGet = this.referenceCount.decrementAndGet();
        if (!$assertionsDisabled && decrementAndGet < 0) {
            throw new AssertionError();
        }
        if (decrementAndGet == 0) {
            this.trainedModelCircuitBreaker.addWithoutBreaking(-ramBytesUsed());
        }
        return this.referenceCount.get();
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        release();
    }

    public String toString() {
        return "LocalModel{trainedModelDefinition=" + this.trainedModelDefinition + ", modelId='" + this.modelId + "', fieldNames=" + this.fieldNames + ", defaultFieldMap=" + this.defaultFieldMap + ", statsAccumulator=" + this.statsAccumulator + ", trainedModelStatsService=" + this.trainedModelStatsService + ", persistenceQuotient=" + this.persistenceQuotient + ", currentInferenceCount=" + this.currentInferenceCount + ", inferenceConfig=" + this.inferenceConfig + ", licenseLevel=" + this.licenseLevel + ", trainedModelCircuitBreaker=" + this.trainedModelCircuitBreaker + ", referenceCount=" + this.referenceCount + '}';
    }

    static {
        $assertionsDisabled = !LocalModel.class.desiredAssertionStatus();
    }
}
