package org.tribuo.interop.tensorflow;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.Closeable;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.OffsetDateTime;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.proto.framework.GraphDef;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.Prediction;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.interop.ExternalDatasetProvenance;
import org.tribuo.interop.ExternalModel;
import org.tribuo.interop.ExternalTrainerProvenance;
import org.tribuo.interop.tensorflow.protos.FeatureConverterProto;
import org.tribuo.interop.tensorflow.protos.OutputConverterProto;
import org.tribuo.interop.tensorflow.protos.TensorFlowFrozenExternalModelProto;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/interop/tensorflow/TensorFlowFrozenExternalModel.class */
public final class TensorFlowFrozenExternalModel<T extends Output<T>> extends ExternalModel<T, TensorMap, Tensor> implements Closeable {
    private static final long serialVersionUID = 200;
    public static final int CURRENT_VERSION = 0;
    private transient Graph model;
    private transient Session session;
    private final FeatureConverter featureConverter;
    private final OutputConverter<T> outputConverter;

    @Deprecated
    private final String inputName = "";
    private final String outputName;

    private TensorFlowFrozenExternalModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, Map<String, Integer> map, Graph graph, String str2, FeatureConverter featureConverter, OutputConverter<T> outputConverter) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, outputConverter.generatesProbabilities(), map);
        this.inputName = "";
        this.model = graph;
        this.session = new Session(graph);
        this.outputName = str2;
        this.featureConverter = featureConverter;
        this.outputConverter = outputConverter;
    }

    private TensorFlowFrozenExternalModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, int[] iArr, int[] iArr2, Graph graph, String str2, FeatureConverter featureConverter, OutputConverter<T> outputConverter) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, iArr, iArr2, outputConverter.generatesProbabilities());
        this.inputName = "";
        this.model = graph;
        this.session = new Session(graph);
        this.outputName = str2;
        this.featureConverter = featureConverter;
        this.outputConverter = outputConverter;
    }

    public static TensorFlowFrozenExternalModel<?> deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        TensorFlowFrozenExternalModelProto unpack = any.unpack(TensorFlowFrozenExternalModelProto.class);
        OutputConverter outputConverter = (OutputConverter) ProtoUtil.deserialize(unpack.getOutputConverter());
        FeatureConverter featureConverter = (FeatureConverter) ProtoUtil.deserialize(unpack.getFeatureConverter());
        ModelDataCarrier deserialize = ModelDataCarrier.deserialize(unpack.getMetadata());
        if (!deserialize.outputDomain().getOutput(0).getClass().equals(outputConverter.getTypeWitness())) {
            throw new IllegalStateException("Invalid protobuf, output domain does not match converter, found " + deserialize.outputDomain().getClass() + " and " + outputConverter.getTypeWitness());
        }
        int[] primitiveInt = Util.toPrimitiveInt(unpack.getForwardFeatureMappingList());
        int[] primitiveInt2 = Util.toPrimitiveInt(unpack.getBackwardFeatureMappingList());
        if (!validateFeatureMapping(primitiveInt, primitiveInt2, deserialize.featureDomain())) {
            throw new IllegalStateException("Invalid protobuf, external<->Tribuo feature mapping does not form a bijection");
        }
        Graph graph = new Graph();
        graph.importGraphDef(GraphDef.parseFrom(unpack.getModelDef()));
        return new TensorFlowFrozenExternalModel<>(deserialize.name(), deserialize.provenance(), deserialize.featureDomain(), deserialize.outputDomain(), primitiveInt, primitiveInt2, graph, unpack.getOutputName(), featureConverter, outputConverter);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: convertFeatures, reason: merged with bridge method [inline-methods] */
    public TensorMap m19convertFeatures(SparseVector sparseVector) {
        return this.featureConverter.convert((SGDVector) sparseVector);
    }

    protected TensorMap convertFeaturesList(List<SparseVector> list) {
        return this.featureConverter.convert((List<? extends SGDVector>) list);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Tensor externalPrediction(TensorMap tensorMap) {
        Tensor tensor = (Tensor) tensorMap.feedInto(this.session.runner()).fetch(this.outputName).run().get(0);
        tensorMap.close();
        return tensor;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Prediction<T> convertOutput(Tensor tensor, int i, Example<T> example) {
        Prediction<T> convertToPrediction = this.outputConverter.convertToPrediction(tensor, this.outputIDInfo, i, example);
        tensor.close();
        return convertToPrediction;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<Prediction<T>> convertOutput(Tensor tensor, int[] iArr, List<Example<T>> list) {
        List<Prediction<T>> convertToBatchPrediction = this.outputConverter.convertToBatchPrediction(tensor, this.outputIDInfo, iArr, list);
        tensor.close();
        return convertToBatchPrediction;
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        return Collections.emptyMap();
    }

    protected Model<T> copy(String str, ModelProvenance modelProvenance) {
        GraphDef graphDef = this.model.toGraphDef();
        Graph graph = new Graph();
        graph.importGraphDef(graphDef);
        return new TensorFlowFrozenExternalModel(str, modelProvenance, this.featureIDMap, this.outputIDInfo, this.featureForwardMapping, this.featureBackwardMapping, graph, this.outputName, this.featureConverter, this.outputConverter);
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        if (this.session != null) {
            this.session.close();
        }
        if (this.model != null) {
            this.model.close();
        }
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public ModelProto m20serialize() {
        ModelDataCarrier createDataCarrier = createDataCarrier();
        TensorFlowFrozenExternalModelProto.Builder newBuilder = TensorFlowFrozenExternalModelProto.newBuilder();
        newBuilder.setMetadata(createDataCarrier.serialize());
        newBuilder.setModelDef(ByteString.copyFrom(this.model.toGraphDef().toByteArray()));
        newBuilder.setOutputName(this.outputName);
        newBuilder.addAllForwardFeatureMapping((Iterable) Arrays.stream(this.featureForwardMapping).boxed().collect(Collectors.toList()));
        newBuilder.addAllBackwardFeatureMapping((Iterable) Arrays.stream(this.featureBackwardMapping).boxed().collect(Collectors.toList()));
        newBuilder.setOutputConverter((OutputConverterProto) this.outputConverter.serialize());
        newBuilder.setFeatureConverter((FeatureConverterProto) this.featureConverter.serialize());
        ModelProto.Builder newBuilder2 = ModelProto.newBuilder();
        newBuilder2.setSerializedData(Any.pack(newBuilder.m408build()));
        newBuilder2.setClassName(TensorFlowFrozenExternalModel.class.getName());
        newBuilder2.setVersion(0);
        return newBuilder2.build();
    }

    public static <T extends Output<T>> TensorFlowFrozenExternalModel<T> createTensorflowModel(OutputFactory<T> outputFactory, Map<String, Integer> map, Map<T, Integer> map2, String str, FeatureConverter featureConverter, OutputConverter<T> outputConverter, String str2) {
        try {
            Path path = Paths.get(str2, new String[0]);
            byte[] readAllBytes = Files.readAllBytes(path);
            Graph graph = new Graph();
            graph.importGraphDef(GraphDef.parseFrom(readAllBytes));
            URL url = path.toUri().toURL();
            return new TensorFlowFrozenExternalModel<>("tf-frozen-graph", new ModelProvenance(TensorFlowFrozenExternalModel.class.getName(), OffsetDateTime.now(), new ExternalDatasetProvenance("unknown-external-data", outputFactory, false, map.size(), map2.size()), new ExternalTrainerProvenance(url)), ExternalModel.createFeatureMap(map.keySet()), ExternalModel.createOutputInfo(outputFactory, map2), map, graph, str, featureConverter, outputConverter);
        } catch (IOException e) {
            throw new IllegalArgumentException("Unable to load model from path " + str2, e);
        }
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.defaultWriteObject();
        objectOutputStream.writeObject(this.model.toGraphDef().toByteArray());
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        byte[] bArr = (byte[]) objectInputStream.readObject();
        this.model = new Graph();
        this.model.importGraphDef(GraphDef.parseFrom(bArr));
        this.session = new Session(this.model);
    }

    /* renamed from: convertFeaturesList, reason: collision with other method in class */
    protected /* bridge */ /* synthetic */ Object m18convertFeaturesList(List list) {
        return convertFeaturesList((List<SparseVector>) list);
    }
}
