/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.sting.gatk.walkers.variantrecalibration;

import Jama.Matrix;
import cern.jet.random.Normal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.log4j.Logger;
import org.broadinstitute.sting.gatk.GenomeAnalysisEngine;
import org.broadinstitute.sting.gatk.walkers.variantrecalibration.MultivariateGaussian;
import org.broadinstitute.sting.gatk.walkers.variantrecalibration.VariantDatum;
import org.broadinstitute.sting.utils.MathUtils;

public class GaussianMixtureModel {
    protected static final Logger logger = Logger.getLogger(GaussianMixtureModel.class);
    private final ArrayList<MultivariateGaussian> gaussians;
    private final double shrinkage;
    private final double dirichletParameter;
    private final double priorCounts;
    private final double[] empiricalMu;
    private final Matrix empiricalSigma;
    public boolean isModelReadyForEvaluation;

    public GaussianMixtureModel(int numGaussians, int numAnnotations, double shrinkage, double dirichletParameter, double priorCounts) {
        this.gaussians = new ArrayList(numGaussians);
        for (int iii = 0; iii < numGaussians; ++iii) {
            MultivariateGaussian gaussian = new MultivariateGaussian(numAnnotations);
            this.gaussians.add(gaussian);
        }
        this.shrinkage = shrinkage;
        this.dirichletParameter = dirichletParameter;
        this.priorCounts = priorCounts;
        this.empiricalMu = new double[numAnnotations];
        this.empiricalSigma = new Matrix(numAnnotations, numAnnotations);
        this.isModelReadyForEvaluation = false;
        Arrays.fill(this.empiricalMu, 0.0);
        this.empiricalSigma.setMatrix(0, this.empiricalMu.length - 1, 0, this.empiricalMu.length - 1, Matrix.identity(this.empiricalMu.length, this.empiricalMu.length).times(200.0).inverse());
    }

    public void initializeRandomModel(List<VariantDatum> data, int numKMeansIterations) {
        for (MultivariateGaussian gaussian : this.gaussians) {
            gaussian.initializeRandomMu(GenomeAnalysisEngine.getRandomGenerator());
        }
        logger.info("Initializing model with " + numKMeansIterations + " k-means iterations...");
        this.initializeMeansUsingKMeans(data, numKMeansIterations);
        for (MultivariateGaussian gaussian : this.gaussians) {
            gaussian.pMixtureLog10 = Math.log10(1.0 / (double)this.gaussians.size());
            gaussian.sumProb = 1.0 / (double)this.gaussians.size();
            gaussian.initializeRandomSigma(GenomeAnalysisEngine.getRandomGenerator());
            gaussian.hyperParameter_a = this.priorCounts;
            gaussian.hyperParameter_b = this.shrinkage;
            gaussian.hyperParameter_lambda = this.dirichletParameter;
        }
    }

    private void initializeMeansUsingKMeans(List<VariantDatum> data, int numIterations) {
        int ttt = 0;
        while (ttt++ < numIterations) {
            for (VariantDatum datum : data) {
                MultivariateGaussian minGaussian;
                double minDistance = Double.MAX_VALUE;
                datum.assignment = minGaussian = null;
                for (MultivariateGaussian gaussian : this.gaussians) {
                    double dist = gaussian.calculateDistanceFromMeanSquared(datum);
                    if (!(dist < minDistance)) continue;
                    minDistance = dist;
                    minGaussian = gaussian;
                }
                datum.assignment = minGaussian;
            }
            for (MultivariateGaussian gaussian : this.gaussians) {
                gaussian.zeroOutMu();
                int numAssigned = 0;
                for (VariantDatum datum : data) {
                    if (!datum.assignment.equals(gaussian)) continue;
                    ++numAssigned;
                    gaussian.incrementMu(datum);
                }
                if (numAssigned != 0) {
                    gaussian.divideEqualsMu(numAssigned);
                    continue;
                }
                gaussian.initializeRandomMu(GenomeAnalysisEngine.getRandomGenerator());
            }
        }
    }

    public void expectationStep(List<VariantDatum> data) {
        for (MultivariateGaussian gaussian : this.gaussians) {
            gaussian.precomputeDenominatorForVariationalBayes(this.getSumHyperParameterLambda());
        }
        for (VariantDatum datum : data) {
            ArrayList<Double> pVarInGaussianLog10 = new ArrayList<Double>(this.gaussians.size());
            for (MultivariateGaussian gaussian : this.gaussians) {
                double pVarLog10 = gaussian.evaluateDatumLog10(datum);
                pVarInGaussianLog10.add(pVarLog10);
            }
            double[] pVarInGaussianNormalized = MathUtils.normalizeFromLog10(pVarInGaussianLog10);
            int iii = 0;
            for (MultivariateGaussian gaussian : this.gaussians) {
                gaussian.assignPVarInGaussian(pVarInGaussianNormalized[iii++]);
            }
        }
    }

    public void maximizationStep(List<VariantDatum> data) {
        for (MultivariateGaussian gaussian : this.gaussians) {
            gaussian.maximizeGaussian(data, this.empiricalMu, this.empiricalSigma, this.shrinkage, this.dirichletParameter, this.priorCounts);
        }
    }

    private double getSumHyperParameterLambda() {
        double sum = 0.0;
        for (MultivariateGaussian gaussian : this.gaussians) {
            sum += gaussian.hyperParameter_lambda;
        }
        return sum;
    }

    public void evaluateFinalModelParameters(List<VariantDatum> data) {
        for (MultivariateGaussian gaussian : this.gaussians) {
            gaussian.evaluateFinalModelParameters(data);
        }
        this.normalizePMixtureLog10();
    }

    public double normalizePMixtureLog10() {
        double sumDiff = 0.0;
        double sumPK = 0.0;
        for (MultivariateGaussian gaussian : this.gaussians) {
            sumPK += gaussian.sumProb;
        }
        int gaussianIndex = 0;
        double[] pGaussianLog10 = new double[this.gaussians.size()];
        for (MultivariateGaussian gaussian : this.gaussians) {
            pGaussianLog10[gaussianIndex++] = Math.log10(gaussian.sumProb / sumPK);
        }
        pGaussianLog10 = MathUtils.normalizeFromLog10(pGaussianLog10, true);
        gaussianIndex = 0;
        for (MultivariateGaussian gaussian : this.gaussians) {
            sumDiff += Math.abs(pGaussianLog10[gaussianIndex] - gaussian.pMixtureLog10);
            gaussian.pMixtureLog10 = pGaussianLog10[gaussianIndex++];
        }
        return sumDiff;
    }

    public void precomputeDenominatorForEvaluation() {
        for (MultivariateGaussian gaussian : this.gaussians) {
            gaussian.precomputeDenominatorForEvaluation();
        }
        this.isModelReadyForEvaluation = true;
    }

    public double evaluateDatum(VariantDatum datum) {
        for (boolean isNull : datum.isNull) {
            if (!isNull) continue;
            return this.evaluateDatumMarginalized(datum);
        }
        double[] pVarInGaussianLog10 = new double[this.gaussians.size()];
        int gaussianIndex = 0;
        for (MultivariateGaussian gaussian : this.gaussians) {
            pVarInGaussianLog10[gaussianIndex++] = gaussian.pMixtureLog10 + gaussian.evaluateDatumLog10(datum);
        }
        return MathUtils.log10sumLog10(pVarInGaussianLog10);
    }

    public Double evaluateDatumInOneDimension(VariantDatum datum, int iii) {
        if (datum.isNull[iii]) {
            return null;
        }
        Normal normal = new Normal(0.0, 1.0, null);
        double[] pVarInGaussianLog10 = new double[this.gaussians.size()];
        int gaussianIndex = 0;
        for (MultivariateGaussian gaussian : this.gaussians) {
            normal.setState(gaussian.mu[iii], gaussian.sigma.get(iii, iii));
            pVarInGaussianLog10[gaussianIndex++] = gaussian.pMixtureLog10 + Math.log10(normal.pdf(datum.annotations[iii]));
        }
        return MathUtils.log10sumLog10(pVarInGaussianLog10);
    }

    public double evaluateDatumMarginalized(VariantDatum datum) {
        int numRandomDraws = 0;
        double sumPVarInGaussian = 0.0;
        int numIterPerMissingAnnotation = 10;
        double[] pVarInGaussianLog10 = new double[this.gaussians.size()];
        for (int iii = 0; iii < datum.annotations.length; ++iii) {
            if (!datum.isNull[iii]) continue;
            for (int ttt = 0; ttt < 10; ++ttt) {
                datum.annotations[iii] = GenomeAnalysisEngine.getRandomGenerator().nextGaussian();
                int gaussianIndex = 0;
                for (MultivariateGaussian gaussian : this.gaussians) {
                    pVarInGaussianLog10[gaussianIndex++] = gaussian.pMixtureLog10 + gaussian.evaluateDatumLog10(datum);
                }
                sumPVarInGaussian += Math.pow(10.0, MathUtils.log10sumLog10(pVarInGaussianLog10));
                ++numRandomDraws;
            }
        }
        return Math.log10(sumPVarInGaussian / (double)numRandomDraws);
    }
}

