/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.sting.gatk.walkers.phasing;

import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.broadinstitute.sting.commandline.Argument;
import org.broadinstitute.sting.commandline.ArgumentCollection;
import org.broadinstitute.sting.commandline.Output;
import org.broadinstitute.sting.gatk.arguments.StandardVariantContextInputArgumentCollection;
import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.walkers.RodWalker;
import org.broadinstitute.sting.utils.MathUtils;
import org.broadinstitute.sting.utils.SampleUtils;
import org.broadinstitute.sting.utils.codecs.vcf.VCFFormatHeaderLine;
import org.broadinstitute.sting.utils.codecs.vcf.VCFHeader;
import org.broadinstitute.sting.utils.codecs.vcf.VCFHeaderLine;
import org.broadinstitute.sting.utils.codecs.vcf.VCFHeaderLineType;
import org.broadinstitute.sting.utils.codecs.vcf.VCFUtils;
import org.broadinstitute.sting.utils.codecs.vcf.VCFWriter;
import org.broadinstitute.sting.utils.text.XReadLines;
import org.broadinstitute.sting.utils.variantcontext.Allele;
import org.broadinstitute.sting.utils.variantcontext.Genotype;
import org.broadinstitute.sting.utils.variantcontext.VariantContext;
import org.broadinstitute.sting.utils.variantcontext.VariantContextUtils;

public class PhaseByTransmission
extends RodWalker<Integer, Integer> {
    @ArgumentCollection
    protected StandardVariantContextInputArgumentCollection variantCollection = new StandardVariantContextInputArgumentCollection();
    @Argument(shortName="f", fullName="familySpec", required=true, doc="Patterns for the family structure (usage: mom+dad=child).  Specify several trios by supplying this argument many times and/or a file containing many patterns.")
    public ArrayList<String> familySpecs = null;
    @Output
    protected VCFWriter vcfWriter = null;
    private final String TRANSMISSION_PROBABILITY_TAG_NAME = "TP";
    private final String SOURCE_NAME = "PhaseByTransmission";
    private final Double MENDELIAN_VIOLATION_PRIOR = 1.0E-8;
    private ArrayList<Trio> trios = new ArrayList();

    public ArrayList<Trio> getFamilySpecsFromCommandLineInput(ArrayList<String> familySpecs) {
        if (familySpecs != null) {
            ArrayList<Trio> specs = new ArrayList<Trio>();
            for (String familySpec : familySpecs) {
                File specFile = new File(familySpec);
                try {
                    XReadLines reader = new XReadLines(specFile);
                    List<String> lines = reader.readLines();
                    for (String line : lines) {
                        specs.add(new Trio(line));
                    }
                }
                catch (FileNotFoundException e) {
                    specs.add(new Trio(familySpec));
                }
            }
            return specs;
        }
        return new ArrayList<Trio>();
    }

    @Override
    public void initialize() {
        this.trios = this.getFamilySpecsFromCommandLineInput(this.familySpecs);
        ArrayList<String> rodNames = new ArrayList<String>();
        rodNames.add(this.variantCollection.variants.getName());
        Map<String, VCFHeader> vcfRods = VCFUtils.getVCFHeadersFromRods(this.getToolkit(), rodNames);
        Set<String> vcfSamples = SampleUtils.getSampleList(vcfRods, VariantContextUtils.GenotypeMergeType.REQUIRE_UNIQUE);
        HashSet<VCFHeaderLine> headerLines = new HashSet<VCFHeaderLine>();
        headerLines.addAll(VCFUtils.getHeaderFields(this.getToolkit()));
        headerLines.add(new VCFFormatHeaderLine("TP", 1, VCFHeaderLineType.Float, "Probability that the phase is correct given that the genotypes are correct"));
        headerLines.add(new VCFHeaderLine("source", "PhaseByTransmission"));
        this.vcfWriter.writeHeader(new VCFHeader(headerLines, vcfSamples));
    }

    private double computeTransmissionLikelihoodOfGenotypeConfiguration(Genotype mom, Genotype dad, Genotype child) {
        double[] momLikelihoods = MathUtils.normalizeFromLog10(mom.getLikelihoods().getAsVector());
        double[] dadLikelihoods = MathUtils.normalizeFromLog10(dad.getLikelihoods().getAsVector());
        double[] childLikelihoods = MathUtils.normalizeFromLog10(child.getLikelihoods().getAsVector());
        int momIndex = mom.getType().ordinal() - 1;
        int dadIndex = dad.getType().ordinal() - 1;
        int childIndex = child.getType().ordinal() - 1;
        return momLikelihoods[momIndex] * dadLikelihoods[dadIndex] * childLikelihoods[childIndex];
    }

    private ArrayList<Genotype> createAllThreeGenotypes(Allele refAllele, Allele altAllele, Genotype g) {
        ArrayList<Allele> homRefAlleles = new ArrayList<Allele>();
        homRefAlleles.add(refAllele);
        homRefAlleles.add(refAllele);
        Genotype homRef = new Genotype(g.getSampleName(), homRefAlleles, g.getNegLog10PError(), null, g.getAttributes(), false);
        ArrayList<Allele> hetAlleles = new ArrayList<Allele>();
        hetAlleles.add(refAllele);
        hetAlleles.add(altAllele);
        Genotype het = new Genotype(g.getSampleName(), hetAlleles, g.getNegLog10PError(), null, g.getAttributes(), false);
        ArrayList<Allele> homVarAlleles = new ArrayList<Allele>();
        homVarAlleles.add(altAllele);
        homVarAlleles.add(altAllele);
        Genotype homVar = new Genotype(g.getSampleName(), homVarAlleles, g.getNegLog10PError(), null, g.getAttributes(), false);
        ArrayList<Genotype> genotypes = new ArrayList<Genotype>();
        genotypes.add(homRef);
        genotypes.add(het);
        genotypes.add(homVar);
        return genotypes;
    }

    private int getNumberOfMatchingAlleles(Allele alleleToMatch, Genotype g) {
        List<Allele> alleles = g.getAlleles();
        int matchingAlleles = 0;
        for (Allele a : alleles) {
            if (alleleToMatch.equals(a)) continue;
            ++matchingAlleles;
        }
        return matchingAlleles;
    }

    private boolean isMendelianViolation(Allele refAllele, Allele altAllele, Genotype mom, Genotype dad, Genotype child) {
        int numMomRefAlleles = this.getNumberOfMatchingAlleles(refAllele, mom) > 0 ? 1 : 0;
        int numMomAltAlleles = this.getNumberOfMatchingAlleles(altAllele, mom) > 0 ? 1 : 0;
        int numDadRefAlleles = this.getNumberOfMatchingAlleles(refAllele, dad) > 0 ? 1 : 0;
        int numDadAltAlleles = this.getNumberOfMatchingAlleles(altAllele, dad) > 0 ? 1 : 0;
        int numChildRefAlleles = this.getNumberOfMatchingAlleles(refAllele, child);
        int numChildAltAlleles = this.getNumberOfMatchingAlleles(altAllele, child);
        return numMomRefAlleles + numDadRefAlleles < numChildRefAlleles || numMomAltAlleles + numDadAltAlleles < numChildAltAlleles;
    }

    private ArrayList<Genotype> getPhasedGenotypes(Genotype mom, Genotype dad, Genotype child) {
        HashSet<Genotype> possiblePhasedChildGenotypes = new HashSet<Genotype>();
        for (Allele momAllele : mom.getAlleles()) {
            for (Allele dadAllele : dad.getAlleles()) {
                ArrayList<Allele> possiblePhasedChildAlleles = new ArrayList<Allele>();
                possiblePhasedChildAlleles.add(momAllele);
                possiblePhasedChildAlleles.add(dadAllele);
                Genotype possiblePhasedChildGenotype = new Genotype(child.getSampleName(), possiblePhasedChildAlleles, child.getNegLog10PError(), child.getFilters(), child.getAttributes(), true);
                possiblePhasedChildGenotypes.add(possiblePhasedChildGenotype);
            }
        }
        ArrayList<Genotype> finalGenotypes = new ArrayList<Genotype>();
        for (Genotype phasedChildGenotype : possiblePhasedChildGenotypes) {
            if (!child.sameGenotype(phasedChildGenotype, true)) continue;
            Allele momTransmittedAllele = phasedChildGenotype.getAllele(0);
            Allele momUntransmittedAllele = mom.getAllele(0) != momTransmittedAllele ? mom.getAllele(0) : mom.getAllele(1);
            ArrayList<Allele> phasedMomAlleles = new ArrayList<Allele>();
            phasedMomAlleles.add(momTransmittedAllele);
            phasedMomAlleles.add(momUntransmittedAllele);
            Genotype phasedMomGenotype = new Genotype(mom.getSampleName(), phasedMomAlleles, mom.getNegLog10PError(), mom.getFilters(), mom.getAttributes(), true);
            Allele dadTransmittedAllele = phasedChildGenotype.getAllele(1);
            Allele dadUntransmittedAllele = dad.getAllele(0) != dadTransmittedAllele ? dad.getAllele(0) : dad.getAllele(1);
            ArrayList<Allele> phasedDadAlleles = new ArrayList<Allele>();
            phasedDadAlleles.add(dadTransmittedAllele);
            phasedDadAlleles.add(dadUntransmittedAllele);
            Genotype phasedDadGenotype = new Genotype(dad.getSampleName(), phasedDadAlleles, dad.getNegLog10PError(), dad.getFilters(), dad.getAttributes(), true);
            finalGenotypes.add(phasedMomGenotype);
            finalGenotypes.add(phasedDadGenotype);
            finalGenotypes.add(phasedChildGenotype);
            return finalGenotypes;
        }
        finalGenotypes.add(mom);
        finalGenotypes.add(dad);
        finalGenotypes.add(child);
        return finalGenotypes;
    }

    private ArrayList<Genotype> phaseTrioGenotypes(Allele ref, Allele alt, Genotype mother, Genotype father, Genotype child) {
        ArrayList<Genotype> finalGenotypes = new ArrayList<Genotype>();
        finalGenotypes.add(mother);
        finalGenotypes.add(father);
        finalGenotypes.add(child);
        if (mother.isCalled() && father.isCalled() && child.isCalled()) {
            ArrayList<Genotype> possibleMotherGenotypes = this.createAllThreeGenotypes(ref, alt, mother);
            ArrayList<Genotype> possibleFatherGenotypes = this.createAllThreeGenotypes(ref, alt, father);
            ArrayList<Genotype> possibleChildGenotypes = this.createAllThreeGenotypes(ref, alt, child);
            double bestConfigurationLikelihood = 0.0;
            double bestPrior = 0.0;
            Genotype bestMotherGenotype = mother;
            Genotype bestFatherGenotype = father;
            Genotype bestChildGenotype = child;
            double norm = 0.0;
            for (Genotype motherGenotype : possibleMotherGenotypes) {
                for (Genotype fatherGenotype : possibleFatherGenotypes) {
                    for (Genotype childGenotype : possibleChildGenotypes) {
                        double prior = this.isMendelianViolation(ref, alt, motherGenotype, fatherGenotype, childGenotype) ? this.MENDELIAN_VIOLATION_PRIOR : 1.0 - 12.0 * this.MENDELIAN_VIOLATION_PRIOR;
                        double configurationLikelihood = this.computeTransmissionLikelihoodOfGenotypeConfiguration(motherGenotype, fatherGenotype, childGenotype);
                        norm += prior * configurationLikelihood;
                        if (!(prior * configurationLikelihood > bestPrior * bestConfigurationLikelihood)) continue;
                        bestConfigurationLikelihood = configurationLikelihood;
                        bestPrior = prior;
                        bestMotherGenotype = motherGenotype;
                        bestFatherGenotype = fatherGenotype;
                        bestChildGenotype = childGenotype;
                    }
                }
            }
            if (!(bestMotherGenotype.isHet() && bestFatherGenotype.isHet() && bestChildGenotype.isHet())) {
                HashMap<String, Object> attributes = new HashMap<String, Object>();
                attributes.putAll(bestChildGenotype.getAttributes());
                attributes.put("TP", bestPrior * bestConfigurationLikelihood / norm);
                bestChildGenotype = Genotype.modifyAttributes(bestChildGenotype, attributes);
                finalGenotypes = this.getPhasedGenotypes(bestMotherGenotype, bestFatherGenotype, bestChildGenotype);
            }
        }
        return finalGenotypes;
    }

    @Override
    public Integer map(RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context) {
        if (tracker != null) {
            VariantContext vc = tracker.getFirstValue(this.variantCollection.variants, context.getLocation());
            Map<String, Genotype> genotypeMap = vc.getGenotypes();
            for (Trio trio : this.trios) {
                Genotype mother = vc.getGenotype(trio.getMother());
                Genotype father = vc.getGenotype(trio.getFather());
                Genotype child = vc.getGenotype(trio.getChild());
                ArrayList<Genotype> trioGenotypes = this.phaseTrioGenotypes(vc.getReference(), vc.getAltAlleleWithHighestAlleleCount(), mother, father, child);
                Genotype phasedMother = trioGenotypes.get(0);
                Genotype phasedFather = trioGenotypes.get(1);
                Genotype phasedChild = trioGenotypes.get(2);
                genotypeMap.put(phasedMother.getSampleName(), phasedMother);
                genotypeMap.put(phasedFather.getSampleName(), phasedFather);
                genotypeMap.put(phasedChild.getSampleName(), phasedChild);
            }
            VariantContext newvc = VariantContext.modifyGenotypes(vc, genotypeMap);
            this.vcfWriter.add(newvc);
        }
        return null;
    }

    @Override
    public Integer reduceInit() {
        return null;
    }

    @Override
    public Integer reduce(Integer value, Integer sum) {
        return null;
    }

    private class Trio {
        private String mother;
        private String father;
        private String child;

        public Trio(String mother, String father, String child) {
            this.mother = mother;
            this.father = father;
            this.child = child;
        }

        public Trio(String familySpec) {
            String[] pieces = familySpec.split("[\\+\\=]");
            this.mother = pieces[0];
            this.father = pieces[1];
            this.child = pieces[2];
        }

        public String getMother() {
            return this.mother;
        }

        public String getFather() {
            return this.father;
        }

        public String getChild() {
            return this.child;
        }
    }
}

