package org.tribuo.interop.tensorflow.sequence;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.proto.framework.GraphDef;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.interop.tensorflow.TensorFlowUtil;
import org.tribuo.interop.tensorflow.TensorMap;
import org.tribuo.interop.tensorflow.protos.SequenceFeatureConverterProto;
import org.tribuo.interop.tensorflow.protos.SequenceOutputConverterProto;
import org.tribuo.interop.tensorflow.protos.TensorFlowSequenceModelProto;
import org.tribuo.interop.tensorflow.protos.TensorTupleProto;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.SequenceModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.sequence.SequenceExample;
import org.tribuo.sequence.SequenceModel;

/* loaded from: input_file:org/tribuo/interop/tensorflow/sequence/TensorFlowSequenceModel.class */
public class TensorFlowSequenceModel<T extends Output<T>> extends SequenceModel<T> implements AutoCloseable {
    private static final long serialVersionUID = 200;
    public static final int CURRENT_VERSION = 0;
    private transient Graph modelGraph;
    private transient Session session;
    protected final SequenceFeatureConverter featureConverter;
    protected final SequenceOutputConverter<T> outputConverter;
    protected final String predictOp;

    /* JADX INFO: Access modifiers changed from: package-private */
    public TensorFlowSequenceModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, GraphDef graphDef, SequenceFeatureConverter sequenceFeatureConverter, SequenceOutputConverter<T> sequenceOutputConverter, String str2, Map<String, TensorFlowUtil.TensorTuple> map) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo);
        this.modelGraph = null;
        this.session = null;
        this.featureConverter = sequenceFeatureConverter;
        this.outputConverter = sequenceOutputConverter;
        this.predictOp = str2;
        this.modelGraph = new Graph();
        this.modelGraph.importGraphDef(graphDef);
        this.session = new Session(this.modelGraph);
        TensorFlowUtil.restoreMarshalledVariables(this.session, map);
    }

    public static TensorFlowSequenceModel<?> deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        TensorFlowSequenceModelProto unpack = any.unpack(TensorFlowSequenceModelProto.class);
        SequenceOutputConverter sequenceOutputConverter = (SequenceOutputConverter) ProtoUtil.deserialize(unpack.getOutputConverter());
        SequenceFeatureConverter sequenceFeatureConverter = (SequenceFeatureConverter) ProtoUtil.deserialize(unpack.getFeatureConverter());
        ModelDataCarrier deserialize = ModelDataCarrier.deserialize(unpack.getMetadata());
        if (!deserialize.outputDomain().getOutput(0).getClass().equals(sequenceOutputConverter.getTypeWitness())) {
            throw new IllegalStateException("Invalid protobuf, output domain does not match converter, found " + deserialize.outputDomain().getClass() + " and " + sequenceOutputConverter.getTypeWitness());
        }
        GraphDef parseFrom = GraphDef.parseFrom(unpack.getModelDef());
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, TensorTupleProto> entry : unpack.getTensorsMap().entrySet()) {
            hashMap.put(entry.getKey(), new TensorFlowUtil.TensorTuple(entry.getValue()));
        }
        return new TensorFlowSequenceModel<>(deserialize.name(), deserialize.provenance(), deserialize.featureDomain(), deserialize.outputDomain(), parseFrom, sequenceFeatureConverter, sequenceOutputConverter, unpack.getPredictOp(), hashMap);
    }

    public List<Prediction<T>> predict(SequenceExample<T> sequenceExample) {
        TensorMap encode = this.featureConverter.encode((SequenceExample<?>) sequenceExample, this.featureIDMap);
        try {
            Tensor tensor = (Tensor) encode.feedInto(this.session.runner()).fetch(this.predictOp).run().get(0);
            try {
                List<Prediction<T>> decode = this.outputConverter.decode(tensor, sequenceExample, this.outputIDMap);
                if (tensor != null) {
                    tensor.close();
                }
                if (encode != null) {
                    encode.close();
                }
                return decode;
            } finally {
            }
        } catch (Throwable th) {
            if (encode != null) {
                try {
                    encode.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        return Collections.emptyMap();
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        if (this.session != null) {
            this.session.close();
        }
        if (this.modelGraph != null) {
            this.modelGraph.close();
        }
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public SequenceModelProto m606serialize() {
        ModelDataCarrier createDataCarrier = createDataCarrier();
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, TensorFlowUtil.TensorTuple> entry : TensorFlowUtil.extractMarshalledVariables(this.modelGraph, this.session).entrySet()) {
            hashMap.put(entry.getKey(), entry.getValue().serialize());
        }
        TensorFlowSequenceModelProto.Builder newBuilder = TensorFlowSequenceModelProto.newBuilder();
        newBuilder.setMetadata(createDataCarrier.serialize());
        newBuilder.setModelDef(ByteString.copyFrom(this.modelGraph.toGraphDef().toByteArray()));
        newBuilder.putAllTensors(hashMap);
        newBuilder.setPredictOp(this.predictOp);
        newBuilder.setOutputConverter((SequenceOutputConverterProto) this.outputConverter.serialize());
        newBuilder.setFeatureConverter((SequenceFeatureConverterProto) this.featureConverter.serialize());
        SequenceModelProto.Builder newBuilder2 = SequenceModelProto.newBuilder();
        newBuilder2.setSerializedData(Any.pack(newBuilder.m550build()));
        newBuilder2.setClassName(TensorFlowSequenceModel.class.getName());
        newBuilder2.setVersion(0);
        return newBuilder2.build();
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.defaultWriteObject();
        objectOutputStream.writeObject(this.modelGraph.toGraphDef().toByteArray());
        objectOutputStream.writeObject(TensorFlowUtil.extractMarshalledVariables(this.modelGraph, this.session));
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        byte[] bArr = (byte[]) objectInputStream.readObject();
        Map map = (Map) objectInputStream.readObject();
        this.modelGraph = new Graph();
        this.modelGraph.importGraphDef(GraphDef.parseFrom(bArr));
        this.session = new Session(this.modelGraph);
        TensorFlowUtil.restoreMarshalledVariables(this.session, map);
    }
}
