import string
import re

################rc of sequence########################
def rc(seq):

    comp = {}
    comp['A'] = 'T'
    comp['C'] = 'G'
    comp['G'] = 'C'
    comp['T'] = 'A'
    comp['N'] = 'N'
    
    rc_seq = ''
    for i in range(len(seq)):
        rc_seq = rc_seq + comp[seq[-(i+1)]]

    return rc_seq

#######################################################
# read in the sequences function

def get_fasta(fastafile):

    f = open(fastafile,'r')
    seq = {}
    s = f.readline()
    nseq = 0
    while s != "":
        
        if s[0] == ">":
            nseq = nseq+1
            name = string.split(s[1:])[0]
            seq[name] = {}
            seq[name]['name']=name
            seq[name]['seq'] = ''
        else:
            s = string.strip(s)
            seq[name]['seq'] = seq[name]['seq']+s
        s = f.readline()
    f.close()

    return seq

#######################################################
# read in the sequences function

def get_pair_fasta(fastafile):

    f = open(fastafile,'r')
    seq = {}
    s = f.readline()
    nseq = 0
    while s != "":
        
	if s[0] == ">":
            nseq = nseq+1
            name = string.strip(s[1:])
            seq[`nseq`] = {}
            seq[`nseq`]['name']=name
            seq[`nseq`]['seq'] = ''
	else:
            s = string.strip(s)
            seq[`nseq`]['seq'] = seq[`nseq`]['seq']+s
        s = f.readline()
    f.close()

    return seq

#######################################################

def process_block(blk, seq, len_threshold, pid_threshold):
    
    ## takes a block and extracts all sub-blocks with given
    ## constraints on length and pid.
    subblocks = []
    if blk['len'] < len_threshold:
        return subblocks

    if blk['PID'] >= pid_threshold:
        subblocks.append(blk)
        return subblocks


    b1 = blk['b1']
    b2 = blk['b2']
    e1 = blk['e1']
    e2 = blk['e2']
    len = blk['len']
    sum_threshold = int(pid_threshold*len_threshold/100)

    pid_array = {}

    ## NOTE: always add -1 to seq index, since they go from 0... while b1 goes from 1... !!!
    sum = 0
    for j in range(0,len_threshold):
        if seq['1']['seq'][b1-1+j] == seq['2']['seq'][b2-1+j]:
            sum = sum+1
    pid_array[`0`] = sum
    
    for i in range(1,len-len_threshold+1):
        if seq['1']['seq'][b1-1+i-1] == seq['2']['seq'][b2-1+i-1]:
            sum = sum-1
        if seq['1']['seq'][b1-1+i+len_threshold-1] == seq['2']['seq'][b2-1+i+len_threshold-1]:
            sum = sum+1
        pid_array[`i`] = sum

    last_position_in_current_block = -1
    for i in range(len-len_threshold+1):
        if pid_array[`i`] >= sum_threshold:
            ## good one !
            if last_position_in_current_block == -1:
                ## start a block
                subblock = {}
                subblock['b1'] = b1+i
                subblock['b2'] = b2+i
                last_position_in_current_block = i+len_threshold-1
            else:
                if last_position_in_current_block >= i:
                    # continue the current block
                    last_position_in_current_block = i+len_threshold-1
                else:
                    # end the current block
                    subblock['e1'] = b1+last_position_in_current_block
                    subblock['e2'] = b2+last_position_in_current_block
                    current_block_len = subblock['e1']-subblock['b1']
                    current_block_pid = 0
                    current_block_b1 = subblock['b1']
                    current_block_b2 = subblock['b2']
                    for j in range(current_block_len):
                        if seq['1']['seq'][current_block_b1-1+j]==seq['2']['seq'][current_block_b2-1+j]:
                            current_block_pid = current_block_pid+1
                    subblock['PID'] = int(100*current_block_pid/current_block_len)
                    subblock['len'] = current_block_len
                    subblocks.append(subblock)
                    # and start a new block
                    subblock = {}
                    subblock['b1'] = b1+i
                    subblock['b2'] = b2+i
                    last_position_in_current_block = i+len_threshold-1
        else:
            ## bad one
            if last_position_in_current_block == -1:
                ## do nothing
                continue
            else:
                if last_position_in_current_block >= i:
                    ## do nothing
                    continue
                else:
                    ## end the current block
                    subblock['e1'] = b1+last_position_in_current_block
                    subblock['e2'] = b2+last_position_in_current_block
                    current_block_len = subblock['e1']-subblock['b1']
                    current_block_pid = 0
                    current_block_b1 = subblock['b1']
                    current_block_b2 = subblock['b2']
                    for j in range(current_block_len):
                        if seq['1']['seq'][current_block_b1-1+j]==seq['2']['seq'][current_block_b2-1+j]:
                            current_block_pid = current_block_pid+1
                    subblock['PID'] = int(100*current_block_pid/current_block_len)
                    subblock['len'] = current_block_len
                    subblocks.append(subblock)
                    ## and set the flag
                    last_position_in_current_block = -1

    if last_position_in_current_block != -1:        
        # end the current block
        subblock['e1'] = b1+last_position_in_current_block
        subblock['e2'] = b2+last_position_in_current_block
        current_block_len = subblock['e1']-subblock['b1']
        current_block_pid = 0
        current_block_b1 = subblock['b1']
        current_block_b2 = subblock['b2']
        for j in range(current_block_len):
            if seq['1']['seq'][current_block_b1-1+j]==seq['2']['seq'][current_block_b2-1+j]:
                current_block_pid = current_block_pid+1
        subblock['PID'] = int(100*current_block_pid/current_block_len)
        subblock['len'] = current_block_len
        subblocks.append(subblock)
                
    return subblocks
    
                   
#######################################################
def get_lagan_blk(seq_file):

    ## maps bases to bases from seq1 to seq2, puts into dict = 'hits'
    ## extracts beginning and ends of blks, puts into dict = 'blk'
    ## sequences are from mfa file from LAGAN
    ## len(seq1) = len(seq2)
    
    seq = get_pair_fasta(seq_file)
    
    count1 = 0
    count2 = 0
    flag = 0
    blks = []
    for i in range(len(seq['1']['seq'])):
        
	if seq['1']['seq'][i] != '-':
            count1 = count1+1
        if seq['2']['seq'][i] != '-':
            count2 = count2+1
           
        if seq['1']['seq'][i] != '-' and seq['2']['seq'][i] != '-':
                    
            if flag == 0:  ## start a new block
                blk = {}
                flag = 1
                sum = 0
                blk['b1'] = count1
                blk['b2'] = count2
                if seq['1']['seq'][i] == seq['2']['seq'][i]:
                    sum = sum+1
            else:    ## still in a block
                if seq['1']['seq'][i] == seq['2']['seq'][i]:
                    sum = sum+1
        else:
            if flag == 1:   ## end a block
                if seq['1']['seq'][i] == '-':            
                    flag = 0
                    blk['e1'] = count1
                    blk['e2'] = count2-1
                else:
                    flag = 0
                    blk['e1'] = count1-1
                    blk['e2'] = count2
                PID = int(100.0*float(sum)/float(blk['e1']-blk['b1']+1))
                blk['PID'] = PID
                blk['len'] = blk['e1']-blk['b1']+1
                blks.append(blk)
                
    if flag == 1: ## end the block
        blk['e1'] = count1
        blk['e2'] = count2
        PID = int(100.0*float(sum)/float(blk['e1']-blk['b1']+1))
        blk['PID'] = PID
        blk['len'] = blk['e1']-blk['b1']+1
        blks.append(blk)

    return blks

#############################################################################
def get_lagan_blk_with_constraints(seq_file,len_threshold,pid_threshold):

    ## maps bases to bases from seq1 to seq2, puts into dict = 'hits'
    ## extracts beginning and ends of blks, puts into dict = 'blk'
    ## sequences are from mfa file from LAGAN
    ## len(seq1) = len(seq2)
    
    seq = get_pair_fasta(seq_file)
    
    count1 = 0
    count2 = 0
    flag = 0
    blks = []
    for i in range(len(seq['1']['seq'])):
        
	if seq['1']['seq'][i] != '-':
            count1 = count1+1
        if seq['2']['seq'][i] != '-':
            count2 = count2+1
           
        if seq['1']['seq'][i] != '-' and seq['2']['seq'][i] != '-':
                    
            if flag == 0:  ## start a new block
                blk = {}
                flag = 1
                sum = 0
                blk['b1'] = count1
                blk['b2'] = count2
                if seq['1']['seq'][i] == seq['2']['seq'][i]:
                    sum = sum+1
            else:    ## still in a block
                if seq['1']['seq'][i] == seq['2']['seq'][i]:
                    sum = sum+1
        else:
            if flag == 1:   ## end a block
                if seq['1']['seq'][i] == '-':            
                    flag = 0
                    blk['e1'] = count1
                    blk['e2'] = count2-1
                else:
                    flag = 0
                    blk['e1'] = count1-1
                    blk['e2'] = count2
                PID = int(100.0*float(sum)/float(blk['e1']-blk['b1']+1))
                blk['PID'] = PID
                blk['len'] = blk['e1']-blk['b1']+1
                subblocks = process_block(blk,seq,int(len_threshold),int(pid_threshold))
                for subblk in subblocks:
                    blks.append(subblk)
                
    if flag == 1: ## end the block
        blk['e1'] = count1
        blk['e2'] = count2
        PID = int(100.0*float(sum)/float(blk['e1']-blk['b1']+1))
        blk['PID'] = PID
        blk['len'] = blk['e1']-blk['b1']+1
        subblocks = process_block(blk,seq,int(len_threshold),int(pid_threshold))
        for subblk in subblocks:
            blks.append(subblk)

    return blks

#############################################################################

#######################################################

def get_blastz_blk(blastz_file):

    ## extracts beginning and ends of blks, puts into dict = 'blk'
    ## sequences are from mfa file from LAGAN
    ## len(seq1) = len(seq2)

    f = open(blastz_file,'r')
    s = f.read()
    
    lines = re.findall('\sl\s\d.+?\n', s)

    blks = []

    for line in lines:
        blk = {}

        line = string.split(line)
        
        blk['b1'] = int(line[1])
        blk['b2'] = int(line[2])
        blk['e1'] = int(line[3])
        blk['e2'] = int(line[4])
        blk['PID'] = int(line[5])
        blk['len'] = blk['e1']-blk['b1']+1

        blks.append(blk)
        
    return blks

#############################################################################
#extracts all the conserved blocks of inputed specie from an alignment file 
#produced by LAGAN for two species.
#

def binding_site_stats(blks, seq, gff_file, method):

## total conserved in chunks:

    conserved = 0
    hits = {}
    for blk in blks:
        s1 = blk['b1']
        s2 = blk['b2']
        for i in range(blk['len']):
            if seq['1']['seq'][s1+i-1]==seq['2']['seq'][s2+i-1]:
                conserved = conserved + 1
                hits[`s1+i`] = seq['1']['seq'][s1+i-1]
    
### estimate range over which things are conserved
				
    rmin1 = blks[0]['b1']
    rmin2 = blks[0]['b2']
    rmax1 = blks[-1]['e1']
    rmax2 = blks[-1]['e2']
    
    if (rmax1-rmin1+1) <= (rmax2-rmin2+1):
        seq_range = rmax1-rmin1+1
    else:
        seq_range = rmax2-rmin2+1
	
    rmin = rmin1
    rmax = rmax1

#####compute binding site stats

    fsites = open(gff_file,'r')   ## file containing sites
        
    sites = []
    s = fsites.readline()
    conserved_sites = 0
    total_sites = 0
    while s != '':
            
        s = string.split(s)
        begin = int(s[3])
        end = int(s[4])
        module = s[1]
        name = s[2]
        
        if begin >= rmin and end <= rmax:
            total_sites = total_sites + end-begin+1

        if method == 0:
            sum = 0
            for i in range(begin,end+1):
                if hits.has_key(`i`):
                    conserved_sites = conserved_sites+1
                    sum = sum + 1
        else:
            sum = 0
            for blk in blks:
                if begin >= blk['b1'] and end <= blk['e1']:     ## site is wholely contained in block
                    for i in range(begin,end+1):
                        if hits.has_key(`i`):
                            conserved_sites = conserved_sites+1
                            sum = sum + 1
                    break
	
        lsite = end-begin+1
        PID = float(sum)/float(lsite)
        out = string.join([module,name,`begin`,`end`,`lsite`,`sum`,`PID`],'\t')
        sites.append(out)
		
        s = fsites.readline()

    return conserved_sites, total_sites, conserved, seq_range, sites

###############################################################################

def calc_PID(blks, seq, win, shift, prefx):

    rmin = blks[0]['b1']
    rmax = blks[-1]['e1']
    
    hits = {}
    for blk in blks:
        s1 = blk['b1']
        s2 = blk['b2']
        for i in range(blk['len']):
            if seq['1']['seq'][s1+i-1]==seq['2']['seq'][s2+i-1]:
                hits[`s1+i`] = seq['1']['seq'][s1+i-1]

	# generate %ID profile

    fPID = open(prefx+'.PID','w')
    Nwin = (rmax-rmin+1)/shift
    ii = rmin
    for i in range(Nwin):
        
        sum = 0
        for j in range(ii,ii+win):
            if hits.has_key(`j`):
                sum = sum+1

        fPID.write('%d %f\n' % (ii, int(100.0*float(sum)/float(win)) ) )

        ii = ii+shift
    fPID.close()

    return 
