package org.tribuo.data.csv;

import com.opencsv.CSVParserWriter;
import com.opencsv.RFC4180ParserBuilder;
import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.Output;
import org.tribuo.VariableIDInfo;
import org.tribuo.data.columnar.processors.response.BinaryResponseProcessor;

/* loaded from: input_file:org/tribuo/data/csv/CSVSaver.class */
public class CSVSaver implements Configurable {
    private static final Logger logger = Logger.getLogger(CSVSaver.class.getName());
    public static final String DEFAULT_RESPONSE = "Response";

    @Config(description = "The column separator.")
    private char separator;

    @Config(description = "The quote character.")
    private char quote;

    public CSVSaver(char c, char c2) {
        this.separator = ',';
        this.quote = '\"';
        if (c == c2) {
            throw new IllegalArgumentException("Quote and separator must be different characters.");
        }
        this.separator = c;
        this.quote = c2;
    }

    public CSVSaver() {
        this(',', '\"');
    }

    public <T extends Output<T>> void save(Path path, Dataset<T> dataset, String str) throws IOException {
        save(path, dataset, Collections.singleton(str));
    }

    public <T extends Output<T>> void save(Path path, Dataset<T> dataset, Set<String> set) throws IOException {
        boolean z = set.size() > 1;
        ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
        int size = featureIDMap.size() + set.size();
        String[] strArr = new String[size];
        HashMap hashMap = new HashMap();
        int i = 0;
        for (String str : set) {
            strArr[i] = str;
            hashMap.put(str, Integer.valueOf(i));
            i++;
        }
        for (int i2 = 0; i2 < featureIDMap.size(); i2++) {
            int i3 = i;
            i++;
            strArr[i3] = featureIDMap.get(i2).getName();
        }
        CSVParserWriter cSVParserWriter = new CSVParserWriter(Files.newBufferedWriter(path, StandardCharsets.UTF_8, new OpenOption[0]), new RFC4180ParserBuilder().withSeparator(this.separator).withQuoteChar(this.quote).build(), "\n");
        try {
            cSVParserWriter.writeNext(strArr);
            Iterator it = dataset.iterator();
            while (it.hasNext()) {
                Example example = (Example) it.next();
                String[] densifyMultiOutput = z ? densifyMultiOutput(example, hashMap) : densifySingleOutput(example);
                String[] generateFeatureArray = generateFeatureArray(example, featureIDMap);
                if (generateFeatureArray.length != featureIDMap.size()) {
                    throw new IllegalStateException(String.format("Invalid example: had %d features, expected %d.", Integer.valueOf(generateFeatureArray.length), Integer.valueOf(featureIDMap.size())));
                }
                String[] strArr2 = new String[size];
                System.arraycopy(densifyMultiOutput, 0, strArr2, 0, densifyMultiOutput.length);
                System.arraycopy(generateFeatureArray, 0, strArr2, densifyMultiOutput.length, generateFeatureArray.length);
                cSVParserWriter.writeNext(strArr2);
            }
            cSVParserWriter.close();
        } catch (Throwable th) {
            try {
                cSVParserWriter.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    private static <T extends Output<T>> String[] densifySingleOutput(Example<T> example) {
        return new String[]{example.getOutput().getSerializableForm(false)};
    }

    private static <T extends Output<T>> String[] densifyMultiOutput(Example<T> example, Map<String, Integer> map) {
        String[] strArr = new String[map.size()];
        Arrays.fill(strArr, BinaryResponseProcessor.NEGATIVE_NAME);
        String serializableForm = example.getOutput().getSerializableForm(false);
        if (serializableForm.isEmpty()) {
            return strArr;
        }
        for (String str : serializableForm.split(",")) {
            String[] split = str.split("=");
            if (split.length != 2) {
                throw new IllegalArgumentException("Bad serialized string element: '" + str + "'");
            }
            String str2 = split[0];
            String str3 = split[1];
            int intValue = map.getOrDefault(str2, -1).intValue();
            if (intValue != -1) {
                strArr[intValue] = str3;
            } else if (!str2.equals("")) {
                throw new IllegalStateException(String.format("Invalid example: unknown response name '%s'. (known response names: %s)", str2, map.keySet()));
            }
        }
        return strArr;
    }

    private static <T extends Output<T>> String[] generateFeatureArray(Example<T> example, ImmutableFeatureMap immutableFeatureMap) {
        String[] strArr = new String[immutableFeatureMap.size()];
        HashMap hashMap = new HashMap();
        Iterator it = example.iterator();
        while (it.hasNext()) {
            Feature feature = (Feature) it.next();
            VariableIDInfo variableIDInfo = immutableFeatureMap.get(feature.getName());
            if (variableIDInfo != null) {
                hashMap.put(Integer.valueOf(variableIDInfo.getID()), Double.valueOf(feature.getValue()));
            }
        }
        for (int i = 0; i < immutableFeatureMap.size(); i++) {
            Double d = (Double) hashMap.get(Integer.valueOf(i));
            if (d == null) {
                strArr[i] = BinaryResponseProcessor.NEGATIVE_NAME;
            } else {
                strArr[i] = d.toString();
            }
        }
        return strArr;
    }
}
