package org.tribuo.interop.tensorflow;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import java.io.Closeable;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.file.Paths;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.proto.framework.GraphDef;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.interop.tensorflow.protos.FeatureConverterProto;
import org.tribuo.interop.tensorflow.protos.OutputConverterProto;
import org.tribuo.interop.tensorflow.protos.TensorFlowCheckpointModelProto;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;

/* loaded from: input_file:org/tribuo/interop/tensorflow/TensorFlowCheckpointModel.class */
public final class TensorFlowCheckpointModel<T extends Output<T>> extends TensorFlowModel<T> implements Closeable {
    private static final Logger logger = Logger.getLogger(TensorFlowCheckpointModel.class.getName());
    private static final long serialVersionUID = 200;
    public static final int CURRENT_VERSION = 0;
    private String checkpointDirectory;
    private String checkpointName;
    private boolean initialized;

    /* JADX INFO: Access modifiers changed from: package-private */
    public TensorFlowCheckpointModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, GraphDef graphDef, String str2, String str3, int i, String str4, FeatureConverter featureConverter, OutputConverter<T> outputConverter) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, graphDef, i, str4, featureConverter, outputConverter);
        this.checkpointDirectory = str2;
        this.checkpointName = str3;
        try {
            this.session.restore(resolvePath());
            this.initialized = true;
        } catch (TensorFlowException e) {
            logger.log(Level.WARNING, "Failed to initialise model in directory " + str2, e);
        }
    }

    public static TensorFlowCheckpointModel<?> deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        TensorFlowCheckpointModelProto unpack = any.unpack(TensorFlowCheckpointModelProto.class);
        OutputConverter outputConverter = (OutputConverter) ProtoUtil.deserialize(unpack.getOutputConverter());
        FeatureConverter featureConverter = (FeatureConverter) ProtoUtil.deserialize(unpack.getFeatureConverter());
        ModelDataCarrier deserialize = ModelDataCarrier.deserialize(unpack.getMetadata());
        if (!deserialize.outputDomain().getOutput(0).getClass().equals(outputConverter.getTypeWitness())) {
            throw new IllegalStateException("Invalid protobuf, output domain does not match converter, found " + deserialize.outputDomain().getClass() + " and " + outputConverter.getTypeWitness());
        }
        return new TensorFlowCheckpointModel<>(deserialize.name(), deserialize.provenance(), deserialize.featureDomain(), deserialize.outputDomain(), GraphDef.parseFrom(unpack.getModelDef()), unpack.getCheckpointDirectory(), unpack.getCheckpointName(), unpack.getBatchSize(), unpack.getOutputName(), featureConverter, outputConverter);
    }

    private final String resolvePath() {
        return Paths.get(this.checkpointDirectory, this.checkpointName).toString();
    }

    public boolean isInitialized() {
        return this.initialized;
    }

    public final void initialize() {
        if (this.session != null) {
            this.session.close();
            this.session = null;
        }
        this.session = new Session(this.modelGraph);
        this.session.restore(resolvePath());
        this.initialized = true;
    }

    public void setCheckpointDirectory(String str) {
        this.checkpointDirectory = str;
    }

    public String getCheckpointDirectory() {
        return this.checkpointDirectory;
    }

    public void setCheckpointName(String str) {
        this.checkpointName = str;
    }

    public String getCheckpointName() {
        return this.checkpointName;
    }

    public TensorFlowNativeModel<T> convertToNativeModel() {
        return new TensorFlowNativeModel<>(this.name, this.provenance, this.featureIDMap, this.outputIDInfo, this.modelGraph.toGraphDef(), TensorFlowUtil.extractMarshalledVariables(this.modelGraph, this.session), this.batchSize, this.outputName, this.featureConverter, this.outputConverter);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public TensorFlowCheckpointModel<T> m16copy(String str, ModelProvenance modelProvenance) {
        return new TensorFlowCheckpointModel<>(str, modelProvenance, this.featureIDMap, this.outputIDInfo, this.modelGraph.toGraphDef(), this.checkpointDirectory, this.checkpointName, this.batchSize, this.outputName, this.featureConverter, this.outputConverter);
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public ModelProto m17serialize() {
        ModelDataCarrier createDataCarrier = createDataCarrier();
        TensorFlowCheckpointModelProto.Builder newBuilder = TensorFlowCheckpointModelProto.newBuilder();
        newBuilder.setMetadata(createDataCarrier.serialize());
        newBuilder.setModelDef(ByteString.copyFrom(this.modelGraph.toGraphDef().toByteArray()));
        newBuilder.setOutputName(this.outputName);
        newBuilder.setBatchSize(this.batchSize);
        newBuilder.setCheckpointDirectory(this.checkpointDirectory);
        newBuilder.setCheckpointName(this.checkpointName);
        newBuilder.setOutputConverter((OutputConverterProto) this.outputConverter.serialize());
        newBuilder.setFeatureConverter((FeatureConverterProto) this.featureConverter.serialize());
        ModelProto.Builder newBuilder2 = ModelProto.newBuilder();
        newBuilder2.setSerializedData(Any.pack(newBuilder.m361build()));
        newBuilder2.setClassName(TensorFlowCheckpointModel.class.getName());
        newBuilder2.setVersion(0);
        return newBuilder2.build();
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        if (this.closed) {
            throw new IllegalStateException("Can't serialize a closed model, the state has gone.");
        }
        objectOutputStream.defaultWriteObject();
        objectOutputStream.writeObject(this.modelGraph.toGraphDef().toByteArray());
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        byte[] bArr = (byte[]) objectInputStream.readObject();
        this.modelGraph = new Graph();
        this.modelGraph.importGraphDef(GraphDef.parseFrom(bArr));
        this.session = new Session(this.modelGraph);
        try {
            this.session.restore(resolvePath());
            this.initialized = true;
        } catch (TensorFlowException e) {
            logger.log(Level.WARNING, "Failed to initialise model after deserialization, attempted to load from " + this.checkpointDirectory, e);
        }
    }
}
