/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.sting.gatk.walkers.annotator;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import net.sf.samtools.SAMRecord;
import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
import org.broadinstitute.sting.gatk.contexts.AlignmentContextUtils;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.walkers.annotator.interfaces.AnnotatorCompatibleWalker;
import org.broadinstitute.sting.gatk.walkers.annotator.interfaces.InfoFieldAnnotation;
import org.broadinstitute.sting.gatk.walkers.annotator.interfaces.StandardAnnotation;
import org.broadinstitute.sting.gatk.walkers.genotyper.IndelGenotypeLikelihoodsCalculationModel;
import org.broadinstitute.sting.utils.BaseUtils;
import org.broadinstitute.sting.utils.MathUtils;
import org.broadinstitute.sting.utils.QualityUtils;
import org.broadinstitute.sting.utils.codecs.vcf.VCFHeaderLineType;
import org.broadinstitute.sting.utils.codecs.vcf.VCFInfoHeaderLine;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
import org.broadinstitute.sting.utils.genotype.Haplotype;
import org.broadinstitute.sting.utils.pileup.PileupElement;
import org.broadinstitute.sting.utils.pileup.ReadBackedPileup;
import org.broadinstitute.sting.utils.sam.AlignmentUtils;
import org.broadinstitute.sting.utils.variantcontext.Allele;
import org.broadinstitute.sting.utils.variantcontext.Genotype;
import org.broadinstitute.sting.utils.variantcontext.VariantContext;

public class HaplotypeScore
extends InfoFieldAnnotation
implements StandardAnnotation {
    private static final boolean DEBUG = false;
    private static final int MIN_CONTEXT_WING_SIZE = 10;
    private static final int MAX_CONSENSUS_HAPLOTYPES_TO_CONSIDER = 50;
    private static final char REGEXP_WILDCARD = '.';

    @Override
    public Map<String, Object> annotate(RefMetaDataTracker tracker, AnnotatorCompatibleWalker walker, ReferenceContext ref, Map<String, AlignmentContext> stratifiedContexts, VariantContext vc) {
        if (stratifiedContexts.size() == 0) {
            return null;
        }
        if (vc.isSNP() && !vc.isBiallelic()) {
            return null;
        }
        AlignmentContext context = AlignmentContextUtils.joinContexts(stratifiedContexts.values());
        int contextWingSize = Math.min(((int)ref.getWindow().size() - 1) / 2, 10);
        int contextSize = contextWingSize * 2 + 1;
        int locus = ref.getLocus().getStart() + (ref.getLocus().getStop() - ref.getLocus().getStart()) / 2;
        ReadBackedPileup pileup = null;
        if (context.hasExtendedEventPileup()) {
            pileup = context.getExtendedEventPileup();
        } else if (context.hasBasePileup()) {
            pileup = context.getBasePileup();
        }
        if (pileup == null) {
            return null;
        }
        List<Haplotype> haplotypes = this.computeHaplotypes(pileup, contextSize, locus, vc);
        MathUtils.RunningAverage scoreRA = new MathUtils.RunningAverage();
        if (haplotypes != null) {
            Set<Map.Entry<String, Genotype>> genotypes = vc.getGenotypes().entrySet();
            for (Map.Entry<String, Genotype> genotype : genotypes) {
                ReadBackedPileup thisPileup;
                AlignmentContext thisContext = stratifiedContexts.get(genotype.getKey());
                if (thisContext == null || (thisPileup = thisContext.hasExtendedEventPileup() ? thisContext.getExtendedEventPileup() : (thisContext.hasBasePileup() ? thisContext.getBasePileup() : null)) == null) continue;
                if (vc.isSNP()) {
                    scoreRA.add(this.scoreReadsAgainstHaplotypes(haplotypes, thisPileup, contextSize, locus));
                    continue;
                }
                if (vc.isIndel() || vc.isMixed()) {
                    Double d = this.scoreIndelsAgainstHaplotypes(thisPileup);
                    if (d == null) {
                        return null;
                    }
                    scoreRA.add(d);
                    continue;
                }
                return null;
            }
        }
        HashMap<String, Object> map = new HashMap<String, Object>();
        map.put(this.getKeyNames().get(0), String.format("%.4f", scoreRA.mean()));
        return map;
    }

    private List<Haplotype> computeHaplotypes(ReadBackedPileup pileup, int contextSize, int locus, VariantContext vc) {
        Haplotype elem;
        int haplotypesToCompute = vc.getAlternateAlleles().size() + 1;
        PriorityQueue<Haplotype> candidateHaplotypeQueue = new PriorityQueue<Haplotype>(100, new HaplotypeComparator());
        PriorityQueue<Haplotype> consensusHaplotypeQueue = new PriorityQueue<Haplotype>(50, new HaplotypeComparator());
        for (PileupElement p : pileup) {
            Haplotype haplotypeFromRead = this.getHaplotypeFromRead(p, contextSize, locus);
            candidateHaplotypeQueue.add(haplotypeFromRead);
        }
        while ((elem = candidateHaplotypeQueue.poll()) != null) {
            boolean foundHaplotypeMatch = false;
            Haplotype lastCheckedHaplotype = null;
            for (Haplotype haplotypeFromList : consensusHaplotypeQueue) {
                Haplotype consensusHaplotype = this.getConsensusHaplotype(elem, haplotypeFromList);
                if (consensusHaplotype != null) {
                    foundHaplotypeMatch = true;
                    if (!(consensusHaplotype.getQualitySum() > haplotypeFromList.getQualitySum())) break;
                    consensusHaplotypeQueue.remove(haplotypeFromList);
                    consensusHaplotypeQueue.add(consensusHaplotype);
                    break;
                }
                lastCheckedHaplotype = haplotypeFromList;
            }
            if (!foundHaplotypeMatch && consensusHaplotypeQueue.size() < 50) {
                consensusHaplotypeQueue.add(elem);
                continue;
            }
            if (foundHaplotypeMatch || lastCheckedHaplotype == null || !(elem.getQualitySum() > lastCheckedHaplotype.getQualitySum())) continue;
            consensusHaplotypeQueue.remove(lastCheckedHaplotype);
            consensusHaplotypeQueue.add(elem);
        }
        if (consensusHaplotypeQueue.size() > 0) {
            Haplotype haplotype1 = consensusHaplotypeQueue.poll();
            ArrayList<Haplotype> hlist = new ArrayList<Haplotype>();
            hlist.add(new Haplotype(haplotype1.getBasesAsBytes(), 60));
            for (int k = 1; k < haplotypesToCompute; ++k) {
                Haplotype haplotype2 = consensusHaplotypeQueue.poll();
                if (haplotype2 == null) {
                    haplotype2 = haplotype1;
                }
                hlist.add(new Haplotype(haplotype2.getBasesAsBytes(), 20));
            }
            return hlist;
        }
        return null;
    }

    private Haplotype getHaplotypeFromRead(PileupElement p, int contextSize, int locus) {
        SAMRecord read = p.getRead();
        int readOffsetFromPileup = p.getOffset();
        byte[] haplotypeBases = new byte[contextSize];
        Arrays.fill(haplotypeBases, (byte)46);
        double[] baseQualities = new double[contextSize];
        Arrays.fill(baseQualities, 0.0);
        byte[] readBases = read.getReadBases();
        readBases = AlignmentUtils.readToAlignmentByteArray(p.getRead().getCigar(), readBases);
        byte[] readQuals = read.getBaseQualities();
        readQuals = AlignmentUtils.readToAlignmentByteArray(p.getRead().getCigar(), readQuals);
        readOffsetFromPileup = AlignmentUtils.calcAlignmentByteArrayOffset(p.getRead().getCigar(), readOffsetFromPileup, p.getRead().getAlignmentStart(), locus);
        int baseOffsetStart = readOffsetFromPileup - (contextSize - 1) / 2;
        for (int i = 0; i < contextSize; ++i) {
            int baseOffset = i + baseOffsetStart;
            if (baseOffset < 0) continue;
            if (baseOffset >= readBases.length) break;
            if (readQuals[baseOffset] == 68) {
                readQuals[baseOffset] = 16;
            }
            if (!BaseUtils.isRegularBase(readBases[baseOffset])) {
                readBases[baseOffset] = 46;
                readQuals[baseOffset] = 0;
            }
            readQuals[baseOffset] = (byte)Math.min(readQuals[baseOffset], p.getMappingQual());
            if (readQuals[baseOffset] < 5) {
                readQuals[baseOffset] = 0;
            }
            haplotypeBases[i] = readBases[baseOffset];
            baseQualities[i] = readQuals[baseOffset];
        }
        return new Haplotype(haplotypeBases, baseQualities);
    }

    private Haplotype getConsensusHaplotype(Haplotype haplotypeA, Haplotype haplotypeB) {
        byte[] b;
        byte[] a = haplotypeA.getBasesAsBytes();
        if (a.length != (b = haplotypeB.getBasesAsBytes()).length) {
            throw new ReviewedStingException("Haplotypes a and b must be of same length");
        }
        int wc = 46;
        int length = a.length;
        byte[] consensusChars = new byte[length];
        double[] consensusQuals = new double[length];
        double[] qualsA = haplotypeA.getQuals();
        double[] qualsB = haplotypeB.getQuals();
        for (int i = 0; i < length; ++i) {
            byte chA = a[i];
            byte chB = b[i];
            if (chA != chB && chA != 46 && chB != 46) {
                return null;
            }
            if (chA == 46 && chB == 46) {
                consensusChars[i] = 46;
                consensusQuals[i] = 0.0;
                continue;
            }
            if (chA == 46) {
                consensusChars[i] = chB;
                consensusQuals[i] = qualsB[i];
                continue;
            }
            if (chB == 46) {
                consensusChars[i] = chA;
                consensusQuals[i] = qualsA[i];
                continue;
            }
            consensusChars[i] = chA;
            consensusQuals[i] = qualsA[i] + qualsB[i];
        }
        return new Haplotype(consensusChars, consensusQuals);
    }

    private double scoreReadsAgainstHaplotypes(List<Haplotype> haplotypes, ReadBackedPileup pileup, int contextSize, int locus) {
        ArrayList<double[]> haplotypeScores = new ArrayList<double[]>();
        for (PileupElement p : pileup) {
            double[] scores = new double[haplotypes.size()];
            for (int i = 0; i < haplotypes.size(); ++i) {
                double score;
                Haplotype haplotype = haplotypes.get(i);
                scores[i] = score = this.scoreReadAgainstHaplotype(p, contextSize, haplotype, locus);
            }
            haplotypeScores.add(scores);
        }
        double overallScore = 0.0;
        for (double[] readHaplotypeScores : haplotypeScores) {
            overallScore += MathUtils.arrayMin(readHaplotypeScores);
        }
        return overallScore;
    }

    private double scoreReadAgainstHaplotype(PileupElement p, int contextSize, Haplotype haplotype, int locus) {
        double expected = 0.0;
        double mismatches = 0.0;
        byte[] haplotypeBases = haplotype.getBasesAsBytes();
        SAMRecord read = p.getRead();
        byte[] readBases = read.getReadBases();
        readBases = AlignmentUtils.readToAlignmentByteArray(p.getRead().getCigar(), readBases);
        byte[] readQuals = read.getBaseQualities();
        readQuals = AlignmentUtils.readToAlignmentByteArray(p.getRead().getCigar(), readQuals);
        int readOffsetFromPileup = p.getOffset();
        readOffsetFromPileup = AlignmentUtils.calcAlignmentByteArrayOffset(p.getRead().getCigar(), readOffsetFromPileup, p.getRead().getAlignmentStart(), locus);
        int baseOffsetStart = readOffsetFromPileup - (contextSize - 1) / 2;
        for (int i = 0; i < contextSize; ++i) {
            int baseOffset = i + baseOffsetStart;
            if (baseOffset < 0) continue;
            if (baseOffset >= readBases.length) break;
            byte readBase = readBases[baseOffset];
            byte haplotypeBase = haplotypeBases[i];
            boolean matched = readBase == haplotypeBase || haplotypeBase == 46;
            byte qual = readQuals[baseOffset];
            if (qual == 68) {
                qual = 16;
            }
            if ((qual = (byte)((byte)Math.min(qual, p.getMappingQual()))) < 5) continue;
            double e = QualityUtils.qualToErrorProb(qual);
            expected += e;
            mismatches += matched ? e : 1.0 - e / 3.0;
        }
        return mismatches - expected;
    }

    private Double scoreIndelsAgainstHaplotypes(ReadBackedPileup pileup) {
        ArrayList<double[]> haplotypeScores = new ArrayList<double[]>();
        HashMap<PileupElement, LinkedHashMap<Allele, Double>> indelLikelihoodMap = IndelGenotypeLikelihoodsCalculationModel.getIndelLikelihoodMap();
        if (indelLikelihoodMap == null) {
            return null;
        }
        for (PileupElement p : pileup) {
            if (!indelLikelihoodMap.containsKey(p)) continue;
            LinkedHashMap<Allele, Double> el = indelLikelihoodMap.get(p);
            double[] scores = new double[el.size()];
            int i = 0;
            for (Allele a : el.keySet()) {
                scores[i++] = -el.get(a).doubleValue();
            }
            haplotypeScores.add(scores);
        }
        double overallScore = 0.0;
        for (double[] readHaplotypeScores : haplotypeScores) {
            overallScore += MathUtils.arrayMin(readHaplotypeScores);
        }
        return overallScore;
    }

    @Override
    public List<String> getKeyNames() {
        return Arrays.asList("HaplotypeScore");
    }

    @Override
    public List<VCFInfoHeaderLine> getDescriptions() {
        return Arrays.asList(new VCFInfoHeaderLine("HaplotypeScore", 1, VCFHeaderLineType.Float, "Consistency of the site with at most two segregating haplotypes"));
    }

    private class HaplotypeComparator
    implements Comparator<Haplotype> {
        private HaplotypeComparator() {
        }

        @Override
        public int compare(Haplotype a, Haplotype b) {
            if (a.getQualitySum() < b.getQualitySum()) {
                return 1;
            }
            if (a.getQualitySum() > b.getQualitySum()) {
                return -1;
            }
            return 0;
        }
    }
}

