package org.tribuo.interop.tensorflow.example;

import org.tensorflow.Graph;
import org.tensorflow.framework.initializers.Glorot;
import org.tensorflow.framework.initializers.VarianceScaling;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.linalg.MatMul;
import org.tensorflow.op.math.Add;
import org.tensorflow.op.nn.Relu;
import org.tensorflow.proto.framework.GraphDef;
import org.tensorflow.types.TFloat32;

/* loaded from: input_file:org/tribuo/interop/tensorflow/example/MLPExamples.class */
public abstract class MLPExamples {
    private MLPExamples() {
    }

    public static GraphDefTuple buildMLPGraph(String str, int i, int[] iArr, int i2) {
        if (i < 1) {
            throw new IllegalArgumentException("Must have a positive number of features, found " + i);
        }
        if (i2 < 1) {
            throw new IllegalArgumentException("Must have a positive number of outputs, found " + i2);
        }
        if (iArr.length < 1) {
            throw new IllegalArgumentException("Must supply a hidden layer dimension.");
        }
        for (int i3 = 0; i3 < iArr.length; i3++) {
            if (iArr[i3] < 1) {
                throw new IllegalArgumentException("Hidden dimensions must be positive, found " + iArr[i3]);
            }
        }
        Graph graph = new Graph();
        Ops create = Ops.create(graph);
        Glorot glorot = new Glorot(VarianceScaling.Distribution.TRUNCATED_NORMAL, 12345L);
        Relu placeholder = create.withName(str).placeholder(TFloat32.class, new Placeholder.Options[]{Placeholder.shape(Shape.of(new long[]{-1, i}))});
        long j = i;
        for (int i4 = 0; i4 < iArr.length; i4++) {
            Relu relu = create.nn.relu(create.math.add(create.linalg.matMul(placeholder, create.variable(glorot.call(create, create.array(new long[]{j, iArr[i4]}), TFloat32.class), new Variable.Options[0]), new MatMul.Options[0]), create.variable(create.fill(create.array(new int[]{iArr[i4]}), create.constant(0.1f)), new Variable.Options[0])));
            j = iArr[i4];
            placeholder = relu;
        }
        Add add = create.math.add(create.linalg.matMul(placeholder, create.variable(glorot.call(create, create.array(new long[]{j, i2}), TFloat32.class), new Variable.Options[0]), new MatMul.Options[0]), create.variable(create.fill(create.array(new int[]{i2}), create.constant(0.1f)), new Variable.Options[0]));
        GraphDef graphDef = graph.toGraphDef();
        String name = add.op().name();
        graph.close();
        return new GraphDefTuple(graphDef, str, name);
    }
}
