package org.tribuo.math.la;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.IntBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.ToDoubleBiFunction;
import java.util.stream.Collectors;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.Output;
import org.tribuo.math.protos.SparseTensorProto;
import org.tribuo.math.protos.TensorProto;
import org.tribuo.math.util.VectorNormalizer;
import org.tribuo.util.IntDoublePair;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/math/la/SparseVector.class */
public class SparseVector implements SGDVector {
    private static final long serialVersionUID = 1;
    public static final int CURRENT_VERSION = 0;
    private final int[] shape;
    protected final int[] indices;
    protected final double[] values;
    private final int size;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/tribuo/math/la/SparseVector$SparseVectorIterator.class */
    public static class SparseVectorIterator implements VectorIterator {
        private final SparseVector vector;
        private final VectorTuple tuple = new VectorTuple();
        private int index = 0;

        public SparseVectorIterator(SparseVector sparseVector) {
            this.vector = sparseVector;
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return this.index < this.vector.indices.length;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public VectorTuple next() {
            if (!hasNext()) {
                throw new NoSuchElementException("Off the end of the iterator.");
            }
            this.tuple.index = this.vector.indices[this.index];
            this.tuple.value = this.vector.values[this.index];
            this.index++;
            return this.tuple;
        }

        @Override // org.tribuo.math.la.VectorIterator
        public VectorTuple getReference() {
            return this.tuple;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public SparseVector(int i) {
        this.indices = new int[0];
        this.values = new double[0];
        this.size = i;
        this.shape = new int[]{i};
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public SparseVector(int i, int[] iArr, double[] dArr) {
        this.size = i;
        this.shape = new int[]{i};
        this.indices = iArr;
        this.values = dArr;
    }

    private SparseVector(SparseVector sparseVector) {
        this.size = sparseVector.size;
        int numActiveElements = sparseVector.numActiveElements();
        this.indices = new int[numActiveElements];
        this.values = new double[numActiveElements];
        int i = 0;
        Iterator<VectorTuple> iterator2 = sparseVector.iterator2();
        while (iterator2.hasNext()) {
            VectorTuple next = iterator2.next();
            this.indices[i] = next.index;
            this.values[i] = next.value;
            i++;
        }
        this.shape = new int[]{this.size};
    }

    public SparseVector(int i, int[] iArr, double d) {
        this.indices = Arrays.copyOf(iArr, iArr.length);
        this.values = new double[iArr.length];
        Arrays.fill(this.values, d);
        this.size = i;
        this.shape = new int[]{i};
    }

    public static <T extends Output<T>> SparseVector createSparseVector(Example<T> example, ImmutableFeatureMap immutableFeatureMap, boolean z) {
        int size;
        int size2 = example.size();
        if (z) {
            size = immutableFeatureMap.size() + 1;
            size2++;
        } else {
            size = immutableFeatureMap.size();
        }
        int[] iArr = new int[size2];
        double[] dArr = new double[size2];
        int i = 0;
        int i2 = -1;
        Iterator it = example.iterator();
        while (it.hasNext()) {
            Feature feature = (Feature) it.next();
            int id = immutableFeatureMap.getID(feature.getName());
            if (id > i2) {
                i2 = id;
                iArr[i] = id;
                dArr[i] = feature.getValue();
                if (Double.isNaN(dArr[i])) {
                    throw new IllegalArgumentException("Example contained a NaN feature, " + feature.toString());
                }
                i++;
            } else if (id <= -1) {
                continue;
            } else {
                int binarySearch = Arrays.binarySearch(iArr, 0, i, id);
                if (binarySearch < 0) {
                    int i3 = -(binarySearch + 1);
                    System.arraycopy(iArr, i3, iArr, i3 + 1, i - i3);
                    System.arraycopy(dArr, i3, dArr, i3 + 1, i - i3);
                    iArr[i3] = id;
                    dArr[i3] = feature.getValue();
                    if (Double.isNaN(dArr[i3])) {
                        throw new IllegalArgumentException("Example contained a NaN feature, " + feature.toString());
                    }
                    i++;
                } else {
                    dArr[binarySearch] = dArr[binarySearch] + feature.getValue();
                    if (Double.isNaN(dArr[binarySearch])) {
                        throw new IllegalArgumentException("Example contained a NaN feature, " + feature.toString());
                    }
                }
            }
        }
        if (z) {
            iArr[i] = size - 1;
            dArr[i] = 1.0d;
            i++;
        }
        return new SparseVector(size, Arrays.copyOf(iArr, i), Arrays.copyOf(dArr, i));
    }

    public static SparseVector createSparseVector(int i, int[] iArr, double[] dArr) {
        if (iArr.length != dArr.length) {
            throw new IllegalArgumentException("Indices and values must be the same length, found indices.length = " + iArr.length + " and values.length = " + dArr.length);
        }
        if (iArr.length == 0) {
            return new SparseVector(i, iArr, dArr);
        }
        IntDoublePair[] intDoublePairArr = new IntDoublePair[iArr.length];
        for (int i2 = 0; i2 < intDoublePairArr.length; i2++) {
            intDoublePairArr[i2] = new IntDoublePair(iArr[i2], dArr[i2]);
        }
        Arrays.sort(intDoublePairArr, IntDoublePair.pairIndexComparator());
        int[] iArr2 = new int[iArr.length];
        double[] dArr2 = new double[dArr.length];
        for (int i3 = 0; i3 < intDoublePairArr.length; i3++) {
            iArr2[i3] = intDoublePairArr[i3].index;
            dArr2[i3] = intDoublePairArr[i3].value;
        }
        if (i < iArr2[iArr2.length - 1]) {
            throw new IllegalArgumentException("Number of dimensions is less than the maximum index, dimensions = " + i + ", max index = " + iArr2[iArr2.length - 1]);
        }
        return new SparseVector(i, iArr2, dArr2);
    }

    public static SparseVector createSparseVector(int i, Map<Integer, Double> map) {
        if (map.isEmpty()) {
            return new SparseVector(i, new int[0], new double[0]);
        }
        List list = (List) map.entrySet().stream().sorted(Map.Entry.comparingByKey()).collect(Collectors.toList());
        int[] iArr = new int[list.size()];
        double[] dArr = new double[list.size()];
        for (int i2 = 0; i2 < list.size(); i2++) {
            iArr[i2] = ((Integer) ((Map.Entry) list.get(i2)).getKey()).intValue();
            dArr[i2] = ((Double) ((Map.Entry) list.get(i2)).getValue()).doubleValue();
        }
        if (i < iArr[iArr.length - 1]) {
            throw new IllegalArgumentException("Number of dimensions is less than the maximum index, dimensions = " + i + ", max index = " + iArr[iArr.length - 1]);
        }
        return new SparseVector(i, iArr, dArr);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static SparseVector createAndValidate(int i, int[] iArr, double[] dArr) {
        if (i < 1) {
            throw new IllegalArgumentException("Invalid proto, dimension must be positive, found " + i);
        }
        if (iArr.length != dArr.length) {
            throw new IllegalArgumentException("Invalid proto, mismatch between number of indices and values. indices.length = " + iArr.length + ", values.length = " + dArr.length);
        }
        int i2 = -1;
        for (int i3 : iArr) {
            if (i3 <= i2) {
                throw new IllegalArgumentException("Invalid proto, indices are not non-negative and monotonic, indices = " + Arrays.toString(iArr));
            }
            if (i3 >= i) {
                throw new IllegalArgumentException("Invalid proto, an index is larger than the shape, indices = " + Arrays.toString(iArr));
            }
            i2 = i3;
        }
        return new SparseVector(i, iArr, dArr);
    }

    public static SparseVector deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        SparseTensorProto unpack = any.unpack(SparseTensorProto.class);
        int[] primitiveInt = Util.toPrimitiveInt(unpack.getDimensionsList());
        if (primitiveInt.length != 1) {
            throw new IllegalArgumentException("Invalid proto, expected a vector, found shape " + Arrays.toString(primitiveInt));
        }
        if (primitiveInt[0] < 1) {
            throw new IllegalArgumentException("Invalid proto, shape must be positive, found " + primitiveInt[0] + " at position 0");
        }
        int numNonZero = unpack.getNumNonZero();
        IntBuffer asIntBuffer = unpack.getIndices().asReadOnlyByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asIntBuffer();
        if (asIntBuffer.remaining() != numNonZero) {
            throw new IllegalArgumentException("Invalid proto, claimed " + numNonZero + ", but only had " + asIntBuffer.remaining() + " indices");
        }
        int[] iArr = new int[numNonZero];
        asIntBuffer.get(iArr);
        DoubleBuffer asDoubleBuffer = unpack.getValues().asReadOnlyByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asDoubleBuffer();
        if (asDoubleBuffer.remaining() != numNonZero) {
            throw new IllegalArgumentException("Invalid proto, claimed " + numNonZero + ", but only had " + asDoubleBuffer.remaining() + " values");
        }
        double[] dArr = new double[numNonZero];
        asDoubleBuffer.get(dArr);
        return createAndValidate(primitiveInt[0], iArr, dArr);
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public TensorProto m21serialize() {
        TensorProto.Builder newBuilder = TensorProto.newBuilder();
        newBuilder.setVersion(0);
        newBuilder.setClassName(SparseVector.class.getName());
        SparseTensorProto.Builder newBuilder2 = SparseTensorProto.newBuilder();
        newBuilder2.addAllDimensions((Iterable) Arrays.stream(this.shape).boxed().collect(Collectors.toList()));
        ByteBuffer order = ByteBuffer.allocate(this.indices.length * 4).order(ByteOrder.LITTLE_ENDIAN);
        IntBuffer asIntBuffer = order.asIntBuffer();
        asIntBuffer.put(this.indices);
        asIntBuffer.rewind();
        ByteBuffer order2 = ByteBuffer.allocate(this.values.length * 8).order(ByteOrder.LITTLE_ENDIAN);
        DoubleBuffer asDoubleBuffer = order2.asDoubleBuffer();
        asDoubleBuffer.put(this.values);
        asDoubleBuffer.rewind();
        newBuilder2.setIndices(ByteString.copyFrom(order));
        newBuilder2.setValues(ByteString.copyFrom(order2));
        newBuilder2.setNumNonZero(this.values.length);
        newBuilder.setSerializedData(Any.pack(newBuilder2.m789build()));
        return newBuilder.m836build();
    }

    @Override // org.tribuo.math.la.SGDVector, org.tribuo.math.la.Tensor
    public SparseVector copy() {
        return new SparseVector(this);
    }

    @Override // org.tribuo.math.la.Tensor
    public int[] getShape() {
        return this.shape;
    }

    @Override // org.tribuo.math.la.Tensor
    public Tensor reshape(int[] iArr) {
        throw new UnsupportedOperationException("Reshape not supported on sparse Tensors.");
    }

    @Override // org.tribuo.math.la.SGDVector
    public int size() {
        return this.size;
    }

    @Override // org.tribuo.math.la.SGDVector
    public int numActiveElements() {
        return this.values.length;
    }

    public boolean equals(Object obj) {
        if (!(obj instanceof SGDVector)) {
            return false;
        }
        Iterator<VectorTuple> iterator2 = iterator2();
        Iterator<VectorTuple> it = ((SGDVector) obj).iterator();
        while (iterator2.hasNext() && it.hasNext()) {
            if (!iterator2.next().equals(it.next())) {
                return false;
            }
        }
        return (iterator2.hasNext() || it.hasNext()) ? false : true;
    }

    public int hashCode() {
        return (31 * ((31 * Objects.hash(Integer.valueOf(this.size))) + Arrays.hashCode(this.indices))) + Arrays.hashCode(this.values);
    }

    @Override // org.tribuo.math.la.SGDVector
    public SGDVector add(SGDVector sGDVector) {
        if (sGDVector.size() != this.size) {
            throw new IllegalArgumentException("Can't add two vectors of different dimension, this = " + this.size + ", other = " + sGDVector.size());
        }
        if (sGDVector instanceof DenseVector) {
            return sGDVector.add(this);
        }
        if (!(sGDVector instanceof SparseVector)) {
            throw new IllegalArgumentException("Vector other is not dense or sparse.");
        }
        HashMap hashMap = new HashMap();
        Iterator<VectorTuple> iterator2 = iterator2();
        while (iterator2.hasNext()) {
            VectorTuple next = iterator2.next();
            hashMap.put(Integer.valueOf(next.index), Double.valueOf(next.value));
        }
        for (VectorTuple vectorTuple : sGDVector) {
            hashMap.merge(Integer.valueOf(vectorTuple.index), Double.valueOf(vectorTuple.value), (v0, v1) -> {
                return Double.sum(v0, v1);
            });
        }
        return createSparseVector(this.size, hashMap);
    }

    @Override // org.tribuo.math.la.SGDVector
    public SGDVector subtract(SGDVector sGDVector) {
        if (sGDVector.size() != this.size) {
            throw new IllegalArgumentException("Can't subtract two vectors of different dimension, this = " + this.size + ", other = " + sGDVector.size());
        }
        if (sGDVector instanceof DenseVector) {
            DenseVector copy = ((DenseVector) sGDVector).copy();
            Iterator<VectorTuple> iterator2 = iterator2();
            while (iterator2.hasNext()) {
                VectorTuple next = iterator2.next();
                copy.set(next.index, next.value - copy.get(next.index));
            }
            return copy;
        }
        if (!(sGDVector instanceof SparseVector)) {
            throw new IllegalArgumentException("Vector other is not dense or sparse.");
        }
        HashMap hashMap = new HashMap();
        Iterator<VectorTuple> iterator22 = iterator2();
        while (iterator22.hasNext()) {
            VectorTuple next2 = iterator22.next();
            hashMap.put(Integer.valueOf(next2.index), Double.valueOf(next2.value));
        }
        for (VectorTuple vectorTuple : sGDVector) {
            hashMap.merge(Integer.valueOf(vectorTuple.index), Double.valueOf(-vectorTuple.value), (v0, v1) -> {
                return Double.sum(v0, v1);
            });
        }
        return createSparseVector(this.size, hashMap);
    }

    @Override // org.tribuo.math.la.Tensor
    public void intersectAndAddInPlace(Tensor tensor, DoubleUnaryOperator doubleUnaryOperator) {
        if (!(tensor instanceof SparseVector)) {
            if (!(tensor instanceof DenseVector)) {
                throw new IllegalStateException("Unknown Tensor subclass " + tensor.getClass().getCanonicalName() + " for input");
            }
            DenseVector denseVector = (DenseVector) tensor;
            if (denseVector.size() != this.size) {
                throw new IllegalArgumentException("Can't intersect two vectors of different dimension, this = " + this.size + ", other = " + denseVector.size());
            }
            for (int i = 0; i < this.indices.length; i++) {
                double[] dArr = this.values;
                int i2 = i;
                dArr[i2] = dArr[i2] + doubleUnaryOperator.applyAsDouble(denseVector.get(this.indices[i]));
            }
            return;
        }
        SparseVector sparseVector = (SparseVector) tensor;
        if (sparseVector.size() != this.size) {
            throw new IllegalArgumentException("Can't intersect two vectors of different dimension, this = " + this.size + ", other = " + sparseVector.size());
        }
        if (sparseVector.numActiveElements() > 0) {
            int i3 = 0;
            Iterator<VectorTuple> iterator2 = sparseVector.iterator2();
            VectorTuple next = iterator2.next();
            while (i3 < this.indices.length - 1 && iterator2.hasNext()) {
                if (this.indices[i3] == next.index) {
                    double[] dArr2 = this.values;
                    int i4 = i3;
                    dArr2[i4] = dArr2[i4] + doubleUnaryOperator.applyAsDouble(next.value);
                    i3++;
                    next = iterator2.next();
                } else if (this.indices[i3] < next.index) {
                    i3++;
                } else {
                    next = iterator2.next();
                }
            }
            while (i3 < this.indices.length - 1) {
                if (this.indices[i3] == next.index) {
                    double[] dArr3 = this.values;
                    int i5 = i3;
                    dArr3[i5] = dArr3[i5] + doubleUnaryOperator.applyAsDouble(next.value);
                }
                i3++;
            }
            while (iterator2.hasNext()) {
                if (this.indices[i3] == next.index) {
                    double[] dArr4 = this.values;
                    int i6 = i3;
                    dArr4[i6] = dArr4[i6] + doubleUnaryOperator.applyAsDouble(next.value);
                }
                next = iterator2.next();
            }
            if (this.indices[i3] == next.index) {
                double[] dArr5 = this.values;
                int i7 = i3;
                dArr5[i7] = dArr5[i7] + doubleUnaryOperator.applyAsDouble(next.value);
            }
        }
    }

    @Override // org.tribuo.math.la.Tensor
    public void hadamardProductInPlace(Tensor tensor, DoubleUnaryOperator doubleUnaryOperator) {
        if (!(tensor instanceof SparseVector)) {
            if (!(tensor instanceof DenseVector)) {
                throw new IllegalArgumentException("Invalid Tensor subclass " + tensor.getClass().getCanonicalName() + " for input");
            }
            DenseVector denseVector = (DenseVector) tensor;
            if (denseVector.size() != this.size) {
                throw new IllegalArgumentException("Can't hadamard product two vectors of different dimension, this = " + this.size + ", other = " + denseVector.size());
            }
            for (int i = 0; i < this.indices.length; i++) {
                double[] dArr = this.values;
                int i2 = i;
                dArr[i2] = dArr[i2] * doubleUnaryOperator.applyAsDouble(denseVector.get(this.indices[i]));
            }
            return;
        }
        SparseVector sparseVector = (SparseVector) tensor;
        if (sparseVector.size() != this.size) {
            throw new IllegalArgumentException("Can't hadamard product two vectors of different dimension, this = " + this.size + ", other = " + sparseVector.size());
        }
        if (sparseVector.numActiveElements() > 0) {
            int i3 = 0;
            Iterator<VectorTuple> iterator2 = sparseVector.iterator2();
            VectorTuple next = iterator2.next();
            while (i3 < this.indices.length - 1 && iterator2.hasNext()) {
                if (this.indices[i3] == next.index) {
                    double[] dArr2 = this.values;
                    int i4 = i3;
                    dArr2[i4] = dArr2[i4] * doubleUnaryOperator.applyAsDouble(next.value);
                    i3++;
                    next = iterator2.next();
                } else if (this.indices[i3] < next.index) {
                    i3++;
                } else {
                    next = iterator2.next();
                }
            }
            while (i3 < this.indices.length - 1) {
                if (this.indices[i3] == next.index) {
                    double[] dArr3 = this.values;
                    int i5 = i3;
                    dArr3[i5] = dArr3[i5] * doubleUnaryOperator.applyAsDouble(next.value);
                }
                i3++;
            }
            while (iterator2.hasNext()) {
                if (this.indices[i3] == next.index) {
                    double[] dArr4 = this.values;
                    int i6 = i3;
                    dArr4[i6] = dArr4[i6] * doubleUnaryOperator.applyAsDouble(next.value);
                }
                next = iterator2.next();
            }
            if (this.indices[i3] == next.index) {
                double[] dArr5 = this.values;
                int i7 = i3;
                dArr5[i7] = dArr5[i7] * doubleUnaryOperator.applyAsDouble(next.value);
            }
        }
    }

    @Override // org.tribuo.math.la.Tensor
    public void foreachInPlace(DoubleUnaryOperator doubleUnaryOperator) {
        for (int i = 0; i < this.values.length; i++) {
            this.values[i] = doubleUnaryOperator.applyAsDouble(this.values[i]);
        }
    }

    @Override // org.tribuo.math.la.SGDVector
    public void foreachIndexedInPlace(ToDoubleBiFunction<Integer, Double> toDoubleBiFunction) {
        for (int i = 0; i < this.values.length; i++) {
            this.values[i] = toDoubleBiFunction.applyAsDouble(Integer.valueOf(this.indices[i]), Double.valueOf(this.values[i]));
        }
    }

    @Override // org.tribuo.math.la.SGDVector
    public SparseVector scale(double d) {
        double[] copyOf = Arrays.copyOf(this.values, this.values.length);
        for (int i = 0; i < this.values.length; i++) {
            int i2 = i;
            copyOf[i2] = copyOf[i2] * d;
        }
        return new SparseVector(this.size, Arrays.copyOf(this.indices, this.indices.length), copyOf);
    }

    @Override // org.tribuo.math.la.SGDVector
    public void add(int i, double d) {
        int binarySearch = Arrays.binarySearch(this.indices, i);
        if (binarySearch < 0) {
            throw new IllegalArgumentException("SparseVector cannot have new elements added.");
        }
        double[] dArr = this.values;
        dArr[binarySearch] = dArr[binarySearch] + d;
    }

    @Override // org.tribuo.math.la.SGDVector
    public double dot(SGDVector sGDVector) {
        if (sGDVector.size() != this.size) {
            throw new IllegalArgumentException("Can't dot two vectors of different lengths, this = " + this.size + ", other = " + sGDVector.size());
        }
        if (!(sGDVector instanceof SparseVector)) {
            if (!(sGDVector instanceof DenseVector)) {
                throw new IllegalArgumentException("Unknown vector subclass " + sGDVector.getClass().getCanonicalName() + " for input");
            }
            double d = 0.0d;
            for (int i = 0; i < this.indices.length; i++) {
                d += sGDVector.get(this.indices[i]) * this.values[i];
            }
            return d;
        }
        double d2 = 0.0d;
        if (sGDVector.numActiveElements() != 0 && this.indices.length != 0) {
            Iterator<VectorTuple> iterator2 = iterator2();
            Iterator<VectorTuple> it = sGDVector.iterator();
            VectorTuple next = iterator2.next();
            VectorTuple next2 = it.next();
            while (iterator2.hasNext() && it.hasNext()) {
                if (next.index == next2.index) {
                    d2 += next.value * next2.value;
                    next = iterator2.next();
                    next2 = it.next();
                } else if (next.index < next2.index) {
                    next = iterator2.next();
                } else {
                    next2 = it.next();
                }
            }
            while (iterator2.hasNext()) {
                if (next.index == next2.index) {
                    d2 += next.value * next2.value;
                }
                next = iterator2.next();
            }
            while (it.hasNext()) {
                if (next.index == next2.index) {
                    d2 += next.value * next2.value;
                }
                next2 = it.next();
            }
            if (next.index == next2.index) {
                d2 += next.value * next2.value;
            }
        }
        return d2;
    }

    /* JADX WARN: Type inference failed for: r0v11, types: [double[], double[][]] */
    @Override // org.tribuo.math.la.SGDVector
    public Matrix outer(SGDVector sGDVector) {
        if (sGDVector instanceof SparseVector) {
            SparseVector sparseVector = (SparseVector) sGDVector;
            SparseVector[] sparseVectorArr = new SparseVector[this.size];
            int i = 0;
            Iterator<VectorTuple> iterator2 = iterator2();
            while (iterator2.hasNext()) {
                VectorTuple next = iterator2.next();
                while (i < next.index) {
                    sparseVectorArr[i] = new SparseVector(sGDVector.size(), new int[0], new double[0]);
                    i++;
                }
                sparseVectorArr[next.index] = sparseVector.scale(next.value);
                i++;
            }
            while (i < sparseVectorArr.length) {
                sparseVectorArr[i] = new SparseVector(sGDVector.size(), new int[0], new double[0]);
                i++;
            }
            return new DenseSparseMatrix(sparseVectorArr);
        }
        if (!(sGDVector instanceof DenseVector)) {
            throw new IllegalArgumentException("Unknown vector subclass " + sGDVector.getClass().getCanonicalName() + " for input");
        }
        DenseVector denseVector = (DenseVector) sGDVector;
        int size = denseVector.size();
        ?? r0 = new double[this.size];
        int i2 = 0;
        Iterator<VectorTuple> iterator22 = iterator2();
        while (iterator22.hasNext()) {
            VectorTuple next2 = iterator22.next();
            while (i2 < next2.index) {
                r0[i2] = new double[size];
                i2++;
            }
            r0[next2.index] = denseVector.scale(next2.value).elements;
            i2++;
        }
        while (i2 < r0.length) {
            r0[i2] = new double[size];
            i2++;
        }
        return new DenseMatrix((double[][]) r0);
    }

    @Override // org.tribuo.math.la.SGDVector
    public double sum() {
        double d = 0.0d;
        for (int i = 0; i < this.values.length; i++) {
            d += this.values[i];
        }
        return d;
    }

    @Override // org.tribuo.math.la.SGDVector, org.tribuo.math.la.Tensor
    public double twoNorm() {
        double d = 0.0d;
        for (int i = 0; i < this.values.length; i++) {
            d += this.values[i] * this.values[i];
        }
        return Math.sqrt(d);
    }

    @Override // org.tribuo.math.la.SGDVector
    public double oneNorm() {
        double d = 0.0d;
        for (int i = 0; i < this.values.length; i++) {
            d += Math.abs(this.values[i]);
        }
        return d;
    }

    @Override // org.tribuo.math.la.SGDVector
    public double get(int i) {
        int binarySearch = Arrays.binarySearch(this.indices, i);
        if (binarySearch < 0) {
            return 0.0d;
        }
        return this.values[binarySearch];
    }

    @Override // org.tribuo.math.la.SGDVector
    public void set(int i, double d) {
        int binarySearch = Arrays.binarySearch(this.indices, i);
        if (binarySearch < 0) {
            throw new IllegalArgumentException("SparseVector cannot have new elements added.");
        }
        this.values[binarySearch] = d;
    }

    @Override // org.tribuo.math.la.SGDVector
    public int indexOfMax() {
        int i = 0;
        double d = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < this.values.length; i2++) {
            double d2 = this.values[i2];
            if (d2 > d) {
                i = i2;
                d = d2;
            }
        }
        return this.indices[i];
    }

    @Override // org.tribuo.math.la.SGDVector
    public double maxValue() {
        double d = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < this.values.length; i++) {
            double d2 = this.values[i];
            if (d2 > d) {
                d = d2;
            }
        }
        return d;
    }

    @Override // org.tribuo.math.la.SGDVector
    public double minValue() {
        double d = Double.POSITIVE_INFINITY;
        for (int i = 0; i < this.values.length; i++) {
            double d2 = this.values[i];
            if (d2 < d) {
                d = d2;
            }
        }
        return d;
    }

    public int[] difference(SparseVector sparseVector) {
        ArrayList arrayList = new ArrayList();
        if (sparseVector.numActiveElements() == 0) {
            return Arrays.copyOf(this.indices, this.indices.length);
        }
        if (this.indices.length == 0) {
            return new int[0];
        }
        Iterator<VectorTuple> iterator2 = iterator2();
        Iterator<VectorTuple> iterator22 = sparseVector.iterator2();
        VectorTuple next = iterator2.next();
        VectorTuple next2 = iterator22.next();
        while (iterator2.hasNext() && iterator22.hasNext()) {
            if (next.index == next2.index) {
                next = iterator2.next();
                next2 = iterator22.next();
            } else if (next.index < next2.index) {
                arrayList.add(Integer.valueOf(next.index));
                next = iterator2.next();
            } else {
                next2 = iterator22.next();
            }
        }
        while (iterator2.hasNext()) {
            if (next.index != next2.index) {
                arrayList.add(Integer.valueOf(next.index));
            }
            next = iterator2.next();
        }
        while (iterator22.hasNext() && next.index != next2.index) {
            next2 = iterator22.next();
        }
        if (next.index != next2.index) {
            arrayList.add(Integer.valueOf(next.index));
        }
        return Util.toPrimitiveInt(arrayList);
    }

    public int[] intersection(SparseVector sparseVector) {
        ArrayList arrayList = new ArrayList();
        Iterator<VectorTuple> iterator2 = iterator2();
        Iterator<VectorTuple> iterator22 = sparseVector.iterator2();
        if (iterator2.hasNext() && iterator22.hasNext()) {
            VectorTuple next = iterator2.next();
            VectorTuple next2 = iterator22.next();
            while (iterator2.hasNext() && iterator22.hasNext()) {
                if (next.index == next2.index) {
                    arrayList.add(Integer.valueOf(next.index));
                    next = iterator2.next();
                    next2 = iterator22.next();
                } else if (next.index < next2.index) {
                    next = iterator2.next();
                } else {
                    next2 = iterator22.next();
                }
            }
            while (iterator2.hasNext()) {
                if (next.index == next2.index) {
                    arrayList.add(Integer.valueOf(next.index));
                }
                next = iterator2.next();
            }
            while (iterator22.hasNext()) {
                if (next.index == next2.index) {
                    arrayList.add(Integer.valueOf(next.index));
                }
                next2 = iterator22.next();
            }
            if (next.index == next2.index) {
                arrayList.add(Integer.valueOf(next.index));
            }
        }
        return Util.toPrimitiveInt(arrayList);
    }

    @Override // org.tribuo.math.la.SGDVector
    public void normalize(VectorNormalizer vectorNormalizer) {
        throw new UnsupportedOperationException("Can't normalize a sparse array");
    }

    @Override // org.tribuo.math.la.SGDVector
    public double reduce(double d, DoubleUnaryOperator doubleUnaryOperator, DoubleBinaryOperator doubleBinaryOperator) {
        double d2 = d;
        double applyAsDouble = doubleUnaryOperator.applyAsDouble(0.0d);
        int i = 0;
        Iterator<VectorTuple> iterator2 = iterator2();
        while (iterator2.hasNext()) {
            VectorTuple next = iterator2.next();
            while (i < next.index) {
                d2 = doubleBinaryOperator.applyAsDouble(applyAsDouble, d2);
                i++;
            }
            d2 = doubleBinaryOperator.applyAsDouble(doubleUnaryOperator.applyAsDouble(next.value), d2);
            i++;
        }
        while (i < this.size) {
            d2 = doubleBinaryOperator.applyAsDouble(applyAsDouble, d2);
            i++;
        }
        return d2;
    }

    @Override // org.tribuo.math.la.SGDVector
    public double euclideanDistance(SGDVector sGDVector) {
        return distance(sGDVector, d -> {
            return d * d;
        }, Math::sqrt);
    }

    @Override // org.tribuo.math.la.SGDVector
    public double l1Distance(SGDVector sGDVector) {
        return distance(sGDVector, Math::abs, DoubleUnaryOperator.identity());
    }

    public double distance(SGDVector sGDVector, DoubleUnaryOperator doubleUnaryOperator, DoubleUnaryOperator doubleUnaryOperator2) {
        if (sGDVector.size() != this.size) {
            throw new IllegalArgumentException("Can't measure the distance between two vectors of different lengths, this = " + this.size + ", other = " + sGDVector.size());
        }
        double d = 0.0d;
        if (sGDVector.numActiveElements() != 0 && this.indices.length != 0) {
            Iterator<VectorTuple> iterator2 = iterator2();
            Iterator<VectorTuple> it = sGDVector.iterator();
            VectorTuple next = iterator2.next();
            VectorTuple next2 = it.next();
            while (iterator2.hasNext() && it.hasNext()) {
                if (next.index == next2.index) {
                    d += doubleUnaryOperator.applyAsDouble(next.value - next2.value);
                    next = iterator2.next();
                    next2 = it.next();
                } else if (next.index < next2.index) {
                    d += doubleUnaryOperator.applyAsDouble(next.value);
                    next = iterator2.next();
                } else {
                    d += doubleUnaryOperator.applyAsDouble(next2.value);
                    next2 = it.next();
                }
            }
            while (iterator2.hasNext()) {
                if (next.index == next2.index) {
                    d += doubleUnaryOperator.applyAsDouble(next.value - next2.value);
                    next2 = new VectorTuple();
                } else {
                    d += doubleUnaryOperator.applyAsDouble(next.value);
                }
                next = iterator2.next();
            }
            while (it.hasNext()) {
                if (next.index == next2.index) {
                    d += doubleUnaryOperator.applyAsDouble(next.value - next2.value);
                    next = new VectorTuple();
                } else {
                    d += doubleUnaryOperator.applyAsDouble(next2.value);
                }
                next2 = it.next();
            }
            if (next.index == next2.index) {
                d += doubleUnaryOperator.applyAsDouble(next.value - next2.value);
            } else {
                if (next.index != -1) {
                    d += doubleUnaryOperator.applyAsDouble(next.value);
                }
                if (next2.index != -1) {
                    d += doubleUnaryOperator.applyAsDouble(next2.value);
                }
            }
        } else if (this.indices.length != 0) {
            Iterator<VectorTuple> iterator22 = iterator2();
            while (iterator22.hasNext()) {
                d += doubleUnaryOperator.applyAsDouble(iterator22.next().value);
            }
        } else {
            Iterator<VectorTuple> it2 = sGDVector.iterator();
            while (it2.hasNext()) {
                d += doubleUnaryOperator.applyAsDouble(it2.next().value);
            }
        }
        return doubleUnaryOperator2.applyAsDouble(d);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("SparseVector(size=");
        sb.append(this.size);
        sb.append(",tuples=");
        for (int i = 0; i < this.indices.length; i++) {
            sb.append("[");
            sb.append(this.indices[i]);
            sb.append(",");
            sb.append(this.values[i]);
            sb.append("],");
        }
        sb.setCharAt(sb.length() - 1, ')');
        return sb.toString();
    }

    public DenseVector densify() {
        return new DenseVector(toArray());
    }

    @Deprecated
    public double[] toDenseArray() {
        return toArray();
    }

    @Override // org.tribuo.math.la.SGDVector
    public double[] toArray() {
        double[] dArr = new double[this.size];
        for (int i = 0; i < this.values.length; i++) {
            dArr[this.indices[i]] = this.values[i];
        }
        return dArr;
    }

    @Override // org.tribuo.math.la.SGDVector
    public double variance(double d) {
        double d2 = 0.0d;
        for (int i = 0; i < this.values.length; i++) {
            d2 += (this.values[i] - d) * (this.values[i] - d);
        }
        return d2 + ((this.size - this.values.length) * d * d);
    }

    @Override // java.lang.Iterable
    /* renamed from: iterator, reason: merged with bridge method [inline-methods] */
    public Iterator<VectorTuple> iterator2() {
        return new SparseVectorIterator(this);
    }

    public static SparseVector[] transpose(SparseVector[] sparseVectorArr) {
        int length = sparseVectorArr.length;
        int i = sparseVectorArr[0].size;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(new ArrayList());
            arrayList2.add(new ArrayList());
        }
        for (int i3 = 0; i3 < length; i3++) {
            Iterator<VectorTuple> iterator2 = sparseVectorArr[i3].iterator2();
            while (iterator2.hasNext()) {
                VectorTuple next = iterator2.next();
                ((ArrayList) arrayList.get(next.index)).add(Integer.valueOf(i3));
                ((ArrayList) arrayList2.get(next.index)).add(Double.valueOf(next.value));
            }
        }
        SparseVector[] sparseVectorArr2 = new SparseVector[i];
        for (int i4 = 0; i4 < i; i4++) {
            sparseVectorArr2[i4] = new SparseVector(length, Util.toPrimitiveInt((List) arrayList.get(i4)), Util.toPrimitiveDouble((List) arrayList2.get(i4)));
        }
        return sparseVectorArr2;
    }

    public static <T extends Output<T>> SparseVector[] transpose(Dataset<T> dataset) {
        return transpose(dataset, dataset.getFeatureIDMap());
    }

    public static <T extends Output<T>> SparseVector[] transpose(Dataset<T> dataset, ImmutableFeatureMap immutableFeatureMap) {
        if (dataset.getFeatureMap().size() != immutableFeatureMap.size()) {
            throw new IllegalArgumentException("The dataset's internal feature map and the supplied feature map have different sizes. dataset = " + dataset.getFeatureMap().size() + ", fMap = " + immutableFeatureMap.size());
        }
        int size = dataset.size();
        int size2 = immutableFeatureMap.size();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < size2; i++) {
            arrayList.add(new ArrayList());
            arrayList2.add(new ArrayList());
        }
        int i2 = 0;
        Iterator it = dataset.iterator();
        while (it.hasNext()) {
            Iterator it2 = ((Example) it.next()).iterator();
            while (it2.hasNext()) {
                Feature feature = (Feature) it2.next();
                int id = immutableFeatureMap.getID(feature.getName());
                ((ArrayList) arrayList.get(id)).add(Integer.valueOf(i2));
                ((ArrayList) arrayList2.get(id)).add(Double.valueOf(feature.getValue()));
            }
            i2++;
        }
        SparseVector[] sparseVectorArr = new SparseVector[size2];
        for (int i3 = 0; i3 < immutableFeatureMap.size(); i3++) {
            sparseVectorArr[i3] = new SparseVector(size, Util.toPrimitiveInt((List) arrayList.get(i3)), Util.toPrimitiveDouble((List) arrayList2.get(i3)));
        }
        return sparseVectorArr;
    }
}
