package org.tribuo.interop.tensorflow;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.framework.optimizers.AdaDelta;
import org.tensorflow.framework.optimizers.AdaGrad;
import org.tensorflow.framework.optimizers.AdaGradDA;
import org.tensorflow.framework.optimizers.Adam;
import org.tensorflow.framework.optimizers.Adamax;
import org.tensorflow.framework.optimizers.Ftrl;
import org.tensorflow.framework.optimizers.GradientDescent;
import org.tensorflow.framework.optimizers.Momentum;
import org.tensorflow.framework.optimizers.Nadam;
import org.tensorflow.framework.optimizers.RMSProp;
import org.tensorflow.op.Op;
import org.tensorflow.types.family.TNumber;
import org.tribuo.interop.tensorflow.protos.TensorFlowCheckpointModelProto;

/* loaded from: input_file:org/tribuo/interop/tensorflow/GradientOptimiser.class */
public enum GradientOptimiser {
    ADADELTA("learningRate", "rho", "epsilon"),
    ADAGRAD("learningRate", "initialAccumulatorValue"),
    ADAGRADDA("learningRate", "initialAccumulatorValue", "l1Strength", "l2Strength"),
    ADAM("learningRate", "betaOne", "betaTwo", "epsilon"),
    ADAMAX("learningRate", "betaOne", "betaTwo", "epsilon"),
    FTRL("learningRate", "learningRatePower", "initialAccumulatorValue", "l1Strength", "l2Strength", "l2ShrinkageRegularizationStrength"),
    GRADIENT_DESCENT("learningRate"),
    MOMENTUM("learningRate", "momentum"),
    NESTEROV("learningRate", "momentum"),
    NADAM("learningRate", "betaOne", "betaTwo", "epsilon"),
    RMSPROP("learningRate", "decay", "momentum", "epsilon");

    private final Set<String> args;

    /* renamed from: org.tribuo.interop.tensorflow.GradientOptimiser$1, reason: invalid class name */
    /* loaded from: input_file:org/tribuo/interop/tensorflow/GradientOptimiser$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$tribuo$interop$tensorflow$GradientOptimiser = new int[GradientOptimiser.values().length];

        static {
            try {
                $SwitchMap$org$tribuo$interop$tensorflow$GradientOptimiser[GradientOptimiser.ADADELTA.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$tribuo$interop$tensorflow$GradientOptimiser[GradientOptimiser.ADAGRAD.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$tribuo$interop$tensorflow$GradientOptimiser[GradientOptimiser.ADAGRADDA.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$tribuo$interop$tensorflow$GradientOptimiser[GradientOptimiser.ADAM.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$tribuo$interop$tensorflow$GradientOptimiser[GradientOptimiser.ADAMAX.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$tribuo$interop$tensorflow$GradientOptimiser[GradientOptimiser.FTRL.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$tribuo$interop$tensorflow$GradientOptimiser[GradientOptimiser.GRADIENT_DESCENT.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$tribuo$interop$tensorflow$GradientOptimiser[GradientOptimiser.MOMENTUM.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$tribuo$interop$tensorflow$GradientOptimiser[GradientOptimiser.NESTEROV.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$org$tribuo$interop$tensorflow$GradientOptimiser[GradientOptimiser.NADAM.ordinal()] = 10;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$org$tribuo$interop$tensorflow$GradientOptimiser[GradientOptimiser.RMSPROP.ordinal()] = 11;
            } catch (NoSuchFieldError e11) {
            }
        }
    }

    GradientOptimiser(String... strArr) {
        this.args = Collections.unmodifiableSet(new HashSet(Arrays.asList(strArr)));
    }

    public Set<String> getParameterNames() {
        return this.args;
    }

    public boolean validateParamNames(Set<String> set) {
        return this.args.size() == set.size() && this.args.containsAll(set);
    }

    public <T extends TNumber> Op applyOptimiser(Graph graph, Operand<T> operand, Map<String, Float> map) {
        AdaDelta rMSProp;
        if (!validateParamNames(map.keySet())) {
            throw new IllegalArgumentException("Invalid optimiser parameters, expected " + this.args.toString() + ", found " + map.keySet().toString());
        }
        switch (AnonymousClass1.$SwitchMap$org$tribuo$interop$tensorflow$GradientOptimiser[ordinal()]) {
            case 1:
                rMSProp = new AdaDelta(graph, "tribuo-adadelta", map.get("learningRate").floatValue(), map.get("rho").floatValue(), map.get("epsilon").floatValue());
                break;
            case 2:
                rMSProp = new AdaGrad(graph, "tribuo-adagrad", map.get("learningRate").floatValue(), map.get("initialAccumulatorValue").floatValue());
                break;
            case 3:
                rMSProp = new AdaGradDA(graph, "tribuo-adagradda", map.get("learningRate").floatValue(), map.get("initialAccumulatorValue").floatValue(), map.get("l1Strength").floatValue(), map.get("l2Strength").floatValue());
                break;
            case 4:
                rMSProp = new Adam(graph, "tribuo-adam", map.get("learningRate").floatValue(), map.get("betaOne").floatValue(), map.get("betaTwo").floatValue(), map.get("epsilon").floatValue());
                break;
            case 5:
                rMSProp = new Adamax(graph, "tribuo-adamax", map.get("learningRate").floatValue(), map.get("betaOne").floatValue(), map.get("betaTwo").floatValue(), map.get("epsilon").floatValue());
                break;
            case 6:
                rMSProp = new Ftrl(graph, "tribuo-ftrl", map.get("learningRate").floatValue(), map.get("learningRatePower").floatValue(), map.get("initialAccumulatorValue").floatValue(), map.get("l1Strength").floatValue(), map.get("l2Strength").floatValue(), map.get("l2ShrinkageRegularizationStrength").floatValue());
                break;
            case 7:
                rMSProp = new GradientDescent(graph, "tribuo-sgd", map.get("learningRate").floatValue());
                break;
            case TensorFlowCheckpointModelProto.OUTPUT_NAME_FIELD_NUMBER /* 8 */:
                rMSProp = new Momentum(graph, "tribuo-momentum", map.get("learningRate").floatValue(), map.get("momentum").floatValue(), false);
                break;
            case 9:
                rMSProp = new Momentum(graph, "tribuo-nesterov", map.get("learningRate").floatValue(), map.get("momentum").floatValue(), true);
                break;
            case DenseFeatureConverter.WARNING_THRESHOLD /* 10 */:
                rMSProp = new Nadam(graph, "tribuo-nadam", map.get("learningRate").floatValue(), map.get("betaOne").floatValue(), map.get("betaTwo").floatValue(), map.get("epsilon").floatValue());
                break;
            case 11:
                rMSProp = new RMSProp(graph, "tribuo-rmsprop", map.get("learningRate").floatValue(), map.get("decay").floatValue(), map.get("momentum").floatValue(), map.get("epsilon").floatValue(), false);
                break;
            default:
                throw new IllegalStateException("Unimplemented switch branch " + toString());
        }
        return rMSProp.minimize(operand, "tribuo-" + toString() + "-minimize");
    }
}
