#!/usr/bin/env python
#======================================================================
# pymdgbsa.py version 0.7 (May 2011)
# copyright Novartis Institutes for Biomedical Research
# Basel, Switzerland
# Author: Romain M. Wolf
#======================================================================
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.

#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.

#  You should have received a copy of the GNU General Public License
#  along with this program.  If not, see <http://www.gnu.org/licenses/>.
#=======================================================================

import os, re, sys, math, tempfile, shutil
from subprocess import Popen, PIPE
from optparse import OptionParser
whitespace = re.compile(r'\s+')

"""
Usage: pymdpbsa.py [options] or --help for HELP

Computes the free energy of interaction terms for a receptor-ligand
complex, given an single MD trajectory (ptraj-readable) of the complex
and the individual parameter-topology files for the complex, 
the ligand, and the receptor.

This version calls the pbsa routine of AMBER Tools for PB, i.e.,
it will not work with AMBER Tools 1.2 or earlier...

There are five different options for the solvation term:
3 Generalized-Born methods (corresponding to igb = 1, 2, 5 in other
AMBER modules) with a simple linear function of the SASA for the
nonpolar part Gnp = 0.0072 * SASA;
2 Poisson-Boltzmann methods with different modes to treat
non-polar solvation terms:
   (a) a standard Gnp = 0.005 * SASA + 0.86 (fast);
   (b) the Ecav/Edisp split according to Luo et al. (slower);

When no solvation term is specified, a distance-dependent dielectric 
function with epsilon = 1 is used.

Entropy terms are NOT evaluated.

******************************************************************
NOTE: Use the PB options 3 or 4 with care, look at the 
pbsacontrol_solv3 and pbsacontrol_solv4 functions below and adjust
the settings to your taste...
******************************************************************

"""

#function definitions
#======================================================================
def nrgtable_gb(prmtop, project, P, start, stop, step, gb, sa):
#======================================================================
  """
  Makes the actual energy tables by calling the NAB routine ffgbsa
  * prmtop = prmtop file of the part for which the energy is to be computed
    (ligand or receptor or complex)
  * project = global project name (used to identify the tables)
  * P = flag to define which part is computed:
      L for ligand alone
      R for receptor alone
      C (or anything else) for entire complex
  * start, stop, and step are the portions of the trajectory
    to be used in the evaluation
    NOTE: THESE VALUES MUST BE THE SAME AS IN THE TRAJECTORY SPLIT.
  * gb = flag to include Generalized-Born (1, 2, or 5) or not (0)
       no GB means distance-dependent dielectrics epsilon=r
  * sa = flag to include the SASA term (always 1, i.e., set,  in this version)

  !!! Calls ffgbsa with runs the NAB routine ffgbsa !!!
  """
  if (P == "L"):
    middle = "L"
  elif(P == "R"):
    middle = "R"
  else:
    middle = "C"
  table = open("%s.%s.nrg"%(project, middle), "w")
  if start < 1:
    start = 1
  for x in range(start, stop+step, step):
# read the (temporary) PDB files of the required frames, one after the other
    pdbfile = "%s.%s.pdb.%d"%(project, middle, x)   
# compute the energy terms by calling ffgbsa
    etot, ebat, evdw, ecoul, egb, sasa = ffgbsa(prmtop, pdbfile, gb, sa)
# write the values to the selected .nrg table
    table.write("%4i %10.2f %10.2f %10.2f %10.2f %10.2f %10.2f\n"
        %(x, etot, ebat, evdw, ecoul, egb, sasa))
    table.flush()
  table.close()     

#======================================================================
def nrgtable_pb(prmtop, project, P, start, stop, step):
#======================================================================
  """
  makes the actual energy tables by calling pbsarun
  * prmtop = prmtop file of the part for which the energy is to be computed
    (ligand or receptor or complex)
  * project = global project name (used to identify the tables)
  * P = flag to define which part is computed:
      L for ligand alone
      R for receptor alone
      C (or anything else) for entire complex
  * start, stop, and step are the portions of the trajectory
    to be used in the evaluation
    NOTE: THESE VALUES MUST BE THE SAME AS IN THE TRAJECTORY SPLIT.
  """
  if (P == "L"):
    middle = "L"
  elif(P == "R"):
    middle = "R"
  else:
    middle = "C"
  table = open("%s.%s.nrg"%(project, middle), "w")
  if start < 1:
    start = 1
  for x in range(start, stop+step, step):
# read the (temporary) CRD files of the required frames, one after the other
    crdfile = "%s.%s.crd.%d"%(project, middle, x)   
# compute the energy terms by calling pbsa
    etot, evdw, ecoul, epb, ecav, edisp = pbsarun(prmtop, crdfile)
# write the values to the selected .nrg table
    table.write("%4i %10.2f %10.2f %10.2f %10.2f %10.2f %10.2f\n"
        %(x, etot, evdw, ecoul, epb, ecav, edisp))
    table.flush()
  table.close()     
 
#======================================================================
def difftable_gb(project, start, step):
#======================================================================
  """
  Generates the interaction energy table *.D.nrg.
  Reads the three independent *.nrg tables for the ligand, the
  receptor, and the complex.
  Returns nothing, but generates the global file
  'project.D.nrg'

  """
  fvdw = 1.00 # weight for vdW term in final difference
  # this value might become a changeable option in future releases
  lignrg = open("%s.L.nrg"%project, 'r');
  recnrg = open("%s.R.nrg"%project, 'r');
  comnrg = open("%s.C.nrg"%project, 'r');
  diffnrg = open("%s.D.nrg"%project, 'w');

  current = start-step
  for ligcurrent in lignrg.readlines():
    reccurrent = recnrg.readline()
    comcurrent = comnrg.readline()
    lvalues = whitespace.split(ligcurrent.strip())
    rvalues = whitespace.split(reccurrent.strip())
    cvalues = whitespace.split(comcurrent.strip())
#here we reject diffs where sasa is zero in any of the nrg lines
    if float(cvalues[6]) == 0. or float(rvalues[6]) == 0. or float(lvalues[6]) == 0.:
      continue
    dbat = float(cvalues[2])-float(rvalues[2])-float(lvalues[2]) #MM energy
    dvdw = (float(cvalues[3])-float(rvalues[3])-float(lvalues[3]))*fvdw #vdW
    dcoul = float(cvalues[4])-float(rvalues[4])-float(lvalues[4]) #Coulomb
    dgb = float(cvalues[5])-float(rvalues[5])-float(lvalues[5]) #GB
    dsasa = float(cvalues[6])-float(rvalues[6])-float(lvalues[6]) #SASA
    dtot = dbat+dvdw+dcoul+dgb+dsasa #total
    current+= step
    diffnrg.write("%4i %10.2f %10.2f %10.2f %10.2f %10.2f %10.2f\n"
          %(current, dtot, dbat, dvdw, dcoul, dgb, dsasa))

#======================================================================
def difftable_pb(project, start, step):
#======================================================================
  """
  generates the interaction energy table *.D.nrg;
  reads the three independent *.nrg tables for the ligand, the
  receptor, and the complex;
  returns nothing, but generates the global file
  'project.D.nrg';
  """
  lignrg = open("%s.L.nrg"%project, 'r');
  recnrg = open("%s.R.nrg"%project, 'r');
  comnrg = open("%s.C.nrg"%project, 'r');
  diffnrg = open("%s.D.nrg"%project, 'w');
  current = start-step
  for ligcurrent in lignrg.readlines():
    reccurrent = recnrg.readline()
    comcurrent = comnrg.readline()
    lvalues = whitespace.split(ligcurrent.strip())
    rvalues = whitespace.split(reccurrent.strip())
    cvalues = whitespace.split(comcurrent.strip())
    dtot = float(cvalues[1])-float(rvalues[1])-float(lvalues[1])
    dvdw = float(cvalues[2])-float(rvalues[2])-float(lvalues[2])
    dcoul = float(cvalues[3])-float(rvalues[3])-float(lvalues[3])
    dpb = float(cvalues[4])-float(rvalues[4])-float(lvalues[4]) 
    dcav = float(cvalues[5])-float(rvalues[5])-float(lvalues[5])
    ddisp = float(cvalues[6])-float(rvalues[6])-float(lvalues[6])
    current+= step
    diffnrg.write("%4i %10.2f %10.2f %10.2f %10.2f %10.2f %10.2f\n"
          %(current, dtot, dvdw, dcoul, dpb, dcav, ddisp))

#======================================================================
def ffgbsa(prmtop, pdbfile, gb, sa):
#======================================================================
  """Computes force field terms, Generalized Born, and SASA
  
  Calls ffgbsa, a NAB application written for this purpose
  * prmtop = AMBER prmtop file for the molecular system;
  * pdbfile = PDB file for the molecular system;
  * gb = flag for Generalized Born: 1, 2, 5 for yes, else no
    (if gb not 1, 2, 5, i.e., is not used, a distance-dependent
     function eps = r is used instead)
  * sa is a flag to compute the solvent-accessible surface, 
    in the current version, it MUST always be set to 1.

  Returns energy components in tuple with 6 values in the order
     etot (total energy)
     ebat (sum of bond, angle, torsion terms)
     evdw (the van der Waals energy)
     ecoul (the Coulomb term)
     egb (the Generalized-Born term, if asked for...)
     sasa (the SASA cavity energy term if molsurf works correctly...)

  NOTE 1:
  The molsurf function in NAB usually works fine.
  In the rare case of a problem, SA is set to zero for the problematic
  frame, a warning is emitted, and later the frame is excluded from the
  statistical evaluations
  NOTE 2:
  Atom radii used for molsurf are set in accordance with the radii used
  in Holger Gohlke's MMPB/SA Perl scripts.
  All radii are then incremented by 1.4 Angstroem and the surface is computed 
  with a probe of radius zero.
  NOTE 3: 
  We currently use a surface tension term of 0.0072 to convert the SASA into kcal/mol.
  """

# surface tension
  fsurf = 0.0072
  gibbs = Popen(["ffgbsa", pdbfile, prmtop, gb, sa], stdout = PIPE)
  gibbs.wait()
  lines = gibbs.stdout.readlines()
# change in december 2009
# since the output of ffgbsa can vary in length, the relevant line numbers can vary,
# hence we scan for the first field in each line to make sure that the correct record is read   
  for line in lines:
    energies = whitespace.split(line)
    if energies[0] == "ff:":
      etot = float(energies[2])
      ebat = float(energies[3])
      evdw = float(energies[4])
      ecoul = float(energies[5])
      egb = float(energies[7])
# check if sasa worked and if yes, compute energy contribution
# note that sasa = 0.00 will exclude this frame from evaluation of
# averages later
  sasa = 0.00
  for line in lines:
    energies = whitespace.split(line)
    if energies[0] == "sasa:":
      sasa = float(energies[1]) * fsurf
# notice to the screen that some problems occured in molsurf
  if sasa == 0.00:
    print("\nfailure in molsurf")
# total energy is the etot term from the energy call plus the sasa energy
  etot = etot + sasa
  return etot, ebat, evdw, ecoul, egb, sasa

#======================================================================
def pbsarun(prmtop, crdfile):
#======================================================================
  """
  runs pbsa on the specified prmtop and crd file pair,
  uses temporary 'pbsasfe.in' command file,
  returns the energy components fished from the temporary 
  output file 'pbsasfe.out'
  """
  cmdline = "pbsa -O -i pbsasfe.in -o pbsasfe.out -p %s -c %s"%(prmtop, crdfile)
  os.system(cmdline)
  result = open('pbsasfe.out', 'r')
  lines = result.readlines()
  i=-1
  for line in lines:
    i=i+1
    energies = whitespace.split(line)
    if energies[1] == "FINAL":
       etot  = float(whitespace.split(lines[i+4])[3])
       evdw  = float(whitespace.split(lines[i+6])[11])
       ecoul = float(whitespace.split(lines[i+7])[3])
       epb   = float(whitespace.split(lines[i+7])[6])
       ecav  = float(whitespace.split(lines[i+8])[2])
       edisp = float(whitespace.split(lines[i+8])[5])
  return etot, evdw, ecoul, epb, ecav, edisp

#======================================================================
def pbsacontrol_solv3():
#======================================================================
# for opt.solv = 3
  """
  writes the pbsa control file for full PBSA energy terms
  into a file 'pbsasfe.in'
  see pbsa documentation for details...
  simple SASA with Esasa = 0.005*SASA+0.86

  NOTE: users who want to change PBSA settings should
      do so in the lines below...
  """

  cmd = open('pbsasfe.in', 'w')
  cmd.write('''PB calculation (SASA only)
&cntrl 
ntx=1, imin=1, igb=10, inp=1,
/                     
&pb                     
epsout=80.0, epsin=1.0, space=0.5, bcopt=6, dprob=1.4,
cutnb=0, eneopt=2, 
accept=0.001, sprob=1.6, radiopt=0, fillratio=4,   
maxitn=1000, arcres=0.0625,
cavity_surften=0.005, cavity_offset=0.86
/
''')
  cmd.close()

#======================================================================
def pbsacontrol_solv4():
#======================================================================
# for opt.solv = 4
  """
  writes the pbsa control file for full PBSA energy terms
  into a file 'pbsasfe.in'
  see pbsa documentation for details...
  uses settings for SAV from Table 3 in Luo paper...
  identical as settings for solvation free energy...
  """
  cmd = open('pbsasfe.in', 'w')
  cmd.write('''PB calculation SAV/DISP scheme
&cntrl 
ntx=1, imin=1, igb=10, inp=2 
/ 
&pb 
npbverb=0, istrng=0.0, epsout=80.0, epsin=1.0, 
radiopt=1, dprob=1.6,
space=0.5, nbuffer=0, fillratio=4.0, 
accept=0.001, arcres=0.0625,
cutnb=0, eneopt=2, 
decompopt=2, use_rmin=1, sprob=0.557, vprob=1.300, 
rhow_effect=1.129, use_sav=1, 
cavity_surften=0.0378, cavity_offset=-0.5692
/ 
''')
  cmd.close()

#======================================================================
def stats_gb(project, P):
#======================================================================
  """
  Generates statistics (average, standard deviation, and standard error)
  NEW (September 2009) Takes care of excluding frames for which molsurf failed
  """
  if (P == "L"):
    middle = "L"
  elif(P == "R"):
    middle = "R"
  elif(P == "D"):
    middle = "D"
  else:
    middle = "C"

  table = open("%s.%s.nrg"%(project, middle), "r")  
  lines = table.readlines()
  records = len(lines)
  current = 0
  skip = 0
  tot = 0.00
  bat = 0.00
  vdw = 0.00
  coul = 0.00
  gb = 0.00
  sasa = 0.00
  # averages
  while(current<records):
    currentline = lines[current]
    values = whitespace.split(currentline.strip())
    if float(values[6]) == 0.:
      current+= 1
      skip+=1
      continue
    tot+= float(values[1])
    bat+= float(values[2])
    vdw+= float(values[3])
    coul+= float(values[4])
    gb+= float(values[5])
    sasa+= float(values[6])
    current+= 1
  records = records - skip

  average = (tot/records, bat/records, vdw/records, coul/records, gb/records, sasa/records)

  records = len(lines) #reset the records here
  current = 0
  skip = 0 # counter for records to be skipped because of molsurf failure
  tot = 0.00
  bat = 0.00
  vdw = 0.00
  coul = 0.00
  gb = 0.00
  sasa = 0.00
  # standard deviations and standard error of mean (SEM)
  while(current<records):
    currentline = lines[current]
    values = whitespace.split(currentline.strip())
    if float(values[6]) == 0.:
      current+= 1
      skip+=1
      continue
    tot+= (float(values[1])-average[0])*(float(values[1])-average[0])
    bat+= (float(values[2])-average[1])*(float(values[2])-average[1])
    vdw+= (float(values[3])-average[2])*(float(values[3])-average[2])
    coul+= (float(values[4])-average[3])*(float(values[4])-average[3])
    gb+= (float(values[5])-average[4])*(float(values[5])-average[4])
    sasa+= (float(values[6])-average[5])*(float(values[6])-average[5])
    current+= 1
  records = records - skip
  di = records - 1
  if(di == 0):
    stdev = (0, 0, 0, 0, 0, 0)
    stderr = (0, 0, 0, 0, 0, 0)
  else:
    stdev = (math.sqrt(tot/di), math.sqrt(bat/di), math.sqrt(vdw/di), 
       math.sqrt(coul/di), math.sqrt(gb/di), math.sqrt(sasa/di))
    stderr = (stdev[0]/math.sqrt(records), stdev[1]/math.sqrt(records), 
        stdev[2]/math.sqrt(records), stdev[3]/math.sqrt(records), 
        stdev[4]/math.sqrt(records), stdev[5]/math.sqrt(records))
  return(average, stdev, stderr)

#======================================================================
def stats_pb(project, P):
#======================================================================
  """
  Generates statistics (average, standard deviation, and standard error)
  """
  if (P == "L"):
    middle = "L"
  elif(P == "R"):
    middle = "R"
  elif(P == "D"):
    middle = "D"
  else:
    middle = "C"
  table = open("%s.%s.nrg"%(project, middle), "r")  
  lines = table.readlines()
  records = len(lines)
  current = 0
  tot = 0.00
  vdw = 0.00
  coul = 0.00
  pb = 0.00
  cav = 0.00
  disp = 0.00
# averages
  while(current<records):
    currentline = lines[current]
    values = whitespace.split(currentline.strip())
    tot+= float(values[1])
    vdw+= float(values[2])
    coul+= float(values[3])
    pb+= float(values[4])
    cav+= float(values[5])
    disp+= float(values[6])
    current+= 1

  average = (tot/records, vdw/records, coul/records, pb/records, cav/records, disp/records)

  current = 0
  tot = 0.00
  vdw = 0.00
  coul = 0.00
  pb = 0.00
  cav = 0.00
  disp = 0.00
# standard deviations and standard error of mean (SEM)
  while(current<records):
    currentline = lines[current]
    values = whitespace.split(currentline.strip())
    tot+= (float(values[1])-average[0])*(float(values[1])-average[0])
    vdw+= (float(values[2])-average[1])*(float(values[2])-average[1])
    coul+= (float(values[3])-average[2])*(float(values[3])-average[2])
    pb+= (float(values[4])-average[3])*(float(values[4])-average[3])
    cav+= (float(values[5])-average[4])*(float(values[5])-average[4])
    disp+= (float(values[6])-average[5])*(float(values[6])-average[5])
    current+= 1

  di = records - 1
  if(di == 0):
    stdev = (0, 0, 0, 0, 0, 0)
    stderr = (0, 0, 0, 0, 0, 0)
  else:
    stdev = (math.sqrt(tot/di), math.sqrt(vdw/di), math.sqrt(coul/di), 
       math.sqrt(pb/di), math.sqrt(cav/di), math.sqrt(disp/di))
    stderr = (stdev[0]/math.sqrt(records), stdev[1]/math.sqrt(records), 
        stdev[2]/math.sqrt(records), stdev[3]/math.sqrt(records), 
        stdev[4]/math.sqrt(records), stdev[5]/math.sqrt(records))
  return(average, stdev, stderr)

#====================================================================== 
def summary_gb(opt):
#======================================================================
  if opt.solv == 0:
    genb = 'eps = r'
  else:
    genb = 'GB'

  final = open("%s.sum"%opt.project, "w")
  final.write("\n=======================================================================\n")
  final.write("Summary Statistics for Project %s\n"%opt.project)
  final.write("Frames                : %d to %d (every %d)\n"%(opt.start,opt.stop,opt.step))
  final.write("Solvation             : %s (--solv=%d)\n"%(genb, opt.solv))
  final.write("Trajectory File       : %s\n"%opt.trajectory)
  final.write("Complex  parmtop File : %s\n"%opt.comprm)
  final.write("Receptor parmtop File : %s\n"%opt.recprm)
  final.write("Ligand   parmtop File : %s\n"%opt.ligprm)
  final.write("=======================================================================\n")

  tables = "L", "R", "C", "D"
  for x in tables:
    summary = stats_gb(opt.project, x)
    if(x == "L"):
      final.write("-----Ligand Energies---------------------------------------------------\n")
    elif(x == "R"):
      final.write("-----Receptor Energies-------------------------------------------------\n")
    elif(x == "C"):
      final.write("-----Complex Energies--------------------------------------------------\n")
    else:
      final.write("-----Interaction Energy Components-------------------------------------\n")
    final.write("Etot  = %10.2lf (%6.2lf, %6.2lf) "%(summary[0][0], summary[1][0], summary[2][0]))
    final.write("Ebat  = %10.2lf (%6.2lf, %6.2lf)\n"%(summary[0][1], summary[1][1], summary[2][1]))
    final.write("Evdw  = %10.2lf (%6.2lf, %6.2lf) "%(summary[0][2], summary[1][2], summary[2][2]))
    final.write("Ecoul = %10.2lf (%6.2lf, %6.2lf)\n"%(summary[0][3], summary[1][3], summary[2][3]))
    final.write("EGB   = %10.2lf (%6.2lf, %6.2lf) "%(summary[0][4], summary[1][4], summary[2][4]))
    final.write("Esasa = %10.2lf (%6.2lf, %6.2lf)\n"%(summary[0][5], summary[1][5], summary[2][5]))
  final.write("=======================================================================\n")

#====================================================================== 
def summary_pb(opt):
#======================================================================
  if opt.solv == 3:
    nonpolar = 'PB+SASA'
  else:
    nonpolar = 'PB+SAV+DISP'

  final = open("%s.sum"%opt.project, "w")
  final.write("\n=======================================================================\n")
  final.write("Summary MDPBSA Statistics for Project %s\n"%opt.project)
  final.write("Solvation             : %s (--solv=%d)\n"%(nonpolar, opt.solv))
  final.write("Frames                : %d to %d (every %d)\n"%(opt.start,opt.stop,opt.step))
  final.write("Trajectory File       : %s\n"%opt.trajectory)
  final.write("Complex  parmtop File : %s\n"%opt.comprm)
  final.write("Receptor parmtop File : %s\n"%opt.recprm)
  final.write("Ligand   parmtop File : %s\n"%opt.ligprm)
  final.write("=======================================================================\n")


  tables = "L", "R", "C", "D"
  for x in tables:
    summary = stats_pb(opt.project, x)
    if(x == "L"):
      final.write("-----Ligand Energies---------------------------------------------------\n")
    elif(x == "R"):
      final.write("-----Receptor Energies-------------------------------------------------\n")
    elif(x == "C"):
      final.write("-----Complex Energies--------------------------------------------------\n")
    else:
      final.write("-----Interaction Energy Components-------------------------------------\n")
    final.write("Etot  = %10.2lf (%6.2lf, %6.2lf) "%(summary[0][0], summary[1][0], summary[2][0]))
    final.write("Evdw  = %10.2lf (%6.2lf, %6.2lf)\n"%(summary[0][1], summary[1][1], summary[2][1]))
    final.write("Ecoul = %10.2lf (%6.2lf, %6.2lf) "%(summary[0][2], summary[1][2], summary[2][2]))
    final.write("Epb   = %10.2lf (%6.2lf, %6.2lf)\n"%(summary[0][3], summary[1][3], summary[2][3]))
    final.write("Ecav  = %10.2lf (%6.2lf, %6.2lf) "%(summary[0][4], summary[1][4], summary[2][4]))
    final.write("Edisp = %10.2lf (%6.2lf, %6.2lf)\n"%(summary[0][5], summary[1][5], summary[2][5]))
    final.write("=======================================================================\n")

#======================================================================
def split2pdb(prmtop, trajectory, project, strip, P, start, stop, step):
#======================================================================

  """ Splits MD trajectories into PDB files via the AMBER tool ptraj
  call arguments:
  * prmtop = AMBER prmtop file used to record the trajectory(!);
  * trajectory = actual name of the trajectory file (any format that
    can be read by ptraj is accepted);
  * project = global project name (used to identify output files);
  * strip = AMBER mask telling ptraj which part of the molecule
    should be "stripped" (e.g. ':LIG' would strip the ligand atoms from
    the output, writing only the receptor coordinates to the output);
    strip is set by the main calling routine
  * P is a flag that tells ptraj which part of the molecular system
    has to be written to the output: 
    P = L for ligand, 
    P = R for receptor, 
    P = C for the entire complex;
  * start, stop, end define the frame selection 
     (e.g. 11, 2, 21 would write out frames 11, 13, 15, 17, 19, 21);
  """
  traj = Popen("ptraj %s\n" %prmtop, shell = True, stdin = PIPE, stdout = PIPE, stderr = PIPE)
  traj.stdin.write("trajin %s %s %s %s\n" %(trajectory, start, stop, step))
  # fit to C-alphas with first frame as reference
  traj.stdin.write("rms first '@CA'\n")
  # strip the ligand or the receptor, depending on P and write the
  # stripped PDB file for the frames
  traj.stdin.write("strip %s\n"%strip)
  if(P == "L"):
     traj.stdin.write("trajout %s.L.pdb pdb nowrap\n" %project)
  elif(P == "R"):
     traj.stdin.write("trajout %s.R.pdb pdb nowrap\n" %project)
  else:
     traj.stdin.write("trajout %s.C.pdb pdb nowrap\n" %project)
     
  traj.communicate("go\n")

#======================================================================
def split2crd(prmtop, trajectory, project, strip, P, start, stop, step):
#======================================================================
  """ 
  Splits MD trajectories into AMBER crd (restart) files via ptraj
  ...otherwise like split2pdb  
  """
  traj = Popen("ptraj %s\n" %prmtop, shell = True, stdin = PIPE, stdout = PIPE, stderr = PIPE)
  traj.stdin.write("trajin %s %s %s %s\n" %(trajectory, start, stop, step))
# strip the ligand or the receptor, depending on P and write the stripped CRD file for the frames
  traj.stdin.write("strip %s\n"%strip)
  if(P == "L"):
    traj.stdin.write("trajout %s.L.crd restart\n" %project)
  elif(P == "R"):
    traj.stdin.write("trajout %s.R.crd  restart\n" %project)
  else:
    traj.stdin.write("trajout %s.C.crd restart\n" %project)
  traj.communicate("go\n")


#**********************************************************************
#         Main  
#**********************************************************************
if __name__ ==  "__main__":

  parser = OptionParser()
  parser.add_option("--proj", metavar = "NAME", 
        dest = "project", 
        help = "global project name")
  parser.add_option("--traj", default = "traj.binpos", metavar = "FILE", 
        dest = "trajectory", 
        help = "MD trajectory file               (default: traj.binpos)")
  parser.add_option("--cprm", default = "com.prm", metavar = "FILE", 
        dest = "comprm", 
        help = "complex prmtop file              (default: com.prm)")
  parser.add_option("--lprm", default = "lig.prm", metavar = "FILE", 
        dest = "ligprm", 
        help = "ligand only prmtop file          (default: lig.prm)")
  parser.add_option("--rprm", default = "rec.prm", metavar = "FILE", 
        dest = "recprm", 
        help = "receptor only prmtop file        (default: rec.prm)")
  parser.add_option("--lig", default = "LIG", metavar = "STRING", 
        dest = "lig", 
        help = "residue name of ligand           (default: LIG)")
  parser.add_option("--start", default = 1, metavar = "INT", 
        type = "int", dest = "start", 
        help = "first MD frame to be used        (default: 1)")
  parser.add_option("--stop", default = 1, metavar = "INT", 
        type = "int", dest = "stop", 
        help = "last MD frame to be used         (default: 1)")
  parser.add_option("--step", default = 1, metavar = "INT", 
        type = "int", dest = "step", 
        help = "use every [step] MD frame        (default: 1)")
  parser.add_option("--solv", default = 1, metavar = "INT",
        type = "int", dest = "solv",
        help = "0 for no solvation term (eps=r)                              "
         "1, 2, or 5 for GBSA                                                "
         "3 for PBSA                                                         "
         "4 for PBSA/dispersion            (default: 1)    ")
  parser.add_option("--clean", action = "store_true", dest = "clean",
        help = "clean up temporary files         (default: no clean)")
  if len(sys.argv) == 1:
    print("---------------------------------------------------------")
    print(" pymdpbsa version 0.7 (May 2011)")
    print("---------------------------------------------------------")
    parser.print_help()
    sys.exit(-1)
  else:pass

  (opt, args) = parser.parse_args()

# some checks before we run in trouble...
# check if all required data files are present...
  for file in [opt.trajectory, opt.comprm, opt.recprm, opt.ligprm]:
    if not os.path.exists(file):
      print('...the file %s cannot be found...bye'%file)
      sys.exit(-1)

# the following executables must be in the path for pymdpbsa to work:
  for prog in ['ffgbsa', 'pbsa', 'ptraj']:
    try:
      Popen([prog, '-h'], stdin = PIPE, stdout = PIPE, stderr = PIPE)
    except OSError,e:
      print 'the required routine %s cannot be found...bye'%prog
      sys.exit(-1)


# prepare temporary directory where all computations are to be carried out...
  tmpdir = tempfile.mkdtemp('.tmpdir','%s_'%opt.project,'.')
  os.chdir(tmpdir)
  filelist = [opt.trajectory, opt.comprm, opt.recprm, opt.ligprm]
  for i,file in enumerate(filelist):
    os.symlink('../'+file, 'link%d'%i)
  traj = 'link0'; cprm = 'link1'; rprm = 'link2'; lprm = 'link3'
# strips everything from trajectory except the ligand   
  striplig = ":*&!:%s"%opt.lig
# strips the ligand from the trajectory, leaving just the receptor
# (and possibly other entities considered as part of the receptor
# like single water molecules or ions)
  striprec = ":%s"%opt.lig
# we always use the SASA term at this point (might be changed in the future)
  sa = "1"
# check if the start, step, stop settings are reasonable:
# crude attempt to avoid obvious errors
  if opt.stop < 1:
    opt.stop = 1
  if opt.start < 1:
    opt.start = 1
  if opt.start > opt.stop:
    opt.start = opt.stop
  if opt.stop < opt.start:
    opt.stop = opt.start
  if (opt.stop-opt.start)%opt.step != 0:
    opt.stop = opt.stop - (opt.stop-opt.start)%opt.step
  else: pass


  gbflag = 0
  pbflag = 0
# decide of GB or PB, note the gb must be passed as string here!
  if opt.solv < 1 or opt.solv > 5:
    gb = "0"
    gbflag = 1
  elif opt.solv == 1 or opt.solv == 2 or opt.solv == 5:
    gb = "%s"%opt.solv
    gbflag = 1
  else:
    pb = opt.solv
    pbflag = 1
# prepare the correct pbsa command file if PB is chosen
  if opt.solv == 3:
    pbsacontrol_solv3()
  elif opt.solv == 4:
    pbsacontrol_solv4()   
  else: pass

# function calls: split trajectories, generate energy tables,
# make statistics and write out to the summary file

# ligand alone ------

  if gbflag == 1:
    split2pdb(cprm, traj, opt.project, striplig, "L", 
        opt.start, opt.stop, opt.step)
    nrgtable_gb(lprm, opt.project , "L", 
           opt.start, opt.stop, opt.step, gb, sa)
  else:
    split2crd(cprm, traj, opt.project, striplig, "L", 
        opt.start, opt.stop, opt.step)
    nrgtable_pb(lprm, opt.project , "L", 
           opt.start, opt.stop, opt.step)
# receptor alone -----
  if gbflag == 1:
    split2pdb(cprm, traj, opt.project, striprec, "R", 
        opt.start, opt.stop, opt.step)
    nrgtable_gb(rprm, opt.project, "R", 
           opt.start, opt.stop, opt.step, gb, sa)
  else:
    split2crd(cprm, traj, opt.project, striprec, "R", 
        opt.start, opt.stop, opt.step)
    nrgtable_pb(rprm, opt.project, "R", 
           opt.start, opt.stop, opt.step)
# complex -----
  if gbflag == 1:
    split2pdb(cprm, traj, opt.project, ":ZZZ", "C", 
        opt.start, opt.stop, opt.step)
    nrgtable_gb(cprm, opt.project, "C", 
           opt.start, opt.stop, opt.step, gb, sa)
  else:
    split2crd(cprm, traj, opt.project, ":ZZZ", "C", 
        opt.start, opt.stop, opt.step)
    nrgtable_pb(cprm, opt.project, "C", 
           opt.start, opt.stop, opt.step)
# interaction energy -----
  if gbflag == 1:
    difftable_gb(opt.project, opt.start, opt.step)
  else:
    difftable_pb(opt.project, opt.start, opt.step)
# make statistics and generate summary output file
  if gbflag == 1:
    summary_gb(opt)
  else:
    summary_pb(opt)

# copy files to keep to the parent directory
  shutil.copy('%s.sum'%opt.project, '../%s.sum'%opt.project)
  for mid in ['L', 'R', 'C', 'D']:
    shutil.copy('%s.%s.nrg'%(opt.project,mid), '../%s.%s.nrg'%(opt.project,mid))

# remove the temporary directory if --clean option was specified
  os.chdir('../')
  if(opt.clean):
    shutil.rmtree(tmpdir)
  else:
    sys.exit(0)


