diff --git a/environment.yml b/environment.yml index e48a81f..f65c4df 100644 --- a/environment.yml +++ b/environment.yml @@ -4,6 +4,7 @@ channels: dependencies: - click - MDAnalysis + - MDAnalysisTests - netCDF4 - openff-units - pip diff --git a/pyproject.toml b/pyproject.toml index 787ac87..f1c7451 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ test = [ "pytest-xdist", "pytest-cov", "pooch", + "MDAnalysisTests", ] [project.urls] diff --git a/src/openfe_analysis/rmsd.py b/src/openfe_analysis/rmsd.py index bc54701..0b705d3 100644 --- a/src/openfe_analysis/rmsd.py +++ b/src/openfe_analysis/rmsd.py @@ -5,10 +5,9 @@ import MDAnalysis as mda import netCDF4 as nc import numpy as np -import tqdm -from MDAnalysis.analysis import rms +from MDAnalysis.analysis import diffusionmap, rms +from MDAnalysis.analysis.base import AnalysisBase from MDAnalysis.transformations import unwrap -from numpy import typing as npt from .reader import FEReader from .transformations import Aligner, ClosestImageShift, NoJump @@ -53,7 +52,7 @@ def make_Universe(top: pathlib.Path, trj: nc.Dataset, state: int) -> mda.Univers - Unwraps protein and ligand atom to be made whole - Shifts protein chains and the ligand to the image closest to the first protein chain (:class:`ClosestImageShift`) - - Aligns the entire system to minimise the protein RMSD (:class:`Aligner`) + - Aligns the entire system to minimize the protein RMSD (:class:`Aligner`) If only a ligand is present: @@ -88,7 +87,7 @@ def make_Universe(top: pathlib.Path, trj: nc.Dataset, state: int) -> mda.Univers else: # if there's no protein # - make the ligand not jump periodic images between frames - # - align the ligand to minimise its RMSD + # - align the ligand to minimize its RMSD nope = NoJump(ligand) align = Aligner(ligand) @@ -100,9 +99,169 @@ def make_Universe(top: pathlib.Path, trj: nc.Dataset, state: int) -> mda.Univers return u +class Protein2DRMSD(AnalysisBase): + """ + Flattened 2D RMSD matrix + + For all unique frame pairs ``(i, j)`` with ``i < j``, this function + computes the RMSD between atomic coordinates after optimal alignment. + Alignment is performed by centering each frame on its center of geometry, + followed by rotational and translational superposition using the QCP method. + + Parameters + ---------- + atomgroup: mda.AtomGroup + Protein atoms (e.g. CA selection) + weights: np.ndarray, optional + Per-atom weights to use in the RMSD calculation. If ``None``, + all atoms are weighted equally. + + Notes + ----- + All atom positions are accumulated in memory during the trajectory + iteration. For long trajectories or large systems this may result in + significant memory usage. Consider using the ``step`` argument to + ``run()`` to reduce the number of frames analyzed. + """ + + _analysis_algorithm_is_parallelizable = False + + def __init__(self, atomgroup: mda.AtomGroup, weights: Optional[np.ndarray] = None, **kwargs): + super().__init__(atomgroup.universe.trajectory, **kwargs) + self._weights = weights + self._ag = atomgroup + + def _prepare(self) -> None: + self._coords = np.zeros((self.n_frames, self._ag.n_atoms, 3), dtype=np.float64) + + def _single_frame(self) -> None: + self._coords[self._frame_index] = self._ag.positions + + def _conclude(self) -> None: + nframes = self._coords.shape[0] + + # Pre-allocate numpy arrays + n_pairs = nframes * (nframes - 1) // 2 + self.results.rmsd2d = np.empty(n_pairs) + + for idx, (i, j) in enumerate(itertools.combinations(range(nframes), 2)): + posi, posj = self._coords[i], self._coords[j] + self.results.rmsd2d[idx] = rms.rmsd( + posi, + posj, + self._weights, + center=True, + superposition=True, + ) + + +class RMSDAnalysis(AnalysisBase): + """ + 1D RMSD time series for an AtomGroup. + + Parameters + ---------- + atomgroup : MDAnalysis.AtomGroup + Atoms to compute RMSD for. + reference: Optional[MDAnalysis.AtomGroup] + Reference AtomGroup. If ``None``, the reference positions are taken + from the first analyzed frame, so ``run(start=10)`` measures RMSD + relative to frame 10, not frame 0. + mass_weighted : bool, optional + If True, compute mass-weighted RMSD. + center : bool, optional + If ``True``, subtract the center of geometry before computing RMSD. + Defaults to ``False`` as the trajectory is assumed to be pre-centered. + superposition : bool, optional + If ``True``, perform rotational superposition before computing RMSD. + Defaults to ``False`` as the trajectory is assumed to be pre-superposed. + """ + + _analysis_algorithm_is_parallelizable = False + + def __init__( + self, + atomgroup: mda.AtomGroup, + reference: Optional[mda.AtomGroup] = None, + mass_weighted: bool = False, + center: bool = False, + superposition: bool = False, + **kwargs, + ): + super().__init__(atomgroup.universe.trajectory, **kwargs) + + self._ag = atomgroup + self._reference = reference if reference is not None else self._ag + self._mass_weighted = mass_weighted + self._center = center + self._superposition = superposition + + def _prepare(self) -> None: + self.results.rmsd = np.zeros(self.n_frames, dtype=np.float64) + # reference is taken from the first analyzed frame, not necessarily frame 0 + self._reference_pos = self._reference.positions.copy() + + if self._mass_weighted: + self._weights = self._ag.masses / np.mean(self._ag.masses) + else: + self._weights = None + + def _single_frame(self) -> None: + self.results.rmsd[self._frame_index] = rms.rmsd( + self._ag.positions, + self._reference_pos, + self._weights, + center=self._center, + superposition=self._superposition, + ) + + +class LigandCOMDrift(AnalysisBase): + """ + Ligand center-of-mass displacement from initial position. + + Parameters + ---------- + atomgroup : mda.AtomGroup + Ligand atoms for which the center-of-mass drift is calculated. + + Notes + ----- + The initial position is taken from the first analyzed frame, so + ``run(start=10)`` measures drift relative to frame 10, not frame 0. + + PBC are not applied as the trajectory is assumed to have been + pre-processed, ensuring the ligand does not jump between periodic images. + Passing a box to apply the minimum image convention would give + incorrect results for ligands that have drifted more than half a box + length from their starting position. + """ + + _analysis_algorithm_is_parallelizable = False + + def __init__(self, atomgroup: mda.AtomGroup, **kwargs): + super().__init__(atomgroup.universe.trajectory, **kwargs) + self._ag = atomgroup + + def _prepare(self) -> None: + self.results.com_drift = np.zeros(self.n_frames, dtype=np.float64) + # initial COM is taken from the first analyzed frame, not necessarily frame 0 + self._initial_com = self._ag.center_of_mass() + + def _single_frame(self) -> None: + # no box argument, assumes the ligand stays in a consistent image; + # applying the minimum image convention could mask large drifts > half a box length + self.results.com_drift[self._frame_index] = mda.lib.distances.calc_bonds( + self._ag.center_of_mass(), + self._initial_com, + ) + + def gather_rms_data( - pdb_topology: pathlib.Path, dataset: pathlib.Path, skip: Optional[int] = None -) -> dict[str, list[Any]]: + pdb_topology: pathlib.Path, + dataset: pathlib.Path, + skip: Optional[int] = None, +) -> dict[str, list[float]]: """ Compute structural RMSD-based metrics for a multistate BFE simulation. @@ -161,105 +320,30 @@ def gather_rms_data( # max against 1 to avoid skip=0 case skip = max(n_frames // 500, 1) - pb = tqdm.tqdm(total=int(n_frames / skip) * n_lambda) - u_top = mda.Universe(pdb_topology) - for i in range(n_lambda): + for state_idx in range(n_lambda): # cheeky, but we can read the PDB topology once and reuse per universe # this then only hits the PDB file once for all replicas - u = make_Universe(u_top._topology, ds, state=i) + u = make_Universe(u_top._topology, ds, state=state_idx) prot = u.select_atoms("protein and name CA") ligand = u.select_atoms("resname UNK") - # save coordinates for 2D RMSD matrix - # TODO: Some smart guard to avoid allocating a silly amount of memory? - prot2d = np.empty((len(u.trajectory[::skip]), len(prot), 3), dtype=np.float32) - - prot_start = prot.positions - ligand_start = ligand.positions - ligand_initial_com = ligand.center_of_mass() - ligand_weights = ligand.masses / np.mean(ligand.masses) - - this_protein_rmsd = [] - this_ligand_rmsd = [] - this_ligand_wander = [] - - for ts_i, ts in enumerate(u.trajectory[::skip]): - pb.update() - - if prot: - prot2d[ts_i, :, :] = prot.positions - this_protein_rmsd.append( - rms.rmsd( - prot.positions, - prot_start, - None, # prot_weights, - center=False, - superposition=False, - ) - ) - if ligand: - this_ligand_rmsd.append( - rms.rmsd( - ligand.positions, - ligand_start, - ligand_weights, - center=False, - superposition=False, - ) - ) - this_ligand_wander.append( - # distance between start and current ligand position - # ignores PBC, but we've already centered the traj - mda.lib.distances.calc_bonds(ligand.center_of_mass(), ligand_initial_com) - ) - if prot: - # can ignore weights here as it's all Ca - rmsd2d = twoD_RMSD(prot2d, w=None) # prot_weights) - output["protein_RMSD"].append(this_protein_rmsd) - output["protein_2D_RMSD"].append(rmsd2d) - if ligand: - output["ligand_RMSD"].append(this_ligand_rmsd) - output["ligand_wander"].append(this_ligand_wander) - - output["time(ps)"] = list(np.arange(len(u.trajectory))[::skip] * u.trajectory.dt) - - return output - + prot_rmsd = RMSDAnalysis(prot).run(step=skip) + output["protein_RMSD"].append(prot_rmsd.results.rmsd) -def twoD_RMSD(positions, w: Optional[npt.NDArray]) -> list[float]: - """ - Compute a flattened 2D RMSD matrix from a trajectory. + prot_rmsd2d = Protein2DRMSD(prot).run(step=skip) + output["protein_2D_RMSD"].append(prot_rmsd2d.results.rmsd2d) - For all unique frame pairs ``(i, j)`` with ``i < j``, this function - computes the RMSD between atomic coordinates after optimal alignment. - - Parameters - ---------- - positions : np.ndarray - Atomic coordinates for all frames in the trajectory. - w : np.ndarray, optional - Per-atom weights to use in the RMSD calculation. If ``None``, - all atoms are weighted equally. - - Returns - ------- - list of float - Flattened list of RMSD values corresponding to all frame pairs - ``(i, j)`` with ``i < j``. - """ - nframes, _, _ = positions.shape - - output = [] - - for i, j in itertools.combinations(range(nframes), 2): - posi, posj = positions[i], positions[j] + if ligand: + lig_rmsd = RMSDAnalysis(ligand, mass_weighted=True).run(step=skip) + output["ligand_RMSD"].append(lig_rmsd.results.rmsd) - rmsd = rms.rmsd(posi, posj, w, center=True, superposition=True) + lig_com_drift = LigandCOMDrift(ligand).run(step=skip) + output["ligand_wander"].append(lig_com_drift.results.com_drift) - output.append(rmsd) + output["time(ps)"] = np.arange(len(u.trajectory))[::skip] * u.trajectory.dt return output diff --git a/src/openfe_analysis/tests/test_rmsd_mda_data.py b/src/openfe_analysis/tests/test_rmsd_mda_data.py new file mode 100644 index 0000000..83f95ef --- /dev/null +++ b/src/openfe_analysis/tests/test_rmsd_mda_data.py @@ -0,0 +1,80 @@ +import MDAnalysis as mda +import pytest +from MDAnalysisTests.datafiles import DCD, PSF +from numpy.testing import assert_allclose, assert_almost_equal + +from openfe_analysis.rmsd import RMSDAnalysis + + +@pytest.fixture +def mda_universe(): + return mda.Universe(PSF, DCD) + + +@pytest.fixture() +def correct_values(): + return [0, 4.68953] + + +@pytest.fixture() +def correct_values_mass(): + return [0, 4.74920] + + +def test_rmsd(mda_universe, correct_values): + prot = mda_universe.select_atoms("name CA") + prot_rmsd = RMSDAnalysis(prot, superposition=True).run(step=49) + assert_almost_equal( + prot_rmsd.results.rmsd, + correct_values, + 4, + err_msg="error: rmsd profile should match" + "test values", + ) + + +def test_rmsd_frames(mda_universe, correct_values): + prot = mda_universe.select_atoms("name CA") + prot_rmsd = RMSDAnalysis(prot, superposition=True).run(frames=[0, 49]) + assert_almost_equal( + prot_rmsd.results.rmsd, + correct_values, + 4, + err_msg="error: rmsd profile should match" + "test values", + ) + + +def test_rmsd_single_frame(mda_universe): + prot = mda_universe.select_atoms("name CA") + prot_rmsd = RMSDAnalysis(prot, superposition=True).run(start=5, stop=6) + single_frame = [0.91544906] + assert_almost_equal( + prot_rmsd.results.rmsd, + single_frame, + 4, + err_msg="error: rmsd profile should match" + "test values", + ) + + +def test_mass_weighted(mda_universe, correct_values): + # mass weighting the CA should give the same answer as weighing + # equally because all CA have the same mass + prot = mda_universe.select_atoms("name CA") + prot_rmsd = RMSDAnalysis(prot, superposition=True, mass_weighted=True).run(step=49) + + assert_almost_equal( + prot_rmsd.results.rmsd, + correct_values, + 4, + err_msg="error: rmsd profile should matchtest values", + ) + + +def test_custom_weighted(mda_universe, correct_values_mass): + prot = mda_universe.select_atoms("all") + prot_rmsd = RMSDAnalysis(prot, superposition=True, mass_weighted=True).run(step=49) + assert_almost_equal( + prot_rmsd.results.rmsd, + correct_values_mass, + 4, + err_msg="error: rmsd profile should matchtest values", + )