package org.tribuo.interop.tensorflow.sequence;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.PrimitiveProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceException;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import java.io.BufferedInputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.file.Path;
import java.time.Instant;
import java.time.OffsetDateTime;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Map;
import java.util.SplittableRandom;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.proto.framework.GraphDef;
import org.tensorflow.types.TFloat32;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.interop.tensorflow.TensorFlowUtil;
import org.tribuo.interop.tensorflow.TensorMap;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.SkeletalTrainerProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.sequence.SequenceDataset;
import org.tribuo.sequence.SequenceModel;
import org.tribuo.sequence.SequenceTrainer;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/interop/tensorflow/sequence/TensorFlowSequenceTrainer.class */
public class TensorFlowSequenceTrainer<T extends Output<T>> implements SequenceTrainer<T> {
    private static final Logger log = Logger.getLogger(TensorFlowSequenceTrainer.class.getName());

    @Config(mandatory = true, description = "Path to the protobuf containing the TensorFlow graph.")
    protected Path graphPath;
    private GraphDef graphDef;

    @Config(mandatory = true, description = "Sequence feature extractor.")
    protected SequenceFeatureConverter featureConverter;

    @Config(mandatory = true, description = "Sequence output extractor.")
    protected SequenceOutputConverter<T> outputConverter;

    @Config(description = "Minibatch size")
    protected int minibatchSize;

    @Config(description = "Number of SGD epochs to run.")
    protected int epochs;

    @Config(description = "Logging interval to print the loss.")
    protected int loggingInterval;

    @Config(description = "Seed for the RNG.")
    protected long seed;

    @Config(mandatory = true, description = "Name of the training operation.")
    protected String trainOp;

    @Config(mandatory = true, description = "Name of the loss operation (to inspect the loss).")
    protected String getLossOp;

    @Config(mandatory = true, description = "Name of the prediction operation.")
    protected String predictOp;
    protected SplittableRandom rng;
    protected int trainInvocationCounter;

    /* loaded from: input_file:org/tribuo/interop/tensorflow/sequence/TensorFlowSequenceTrainer$TensorFlowSequenceTrainerProvenance.class */
    public static class TensorFlowSequenceTrainerProvenance extends SkeletalTrainerProvenance {
        private static final long serialVersionUID = 1;
        public static final String GRAPH_HASH = "graph-hash";
        public static final String GRAPH_LAST_MOD = "graph-last-modified";
        private final StringProvenance graphHash;
        private final DateTimeProvenance graphLastModified;

        <T extends Output<T>> TensorFlowSequenceTrainerProvenance(TensorFlowSequenceTrainer<T> tensorFlowSequenceTrainer) {
            super(tensorFlowSequenceTrainer);
            this.graphHash = new StringProvenance("graph-hash", ProvenanceUtil.hashResource(DEFAULT_HASH_TYPE, tensorFlowSequenceTrainer.graphPath));
            this.graphLastModified = new DateTimeProvenance("graph-last-modified", OffsetDateTime.ofInstant(Instant.ofEpochMilli(tensorFlowSequenceTrainer.graphPath.toFile().lastModified()), ZoneId.systemDefault()));
        }

        public TensorFlowSequenceTrainerProvenance(Map<String, Provenance> map) {
            this(extractTFProvenanceInfo(map));
        }

        private TensorFlowSequenceTrainerProvenance(SkeletalConfiguredObjectProvenance.ExtractedInfo extractedInfo) {
            super(extractedInfo);
            this.graphHash = (StringProvenance) extractedInfo.instanceValues.get("graph-hash");
            this.graphLastModified = (DateTimeProvenance) extractedInfo.instanceValues.get("graph-last-modified");
        }

        public Map<String, PrimitiveProvenance<?>> getInstanceValues() {
            Map<String, PrimitiveProvenance<?>> instanceValues = super.getInstanceValues();
            instanceValues.put(this.graphHash.getKey(), this.graphHash);
            instanceValues.put(this.graphLastModified.getKey(), this.graphLastModified);
            return instanceValues;
        }

        protected static SkeletalConfiguredObjectProvenance.ExtractedInfo extractTFProvenanceInfo(Map<String, Provenance> map) {
            SkeletalConfiguredObjectProvenance.ExtractedInfo extractProvenanceInfo = SkeletalTrainerProvenance.extractProvenanceInfo(map);
            if (!extractProvenanceInfo.configuredParameters.containsKey("graph-hash")) {
                throw new ProvenanceException("Failed to find graph-hash when constructing SkeletalTrainerProvenance");
            }
            HashProvenance hashProvenance = (Provenance) extractProvenanceInfo.configuredParameters.remove("graph-hash");
            if (!(hashProvenance instanceof HashProvenance)) {
                throw new ProvenanceException("graph-hash was not of type HashProvenance in class " + extractProvenanceInfo.className);
            }
            extractProvenanceInfo.instanceValues.put("graph-hash", hashProvenance);
            if (!extractProvenanceInfo.configuredParameters.containsKey("graph-last-modified")) {
                throw new ProvenanceException("Failed to find graph-last-modified when constructing SkeletalTrainerProvenance");
            }
            DateTimeProvenance dateTimeProvenance = (Provenance) extractProvenanceInfo.configuredParameters.remove("graph-last-modified");
            if (!(dateTimeProvenance instanceof DateTimeProvenance)) {
                throw new ProvenanceException("graph-last-modified was not of type DateTimeProvenance in class " + extractProvenanceInfo.className);
            }
            extractProvenanceInfo.instanceValues.put("graph-last-modified", dateTimeProvenance);
            return extractProvenanceInfo;
        }
    }

    public TensorFlowSequenceTrainer(Path path, SequenceFeatureConverter sequenceFeatureConverter, SequenceOutputConverter<T> sequenceOutputConverter, int i, int i2, int i3, long j, String str, String str2, String str3) throws IOException {
        this.minibatchSize = 1;
        this.epochs = 5;
        this.loggingInterval = 100;
        this.seed = 1L;
        this.graphPath = path;
        this.featureConverter = sequenceFeatureConverter;
        this.outputConverter = sequenceOutputConverter;
        this.minibatchSize = i;
        this.epochs = i2;
        this.loggingInterval = i3;
        this.seed = j;
        this.trainOp = str;
        this.getLossOp = str2;
        this.predictOp = str3;
        postConfig();
    }

    private TensorFlowSequenceTrainer() {
        this.minibatchSize = 1;
        this.epochs = 5;
        this.loggingInterval = 100;
        this.seed = 1L;
    }

    public synchronized void postConfig() throws IOException {
        this.rng = new SplittableRandom(this.seed);
        this.graphDef = GraphDef.parseFrom(new BufferedInputStream(new FileInputStream(this.graphPath.toFile())));
    }

    public SequenceModel<T> train(SequenceDataset<T> sequenceDataset, Map<String, Provenance> map) {
        SplittableRandom split;
        TrainerProvenance m608getProvenance;
        synchronized (this) {
            split = this.rng.split();
            m608getProvenance = m608getProvenance();
            this.trainInvocationCounter++;
        }
        ImmutableFeatureMap featureIDMap = sequenceDataset.getFeatureIDMap();
        ImmutableOutputInfo<T> outputIDInfo = sequenceDataset.getOutputIDInfo();
        ArrayList arrayList = new ArrayList();
        int[] randperm = Util.randperm(sequenceDataset.size(), split);
        try {
            Graph graph = new Graph();
            try {
                Session session = new Session(graph);
                try {
                    graph.importGraphDef(this.graphDef);
                    preTrainingHook(session, sequenceDataset);
                    int i = 0;
                    for (int i2 = 0; i2 < this.epochs; i2++) {
                        log.log(Level.INFO, "Starting epoch " + i2);
                        Util.randpermInPlace(randperm, split);
                        int i3 = 0;
                        while (i3 < sequenceDataset.size()) {
                            arrayList.clear();
                            for (int i4 = i3; i4 < i3 + this.minibatchSize && i4 < sequenceDataset.size(); i4++) {
                                arrayList.add(sequenceDataset.getExample(randperm[i4]));
                            }
                            TensorMap encode = this.featureConverter.encode(arrayList, featureIDMap);
                            TensorMap encode2 = this.outputConverter.encode(arrayList, outputIDInfo);
                            TensorMap hyperparameterFeed = getHyperparameterFeed();
                            Session.Runner runner = session.runner();
                            encode.feedInto(runner);
                            encode2.feedInto(runner);
                            hyperparameterFeed.feedInto(runner);
                            TFloat32 tFloat32 = (Tensor) runner.addTarget(this.trainOp).fetch(this.getLossOp).run().get(0);
                            try {
                                if (i % this.loggingInterval == 0) {
                                    log.info(String.format("loss %-5.6f [epoch %-2d batch %-4d #(%d - %d)/%d]", Float.valueOf(tFloat32.getFloat(new long[]{0})), Integer.valueOf(i2), Integer.valueOf(i), Integer.valueOf(i3), Integer.valueOf(Math.min(sequenceDataset.size(), i3 + this.minibatchSize)), Integer.valueOf(sequenceDataset.size())));
                                }
                                i++;
                                if (tFloat32 != null) {
                                    tFloat32.close();
                                }
                                encode.close();
                                encode2.close();
                                hyperparameterFeed.close();
                                i3 += this.minibatchSize;
                            } catch (Throwable th) {
                                if (tFloat32 != null) {
                                    try {
                                        tFloat32.close();
                                    } catch (Throwable th2) {
                                        th.addSuppressed(th2);
                                    }
                                }
                                throw th;
                            }
                        }
                    }
                    TensorFlowUtil.annotateGraph(graph, session);
                    TensorFlowSequenceModel tensorFlowSequenceModel = new TensorFlowSequenceModel("tf-sequence-model", new ModelProvenance(TensorFlowSequenceModel.class.getName(), OffsetDateTime.now(), sequenceDataset.getProvenance(), m608getProvenance, map), featureIDMap, outputIDInfo, graph.toGraphDef(), this.featureConverter, this.outputConverter, this.predictOp, TensorFlowUtil.extractMarshalledVariables(graph, session));
                    session.close();
                    graph.close();
                    return tensorFlowSequenceModel;
                } catch (Throwable th3) {
                    try {
                        session.close();
                    } catch (Throwable th4) {
                        th3.addSuppressed(th4);
                    }
                    throw th3;
                }
            } finally {
            }
        } catch (TensorFlowException e) {
            log.log(Level.SEVERE, "TensorFlow threw an error", e);
            throw new IllegalStateException(e);
        }
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public String toString() {
        return "TensorflowSequenceTrainer(graphPath=" + this.graphPath.toString() + ",exampleConverter=" + this.featureConverter.toString() + ",outputConverter=" + this.outputConverter.toString() + ",minibatchSize=" + this.minibatchSize + ",epochs=" + this.epochs + ",seed=" + this.seed + ")";
    }

    protected void preTrainingHook(Session session, SequenceDataset<T> sequenceDataset) {
    }

    protected TensorMap getHyperparameterFeed() {
        return new TensorMap(Collections.emptyMap());
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public TrainerProvenance m608getProvenance() {
        return new TensorFlowSequenceTrainerProvenance(this);
    }
}
