package org.tribuo.interop.tensorflow.example;

import java.util.Arrays;
import org.tensorflow.Graph;
import org.tensorflow.framework.initializers.Glorot;
import org.tensorflow.framework.initializers.VarianceScaling;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.linalg.MatMul;
import org.tensorflow.op.math.Add;
import org.tensorflow.op.nn.BiasAdd;
import org.tensorflow.op.nn.Conv2d;
import org.tensorflow.op.nn.MaxPool;
import org.tensorflow.proto.framework.GraphDef;
import org.tensorflow.types.TFloat32;

/* loaded from: input_file:org/tribuo/interop/tensorflow/example/CNNExamples.class */
public abstract class CNNExamples {
    private CNNExamples() {
    }

    public static GraphDefTuple buildLeNetGraph(String str, int i, int i2, int i3) {
        if (i < 1) {
            throw new IllegalArgumentException("Must have a positive image size, found " + i);
        }
        if (i2 < 1) {
            throw new IllegalArgumentException("Must have a positive pixel depth, found " + i2);
        }
        if (i3 < 1) {
            throw new IllegalArgumentException("Must have a positive number of outputs, found " + i3);
        }
        Graph graph = new Graph();
        Ops create = Ops.create(graph);
        Glorot glorot = new Glorot(VarianceScaling.Distribution.TRUNCATED_NORMAL, 12345L);
        MaxPool maxPool = create.nn.maxPool(create.nn.relu(create.nn.biasAdd(create.nn.conv2d(create.nn.maxPool(create.nn.relu(create.nn.biasAdd(create.nn.conv2d(create.math.div(create.math.sub(create.withName(str).placeholder(TFloat32.class, new Placeholder.Options[]{Placeholder.shape(Shape.of(new long[]{-1, i, i, 1}))}), create.constant(i2 / 2.0f)), create.constant(i2)), create.variable(glorot.call(create, create.array(new long[]{5, 5, 1, 32}), TFloat32.class), new Variable.Options[0]), Arrays.asList(1L, 1L, 1L, 1L), "SAME", new Conv2d.Options[0]), create.variable(create.fill(create.array(new int[]{32}), create.constant(0.0f)), new Variable.Options[0]), new BiasAdd.Options[0])), create.array(new int[]{1, 2, 2, 1}), create.array(new int[]{1, 2, 2, 1}), "SAME", new MaxPool.Options[0]), create.variable(glorot.call(create, create.array(new long[]{5, 5, 32, 64}), TFloat32.class), new Variable.Options[0]), Arrays.asList(1L, 1L, 1L, 1L), "SAME", new Conv2d.Options[0]), create.variable(create.fill(create.array(new int[]{64}), create.constant(0.1f)), new Variable.Options[0]), new BiasAdd.Options[0])), create.array(new int[]{1, 2, 2, 1}), create.array(new int[]{1, 2, 2, 1}), "SAME", new MaxPool.Options[0]);
        long[] asArray = maxPool.shape().subShape(1, 4).asArray();
        long j = asArray[0] * asArray[1] * asArray[2];
        Add add = create.math.add(create.linalg.matMul(create.nn.relu(create.math.add(create.linalg.matMul(create.reshape(maxPool, create.concat(Arrays.asList(create.array(new long[]{-1}), create.array(new long[]{j})), create.constant(0))), create.variable(glorot.call(create, create.concat(Arrays.asList(create.array(new long[]{j}), create.array(new long[]{512})), create.constant(0)), TFloat32.class), new Variable.Options[0]), new MatMul.Options[0]), create.variable(create.fill(create.array(new int[]{512}), create.constant(0.1f)), new Variable.Options[0]))), create.variable(glorot.call(create, create.array(new long[]{512, i3}), TFloat32.class), new Variable.Options[0]), new MatMul.Options[0]), create.variable(create.fill(create.array(new int[]{i3}), create.constant(0.1f)), new Variable.Options[0]));
        GraphDef graphDef = graph.toGraphDef();
        String name = add.op().name();
        graph.close();
        return new GraphDefTuple(graphDef, str, name);
    }
}
