Tinkergpu:get rot rest

From biowiki
Jump to navigation Jump to search

#!/usr/bin/env python
'''
Get Tinker keywords for Boresch restraint

> get_rot_rest.py sample.xyz sample.key

Read any traj format and "ligand keyword" from Tinker key file
Print 6 distance/angle/torsion restraints and the standard state correction

'''

import numpy as np
import mdtraj as md
import sys
import os
import time
import scipy.optimize
from scipy.spatial import distance_matrix
from scipy.cluster.hierarchy import fcluster, leaders, linkage
from scipy.spatial.distance import pdist
from collections import Counter, defaultdict
from openbabel import openbabel
from openbabel import pybel

def get_adjlist(top):
    adjlist = defaultdict(list)
    for bond in top.bonds:
        a1, a2 = bond
        if not (a1.name.startswith('H') or a2.name.startswith('H') ):
            adjlist[a1.index].append(a2.index)
            adjlist[a2.index].append(a1.index)
    return adjlist

def get_tab_branch(top):
    '''
    Given mdtraj Trajectory.Topology, 
    returns a dictionary for the number of connected heavy atoms
    '''
    i = 0
    nbr = np.zeros(top.n_atoms, dtype=np.int)
    for bond in top.bonds:
        i += 1
        a1, a2 = bond
        if not (a1.name.startswith('H') or a2.name.startswith('H') ):
            nbr[a1.index] += 1
            nbr[a2.index] += 1
    return nbr

def error_exit(msg):
    print("ERROR:", msg)
    sys.exit(1)
def write_tinker_idx(idxs):
    idx_out = []
    rs = []
    for i0 in sorted(idxs):
        if len(rs) == 0 or rs[-1][1]+1 < i0:
            rs.append([i0, i0])
        else: 
            rs[-1][1] = i0
    for r in rs:
        if r[0] == r[1]:
            idx_out.append(r[0])
        elif r[0] == r[1] - 1:
            idx_out.append(r[0])
            idx_out.append(r[1])
        else:
            idx_out.append(-r[0])
            idx_out.append(r[1])
    return idx_out
        
def read_tinker_idx(args):
    ''' Read tinker indices

    Example:
    ['5'] -> {5}
    ['-5', '7', '10', '-12', '15'] -> {5, 6, 7, 10, 12, 13, 14, 15}

    args: list of strings of integers
    return: a set of indices

    '''
    _range = []
    idxs = []
    for a in args:
        n = int(a)
        if n < 0 and len(_range) == 0:
            _range.append(-n)
        if n > 0:
            if len(_range) == 1:
                idxs.extend(list(range(_range.pop(), n+1)))
            else:
                idxs.append(n)
    return set(idxs)

def read_ligidx(fkey):
    '''Read tinker key file and return indices of the ligand
    '''
    ligand_idx = set() # ligand indices
    with open(fkey, 'r') as fh:
        for line in fh:
            w = line.split()
            if line.lower().startswith('ligand') and len(w) >= 2:
                idx_str = line.replace(',', ' ').split()[1:]
                ligand_idx |= read_tinker_idx(idx_str)
    return ligand_idx


def target_disp(weights, coord, alpha=0.1):
    '''target function for center of mass displacement and number of atoms in the group
    '''
    assert len(weights) == coord.shape[0]
    wts = np.array(weights).reshape((-1, 1))
    wts = np.maximum(wts, 0)
    assert sum(wts) > 0
    wts *= 1.0/np.mean(wts)

    loss = np.sum(np.abs(np.sum(wts*coord, axis=0)))
    loss += alpha*np.sum(np.abs(np.abs(wts*coord)))

    for m in np.arange(2, 20):
        mask0 = (wts > m)
        loss += alpha*np.sum((np.abs((wts*mask0))))

    return loss

def calc_coord(traj, idxs):
    vals = np.zeros(traj.n_frames)
    idxs = np.array(idxs)
    NM_IN_ANG = 10.0
    RAD_IN_DEG = 180.0/np.pi
    if len(idxs) == 2:
        vals = md.compute_distances(traj, idxs.reshape(1, -1))
        vals *= NM_IN_ANG
    elif len(idxs) == 3:
        vals = md.compute_angles(traj, idxs.reshape(1, -1))
        vals *= RAD_IN_DEG
    elif len(idxs) == 4:
        vals = md.compute_dihedrals(traj, idxs.reshape(1, -1))
        vals *= RAD_IN_DEG
    return vals[:, 0]

def calc_idx_ortho(traj, i1, i2, idx3, nbr=None, method='long'):
    '''
    find the index that gives the largest ortho vector

    nbr: list of nr of branched atoms
    '''
    DELTA_R = 0.2
    DELTA_COS = 0.3

    if nbr is None:
        nbr = defaultdict(int)
    t = traj
    a1 = t.xyz[0, [i1], :] - t.xyz[0, [i2], :]
    u1 = a1/np.linalg.norm(a1, axis=1)

    vec2 = t.xyz[0, idx3, :] - t.xyz[0, [i2], :]
    d2 = np.linalg.norm(vec2, axis=1)
    d2p = np.abs(np.sum(vec2 * u1,axis=1))
    d2o = np.sqrt(d2**2.0 - d2p**2.0)
    d2cos = d2p / np.maximum(1e-5, d2)
    if method == 'short':
        # large angle (~90 deg), short distance
        d2on = np.array(list(zip([nbr[_]<1 for _ in idx3], -d2cos//DELTA_COS, -d2o)) , dtype=[('n', np.int),('c', np.float), ('r', np.float)])
    else:
        # large ortho vector
        d2on = np.array(list(zip([nbr[_] for _ in idx3], d2o)), dtype=[('n', np.int), ('r', np.float)])
    i3 = idx3[np.argsort(d2on)[-1]]
    return i3

def write_xyz_from_md(traj, idx):
    outp = '%d\nExtracted from MD\n'%(len(idx))
    for n, i in enumerate(idx):
        atom = traj.topology.atom(i)
        xyz = traj.xyz[0, i, :]*10
        outp += '%5s %12.6f %12.6f %12.6f\n'%(atom.element.symbol, xyz[0], xyz[1], xyz[2])
    return outp

def get_rotlist(traj, ligidx):
    ftype = 'xyz'
    ligidx_sort = sorted(ligidx)
    outp_xyz = write_xyz_from_md(traj, ligidx_sort)
    mymols = list([pybel.readstring(ftype, outp_xyz)])
    mymol = mymols[0]
    iter_bond = openbabel.OBMolBondIter(mymol.OBMol)
    rotlist = []
    for bond in iter_bond:
        if bond.IsRotor():
            i1, i2 = (bond.GetBeginAtomIdx(), bond.GetEndAtomIdx())
            rotlist.append((ligidx_sort[i1-1], ligidx_sort[i2-1])) 
    return rotlist

def get_idx_pocket(traj, idx1, idx2, rcutoff=0.5):
    R_CLUSTER = rcutoff
    t = traj
    nbr = get_tab_branch(t.topology)
    sorted_i1, sorted_i2, dist1, dist2 = calc_idx_interface(traj, idx1, idx2, rcutoff=rcutoff)

    if len(sorted_i1) == 0:
        error_exit('No interface atom found within %.3f nm'%rcutoff)

    prot_idx = sorted_i1
    lig_idx = sorted_i2
    prot_dist = pdist(t.xyz[0, prot_idx, :]) 
    prot_distmat = distance_matrix(t.xyz[0, prot_idx, :],t.xyz[0, prot_idx, :]) 
    Z = linkage(prot_dist, method='average')

    clst_idx = fcluster(Z, R_CLUSTER, criterion='distance')
    clst_size = Counter(clst_idx)
    clst_size_list = sorted(clst_size.items(), key = lambda t:t[1])
    iclstm = clst_size_list[-1][0]

    flag_max = (clst_idx == iclstm)
    iprot = prot_idx[flag_max][0]

    lig_dist = np.linalg.norm(t.xyz[0, lig_idx, :] - t.xyz[0, [iprot], :], axis=1)
    dist_nr2 = np.array(list(zip(lig_dist, [nbr[_]<=1 for _ in lig_idx])), dtype=[('r', np.float), ('n', np.int)])
    ord_nr2 = np.argsort(dist_nr2, order=('n', 'r'))
    ilig = sorted_i2[ord_nr2[0]]

    iprot2 = calc_idx_ortho(traj, ilig, iprot, prot_idx, nbr)
    iprot3 = calc_idx_ortho(traj, iprot, iprot2, prot_idx[prot_idx != iprot], nbr)

    ilig2 = calc_idx_ortho(traj, iprot, ilig, lig_idx[lig_idx != ilig], nbr, method='short')
    ilig3 = calc_idx_ortho(traj, ilig, ilig2, lig_idx[(lig_idx != ilig)*(lig_idx != ilig2)], nbr, method='short')

    #                             -r-
    #                        -a-      -a-
    # iprot3 ... iprot2 -t- iprot -t- ilig -t- ilig2 ... ilig3
    int_idxs = [[iprot, ilig]]
    int_idxs.extend([[iprot2, iprot, ilig], [iprot, ilig, ilig2]])
    int_idxs.extend([[iprot3, iprot2, iprot, ilig], [iprot2, iprot, ilig, ilig2], [iprot, ilig, ilig2, ilig3]])
    #r0s = [calc_int(t.xyz[0, _, :])[0] for _ in int_idxs]
    #print(r0s)
    r0s = [calc_coord(t, _)[0] for _ in int_idxs]

    return int_idxs, r0s

def calc_idx_interface(traj, idx1, idx2, rcutoff=0.6):
    t = traj
    distmat = distance_matrix(t.xyz[0, idx1, :], t.xyz[0, idx2, :])
    dist2 = np.min(distmat, axis=0)
    dist1 = np.min(distmat, axis=1)

    ord1 = np.argsort(dist1)
    ord2 = np.argsort(dist2)

    n1 = np.sum(dist1 <= rcutoff)
    n2 = np.sum(dist2 <= rcutoff)

    return idx1[ord1[:n1]], idx2[ord2[:n2]], dist1[ord1[:n1]], dist2[ord2[:n2]]

def write_rest(int_idxs, r0s, k0s, fmt='tinker'):
    rest_name = ['', '', 'restrain-distance', 'restrain-angle', 'restrain-torsion']
    outp = ''
    for ridx, r0, k0 in zip(int_idxs, r0s, k0s):
        nat = len(ridx)
        if nat >= len(rest_name):
            continue
        if fmt == 'tinker':
            outp += '%s %s %.6f %.6f %.6f\n'%(rest_name[nat], ' '.join('%4d'%(_+1) for _ in ridx), k0, r0, r0)
        else:
            print("Format %s not supported"%fmt)
            return ''
    return outp

def get_rottors(traj, idx1):
    t = traj
    rotlist = get_rotlist(t, idx1)
    tors = []
    adjlist = get_adjlist(traj.topology)
    for bond in rotlist:
        a2 = min(bond)
        a3 = max(bond)
        a1s = [_ for _ in adjlist[a2] if _ != a3]
        a4s = [_ for _ in adjlist[a3] if _ != a2]
        if len(a1s) * len(a4s) == 0:
            error_exit("Cannot find heavy atom connected to rotable bond %d %d"%(a2, a3))
        tors.append([a1s[0], a2, a3, a4s[0]])
    r0s = [calc_coord(t, _)[0] for _ in tors]
    k0s = np.zeros_like(r0s) + 0.01
    print(write_rest(tors, r0s, k0s))
    return tors

def find_rotation_rest(fxyz, ligidx0, atomnames='CA', rcutoff=0.6, alpha=1.0, rest_rotbond=False, rotbond_only=False):
    try:
        t = md.load_arc(fxyz)
    except IOError:
        t = md.load(fxyz)
    if rest_rotbond:
        get_rottors(t, ligidx0)
    if rotbond_only:
        return

    nbr = get_tab_branch(t.topology)

    ligidx = np.array(sorted(list(set(ligidx0) - set(t.topology.select('name H')))))
    if len(ligidx) == 0:
        error_exit("No ligand heavy atoms found")

    protidx0 = t.topology.select('name %s'%(atomnames))
    protidx0 = np.array(sorted(set(protidx0) - set(ligidx)))

    if len(protidx0) == 0:
        error_exit("No ligand atoms within %.3f nm of protein %s atoms"%(rcutoff, atomnames))

    int_idxs, r0s = get_idx_pocket(t, protidx0, ligidx)
    RAD_IN_DEG = 180/np.pi
    k0s = np.zeros_like(r0s) + 10.0/(RAD_IN_DEG)**2.0
    k0s[0] = 10.0
    print(write_rest(int_idxs, r0s, k0s))
    RT = 8.314 * 298 / 4184
    # https://doi.org/10.1021/jp0217839
    # Eq. (14)
    assert len(k0s) == 6
    dgrest = RT*np.log(1662*8*np.pi**2.0*np.sqrt(np.prod(k0s)*RAD_IN_DEG**10.0) \
                        /(r0s[0]**2.0*np.sin(r0s[1]/RAD_IN_DEG)*np.sin(r0s[2]/RAD_IN_DEG)*(2*np.pi*RT)**3))

    print("#dGrest(kcal/mol) %.5f"%dgrest)



def find_grp_idx(fxyz, ligidx0, atomnames='CA', rcutoff=1.2, alpha=1.0):
    try:
        t = md.load_arc(fxyz)
    except IOError:
        t = md.load(fxyz)
    ligidx = np.array(sorted(list(set(ligidx0) - set(t.topology.select('name H')))))
    if len(ligidx) == 0:
        error_exit("No ligand heavy atoms found")

    protidx0 = t.topology.select('name %s'%(atomnames))
    protidx0 = np.array(sorted(set(protidx0) - set(ligidx)))

    distmat = distance_matrix(t.xyz[0, ligidx, :], t.xyz[0, protidx0, :])
    distpro = np.min(distmat, axis=0)
    protidx = protidx0[distpro <= rcutoff]
    if len(protidx) == 0:
        return

    imindist = np.argmin(distmat)
    iminlig = ligidx[imindist // len(protidx0)]

    com_lig = np.mean(t.xyz[0, list(ligidx), :], axis=0)
    com_lig = t.xyz[0, [iminlig], :]


    wts0 = np.ones(len(protidx))
    xyz_prot = t.xyz[0, protidx, :]
    xyz1_prot = xyz_prot - com_lig.reshape((1, -1))
    res = scipy.optimize.minimize(target_disp, wts0, args=(xyz1_prot, alpha))
    wts1 = np.array(res.x)
    wts1 = np.maximum(0, wts1)
    wts1 *= 1.0/np.sum(wts1)
    wtm = np.mean(wts1[wts1 > 0])
    mask1 = wts1 > 0.4*wtm
    #print(wts1)
    #print(np.mean(wts1))
    #print(mask1)
    com_p1 = (np.mean(t.xyz[0, protidx[mask1], :], axis=0)).reshape((1, -1))

    xyz_lig = t.xyz[0, ligidx, :]
    dists = np.linalg.norm(xyz_lig - com_p1, axis=1)
    #print('DIST', dists)
    imin = np.argmin(dists)

    dcom = np.linalg.norm(com_lig - com_p1)
    dmin = np.linalg.norm(xyz_lig[imin] - com_p1)


    idx_tinker = write_tinker_idx([_+1 for _ in protidx[mask1]])
    #print('#', (' '.join(['%5d'%(_) for _ in protidx[mask1]])))
    outp = ''
    sgrp = ''
    for n in idx_tinker:
        sgrp += ' %5d'%n
        if len(sgrp) > 50 and n > 0:
            #print('group 1 %s'%sgrp)
            outp += ('group 1 %s\n'%sgrp)
            sgrp = ''
            
    outp += ('group 2 %s\n'%(' '.join(['%5d'%(_) for _ in [ligidx[imin] + 1]])))
    outp += ("#r_0=%.3f"%(dmin*10))
    #print('group 2 %s'%(' '.join(['%5d'%(_) for _ in [ligidx[imin] + 1]])))
    #print("#r_0=%.3f"%(dmin*10))
    return dmin*10, outp


def main():
    fxyz = sys.argv[1]
    fkey = sys.argv[2]
    ligidx = [_-1 for _ in sorted(read_ligidx(fkey))]
    args = sys.argv[3:]
    if len(ligidx) == 0:
        error_exit("ligand keyword not found")
    #rest_rotbond = ( len(sys.argv)>3 and sys.argv[3] == 'rotatable')
    rest_rotbond = 'rotatable' in args
    no_orient = 'only' in args
    #find_rotation_rest(fxyz, ligidx, "CA C N")
    find_rotation_rest(fxyz, ligidx, "CA C N", rest_rotbond=rest_rotbond, rotbond_only=no_orient)

main()