package org.tribuo.common.liblinear;

import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import de.bwaldvogel.liblinear.Linear;
import de.bwaldvogel.liblinear.SolverType;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.StringReader;
import java.io.StringWriter;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.common.liblinear.protos.LibLinearModelProto;
import org.tribuo.common.liblinear.protos.LibLinearProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/common/liblinear/LibLinearModel.class */
public abstract class LibLinearModel<T extends Output<T>> extends Model<T> {
    private static final long serialVersionUID = 3;
    private static final Logger logger = Logger.getLogger(LibLinearModel.class.getName());
    public static final int CURRENT_VERSION = 0;
    protected List<de.bwaldvogel.liblinear.Model> models;

    protected LibLinearModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, boolean z, List<de.bwaldvogel.liblinear.Model> list) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, z);
        this.models = list;
        Linear.disableDebugOutput();
    }

    public List<de.bwaldvogel.liblinear.Model> getInnerModels() {
        ArrayList arrayList = new ArrayList();
        Iterator<de.bwaldvogel.liblinear.Model> it = this.models.iterator();
        while (it.hasNext()) {
            arrayList.add(copyModel(it.next()));
        }
        return Collections.unmodifiableList(arrayList);
    }

    public Optional<Excuse<T>> getExcuse(Example<T> example) {
        return Optional.of(innerGetExcuse(example, getFeatureWeights()));
    }

    public Optional<List<Excuse<T>>> getExcuses(Iterable<Example<T>> iterable) {
        double[][] featureWeights = getFeatureWeights();
        ArrayList arrayList = new ArrayList();
        Iterator<Example<T>> it = iterable.iterator();
        while (it.hasNext()) {
            arrayList.add(innerGetExcuse(it.next(), featureWeights));
        }
        return Optional.of(arrayList);
    }

    protected static de.bwaldvogel.liblinear.Model copyModel(de.bwaldvogel.liblinear.Model model) {
        try {
            StringWriter stringWriter = new StringWriter();
            Linear.saveModel(stringWriter, model);
            return Linear.loadModel(new StringReader(stringWriter.toString()));
        } catch (IOException e) {
            throw new IllegalStateException("IOException found when copying the model in memory via a String.", e);
        }
    }

    protected abstract double[][] getFeatureWeights();

    protected abstract Excuse<T> innerGetExcuse(Example<T> example, double[][] dArr);

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public ModelProto m1serialize() {
        ModelDataCarrier createDataCarrier = createDataCarrier();
        LibLinearModelProto.Builder newBuilder = LibLinearModelProto.newBuilder();
        newBuilder.setMetadata(createDataCarrier.serialize());
        for (de.bwaldvogel.liblinear.Model model : this.models) {
            try {
                ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
                ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteArrayOutputStream);
                objectOutputStream.writeObject(model);
                objectOutputStream.close();
                newBuilder.addModels(ByteString.copyFrom(byteArrayOutputStream.toByteArray()));
            } catch (IOException e) {
                throw new IllegalStateException("Could not serialize liblinear model to byte array");
            }
        }
        ModelProto.Builder newBuilder2 = ModelProto.newBuilder();
        newBuilder2.setSerializedData(Any.pack(newBuilder.m47build()));
        newBuilder2.setClassName(getClass().getName());
        newBuilder2.setVersion(0);
        return newBuilder2.build();
    }

    private static LibLinearProto serializeModel(de.bwaldvogel.liblinear.Model model) {
        LibLinearProto.Builder newBuilder = LibLinearProto.newBuilder();
        newBuilder.setBias(model.getBias());
        newBuilder.addAllLabel((Iterable) Arrays.stream(model.getLabels()).boxed().collect(Collectors.toList()));
        newBuilder.setNrClass(model.getNrClass());
        newBuilder.setNrFeature(model.getNrFeature());
        newBuilder.setSolverType(model.getSolverType().name());
        newBuilder.addAllW((Iterable) Arrays.stream(model.getFeatureWeights()).boxed().collect(Collectors.toList()));
        if (model.getSolverType().isOneClass()) {
            newBuilder.setRho(model.getDecfunRho());
        }
        return newBuilder.m94build();
    }

    private static de.bwaldvogel.liblinear.Model deserializeModels(LibLinearProto libLinearProto) {
        de.bwaldvogel.liblinear.Model model = new de.bwaldvogel.liblinear.Model();
        setField(de.bwaldvogel.liblinear.Model.class, "bias", model, Double.valueOf(libLinearProto.getBias()));
        setField(de.bwaldvogel.liblinear.Model.class, "label", model, Util.toPrimitiveInt(libLinearProto.getLabelList()));
        setField(de.bwaldvogel.liblinear.Model.class, "nr_class", model, Integer.valueOf(libLinearProto.getNrClass()));
        setField(de.bwaldvogel.liblinear.Model.class, "nr_feature", model, Integer.valueOf(libLinearProto.getNrFeature()));
        setField(de.bwaldvogel.liblinear.Model.class, "solverType", model, SolverType.valueOf(libLinearProto.getSolverType()));
        setField(de.bwaldvogel.liblinear.Model.class, "w", model, Util.toPrimitiveDouble(libLinearProto.getWList()));
        setField(de.bwaldvogel.liblinear.Model.class, "rho", model, Double.valueOf(libLinearProto.getRho()));
        return model;
    }

    private static <U> void setField(Class<U> cls, String str, U u, Object obj) {
        try {
            Field field = cls.getField(str);
            field.setAccessible(true);
            field.set(u, obj);
            field.setAccessible(false);
        } catch (IllegalAccessException | NoSuchFieldException e) {
            throw new IllegalStateException("Failed to write to field " + str, e);
        }
    }
}
