package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.IntStream;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.TermQueryBuilder;
import org.elasticsearch.search.aggregations.AbstractAggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.common.AbstractAucRoc;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

/* loaded from: input_file:org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/AucRoc.class */
public class AucRoc extends AbstractAucRoc {
    public static final ParseField INCLUDE_CURVE = new ParseField("include_curve", new String[0]);
    public static final ParseField CLASS_NAME = new ParseField("class_name", new String[0]);
    public static final ConstructingObjectParser<AucRoc, Void> PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), objArr -> {
        return new AucRoc((Boolean) objArr[0], (String) objArr[1]);
    });
    private static final String TRUE_AGG_NAME;
    private static final String NON_TRUE_AGG_NAME;
    private static final String NESTED_AGG_NAME = "nested";
    private static final String NESTED_FILTER_AGG_NAME = "nested_filter";
    private static final String PERCENTILES_AGG_NAME = "percentiles";
    private final boolean includeCurve;
    private final String className;
    private final SetOnce<EvaluationFields> fields = new SetOnce<>();
    private final SetOnce<EvaluationMetricResult> result = new SetOnce<>();

    public static AucRoc fromXContent(XContentParser xContentParser) {
        return (AucRoc) PARSER.apply(xContentParser, (Object) null);
    }

    public AucRoc(Boolean bool, String str) {
        this.includeCurve = bool == null ? false : bool.booleanValue();
        this.className = (String) ExceptionsHelper.requireNonNull(str, CLASS_NAME.getPreferredName());
    }

    public AucRoc(StreamInput streamInput) throws IOException {
        this.includeCurve = streamInput.readBoolean();
        this.className = streamInput.readOptionalString();
    }

    public String getWriteableName() {
        return MlEvaluationNamedXContentProvider.registeredMetricName(Classification.NAME, NAME);
    }

    public void writeTo(StreamOutput streamOutput) throws IOException {
        streamOutput.writeBoolean(this.includeCurve);
        streamOutput.writeOptionalString(this.className);
    }

    public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
        xContentBuilder.startObject();
        xContentBuilder.field(INCLUDE_CURVE.getPreferredName(), this.includeCurve);
        if (this.className != null) {
            xContentBuilder.field(CLASS_NAME.getPreferredName(), this.className);
        }
        xContentBuilder.endObject();
        return xContentBuilder;
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric
    public Set<String> getRequiredFields() {
        return Sets.newHashSet(new String[]{EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_CLASS_FIELD.getPreferredName(), EvaluationFields.PREDICTED_PROBABILITY_FIELD.getPreferredName()});
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        AucRoc aucRoc = (AucRoc) obj;
        return this.includeCurve == aucRoc.includeCurve && Objects.equals(this.className, aucRoc.className);
    }

    public int hashCode() {
        return Objects.hash(Boolean.valueOf(this.includeCurve), this.className);
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric
    public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters evaluationParameters, EvaluationFields evaluationFields) {
        if (this.result.get() != null) {
            return Tuple.tuple(Arrays.asList(new AggregationBuilder[0]), Arrays.asList(new PipelineAggregationBuilder[0]));
        }
        this.fields.trySet(evaluationFields);
        AbstractAggregationBuilder subAggregation = AggregationBuilders.nested("nested", evaluationFields.getTopClassesField()).subAggregation(AggregationBuilders.filter(NESTED_FILTER_AGG_NAME, QueryBuilders.termQuery(evaluationFields.getPredictedClassField(), this.className)).subAggregation(AggregationBuilders.percentiles(PERCENTILES_AGG_NAME).field(evaluationFields.getPredictedProbabilityField()).percentiles(IntStream.range(1, 100).mapToDouble(i -> {
            return i;
        }).toArray())));
        TermQueryBuilder termQuery = QueryBuilders.termQuery(evaluationFields.getActualField(), this.className);
        return Tuple.tuple(Arrays.asList(AggregationBuilders.filter(TRUE_AGG_NAME, termQuery).subAggregation(subAggregation), AggregationBuilders.filter(NON_TRUE_AGG_NAME, QueryBuilders.boolQuery().mustNot(termQuery)).subAggregation(subAggregation)), Arrays.asList(new PipelineAggregationBuilder[0]));
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric
    public void process(Aggregations aggregations) {
        if (this.result.get() != null) {
            return;
        }
        Filter filter = aggregations.get(TRUE_AGG_NAME);
        Filter filter2 = filter.getAggregations().get("nested").getAggregations().get(NESTED_FILTER_AGG_NAME);
        Filter filter3 = aggregations.get(NON_TRUE_AGG_NAME);
        Filter filter4 = filter3.getAggregations().get("nested").getAggregations().get(NESTED_FILTER_AGG_NAME);
        if (filter.getDocCount() == 0) {
            throw ExceptionsHelper.badRequestException("[{}] requires at least one [{}] to have the value [{}]", getName(), ((EvaluationFields) this.fields.get()).getActualField(), this.className);
        }
        if (filter3.getDocCount() == 0) {
            throw ExceptionsHelper.badRequestException("[{}] requires at least one [{}] to have a different value than [{}]", getName(), ((EvaluationFields) this.fields.get()).getActualField(), this.className);
        }
        long docCount = filter2.getDocCount() + filter4.getDocCount();
        long docCount2 = filter.getDocCount() + filter3.getDocCount();
        if (docCount < docCount2) {
            throw ExceptionsHelper.badRequestException("[{}] requires that [{}] appears as one of the [{}] for every document (appeared in {} out of {}). This is probably caused by the {} value being less than the total number of actual classes in the dataset.", getName(), this.className, ((EvaluationFields) this.fields.get()).getPredictedClassField(), Long.valueOf(docCount), Long.valueOf(docCount2), org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification.NUM_TOP_CLASSES.getPreferredName());
        }
        List<AbstractAucRoc.AucRocPoint> buildAucRocCurve = buildAucRocCurve(percentilesArray(filter2.getAggregations().get(PERCENTILES_AGG_NAME)), percentilesArray(filter4.getAggregations().get(PERCENTILES_AGG_NAME)));
        this.result.set(new AbstractAucRoc.Result(calculateAucScore(buildAucRocCurve), this.includeCurve ? buildAucRocCurve : Collections.emptyList()));
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric
    public Optional<EvaluationMetricResult> getResult() {
        return Optional.ofNullable((EvaluationMetricResult) this.result.get());
    }

    static {
        PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), INCLUDE_CURVE);
        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), CLASS_NAME);
        TRUE_AGG_NAME = NAME.getPreferredName() + "_true";
        NON_TRUE_AGG_NAME = NAME.getPreferredName() + "_non_true";
    }
}
