package org.tribuo.interop.tensorflow;

import com.google.protobuf.ByteString;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import org.tensorflow.Graph;
import org.tensorflow.GraphOperation;
import org.tensorflow.GraphOperationBuilder;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.ByteDataBuffer;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.op.Scope;
import org.tensorflow.types.family.TType;
import org.tribuo.interop.tensorflow.protos.TensorTupleProto;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/interop/tensorflow/TensorFlowUtil.class */
public abstract class TensorFlowUtil {
    private static final Logger logger = Logger.getLogger(TensorFlowUtil.class.getName());
    public static final String VARIABLE_V2 = "VariableV2";
    public static final String ASSIGN_OP = "Assign";
    public static final String ASSIGN_PLACEHOLDER = "Assign_from_Placeholder";
    public static final String PLACEHOLDER = "Placeholder";
    public static final String DTYPE = "dtype";

    /* loaded from: input_file:org/tribuo/interop/tensorflow/TensorFlowUtil$TensorTuple.class */
    public static final class TensorTuple implements Serializable {
        private static final long serialVersionUID = 1;
        public final String className;
        public final long[] shape;
        public final byte[] data;

        public TensorTuple(String str, long[] jArr, byte[] bArr) {
            this.className = str;
            this.shape = jArr;
            this.data = bArr;
        }

        public TensorTuple(TensorTupleProto tensorTupleProto) {
            this.className = tensorTupleProto.getClassName();
            this.shape = Util.toPrimitiveLong(tensorTupleProto.getShapeList());
            this.data = tensorTupleProto.getData().toByteArray();
        }

        public Tensor rebuildTensor() {
            try {
                Class<?> cls = Class.forName(this.className);
                if (TType.class.isAssignableFrom(cls)) {
                    return Tensor.of(cls, Shape.of(this.shape), DataBuffers.of(this.data));
                }
                throw new IllegalStateException("Unexpected Tensor type, found " + this.className);
            } catch (ClassNotFoundException e) {
                throw new IllegalStateException("Failed to instantiate Tensor class", e);
            }
        }

        public TensorTupleProto serialize() {
            TensorTupleProto.Builder newBuilder = TensorTupleProto.newBuilder();
            newBuilder.setClassName(this.className);
            newBuilder.addAllShape((Iterable) Arrays.stream(this.shape).boxed().collect(Collectors.toList()));
            newBuilder.setData(ByteString.copyFrom(this.data));
            return newBuilder.m598build();
        }

        public static TensorTuple of(TType tType) {
            ByteDataBuffer data = tType.asRawTensor().data();
            long size = data.size();
            if (size > 2147483647L) {
                throw new IllegalArgumentException("Cannot serialize Tensors bigger than Integer.MAX_VALUE, found " + size);
            }
            String name = tType.type().getName();
            long[] asArray = tType.shape().asArray();
            byte[] bArr = new byte[(int) size];
            data.read(bArr);
            return new TensorTuple(name, asArray, bArr);
        }
    }

    private TensorFlowUtil() {
    }

    public static void closeTensorCollection(Collection<Tensor> collection) {
        Iterator<Tensor> it = collection.iterator();
        while (it.hasNext()) {
            it.next().close();
        }
    }

    public static void annotateGraph(Graph graph, Session session) {
        ArrayList arrayList = new ArrayList();
        HashMap hashMap = new HashMap();
        Iterator operations = graph.operations();
        while (operations.hasNext()) {
            GraphOperation graphOperation = (GraphOperation) operations.next();
            if (graphOperation.type().equals(VARIABLE_V2)) {
                arrayList.add(graphOperation.name());
                hashMap.put(graphOperation.name(), graphOperation);
            }
        }
        Session.Runner runner = session.runner();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            runner.fetch((String) it.next());
        }
        List run = runner.run();
        if (run.size() != arrayList.size()) {
            closeTensorCollection(run);
            throw new IllegalStateException("Failed to annotate all requested variables. Requested " + arrayList.size() + ", found " + run.size());
        }
        Scope baseScope = graph.baseScope();
        for (int i = 0; i < run.size(); i++) {
            GraphOperationBuilder opBuilder = graph.opBuilder(PLACEHOLDER, generatePlaceholderName((String) arrayList.get(i)), baseScope);
            opBuilder.setAttr(DTYPE, ((Tensor) run.get(i)).dataType());
            GraphOperation build = opBuilder.build();
            GraphOperationBuilder opBuilder2 = graph.opBuilder(ASSIGN_OP, ((String) arrayList.get(i)) + "/" + ASSIGN_PLACEHOLDER, baseScope);
            opBuilder2.addInput(((GraphOperation) hashMap.get(arrayList.get(i))).output(0));
            opBuilder2.addInput(build.output(0));
            opBuilder2.build();
        }
        closeTensorCollection(run);
    }

    public static String generatePlaceholderName(String str) {
        return str + "-tribuo-" + PLACEHOLDER;
    }

    public static Map<String, TensorTuple> extractMarshalledVariables(Graph graph, Session session) {
        ArrayList arrayList = new ArrayList();
        Iterator operations = graph.operations();
        while (operations.hasNext()) {
            GraphOperation graphOperation = (GraphOperation) operations.next();
            if (graphOperation.type().equals(VARIABLE_V2)) {
                arrayList.add(graphOperation.name());
            }
        }
        Session.Runner runner = session.runner();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            runner.fetch((String) it.next());
        }
        List run = runner.run();
        if (run.size() != arrayList.size()) {
            closeTensorCollection(run);
            throw new IllegalStateException("Failed to serialise all requested variables. Requested " + arrayList.size() + ", found " + run.size());
        }
        HashMap hashMap = new HashMap();
        for (int i = 0; i < arrayList.size(); i++) {
            hashMap.put((String) arrayList.get(i), TensorTuple.of((Tensor) run.get(i)));
        }
        closeTensorCollection(run);
        return hashMap;
    }

    public static void restoreMarshalledVariables(Session session, Map<String, TensorTuple> map) {
        Session.Runner runner = session.runner();
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<String, TensorTuple> entry : map.entrySet()) {
            logger.log(Level.FINEST, "Loading " + entry.getKey() + " of type " + entry.getValue().getClass().getName());
            Tensor rebuildTensor = entry.getValue().rebuildTensor();
            runner.feed(generatePlaceholderName(entry.getKey()), rebuildTensor);
            runner.addTarget(entry.getKey() + "/" + ASSIGN_PLACEHOLDER);
            arrayList.add(rebuildTensor);
        }
        runner.run();
        closeTensorCollection(arrayList);
    }
}
