#!/usr/local/biotools/python/2.7/bin/python
"""RNASEQ QC plots for RIN versus sensitivity for incoming samples with respect to previously seen UHR samples"""
__author__ = "Numrah Fadra"
__email__ = "Fadra.Numrah@mayo.edu"
__status__ = "Validation"


import argparse
import itertools
import sys
import csv
import os
import subprocess
import time
import shutil
import textwrap

csv.field_size_limit(sys.maxsize)

def ValidatePath(path):
        if os.path.exists(path)==False:
                raise Exception("ERROR: File/directory " +path+" does not exist or the path is incorrect \n Please check the path")
        return path


def ArgumentParser():
        help_text="Estimated sensitivity plots"
        parser = argparse.ArgumentParser(description=help_text,epilog=help_text)
        parser.add_argument("--bam", help="Name of the bam file to plot the sensitivity",required="True",type=ValidatePath)
	parser.add_argument("--outputdir",help="Path to output directory",default=os.getcwd(),type=ValidatePath)
	parser.add_argument("--coverage",help="coverage threshold for genes covered",default=10)
        parser.add_argument("--gene_list",help="Coordinate file containing NM numbers and distance to polyA",default="config/clinical_genes_coords.bed",type=ValidatePath)
	parser.add_argument("--uhr_dir",help="Path to directory containing UHR files, avg.UHR.all.cumulative.txt and avg.UHR.all.choppy.txt",default="config",type=ValidatePath)
        parser.add_argument("--gatk_path",help="Path to GATK",default="config/GenomeAnalysisTK.jar",type=ValidatePath)
	parser.add_argument("--reference",help="Path to reference",default="ref/allchr.fa",type=ValidatePath)
        parser.add_argument("--java",help="Path to java",default="java")
        parser.add_argument("--R",help="Path to Rscript",default="Rscript")

        return parser


parser=ArgumentParser()

try:
        args=parser.parse_args()
except Exception,e:
        print str(e)
        exit()



path_to_bam=args.bam
clinical_list=args.gene_list
ref=args.reference
gatk=args.gatk_path
java=args.java
R=args.R
output_dir=args.outputdir
cov=args.coverage
uhr_dir=args.uhr_dir


def RunGATK(clinical_list,path_to_bam,ref,gatk,output_dir):
	name=os.path.splitext(os.path.basename(path_to_bam))[0]
	#print name
	outfile=output_dir+'/'+name+'cvg_length'
	gatk_command=" ".join([java,'-jar',gatk,'-T','DepthOfCoverage','-R',ref,'-o',outfile,\
                               '-I',path_to_bam,'-L',clinical_list])
	#print "\n--------------------------------------------GATK job submitted------------------------------------\n"
	#print gatk_command
	error=os.system(gatk_command)
        if (error!=0):
		sys.exit("Error code:"+str(error)+" Trying to execute GATK DepthOfCoverage. Make sure java and the path to GATK are in you PATH")
	if(os.path.exists(outfile)==False):
		sys.exit(Outfile+" generated by GATK DepthOfCoverage not present")




		
##########################One module#############################################################################

# open the incoming file to write the result to

def CreateCvgTableFile(output_dir,name,clinical_list):
	k=csv.writer(open(output_dir+'/'+name+'.cvg_length.table.txt','a'),delimiter="\t")
	# open the clinical list or clinical genes coords file
	h=csv.reader(open(clinical_list,'rb'),delimiter="\t")
	hlist=list(h)
	mydict={}
	g=csv.reader(open(output_dir+'/'+name+'cvg_length','rb'),delimiter="\t")
	glist=list(g)
	gatk_header=['Gene','Distance','Total_depth','Avg_Depth','Depth_'+name]
	k.writerow(gatk_header)
	glist_new=glist[1:]
	for j in hlist:
		coord=j[0]+':'+j[2]
		mydict[coord]=j[3].split('-')[0]+"\t"+j[3].split('-')[1]
		# open cvg_length file found from the GATK step 
	for i in glist_new:
		a1=i[0].split(':')[0]
		b1=i[0].split(':')[1]
		match=a1+':'+b1
		mynew=str(i[1]+'\t'+i[2]+'\t'+i[3])
		if match in mydict:
			mydict[match]=[mydict[match],mynew]
	

	# print the NM numbers and counts per sample into the output file 
	final_list=[]
	for m1 in mydict.keys():
		newline1=mydict[m1][0].split("\t")
		newline2=mydict[m1][1].split("\t")
		temp=[newline1,newline2]
		merged=list(itertools.chain.from_iterable(temp))
		final_list.append(merged)
	
	sorted_list=sorted(final_list)

	for i in sorted_list: 
		k.writerow(i)

	

#### Pull out genes from cvg_length.table.txt from UHR sample file:

def CreateCvgExprFile(output_dir,name,cov):
	k2=csv.reader(open(output_dir+'/'+name+'.cvg_length.table.txt','rb'),delimiter="\t")
	k2_list=list(k2)
	my_header=k2_list[0]
	l1=csv.writer(open(output_dir+'/'+name+'.cvg_length.expr.txt','a'),delimiter="\t")
	gene_less=csv.writer(open(output_dir+'/'+name+'_low_expressed_300','a'),delimiter="\t")
	l1.writerow(my_header)
	NM_list=[]
	low_expressed=[]
	k3_list=k2_list[1:]
	for i in k3_list:
		if (int(i[1])==300 and int(i[4])>=int(cov)):
			NM_list.append(i[0])
		if (int(i[1])==300 and int(i[4])<int(cov)):
			low_expressed.append(i[0])

## this list corresponds to gene_expr.txt file from Jaime's results
	uniq_NM_list=sorted(set(NM_list))
	uniq_low_expressed=sorted(set(low_expressed))
	
	for i in k2_list:
		if i[0] in uniq_low_expressed :
			gene_less.writerow(i)

	for i in k2_list:
		if i[0] in uniq_NM_list:
			l1.writerow(i)


########################################Next calculate for distance from 300-5kb what percentage of genes are expressed at 
######################################## higher than 10 reads for each sample 

def calculate_cvg_profile(output_dir,name,cov,uhr_dir):
	shutil.copy(uhr_dir+"/avg.UHR.all.cumulative.txt",output_dir+"/"+name+".avg.UHR.cumulative.txt")
	shutil.copy(uhr_dir+"/avg.UHR.all.choppy.txt",output_dir+"/"+name+".avg.UHR.choppy.graph.txt")
	m1=csv.writer(open(output_dir+"/"+name+".avg.UHR.cumulative.txt","a"),delimiter="\t")
	m2=csv.writer(open(output_dir+"/"+name+".avg.UHR.choppy.graph.txt","a"),delimiter="\t")
	l2=csv.reader(open(output_dir+'/'+name+'.cvg_length.expr.txt','rb'),delimiter="\t")
	l2_list=list(l2)
	l3_list=l2_list[1:]
	for i in range(400,5100,100):
		count=0
		number_of_genes_dist=0
		number_of_genes_dist_300=0
		pct=0
		for j in l3_list:
			if (int(j[1])>=300 and int(j[1])<=i):
				number_of_genes_dist=number_of_genes_dist+1
			if (int(j[1])>=300 and int(j[1])<=i and int(j[4])>=int(cov)):
				count=count+1
		pct=round((float(count)/float(number_of_genes_dist))*100.0,2)
		row=[i,pct,name,name]
		m1.writerow(row)

	for i in range(300,5100,100):
                count1=0
                number_of_genes_dist1=0
                pct1=0
                for j in l3_list:
                        if (int(j[1])==i ):
                                number_of_genes_dist1=number_of_genes_dist1+1
                        if (int(j[1])==i and int(j[4])>=int(cov)):
                                count1=count1+1
                pct1=round((float(count1)/float(number_of_genes_dist1))*100.0,2)
                row1=[i,pct1,name,name]
                m2.writerow(row1)

#Generates the R files needed for plotting
def R_plotting(name):
	R1=open(output_dir+"/"+name+".plots.R","w")
	line1="library(ggplot2)"
	line2="library(plyr)"
	line3=''.join(['data','=','read.table','(','file','=','"',str(output_dir)+"/"+name+".avg.UHR.cumulative.txt",'"',',','header','=','TRUE',')'])
	line3B=''.join(['data','=','read.table','(','file','=','"',str(output_dir)+"/"+name+".avg.UHR.choppy.graph.txt",'"',',','header','=','TRUE',')'])
	line3C=''.join(['cvg','=','read.table','(','file','=','"',str(output_dir)+"/"+name+".cvg_length.expr.txt",'"',',','header','=','TRUE',')'])
	line4=''.join(['png','(','"',str(output_dir)+"/"+name+".sensitivity.cdf.png",'"'',','height','=','1000',',','width','=','1000',',','res','=','100',')'])
	line4B=''.join(['png','(','"',str(output_dir)+"/"+name+".sensitivity.png",'"'',','height','=','1000',',','width','=','1000',',','res','=','100',')'])
	line4C=''.join(['png','(','"',str(output_dir)+"/"+name+".median.decay.png",'"'',','height','=','1000',',','width','=','1000',',','res','=','100',')'])
	

	line5=textwrap.dedent('''\
        dist=c(500,1000,1500,2000,2500,3000,3500,4000,4500,5000)
        labels=c('0.5kb\\n(SS18-SSX1)','1.0kb\\n(FUS-DDIT3)','1.5kb\\n(EZR-ROS1)','2.0kb\\n(NPM1-ALK)','2.5kb\\n(PML-RARA)',
                  '3.0kb\\n(EWSR1-NR4A3)','3.5kb\\n(BRD4-NUT)','4.0kb\\n(PAX3-NCOA1)','4.5kb\\n(TMPRSS2-ERG)','5.0kb\\n(BCR-ABL1)')
        ''')

	#Cumulative sensitivity plot
	line7=textwrap.dedent('''\
        labels.rin=data.frame(sample_name=unique(data$sample_name),pct=data[which(data$distance==4600),]$pct)
        ggplot(data=data,aes(x=distance,y=pct,group=factor(sample_name),colour=factor(samples)))+
        geom_text(data=labels.rin,aes(label=sample_name,x=4600,y=pct+1),size=4,colour="black")+
        geom_point(aes(color=factor(samples)),size=2) +
        ggtitle("Cumulative sensitivity") +
        stat_smooth()+
        scale_x_continuous(limits=c(300,5000),breaks=dist,labels=labels) +
        scale_y_continuous(limits=c(0,100))+
        xlab("Distance to poly-A") +
        ylab("Cumulative sensitivity") +
        labs(colour="Samples") +
        theme(axis.text.x=element_text(face="bold",color="#993333",size=10,angle=30),
            axis.text.y=element_text(face="bold",color="#993333",size=10),
            axis.title.x=element_text(size=14,face="bold"),
            axis.title.y=element_text(size=14,face="bold"),
            plot.title=element_text(color="black",size=14,face="bold"))
        ''')

	#Sensitivity plot
	line7B=textwrap.dedent('''\
        labels.rin=data.frame(sample_name=unique(data$sample_name),pct=data[which(data$distance==4600),]$pct)
        ggplot(data=data,aes(x=distance,y=pct,group=factor(sample_name),colour=factor(samples)))+
        geom_text(data=labels.rin,aes(label=sample_name,x=4600,y=pct+4),size=4,colour="black")+
        geom_point(aes(color=factor(samples)),size=2) +
        ggtitle("Sensitivity") +
        stat_smooth()+
        scale_x_continuous(limits=c(300,5000),breaks=dist,labels=labels) +
        scale_y_continuous(limits=c(0,100))+
        xlab("Distance to poly-A") +
        ylab("Sensitivity") +
        labs(colour="Samples") +
        theme(axis.text.x=element_text(face="bold",color="#993333",size=10,angle=30),
            axis.text.y=element_text(face="bold",color="#993333",size=10),
            axis.title.x=element_text(size=14,face="bold"),
            axis.title.y=element_text(size=14,face="bold"),
            plot.title=element_text(color="black",size=14,face="bold"))
        ''')

	#Decay plot
	line7C=textwrap.dedent('''\
        nums=c(3:50)*100
        median.cvg=data.frame(dist=nums,
        cvg=sapply(nums,function(x) median(cvg[which(cvg$Distance==x),]$Total_depth) ))

        decay.model=lm (log(median.cvg$cvg)~median.cvg$dist)
        decay.rate=format(decay.model$coefficients[[2]]*1000,digits=2)
        r.squared=format(summary(decay.model)$r.squared,digits=2)

        legend=as.character(as.expression(
        substitute(italic("Decay rate")~"="~dr~","~italic(R)^2~"="~r2,
               list(dr=decay.rate,r2=r.squared))))

        ggplot(data=median.cvg,aes(x=dist,y=cvg))+
        ggtitle("Median coverage decay") +
        scale_y_log10()+
        geom_point()+
        stat_smooth(method="lm")+
        xlab("Distance from 3' end")+ylab("Median coverage(log-scale)")+
        geom_text(aes(label=legend,x=3000,y=50),size=6,parse=TRUE)+
        theme(axis.text.y=element_text(face="bold",color="#993333",size=10),
            axis.title.x=element_text(size=14,face="bold"),
            axis.title.y=element_text(size=14,face="bold"),
            plot.title=element_text(color="black",size=14,face="bold"))
        ''')

	line8='dev.off()'

	final_lines="\n".join([line1,line2,line3,line4,line5,line7,line8])+"\n"
	final_lines2="\n".join([line3B,line4B,line5,line7B,line8])+"\n"
	final_lines3="\n".join([line3C,line4C,line7C,line8])


	print "\nMaking R file for cumulative plot"
	print "\nPrinting to R file for sensitivity plot"
	R1.write(final_lines)
	R1.write(final_lines2)
	R1.write(final_lines3)
	R1.close()		       


	R_command=" ".join([R,' --vanilla',output_dir+'/'+name+'.plots.R'])

	error=os.system(R_command)

        if (error!=0):
		sys.exit("Error code:"+str(error)+" Trying to execute Rscript. Make sure Rscript is in the path and R_libs has the path to ggplot2")
	if(os.path.exists(str(output_dir)+"/"+name+".sensitivity.cdf.png")==False):
		sys.exit("Cumulative sensitivity plot not generated")

	if(os.path.exists(str(output_dir)+"/"+name+".sensitivity.png")==False):
		sys.exit("Sensitivity plot not generated")
	if(os.path.exists(str(output_dir)+"/"+name+".median.decay.png")==False):
		sys.exit("Median decay plot not generated")


RunGATK(clinical_list,path_to_bam,ref,gatk,output_dir)

name=os.path.splitext(os.path.basename(path_to_bam))[0]       
print "\n making calls to CreatTable function.....................\n"
CreateCvgTableFile(output_dir,name,clinical_list)
print "\nmaking calls to CreateCvgExprFile function....................\n"
CreateCvgExprFile(output_dir,name,cov)
print "\nmaking calls to cvg_calc_profile function..................\n"
calculate_cvg_profile(output_dir,name,cov,uhr_dir)
print "\nGenerating R file for plotting purposes\n"
R_plotting(name)



