package org.tribuo.json;

import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.config.json.JsonProvenanceSerialization;
import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance;
import com.oracle.labs.mlrg.olcut.util.IOUtil;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.ensemble.EnsembleModel;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.EnsembleModelProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.EmptyDatasetProvenance;
import org.tribuo.provenance.impl.EmptyTrainerProvenance;

/* loaded from: input_file:org/tribuo/json/StripProvenance.class */
public final class StripProvenance {
    private static final Logger logger = Logger.getLogger(StripProvenance.class.getName());

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/tribuo/json/StripProvenance$ModelTuple.class */
    public static class ModelTuple<T extends Output<T>> {
        public final Model<T> model;
        public final ModelProvenance provenance;

        public ModelTuple(Model<T> model, ModelProvenance modelProvenance) {
            this.model = model;
            this.provenance = modelProvenance;
        }
    }

    /* loaded from: input_file:org/tribuo/json/StripProvenance$ProvenanceTypes.class */
    public enum ProvenanceTypes {
        DATASET,
        TRAINER,
        INSTANCE,
        SYSTEM,
        ALL
    }

    /* loaded from: input_file:org/tribuo/json/StripProvenance$StripProvenanceOptions.class */
    public static class StripProvenanceOptions implements Options {

        @Option(charName = 'h', longName = "store-provenance-hash", usage = "Stores a hash of the model provenance in the stripped model.")
        public boolean storeHash;

        @Option(charName = 'i', longName = "input-model-path", usage = "The model to load.")
        public File inputModel;

        @Option(charName = 'o', longName = "output-model-path", usage = "The location to write out the stripped model.")
        public File outputModel;

        @Option(charName = 'p', longName = "provenance-path", usage = "Write out the stripped provenance as json.")
        public File provenanceFile;

        @Option(charName = 'r', longName = "remove-provenances", usage = "The provenances to remove")
        public EnumSet<ProvenanceTypes> removeProvenances = EnumSet.noneOf(ProvenanceTypes.class);

        @Option(charName = 't', longName = "hash-type", usage = "The hash type to use.")
        public ProvenanceUtil.HashType hashType = ObjectProvenance.DEFAULT_HASH_TYPE;

        @Option(longName = "model-protobuf", usage = "Read and write protobuf formatted models.")
        public boolean protobuf;

        public String getOptionsDescription() {
            return "A program for removing Provenance information from a Tribuo Model or SequenceModel.";
        }
    }

    private StripProvenance() {
    }

    private static ModelProvenance cleanProvenance(ModelProvenance modelProvenance, String str, StripProvenanceOptions stripProvenanceOptions) {
        HashMap hashMap;
        OffsetDateTime offsetDateTime;
        DatasetProvenance emptyDatasetProvenance = (stripProvenanceOptions.removeProvenances.contains(ProvenanceTypes.ALL) || stripProvenanceOptions.removeProvenances.contains(ProvenanceTypes.DATASET)) ? new EmptyDatasetProvenance() : modelProvenance.getDatasetProvenance();
        TrainerProvenance emptyTrainerProvenance = (stripProvenanceOptions.removeProvenances.contains(ProvenanceTypes.ALL) || stripProvenanceOptions.removeProvenances.contains(ProvenanceTypes.TRAINER)) ? new EmptyTrainerProvenance() : modelProvenance.getTrainerProvenance();
        if (stripProvenanceOptions.removeProvenances.contains(ProvenanceTypes.ALL) || stripProvenanceOptions.removeProvenances.contains(ProvenanceTypes.INSTANCE)) {
            hashMap = new HashMap();
            offsetDateTime = OffsetDateTime.MIN;
        } else {
            hashMap = new HashMap(modelProvenance.getInstanceProvenance().getMap());
            offsetDateTime = modelProvenance.getTrainingTime();
        }
        if (stripProvenanceOptions.storeHash) {
            logger.info("Writing provenance hash into instance map.");
            hashMap.put("original-provenance-hash", new HashProvenance(stripProvenanceOptions.hashType, "original-provenance-hash", str));
        }
        return new ModelProvenance(modelProvenance.getClassName(), offsetDateTime, emptyDatasetProvenance, emptyTrainerProvenance, hashMap, !(stripProvenanceOptions.removeProvenances.contains(ProvenanceTypes.ALL) || stripProvenanceOptions.removeProvenances.contains(ProvenanceTypes.SYSTEM)));
    }

    private static EnsembleModelProvenance cleanEnsembleProvenance(EnsembleModelProvenance ensembleModelProvenance, ListProvenance<ModelProvenance> listProvenance, String str, StripProvenanceOptions stripProvenanceOptions) {
        HashMap hashMap;
        OffsetDateTime offsetDateTime;
        DatasetProvenance emptyDatasetProvenance = (stripProvenanceOptions.removeProvenances.contains(ProvenanceTypes.ALL) || stripProvenanceOptions.removeProvenances.contains(ProvenanceTypes.DATASET)) ? new EmptyDatasetProvenance() : ensembleModelProvenance.getDatasetProvenance();
        TrainerProvenance emptyTrainerProvenance = (stripProvenanceOptions.removeProvenances.contains(ProvenanceTypes.ALL) || stripProvenanceOptions.removeProvenances.contains(ProvenanceTypes.TRAINER)) ? new EmptyTrainerProvenance() : ensembleModelProvenance.getTrainerProvenance();
        if (stripProvenanceOptions.removeProvenances.contains(ProvenanceTypes.ALL) || stripProvenanceOptions.removeProvenances.contains(ProvenanceTypes.INSTANCE)) {
            hashMap = new HashMap();
            offsetDateTime = OffsetDateTime.MIN;
        } else {
            hashMap = new HashMap(ensembleModelProvenance.getInstanceProvenance().getMap());
            offsetDateTime = ensembleModelProvenance.getTrainingTime();
        }
        if (stripProvenanceOptions.storeHash) {
            logger.info("Writing provenance hash into instance map.");
            hashMap.put("original-provenance-hash", new HashProvenance(stripProvenanceOptions.hashType, "original-provenance-hash", str));
        }
        return new EnsembleModelProvenance(ensembleModelProvenance.getClassName(), offsetDateTime, emptyDatasetProvenance, emptyTrainerProvenance, hashMap, listProvenance);
    }

    private static <T extends Output<T>> ModelTuple<T> convertModel(Model<T> model, String str, StripProvenanceOptions stripProvenanceOptions) throws InvocationTargetException, IllegalAccessException, NoSuchMethodException {
        if (!(model instanceof EnsembleModel)) {
            ModelProvenance cleanProvenance = cleanProvenance(model.getProvenance(), str, stripProvenanceOptions);
            Method declaredMethod = model.getClass().getDeclaredMethod("copy", String.class, ModelProvenance.class);
            boolean isAccessible = declaredMethod.isAccessible();
            declaredMethod.setAccessible(true);
            Model model2 = (Model) declaredMethod.invoke(model, model.getName().isEmpty() ? "deprovenanced" : model.getName() + "-deprovenanced", cleanProvenance);
            declaredMethod.setAccessible(isAccessible);
            return new ModelTuple<>(model2, cleanProvenance);
        }
        EnsembleModelProvenance provenance = ((EnsembleModel) model).getProvenance();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        Iterator it = ((EnsembleModel) model).getModels().iterator();
        while (it.hasNext()) {
            ModelTuple convertModel = convertModel((Model) it.next(), str, stripProvenanceOptions);
            arrayList.add(convertModel.provenance);
            arrayList2.add(convertModel.model);
        }
        EnsembleModelProvenance cleanEnsembleProvenance = cleanEnsembleProvenance(provenance, new ListProvenance(arrayList), str, stripProvenanceOptions);
        Method declaredMethod2 = model.getClass().getDeclaredMethod("copy", String.class, ModelProvenance.class, List.class);
        boolean isAccessible2 = declaredMethod2.isAccessible();
        declaredMethod2.setAccessible(true);
        EnsembleModel ensembleModel = (EnsembleModel) declaredMethod2.invoke(model, model.getName().isEmpty() ? "deprovenanced" : model.getName() + "-deprovenanced", cleanEnsembleProvenance, arrayList2);
        declaredMethod2.setAccessible(isAccessible2);
        return new ModelTuple<>(ensembleModel, cleanEnsembleProvenance);
    }

    public static <T extends Output<T>> void main(String[] strArr) {
        Model model;
        LabsLogFormatter.setAllLogFormatters();
        StripProvenanceOptions stripProvenanceOptions = new StripProvenanceOptions();
        try {
            ConfigurationManager configurationManager = new ConfigurationManager(strArr, stripProvenanceOptions);
            if (stripProvenanceOptions.inputModel == null || stripProvenanceOptions.outputModel == null) {
                logger.info(configurationManager.usage());
                System.exit(1);
            }
            try {
                logger.info("Loading model from " + stripProvenanceOptions.inputModel);
                if (stripProvenanceOptions.protobuf) {
                    model = Model.deserializeFromFile(stripProvenanceOptions.inputModel.toPath());
                } else {
                    ObjectInputStream objectInputStream = IOUtil.getObjectInputStream(stripProvenanceOptions.inputModel);
                    try {
                        model = (Model) objectInputStream.readObject();
                        if (objectInputStream != null) {
                            objectInputStream.close();
                        }
                    } catch (Throwable th) {
                        if (objectInputStream != null) {
                            try {
                                objectInputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                }
                ModelProvenance provenance = model.getProvenance();
                logger.info("Marshalling provenance and creating JSON.");
                JsonProvenanceSerialization jsonProvenanceSerialization = new JsonProvenanceSerialization(true);
                String marshalAndSerialize = jsonProvenanceSerialization.marshalAndSerialize(provenance);
                logger.info("Hashing JSON file");
                String bytesToHexString = ProvenanceUtil.bytesToHexString(stripProvenanceOptions.hashType.getDigest().digest(marshalAndSerialize.getBytes(StandardCharsets.UTF_8)));
                logger.info("Provenance hash = " + bytesToHexString);
                if (stripProvenanceOptions.provenanceFile != null) {
                    logger.info("Writing JSON provenance to " + stripProvenanceOptions.provenanceFile.toString());
                    PrintWriter printWriter = new PrintWriter(new OutputStreamWriter(new FileOutputStream(stripProvenanceOptions.provenanceFile), StandardCharsets.UTF_8));
                    try {
                        printWriter.println(marshalAndSerialize);
                        printWriter.close();
                    } catch (Throwable th3) {
                        try {
                            printWriter.close();
                        } catch (Throwable th4) {
                            th3.addSuppressed(th4);
                        }
                        throw th3;
                    }
                }
                ModelTuple convertModel = convertModel(model, bytesToHexString, stripProvenanceOptions);
                logger.info("Writing model to " + stripProvenanceOptions.outputModel);
                ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(stripProvenanceOptions.outputModel));
                try {
                    objectOutputStream.writeObject(convertModel.model);
                    objectOutputStream.close();
                    ModelProvenance modelProvenance = convertModel.provenance;
                    logger.info("Marshalling provenance and creating JSON.");
                    String marshalAndSerialize2 = jsonProvenanceSerialization.marshalAndSerialize(modelProvenance);
                    logger.info("Old provenance = \n" + marshalAndSerialize);
                    logger.info("New provenance = \n" + marshalAndSerialize2);
                } catch (Throwable th5) {
                    try {
                        objectOutputStream.close();
                    } catch (Throwable th6) {
                        th5.addSuppressed(th6);
                    }
                    throw th5;
                }
            } catch (FileNotFoundException e) {
                logger.log(Level.SEVERE, "Failed to find the input file.", (Throwable) e);
            } catch (UnsupportedEncodingException e2) {
                logger.log(Level.SEVERE, "Unsupported encoding exception.", (Throwable) e2);
            } catch (IOException e3) {
                logger.log(Level.SEVERE, "IO error when reading or writing a file.", (Throwable) e3);
            } catch (ClassNotFoundException e4) {
                logger.log(Level.SEVERE, "The model and/or provenance classes are not on the classpath.", (Throwable) e4);
            } catch (IllegalAccessException e5) {
                logger.log(Level.SEVERE, "Failed to modify protection on inner copy method on Model.", (Throwable) e5);
            } catch (NoSuchMethodException e6) {
                logger.log(Level.SEVERE, "Model.copy method missing on a class which extends Model.", (Throwable) e6);
            } catch (InvocationTargetException e7) {
                logger.log(Level.SEVERE, "Failed to invoke inner copy method on Model.", (Throwable) e7);
            }
        } catch (UsageException e8) {
            logger.info(e8.getMessage());
        }
    }
}
