package org.tribuo.classification.example;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import org.tribuo.Example;
import org.tribuo.classification.Label;
import org.tribuo.impl.ArrayExample;

/* loaded from: input_file:org/tribuo/classification/example/GaussianLabelDataSource.class */
public final class GaussianLabelDataSource extends DemoLabelDataSource {

    @Config(mandatory = true, description = "2d mean of the first Gaussian.")
    private double[] firstMean;

    @Config(mandatory = true, description = "4 element covariance matrix of the first Gaussian.")
    private double[] firstCovarianceMatrix;

    @Config(mandatory = true, description = "2d mean of the second Gaussian.")
    private double[] secondMean;

    @Config(mandatory = true, description = "4 element covariance matrix of the second Gaussian.")
    private double[] secondCovarianceMatrix;
    private double[] firstCholesky;
    private double[] secondCholesky;

    private GaussianLabelDataSource() {
    }

    public GaussianLabelDataSource(int i, long j, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4) {
        super(i, j);
        this.firstMean = dArr;
        this.firstCovarianceMatrix = dArr2;
        this.secondMean = dArr3;
        this.secondCovarianceMatrix = dArr4;
        postConfig();
    }

    @Override // org.tribuo.classification.example.DemoLabelDataSource
    public void postConfig() {
        if (this.firstMean.length != 2) {
            throw new PropertyException("", "firstMean", "firstMean is not the right length");
        }
        if (this.secondMean.length != 2) {
            throw new PropertyException("", "secondMean", "secondMean is not the right length");
        }
        if (this.firstCovarianceMatrix.length != 4) {
            throw new PropertyException("", "firstCovarianceMatrix", "firstCovarianceMatrix is not the right length");
        }
        if (this.secondCovarianceMatrix.length != 4) {
            throw new PropertyException("", "secondCovarianceMatrix", "secondCovarianceMatrix is not the right length");
        }
        for (int i = 0; i < this.firstCovarianceMatrix.length; i++) {
            if (this.firstCovarianceMatrix[i] < 0.0d) {
                throw new PropertyException("", "firstCovarianceMatrix", "First covariance matrix is not positive semi-definite");
            }
            if (this.secondCovarianceMatrix[i] < 0.0d) {
                throw new PropertyException("", "secondCovarianceMatrix", "Second covariance matrix is not positive semi-definite");
            }
        }
        if (this.firstCovarianceMatrix[1] != this.firstCovarianceMatrix[2]) {
            throw new PropertyException("", "firstCovarianceMatrix", "First covariance matrix is not a covariance matrix");
        }
        if (this.secondCovarianceMatrix[1] != this.secondCovarianceMatrix[2]) {
            throw new PropertyException("", "secondCovarianceMatrix", "Second covariance matrix is not a covariance matrix");
        }
        this.firstCholesky = new double[3];
        this.firstCholesky[0] = Math.sqrt(this.firstCovarianceMatrix[0]);
        this.firstCholesky[1] = this.firstCovarianceMatrix[1] / Math.sqrt(this.firstCovarianceMatrix[0]);
        this.firstCholesky[2] = Math.sqrt((this.firstCovarianceMatrix[3] * this.firstCovarianceMatrix[0]) - (this.firstCovarianceMatrix[1] * this.firstCovarianceMatrix[1])) / Math.sqrt(this.firstCovarianceMatrix[0]);
        this.secondCholesky = new double[3];
        this.secondCholesky[0] = Math.sqrt(this.secondCovarianceMatrix[0]);
        this.secondCholesky[1] = this.secondCovarianceMatrix[1] / Math.sqrt(this.secondCovarianceMatrix[0]);
        this.secondCholesky[2] = Math.sqrt((this.secondCovarianceMatrix[3] * this.secondCovarianceMatrix[0]) - (this.secondCovarianceMatrix[1] * this.secondCovarianceMatrix[1])) / Math.sqrt(this.secondCovarianceMatrix[0]);
        super.postConfig();
    }

    @Override // org.tribuo.classification.example.DemoLabelDataSource
    protected List<Example<Label>> generate() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.numSamples / 2; i++) {
            arrayList.add(new ArrayExample(FIRST_CLASS, FEATURE_NAMES, sampleGaussian(this.rng, this.firstMean, this.firstCholesky)));
        }
        for (int i2 = this.numSamples / 2; i2 < this.numSamples; i2++) {
            arrayList.add(new ArrayExample(SECOND_CLASS, FEATURE_NAMES, sampleGaussian(this.rng, this.secondMean, this.secondCholesky)));
        }
        return arrayList;
    }

    private static double[] sampleGaussian(Random random, double[] dArr, double[] dArr2) {
        double nextGaussian = random.nextGaussian();
        return new double[]{dArr[0] + (nextGaussian * dArr2[0]), dArr[1] + (nextGaussian * dArr2[1]) + (random.nextGaussian() * dArr2[2])};
    }

    public String toString() {
        return "GaussianGenerator(numSamples=" + this.numSamples + ",seed=" + this.seed + ",firstMean=" + Arrays.toString(this.firstMean) + ",firstCovarianceMatrix=" + Arrays.toString(this.firstCovarianceMatrix) + ",secondMean=" + Arrays.toString(this.secondMean) + ",secondCovarianceMatrix=" + Arrays.toString(this.secondCovarianceMatrix) + ')';
    }
}
