#!/usr/bin/python

"""
alignment.py: Template classes to generate alignment for pair of sequences
and three or more sequences. For longer than two sequences, we can generate
a graph of events.
"""

import difflib
from graph import Graph
import logging

__author__ = "Gabriel Zaccak"
__email__ = "jabra.zaccak@gmail.com"

LOGFORMAT = "%(asctime).19s %(levelname)s %(filename)s: %(lineno)s %(message)s"
logger = logging.getLogger('utils.alignment')


def match_str(str_a, str_b, match_threshold):
  """match string with certain threshold using difflib library"""
  str_len = float(max(len(str_a), len(str_b)))
  s = difflib.SequenceMatcher(None, str_a, str_b)
  match = 0.0
  for block in s.get_matching_blocks():
    match += block.size
  match_ratio = match / str_len
  return match_ratio >= match_threshold


class PairAlignment(object):
  """General Needle Wunch alignment template class.
     To use this class you need to create a subclass and override
     is_gap, is_empty and matches functions with to handle your objects"""
  
  def __init__(self, match_award=30, mismatch_penalty=-30, gap_penalty=-5):
    self.match_award = match_award
    self.mismatch_penalty = mismatch_penalty
    self.gap_penalty = gap_penalty 
   
  def is_gap(self, obj):
    """Is the object a gap object? A gap object is a blank object."""
    if isinstance(obj, str) or isinstance(obj, unicode):
      return obj == '-'
    raise NotImplementedError("You need to implement this!")

  def is_empty(self, obj):
    """Is the object a empty object?"""
    if isinstance(obj, str):
      if obj:
        return False
      else:
        return True
    raise NotImplementedError("You need to implement this to suppurt your object!")
   
  def matches(self, index_a, objs_a, index_b, objs_b, match_threshold = 1.0):
    """Compare objects and return true if they match above a threshold"""
    #print "in matching string"
    obj_a , obj_b = objs_a[index_a], objs_b[index_b]
    if (isinstance(obj_a, str) or isinstance(obj_a, unicode)) and\
       (isinstance(obj_b, str) or isinstance(obj_b, unicode)):
      return match_str(obj_a, obj_b, match_threshold)
    raise NotImplementedError("You need to implement this!")

  @staticmethod
  def zeros(shape):
    """Create a zeros matrix of dimenion of shape=(m,n)"""
    retval = []
    for _ in range(shape[0]):
      retval.append([])
      for _ in range(shape[1]):
        retval[-1].append(0)
    return retval

  def match_score(self, index_a, objs_a, index_b, objs_b, match_threshold):
    """compute match score"""
    if self.matches(index_a, objs_a, index_b, objs_b, match_threshold):
      return self.match_award
    elif self.is_gap(objs_a[index_a]) or self.is_gap(objs_b[index_b]):
      return self.gap_penalty
    else:
      return self.mismatch_penalty

  def needle(self, objs_a, objs_b, match_threshold, 
             gap_obj=u'-', empty_obj=u' '):
    """compute alignment"""
    m = len(objs_a)
    n = len(objs_b)
    # Generate DP table and traceback path pointer matrix
    score = PairAlignment.zeros((m+1, n+1)) # the DP table
  
    # Calcuclate DP table
    for i in range(m+1):
      score[i][0] = self.gap_penalty * i
    for j in range(n+1):
      score[0][j] = self.gap_penalty * j
    for i in range(1, m+1):
      for j in range(1, n+1):
        match = score[i-1][j-1] +\
                self.match_score(i-1, objs_a, j-1, objs_b, match_threshold)
        delete = score[i-1][j] + self.gap_penalty
        insert = score[i][j-1] + self.gap_penalty
        score[i][j] = max(match, delete, insert)
  
    #Traceback and compute the alignment
    align1, align2 = [], []
    i, j = m, n #start from the bottom right cell
    while i > 0 and j > 0: # end, reaching top corner
      score_current = score[i][j]
      score_diagonal = score[i-1][j-1]
      score_up = score[i][j-1]
      score_left = score[i-1][j]
      if score_current == score_diagonal +\
                          self.match_score(i-1, 
                                           objs_a,
                                           j-1,
                                           objs_b,
                                      match_threshold):
        align1.append(objs_a[i-1])
        align2.append(objs_b[j-1])
        i -= 1
        j -= 1
      elif score_current == score_left + self.gap_penalty:
        align1.append(objs_a[i-1])
        align2.append(gap_obj)
        i -= 1
      elif score_current == score_up + self.gap_penalty:
        align1.append(gap_obj)
        align2.append(objs_b[j-1])
        j -= 1
  
    # Finish tracing up to the top left cell
    while i > 0:
      align1.append(objs_a[i-1])
      align2.append(gap_obj)
      i -= 1
    while j > 0:
      align1.append(gap_obj)
      align2.append(objs_b[j-1])
      j -= 1
  
    return self.finalize(align1, align2, match_threshold, empty_obj)

  def finalize(self, align1, align2, match_threshold, empty_obj):
    """Calculate identity, score and aligned sequences"""
    align1, align2 = align1[::-1], align2[::-1]
    symbol = []
    score, identity = 0, 0

    for i in range(len(align1)):
      # if tow AAs are the same, then output the letter
      if align1[i] == align2[i]:
        symbol.append(align1[i])
        identity +=  1
        score += self.match_score(i,  align1, i,  align2, match_threshold)
      # if they are not identical and none of them is gap
      elif align1[i] != align2[i] and\
           not self.is_gap(align1[i]) and\
           not self.is_gap(align2[i]):
        score += self.match_score(i, align1, i, align2, match_threshold)
        symbol.append(empty_obj)
      # if one of them is a gap , output a space
      elif self.is_gap(align1[i]) or\
           self.is_gap(align2[i]):
        symbol.append(empty_obj)
        score += self.gap_penalty
    identity = float(identity) / len(align1) * 100
    return align1, symbol, align2, score, identity


class MultipleAlignment(object):
  """Generalize PairAlignment to more than two sequences"""
  def __init__(self, alignment_algo, match_threshold=0.8):
    self.alignment_algo = alignment_algo
    self.match_threshold = match_threshold
   
  def call_alignment_algo_on_stories(self,
                                     story1,
                                     story2,
                                     match_threshold,
                                     gap_obj,
                                     empty_obj):
    """Create multiple alignment using alignment algorithm provided"""
    align1, _, align2, _, _ = self.alignment_algo.needle(story1,
                                                    story2,
                                                    match_threshold,
                                                    gap_obj,
                                                    empty_obj)
    return [align1, align2]

  def get_matching_indices(self, align1, align2):
    """Compare two similar alignment with different gap objects
       and return the indices where they match
       Example:
       1. a, e, c, d, b, -
          a, -, -, -, b

       match:
       a, -, -, -, b
       -, a, -, b, -
       
       2. -, a, -, b
         c, a, e, b 
       
       ret = [[0,1], [1,2,-1], [4,3]]
       """       
    indices = []
    i, j = 0, 0
    while i < len(align1) and j < len(align2):
      if self.alignment_algo.matches(i, align1 , j, align2, self.match_threshold):
        indices.append([i, j])
        i, j = i + 1, j + 1
      elif self.alignment_algo.is_gap(align1[i]):
        i += 1
      elif self.alignment_algo.is_gap(align2[j]):
        j += 1
      else:
        return []
    return indices 

  def get_insert_indices(self, matching_indices, align1_size, align2_size):
    """Given matching indices return indices where to add gap objects to make
       the two similar alignments align"""
    # i_s are the inserts to the first alignment
    # j_s are the inserts to the second alignment
    i_s, j_s = [], []
    prev_pair = [-1, -1]
    for pair in matching_indices:
      diff_i = pair[0] - prev_pair[0] - 1
      if diff_i > 0:
        j_s.append([pair[1], diff_i])
      diff_j = pair[1] - prev_pair[1] - 1
      if diff_j > 0:
        i_s.append([pair[0], diff_j])
      prev_pair = pair
    # if last matching index didn't arrive to the end
    if prev_pair[0] + 1 < align1_size :
      j_s.append([prev_pair[1]+1, align1_size - prev_pair[0] - 1]) 
    if prev_pair[1] + 1 < align2_size:
      i_s.append([prev_pair[0]+1, align2_size - prev_pair[1] - 1])
    return [i_s, j_s]

  def get_gap_indices(self, align, inserts):
    """Get the gap indices of the alingment."""
    indices = []
    for i in range(len(align)):
      if self.alignment_algo.is_gap(align[i]):
        indices.append(i)
    gaps, flat_inserts = [], []
    for insert in inserts:
     for offset in xrange(insert[1]):
        flat_inserts.append(insert[0] + offset)
    i, j = 0, 0
    while i < len(indices) and j < len(flat_inserts):
      elm_i, elm_j = indices[i], inserts[j]
      if elm_i < elm_j:
        gaps.append(elm_i)
        i += 1
      elif elm_i == elm_j:
        gaps.append(elm_i)
        i, j = i + 1, j + 1
      else:
        j += 1
    return gaps

  def print_alignments_2_logger(self, alignments, msg=""):
    """Print to logger the current alignments."""
    logger.debug("alignments %s", msg)
    for index, alignment in enumerate(alignments):
      logger.debug("lenght of alignment %d :  %d", index, len(alignment))  
      temp_str = "\n"
      for i in range(len(alignment)):
        str_len = len(unicode(alignment[i]))
        if str_len < 5:
          temp_str = temp_str + "{}{}, ".format(unicode(alignment[i])[0:5], " "*(5-str_len))
        else:
          temp_str = temp_str + "{}{}, ".format(unicode(alignment[i])[0:40]," "*(40-str_len))
      logger.debug(temp_str) 

  def augment_alignments(self, current_alignnments, new_story, gap_obj=u'-',
                        empty_obj=u' '):
    """Augment alignments with a new story example"""
    alignments = deepcopy(current_alignments)
    prev_alignment = alignments[-1]
    # remove gaps from story
    prev_story = [event for event in prev_alignment
                        if not self.alignment_algo.is_gap(event) and
                           not self.alignment_algo.is_empty(event)]
    alignments = self.augment_story_to_current_alignments(alignments,
                                                          new_story,
                                                          prev_story,
                                                          gap_obj,
                                                          empty_obj)
    return alignments

  def augment_story_to_current_alignments(self, 
                                          alignments,
                                          new_story,
                                          prev_story,
                                          gap_obj,
                                          empty_obj):
    """Multiple alignment helper function. It merges and clean the new story
       to the previous alignments."""
    align1, align2 = self.call_alignment_algo_on_stories(prev_story, new_story,
                              self.match_threshold,gap_obj, empty_obj)
    if not alignments: 
      alignments = [align1, align2]
    else:
      # add the new story by aligning it with the previous stories in the
      # previous aligned stories
      previous_alignment = alignments[-1]

      self.print_alignments_2_logger(alignments, "before change")
      logger.debug("previous alignment: %s", previous_alignment)
      logger.debug("align1: %s", align1)
      logger.debug("align2: %s", align2)
      
      # previous_alignment and align1 are the same story by with different gaps
      matching_indices = self.get_matching_indices(previous_alignment, align1)
      # given the matching indices get indices where to insert gaps to merge
      # the story to the previous alingments
      inserts = self.get_insert_indices(matching_indices,
                                        len(previous_alignment), len(align1))

      # sort in descending order for gap insert consistency
      inserts[0].sort(reverse = True)
      inserts[1].sort(reverse = True)

      logger.debug("matching indices %s", matching_indices)
      logger.debug("inserts indices %s", inserts)

      # insert alignment and update old alignments in case we need
      # to insert more gap objects
      for alignment in alignments:
        for insert in inserts[0]:
          for _ in xrange(insert[1]):
            alignment.insert(insert[0], gap_obj)
      for insert in inserts[1]:
        for _ in xrange(insert[1]):
          align2.insert(insert[0]-1, gap_obj) # insert before to avoid conflicts

      # Next we need to make sure that the gaps in align1 and non-gaps in align2
      # have no conflict.
      # we will add gap objects if necessary to avoid conflicts
      gaps = self.get_gap_indices(previous_alignment, inserts[1])
     
      # gaps that need to be added
      gap_indices_to_maintain = []
      logger.debug("gap candidates indices %s", gaps)
      for gap_i in gaps:
        for alignment in alignments:
          if not self.alignment_algo.is_gap(alignment[gap_i]) and \
             not self.alignment_algo.matches(gap_i,
                                             alignment,
                                             gap_i,
                                             align2,
                                             self.match_threshold):
             gap_indices_to_maintain.append(gap_i)
             break

      logger.debug("gaps to maintain %s", gap_indices_to_maintain)
      # add the conflicted gaps to all alingments
      for alignment in alignments:
        for gap_index in reversed(gap_indices_to_maintain):
            alignment.insert(gap_index, gap_obj)
      for insert in reversed(gap_indices_to_maintain):
          align2.insert(insert + 1, gap_obj)

      # finally add the new story to alingments
      alignments.append(align2)

      self.print_alignments_2_logger(alignments, "after inserting new alingment")
      # after we added the new alingment, we need to do some cleaning
      # we check for free order elements and merge them
      self.clean_alignments(alignments)
      self.print_alignments_2_logger(alignments, "after removing free order events")

      # after clean and moving things around there might a need to merge
      # adjacent columns.
      # last step in the alingment, check if there are any adjancent columns
      # that can be merged.
      self.merge_adjacent_columns(alignments)
      still_to_merge = True
      while still_to_merge:
        still_to_merge = self.merge_across_multiple_columns(alignments, self.match_threshold)
        print "still to merge --------", still_to_merge
      self.print_alignments_2_logger(alignments, "after final cleaning alingments")
   
    return alignments

  def get_alignments(self, stories_list, gap_obj=u'-', empty_obj=u' '):
    """Input: Stories
       Output: List of lists of aligned stories
       It uses pair alignment algorithm to incrementally build the multiple
       alignment list.
       """
    # alignments holds all the stories aligned using gap_obj for gaps
    alignments, prev_story = None, None
    for story_index, story in enumerate(stories_list):
      logger.debug("processing story index  %d", story_index)
      if not prev_story:
        prev_story = story
        continue
      alignments = self.augment_story_to_current_alignments(alignments,
                                                            story,
                                                            prev_story,
                                                            gap_obj,
                                                            empty_obj)
      prev_story = story
    return alignments

  def merge_adjacent_columns(self, alignments):
    """Merge columns that can be merged"""
    num_rows, num_columns = len(alignments) , len(alignments[0])
    columns_to_merge, prev_col, prev_col_indices = [], None, []
    #for every column
    for col_i in range(num_columns):
      col_events = []
      # get non gap alignments + some indexes
      col_events = [[alignments[j][col_i], j, col_i, alignments[j]] \
                     for j in range(num_rows)\
                          if not self.alignment_algo.is_gap(alignments[j][col_i])]
      # row_i , col_j
      col_events_indices = [[j, col_i] for _, j, col_i, _ in col_events]

      if prev_col is None:
        prev_col = col_events
        prev_col_indices = col_events_indices
        continue

      #don't merge if there is command is repeated and the merge is not only merging gaps.
      merge_flag = True
      for row_l, col_k in prev_col_indices:
        if [row_l, col_k + 1] in col_events_indices:
          merge_flag = False
          break
      if not merge_flag:
        prev_col = col_events
        prev_col_indices = col_events_indices
        continue
      
      #get first elements and compare 
      elm_col_i, elm_col_prev = col_events[0], prev_col[0]
      if self.alignment_algo.matches(elm_col_prev[2],
                                     elm_col_prev[3],
                                     elm_col_i[2],
                                     elm_col_i[3], self.match_threshold):
        columns_to_merge.append(col_i)
      prev_col = col_events
      prev_col_indices = col_events_indices

    logger.debug("columns to merge %s", columns_to_merge)
    # merge columns
    for col_i in reversed(columns_to_merge):
      for row_i, alignment in enumerate(alignments):
        del_candidate = alignment.pop(col_i)
        if not self.alignment_algo.is_gap(del_candidate):
          alignment[col_i - 1] = del_candidate



  def merge_across_multiple_columns(self, alignments, threshold=0.8):
    """Merge columns that can be merged"""
    num_rows, num_columns = len(alignments) , len(alignments[0])
    columns_to_merge= []
    #prev_col, prev_col_indices = [], None, []
    #for every column
    for col_i in range(num_columns):
      col_events = []
      # get non gap alignments + some indexes
      col_events = [[alignments[j][col_i], j, col_i, alignments[j]] \
                     for j in range(num_rows)\
                          if not self.alignment_algo.is_gap(alignments[j][col_i])]
      # row_i , col_j
      col_events_indices = [[j, col_i] for _, j, col_i, _ in col_events]

      #if prev_col is None:
      #  prev_col = col_events
      #  prev_col_indices = col_events_indices
      #  continue
      prev_col = []
      prev_col_indices = []
      for col_new_i in range(col_i + 1, num_columns): 
        col_events_new = []
        # get non gap alignments + some indexes
        col_events_new = [[alignments[j][col_new_i], j, col_new_i, alignments[j]] \
                       for j in range(num_rows)\
                            if not self.alignment_algo.is_gap(alignments[j][col_new_i])]
        # row_i , col_j
        col_events_indices_new = [[j, col_new_i] for _, j, col_new_i, _ in col_events_new]




        #print col_events_indices
        #print col_events_indices_new
        #print prev_col_indices
        #don't merge if there is command is repeated and the merge is not only merging gaps.
        elm_col_i, elm_col_prev = col_events_new[0], col_events[0]
        print elm_col_i[0]
        print elm_col_prev[0]
        merge_flag = True
        for row_l, col_k in col_events_indices:
          if [row_l, col_new_i] in col_events_indices_new:
              merge_flag = False
              break
        for row_l, col_k in col_events_indices_new:
          for p_c in prev_col_indices:
            col_r, p_indices = p_c
            if [row_l, col_r] in p_indices:
              merge_flag = False
              break
        if not merge_flag:
          #prev_col = col_events
          #prev_col_indices = col_events_indices
          continue
        #get first elements and compare 
        elm_col_i, elm_col_prev = col_events_new[0], col_events[0]
        print elm_col_i[0]
        print elm_col_prev[0]
        if self.alignment_algo.matches(elm_col_prev[2],
                                       elm_col_prev[3],
                                       elm_col_i[2],
                                       elm_col_i[3], 1.0):
          columns_to_merge.append([col_i,col_new_i])
        prev_col.extend(col_events_new)
        prev_col_indices.extend([[col_new_i, col_events_indices_new]])


    logger.debug("columns to merge %s", columns_to_merge)
    for col_i, col_i_new in reversed(columns_to_merge):
      print col_i
    # merge columns
    for col_i, col_i_new in reversed(columns_to_merge):
      for row_i, alignment in enumerate(alignments):
        del_candidate = alignment.pop(col_i_new)
        if not self.alignment_algo.is_gap(del_candidate):
          alignment[col_i] = del_candidate

    if len(columns_to_merge):
      return True
    else:
      False

  def clean_alignments(self, alignments):
    """Clean alignments avoid having alignmnet output depends on the ordering
       of the input stories.
       It checks for free order of events by checking the columns and then choose
       the column with better anchoring"""
    free_order_events = {}
    num_rows, num_columns = len(alignments) , len(alignments[0])
    for col_i in reversed(range(num_columns)):
      col_events = [[alignments[j][col_i], j, col_i, alignments[j]]\
                     for j in range(num_rows)\
                          if not self.alignment_algo.is_gap(alignments[j][col_i])]
      if len(col_events) == 0:
        num_columns -= 1
        for row_i, alignment in enumerate(alignments):
          alignment.pop(col_i)
        # need to update indices in free_order_events
        new_f_o_es = {}
        for index, evt in free_order_events.items():
          evt[2] = evt[2] - 1
          new_col = index - 1
          new_f_o_es[new_col] = evt
        free_order_events = new_f_o_es
        continue
      
      # TODO check if necessary
      # it seems necessary but it might be a bug, because logically it doesn't make
      # sense for this repetition
      col_events = [[alignments[j][col_i], j, col_i, alignments[j]]\
                     for j in range(num_rows)\
                          if not self.alignment_algo.is_gap(alignments[j][col_i])]
      if len(col_events) == 1:
        free_order_events[col_i] = col_events[0]

    # partition free_order_events
    # get clusters of size 2 of all matching free order events (n choose 2)
    free_order_cluster = self.partition_free_order_events(free_order_events)
   
    #choose free order based on better anchoring
    for cluster in free_order_cluster:
      if len(cluster) > 1:
        for cluster_item in cluster:
          surrounding_anchors = len(self.get_surrounding_anchors(cluster_item, alignments))
          total_anchors       = len(self.get_total_row_anchors(cluster_item, alignments))
          cluster_item.append(surrounding_anchors)
          cluster_item.append(total_anchors)

    for cluster in free_order_cluster:
      if len(cluster) > 1:
        if len(cluster) > 2:
          logger.warning("something is wrong, clusters should be of length 2, %s", cluster)
        sorted_cluster = sorted(cluster, key = lambda x: (x[3], x[4]), reverse=True)
        winning_column = sorted_cluster[0][2]
        untouched_row = sorted_cluster[0][1]
        del_column = sorted_cluster[1][2]
        del_row = sorted_cluster[1][1]

        # check the delete is legal and we are not ruining the order:
        merge_columns = sorted([del_column, winning_column])
        delete_flag = True
        for col_i in range(merge_columns[0] + 1, merge_columns[1]):
          if not self.alignment_algo.is_gap(alignments[del_row][col_i]):
            delete_flag = False
            break
        if not delete_flag:
          logger.debug("don't delete column") 
          continue
        logger.debug("deleted column %d", del_column)
        if del_column < winning_column:
          winning_column -= 1
        for row_i, alignment in enumerate(alignments):
            del_candidate = alignment.pop(del_column)
            if not self.alignment_algo.is_gap(del_candidate):
              alignment[winning_column] = del_candidate
 
  def get_total_row_anchors(self, cluster, alignments):
    """Get number of anchors in row"""
    num_rows, num_columns = len(alignments) , len(alignments[0])
    row, col = cluster[1:3]
    anchors = []
    for col_i in range(0, num_columns):
      if not self.alignment_algo.is_gap(alignments[row][col_i]):
            anchors.append(alignments[row][col_i])
    return anchors   
      
  def get_surrounding_anchors(self, cluster, alignments):
    """Get number of surrounding anchors"""
    num_rows, num_columns = len(alignments) , len(alignments[0])
    row, col = cluster[1:3]
    anchors = []
    for col_i in reversed(range(0, col)):
      if self.alignment_algo.is_gap(alignments[row][col_i]):
        break
      else:
        anchors.append(alignments[row][col_i])
    for col_i in range(col + 1 , num_columns):
      if self.alignment_algo.is_gap(alignments[row][col_i]):
        break
      else:
        anchors.append(alignments[row][col_i])
    return anchors

  def partition_free_order_events(self, free_order_events):
    """Create clusters of size 2 of the free orders events between events of
       different rows"""
    #create all pairs from different rows
    free_order_cluster = []
    for key, value in free_order_events.items():
      if len(free_order_cluster) == 0:
        free_order_cluster = [[value]]
      else:
        found = False
        new_clusters_to_from = []
        for cluster in free_order_cluster:
          for cluster_item in cluster:
            if value[1] != cluster_item[1] and \
               self.alignment_algo.matches(cluster_item[2],
                                           cluster_item[3],
                                           value[2],
                                           value[3], 
                                           self.match_threshold):
              found = True
              if len(cluster) < 2:
                cluster.append(value)
              else:
                new_clusters_to_from.append([cluster_item, value])
        if not found:
          free_order_cluster.append([value])
        else:
          free_order_cluster.extend(new_clusters_to_from)
    return free_order_cluster

  def build_graph(self, alignments):
    """Generate graph object from mutiple aligned sequences"""
    graph = Graph()
    alignments_size = len(alignments)
    if not alignments_size:
      return graph
    longest_path_size = len(alignments[0])
    if not longest_path_size:
      return graph
    graph.add_node(u'root')
    node_count = 1 # internal counting to name nodes
    # for each alignment keep track of where in the graph it was connected last
    last_nodes = [u'root' for _ in alignments]
    for i in range(longest_path_size):
      new_node = unicode(node_count) # TODO change naming scheme
      non_gap_alignments = [alignments[j][i] for j in range(alignments_size)\
                          if not self.alignment_algo.is_gap(alignments[j][i])]
      new_node = unicode(node_count) + u": "+ unicode(non_gap_alignments[0])
      graph.add_node(new_node, non_gap_alignments)
      for j in range(alignments_size):
        if not self.alignment_algo.is_gap(alignments[j][i]):
          graph.add_vertex(last_nodes[j], new_node)
          last_nodes[j] = new_node
      node_count += 1
    end_node = 'end'
    graph.add_node(end_node)
    for j in range(alignments_size):
      graph.add_vertex(last_nodes[j], end_node)
    return graph

