package org.tribuo.interop.tensorflow;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.logging.Logger;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.types.TFloat32;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.interop.tensorflow.protos.DenseFeatureConverterProto;
import org.tribuo.interop.tensorflow.protos.FeatureConverterProto;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.protos.ProtoSerializableClass;
import org.tribuo.protos.ProtoSerializableField;
import org.tribuo.protos.ProtoUtil;

@ProtoSerializableClass(serializedDataClass = DenseFeatureConverterProto.class, version = 0)
/* loaded from: input_file:org/tribuo/interop/tensorflow/DenseFeatureConverter.class */
public class DenseFeatureConverter implements FeatureConverter {
    private static final long serialVersionUID = 1;
    private static final Logger logger = Logger.getLogger(DenseFeatureConverter.class.getName());
    public static final int CURRENT_VERSION = 0;
    public static final int THRESHOLD = 1000000;
    public static final int WARNING_THRESHOLD = 10;
    private int warningCount = 0;

    @Config(mandatory = true, description = "TensorFlow Placeholder Input name.")
    @ProtoSerializableField
    private String inputName;

    private DenseFeatureConverter() {
    }

    public DenseFeatureConverter(String str) {
        this.inputName = str;
    }

    public static DenseFeatureConverter deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        return new DenseFeatureConverter(any.unpack(DenseFeatureConverterProto.class).getInputName());
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public FeatureConverterProto m1serialize() {
        return ProtoUtil.serialize(this);
    }

    float[] innerTransform(Example<?> example, ImmutableFeatureMap immutableFeatureMap) {
        if (this.warningCount < 10 && immutableFeatureMap.size() > 1000000) {
            logger.warning("Large dense example requested, featureIDMap.size() = " + immutableFeatureMap.size() + ", example.size() = " + example.size());
            this.warningCount++;
        }
        float[] fArr = new float[immutableFeatureMap.size()];
        Iterator it = example.iterator();
        while (it.hasNext()) {
            Feature feature = (Feature) it.next();
            int id = immutableFeatureMap.getID(feature.getName());
            if (id > -1) {
                fArr[id] = (float) feature.getValue();
            }
        }
        return fArr;
    }

    private float[] innerTransform(SGDVector sGDVector) {
        if (this.warningCount < 10 && sGDVector.size() > 1000000) {
            logger.warning("Large dense example requested, dimension = " + sGDVector.size() + ", numActiveElements = " + sGDVector.numActiveElements());
            this.warningCount++;
        }
        float[] fArr = new float[sGDVector.size()];
        if (sGDVector instanceof DenseVector) {
            DenseVector denseVector = (DenseVector) sGDVector;
            for (int i = 0; i < fArr.length; i++) {
                fArr[i] = (float) denseVector.get(i);
            }
        } else {
            Iterator it = sGDVector.iterator();
            while (it.hasNext()) {
                VectorTuple vectorTuple = (VectorTuple) it.next();
                fArr[vectorTuple.index] = (float) vectorTuple.value;
            }
        }
        return fArr;
    }

    @Override // org.tribuo.interop.tensorflow.FeatureConverter
    public TensorMap convert(Example<?> example, ImmutableFeatureMap immutableFeatureMap) {
        return new TensorMap(this.inputName, TFloat32.tensorOf(Shape.of(new long[]{serialVersionUID, r0.length}), DataBuffers.of(innerTransform(example, immutableFeatureMap))));
    }

    @Override // org.tribuo.interop.tensorflow.FeatureConverter
    public TensorMap convert(List<? extends Example<?>> list, ImmutableFeatureMap immutableFeatureMap) {
        TFloat32 tensorOf = TFloat32.tensorOf(Shape.of(new long[]{list.size(), immutableFeatureMap.size()}));
        int i = 0;
        Iterator<? extends Example<?>> it = list.iterator();
        while (it.hasNext()) {
            tensorOf.set(NdArrays.vectorOf(innerTransform(it.next(), immutableFeatureMap)), new long[]{i});
            i++;
        }
        return new TensorMap(this.inputName, tensorOf);
    }

    @Override // org.tribuo.interop.tensorflow.FeatureConverter
    public TensorMap convert(SGDVector sGDVector) {
        return new TensorMap(this.inputName, TFloat32.tensorOf(Shape.of(new long[]{serialVersionUID, r0.length}), DataBuffers.of(innerTransform(sGDVector))));
    }

    @Override // org.tribuo.interop.tensorflow.FeatureConverter
    public TensorMap convert(List<? extends SGDVector> list) {
        TFloat32 tensorOf = TFloat32.tensorOf(Shape.of(new long[]{list.size(), list.get(0).size()}));
        int i = 0;
        Iterator<? extends SGDVector> it = list.iterator();
        while (it.hasNext()) {
            tensorOf.set(NdArrays.vectorOf(innerTransform(it.next())), new long[]{i});
            i++;
        }
        return new TensorMap(this.inputName, tensorOf);
    }

    @Override // org.tribuo.interop.tensorflow.FeatureConverter
    public Set<String> inputNamesSet() {
        return Collections.singleton(this.inputName);
    }

    public String toString() {
        return "DenseFeatureConverter(inputName='" + this.inputName + "')";
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m2getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "FeatureConverter");
    }
}
