package org.elasticsearch.xpack.vectors.query;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.Version;
import org.elasticsearch.common.logging.DeprecationCategory;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.xpack.vectors.mapper.SparseVectorFieldMapper;
import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder;
import org.elasticsearch.xpack.vectors.query.VectorScriptDocValues;

/* loaded from: input_file:org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.class */
public class ScoreScriptUtils {
    private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(ScoreScriptUtils.class);
    static final String DEPRECATION_MESSAGE = "The vector functions of the form function(query, doc['field']) are deprecated, and the form function(query, 'field') should be used instead. For example, cosineSimilarity(query, doc['field']) is replaced by cosineSimilarity(query, 'field').";

    /* loaded from: input_file:org/elasticsearch/xpack/vectors/query/ScoreScriptUtils$CosineSimilarity.class */
    public static final class CosineSimilarity extends DenseVectorFunction {
        public CosineSimilarity(ScoreScript scoreScript, List<Number> list, Object obj) {
            super(scoreScript, list, obj, true);
        }

        public double cosineSimilarity() {
            BytesRef encodedVector = getEncodedVector();
            ByteBuffer wrap = ByteBuffer.wrap(encodedVector.bytes, encodedVector.offset, encodedVector.length);
            double d = 0.0d;
            for (int i = 0; i < this.queryVector.length; i++) {
                d += r0[i] * wrap.getFloat();
            }
            return d / this.docValues.getMagnitude();
        }
    }

    /* loaded from: input_file:org/elasticsearch/xpack/vectors/query/ScoreScriptUtils$CosineSimilaritySparse.class */
    public static final class CosineSimilaritySparse extends SparseVectorFunction {
        final double queryVectorMagnitude;

        public CosineSimilaritySparse(ScoreScript scoreScript, Map<String, Number> map, Object obj) {
            super(scoreScript, map, obj);
            double d = 0.0d;
            for (int i = 0; i < this.queryDims.length; i++) {
                d += this.queryValues[i] * this.queryValues[i];
            }
            this.queryVectorMagnitude = Math.sqrt(d);
        }

        public double cosineSimilaritySparse() {
            double sqrt;
            BytesRef encodedVector = getEncodedVector();
            int[] decodeSparseVectorDims = VectorEncoderDecoder.decodeSparseVectorDims(this.docValues.indexVersion(), encodedVector);
            float[] decodeSparseVector = VectorEncoderDecoder.decodeSparseVector(this.docValues.indexVersion(), encodedVector);
            double intDotProductSparse = ScoreScriptUtils.intDotProductSparse(this.queryValues, this.queryDims, decodeSparseVector, decodeSparseVectorDims);
            double d = 0.0d;
            if (this.docValues.indexVersion().onOrAfter(Version.V_7_5_0)) {
                sqrt = VectorEncoderDecoder.decodeMagnitude(this.docValues.indexVersion(), encodedVector);
            } else {
                for (float f : decodeSparseVector) {
                    d += f * f;
                }
                sqrt = (float) Math.sqrt(d);
            }
            return intDotProductSparse / (sqrt * this.queryVectorMagnitude);
        }
    }

    /* loaded from: input_file:org/elasticsearch/xpack/vectors/query/ScoreScriptUtils$DenseVectorFunction.class */
    public static class DenseVectorFunction {
        final ScoreScript scoreScript;
        final float[] queryVector;
        final VectorScriptDocValues.DenseVectorScriptDocValues docValues;

        public DenseVectorFunction(ScoreScript scoreScript, List<Number> list, Object obj) {
            this(scoreScript, list, obj, false);
        }

        public DenseVectorFunction(ScoreScript scoreScript, List<Number> list, Object obj, boolean z) {
            this.scoreScript = scoreScript;
            if (obj instanceof String) {
                this.docValues = (VectorScriptDocValues.DenseVectorScriptDocValues) scoreScript.getDoc().get((String) obj);
            } else {
                if (!(obj instanceof VectorScriptDocValues.DenseVectorScriptDocValues)) {
                    throw new IllegalArgumentException("For vector functions, the 'field' argument must be of type String or VectorScriptDocValues");
                }
                this.docValues = (VectorScriptDocValues.DenseVectorScriptDocValues) obj;
                ScoreScriptUtils.deprecationLogger.critical(DeprecationCategory.SCRIPTING, "vector_function_signature", ScoreScriptUtils.DEPRECATION_MESSAGE, new Object[0]);
            }
            if (this.docValues.dims() != list.size()) {
                throw new IllegalArgumentException("The query vector has a different number of dimensions [" + list.size() + "] than the document vectors [" + this.docValues.dims() + "].");
            }
            this.queryVector = new float[list.size()];
            double d = 0.0d;
            for (int i = 0; i < list.size(); i++) {
                this.queryVector[i] = list.get(i).floatValue();
                d += r0 * r0;
            }
            double sqrt = Math.sqrt(d);
            if (z) {
                for (int i2 = 0; i2 < this.queryVector.length; i2++) {
                    this.queryVector[i2] = (float) (r0[r1] / sqrt);
                }
            }
        }

        BytesRef getEncodedVector() {
            try {
                this.docValues.setNextDocId(this.scoreScript._getDocId());
                BytesRef encodedValue = this.docValues.getEncodedValue();
                if (encodedValue == null) {
                    throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
                }
                return encodedValue;
            } catch (IOException e) {
                throw ExceptionsHelper.convertToElastic(e);
            }
        }
    }

    /* loaded from: input_file:org/elasticsearch/xpack/vectors/query/ScoreScriptUtils$DotProduct.class */
    public static final class DotProduct extends DenseVectorFunction {
        public DotProduct(ScoreScript scoreScript, List<Number> list, Object obj) {
            super(scoreScript, list, obj);
        }

        public double dotProduct() {
            BytesRef encodedVector = getEncodedVector();
            ByteBuffer wrap = ByteBuffer.wrap(encodedVector.bytes, encodedVector.offset, encodedVector.length);
            double d = 0.0d;
            for (int i = 0; i < this.queryVector.length; i++) {
                d += r0[i] * wrap.getFloat();
            }
            return d;
        }
    }

    /* loaded from: input_file:org/elasticsearch/xpack/vectors/query/ScoreScriptUtils$DotProductSparse.class */
    public static final class DotProductSparse extends SparseVectorFunction {
        public DotProductSparse(ScoreScript scoreScript, Map<String, Number> map, Object obj) {
            super(scoreScript, map, obj);
        }

        public double dotProductSparse() {
            BytesRef encodedVector = getEncodedVector();
            int[] decodeSparseVectorDims = VectorEncoderDecoder.decodeSparseVectorDims(this.docValues.indexVersion(), encodedVector);
            return ScoreScriptUtils.intDotProductSparse(this.queryValues, this.queryDims, VectorEncoderDecoder.decodeSparseVector(this.docValues.indexVersion(), encodedVector), decodeSparseVectorDims);
        }
    }

    /* loaded from: input_file:org/elasticsearch/xpack/vectors/query/ScoreScriptUtils$L1Norm.class */
    public static final class L1Norm extends DenseVectorFunction {
        public L1Norm(ScoreScript scoreScript, List<Number> list, Object obj) {
            super(scoreScript, list, obj);
        }

        public double l1norm() {
            BytesRef encodedVector = getEncodedVector();
            ByteBuffer wrap = ByteBuffer.wrap(encodedVector.bytes, encodedVector.offset, encodedVector.length);
            double d = 0.0d;
            for (int i = 0; i < this.queryVector.length; i++) {
                d += Math.abs(r0[i] - wrap.getFloat());
            }
            return d;
        }
    }

    /* loaded from: input_file:org/elasticsearch/xpack/vectors/query/ScoreScriptUtils$L1NormSparse.class */
    public static final class L1NormSparse extends SparseVectorFunction {
        public L1NormSparse(ScoreScript scoreScript, Map<String, Number> map, Object obj) {
            super(scoreScript, map, obj);
        }

        public double l1normSparse() {
            BytesRef encodedVector = getEncodedVector();
            int[] decodeSparseVectorDims = VectorEncoderDecoder.decodeSparseVectorDims(this.docValues.indexVersion(), encodedVector);
            float[] decodeSparseVector = VectorEncoderDecoder.decodeSparseVector(this.docValues.indexVersion(), encodedVector);
            int i = 0;
            int i2 = 0;
            double d = 0.0d;
            while (i < this.queryDims.length && i2 < decodeSparseVectorDims.length) {
                if (this.queryDims[i] == decodeSparseVectorDims[i2]) {
                    d += Math.abs(this.queryValues[i] - decodeSparseVector[i2]);
                    i++;
                    i2++;
                } else if (this.queryDims[i] > decodeSparseVectorDims[i2]) {
                    d += Math.abs(decodeSparseVector[i2]);
                    i2++;
                } else {
                    d += Math.abs(this.queryValues[i]);
                    i++;
                }
            }
            while (i < this.queryDims.length) {
                d += Math.abs(this.queryValues[i]);
                i++;
            }
            while (i2 < decodeSparseVectorDims.length) {
                d += Math.abs(decodeSparseVector[i2]);
                i2++;
            }
            return d;
        }
    }

    /* loaded from: input_file:org/elasticsearch/xpack/vectors/query/ScoreScriptUtils$L2Norm.class */
    public static final class L2Norm extends DenseVectorFunction {
        public L2Norm(ScoreScript scoreScript, List<Number> list, Object obj) {
            super(scoreScript, list, obj);
        }

        public double l2norm() {
            BytesRef encodedVector = getEncodedVector();
            ByteBuffer wrap = ByteBuffer.wrap(encodedVector.bytes, encodedVector.offset, encodedVector.length);
            double d = 0.0d;
            for (float f : this.queryVector) {
                double d2 = f - wrap.getFloat();
                d += d2 * d2;
            }
            return Math.sqrt(d);
        }
    }

    /* loaded from: input_file:org/elasticsearch/xpack/vectors/query/ScoreScriptUtils$L2NormSparse.class */
    public static final class L2NormSparse extends SparseVectorFunction {
        public L2NormSparse(ScoreScript scoreScript, Map<String, Number> map, Object obj) {
            super(scoreScript, map, obj);
        }

        public double l2normSparse() {
            BytesRef encodedVector = getEncodedVector();
            int[] decodeSparseVectorDims = VectorEncoderDecoder.decodeSparseVectorDims(this.docValues.indexVersion(), encodedVector);
            float[] decodeSparseVector = VectorEncoderDecoder.decodeSparseVector(this.docValues.indexVersion(), encodedVector);
            int i = 0;
            int i2 = 0;
            double d = 0.0d;
            while (i < this.queryDims.length && i2 < decodeSparseVectorDims.length) {
                if (this.queryDims[i] == decodeSparseVectorDims[i2]) {
                    double d2 = this.queryValues[i] - decodeSparseVector[i2];
                    d += d2 * d2;
                    i++;
                    i2++;
                } else if (this.queryDims[i] > decodeSparseVectorDims[i2]) {
                    double d3 = decodeSparseVector[i2];
                    d += d3 * d3;
                    i2++;
                } else {
                    double d4 = this.queryValues[i];
                    d += d4 * d4;
                    i++;
                }
            }
            while (i < this.queryDims.length) {
                d += this.queryValues[i] * this.queryValues[i];
                i++;
            }
            while (i2 < decodeSparseVectorDims.length) {
                d += decodeSparseVector[i2] * decodeSparseVector[i2];
                i2++;
            }
            return Math.sqrt(d);
        }
    }

    /* loaded from: input_file:org/elasticsearch/xpack/vectors/query/ScoreScriptUtils$SparseVectorFunction.class */
    public static class SparseVectorFunction {
        final ScoreScript scoreScript;
        final float[] queryValues;
        final int[] queryDims;
        final VectorScriptDocValues.SparseVectorScriptDocValues docValues;

        public SparseVectorFunction(ScoreScript scoreScript, Map<String, Number> map, Object obj) {
            this.scoreScript = scoreScript;
            int size = map.size();
            this.queryValues = new float[size];
            this.queryDims = new int[size];
            int i = 0;
            for (Map.Entry<String, Number> entry : map.entrySet()) {
                try {
                    this.queryDims[i] = Integer.parseInt(entry.getKey());
                    this.queryValues[i] = entry.getValue().floatValue();
                    i++;
                } catch (NumberFormatException e) {
                    throw new IllegalArgumentException("Failed to parse a query vector dimension, it must be an integer!", e);
                }
            }
            VectorEncoderDecoder.sortSparseDimsFloatValues(this.queryDims, this.queryValues, size);
            if (obj instanceof String) {
                this.docValues = (VectorScriptDocValues.SparseVectorScriptDocValues) scoreScript.getDoc().get((String) obj);
            } else {
                if (!(obj instanceof VectorScriptDocValues.SparseVectorScriptDocValues)) {
                    throw new IllegalArgumentException("For vector functions, the 'field' argument must be of type String or VectorScriptDocValues");
                }
                this.docValues = (VectorScriptDocValues.SparseVectorScriptDocValues) obj;
                ScoreScriptUtils.deprecationLogger.critical(DeprecationCategory.SCRIPTING, "vector_function_signature", ScoreScriptUtils.DEPRECATION_MESSAGE, new Object[0]);
            }
            ScoreScriptUtils.deprecationLogger.critical(DeprecationCategory.MAPPINGS, "sparse_vector_function", SparseVectorFieldMapper.DEPRECATION_MESSAGE, new Object[0]);
        }

        BytesRef getEncodedVector() {
            try {
                this.docValues.setNextDocId(this.scoreScript._getDocId());
                BytesRef encodedValue = this.docValues.getEncodedValue();
                if (encodedValue == null) {
                    throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
                }
                return encodedValue;
            } catch (IOException e) {
                throw ExceptionsHelper.convertToElastic(e);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double intDotProductSparse(float[] fArr, int[] iArr, float[] fArr2, int[] iArr2) {
        double d = 0.0d;
        int i = 0;
        int i2 = 0;
        while (i < fArr.length && i2 < fArr2.length) {
            if (iArr[i] == iArr2[i2]) {
                d += fArr[i] * fArr2[i2];
                i++;
                i2++;
            } else if (iArr[i] > iArr2[i2]) {
                i2++;
            } else {
                i++;
            }
        }
        return d;
    }
}
