from __future__ import division

import numpy as np
import matplotlib.pyplot as plt
import sys

from Bio import SeqIO
from Bio import Entrez
from Bio.SeqUtils import GC


def seq_streams(parsed_gb, start_chrome_ind, end_chrome_ind, 
    consec_bases = 750):
    '''
    This function returns the 4 flanking regions 
    on the up & down stream regions from the start and 
    stop positions of a given feature. The default len of
    flanking regions is 750 bases.
    start_chrome_ind: index of start position
    end_chrome_ind: index of end position

    The actual start and stop codon are not returned in the
    subsequences.
    '''
    start_upseq = parsed_gb[(start_chrome_ind - consec_bases):start_chrome_ind]
    start_downseq = parsed_gb[(start_chrome_ind + 3):(start_chrome_ind + consec_bases + 3)]
    end_upseq = parsed_gb[(end_chrome_ind - consec_bases):end_chrome_ind]
    end_downseq = parsed_gb[(end_chrome_ind + 3):(end_chrome_ind + consec_bases + 3)]
    return(start_upseq, start_downseq, end_upseq, end_downseq)

def feature_seq_extracter(parsed_gb, feature_string='CDS', consec_bases = 750):
    '''
    Creates a list of tuples for each CDS feature. The 4 sequences 
    are the flanking regions (FR) of the start and stop positions.
    
    The tuple is as
    follows: gene id, start position, end position, start upstream FR,
    start downstream FR, end upstream FR, end downstream FR, 
    '''
    feat_seq_list = []
    seq_str = str(parsed_gb.seq) # it's super slow unless seqrec is converted to str
    for cur_feature in parsed_gb.features:
        if cur_feature.type == feature_string:
            gene_name = cur_feature.qualifiers['gene'][0]
            start_chrome_ind = cur_feature.location.start.position
            end_chrome_ind = cur_feature.location.end.position
            su, sd, eu, ed = seq_streams(seq_str, start_chrome_ind,
            end_chrome_ind, consec_bases)
            feat_seq_list.append((gene_name, start_chrome_ind, end_chrome_ind,
            su, sd, eu, ed))
    return feat_seq_list

def hist_plots(gc_cont_list, plot_title, file_meta):
    plt.figure()
    plt.xkcd()
    plt.plot()
    gc_content_array = np.array(gc_cont_list, np.float)
    plt.hist(gc_content_array)
    plt.title(plot_title)
    plt.ylabel('Count')
    plt.xlabel('%GC Content')
    plt.legend(['5`up', '5`down', '3`up', '3`down'])
    fig_handle = file_meta + '_' + plot_title + '.png'
    plt.savefig(fig_handle)
    
def main():
    if not sys.argv[1:]:
        refseq_ids = ['NC_000022.11','NC_012920.1','NC_000024.10']
    else:
        refseq_ids = sys.argv[1:]
        
    Entrez.email = 'lefeverde@pitt.edu'
    for cur_id in refseq_ids:
        gb = Entrez.efetch(db='nucleotide', 
            id=cur_id,
            rettype='gbwithparts',
            retmode='text')
        cur_feat_list = feature_seq_extracter(SeqIO.read(gb, 'gb'))
        cur_feat_array = np.array(cur_feat_list, np.str)
        gc_raw_tups = [(GC(i[3]), GC(i[4]), GC(i[5]), GC(i[6])) for i in cur_feat_list]
        gc_unique_tups = [(GC(i[3]), GC(i[4]), GC(i[5]), GC(i[6])) for i in list(set(cur_feat_list))]
        print('\n%s raw CDS count: %i') % (cur_id, len(cur_feat_list))
        hist_plots(gc_raw_tups, cur_id, 'all_cds')
        print('%s unique gene ID count: %i') % (cur_id, len(set(cur_feat_array[0:,0])))
        print('%s unique : %i\n') % (cur_id, len(set(cur_feat_list)))
        hist_plots(gc_unique_tups, cur_id, 'unique_cds')   

if __name__ == '__main__':
    main()


