package org.tribuo.interop.oci;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.databind.ser.impl.SimpleFilterProvider;
import com.oracle.bmc.datascience.DataScienceClient;
import com.oracle.bmc.http.internal.ExplicitlySetFilter;
import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Iterator;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.Model;
import org.tribuo.ONNXExportable;
import org.tribuo.VariableIDInfo;
import org.tribuo.VariableInfo;
import org.tribuo.classification.Label;
import org.tribuo.classification.evaluation.LabelEvaluation;
import org.tribuo.classification.evaluation.LabelEvaluator;
import org.tribuo.interop.oci.OCIUtil;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/interop/oci/OCIModelCLI.class */
public abstract class OCIModelCLI {
    private static final Logger logger = Logger.getLogger(OCIModelCLI.class.getName());

    /* loaded from: input_file:org/tribuo/interop/oci/OCIModelCLI$OCIModelOptions.class */
    public static final class OCIModelOptions implements Options {

        @Option(charName = 'm', longName = "mode", usage = "Deploy or score an OCI DS model.")
        public Mode mode;

        @Option(charName = 'p', longName = "project-id", usage = "Project ID.")
        public String projectID;

        @Option(charName = 'c', longName = "compartment-id", usage = "Compartment ID.")
        public String compartmentID;

        @Option(charName = 'd', longName = "deploy-model-path", usage = "Path to the serialized model to deploy to OCI DS.")
        public Path modelPath;

        @Option(longName = "model-protobuf", usage = "Is the model stored in protobuf format?")
        public boolean modelProtobuf;

        @Option(longName = "model-id", usage = "The id of the model.")
        public String modelId;

        @Option(charName = 's', longName = "dataset-path", usage = "Path to the serialized dataset to score.")
        public Path datasetPath;

        @Option(longName = "dataset-protobuf", usage = "Is the serialized dataset a protobuf?")
        public boolean datasetProtobuf;

        @Option(charName = 'i', longName = "model-deployment-id", usage = "The id of the model deployment.")
        public String modelDeploymentId;

        @Option(longName = "oci-domain", usage = "The OCI endpoint domain.")
        public String endpointDomain;

        @Option(longName = "conda-name", usage = "OCI DS conda environment name.")
        public String condaName;

        @Option(longName = "conda-name", usage = "OCI DS conda environment path in object storage.")
        public String condaPath;

        @Option(longName = "model-display-name", usage = "Model display name.")
        public String modelDisplayName = "tribuo-test";

        @Option(longName = "model-instance-count", usage = "Number of model instances to deploy.")
        public int instanceCount = 1;

        @Option(longName = "model-bandwidth", usage = "Model bandwidth in MBPS.")
        public int bandwidth = 10;

        @Option(longName = "model-instance-shape", usage = "OCI shape to run the model on.")
        public String instanceShape = "VM.Standard2.1";

        @Option(longName = "oci-config-file", usage = "OCI config file path. If null use the default.")
        public Path ociConfigFile = null;

        @Option(longName = "oci-config-file-profile", usage = "OCI profile in the config file. If null use the default.")
        public String ociConfigProfile = null;

        /* loaded from: input_file:org/tribuo/interop/oci/OCIModelCLI$OCIModelOptions$Mode.class */
        public enum Mode {
            CREATE_AND_DEPLOY,
            DEPLOY,
            SCORE
        }

        public String getOptionsDescription() {
            return "OCIModelCLI deploys and scores a Tribuo Classification model using OCI Data Science Model Deployment.";
        }

        DataScienceClient makeClient() throws IOException {
            return new DataScienceClient(OCIModel.makeAuthProvider(this.ociConfigFile, this.ociConfigProfile));
        }
    }

    private OCIModelCLI() {
    }

    private static void createModelAndDeploy(OCIModelOptions oCIModelOptions) throws IOException, ClassNotFoundException {
        Model castModel;
        if (oCIModelOptions.modelProtobuf) {
            castModel = Model.deserializeFromFile(oCIModelOptions.modelPath).castModel(Label.class);
        } else {
            ObjectInputStream objectInputStream = new ObjectInputStream(Files.newInputStream(oCIModelOptions.modelPath, new OpenOption[0]));
            try {
                castModel = ((Model) objectInputStream.readObject()).castModel(Label.class);
                objectInputStream.close();
            } catch (Throwable th) {
                try {
                    objectInputStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
                throw th;
            }
        }
        if (!(castModel instanceof ONNXExportable)) {
            throw new IllegalArgumentException("Model not ONNXExportable, received " + castModel.toString());
        }
        ObjectMapper createObjectMapper = OCIUtil.createObjectMapper();
        DataScienceClient makeClient = oCIModelOptions.makeClient();
        OCIUtil.OCIDSConfig oCIDSConfig = new OCIUtil.OCIDSConfig(oCIModelOptions.compartmentID, oCIModelOptions.projectID);
        OCIUtil.deploy(new OCIUtil.OCIModelDeploymentConfig(oCIDSConfig, OCIUtil.createModel((ONNXExportable) castModel, makeClient, createObjectMapper, new OCIUtil.OCIModelArtifactConfig(oCIDSConfig, oCIModelOptions.modelDisplayName, "Deployed Tribuo Model", "org.tribuo.oci", 1, oCIModelOptions.condaName, oCIModelOptions.condaPath)), oCIModelOptions.modelDisplayName, oCIModelOptions.instanceShape, oCIModelOptions.bandwidth, oCIModelOptions.instanceCount), makeClient, createObjectMapper);
        makeClient.close();
    }

    private static void deploy(OCIModelOptions oCIModelOptions) throws IOException {
        DataScienceClient makeClient = oCIModelOptions.makeClient();
        ObjectMapper enable = new ObjectMapper().enable(SerializationFeature.INDENT_OUTPUT);
        enable.setFilterProvider(new SimpleFilterProvider().addFilter("explicitlySetFilter", ExplicitlySetFilter.INSTANCE));
        System.out.println("Deployment URL = " + OCIUtil.deploy(new OCIUtil.OCIModelDeploymentConfig(new OCIUtil.OCIDSConfig(oCIModelOptions.compartmentID, oCIModelOptions.projectID), oCIModelOptions.modelId, oCIModelOptions.modelDisplayName, oCIModelOptions.instanceShape, oCIModelOptions.bandwidth, oCIModelOptions.instanceCount), makeClient, enable));
        makeClient.close();
    }

    private static void modelScoring(OCIModelOptions oCIModelOptions) throws IOException, ClassNotFoundException {
        Dataset castDataset;
        if (oCIModelOptions.datasetProtobuf) {
            castDataset = Dataset.castDataset(Dataset.deserializeFromFile(oCIModelOptions.datasetPath), Label.class);
        } else {
            ObjectInputStream objectInputStream = new ObjectInputStream(Files.newInputStream(oCIModelOptions.datasetPath, new OpenOption[0]));
            try {
                castDataset = Dataset.castDataset((Dataset) objectInputStream.readObject(), Label.class);
                objectInputStream.close();
            } catch (Throwable th) {
                try {
                    objectInputStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
                throw th;
            }
        }
        ImmutableFeatureMap featureIDMap = castDataset.getFeatureIDMap();
        HashMap hashMap = new HashMap();
        Iterator it = featureIDMap.iterator();
        while (it.hasNext()) {
            VariableIDInfo variableIDInfo = (VariableInfo) it.next();
            hashMap.put(variableIDInfo.getName(), Integer.valueOf(variableIDInfo.getID()));
        }
        HashMap hashMap2 = new HashMap();
        for (Pair pair : castDataset.getOutputIDInfo()) {
            hashMap2.put((Label) pair.getB(), (Integer) pair.getA());
        }
        OCIModel createOCIModel = OCIModel.createOCIModel(castDataset.getOutputFactory(), hashMap, hashMap2, oCIModelOptions.ociConfigFile, oCIModelOptions.ociConfigProfile, oCIModelOptions.endpointDomain + oCIModelOptions.modelDeploymentId, new OCILabelConverter(true));
        System.out.println("Scoring using OCIModel - " + createOCIModel.toString());
        LabelEvaluator labelEvaluator = new LabelEvaluator();
        long currentTimeMillis = System.currentTimeMillis();
        LabelEvaluation evaluate = labelEvaluator.evaluate(createOCIModel, castDataset);
        long currentTimeMillis2 = System.currentTimeMillis();
        System.out.println("Scoring took - " + Util.formatDuration(currentTimeMillis, currentTimeMillis2));
        System.out.println(((currentTimeMillis2 - currentTimeMillis) / castDataset.size()) + "ms per example");
        System.out.println(evaluate.toString());
        System.out.println(evaluate.getConfusionMatrix().toString());
    }

    public static void main(String[] strArr) throws IOException, ClassNotFoundException {
        OCIModelOptions oCIModelOptions = new OCIModelOptions();
        try {
            new ConfigurationManager(strArr, oCIModelOptions, false);
            switch (oCIModelOptions.mode) {
                case CREATE_AND_DEPLOY:
                    createModelAndDeploy(oCIModelOptions);
                    return;
                case DEPLOY:
                    deploy(oCIModelOptions);
                    return;
                case SCORE:
                    modelScoring(oCIModelOptions);
                    return;
                default:
                    return;
            }
        } catch (UsageException e) {
            logger.info(e.getMessage());
        }
    }
}
