package org.elasticsearch.xpack.ml.aggs.correlation;

import java.io.IOException;
import java.util.Objects;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.search.aggregations.AggregationExecutionException;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.pipeline.MovingFunctions;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;

/* loaded from: input_file:org/elasticsearch/xpack/ml/aggs/correlation/CountCorrelationFunction.class */
public class CountCorrelationFunction implements CorrelationFunction {
    public static final ParseField NAME = new ParseField("count_correlation", new String[0]);
    public static final ParseField INDICATOR = new ParseField("indicator", new String[0]);
    private static final ConstructingObjectParser<CountCorrelationFunction, Void> PARSER = new ConstructingObjectParser<>("count_correlation_function", false, objArr -> {
        return new CountCorrelationFunction((CountCorrelationIndicator) objArr[0]);
    });
    private final CountCorrelationIndicator indicator;

    public CountCorrelationFunction(CountCorrelationIndicator countCorrelationIndicator) {
        this.indicator = countCorrelationIndicator;
    }

    public CountCorrelationFunction(StreamInput streamInput) throws IOException {
        this.indicator = new CountCorrelationIndicator(streamInput);
    }

    public static CountCorrelationFunction fromXContent(XContentParser xContentParser) {
        return (CountCorrelationFunction) PARSER.apply(xContentParser, (Object) null);
    }

    public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
        xContentBuilder.startObject();
        xContentBuilder.field(INDICATOR.getPreferredName(), this.indicator);
        xContentBuilder.endObject();
        return xContentBuilder;
    }

    public String getWriteableName() {
        return NAME.getPreferredName();
    }

    public void writeTo(StreamOutput streamOutput) throws IOException {
        this.indicator.writeTo(streamOutput);
    }

    public String getName() {
        return NAME.getPreferredName();
    }

    public int hashCode() {
        return NAME.getPreferredName().hashCode();
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        return Objects.equals(this.indicator, ((CountCorrelationFunction) obj).indicator);
    }

    @Override // org.elasticsearch.xpack.ml.aggs.correlation.CorrelationFunction
    public double execute(CountCorrelationIndicator countCorrelationIndicator) {
        double d;
        double d2;
        if (this.indicator.getExpectations().length != countCorrelationIndicator.getExpectations().length) {
            throw new AggregationExecutionException("value lengths do not match; indicator.expectations [" + this.indicator.getExpectations().length + "] and number of buckets [" + countCorrelationIndicator.getExpectations().length + "]. Unable to calculate correlation");
        }
        if (this.indicator.getFractions() == null) {
            d = MovingFunctions.unweightedAvg(this.indicator.getExpectations());
            if (Double.isNaN(d)) {
                return Double.NaN;
            }
            double stdDev = MovingFunctions.stdDev(this.indicator.getExpectations(), d);
            if (Double.isNaN(stdDev)) {
                return Double.NaN;
            }
            d2 = Math.pow(stdDev, 2.0d);
        } else {
            double d3 = 0.0d;
            for (int i = 0; i < this.indicator.getExpectations().length; i++) {
                d3 += this.indicator.getExpectations()[i] * this.indicator.getFractions()[i];
            }
            if (Double.isNaN(d3)) {
                return Double.NaN;
            }
            d = d3;
            double d4 = 0.0d;
            for (int i2 = 0; i2 < this.indicator.getExpectations().length; i2++) {
                d4 += Math.pow(this.indicator.getExpectations()[i2] - d, 2.0d) * this.indicator.getFractions()[i2];
            }
            d2 = d4;
        }
        double sum = MovingFunctions.sum(countCorrelationIndicator.getExpectations()) / this.indicator.getDocCount();
        if (sum > 1.0d) {
            throw new AggregationExecutionException("doc_count of indicator must be larger than the total count of the correlating values indicator count [" + this.indicator.getDocCount() + "] correlating value total count [" + MovingFunctions.sum(countCorrelationIndicator.getExpectations()) + "]");
        }
        double d5 = ((1.0d - sum) * sum * sum) + (sum * (1.0d - sum) * (1.0d - sum));
        double d6 = 0.0d;
        if (this.indicator.getFractions() == null) {
            double length = 1.0d / this.indicator.getExpectations().length;
            for (int i3 = 0; i3 < this.indicator.getExpectations().length; i3++) {
                double d7 = this.indicator.getExpectations()[i3];
                double d8 = countCorrelationIndicator.getExpectations()[i3];
                d6 = (d6 - ((((this.indicator.getDocCount() * length) - d8) * (d7 - d)) * sum)) + (d8 * (d7 - d) * (1.0d - sum));
            }
        } else {
            for (int i4 = 0; i4 < this.indicator.getExpectations().length; i4++) {
                double d9 = this.indicator.getFractions()[i4];
                double d10 = this.indicator.getExpectations()[i4];
                double d11 = countCorrelationIndicator.getExpectations()[i4];
                d6 = (d6 - ((((this.indicator.getDocCount() * d9) - d11) * (d10 - d)) * sum)) + (d11 * (d10 - d) * (1.0d - sum));
            }
        }
        double docCount = d6 / this.indicator.getDocCount();
        if (d2 * d5 == 0.0d) {
            return Double.NaN;
        }
        return docCount / Math.sqrt(d2 * d5);
    }

    @Override // org.elasticsearch.xpack.ml.aggs.correlation.CorrelationFunction
    public void validate(PipelineAggregationBuilder.ValidationContext validationContext, String str) {
        if (str.endsWith("_count")) {
            return;
        }
        validationContext.addBucketPathValidationError("count correlation requires that bucket_path points to bucket [_count]");
    }

    static {
        PARSER.declareObject(ConstructingObjectParser.constructorArg(), (xContentParser, r3) -> {
            return CountCorrelationIndicator.fromXContent(xContentParser);
        }, INDICATOR);
    }
}
