Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/_static/docstring_previews/de_plot_bcv.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
188 changes: 153 additions & 35 deletions pertpy/tools/_differential_gene_expression/_edger.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from collections.abc import Sequence

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.lines import Line2D
from matplotlib.pyplot import Figure
from scipy.sparse import issparse

from pertpy._doc import _doc_params, doc_common_plot_args
from pertpy._logger import logger

from ._base import LinearModelBase
Expand All @@ -24,21 +28,108 @@ def fit(self, **kwargs): # adata, design, mask, layer
Args:
**kwargs: Keyword arguments specific to glmQLFit()
"""
try:
from rpy2 import robjects as ro
from rpy2.robjects import numpy2ri, pandas2ri
from rpy2.robjects.conversion import get_conversion, localconverter
from rpy2.robjects.packages import importr

except ImportError:
raise ImportError("edger requires rpy2 to be installed.") from None

try:
edger = importr("edgeR")
except ImportError as e:
raise ImportError(
"edgeR requires a valid R installation with the following packages:\nedgeR, BiocParallel, RhpcBLASctl"
) from e
ro, edger = self._ensure_deps("ro", "edger")

if not hasattr(self, "dge") or not hasattr(self, "design_r"):
self._prepare_dge()

logger.info("Fitting linear model")
fit = edger.glmQLFit(self.dge, design=self.design_r, **kwargs)

ro.globalenv["fit"] = fit
self.fit = fit

@_doc_params(common_plot_args=doc_common_plot_args)
def plot_bcv( # pragma: no cover # noqa: D417
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only designed for edger now, right? But it should also work with statsmodels or pydeseq2. Like ideally, I'd like to completely phase out the edger support at some point in the future.

self,
*,
xlabel: str | None = "Average log CPM",
ylabel: str | None = "Biological coefficient of variation",
marker: str = "o",
point_size: float = 0.2,
common_col: str = "red",
trend_col: str = "blue",
tagwise_col: str = "black",
legend: bool = True,
return_fig: bool = False,
**kwargs,
) -> Figure | None:
"""Plot biological coefficient of variation (BCV) like edgeR::plotBCV.

Args:
xlabel: Label for the x-axis (default: "Average log CPM").
ylabel: Label for the y-axis (default: "Biological coefficient of variation").
marker: Marker style.
point_size: Scaling factor for point sizes.
common_col: Color for common dispersion line.
trend_col: Color for trended dispersion line.
tagwise_col: Color for tagwise dispersion points.
legend: Whether to draw a legend.
{common_plot_args}
**kwargs: Additional arguments for ax.scatter and ax.axhline.

Returns:
If `return_fig` is `True`, returns the figure, otherwise `None`.

Examples:
>>> import pertpy as pt
>>> import decoupler as dc
>>> adata = pt.dt.zhang_2021()
>>> adata = adata[adata.obs["Origin"] == "t", :].copy()
>>> adata.layers["counts"] = adata.X.copy()
>>> pdata = dc.pp.pseudobulk(adata, sample_col="Patient", groups_col="Cluster", layer="counts", mode="sum")
>>> dc.pp.filter_samples(pdata, inplace=True)
>>> edgr = pt.tl.EdgeR(pdata, design="~Efficacy+Treatment")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd love this to be pydeseq2

>>> edgr.plot_bcv()

Preview:
.. image:: /_static/docstring_previews/de_plot_bcv.png
"""
if not hasattr(self, "dge"):
self._prepare_dge()

numpy2ri, get_conversion, localconverter = self._ensure_deps("numpy2ri", "get_conversion", "localconverter")

with localconverter(get_conversion() + numpy2ri.converter):
A = np.asarray(self.dge.rx2("AveLogCPM"))
tagwise = np.asarray(self.dge.rx2("tagwise.dispersion"))
common = float(self.dge.rx2("common.dispersion")[0])
trended = np.asarray(self.dge.rx2("trended.dispersion"))

fig, ax = plt.subplots(dpi=300)

ax.scatter(A, np.sqrt(tagwise), c=tagwise_col, s=point_size * 20, marker=marker, linewidths=0, **kwargs)

ax.axhline(np.sqrt(common), color=common_col, linewidth=2, **kwargs)

order = np.argsort(A)

ax.plot(A[order], np.sqrt(trended)[order], color=trend_col, linewidth=2, **kwargs)

ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)

if legend:
handles = [
Line2D([0], [0], marker=marker, linestyle="", color=tagwise_col, label="Tagwise"),
Line2D([0], [0], linestyle="-", color=common_col, label="Common"),
Line2D([0], [0], linestyle="-", color=trend_col, label="Trend"),
]
ax.legend(handles=handles, loc="upper right", frameon=True)

plt.tight_layout()

if return_fig:
return fig

plt.show()
return None

def _prepare_dge(self) -> None:
"""Create DGEList, calculate normalization factors, and estimate dispersions."""
numpy2ri, pandas2ri, get_conversion, localconverter, edger = self._ensure_deps(
"numpy2ri", "pandas2ri", "get_conversion", "localconverter", "edger"
)

# Convert dataframe
with localconverter(get_conversion() + numpy2ri.converter):
Expand All @@ -60,11 +151,8 @@ def fit(self, **kwargs): # adata, design, mask, layer
logger.info("Estimating Dispersions")
dge = edger.estimateDisp(dge, design=design_r)

logger.info("Fitting linear model")
fit = edger.glmQLFit(dge, design=design_r, **kwargs)

ro.globalenv["fit"] = fit
self.fit = fit
self.dge = dge
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why these changes here?

self.design_r = design_r

def _test_single_contrast(self, contrast: Sequence[float], **kwargs) -> pd.DataFrame: # noqa: D417
"""Conduct test for each contrast and return a data frame.
Expand All @@ -81,21 +169,9 @@ def _test_single_contrast(self, contrast: Sequence[float], **kwargs) -> pd.DataF
# parse **kwargs to R function
# Fix mask for .fit()

try:
from rpy2 import robjects as ro
from rpy2.robjects import numpy2ri, pandas2ri
from rpy2.robjects.conversion import get_conversion, localconverter
from rpy2.robjects.packages import importr

except ImportError:
raise ImportError("edger requires rpy2 to be installed.") from None

try:
importr("edgeR")
except ImportError:
raise ImportError(
"edgeR requires a valid R installation with the following packages: edgeR, BiocParallel, RhpcBLASctl"
) from None
ro, numpy2ri, pandas2ri, get_conversion, localconverter = self._ensure_deps(
"ro", "numpy2ri", "pandas2ri", "get_conversion", "localconverter"
)

# Convert vector to R, which drops a category like `self.design_matrix` to use the intercept for the left out.
with localconverter(get_conversion() + numpy2ri.converter) as cv:
Expand Down Expand Up @@ -126,3 +202,45 @@ def _test_single_contrast(self, contrast: Sequence[float], **kwargs) -> pd.DataF
de_res = de_res.reset_index()

return de_res.rename(columns={"PValue": "p_value", "logFC": "log_fc", "FDR": "adj_p_value"})

def _ensure_deps(self, *names):
"""Lazy loader for rpy2 objects with per-instance caching.

Example:
ro, numpy2ri, edger = self._ensure_deps("ro", "numpy2ri", "edger")
"""
if not hasattr(self, "_imports_cache"):
try:
from rpy2 import robjects as ro
from rpy2.robjects import numpy2ri, pandas2ri
from rpy2.robjects.conversion import get_conversion, localconverter
from rpy2.robjects.packages import importr

except ImportError:
raise ImportError("edger requires rpy2 to be installed.") from None

try:
edger = importr("edgeR")
except ImportError as e:
raise ImportError(
"edgeR requires a valid R installation with the following packages:\nedgeR, BiocParallel, RhpcBLASctl"
) from e

self._imports_cache = {
"ro": ro,
"numpy2ri": numpy2ri,
"pandas2ri": pandas2ri,
"get_conversion": get_conversion,
"localconverter": localconverter,
"edger": edger,
}

results = {}

for name in names:
if name in self._imports_cache:
results[name] = self._imports_cache[name]
else:
raise KeyError(f"Unknown import request: '{name}'")

return tuple(results[name] for name in names)
Loading