package org.tribuo.common.sgd;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.SplittableRandom;
import org.tribuo.common.sgd.protos.FMParametersProto;
import org.tribuo.math.FeedForwardParameters;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseSparseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.Matrix;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.math.protos.ParametersProto;
import org.tribuo.math.protos.TensorProto;
import org.tribuo.math.util.HeapMerger;
import org.tribuo.math.util.Merger;
import org.tribuo.protos.ProtoUtil;

/* loaded from: input_file:org/tribuo/common/sgd/FMParameters.class */
public final class FMParameters implements FeedForwardParameters {
    private static final long serialVersionUID = 1;
    public static final int CURRENT_VERSION = 0;
    private static final Merger merger = new HeapMerger();
    private Tensor[] weights;
    private DenseVector biasVector;
    private DenseMatrix weightMatrix;
    private final int numFactors;

    public FMParameters(SplittableRandom splittableRandom, int i, int i2, int i3, double d) {
        this.weights = new Tensor[i2 + 2];
        this.biasVector = new DenseVector(i2);
        this.weightMatrix = new DenseMatrix(i2, i);
        this.weights[0] = this.biasVector;
        this.weights[1] = this.weightMatrix;
        for (int i4 = 0; i4 < i2; i4++) {
            Tensor denseMatrix = new DenseMatrix(i3, i);
            initializeMatrix(splittableRandom, d, denseMatrix);
            this.weights[i4 + 2] = denseMatrix;
        }
        this.numFactors = i3;
    }

    private FMParameters(Tensor[] tensorArr, int i) {
        this.weights = tensorArr;
        this.biasVector = (DenseVector) tensorArr[0];
        this.weightMatrix = (DenseMatrix) tensorArr[1];
        this.numFactors = i;
    }

    public static FMParameters deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        FMParametersProto unpack = any.unpack(FMParametersProto.class);
        int numFactors = unpack.getNumFactors();
        List<TensorProto> weightsList = unpack.getWeightsList();
        DenseMatrix[] denseMatrixArr = new Tensor[weightsList.size()];
        for (int i2 = 0; i2 < denseMatrixArr.length; i2++) {
            denseMatrixArr[i2] = ProtoUtil.deserialize(weightsList.get(i2));
        }
        if (!(denseMatrixArr[0] instanceof DenseVector)) {
            throw new IllegalArgumentException("Invalid protobuf, expected bias vector found " + denseMatrixArr[0].getClass());
        }
        int size = ((DenseVector) denseMatrixArr[0]).size();
        if (size + 2 != denseMatrixArr.length) {
            throw new IllegalArgumentException("Invalid protobuf, expected " + (size + 2) + " weight tensors, found " + denseMatrixArr.length);
        }
        if (!(denseMatrixArr[1] instanceof DenseMatrix)) {
            throw new IllegalArgumentException("Invalid protobuf, expected DenseMatrix, found " + denseMatrixArr[1].getClass());
        }
        DenseMatrix denseMatrix = denseMatrixArr[1];
        int dimension2Size = denseMatrix.getDimension2Size();
        if (denseMatrix.getDimension1Size() != size) {
            throw new IllegalArgumentException("Invalid protobuf, expected weight matrix of shape [" + size + "," + dimension2Size + "], found " + Arrays.toString(denseMatrix.getShape()));
        }
        for (int i3 = 2; i3 < denseMatrixArr.length; i3++) {
            if (denseMatrixArr[i3] instanceof DenseMatrix) {
                DenseMatrix denseMatrix2 = denseMatrixArr[i3];
                if (denseMatrix2.getDimension1Size() != numFactors || denseMatrix2.getDimension2Size() != dimension2Size) {
                    throw new IllegalArgumentException("Invalid protobuf, expected factor matrix of shape [" + numFactors + ", " + dimension2Size + "], found " + Arrays.toString(denseMatrix2.getShape()));
                }
            }
        }
        return new FMParameters(denseMatrixArr, numFactors);
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public ParametersProto m6serialize() {
        ParametersProto.Builder newBuilder = ParametersProto.newBuilder();
        newBuilder.setVersion(0);
        newBuilder.setClassName(FMParameters.class.getName());
        FMParametersProto.Builder newBuilder2 = FMParametersProto.newBuilder();
        newBuilder2.setNumFactors(this.numFactors);
        for (int i = 0; i < this.weights.length; i++) {
            newBuilder2.addWeights((TensorProto) this.weights[i].serialize());
        }
        newBuilder.setSerializedData(Any.pack(newBuilder2.m47build()));
        return newBuilder.build();
    }

    private void initializeMatrix(SplittableRandom splittableRandom, double d, DenseMatrix denseMatrix) {
        Random random = new Random(splittableRandom.nextLong());
        int dimension1Size = denseMatrix.getDimension1Size();
        int dimension2Size = denseMatrix.getDimension2Size();
        for (int i = 0; i < dimension1Size; i++) {
            for (int i2 = 0; i2 < dimension2Size; i2++) {
                denseMatrix.set(i, i2, random.nextGaussian() * d);
            }
        }
    }

    public DenseVector predict(SGDVector sGDVector) {
        DenseVector leftMultiply = this.weightMatrix.leftMultiply(sGDVector);
        leftMultiply.intersectAndAddInPlace(this.biasVector);
        DenseVector denseVector = new DenseVector(this.biasVector.size());
        for (int i = 2; i < this.weights.length; i++) {
            DenseMatrix denseMatrix = this.weights[i];
            double d = 0.0d;
            for (int i2 = 0; i2 < this.numFactors; i2++) {
                double d2 = 0.0d;
                double d3 = 0.0d;
                Iterator it = sGDVector.iterator();
                while (it.hasNext()) {
                    VectorTuple vectorTuple = (VectorTuple) it.next();
                    double d4 = denseMatrix.get(i2, vectorTuple.index) * vectorTuple.value;
                    d3 += d4;
                    d2 += d4 * d4;
                }
                d += (d3 * d3) - d2;
            }
            denseVector.set(i - 2, d / 2.0d);
        }
        leftMultiply.intersectAndAddInPlace(denseVector);
        return leftMultiply;
    }

    public Tensor[] gradients(Pair<Double, SGDVector> pair, SGDVector sGDVector) {
        Matrix denseMatrix;
        Tensor[] tensorArr = new Tensor[this.weights.length];
        SparseVector sparseVector = (SGDVector) pair.getB();
        if (sparseVector instanceof SparseVector) {
            tensorArr[0] = sparseVector.densify();
        } else {
            tensorArr[0] = sparseVector.copy();
        }
        tensorArr[1] = sparseVector.outer(sGDVector);
        for (int i = 2; i < this.weights.length; i++) {
            double d = sparseVector.get(i - 2);
            DenseMatrix denseMatrix2 = this.weights[i];
            if (d != 0.0d) {
                DenseVector leftMultiply = denseMatrix2.leftMultiply(sGDVector);
                if (sGDVector instanceof SparseVector) {
                    ArrayList arrayList = new ArrayList(this.numFactors);
                    for (int i2 = 0; i2 < this.numFactors; i2++) {
                        arrayList.add(((SparseVector) sGDVector).copy());
                    }
                    denseMatrix = new DenseSparseMatrix(arrayList);
                } else {
                    denseMatrix = new DenseMatrix(this.numFactors, sGDVector.size());
                    for (int i3 = 0; i3 < this.numFactors; i3++) {
                        for (int i4 = 0; i4 < sGDVector.size(); i4++) {
                            denseMatrix.set(i3, i4, sGDVector.get(i4));
                        }
                    }
                }
                for (int i5 = 0; i5 < this.numFactors; i5++) {
                    SGDVector row = denseMatrix.getRow(i5);
                    double d2 = leftMultiply.get(i5);
                    int i6 = i5;
                    row.foreachIndexedInPlace((num, d3) -> {
                        return (d3.doubleValue() * d2) - ((denseMatrix2.get(i6, num.intValue()) * d3.doubleValue()) * d3.doubleValue());
                    });
                    row.scaleInPlace(d);
                }
                tensorArr[i] = denseMatrix;
            } else {
                tensorArr[i] = new DenseSparseMatrix(this.numFactors, sGDVector.size());
            }
        }
        return tensorArr;
    }

    public Tensor[] getEmptyCopy() {
        Tensor[] tensorArr = new Tensor[this.weights.length];
        tensorArr[0] = new DenseVector(this.biasVector.size());
        tensorArr[1] = new DenseMatrix(this.weightMatrix.getDimension1Size(), this.weightMatrix.getDimension2Size());
        for (int i = 2; i < this.weights.length; i++) {
            DenseMatrix denseMatrix = this.weights[i];
            tensorArr[i] = new DenseMatrix(denseMatrix.getDimension1Size(), denseMatrix.getDimension2Size());
        }
        return tensorArr;
    }

    public Tensor[] get() {
        return this.weights;
    }

    public void set(Tensor[] tensorArr) {
        if (tensorArr.length == this.weights.length) {
            this.weights = tensorArr;
            this.biasVector = this.weights[0];
            this.weightMatrix = this.weights[1];
        }
    }

    public void update(Tensor[] tensorArr) {
        for (int i = 0; i < tensorArr.length; i++) {
            this.weights[i].intersectAndAddInPlace(tensorArr[i]);
        }
    }

    public Tensor[] merge(Tensor[][] tensorArr, int i) {
        Tensor[] tensorArr2 = new Tensor[this.weights.length];
        for (int i2 = 0; i2 < this.weights.length; i2++) {
            if (tensorArr[0][i2] instanceof DenseVector) {
                for (int i3 = 1; i3 < i; i3++) {
                    tensorArr[0][i2].intersectAndAddInPlace(tensorArr[i3][i2]);
                }
                tensorArr2[i2] = tensorArr[0][i2];
            } else if (tensorArr[0][i2] instanceof DenseMatrix) {
                for (int i4 = 1; i4 < i; i4++) {
                    tensorArr[0][i2].intersectAndAddInPlace(tensorArr[i4][i2]);
                }
                tensorArr2[i2] = tensorArr[0][i2];
            } else {
                if (!(tensorArr[0][i2] instanceof DenseSparseMatrix)) {
                    throw new IllegalStateException("Unexpected gradient type, expected DenseVector, DenseMatrix or DenseSparseMatrix, received " + tensorArr[0][i2].getClass().getName());
                }
                DenseSparseMatrix[] denseSparseMatrixArr = new DenseSparseMatrix[i];
                for (int i5 = 0; i5 < denseSparseMatrixArr.length; i5++) {
                    denseSparseMatrixArr[i5] = (DenseSparseMatrix) tensorArr[i5][0];
                }
                tensorArr2[i2] = merger.merge(denseSparseMatrixArr);
            }
        }
        return tensorArr2;
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public FMParameters m5copy() {
        Tensor[] tensorArr = new Tensor[this.weights.length];
        for (int i = 0; i < this.weights.length; i++) {
            tensorArr[i] = this.weights[i].copy();
        }
        return new FMParameters(tensorArr, this.numFactors);
    }
}
