"""

This is a module that contains functions responsible for mutating    
the trajectory file for alanine scanning in MMPBSA.py. It must be    
included with MMPBSA.py to insure proper functioning of alanine      
scanning.                                                            

Last updated: 04/17/2010                                    

                           GPL LICENSE INFO                             

Copyright (C) 2009  Dwight McGee, Billy Miller III, and Jason Swails

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.
   
You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 59 Temple Place - Suite 330,
Boston, MA 02111-1307, USA.
"""

from MMPBSA_mods.exceptions import MutateError

#+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+

def _getCoords(line, coordsperline, coordsize):
   """ Returns the coordinates of a line in a mdcrd file """
   holder = []
   location = 0
   for i in range(coordsperline):
      try:
         tmp = float(line[location:location+coordsize])
         holder.append(tmp)
      except:
         pass
      location += coordsize

   return holder

#+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+

def _scaledistance(coords, dist):
   """ Scales the distance between 2 3-D cartesian coordinates to the specified
       distance 
   """
   from math import sqrt

   if len(coords) != 6:
      raise MutateError('_scaledistance requires x,y,z coords for 2 atoms')

   coords[3] -= coords[0]  # set first 3 coordinates as origin
   coords[4] -= coords[1]
   coords[5] -= coords[2]

   actualdist = sqrt(coords[3]*coords[3] + coords[4]*coords[4] + 
                     coords[5]*coords[5])

   scalefactor = dist / actualdist # determine scale factor

   coords[3] *= scalefactor  # scale original coordinates
   coords[4] *= scalefactor
   coords[5] *= scalefactor

   coords[3] += coords[0] # move back to original place
   coords[4] += coords[1]
   coords[5] += coords[2]

   return coords  # return the coordinates

#+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+

def _getnumatms(resname):
   """ Returns the number of atoms in a given Amino acid residue """
   if resname in 'GLY':
      return 7
   if resname in ('ALA', 'CYM', 'CYX'):
      return 10
   if resname in ('CYS', 'SER'):
      return 11
   if resname in ('ASP'):
      return 12
   if resname in ('ASN', 'PRO', 'THR'):
      return 14
   if resname in ('GLU'):
      return 15
   if resname in ('GLH', 'VAL'):
      return 16
   if resname in ('GLN', 'HID', 'HIE', 'MET'):
      return 17
   if resname in ('HIP'):
      return 18
   if resname in ('ILE', 'LEU'):
      return 19
   if resname in ('PHE'):
      return 20
   if resname in ('LYN', 'TYR'):
      return 21
   if resname in ('LYS'):
      return 22
   if resname in ('ARG', 'TRP'):
      return 24
   
   raise MutateError(('Unrecognized residue! Add %s to _getnumatms(resname) ' +
                      'in alamdcrd.py and reinstall MMPBSA.py' % resname))

#+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+

def _ressymbol(resname):
   """ Return 1-letter symbol of give amino acid """

   if resname == 'ALA':
      return 'A'
   elif resname == 'ARG':
      return 'R'
   elif resname == 'ASN':
      return 'N'
   elif resname in ['ASP','ASH']:
      return 'D'
   elif resname in ['CYS','CYX','CYM']:
      return 'C'
   elif resname in ['GLU','GLH']:
      return 'E'
   elif resname == 'GLN':
      return 'Q'
   elif resname == 'GLY':
      return 'G'
   elif resname in ['HIP','HID','HIE']:
      return 'H'
   elif resname == 'ILE':
      return 'I'
   elif resname == 'LEU':
      return 'L'
   elif resname in ['LYN','LYS']:
      return 'K'
   elif resname == 'MET':
      return 'M'
   elif resname == 'PHE':
      return 'F'
   elif resname == 'PRO':
      return 'P'
   elif resname == 'SER':
      return 'S'
   elif resname == 'THR':
      return 'T'
   elif resname == 'TRP':
      return 'W'
   elif resname == 'TYR':
      return 'Y'
   elif resname == 'VAL':
      return 'V'
   else:
      return resname

#+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+

class MutantMdcrd(object):
   """ Class for an alanine-mutated amber trajectory file. 
       ASCII only (no netcdf)
   """

   #-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#

   def __init__(self, trajname, prm1, prm2):
      self.traj = trajname
      self.orig_prm = prm1
      self.new_prm = prm2
      self.mutres = self.FindMutantResidue()

   #-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#

   def __str__(self):
      return '%s%d%s' % (_ressymbol(
             self.orig_prm.parm_data['RESIDUE_LABEL'][self.mutres-1]), 
             self.mutres, _ressymbol('ALA'))

   #-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#

   def FindMutantResidue(self):
      """ Finds which residue is the alanine mutant in a pair of prmtop files
      """
      origres = self.orig_prm.parm_data['RESIDUE_LABEL']
      newres = self.new_prm.parm_data['RESIDUE_LABEL']
      diffs = 0
      mutres = -1

      if len(origres) != len(newres):
         raise MutateError(('Mutant prmtop (%s) has a different number of ' +
                            'residues than the original (%s)!') %
                            (self.new_prm.name, self.orig_prm.name))

      for i in range(len(origres)):
         if origres[i] != newres[i]:
            diffs += 1
            if newres[i] != 'ALA':
               raise MutateError('Mutant residue %s is %s but must be ALA!' %
                                 (i+1, newres[i]))
            mutres = i + 1

      if diffs == 0:
         raise MutateError(('Your mutant prmtop (%s) has the same sequence ' +
                            'as the original!') % (self.new_prm.name, 
                                                   self.orig_prm.name))
      elif diffs > 1:
         raise MutateError('Your mutant prmtop (%s) can only have one mutation!'
                           % self.new_prm.name)

      return mutres

#+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+

   def MutateTraj(self, newname):
      """ Mutates a given mdcrd file based on 2 prmtops """

      mutres = self.mutres

      orig_resname = self.orig_prm.parm_data['RESIDUE_LABEL'][mutres-1]
      resstart = self.orig_prm.parm_data['RESIDUE_POINTER'][mutres-1]
      nextresstart = self.orig_prm.parm_data['RESIDUE_POINTER'][mutres]
      number_atoms_mut = self.new_prm.ptr('natom')

      coordsperline = 10 # number of coordinates in each line
      coordsize = 8      # how large coordinates are in characters
      counter = 0
      coords_done = 0
      coords_tomutate = []
      temp_holder = []
      new_coords = []

      # location is 0 before modified coordinates, 1 during modified 
      # coordinates, and 2 after modified coordinates
      location = 0

      if orig_resname == 'GLY':
         raise MutateError('You are trying to mutate GLY to ALA! ' +
                           'Not currently supported.')

      new_mdcrd = open(newname, 'w')
      mdcrd = open(self.traj, 'r')

      for line in mdcrd:
         counter += 1
         # First line is always a comment
         if counter == 1:
            new_mdcrd.write('%-80s' % (line.strip() + 
                            ' and mutated by MMPBSA.py for alanine scanning'))
            continue

         if coords_done <= resstart * 3 - 4 and \
                  coords_done + coordsperline >= resstart * 3 - 4:
            location = 1
            words = _getCoords(line, coordsperline, coordsize)
            for i in range(coordsperline):
               if coords_done <= resstart * 3 - 4:
                  if coords_done % coordsperline == 0:
                     new_mdcrd.write('\n')
                  new_mdcrd.write('%8.3f' % words[i])
                  coords_done += 1
               else:
                  coords_tomutate.append(words[i])
            continue

         elif location == 1:
            words = _getCoords(line, coordsperline, coordsize)
            if coordsperline + len(coords_tomutate) >= \
                        3 * (nextresstart - resstart):
               location = 2
               for i in range(coordsperline):
                  if len(coords_tomutate) < 3 * (nextresstart - resstart):
                     coords_tomutate.append(words[i])
                  else:
                     temp_holder.append(words[i])

               new_coords = self._mutate(orig_resname, coords_tomutate)
               for i in range(len(new_coords)):
                  if coords_done % coordsperline == 0:
                     new_mdcrd.write('\n')
                  new_mdcrd.write('%8.3f' % new_coords[i])
                  coords_done += 1
               if len(temp_holder) != 0:
                  for i in range(len(temp_holder)):
                     if coords_done % coordsperline == 0:
                        new_mdcrd.write('\n')
                     new_mdcrd.write('%8.3f' % temp_holder[i])
                     coords_done += 1
               coords_tomutate = []
               temp_holder = []
            else:
               for i in range(coordsperline):
                  coords_tomutate.append(words[i])

            continue

         elif location == 2:
            words = _getCoords(line, coordsperline, coordsize)
            for i in range(len(words)):
               if coords_done % coordsperline == 0:
                  new_mdcrd.write('\n')
               new_mdcrd.write('%8.3f' % words[i])
               coords_done += 1

            if coords_done == number_atoms_mut * 3:
               coords_done = 0
               location = 0
            continue
         else:
            if coords_done % coordsperline == 0:
               new_mdcrd.write('\n')
            new_mdcrd.write(line[:len(line)-1])
            coords_done = coords_done + coordsperline
            continue

      new_mdcrd.write('\n')
      mdcrd.close()
      new_mdcrd.close()

#+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+

   def _mutate(self, resname, coords):
      list_one = 'ARG ASH ASN ASP CYM CYS CYX GLH GLN GLU HID HIE HIP ' + \
                 'LEU LYN LYS MET PHE SER TRP TYR'
      list_two = 'ILE THR VAL'
      list_three = 'PRO'

      chdist = 1.09
      nhdist = 1.01

      coords_tosend = []
      new_coords  = []
      coords_received = []
      cterm = False

      if _getnumatms(resname) * 3 == len(coords):
         startindex = 0
         cterm = False
      elif (_getnumatms(resname) + 2) * 3 == len(coords):
         startindex = 2
         cterm = False
      elif (_getnumatms(resname) + 1) * 3 == len(coords):
         startindex = 1
         cterm = True
      else:
         raise MutateError(('Mismatch in atom # in residue %s. (%d in ' +
                            'alamdcrd.py and %d passed in)') % (resname, 
                            _getnumatms(resname), len(coords)))
         
      if resname in list_one:
         for i in range((7+startindex)*3):
            new_coords.append(coords[i])
         for i in range(3):
            coords_tosend.append(coords[(4+startindex)*3+i])
         for i in range(3):
            coords_tosend.append(coords[(7+startindex)*3+i])
         coords_received = _scaledistance(coords_tosend, chdist)

         for i in range(3):
            new_coords.append(coords_received[i+3])

      elif resname in list_two:
         for i in range((6+startindex)*3):
            new_coords.append(coords[i])

         for i in range(3):
            coords_tosend.append(coords[(4+startindex)*3+i])
         for i in range(3):
            coords_tosend.append(coords[(6+startindex)*3+i])
         coords_received = _scaledistance(coords_tosend, chdist)

         for i in range(3):
            new_coords.append(coords_received[i+3])

         coords_tosend = []
         coords_received = []
         for i in range(3):
            coords_tosend.append(coords[(4+startindex)*3+i])
         for i in range(3):
            coords_tosend.append(coords[(10+startindex)*3+i])
         coords_received = _scaledistance(coords_tosend, chdist)

         for i in range(3):
            new_coords.append(coords_received[3+i])

      elif resname in list_three:
         for i in range((1+startindex)*3):
            new_coords.append(coords[i])

         coords_tosend = coords[startindex*3:startindex*3 + 6]
         coords_received = _scaledistance(coords_tosend, nhdist)

         for i in range(3):
            new_coords.append(coords_received[i+3])
         for i in range(6):
            new_coords.append(coords[(10+startindex)*3+i])
         for i in range(9):
            new_coords.append(coords[(7+startindex)*3+i])

         coords_tosend = []
         coords_received = []
         for i in range(3):
            coords_tosend.append(coords[(7+startindex)*3+i])
         for i in range(3):
            coords_tosend.append(coords[(4+startindex)*3+i])
         coords_received = _scaledistance(coords_tosend, chdist)

         for i in range(3):
            new_coords.append(coords_received[i+3])

      else:
         raise MutateError("Residue %s not recognized! Can't mutate." % resname)

      if cterm:
         for i in range(9):
            new_coords.append(coords[len(coords)-9+i])
      else:
         for i in range(6):
            new_coords.append(coords[len(coords)-6+i])

      return new_coords

#+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
