package org.tribuo.interop;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.tribuo.CategoricalInfo;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.MutableFeatureMap;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.Prediction;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.VectorIterator;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.provenance.ModelProvenance;

/* loaded from: input_file:org/tribuo/interop/ExternalModel.class */
public abstract class ExternalModel<T extends Output<T>, U, V> extends Model<T> {
    private static final long serialVersionUID = 1;
    public static final int DEFAULT_BATCH_SIZE = 16;
    protected final int[] featureForwardMapping;
    protected final int[] featureBackwardMapping;
    private int batchSize;

    protected ExternalModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, boolean z, Map<String, Integer> map) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, z);
        this.batchSize = 16;
        if (immutableFeatureMap.size() != map.size()) {
            throw new IllegalArgumentException("The featureMapping must be the same size as the featureIDMap, found featureMapping.size()=" + map.size() + ", featureIDMap.size()=" + immutableFeatureMap.size());
        }
        this.featureForwardMapping = new int[immutableFeatureMap.size()];
        this.featureBackwardMapping = new int[immutableFeatureMap.size()];
        Arrays.fill(this.featureForwardMapping, -1);
        Arrays.fill(this.featureBackwardMapping, -1);
        for (Map.Entry<String, Integer> entry : map.entrySet()) {
            int id = immutableFeatureMap.getID(entry.getKey());
            int intValue = entry.getValue().intValue();
            if (id == -1) {
                throw new IllegalArgumentException("Found invalid feature name in mapping " + entry);
            }
            if (intValue >= this.featureForwardMapping.length) {
                throw new IllegalArgumentException("Found invalid feature id in mapping " + entry);
            }
            if (this.featureBackwardMapping[intValue] != -1) {
                throw new IllegalArgumentException("Mapping for " + entry + " already exists as feature " + immutableFeatureMap.get(this.featureBackwardMapping[intValue]));
            }
            this.featureForwardMapping[id] = intValue;
            this.featureBackwardMapping[intValue] = id;
        }
    }

    protected ExternalModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, int[] iArr, int[] iArr2, boolean z) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, z);
        this.batchSize = 16;
        this.featureBackwardMapping = Arrays.copyOf(iArr2, iArr2.length);
        this.featureForwardMapping = Arrays.copyOf(iArr, iArr.length);
    }

    public Prediction<T> predict(Example<T> example) {
        SparseVector createSparseVector = SparseVector.createSparseVector(example, this.featureIDMap, false);
        return convertOutput((ExternalModel<T, U, V>) externalPrediction(convertFeatures(renumberFeatureIndices(createSparseVector))), createSparseVector.numActiveElements(), example);
    }

    protected List<Prediction<T>> innerPredict(Iterable<Example<T>> iterable) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        Iterator<Example<T>> it = iterable.iterator();
        while (it.hasNext()) {
            arrayList2.add(it.next());
            if (arrayList2.size() == this.batchSize) {
                arrayList.addAll(predictBatch(arrayList2));
                arrayList2.clear();
            }
        }
        if (!arrayList2.isEmpty()) {
            arrayList.addAll(predictBatch(arrayList2));
        }
        return arrayList;
    }

    private List<Prediction<T>> predictBatch(List<Example<T>> list) {
        ArrayList arrayList = new ArrayList();
        int[] iArr = new int[list.size()];
        for (int i = 0; i < list.size(); i++) {
            SparseVector createSparseVector = SparseVector.createSparseVector(list.get(i), this.featureIDMap, false);
            arrayList.add(renumberFeatureIndices(createSparseVector));
            iArr[i] = createSparseVector.numActiveElements();
        }
        List<Prediction<T>> convertOutput = convertOutput((ExternalModel<T, U, V>) externalPrediction(convertFeaturesList(arrayList)), iArr, list);
        if (convertOutput.size() != arrayList.size()) {
            throw new IllegalStateException("Unexpected number of predictions received from external model batch, found " + convertOutput.size() + ", expected " + arrayList.size() + ".");
        }
        return convertOutput;
    }

    private SparseVector renumberFeatureIndices(SparseVector sparseVector) {
        int numActiveElements = sparseVector.numActiveElements();
        int[] iArr = new int[numActiveElements];
        double[] dArr = new double[numActiveElements];
        int i = 0;
        VectorIterator it = sparseVector.iterator();
        while (it.hasNext()) {
            VectorTuple vectorTuple = (VectorTuple) it.next();
            int i2 = vectorTuple.index;
            double d = vectorTuple.value;
            iArr[i] = this.featureForwardMapping[i2];
            dArr[i] = d;
            i++;
        }
        return SparseVector.createSparseVector(sparseVector.size(), iArr, dArr);
    }

    protected abstract U convertFeatures(SparseVector sparseVector);

    protected abstract U convertFeaturesList(List<SparseVector> list);

    protected abstract V externalPrediction(U u);

    protected abstract Prediction<T> convertOutput(V v, int i, Example<T> example);

    protected abstract List<Prediction<T>> convertOutput(V v, int[] iArr, List<Example<T>> list);

    public Optional<Excuse<T>> getExcuse(Example<T> example) {
        return Optional.empty();
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public void setBatchSize(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Batch size must be positive, found " + i);
        }
        this.batchSize = i;
    }

    protected static ImmutableFeatureMap createFeatureMap(Set<String> set) {
        MutableFeatureMap mutableFeatureMap = new MutableFeatureMap();
        Iterator<String> it = set.iterator();
        while (it.hasNext()) {
            mutableFeatureMap.put(new CategoricalInfo(it.next()));
        }
        return new ImmutableFeatureMap(mutableFeatureMap);
    }

    protected static <T extends Output<T>> ImmutableOutputInfo<T> createOutputInfo(OutputFactory<T> outputFactory, Map<T, Integer> map) {
        return outputFactory.constructInfoForExternalModel(map);
    }

    protected static boolean validateFeatureMapping(int[] iArr, int[] iArr2, ImmutableFeatureMap immutableFeatureMap) {
        if (iArr2.length != iArr.length || iArr2.length != immutableFeatureMap.size()) {
            return false;
        }
        HashSet hashSet = new HashSet();
        for (int i = 0; i < iArr.length; i++) {
            int i2 = iArr[i];
            if (iArr2[i2] != i) {
                return false;
            }
            hashSet.add(Integer.valueOf(i2));
        }
        return hashSet.size() == immutableFeatureMap.size();
    }
}
