package org.tribuo.interop.onnx.extractors;

import ai.onnxruntime.NodeInfo;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.TensorInfo;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.FloatBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.data.text.TextFeatureExtractor;
import org.tribuo.data.text.TextPipeline;
import org.tribuo.impl.ArrayExample;
import org.tribuo.sequence.SequenceExample;
import org.tribuo.util.tokens.impl.wordpiece.Wordpiece;
import org.tribuo.util.tokens.impl.wordpiece.WordpieceBasicTokenizer;
import org.tribuo.util.tokens.impl.wordpiece.WordpieceTokenizer;

/* loaded from: input_file:org/tribuo/interop/onnx/extractors/BERTFeatureExtractor.class */
public class BERTFeatureExtractor<T extends Output<T>> implements AutoCloseable, TextFeatureExtractor<T>, TextPipeline {
    private static final Logger logger = Logger.getLogger(BERTFeatureExtractor.class.getName());
    public static final String INPUT_IDS = "input_ids";
    public static final String ATTENTION_MASK = "attention_mask";
    public static final String TOKEN_TYPE_IDS = "token_type_ids";
    public static final String TOKEN_OUTPUT = "output_0";
    public static final String CLS_OUTPUT = "output_1";
    public static final String CLASSIFICATION_TOKEN = "[CLS]";
    public static final String SEPARATOR_TOKEN = "[SEP]";
    public static final String UNKNOWN_TOKEN = "[UNK]";
    public static final String TOKEN_METADATA = "Token";
    public static final long MASK_VALUE = 1;
    public static final long TOKEN_TYPE_VALUE = 0;

    @Config(mandatory = true, description = "Output factory to use.")
    private OutputFactory<T> outputFactory;

    @Config(mandatory = true, description = "Path to the BERT model in ONNX format")
    private Path modelPath;

    @Config(mandatory = true, description = "Path to the tokenizer config")
    private Path tokenizerPath;

    @Config(description = "Maximum length in wordpieces")
    private int maxLength;

    @Config(description = "Type of pooling to use when returning a single embedding for the input sequence")
    private OutputPooling pooling;

    @Config(description = "Use CUDA")
    private boolean useCUDA;
    private Map<String, Integer> tokenIDs;
    private String classificationToken;
    private String separatorToken;
    private String unknownToken;
    private WordpieceTokenizer tokenizer;
    private int bertDim;
    private String[] featureNames;
    private OrtEnvironment env;
    private OrtSession session;
    private boolean closed;

    /* loaded from: input_file:org/tribuo/interop/onnx/extractors/BERTFeatureExtractor$BERTFeatureExtractorOptions.class */
    public static class BERTFeatureExtractorOptions implements Options {

        @Option(charName = 'b', longName = "bert", usage = "BERTFeatureExtractor instance")
        public BERTFeatureExtractor<?> bert;

        @Option(charName = 'i', longName = "input-file", usage = "Input file to read, one doc per line")
        public Path inputFile;

        @Option(charName = 'o', longName = "output-file", usage = "Output json file.")
        public Path outputFile;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/tribuo/interop/onnx/extractors/BERTFeatureExtractor$BERTResult.class */
    public static final class BERTResult {
        public final List<String> tokens;
        public final long[] ids;
        public final long[] masks;
        public final long[] tokenTypes;
        public final float[] clsEmbedding;
        public final float[][] tokenEmbeddings;

        BERTResult(List<String> list, long[] jArr, long[] jArr2, long[] jArr3, double[] dArr, float[][] fArr) {
            this.tokens = list;
            this.ids = jArr;
            this.masks = jArr2;
            this.tokenTypes = jArr3;
            this.clsEmbedding = new float[dArr.length];
            for (int i = 0; i < dArr.length; i++) {
                this.clsEmbedding[i] = (float) dArr[i];
            }
            this.tokenEmbeddings = fArr;
        }
    }

    /* loaded from: input_file:org/tribuo/interop/onnx/extractors/BERTFeatureExtractor$OutputPooling.class */
    public enum OutputPooling {
        CLS,
        MEAN,
        CLS_AND_MEAN
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/tribuo/interop/onnx/extractors/BERTFeatureExtractor$TokenizerConfig.class */
    public static final class TokenizerConfig {
        final Map<String, Integer> tokenIDs;
        final String unknownToken;
        final String classificationToken;
        final String separatorToken;
        final boolean lowercase;
        final boolean stripAccents;
        final int maxInputCharsPerWord;

        TokenizerConfig(Map<String, Integer> map, String str, String str2, String str3, boolean z, boolean z2, int i) {
            this.lowercase = z;
            this.unknownToken = str;
            this.classificationToken = str2;
            this.separatorToken = str3;
            this.stripAccents = z2;
            this.tokenIDs = map;
            this.maxInputCharsPerWord = i;
        }
    }

    private BERTFeatureExtractor() {
        this.maxLength = 512;
        this.pooling = OutputPooling.CLS;
        this.useCUDA = false;
        this.classificationToken = CLASSIFICATION_TOKEN;
        this.separatorToken = SEPARATOR_TOKEN;
        this.unknownToken = UNKNOWN_TOKEN;
        this.closed = false;
    }

    public BERTFeatureExtractor(OutputFactory<T> outputFactory, Path path, Path path2) {
        this.maxLength = 512;
        this.pooling = OutputPooling.CLS;
        this.useCUDA = false;
        this.classificationToken = CLASSIFICATION_TOKEN;
        this.separatorToken = SEPARATOR_TOKEN;
        this.unknownToken = UNKNOWN_TOKEN;
        this.closed = false;
        this.outputFactory = outputFactory;
        this.modelPath = path;
        this.tokenizerPath = path2;
        postConfig();
    }

    public BERTFeatureExtractor(OutputFactory<T> outputFactory, Path path, Path path2, OutputPooling outputPooling, int i, boolean z) {
        this.maxLength = 512;
        this.pooling = OutputPooling.CLS;
        this.useCUDA = false;
        this.classificationToken = CLASSIFICATION_TOKEN;
        this.separatorToken = SEPARATOR_TOKEN;
        this.unknownToken = UNKNOWN_TOKEN;
        this.closed = false;
        this.outputFactory = outputFactory;
        this.modelPath = path;
        this.tokenizerPath = path2;
        this.pooling = outputPooling;
        this.maxLength = i;
        this.useCUDA = z;
        postConfig();
    }

    public void postConfig() throws PropertyException {
        try {
            this.env = OrtEnvironment.getEnvironment();
            OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
            if (this.useCUDA) {
                sessionOptions.addCUDA();
            }
            this.session = this.env.createSession(this.modelPath.toString(), sessionOptions);
            Map outputInfo = this.session.getOutputInfo();
            if (outputInfo.size() != 2) {
                throw new PropertyException("", "modelPath", "Invalid model, expected 2 outputs, found " + outputInfo.size());
            }
            NodeInfo nodeInfo = (NodeInfo) outputInfo.get(TOKEN_OUTPUT);
            if (nodeInfo == null || !(nodeInfo.getInfo() instanceof TensorInfo)) {
                throw new PropertyException("", "modelPath", "Invalid model, expected to find tensor output called 'output_0'");
            }
            long[] shape = nodeInfo.getInfo().getShape();
            if (shape.length != 3) {
                throw new PropertyException("", "modelPath", "Invalid model, expected to find output_0 with 3 dimensions, found :" + Arrays.toString(shape));
            }
            this.bertDim = (int) shape[2];
            NodeInfo nodeInfo2 = (NodeInfo) outputInfo.get(CLS_OUTPUT);
            if (nodeInfo2 == null || !(nodeInfo2.getInfo() instanceof TensorInfo)) {
                throw new PropertyException("", "modelPath", "Invalid model, expected to find tensor output called 'output_1'");
            }
            long[] shape2 = nodeInfo2.getInfo().getShape();
            if (shape2.length != 2) {
                throw new PropertyException("", "modelPath", "Invalid model, expected to find output_1 with 2 dimensions, found :" + Arrays.toString(shape2));
            }
            if (shape2[1] != this.bertDim) {
                throw new PropertyException("", "modelPath", "Invalid model, expected to find two outputs with the same embedding dimension, instead found " + this.bertDim + " and " + shape2[1]);
            }
            Map inputInfo = this.session.getInputInfo();
            if (inputInfo.size() != 3) {
                throw new PropertyException("", "modelPath", "Invalid model, expected 3 inputs, found " + inputInfo.size());
            }
            if (!inputInfo.containsKey(ATTENTION_MASK)) {
                throw new PropertyException("", "modelPath", "Invalid model, expected to find an input called 'attention_mask'");
            }
            if (!inputInfo.containsKey(INPUT_IDS)) {
                throw new PropertyException("", "modelPath", "Invalid model, expected to find an input called 'input_ids'");
            }
            if (!inputInfo.containsKey(TOKEN_TYPE_IDS)) {
                throw new PropertyException("", "modelPath", "Invalid model, expected to find an input called 'token_type_ids'");
            }
            this.featureNames = generateFeatureNames(this.bertDim);
            TokenizerConfig loadTokenizer = loadTokenizer(this.tokenizerPath);
            Wordpiece wordpiece = new Wordpiece(loadTokenizer.tokenIDs.keySet(), loadTokenizer.unknownToken, loadTokenizer.maxInputCharsPerWord);
            this.tokenIDs = loadTokenizer.tokenIDs;
            this.unknownToken = loadTokenizer.unknownToken;
            this.classificationToken = loadTokenizer.classificationToken;
            this.separatorToken = loadTokenizer.separatorToken;
            this.tokenizer = new WordpieceTokenizer(wordpiece, new WordpieceBasicTokenizer(), loadTokenizer.lowercase, loadTokenizer.stripAccents, Collections.emptySet());
        } catch (OrtException e) {
            throw new PropertyException(e, "", "modelPath", "Failed to load model, ORT threw: ");
        } catch (IOException e2) {
            throw new PropertyException(e2, "", "tokenizerPath", "Failed to load tokenizer, Jackson threw: ");
        }
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public ConfiguredObjectProvenance m19getProvenance() {
        return new ConfiguredObjectProvenanceImpl(this, "FeatureExtractor");
    }

    public void reconfigureOrtSession(OrtSession.SessionOptions sessionOptions) throws OrtException {
        this.session.close();
        this.session = this.env.createSession(this.modelPath.toString(), sessionOptions);
    }

    public int getMaxLength() {
        return this.maxLength;
    }

    public Set<String> getVocab() {
        return Collections.unmodifiableSet(this.tokenIDs.keySet());
    }

    private static String[] generateFeatureNames(int i) {
        String str = "D=%0" + ("" + i).length() + "d";
        String[] strArr = new String[i];
        for (int i2 = 0; i2 < i; i2++) {
            strArr[i2] = String.format(str, Integer.valueOf(i2));
        }
        return strArr;
    }

    static TokenizerConfig loadTokenizer(Path path) throws IOException {
        String asText;
        String asText2;
        JsonNode readTree = new ObjectMapper(new JsonFactory()).readTree(path.toFile());
        HashMap hashMap = new HashMap();
        JsonNode jsonNode = readTree.get("normalizer");
        if (jsonNode == null) {
            throw new IllegalStateException("Failed to parse tokenizer json, did not find the normalizer");
        }
        boolean asBoolean = jsonNode.get("lowercase").asBoolean();
        boolean asBoolean2 = jsonNode.get("strip_accents").asBoolean();
        JsonNode jsonNode2 = readTree.get("post_processor");
        if (jsonNode2 == null) {
            throw new IllegalStateException("Failed to parse tokenizer json, did not find the post processor");
        }
        String asText3 = jsonNode2.get("type").asText();
        if (asText3 != null && asText3.equals("TemplateProcessing")) {
            JsonNode jsonNode3 = jsonNode2.get("special_tokens");
            if (jsonNode3 == null) {
                throw new IllegalStateException("Failed to parse tokenizer json, did not find the special tokens.");
            }
            JsonNode jsonNode4 = jsonNode3.get(SEPARATOR_TOKEN);
            if (jsonNode4 == null) {
                throw new IllegalStateException("Failed to parse tokenizer json, did not find separator token.");
            }
            asText = jsonNode4.get("tokens").get(0).asText();
            JsonNode jsonNode5 = jsonNode3.get(CLASSIFICATION_TOKEN);
            if (jsonNode5 == null) {
                throw new IllegalStateException("Failed to parse tokenizer json, did not find classification token.");
            }
            asText2 = jsonNode5.get("tokens").get(0).asText();
        } else {
            if (asText3 == null || !asText3.equals("BertProcessing")) {
                throw new IllegalStateException("Failed to parse tokenizer json, did not recognise post_processor:type " + asText3);
            }
            JsonNode jsonNode6 = jsonNode2.get("sep");
            if (jsonNode6 == null) {
                throw new IllegalStateException("Failed to parse tokenizer json, did not find separator token.");
            }
            asText = jsonNode6.get(0).asText();
            JsonNode jsonNode7 = jsonNode2.get("cls");
            if (jsonNode7 == null) {
                throw new IllegalStateException("Failed to parse tokenizer json, did not find classification token.");
            }
            asText2 = jsonNode7.get(0).asText();
        }
        JsonNode jsonNode8 = readTree.get("model");
        if (jsonNode8 == null) {
            throw new IllegalStateException("Failed to parse tokenizer json, did not find the model");
        }
        String asText4 = jsonNode8.get("unk_token").asText();
        if (asText4 == null || asText4.isEmpty()) {
            throw new IllegalStateException("Failed to parse tokenizer json, did not extract unknown token");
        }
        int asInt = jsonNode8.get("max_input_chars_per_word").asInt();
        if (asInt == 0) {
            throw new IllegalStateException("Failed to parse tokenizer json, did not extract max_input_chars_per_word");
        }
        JsonNode jsonNode9 = jsonNode8.get("vocab");
        if (jsonNode9 == null) {
            throw new IllegalStateException("Failed to parse tokenizer json, did not extract vocab");
        }
        Iterator fields = jsonNode9.fields();
        while (fields.hasNext()) {
            Map.Entry entry = (Map.Entry) fields.next();
            int asInt2 = ((JsonNode) entry.getValue()).asInt(-1);
            if (asInt2 == -1) {
                throw new IllegalStateException("Failed to parse tokenizer json, could not extract vocab item '" + ((String) entry.getKey()) + "'");
            }
            hashMap.put((String) entry.getKey(), Integer.valueOf(asInt2));
        }
        return new TokenizerConfig(hashMap, asText4, asText2, asText, asBoolean, asBoolean2, asInt);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private OnnxTensor convertTokens(List<String> list) throws OrtException {
        long[] jArr = new long[list.size() + 2];
        jArr[0] = this.tokenIDs.get(this.classificationToken).intValue();
        int i = 1;
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            if (this.tokenIDs.get(it.next()) == null) {
                jArr[i] = this.tokenIDs.get(this.unknownToken).intValue();
            } else {
                jArr[i] = r0.intValue();
            }
            i++;
        }
        jArr[i] = this.tokenIDs.get(this.separatorToken).intValue();
        return OnnxTensor.createTensor(this.env, new long[]{jArr});
    }

    /* JADX WARN: Multi-variable type inference failed */
    private OnnxTensor createTensor(int i, long j) throws OrtException {
        long[] jArr = new long[i];
        Arrays.fill(jArr, j);
        return OnnxTensor.createTensor(this.env, new long[]{jArr});
    }

    private static double[] extractFeatures(FloatBuffer floatBuffer, int i) {
        double[] dArr = new double[i];
        float[] fArr = new float[i];
        floatBuffer.get(fArr);
        for (int i2 = 0; i2 < fArr.length; i2++) {
            dArr[i2] = fArr[i2];
        }
        return dArr;
    }

    private static void addFeatures(FloatBuffer floatBuffer, int i, double[] dArr) {
        float[] fArr = new float[i];
        floatBuffer.get(fArr);
        for (int i2 = 0; i2 < fArr.length; i2++) {
            int i3 = i2;
            dArr[i3] = dArr[i3] + fArr[i2];
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Example<T> extractExample(List<String> list) {
        return extractExample(list, this.outputFactory.getUnknownOutput());
    }

    public Example<T> extractExample(List<String> list, T t) {
        return new ArrayExample(t, this.featureNames, extractFeatures(list));
    }

    double[] extractFeatures(List<String> list) {
        double[] dArr;
        if (list.size() > this.maxLength - 2) {
            throw new IllegalArgumentException("Too many tokens, expected " + (this.maxLength - 2) + " found " + list.size());
        }
        try {
            OnnxTensor convertTokens = convertTokens(list);
            try {
                OnnxTensor createTensor = createTensor(list.size() + 2, 1L);
                try {
                    OnnxTensor createTensor2 = createTensor(list.size() + 2, 0L);
                    try {
                        HashMap hashMap = new HashMap(3);
                        hashMap.put(INPUT_IDS, convertTokens);
                        hashMap.put(ATTENTION_MASK, createTensor);
                        hashMap.put(TOKEN_TYPE_IDS, createTensor2);
                        OrtSession.Result run = this.session.run(hashMap);
                        try {
                            switch (this.pooling) {
                                case CLS:
                                    dArr = extractCLSVector(run);
                                    break;
                                case MEAN:
                                    dArr = extractTokenVector(run, list.size(), true);
                                    break;
                                case CLS_AND_MEAN:
                                    double[] extractCLSVector = extractCLSVector(run);
                                    double[] extractTokenVector = extractTokenVector(run, list.size(), true);
                                    dArr = new double[this.bertDim];
                                    for (int i = 0; i < this.bertDim; i++) {
                                        dArr[i] = (extractCLSVector[i] + extractTokenVector[i]) / 2.0d;
                                    }
                                    break;
                                default:
                                    throw new IllegalStateException("Unknown pooling type " + this.pooling);
                            }
                            double[] dArr2 = dArr;
                            if (run != null) {
                                run.close();
                            }
                            if (createTensor2 != null) {
                                createTensor2.close();
                            }
                            if (createTensor != null) {
                                createTensor.close();
                            }
                            if (convertTokens != null) {
                                convertTokens.close();
                            }
                            return dArr2;
                        } catch (Throwable th) {
                            if (run != null) {
                                try {
                                    run.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            }
                            throw th;
                        }
                    } catch (Throwable th3) {
                        if (createTensor2 != null) {
                            try {
                                createTensor2.close();
                            } catch (Throwable th4) {
                                th3.addSuppressed(th4);
                            }
                        }
                        throw th3;
                    }
                } catch (Throwable th5) {
                    if (createTensor != null) {
                        try {
                            createTensor.close();
                        } catch (Throwable th6) {
                            th5.addSuppressed(th6);
                        }
                    }
                    throw th5;
                }
            } catch (Throwable th7) {
                if (convertTokens != null) {
                    try {
                        convertTokens.close();
                    } catch (Throwable th8) {
                        th7.addSuppressed(th8);
                    }
                }
                throw th7;
            }
        } catch (OrtException e) {
            throw new IllegalStateException("ORT failed to execute: ", e);
        }
    }

    private double[] extractCLSVector(OrtSession.Result result) {
        OnnxTensor onnxTensor = (OnnxValue) result.get(CLS_OUTPUT).orElseThrow(() -> {
            return new IllegalStateException("Failed to read output_1 from the BERT response");
        });
        if (!(onnxTensor instanceof OnnxTensor)) {
            throw new IllegalStateException("Expected OnnxTensor, found " + onnxTensor.getClass());
        }
        OnnxTensor onnxTensor2 = onnxTensor;
        FloatBuffer floatBuffer = onnxTensor2.getFloatBuffer();
        if (floatBuffer != null) {
            return extractFeatures(floatBuffer, this.bertDim);
        }
        throw new IllegalStateException("Expected a float tensor, found " + onnxTensor2.getInfo().toString());
    }

    private double[] extractTokenVector(OrtSession.Result result, int i, boolean z) {
        OnnxTensor onnxTensor = (OnnxValue) result.get(TOKEN_OUTPUT).orElseThrow(() -> {
            return new IllegalStateException("Failed to read output_0 from the BERT response");
        });
        if (!(onnxTensor instanceof OnnxTensor)) {
            throw new IllegalStateException("Expected OnnxTensor, found " + onnxTensor.getClass());
        }
        OnnxTensor onnxTensor2 = onnxTensor;
        FloatBuffer floatBuffer = onnxTensor2.getFloatBuffer();
        if (floatBuffer == null) {
            throw new IllegalStateException("Expected a float tensor, found " + onnxTensor2.getInfo().toString());
        }
        double[] dArr = new double[this.bertDim];
        floatBuffer.position(this.bertDim);
        for (int i2 = 0; i2 < i; i2++) {
            addFeatures(floatBuffer, this.bertDim, dArr);
        }
        if (z) {
            for (int i3 = 0; i3 < this.bertDim; i3++) {
                int i4 = i3;
                dArr[i4] = dArr[i4] / i;
            }
        }
        return dArr;
    }

    public SequenceExample<T> extractSequenceExample(List<String> list, boolean z) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            arrayList.add(this.outputFactory.getUnknownOutput());
        }
        return extractSequenceExample(list, arrayList, z);
    }

    public SequenceExample<T> extractSequenceExample(List<String> list, List<T> list2, boolean z) {
        if (list.size() > this.maxLength - 2) {
            throw new IllegalArgumentException("Too many tokens, expected " + (this.maxLength - 2) + " found " + list.size());
        }
        try {
            OnnxTensor convertTokens = convertTokens(list);
            try {
                OnnxTensor createTensor = createTensor(list.size() + 2, 1L);
                try {
                    OnnxTensor createTensor2 = createTensor(list.size() + 2, 0L);
                    try {
                        HashMap hashMap = new HashMap(3);
                        hashMap.put(INPUT_IDS, convertTokens);
                        hashMap.put(ATTENTION_MASK, createTensor);
                        hashMap.put(TOKEN_TYPE_IDS, createTensor2);
                        OrtSession.Result run = this.session.run(hashMap);
                        try {
                            OnnxTensor onnxTensor = (OnnxValue) run.get(TOKEN_OUTPUT).orElseThrow(() -> {
                                return new IllegalStateException("Failed to read output_0 from the BERT response");
                            });
                            if (!(onnxTensor instanceof OnnxTensor)) {
                                throw new IllegalStateException("Expected OnnxTensor, found " + onnxTensor.getClass());
                            }
                            OnnxTensor onnxTensor2 = onnxTensor;
                            FloatBuffer floatBuffer = onnxTensor2.getFloatBuffer();
                            if (floatBuffer == null) {
                                throw new IllegalStateException("Expected a float tensor, found " + onnxTensor2.getInfo().toString());
                            }
                            ArrayList arrayList = new ArrayList();
                            if (z) {
                                floatBuffer.position(this.bertDim);
                            } else {
                                ArrayExample arrayExample = new ArrayExample(this.outputFactory.getUnknownOutput(), this.featureNames, extractFeatures(floatBuffer, this.bertDim));
                                arrayExample.setMetadataValue(TOKEN_METADATA, CLASSIFICATION_TOKEN);
                                arrayList.add(arrayExample);
                            }
                            for (int i = 0; i < list.size(); i++) {
                                ArrayExample arrayExample2 = new ArrayExample(list2.get(i), this.featureNames, extractFeatures(floatBuffer, this.bertDim));
                                arrayExample2.setMetadataValue(TOKEN_METADATA, list.get(i));
                                arrayList.add(arrayExample2);
                            }
                            if (!z) {
                                ArrayExample arrayExample3 = new ArrayExample(this.outputFactory.getUnknownOutput(), this.featureNames, extractFeatures(floatBuffer, this.bertDim));
                                arrayExample3.setMetadataValue(TOKEN_METADATA, SEPARATOR_TOKEN);
                                arrayList.add(arrayExample3);
                            }
                            SequenceExample<T> sequenceExample = new SequenceExample<>(arrayList);
                            if (run != null) {
                                run.close();
                            }
                            if (createTensor2 != null) {
                                createTensor2.close();
                            }
                            if (createTensor != null) {
                                createTensor.close();
                            }
                            if (convertTokens != null) {
                                convertTokens.close();
                            }
                            return sequenceExample;
                        } catch (Throwable th) {
                            if (run != null) {
                                try {
                                    run.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            }
                            throw th;
                        }
                    } catch (Throwable th3) {
                        if (createTensor2 != null) {
                            try {
                                createTensor2.close();
                            } catch (Throwable th4) {
                                th3.addSuppressed(th4);
                            }
                        }
                        throw th3;
                    }
                } catch (Throwable th5) {
                    if (createTensor != null) {
                        try {
                            createTensor.close();
                        } catch (Throwable th6) {
                            th5.addSuppressed(th6);
                        }
                    }
                    throw th5;
                }
            } catch (Throwable th7) {
                if (convertTokens != null) {
                    try {
                        convertTokens.close();
                    } catch (Throwable th8) {
                        th7.addSuppressed(th8);
                    }
                }
                throw th7;
            }
        } catch (OrtException e) {
            throw new IllegalStateException("ORT failed to execute: ", e);
        }
    }

    @Override // java.lang.AutoCloseable
    public void close() throws OrtException {
        if (this.closed) {
            return;
        }
        this.session.close();
        this.env.close();
        this.closed = true;
    }

    public Example<T> extract(T t, String str) {
        return extractExample(tokenize(str), t);
    }

    public List<Feature> process(String str, String str2) {
        double[] extractFeatures = extractFeatures(tokenize(str2));
        ArrayList arrayList = new ArrayList(extractFeatures.length);
        for (int i = 0; i < extractFeatures.length; i++) {
            arrayList.add(new Feature(str + "-" + this.featureNames[i], extractFeatures[i]));
        }
        return arrayList;
    }

    List<String> tokenize(String str) {
        List<String> split = this.tokenizer.split(str);
        if (split.size() > this.maxLength - 2) {
            logger.fine("Truncating sentence to " + (this.maxLength + 2) + " from " + split.size());
            split = split.subList(0, this.maxLength - 2);
        }
        return split;
    }

    private BERTResult bert(String str) throws OrtException {
        List<String> list = tokenize(str);
        OnnxTensor convertTokens = convertTokens(list);
        try {
            OnnxTensor createTensor = createTensor(list.size() + 2, 1L);
            try {
                OnnxTensor createTensor2 = createTensor(list.size() + 2, 0L);
                try {
                    long[] jArr = ((long[][]) convertTokens.getValue())[0];
                    long[] jArr2 = ((long[][]) createTensor.getValue())[0];
                    long[] jArr3 = ((long[][]) createTensor2.getValue())[0];
                    HashMap hashMap = new HashMap(3);
                    hashMap.put(INPUT_IDS, convertTokens);
                    hashMap.put(ATTENTION_MASK, createTensor);
                    hashMap.put(TOKEN_TYPE_IDS, createTensor2);
                    OrtSession.Result run = this.session.run(hashMap);
                    try {
                        BERTResult bERTResult = new BERTResult(list, jArr, jArr2, jArr3, extractCLSVector(run), ((float[][][]) ((OnnxValue) run.get(TOKEN_OUTPUT).get()).getValue())[0]);
                        if (run != null) {
                            run.close();
                        }
                        if (createTensor2 != null) {
                            createTensor2.close();
                        }
                        if (createTensor != null) {
                            createTensor.close();
                        }
                        if (convertTokens != null) {
                            convertTokens.close();
                        }
                        return bERTResult;
                    } catch (Throwable th) {
                        if (run != null) {
                            try {
                                run.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                } catch (Throwable th3) {
                    if (createTensor2 != null) {
                        try {
                            createTensor2.close();
                        } catch (Throwable th4) {
                            th3.addSuppressed(th4);
                        }
                    }
                    throw th3;
                }
            } catch (Throwable th5) {
                if (createTensor != null) {
                    try {
                        createTensor.close();
                    } catch (Throwable th6) {
                        th5.addSuppressed(th6);
                    }
                }
                throw th5;
            }
        } catch (Throwable th7) {
            if (convertTokens != null) {
                try {
                    convertTokens.close();
                } catch (Throwable th8) {
                    th7.addSuppressed(th8);
                }
            }
            throw th7;
        }
    }

    public static void main(String[] strArr) throws IOException, OrtException {
        BERTFeatureExtractorOptions bERTFeatureExtractorOptions = new BERTFeatureExtractorOptions();
        new ConfigurationManager(strArr, bERTFeatureExtractorOptions);
        List<String> readAllLines = Files.readAllLines(bERTFeatureExtractorOptions.inputFile, StandardCharsets.UTF_8);
        ArrayList arrayList = new ArrayList();
        Iterator<String> it = readAllLines.iterator();
        while (it.hasNext()) {
            arrayList.add(bERTFeatureExtractorOptions.bert.bert(it.next()));
        }
        ObjectMapper objectMapper = new ObjectMapper();
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(bERTFeatureExtractorOptions.outputFile.toFile()));
        try {
            bufferedWriter.write(objectMapper.writeValueAsString(arrayList));
            bufferedWriter.close();
        } catch (Throwable th) {
            try {
                bufferedWriter.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }
}
