/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.sting.gatk.walkers.genotyper;

import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.apache.log4j.Logger;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.walkers.genotyper.AlleleFrequencyCalculationModel;
import org.broadinstitute.sting.gatk.walkers.genotyper.UnifiedArgumentCollection;
import org.broadinstitute.sting.utils.MathUtils;
import org.broadinstitute.sting.utils.SimpleTimer;
import org.broadinstitute.sting.utils.Utils;
import org.broadinstitute.sting.utils.exceptions.UserException;
import org.broadinstitute.sting.utils.variantcontext.Allele;
import org.broadinstitute.sting.utils.variantcontext.Genotype;
import org.broadinstitute.sting.utils.variantcontext.VariantContext;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;

public class ExactAFCalculationModel
extends AlleleFrequencyCalculationModel {
    private static final boolean DEBUG = false;
    private static final boolean PRINT_LIKELIHOODS = false;
    private static final int N_CYCLES = 1;
    private SimpleTimer timerExpt = new SimpleTimer("linearExactBanded");
    private SimpleTimer timerGS = new SimpleTimer("linearExactGS");
    private static final boolean COMPARE_TO_GS = false;
    private static final double MAX_LOG10_ERROR_TO_STOP_EARLY = 6.0;
    private boolean SIMPLE_GREEDY_GENOTYPER = false;
    private static final double SUM_GL_THRESH_NOCALL = -0.001;
    private final ExactCalculation calcToUse;

    protected ExactAFCalculationModel(UnifiedArgumentCollection UAC, int N, Logger logger, PrintStream verboseWriter) {
        super(UAC, N, logger, verboseWriter);
        this.calcToUse = UAC.EXACT_CALCULATION_TYPE;
    }

    @Override
    public void getLog10PNonRef(RefMetaDataTracker tracker, ReferenceContext ref, Map<String, Genotype> GLs, Set<Allele> alleles, double[] log10AlleleFrequencyPriors, double[] log10AlleleFrequencyPosteriors) {
        int numAlleles;
        int idxAA = AlleleFrequencyCalculationModel.GenotypeType.AA.ordinal();
        int idxAB = AlleleFrequencyCalculationModel.GenotypeType.AB.ordinal();
        int idxBB = AlleleFrequencyCalculationModel.GenotypeType.BB.ordinal();
        int lastK = -1;
        int idxDiag = numAlleles = alleles.size();
        int incr = numAlleles - 1;
        double[][] posteriorCache = new double[numAlleles - 1][];
        double[] bestAFguess = new double[numAlleles - 1];
        for (int k = 1; k < numAlleles; ++k) {
            idxAA = 0;
            idxAB = k;
            idxBB = idxDiag;
            idxDiag += incr--;
            switch (this.calcToUse) {
                case N2_GOLD_STANDARD: {
                    lastK = this.gdaN2GoldStandard(GLs, log10AlleleFrequencyPriors, log10AlleleFrequencyPosteriors, idxAA, idxAB, idxBB);
                    break;
                }
                case LINEAR_EXPERIMENTAL: {
                    lastK = this.linearExact(GLs, log10AlleleFrequencyPriors, log10AlleleFrequencyPosteriors, idxAA, idxAB, idxBB);
                }
            }
            if (numAlleles <= 2) continue;
            posteriorCache[k - 1] = (double[])log10AlleleFrequencyPosteriors.clone();
            bestAFguess[k - 1] = MathUtils.maxElementIndex(log10AlleleFrequencyPosteriors);
        }
        if (numAlleles > 2) {
            int mostLikelyAlleleIdx = MathUtils.maxElementIndex(bestAFguess);
            for (int k = 0; k < log10AlleleFrequencyPosteriors.length - 1; ++k) {
                log10AlleleFrequencyPosteriors[k] = posteriorCache[mostLikelyAlleleIdx][k];
            }
        }
    }

    private static final ArrayList<double[]> getGLs(Map<String, Genotype> GLs) {
        ArrayList<double[]> genotypeLikelihoods = new ArrayList<double[]>();
        genotypeLikelihoods.add(new double[]{0.0, 0.0, 0.0});
        for (Genotype sample : GLs.values()) {
            double[] gls;
            if (!sample.hasLikelihoods() || !(MathUtils.sum(gls = sample.getLikelihoods().getAsVector()) < -0.001)) continue;
            genotypeLikelihoods.add(gls);
        }
        return genotypeLikelihoods;
    }

    public int linearExactBanded(Map<String, Genotype> GLs, double[] log10AlleleFrequencyPriors, double[] log10AlleleFrequencyPosteriors) {
        throw new NotImplementedException();
    }

    public int linearExact(Map<String, Genotype> GLs, double[] log10AlleleFrequencyPriors, double[] log10AlleleFrequencyPosteriors, int idxAA, int idxAB, int idxBB) {
        ArrayList<double[]> genotypeLikelihoods = ExactAFCalculationModel.getGLs(GLs);
        int numSamples = genotypeLikelihoods.size() - 1;
        int numChr = 2 * numSamples;
        ExactACCache logY = new ExactACCache(numSamples + 1);
        logY.getkMinus0()[0] = 0.0;
        double maxLog10L = Double.NEGATIVE_INFINITY;
        boolean done = false;
        int lastK = -1;
        for (int k = 0; k <= numChr && !done; ++k) {
            double[] kMinus0 = logY.getkMinus0();
            if (k == 0) {
                for (int j = 1; j <= numSamples; ++j) {
                    kMinus0[j] = kMinus0[j - 1] + genotypeLikelihoods.get(j)[idxAA];
                }
            } else {
                double[] kMinus1 = logY.getkMinus1();
                double[] kMinus2 = logY.getkMinus2();
                for (int j = 1; j <= numSamples; ++j) {
                    double log10Max;
                    double[] gl = genotypeLikelihoods.get(j);
                    double logDenominator = MathUtils.log10Cache[2 * j] + MathUtils.log10Cache[2 * j - 1];
                    double aa = Double.NEGATIVE_INFINITY;
                    double ab = Double.NEGATIVE_INFINITY;
                    if (k < 2 * j - 1) {
                        aa = MathUtils.log10Cache[2 * j - k] + MathUtils.log10Cache[2 * j - k - 1] + kMinus0[j - 1] + gl[idxAA];
                    }
                    if (k < 2 * j) {
                        ab = MathUtils.log10Cache[2 * k] + MathUtils.log10Cache[2 * j - k] + kMinus1[j - 1] + gl[idxAB];
                    }
                    if (k > 1) {
                        double bb = MathUtils.log10Cache[k] + MathUtils.log10Cache[k - 1] + kMinus2[j - 1] + gl[idxBB];
                        log10Max = ExactAFCalculationModel.approximateLog10SumLog10(aa, ab, bb);
                    } else {
                        log10Max = ExactAFCalculationModel.approximateLog10SumLog10(aa, ab);
                    }
                    kMinus0[j] = log10Max - logDenominator;
                }
            }
            double log10LofK = kMinus0[numSamples];
            log10AlleleFrequencyPosteriors[k] = log10LofK + log10AlleleFrequencyPriors[k];
            lastK = k;
            if (log10LofK < (maxLog10L = Math.max(maxLog10L, log10LofK)) - 6.0) {
                done = true;
            }
            logY.rotate();
        }
        return lastK;
    }

    static final double approximateLog10SumLog10(double a, double b, double c) {
        return ExactAFCalculationModel.approximateLog10SumLog10(ExactAFCalculationModel.approximateLog10SumLog10(a, b), c);
    }

    static final double approximateLog10SumLog10(double small, double big) {
        if (small > big) {
            double t = big;
            big = small;
            small = t;
        }
        if (small == Double.NEGATIVE_INFINITY || big == Double.NEGATIVE_INFINITY) {
            return big;
        }
        if (big >= small + 10.0) {
            return big;
        }
        int ind = (int)Math.round((big - small) / 0.1);
        return big + MathUtils.jacobianLogTable[ind];
    }

    @Override
    public Map<String, Genotype> assignGenotypes(VariantContext vc, double[] log10AlleleFrequencyPosteriors, int AFofMaxLikelihood) {
        double qual;
        if (!vc.isVariant()) {
            throw new UserException("The VCF record passed in does not contain an ALT allele at " + vc.getChr() + ":" + vc.getStart());
        }
        Map<String, Genotype> GLs = vc.getGenotypes();
        double[][] pathMetricArray = new double[GLs.size() + 1][AFofMaxLikelihood + 1];
        int[][] tracebackArray = new int[GLs.size() + 1][AFofMaxLikelihood + 1];
        ArrayList<String> sampleIndices = new ArrayList<String>();
        int sampleIdx = 0;
        for (int k = 0; k <= AFofMaxLikelihood; ++k) {
            for (int j = 0; j <= GLs.size(); ++j) {
                pathMetricArray[j][k] = -1.0E30;
            }
        }
        pathMetricArray[0][0] = 0.0;
        if (this.SIMPLE_GREEDY_GENOTYPER || !vc.isBiallelic()) {
            sampleIndices.addAll(GLs.keySet());
            sampleIdx = GLs.size();
        } else {
            for (Map.Entry<String, Genotype> sample : GLs.entrySet()) {
                double[] likelihoods;
                if (!sample.getValue().hasLikelihoods() || MathUtils.sum(likelihoods = sample.getValue().getLikelihoods().getAsVector()) > -0.001) continue;
                sampleIndices.add(sample.getKey());
                for (int k = 0; k <= AFofMaxLikelihood; ++k) {
                    double m2;
                    double bestMetric = pathMetricArray[sampleIdx][k] + likelihoods[0];
                    int bestIndex = k;
                    if (k > 0 && (m2 = pathMetricArray[sampleIdx][k - 1] + likelihoods[1]) > bestMetric) {
                        bestMetric = m2;
                        bestIndex = k - 1;
                    }
                    if (k > 1 && (m2 = pathMetricArray[sampleIdx][k - 2] + likelihoods[2]) > bestMetric) {
                        bestMetric = m2;
                        bestIndex = k - 2;
                    }
                    pathMetricArray[sampleIdx + 1][k] = bestMetric;
                    tracebackArray[sampleIdx + 1][k] = bestIndex;
                }
                ++sampleIdx;
            }
        }
        HashMap<String, Genotype> calls = new HashMap<String, Genotype>();
        int startIdx = AFofMaxLikelihood;
        for (int k = sampleIdx; k > 0; --k) {
            int bestGTguess;
            String sample = (String)sampleIndices.get(k - 1);
            Genotype g = GLs.get(sample);
            if (!g.hasLikelihoods()) continue;
            ArrayList<Allele> myAlleles = new ArrayList<Allele>();
            qual = Double.NEGATIVE_INFINITY;
            double[] likelihoods = g.getLikelihoods().getAsVector();
            if (this.SIMPLE_GREEDY_GENOTYPER || !vc.isBiallelic()) {
                bestGTguess = Utils.findIndexOfMaxEntry(g.getLikelihoods().getAsVector());
            } else {
                int newIdx = tracebackArray[k][startIdx];
                bestGTguess = startIdx - newIdx;
                startIdx = newIdx;
            }
            for (int i = 0; i < likelihoods.length; ++i) {
                if (i == bestGTguess || !(likelihoods[i] >= qual)) continue;
                qual = likelihoods[i];
            }
            qual = likelihoods[bestGTguess] - qual;
            int kk = 0;
            boolean done = false;
            for (int j = 0; j < vc.getNAlleles(); ++j) {
                for (int i = 0; i <= j; ++i) {
                    if (kk++ != bestGTguess) continue;
                    if (i == 0) {
                        myAlleles.add(vc.getReference());
                    } else {
                        myAlleles.add(vc.getAlternateAllele(i - 1));
                    }
                    if (j == 0) {
                        myAlleles.add(vc.getReference());
                    } else {
                        myAlleles.add(vc.getAlternateAllele(j - 1));
                    }
                    done = true;
                    break;
                }
                if (done) break;
            }
            if (qual < 0.0) {
                double[] normalized = MathUtils.normalizeFromLog10(likelihoods);
                double chosenGenotype = normalized[bestGTguess];
                qual = -1.0 * Math.log10(1.0 - chosenGenotype);
            }
            calls.put(sample, new Genotype(sample, myAlleles, qual, null, g.getAttributes(), false));
        }
        for (Map.Entry<String, Genotype> sample : GLs.entrySet()) {
            if (!sample.getValue().hasLikelihoods()) continue;
            Genotype g = GLs.get(sample.getKey());
            double[] likelihoods = sample.getValue().getLikelihoods().getAsVector();
            if (MathUtils.sum(likelihoods) <= -0.001) continue;
            ArrayList<Allele> myAlleles = new ArrayList<Allele>();
            qual = -1.0;
            myAlleles.add(Allele.NO_CALL);
            myAlleles.add(Allele.NO_CALL);
            calls.put(sample.getKey(), new Genotype(sample.getKey(), myAlleles, qual, null, g.getAttributes(), false));
        }
        return calls;
    }

    public int gdaN2GoldStandard(Map<String, Genotype> GLs, double[] log10AlleleFrequencyPriors, double[] log10AlleleFrequencyPosteriors, int idxAA, int idxAB, int idxBB) {
        int numSamples = GLs.size();
        int numChr = 2 * numSamples;
        double[][] logYMatrix = new double[1 + numSamples][1 + numChr];
        for (int i = 0; i <= numSamples; ++i) {
            for (int j = 0; j <= numChr; ++j) {
                logYMatrix[i][j] = Double.NEGATIVE_INFINITY;
            }
        }
        logYMatrix[0][0] = 0.0;
        int j = 0;
        for (Map.Entry<String, Genotype> sample : GLs.entrySet()) {
            ++j;
            if (!sample.getValue().hasLikelihoods()) continue;
            double[] genotypeLikelihoods = sample.getValue().getLikelihoods().getAsVector();
            double logDenominator = MathUtils.log10Cache[2 * j] + MathUtils.log10Cache[2 * j - 1];
            logYMatrix[j][0] = logYMatrix[j - 1][0] + genotypeLikelihoods[idxAA];
            for (int k = 1; k <= 2 * j; ++k) {
                double[] logNumerator = new double[]{k < 2 * j - 1 ? MathUtils.log10Cache[2 * j - k] + MathUtils.log10Cache[2 * j - k - 1] + logYMatrix[j - 1][k] + genotypeLikelihoods[idxAA] : Double.NEGATIVE_INFINITY, k < 2 * j ? MathUtils.log10Cache[2 * k] + MathUtils.log10Cache[2 * j - k] + logYMatrix[j - 1][k - 1] + genotypeLikelihoods[idxAB] : Double.NEGATIVE_INFINITY, k > 1 ? MathUtils.log10Cache[k] + MathUtils.log10Cache[k - 1] + logYMatrix[j - 1][k - 2] + genotypeLikelihoods[idxBB] : Double.NEGATIVE_INFINITY};
                double logNum = MathUtils.softMax(logNumerator);
                logYMatrix[j][k] = logNum - logDenominator;
            }
        }
        for (int k = 0; k <= numChr; ++k) {
            log10AlleleFrequencyPosteriors[k] = logYMatrix[j][k] + log10AlleleFrequencyPriors[k];
        }
        return numChr;
    }

    private static final void printLikelihoods(int numChr, double[][] logYMatrix, double[] log10AlleleFrequencyPriors) {
        int j = logYMatrix.length - 1;
        System.out.printf("-----------------------------------%n", new Object[0]);
        for (int k = 0; k <= numChr; ++k) {
            double posterior = logYMatrix[j][k] + log10AlleleFrequencyPriors[k];
            System.out.printf("  %4d\t%8.2f\t%8.2f\t%8.2f%n", k, logYMatrix[j][k], log10AlleleFrequencyPriors[k], posterior);
        }
    }

    private static final class ExactACCache {
        double[] kMinus2;
        double[] kMinus1;
        double[] kMinus0;

        private static final double[] create(int n) {
            return new double[n];
        }

        public ExactACCache(int n) {
            this.kMinus2 = ExactACCache.create(n);
            this.kMinus1 = ExactACCache.create(n);
            this.kMinus0 = ExactACCache.create(n);
        }

        public final void rotate() {
            double[] tmp = this.kMinus2;
            this.kMinus2 = this.kMinus1;
            this.kMinus1 = this.kMinus0;
            this.kMinus0 = tmp;
        }

        public final double[] getkMinus2() {
            return this.kMinus2;
        }

        public final double[] getkMinus1() {
            return this.kMinus1;
        }

        public final double[] getkMinus0() {
            return this.kMinus0;
        }
    }

    public static enum ExactCalculation {
        N2_GOLD_STANDARD,
        LINEAR_EXPERIMENTAL;

    }
}

