package org.tribuo.ensemble;

import ai.onnx.proto.OnnxMl;
import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.ONNXExportable;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.protos.core.WeightedEnsembleModelProto;
import org.tribuo.provenance.EnsembleModelProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.impl.TimestampedTrainerProvenance;
import org.tribuo.util.Util;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXInitializer;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXPlaceholder;
import org.tribuo.util.onnx.ONNXRef;

/* loaded from: input_file:org/tribuo/ensemble/WeightedEnsembleModel.class */
public final class WeightedEnsembleModel<T extends Output<T>> extends EnsembleModel<T> implements ONNXExportable {
    private static final long serialVersionUID = 1;
    public static final int CURRENT_VERSION = 0;
    protected final float[] weights;
    protected final EnsembleCombiner<T> combiner;

    public WeightedEnsembleModel(String str, EnsembleModelProvenance ensembleModelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, List<Model<T>> list, EnsembleCombiner<T> ensembleCombiner) {
        this(str, ensembleModelProvenance, immutableFeatureMap, immutableOutputInfo, list, ensembleCombiner, Util.generateUniformVector(list.size(), 1.0f / list.size()));
    }

    public WeightedEnsembleModel(String str, EnsembleModelProvenance ensembleModelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, List<Model<T>> list, EnsembleCombiner<T> ensembleCombiner, float[] fArr) {
        super(str, ensembleModelProvenance, immutableFeatureMap, immutableOutputInfo, list);
        this.weights = Arrays.copyOf(fArr, fArr.length);
        this.combiner = ensembleCombiner;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static WeightedEnsembleModel<?> deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        WeightedEnsembleModelProto unpack = any.unpack(WeightedEnsembleModelProto.class);
        ModelDataCarrier<?> deserialize = ModelDataCarrier.deserialize(unpack.getMetadata());
        ModelProvenance provenance = deserialize.provenance();
        if (!(provenance instanceof EnsembleModelProvenance)) {
            throw new IllegalStateException("Invalid protobuf, the provenance was not an EnsembleModelProvenance. Found " + provenance);
        }
        EnsembleModelProvenance ensembleModelProvenance = (EnsembleModelProvenance) provenance;
        ImmutableOutputInfo<?> outputDomain = deserialize.outputDomain();
        Class<?> cls = outputDomain.getOutput(0).getClass();
        EnsembleCombiner<?> deserialize2 = EnsembleCombiner.deserialize(unpack.getCombiner());
        if (!cls.equals(deserialize2.getTypeWitness())) {
            throw new IllegalStateException("Invalid protobuf, combiner and output domain have a type mismatch, expected " + cls + " found " + deserialize2.getTypeWitness());
        }
        if (unpack.getModelsCount() == 0) {
            throw new IllegalStateException("Invalid protobuf, no models were found in the ensemble");
        }
        if (unpack.getModelsCount() != unpack.getWeightsCount()) {
            throw new IllegalStateException("Invalid protobuf, different numbers of models and weights were found, " + unpack.getModelsCount() + " models, " + unpack.getWeightsCount() + " weights");
        }
        ArrayList arrayList = new ArrayList(unpack.getModelsCount());
        Iterator<ModelProto> it = unpack.getModelsList().iterator();
        while (it.hasNext()) {
            Model<?> deserialize3 = Model.deserialize(it.next());
            if (!deserialize3.validate(cls)) {
                throw new IllegalStateException("Invalid protobuf, output type of model '" + deserialize3.toString() + "' did not match expected " + cls);
            }
            arrayList.add(deserialize3);
        }
        return new WeightedEnsembleModel<>(deserialize.name(), ensembleModelProvenance, deserialize.featureDomain(), outputDomain, arrayList, deserialize2, Util.toPrimitiveFloat(unpack.getWeightsList()));
    }

    @Override // org.tribuo.Model
    public Prediction<T> predict(Example<T> example) {
        ArrayList arrayList = new ArrayList();
        Iterator<Model<T>> it = this.models.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().predict((Example) example));
        }
        return this.combiner.combine(this.outputIDInfo, arrayList, this.weights);
    }

    @Override // org.tribuo.ensemble.EnsembleModel, org.tribuo.Model
    public Optional<Excuse<T>> getExcuse(Example<T> example) {
        HashMap hashMap = new HashMap();
        Prediction<T> predict = predict((Example) example);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.models.size(); i++) {
            Optional<Excuse<T>> excuse = this.models.get(i).getExcuse(example);
            if (excuse.isPresent()) {
                arrayList.add(excuse.get());
                for (Map.Entry<String, List<Pair<String, Double>>> entry : excuse.get().getScores().entrySet()) {
                    Map map = (Map) hashMap.computeIfAbsent(entry.getKey(), str -> {
                        return new HashMap();
                    });
                    for (Pair<String, Double> pair : entry.getValue()) {
                        map.merge((String) pair.getA(), Double.valueOf(((Double) pair.getB()).doubleValue() * this.weights[i]), (v0, v1) -> {
                            return Double.sum(v0, v1);
                        });
                    }
                }
            }
        }
        if (hashMap.isEmpty()) {
            return Optional.empty();
        }
        HashMap hashMap2 = new HashMap();
        for (Map.Entry entry2 : hashMap.entrySet()) {
            ArrayList arrayList2 = new ArrayList();
            for (Map.Entry entry3 : ((Map) entry2.getValue()).entrySet()) {
                arrayList2.add(new Pair((String) entry3.getKey(), (Double) entry3.getValue()));
            }
            arrayList2.sort((pair2, pair3) -> {
                return ((Double) pair3.getB()).compareTo((Double) pair2.getB());
            });
            hashMap2.put((String) entry2.getKey(), arrayList2);
        }
        return Optional.of(new EnsembleExcuse(example, predict, hashMap2, arrayList));
    }

    @Override // org.tribuo.ensemble.EnsembleModel
    protected EnsembleModel<T> copy(String str, EnsembleModelProvenance ensembleModelProvenance, List<Model<T>> list) {
        return new WeightedEnsembleModel(str, ensembleModelProvenance, this.featureIDMap, this.outputIDInfo, list, this.combiner);
    }

    public static <T extends Output<T>> WeightedEnsembleModel<T> createEnsembleFromExistingModels(String str, List<Model<T>> list, EnsembleCombiner<T> ensembleCombiner) {
        return createEnsembleFromExistingModels(str, list, ensembleCombiner, Util.generateUniformVector(list.size(), 1.0f / list.size()));
    }

    public static <T extends Output<T>> WeightedEnsembleModel<T> createEnsembleFromExistingModels(String str, List<Model<T>> list, EnsembleCombiner<T> ensembleCombiner, float[] fArr) {
        if (list.size() < 2) {
            throw new IllegalArgumentException("Must supply at least 2 models, found " + list.size());
        }
        if (fArr.length != list.size()) {
            throw new IllegalArgumentException("Must supply one weight per model, models.size() = " + list.size() + ", weights.length = " + fArr.length);
        }
        ImmutableOutputInfo<T> outputIDInfo = list.get(0).getOutputIDInfo();
        ImmutableFeatureMap featureIDMap = list.get(0).getFeatureIDMap();
        Set<T> domain = outputIDInfo.getDomain();
        for (int i = 1; i < list.size(); i++) {
            if (!list.get(i).getOutputIDInfo().getDomain().equals(domain)) {
                throw new IllegalArgumentException("Model output domains are not equal.");
            }
            if (!list.get(i).getFeatureIDMap().domainEquals(featureIDMap)) {
                throw new IllegalArgumentException("Model feature domains are not equal.");
            }
        }
        return new WeightedEnsembleModel<>(str, new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), list.get(0).mo6getProvenance().getDatasetProvenance(), new TimestampedTrainerProvenance(), ListProvenance.createListProvenance(list)), featureIDMap, outputIDInfo, new ArrayList(list), ensembleCombiner, fArr);
    }

    @Override // org.tribuo.ONNXExportable
    public OnnxMl.ModelProto exportONNXModel(String str, long j) {
        ONNXContext oNNXContext = new ONNXContext();
        oNNXContext.setName("WeightedEnsembleModel");
        ONNXPlaceholder floatInput = oNNXContext.floatInput(this.featureIDMap.size());
        writeONNXGraph(floatInput).assignTo(oNNXContext.floatOutput(this.outputIDInfo.size()));
        return ONNXExportable.buildModel(oNNXContext, str, j, this);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.tribuo.ONNXExportable
    public ONNXNode writeONNXGraph(ONNXRef<?> oNNXRef) {
        ONNXContext onnxContext = oNNXRef.onnxContext();
        ONNXInitializer array = onnxContext.array("unsqueeze_ensemble_output", new long[]{2});
        ArrayList arrayList = new ArrayList();
        for (Model<T> model : this.models) {
            if (!(model instanceof ONNXExportable)) {
                throw new IllegalStateException("Ensemble member '" + model.toString() + "' is not ONNXExportable.");
            }
            ONNXNode apply = ((ONNXExportable) model).writeONNXGraph(oNNXRef).apply(ONNXOperators.UNSQUEEZE, array);
            if (model.getOutputIDInfo().domainAndIDEquals(this.outputIDInfo)) {
                arrayList.add(apply);
            } else {
                int[] iArr = new int[this.outputIDInfo.size()];
                for (int i = 0; i < iArr.length; i++) {
                    iArr[this.outputIDInfo.getID(model.getOutputIDInfo().getOutput(i))] = i;
                }
                arrayList.add(apply.apply(ONNXOperators.GATHER, onnxContext.array("ensemble_output_gather_indices", iArr), Collections.singletonMap("axis", 1)));
            }
        }
        return this.combiner.exportCombiner(onnxContext.operation(ONNXOperators.CONCAT, arrayList, "ensemble_concat", Collections.singletonMap("axis", 2)), onnxContext.array("ensemble_weights", this.weights));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.tribuo.Model, org.tribuo.protos.ProtoSerializable
    /* renamed from: serialize */
    public ModelProto mo14serialize() {
        ModelDataCarrier<T> createDataCarrier = createDataCarrier();
        WeightedEnsembleModelProto.Builder newBuilder = WeightedEnsembleModelProto.newBuilder();
        newBuilder.setMetadata(createDataCarrier.serialize());
        Iterator<Model<T>> it = this.models.iterator();
        while (it.hasNext()) {
            newBuilder.addModels(it.next().mo14serialize());
        }
        newBuilder.addAllWeights(Util.toBoxedFloats(this.weights));
        newBuilder.setCombiner(this.combiner.mo14serialize());
        ModelProto.Builder newBuilder2 = ModelProto.newBuilder();
        newBuilder2.setSerializedData(Any.pack(newBuilder.m2410build()));
        newBuilder2.setClassName(WeightedEnsembleModel.class.getName());
        newBuilder2.setVersion(0);
        return newBuilder2.m1370build();
    }
}
