package org.tribuo.classification.explanations.lime;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.SplittableRandom;
import org.tribuo.CategoricalInfo;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.RealInfo;
import org.tribuo.SparseModel;
import org.tribuo.SparseTrainer;
import org.tribuo.VariableIDInfo;
import org.tribuo.VariableInfo;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.explanations.ColumnarExplainer;
import org.tribuo.classification.explanations.Explanation;
import org.tribuo.data.columnar.FieldProcessor;
import org.tribuo.data.columnar.ResponseProcessor;
import org.tribuo.data.columnar.RowProcessor;
import org.tribuo.impl.ArrayExample;
import org.tribuo.impl.ListExample;
import org.tribuo.math.la.SparseVector;
import org.tribuo.provenance.SimpleDataSourceProvenance;
import org.tribuo.regression.Regressor;
import org.tribuo.util.Util;
import org.tribuo.util.tokens.Token;
import org.tribuo.util.tokens.Tokenizer;

/* loaded from: input_file:org/tribuo/classification/explanations/lime/LIMEColumnar.class */
public class LIMEColumnar extends LIMEBase implements ColumnarExplainer<Regressor> {
    private final RowProcessor<Label> generator;
    private final Map<String, FieldProcessor> binarisedFields;
    private final Map<String, FieldProcessor> tabularFields;
    private final Map<String, FieldProcessor> textFields;
    private final ResponseProcessor<Label> responseProcessor;
    private final Map<String, List<VariableInfo>> binarisedInfos;
    private final Map<String, double[]> binarisedCDFs;
    private final ImmutableFeatureMap binarisedDomain;
    private final ImmutableFeatureMap tabularDomain;
    private final ImmutableFeatureMap textDomain;
    private final Tokenizer tokenizer;
    private final ThreadLocal<Tokenizer> tokenizerThreadLocal;

    /* renamed from: org.tribuo.classification.explanations.lime.LIMEColumnar$1, reason: invalid class name */
    /* loaded from: input_file:org/tribuo/classification/explanations/lime/LIMEColumnar$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$tribuo$data$columnar$FieldProcessor$GeneratedFeatureType = new int[FieldProcessor.GeneratedFeatureType.values().length];

        static {
            try {
                $SwitchMap$org$tribuo$data$columnar$FieldProcessor$GeneratedFeatureType[FieldProcessor.GeneratedFeatureType.BINARISED_CATEGORICAL.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$tribuo$data$columnar$FieldProcessor$GeneratedFeatureType[FieldProcessor.GeneratedFeatureType.CATEGORICAL.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$tribuo$data$columnar$FieldProcessor$GeneratedFeatureType[FieldProcessor.GeneratedFeatureType.REAL.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$tribuo$data$columnar$FieldProcessor$GeneratedFeatureType[FieldProcessor.GeneratedFeatureType.TEXT.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    public LIMEColumnar(SplittableRandom splittableRandom, Model<Label> model, SparseTrainer<Regressor> sparseTrainer, int i, RowProcessor<Label> rowProcessor, Tokenizer tokenizer) {
        super(splittableRandom, model, sparseTrainer, i);
        this.binarisedFields = new HashMap();
        this.tabularFields = new HashMap();
        this.textFields = new HashMap();
        this.generator = rowProcessor.copy();
        this.responseProcessor = this.generator.getResponseProcessor();
        this.tokenizer = tokenizer;
        this.tokenizerThreadLocal = ThreadLocal.withInitial(() -> {
            try {
                return this.tokenizer.clone();
            } catch (CloneNotSupportedException e) {
                throw new IllegalArgumentException("Tokenizer not cloneable", e);
            }
        });
        if (!this.generator.isConfigured()) {
            this.generator.expandRegexMapping(model);
        }
        this.binarisedInfos = new HashMap();
        ArrayList arrayList = new ArrayList();
        Iterator it = model.getFeatureIDMap().iterator();
        while (it.hasNext()) {
            arrayList.add((VariableInfo) it.next());
        }
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        for (Map.Entry entry : this.generator.getFieldProcessors().entrySet()) {
            String str = ((String) entry.getKey()) + "@";
            switch (AnonymousClass1.$SwitchMap$org$tribuo$data$columnar$FieldProcessor$GeneratedFeatureType[((FieldProcessor) entry.getValue()).getFeatureType().ordinal()]) {
                case 1:
                    int numNamespaces = ((FieldProcessor) entry.getValue()).getNumNamespaces();
                    if (numNamespaces > 1) {
                        for (int i2 = 0; i2 < numNamespaces; i2++) {
                            String str2 = ((String) entry.getKey()) + "#" + i2;
                            String str3 = str2 + "@";
                            this.binarisedFields.put(str2, (FieldProcessor) entry.getValue());
                            List<VariableInfo> computeIfAbsent = this.binarisedInfos.computeIfAbsent(str2, str4 -> {
                                return new ArrayList();
                            });
                            ListIterator listIterator = arrayList.listIterator();
                            while (listIterator.hasNext()) {
                                CategoricalInfo categoricalInfo = (VariableInfo) listIterator.next();
                                if (categoricalInfo.getName().startsWith(str3)) {
                                    if (categoricalInfo.getUniqueObservations() != 1) {
                                        throw new IllegalStateException("Processor " + ((String) entry.getKey()) + ", should have been binary, but had " + categoricalInfo.getUniqueObservations() + " unique values");
                                    }
                                    computeIfAbsent.add(categoricalInfo);
                                    arrayList2.add(categoricalInfo);
                                    listIterator.remove();
                                }
                            }
                        }
                        break;
                    } else {
                        this.binarisedFields.put((String) entry.getKey(), (FieldProcessor) entry.getValue());
                        List<VariableInfo> computeIfAbsent2 = this.binarisedInfos.computeIfAbsent((String) entry.getKey(), str5 -> {
                            return new ArrayList();
                        });
                        ListIterator listIterator2 = arrayList.listIterator();
                        while (listIterator2.hasNext()) {
                            CategoricalInfo categoricalInfo2 = (VariableInfo) listIterator2.next();
                            if (categoricalInfo2.getName().startsWith(str)) {
                                if (categoricalInfo2.getUniqueObservations() != 1) {
                                    throw new IllegalStateException("Processor " + ((String) entry.getKey()) + ", should have been binary, but had " + categoricalInfo2.getUniqueObservations() + " unique values");
                                }
                                computeIfAbsent2.add(categoricalInfo2);
                                arrayList2.add(categoricalInfo2);
                                listIterator2.remove();
                            }
                        }
                        break;
                    }
                case 2:
                case 3:
                    this.tabularFields.put((String) entry.getKey(), (FieldProcessor) entry.getValue());
                    ListIterator listIterator3 = arrayList.listIterator();
                    while (listIterator3.hasNext()) {
                        VariableInfo variableInfo = (VariableInfo) listIterator3.next();
                        if (variableInfo.getName().startsWith(str)) {
                            arrayList3.add(variableInfo);
                            listIterator3.remove();
                        }
                    }
                    break;
                case 4:
                    this.textFields.put((String) entry.getKey(), (FieldProcessor) entry.getValue());
                    ListIterator listIterator4 = arrayList.listIterator();
                    while (listIterator4.hasNext()) {
                        VariableInfo variableInfo2 = (VariableInfo) listIterator4.next();
                        if (variableInfo2.getName().startsWith(str)) {
                            arrayList4.add(variableInfo2);
                            listIterator4.remove();
                        }
                    }
                    break;
                default:
                    throw new IllegalArgumentException("Unsupported feature type " + ((FieldProcessor) entry.getValue()).getFeatureType());
            }
        }
        if (arrayList.size() != 0) {
            throw new IllegalArgumentException("Found " + arrayList.size() + " unsupported features.");
        }
        if (this.generator.getFeatureProcessors().size() != 0) {
            throw new IllegalArgumentException("LIMEColumnar does not support FeatureProcessors.");
        }
        this.tabularDomain = new ImmutableFeatureMap(arrayList3);
        this.textDomain = new ImmutableFeatureMap(arrayList4);
        this.binarisedDomain = new ImmutableFeatureMap(arrayList2);
        this.binarisedCDFs = new HashMap();
        for (Map.Entry<String, List<VariableInfo>> entry2 : this.binarisedInfos.entrySet()) {
            long j = 0;
            long[] jArr = new long[entry2.getValue().size() + 1];
            int i3 = 0;
            Iterator<VariableInfo> it2 = entry2.getValue().iterator();
            while (it2.hasNext()) {
                long count = it2.next().getCount();
                jArr[i3] = count;
                j += count;
                i3++;
            }
            long j2 = this.numTrainingExamples - j;
            if (j2 < 0) {
                throw new IllegalStateException("Processor " + entry2.getKey() + " purports to be a BINARISED_CATEGORICAL, but had overlap in it's elements");
            }
            jArr[i3] = j2;
            this.binarisedCDFs.put(entry2.getKey(), Util.generateCDF(jArr, this.numTrainingExamples));
        }
    }

    @Override // org.tribuo.classification.explanations.ColumnarExplainer
    public Explanation<Regressor> explain(Map<String, String> map) {
        return (LIMEExplanation) explainWithSamples(map).getA();
    }

    protected Pair<LIMEExplanation, List<Example<Regressor>>> explainWithSamples(Map<String, String> map) {
        Optional generateExample = this.generator.generateExample(map, false);
        if (!generateExample.isPresent()) {
            throw new IllegalArgumentException("Label not found in input " + map.toString());
        }
        Example<Label> example = (Example) generateExample.get();
        if (this.textDomain.size() == 0 && this.binarisedCDFs.size() == 0) {
            return explainWithSamples(example);
        }
        Prediction predict = this.innerModel.predict(example);
        ArrayExample arrayExample = new ArrayExample(transformOutput(predict));
        Iterator it = example.iterator();
        while (it.hasNext()) {
            Feature feature = (Feature) it.next();
            if (this.tabularDomain.getID(feature.getName()) != -1) {
                arrayExample.add(feature);
            }
        }
        SparseVector createSparseVector = SparseVector.createSparseVector(arrayExample, this.tabularDomain, false);
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (Map.Entry<String, FieldProcessor> entry : this.textFields.entrySet()) {
            String str = map.get(entry.getKey());
            if (str != null) {
                List<Token> list = this.tokenizerThreadLocal.get().tokenize(str);
                for (int i = 0; i < list.size(); i++) {
                    arrayExample.add(nameFeature(entry.getKey(), list.get(i).text, i), 1.0d);
                }
                hashMap.put(entry.getKey(), str);
                hashMap2.put(entry.getKey(), list);
            }
        }
        List<Example<Regressor>> sampleData = sampleData(createSparseVector, hashMap, hashMap2);
        SparseModel<Regressor> trainExplainer = trainExplainer(arrayExample, sampleData);
        ArrayList arrayList = new ArrayList(trainExplainer.predict(sampleData));
        arrayList.add(trainExplainer.predict(arrayExample));
        return new Pair<>(new LIMEExplanation(trainExplainer, predict, evaluator.evaluate(trainExplainer, arrayList, new SimpleDataSourceProvenance("LIMEColumnar sampled data", regressionFactory))), sampleData);
    }

    protected String nameFeature(String str, String str2, int i) {
        return str + "@" + str2 + "@idx" + i;
    }

    private List<Example<Regressor>> sampleData(SparseVector sparseVector, Map<String, String> map, Map<String, List<Token>> map2) {
        ArrayList arrayList = new ArrayList();
        Random random = new Random(this.rng.nextLong());
        for (int i = 0; i < this.numSamples; i++) {
            ListExample listExample = new ListExample(LabelFactory.UNKNOWN_LABEL);
            ArrayList arrayList2 = new ArrayList();
            Iterator it = this.tabularDomain.iterator();
            while (it.hasNext()) {
                RealInfo realInfo = (VariableInfo) it.next();
                double d = sparseVector.get(((VariableIDInfo) realInfo).getID());
                if (realInfo instanceof CategoricalInfo) {
                    double frequencyBasedSample = ((CategoricalInfo) realInfo).frequencyBasedSample(random, this.numTrainingExamples);
                    if (Math.abs(frequencyBasedSample) > 1.0E-10d) {
                        arrayList2.add(new Feature(realInfo.getName(), frequencyBasedSample));
                    }
                } else {
                    if (!(realInfo instanceof RealInfo)) {
                        throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + realInfo.getClass().getName());
                    }
                    RealInfo realInfo2 = realInfo;
                    if (random.nextDouble() < realInfo2.getCount() / this.numTrainingExamples) {
                        arrayList2.add(new Feature(realInfo.getName(), (random.nextGaussian() * Math.sqrt(realInfo2.getVariance())) + d));
                    }
                }
            }
            for (Map.Entry<String, double[]> entry : this.binarisedCDFs.entrySet()) {
                int sampleFromCDF = Util.sampleFromCDF(entry.getValue(), random);
                if (sampleFromCDF != entry.getValue().length - 1) {
                    arrayList2.add(new Feature(this.binarisedInfos.get(entry.getKey()).get(sampleFromCDF).getName(), 1.0d));
                }
            }
            listExample.addAll(arrayList2);
            double measureDistance = measureDistance(this.tabularDomain, this.numTrainingExamples, sparseVector, SparseVector.createSparseVector(listExample, this.tabularDomain, false));
            ArrayList arrayList3 = new ArrayList();
            ArrayList arrayList4 = new ArrayList();
            double d2 = 0.0d;
            long j = 0;
            for (Map.Entry<String, String> entry2 : map.entrySet()) {
                String value = entry2.getValue();
                List<Token> list = map2.get(entry2.getKey());
                j += list.size();
                int[] iArr = new int[list.size()];
                char[] charArray = value.toCharArray();
                for (int i2 = 0; i2 < iArr.length; i2++) {
                    iArr[i2] = random.nextInt(2);
                    if (iArr[i2] == 0) {
                        d2 += 1.0d;
                        Token token = list.get(i2);
                        Arrays.fill(charArray, token.start, token.end, (char) 0);
                    }
                }
                arrayList3.addAll(this.textFields.get(entry2.getKey()).process(new String(charArray).replace("��", "")));
                for (int i3 = 0; i3 < iArr.length; i3++) {
                    arrayList4.add(new Feature(nameFeature(entry2.getKey(), list.get(i3).text, i3), iArr[i3]));
                }
            }
            listExample.addAll(arrayList3);
            ArrayExample arrayExample = new ArrayExample(transformOutput(this.innerModel.predict(listExample)), (float) (1.0d - ((arrayList2.size() * (kernelDist(measureDistance, this.kernelWidth) + (arrayList4.size() * (d2 / j)))) / (arrayList2.size() + arrayList4.size()))));
            arrayExample.addAll(arrayList2);
            arrayExample.addAll(arrayList4);
            arrayList.add(arrayExample);
        }
        return arrayList;
    }

    @Override // org.tribuo.classification.explanations.ColumnarExplainer
    /* renamed from: explain, reason: avoid collision after fix types in other method */
    public /* bridge */ /* synthetic */ Explanation<Regressor> explain2(Map map) {
        return explain((Map<String, String>) map);
    }
}
