package org.tribuo.math.onnx;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.tribuo.math.la.Matrix;
import org.tribuo.math.la.SGDVector;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXInitializer;

/* loaded from: input_file:org/tribuo/math/onnx/ONNXMathUtils.class */
public abstract class ONNXMathUtils {
    private ONNXMathUtils() {
    }

    public static ONNXInitializer floatVector(ONNXContext oNNXContext, String str, SGDVector sGDVector) {
        return oNNXContext.floatTensor(str, Collections.singletonList(Integer.valueOf(sGDVector.size())), floatBuffer -> {
            sGDVector.forEach(vectorTuple -> {
                floatBuffer.put(vectorTuple.index, (float) vectorTuple.value);
            });
        });
    }

    public static ONNXInitializer floatMatrix(ONNXContext oNNXContext, String str, Matrix matrix, boolean z) {
        List list = (List) Arrays.stream(matrix.getShape()).boxed().collect(Collectors.toList());
        if (z) {
            Collections.reverse(list);
        }
        return oNNXContext.floatTensor(str, list, floatBuffer -> {
            matrix.forEach(matrixTuple -> {
                floatBuffer.put(z ? (matrixTuple.j * matrix.getDimension1Size()) + matrixTuple.i : (matrixTuple.i * matrix.getDimension2Size()) + matrixTuple.j, (float) matrixTuple.value);
            });
        });
    }
}
