package org.tribuo.interop.oci;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.databind.ser.impl.SimpleFilterProvider;
import com.oracle.bmc.datascience.DataScienceClient;
import com.oracle.bmc.datascience.model.CreateModelDeploymentDetails;
import com.oracle.bmc.datascience.model.CreateModelDetails;
import com.oracle.bmc.datascience.model.FixedSizeScalingPolicy;
import com.oracle.bmc.datascience.model.InstanceConfiguration;
import com.oracle.bmc.datascience.model.Metadata;
import com.oracle.bmc.datascience.model.ModelConfigurationDetails;
import com.oracle.bmc.datascience.model.ModelDeployment;
import com.oracle.bmc.datascience.model.SingleModelDeploymentConfigurationDetails;
import com.oracle.bmc.datascience.requests.CreateModelArtifactRequest;
import com.oracle.bmc.datascience.requests.CreateModelDeploymentRequest;
import com.oracle.bmc.datascience.requests.CreateModelRequest;
import com.oracle.bmc.http.internal.ExplicitlySetFilter;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.nio.file.CopyOption;
import java.nio.file.FileSystem;
import java.nio.file.FileSystems;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Map;
import java.util.logging.Logger;
import java.util.regex.Pattern;
import org.tribuo.Model;
import org.tribuo.ONNXExportable;
import org.tribuo.Output;
import org.tribuo.anomaly.Event;
import org.tribuo.classification.Label;
import org.tribuo.clustering.ClusterID;
import org.tribuo.multilabel.MultiLabel;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.regression.Regressor;

/* loaded from: input_file:org/tribuo/interop/oci/OCIUtil.class */
public abstract class OCIUtil {
    private static final Logger logger = Logger.getLogger(OCIUtil.class.getName());
    private static final Pattern CONDA_NAME_PATTERN = Pattern.compile("\\w*");
    private static final Pattern CONDA_ENV_PATH_PATTERN = Pattern.compile("[\\w -@.:/]*");
    private static final int MAX_OBJECT_STORAGE_LENGTH = 1024;
    private static final String OCI_PROTOCOL = "oci://";
    private static final String RUNTIME_YAML_HEADER = "# Copyright (c) 2021, Oracle and/or its affiliates.  All rights reserved.\n# This software is available under the Apache License 2.0 as shown at http://www.apache.org/licenses/LICENSE-2.0.\nMODEL_ARTIFACT_VERSION: '3.0'\nMODEL_DEPLOYMENT:\n  INFERENCE_CONDA_ENV:\n";
    private static final String RUNTIME_YAML_ENV_SLUG = "    INFERENCE_ENV_SLUG: ";
    private static final String RUNTIME_YAML_ENV_TYPE = "    INFERENCE_ENV_TYPE: data_science\n";
    private static final String RUNTIME_YAML_ENV_PATH = "    INFERENCE_ENV_PATH: ";
    private static final String RUNTIME_YAML_PYTHON_VERSION = "    INFERENCE_PYTHON_VERSION: '3.7'";

    /* loaded from: input_file:org/tribuo/interop/oci/OCIUtil$OCIDSConfig.class */
    public static final class OCIDSConfig {
        public final String compartmentID;
        public final String projectID;

        public OCIDSConfig(String str, String str2) {
            this.compartmentID = str;
            this.projectID = str2;
        }
    }

    /* loaded from: input_file:org/tribuo/interop/oci/OCIUtil$OCIModelArtifactConfig.class */
    public static final class OCIModelArtifactConfig {
        public final OCIDSConfig dsConfig;
        public final String modelName;
        public final String modelDescription;
        public final String onnxDomain;
        public final int onnxModelVersion;
        public final String condaName;
        public final String condaPath;

        public OCIModelArtifactConfig(OCIDSConfig oCIDSConfig, String str, String str2, String str3, int i, String str4, String str5) {
            this.dsConfig = oCIDSConfig;
            this.modelDescription = str2;
            this.modelName = str;
            this.onnxDomain = str3;
            this.onnxModelVersion = i;
            this.condaName = str4;
            this.condaPath = str5;
        }
    }

    /* loaded from: input_file:org/tribuo/interop/oci/OCIUtil$OCIModelDeploymentConfig.class */
    public static final class OCIModelDeploymentConfig {
        public final OCIDSConfig dsConfig;
        public final int bandwidth;
        public final int instanceCount;
        public final String deploymentName;
        public final String shape;
        public final String modelID;

        public OCIModelDeploymentConfig(OCIDSConfig oCIDSConfig, String str, String str2, String str3, int i, int i2) {
            if (i2 < 1) {
                throw new IllegalArgumentException("Instance count must be positive, found " + i2);
            }
            if (i < 10) {
                throw new IllegalArgumentException("Bandwidth must be 10 or greater, found " + i);
            }
            if (str2 == null || str2.isEmpty()) {
                throw new IllegalArgumentException("Must supply valid deployment name");
            }
            if (str3 == null || str3.isEmpty()) {
                throw new IllegalArgumentException("Must supply valid instance shape");
            }
            if (str == null || str.isEmpty()) {
                throw new IllegalArgumentException("Must supply valid modelID");
            }
            this.dsConfig = oCIDSConfig;
            this.modelID = str;
            this.deploymentName = str2;
            this.shape = str3;
            this.bandwidth = i;
            this.instanceCount = i2;
        }
    }

    /* loaded from: input_file:org/tribuo/interop/oci/OCIUtil$OCIModelType.class */
    public enum OCIModelType {
        BINARY_CLASSIFICATION("binary_classification"),
        REGRESSION("regression"),
        MULTINOMIAL_CLASSIFICATION("multinomial_classification"),
        CLUSTERING("clustering"),
        RECOMMENDER("recommender"),
        DIMENSIONALITY_REDUCTION("dimensionality_reduction/representation"),
        TIME_SERIES("time_series_forecasting"),
        ANOMALY_DETECTION("anomaly_detection"),
        TOPIC_MODELLING("topic_modelling"),
        NER("ner"),
        SENTIMENT_ANALYSIS("sentiment_analysis"),
        IMAGE_CLASSIFICATION("image_classification"),
        OBJECT_LOCALIZATION("object_localization"),
        OTHER("other");

        public final String modelType;

        OCIModelType(String str) {
            this.modelType = str;
        }
    }

    private OCIUtil() {
    }

    private static void storeResource(FileSystem fileSystem, String str) throws IOException {
        storeStream(fileSystem, str, OCIUtil.class.getResourceAsStream(str));
    }

    private static void storeStream(FileSystem fileSystem, String str, InputStream inputStream) throws IOException {
        Files.copy(inputStream, fileSystem.getPath("/", str), new CopyOption[0]);
    }

    public static ObjectMapper createObjectMapper() {
        ObjectMapper enable = new ObjectMapper().enable(SerializationFeature.INDENT_OUTPUT);
        enable.setFilterProvider(new SimpleFilterProvider().addFilter("explicitlySetFilter", ExplicitlySetFilter.INSTANCE));
        return enable;
    }

    public static <T extends Output<T>, U extends Model<T> & ONNXExportable> String createModel(U u, DataScienceClient dataScienceClient, ObjectMapper objectMapper, OCIModelArtifactConfig oCIModelArtifactConfig) throws IOException {
        OCIModelType oCIModelType;
        Path createTempFile = Files.createTempFile("model", ".onnx", new FileAttribute[0]);
        u.saveONNXModel(oCIModelArtifactConfig.onnxDomain, oCIModelArtifactConfig.onnxModelVersion, createTempFile);
        createTempFile.toFile().deleteOnExit();
        ModelProvenance provenance = u.getProvenance();
        if (u.validate(Label.class)) {
            oCIModelType = OCIModelType.MULTINOMIAL_CLASSIFICATION;
        } else if (u.validate(Regressor.class)) {
            oCIModelType = OCIModelType.REGRESSION;
        } else if (u.validate(Event.class)) {
            oCIModelType = OCIModelType.ANOMALY_DETECTION;
        } else if (u.validate(MultiLabel.class)) {
            oCIModelType = OCIModelType.OTHER;
        } else {
            if (!u.validate(ClusterID.class)) {
                throw new IllegalArgumentException("Unsupported model type " + u.toString());
            }
            oCIModelType = OCIModelType.CLUSTERING;
        }
        return createModel(createTempFile, provenance, oCIModelType, dataScienceClient, objectMapper, oCIModelArtifactConfig);
    }

    protected static Path createModelArtifact(Path path, OCIModelArtifactConfig oCIModelArtifactConfig) throws IOException {
        Path createTempFile = Files.createTempFile("oci-ds-model-deployment", ".zip", new FileAttribute[0]);
        URI create = URI.create("jar:" + createTempFile.toUri());
        Files.delete(createTempFile);
        FileSystem newFileSystem = FileSystems.newFileSystem(create, (Map<String, ?>) Collections.singletonMap("create", "true"));
        try {
            storeResource(newFileSystem, "score.py");
            storeStream(newFileSystem, "runtime.yaml", new ByteArrayInputStream(buildRuntimeYaml(oCIModelArtifactConfig.condaName, oCIModelArtifactConfig.condaPath).getBytes(StandardCharsets.UTF_8)));
            storeStream(newFileSystem, "model.onnx", Files.newInputStream(path, new OpenOption[0]));
            if (newFileSystem != null) {
                newFileSystem.close();
            }
            return createTempFile;
        } catch (Throwable th) {
            if (newFileSystem != null) {
                try {
                    newFileSystem.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public static String createModel(Path path, ModelProvenance modelProvenance, OCIModelType oCIModelType, DataScienceClient dataScienceClient, ObjectMapper objectMapper, OCIModelArtifactConfig oCIModelArtifactConfig) throws IOException {
        Path createModelArtifact = createModelArtifact(path, oCIModelArtifactConfig);
        createModelArtifact.toFile().deleteOnExit();
        CreateModelDetails.Builder builder = CreateModelDetails.builder();
        ArrayList arrayList = new ArrayList();
        arrayList.add(Metadata.builder().key("UseCaseType").value(oCIModelType.modelType).build());
        arrayList.add(Metadata.builder().key("Framework").value("other").build());
        arrayList.add(Metadata.builder().key("FrameworkVersion").value("4.3.0").build());
        arrayList.add(Metadata.builder().key("Algorithm").value(modelProvenance.getTrainerProvenance().getClassName()).build());
        arrayList.add(Metadata.builder().key("hyperparameters").value(objectMapper.writeValueAsString(ProvenanceUtil.convertToMap(modelProvenance.getTrainerProvenance()))).build());
        builder.definedMetadataList(arrayList);
        builder.compartmentId(oCIModelArtifactConfig.dsConfig.compartmentID);
        builder.projectId(oCIModelArtifactConfig.dsConfig.projectID);
        builder.displayName(oCIModelArtifactConfig.modelName);
        builder.description(oCIModelArtifactConfig.modelDescription);
        CreateModelRequest.Builder builder2 = CreateModelRequest.builder();
        builder2.createModelDetails(builder.build());
        com.oracle.bmc.datascience.model.Model model = dataScienceClient.createModel(builder2.build()).getModel();
        logger.info("\n\nCreate model response: \n" + objectMapper.writeValueAsString(model));
        String id = model.getId();
        logger.info("Model ID = " + id);
        CreateModelArtifactRequest.Builder builder3 = CreateModelArtifactRequest.builder();
        builder3.modelArtifact(Files.newInputStream(createModelArtifact, new OpenOption[0]));
        builder3.modelId(id);
        builder3.contentDisposition("attachment; filename=\"" + createModelArtifact.toString() + "\"");
        logger.info("Create artifact response: \n" + dataScienceClient.createModelArtifact(builder3.build()));
        return id;
    }

    public static String deploy(OCIModelDeploymentConfig oCIModelDeploymentConfig, DataScienceClient dataScienceClient, ObjectMapper objectMapper) throws IOException {
        CreateModelDeploymentRequest.Builder builder = CreateModelDeploymentRequest.builder();
        builder.createModelDeploymentDetails(CreateModelDeploymentDetails.builder().projectId(oCIModelDeploymentConfig.dsConfig.projectID).displayName(oCIModelDeploymentConfig.deploymentName).compartmentId(oCIModelDeploymentConfig.dsConfig.compartmentID).modelDeploymentConfigurationDetails(SingleModelDeploymentConfigurationDetails.builder().modelConfigurationDetails(ModelConfigurationDetails.builder().modelId(oCIModelDeploymentConfig.modelID).instanceConfiguration(InstanceConfiguration.builder().instanceShapeName(oCIModelDeploymentConfig.shape).build()).bandwidthMbps(Integer.valueOf(oCIModelDeploymentConfig.bandwidth)).scalingPolicy(FixedSizeScalingPolicy.builder().instanceCount(Integer.valueOf(oCIModelDeploymentConfig.instanceCount)).build()).build()).build()).build());
        ModelDeployment modelDeployment = dataScienceClient.createModelDeployment(builder.build()).getModelDeployment();
        logger.info("Create deployment response: \n" + objectMapper.writeValueAsString(modelDeployment));
        return modelDeployment.getModelDeploymentUrl();
    }

    protected static String buildRuntimeYaml(String str, String str2) {
        if (!validateCondaName(str)) {
            throw new IllegalArgumentException("Invalid conda name '" + str + "'");
        }
        if (validateCondaPath(str2)) {
            return "# Copyright (c) 2021, Oracle and/or its affiliates.  All rights reserved.\n# This software is available under the Apache License 2.0 as shown at http://www.apache.org/licenses/LICENSE-2.0.\nMODEL_ARTIFACT_VERSION: '3.0'\nMODEL_DEPLOYMENT:\n  INFERENCE_CONDA_ENV:\n    INFERENCE_ENV_SLUG: " + str + "\n" + RUNTIME_YAML_ENV_TYPE + RUNTIME_YAML_ENV_PATH + str2 + "\n" + RUNTIME_YAML_PYTHON_VERSION;
        }
        throw new IllegalArgumentException("Invalid conda path '" + str2 + "'");
    }

    protected static boolean validateCondaName(String str) {
        return str != null && CONDA_NAME_PATTERN.matcher(str).matches();
    }

    protected static boolean validateCondaPath(String str) {
        return str != null && CONDA_ENV_PATH_PATTERN.matcher(str).matches() && str.length() < MAX_OBJECT_STORAGE_LENGTH && str.startsWith(OCI_PROTOCOL);
    }
}
