package org.tribuo.interop.oci;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.bmc.ConfigFileReader;
import com.oracle.bmc.auth.AuthenticationDetailsProvider;
import com.oracle.bmc.auth.ConfigFileAuthenticationDetailsProvider;
import com.oracle.bmc.http.signing.RequestSigningFilter;
import com.oracle.labs.mlrg.olcut.provenance.primitives.FileProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.OffsetDateTime;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import javax.ws.rs.client.Client;
import javax.ws.rs.client.ClientBuilder;
import javax.ws.rs.client.Entity;
import javax.ws.rs.client.Invocation;
import javax.ws.rs.client.WebTarget;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.Prediction;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.interop.ExternalDatasetProvenance;
import org.tribuo.interop.ExternalModel;
import org.tribuo.interop.ExternalTrainerProvenance;
import org.tribuo.interop.oci.protos.OCIModelProto;
import org.tribuo.interop.oci.protos.OCIOutputConverterProto;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.SparseVector;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/interop/oci/OCIModel.class */
public final class OCIModel<T extends Output<T>> extends ExternalModel<T, DenseMatrix, DenseMatrix> implements AutoCloseable {
    private static final long serialVersionUID = 1;
    private static final Logger logger = Logger.getLogger(OCIModel.class.getName());
    public static final int CURRENT_VERSION = 0;
    private final Path configFile;
    private final String profileName;
    private final String endpointURL;
    private final String modelDeploymentId;
    private final OCIOutputConverter<T> outputConverter;
    private transient AuthenticationDetailsProvider authProvider;
    private transient RequestSigningFilter requestSigningFilter;
    private transient Client jerseyClient;
    private transient WebTarget modelEndpoint;
    private transient ObjectMapper mapper;

    /* loaded from: input_file:org/tribuo/interop/oci/OCIModel$PredictionJson.class */
    public static final class PredictionJson {

        @JsonProperty("prediction")
        public double[][] prediction;

        @JsonCreator
        public PredictionJson(@JsonProperty("prediction") double[][] dArr) {
            this.prediction = dArr;
        }
    }

    OCIModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, Map<String, Integer> map, Path path, String str2, String str3, String str4, OCIOutputConverter<T> oCIOutputConverter) throws IOException {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, oCIOutputConverter.generatesProbabilities(), map);
        this.configFile = path;
        this.profileName = str2;
        this.endpointURL = str3;
        this.modelDeploymentId = str4;
        this.outputConverter = oCIOutputConverter;
        this.authProvider = makeAuthProvider(path, str2);
        this.mapper = new ObjectMapper();
        System.setProperty("sun.net.http.allowRestrictedHeaders", "true");
        this.requestSigningFilter = RequestSigningFilter.fromAuthProvider(this.authProvider);
        this.jerseyClient = ClientBuilder.newBuilder().build().register(this.requestSigningFilter);
        this.modelEndpoint = this.jerseyClient.target(str3 + str4).path("predict");
    }

    OCIModel(String str, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, int[] iArr, int[] iArr2, Path path, String str2, AuthenticationDetailsProvider authenticationDetailsProvider, String str3, String str4, OCIOutputConverter<T> oCIOutputConverter) {
        super(str, modelProvenance, immutableFeatureMap, immutableOutputInfo, iArr, iArr2, oCIOutputConverter.generatesProbabilities());
        this.configFile = path;
        this.profileName = str2;
        this.authProvider = authenticationDetailsProvider;
        this.endpointURL = str3;
        this.modelDeploymentId = str4;
        this.outputConverter = oCIOutputConverter;
        this.mapper = new ObjectMapper();
        System.setProperty("sun.net.http.allowRestrictedHeaders", "true");
        this.requestSigningFilter = RequestSigningFilter.fromAuthProvider(authenticationDetailsProvider);
        this.jerseyClient = ClientBuilder.newBuilder().build().register(this.requestSigningFilter);
        this.modelEndpoint = this.jerseyClient.target(str3 + str4).path("predict");
    }

    public static OCIModel<?> deserializeFromProto(int i, String str, Any any) throws InvalidProtocolBufferException, IOException {
        if (i < 0 || i > 0) {
            throw new IllegalArgumentException("Unknown version " + i + ", this class supports at most version 0");
        }
        OCIModelProto unpack = any.unpack(OCIModelProto.class);
        OCIOutputConverter oCIOutputConverter = (OCIOutputConverter) ProtoUtil.deserialize(unpack.getOutputConverter());
        ModelDataCarrier deserialize = ModelDataCarrier.deserialize(unpack.getMetadata());
        if (!deserialize.outputDomain().getOutput(0).getClass().equals(oCIOutputConverter.getTypeWitness())) {
            throw new IllegalStateException("Invalid protobuf, output domain does not match converter, found " + deserialize.outputDomain().getClass() + " and " + oCIOutputConverter.getTypeWitness());
        }
        Path path = Paths.get(unpack.getConfigFile(), new String[0]);
        ConfigFileAuthenticationDetailsProvider makeAuthProvider = makeAuthProvider(path, unpack.getProfileName());
        int[] primitiveInt = Util.toPrimitiveInt(unpack.getForwardFeatureMappingList());
        int[] primitiveInt2 = Util.toPrimitiveInt(unpack.getBackwardFeatureMappingList());
        if (validateFeatureMapping(primitiveInt, primitiveInt2, deserialize.featureDomain())) {
            return new OCIModel<>(deserialize.name(), deserialize.provenance(), deserialize.featureDomain(), deserialize.outputDomain(), primitiveInt, primitiveInt2, path, unpack.getProfileName(), makeAuthProvider, unpack.getEndpointUrl(), unpack.getModelDeploymentId(), oCIOutputConverter);
        }
        throw new IllegalStateException("Invalid protobuf, external<->Tribuo feature mapping does not form a bijection");
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        return Collections.emptyMap();
    }

    protected Model<T> copy(String str, ModelProvenance modelProvenance) {
        return new OCIModel(str, modelProvenance, this.featureIDMap, this.outputIDInfo, this.featureForwardMapping, this.featureBackwardMapping, this.configFile, this.profileName, this.authProvider, this.endpointURL, this.modelDeploymentId, this.outputConverter);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    /* renamed from: convertFeatures, reason: merged with bridge method [inline-methods] */
    public DenseMatrix m4convertFeatures(SparseVector sparseVector) {
        return DenseMatrix.createDenseMatrix((double[][]) new double[]{sparseVector.toArray()});
    }

    /* JADX WARN: Type inference failed for: r0v2, types: [double[], double[][]] */
    protected DenseMatrix convertFeaturesList(List<SparseVector> list) {
        ?? r0 = new double[list.size()];
        for (int i = 0; i < list.size(); i++) {
            r0[i] = list.get(i).toArray();
        }
        return DenseMatrix.createDenseMatrix((double[][]) r0);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public DenseMatrix externalPrediction(DenseMatrix denseMatrix) {
        Invocation.Builder request = this.modelEndpoint.request();
        request.accept(new String[]{"application/json"});
        try {
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader((InputStream) request.buildPost(Entity.entity(formatMatrix(denseMatrix), "application/json")).invoke().getEntity(), StandardCharsets.UTF_8));
            try {
                StringBuilder sb = new StringBuilder();
                while (true) {
                    String readLine = bufferedReader.readLine();
                    if (readLine == null) {
                        String sb2 = sb.toString();
                        bufferedReader.close();
                        try {
                            return DenseMatrix.createDenseMatrix(((PredictionJson) this.mapper.readValue(sb2, PredictionJson.class)).prediction);
                        } catch (JsonProcessingException e) {
                            throw new IllegalStateException("Failed to parse json from deployed model endpoint, received '" + sb2 + "'", e);
                        }
                    }
                    sb.append(readLine);
                }
            } finally {
            }
        } catch (IOException e2) {
            throw new IllegalStateException("Failed to read response from input stream", e2);
        }
    }

    private static String formatMatrix(DenseMatrix denseMatrix) {
        StringBuilder sb = new StringBuilder();
        sb.append('[');
        for (int i = 0; i < denseMatrix.getDimension1Size(); i++) {
            sb.append(Arrays.toString(denseMatrix.getRow(i).toArray()));
            sb.append(',');
        }
        sb.setCharAt(sb.length() - 1, ']');
        return sb.toString();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Prediction<T> convertOutput(DenseMatrix denseMatrix, int i, Example<T> example) {
        if (denseMatrix.getDimension1Size() != 1) {
            throw new IllegalStateException("Expected a single score vector, received " + denseMatrix.getDimension1Size());
        }
        return this.outputConverter.convertOutput(denseMatrix.getRow(0), i, example, this.outputIDInfo);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<Prediction<T>> convertOutput(DenseMatrix denseMatrix, int[] iArr, List<Example<T>> list) {
        return this.outputConverter.convertOutput(denseMatrix, iArr, list, this.outputIDInfo);
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        this.jerseyClient.close();
    }

    /* renamed from: serialize, reason: merged with bridge method [inline-methods] */
    public ModelProto m5serialize() {
        ModelDataCarrier createDataCarrier = createDataCarrier();
        OCIModelProto.Builder newBuilder = OCIModelProto.newBuilder();
        newBuilder.setMetadata(createDataCarrier.serialize());
        newBuilder.addAllForwardFeatureMapping((Iterable) Arrays.stream(this.featureForwardMapping).boxed().collect(Collectors.toList()));
        newBuilder.addAllBackwardFeatureMapping((Iterable) Arrays.stream(this.featureBackwardMapping).boxed().collect(Collectors.toList()));
        newBuilder.setConfigFile(this.configFile.toString());
        newBuilder.setProfileName(this.profileName);
        newBuilder.setEndpointUrl(this.endpointURL);
        newBuilder.setModelDeploymentId(this.modelDeploymentId);
        newBuilder.setOutputConverter((OCIOutputConverterProto) this.outputConverter.serialize());
        ModelProto.Builder newBuilder2 = ModelProto.newBuilder();
        newBuilder2.setSerializedData(Any.pack(newBuilder.m102build()));
        newBuilder2.setClassName(OCIModel.class.getName());
        newBuilder2.setVersion(0);
        return newBuilder2.build();
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        System.setProperty("sun.net.http.allowRestrictedHeaders", "true");
        this.authProvider = makeAuthProvider(this.configFile, this.profileName);
        this.mapper = new ObjectMapper();
        this.requestSigningFilter = RequestSigningFilter.fromAuthProvider(this.authProvider);
        this.jerseyClient = ClientBuilder.newBuilder().build().register(this.requestSigningFilter);
        this.modelEndpoint = this.jerseyClient.target(this.endpointURL + this.modelDeploymentId).path("predict");
    }

    public static ConfigFileAuthenticationDetailsProvider makeAuthProvider(Path path, String str) throws IOException {
        return new ConfigFileAuthenticationDetailsProvider(path == null ? ConfigFileReader.parseDefault(str) : ConfigFileReader.parse(path.toString()));
    }

    public static <T extends Output<T>> OCIModel<T> createOCIModel(OutputFactory<T> outputFactory, Map<String, Integer> map, Map<T, Integer> map2, Path path, String str, OCIOutputConverter<T> oCIOutputConverter) {
        return createOCIModel(outputFactory, map, map2, path, null, str, oCIOutputConverter);
    }

    public static <T extends Output<T>> OCIModel<T> createOCIModel(OutputFactory<T> outputFactory, Map<String, Integer> map, Map<T, Integer> map2, Path path, String str, String str2, OCIOutputConverter<T> oCIOutputConverter) {
        try {
            ImmutableFeatureMap createFeatureMap = ExternalModel.createFeatureMap(map.keySet());
            ImmutableOutputInfo createOutputInfo = ExternalModel.createOutputInfo(outputFactory, map2);
            OffsetDateTime now = OffsetDateTime.now();
            ExternalTrainerProvenance externalTrainerProvenance = new ExternalTrainerProvenance(str2.getBytes(StandardCharsets.UTF_8));
            ExternalDatasetProvenance externalDatasetProvenance = new ExternalDatasetProvenance("unknown-external-data", outputFactory, false, map.size(), map2.size());
            String[] split = str2.split("/");
            String str3 = "https://" + split[2] + "/";
            String str4 = split[3];
            HashMap hashMap = new HashMap();
            hashMap.put("configFile", new FileProvenance("configFile", path));
            hashMap.put("endpointURL", new StringProvenance("endpointURL", str2));
            hashMap.put("modelDeploymentId", new StringProvenance("modelDeploymentId", str4));
            return new OCIModel<>("oci-ds-model", new ModelProvenance(OCIModel.class.getName(), now, externalDatasetProvenance, externalTrainerProvenance, hashMap), createFeatureMap, createOutputInfo, map, path, str, str3, str4, oCIOutputConverter);
        } catch (IOException e) {
            throw new IllegalArgumentException("Unable to load configuration from path " + path, e);
        }
    }

    /* renamed from: convertFeaturesList, reason: collision with other method in class */
    protected /* bridge */ /* synthetic */ Object m3convertFeaturesList(List list) {
        return convertFeaturesList((List<SparseVector>) list);
    }
}
