package org.tribuo.interop.tensorflow;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.logging.Logger;
import org.tensorflow.Graph;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.SessionFunction;
import org.tensorflow.Signature;
import org.tensorflow.Tensor;
import org.tensorflow.proto.framework.GraphDef;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.provenance.ModelProvenance;

/* loaded from: input_file:org/tribuo/interop/tensorflow/TensorFlowModel.class */
public abstract class TensorFlowModel<T extends Output<T>> extends Model<T> implements AutoCloseable {
    private static final Logger logger = Logger.getLogger(TensorFlowModel.class.getName());
    private static final long serialVersionUID = 200;
    protected int batchSize;
    protected final String outputName;
    protected final FeatureConverter featureConverter;
    protected final OutputConverter<T> outputConverter;
    protected transient Graph modelGraph;
    protected transient Session session;
    protected transient boolean closed;

    /* JADX INFO: Access modifiers changed from: protected */
    public TensorFlowModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, GraphDef graphDef, int i, String str2, FeatureConverter featureConverter, OutputConverter<T> outputConverter) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, outputConverter.generatesProbabilities());
        this.modelGraph = null;
        this.session = null;
        this.closed = false;
        this.modelGraph = new Graph();
        this.modelGraph.importGraphDef(graphDef);
        this.session = new Session(this.modelGraph);
        this.batchSize = i;
        this.outputName = str2;
        this.featureConverter = featureConverter;
        this.outputConverter = outputConverter;
    }

    public Prediction<T> predict(Example<T> example) {
        if (this.closed) {
            throw new IllegalStateException("Can't use a closed model, the state has gone.");
        }
        SGDVector createSparseVector = SparseVector.createSparseVector(example, this.featureIDMap, false);
        TensorMap convert = this.featureConverter.convert(createSparseVector);
        try {
            Tensor tensor = (Tensor) convert.feedInto(this.session.runner()).fetch(this.outputName).run().get(0);
            try {
                Prediction<T> convertToPrediction = this.outputConverter.convertToPrediction(tensor, this.outputIDInfo, createSparseVector.numActiveElements(), example);
                if (tensor != null) {
                    tensor.close();
                }
                if (convert != null) {
                    convert.close();
                }
                return convertToPrediction;
            } finally {
            }
        } catch (Throwable th) {
            if (convert != null) {
                try {
                    convert.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    protected List<Prediction<T>> innerPredict(Iterable<Example<T>> iterable) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        Iterator<Example<T>> it = iterable.iterator();
        while (it.hasNext()) {
            arrayList2.add(it.next());
            if (arrayList2.size() == this.batchSize) {
                arrayList.addAll(predictBatch(arrayList2));
                arrayList2.clear();
            }
        }
        if (!arrayList2.isEmpty()) {
            arrayList.addAll(predictBatch(arrayList2));
        }
        return arrayList;
    }

    private List<Prediction<T>> predictBatch(List<Example<T>> list) {
        if (this.closed) {
            throw new IllegalStateException("Can't use a closed model, the state has gone.");
        }
        ArrayList arrayList = new ArrayList(list.size());
        int[] iArr = new int[list.size()];
        for (int i = 0; i < list.size(); i++) {
            SparseVector createSparseVector = SparseVector.createSparseVector(list.get(i), this.featureIDMap, false);
            iArr[i] = createSparseVector.numActiveElements();
            arrayList.add(createSparseVector);
        }
        TensorMap convert = this.featureConverter.convert(arrayList);
        try {
            Tensor tensor = (Tensor) convert.feedInto(this.session.runner()).fetch(this.outputName).run().get(0);
            try {
                List<Prediction<T>> convertToBatchPrediction = this.outputConverter.convertToBatchPrediction(tensor, this.outputIDInfo, iArr, list);
                if (tensor != null) {
                    tensor.close();
                }
                if (convert != null) {
                    convert.close();
                }
                return convertToBatchPrediction;
            } finally {
            }
        } catch (Throwable th) {
            if (convert != null) {
                try {
                    convert.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public void setBatchSize(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Batch size must be positive, found " + i);
        }
        this.batchSize = i;
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        return Collections.emptyMap();
    }

    public Optional<Excuse<T>> getExcuse(Example<T> example) {
        return Optional.empty();
    }

    public String getOutputName() {
        return this.outputName;
    }

    public void exportModel(String str) throws IOException {
        if (this.closed) {
            throw new IllegalStateException("Can't serialize a closed model, the state has gone.");
        }
        Signature.Builder builder = Signature.builder();
        for (String str2 : this.featureConverter.inputNamesSet()) {
            builder.input(str2, this.modelGraph.operation(str2).output(0));
        }
        SavedModelBundle.exporter(str).withFunction(SessionFunction.create(builder.output(this.outputName, this.modelGraph.operation(this.outputName).output(0)).build(), this.session)).export();
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        if (this.session != null) {
            this.session.close();
            this.session = null;
        }
        if (this.modelGraph != null) {
            this.modelGraph.close();
            this.modelGraph = null;
        }
        this.closed = true;
    }
}
