Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions pySEQTarget/SEQopts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import multiprocessing
import os
from dataclasses import dataclass, field
from typing import List, Literal, Optional
from typing import List, Literal, Optional, Union


@dataclass
Expand All @@ -20,7 +20,12 @@ class SEQopts:
:param cense_eligible_colname: Column name to identify which rows are eligible for censoring model fitting
:param compevent_colname: Column name specifying a competing event to the outcome
:param covariates: Override to specify the outcome patsy formula for outcome model fitting
:param denominator: Override to specify the outcome patsy formula for denominator model fitting
:param denominator: Override to specify the patsy formula for denominator weight
model fitting. A single string fits the same model in every treatment arm.
A list with one formula per ``treatment_level`` (in ``treatment_level``
order) fits a separate denominator model, with its own covariates, in each
arm; this is only supported for post-expansion weights
(``weight_preexpansion=False``).
:param excused: Boolean to allow excused conditions when method is censoring
:param excused_colnames: Column names (at the same length of treatment_level) specifying excused conditions, default ``[]``
:param expand_only: If True, ``SEQuential.expand()`` returns the expanded dataset and skips weighting,
Expand All @@ -38,8 +43,12 @@ class SEQopts:
:param indicator_squared: How to indicate squared columns in models
:param km_curves: Boolean to create survival, risk, and incidence (if applicable) estimates
:param ncores: Number of cores to use if running in parallel, default ``max(1, cpu_count() - 1)``
:param numerator: Override to specify the outcome patsy formula for
numerator models; "1" or "" indicate intercept only model
:param numerator: Override to specify the patsy formula for numerator weight
models; "1" or "" indicate intercept only model. A single string fits the
same model in every treatment arm. A list with one formula per
``treatment_level`` (in ``treatment_level`` order) fits a separate
numerator model, with its own covariates, in each arm; this is only
supported for post-expansion weights (``weight_preexpansion=False``).
:param offload: Boolean to offload intermediate model data to disk
:param offload_dir: Directory to offload intermediate model data
:param parallel: Boolean to run model fitting in parallel
Expand Down Expand Up @@ -81,7 +90,7 @@ class SEQopts:
compevent_colname: Optional[str] = None
covariates: Optional[str] = None
cox_package: Literal["lifelines", "scikit-survival"] = "lifelines"
denominator: Optional[str] = None
denominator: Optional[Union[str, List[str]]] = None
excused: bool = False
excused_colnames: List[str] = field(default_factory=lambda: [])
expand_only: bool = False
Expand All @@ -97,7 +106,7 @@ class SEQopts:
indicator_squared: str = "_sq"
km_curves: bool = False
ncores: Optional[int] = None
numerator: Optional[str] = None
numerator: Optional[Union[str, List[str]]] = None
offload: bool = False
offload_dir: str = "_seq_models"
parallel: bool = False
Expand Down Expand Up @@ -198,7 +207,17 @@ def _normalize_formulas(self):
"cense_denominator",
):
attr = getattr(self, i)
if attr is not None and not isinstance(attr, list):
if attr is None:
continue
if isinstance(attr, list):
# Per-treatment-level formulas (numerator/denominator): strip
# whitespace from each element, leaving None entries untouched.
setattr(
self,
i,
["".join(a.split()) if isinstance(a, str) else a for a in attr],
)
else:
setattr(self, i, "".join(attr.split()))

def __post_init__(self):
Expand Down
4 changes: 2 additions & 2 deletions pySEQTarget/analysis/_hazard.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _one_boot_log_hr(self, data, idx, model_pos, sample_idx):
data.lazy()
.join(counts.lazy(), on=self.id_col, how="inner")
.with_columns(pl.int_ranges(0, pl.col("_count")).alias("_rep"))
.explode("_rep")
.explode("_rep", empty_as_null=True)
.drop("_count", "_rep")
.collect()
)
Expand Down Expand Up @@ -209,7 +209,7 @@ def _hazard_handler(self, data, idx, boot_idx, rng):
.first()
.sort([self.id_col, "trial"])
.with_columns([pl.lit(list(range(self.followup_max + 1))).alias("followup")])
.explode("followup")
.explode("followup", empty_as_null=True)
.with_columns(
[(pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}")]
)
Expand Down
2 changes: 1 addition & 1 deletion pySEQTarget/analysis/_survival_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _calculate_risk(self, data, idx=None, val=None):
.first()
.drop(["followup", f"followup{self.indicator_squared}"])
.with_columns([pl.lit(followup_range).alias("followup")])
.explode("followup")
.explode("followup", empty_as_null=True)
.with_columns(
[(pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}")]
)
Expand Down
64 changes: 54 additions & 10 deletions pySEQTarget/error/_param_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,23 +68,67 @@ def _param_checker(self):
"For weighted ITT analyses, cense_colname or visit_colname must be provided."
)

# Per-treatment-level weight models: 'numerator'/'denominator' may be a list
# with one formula per treatment_level (in treatment_level order), fitting a
# separate model per arm. Only supported for post-expansion weights.
for name in ("numerator", "denominator"):
spec = getattr(self, name)
if isinstance(spec, (list, tuple)):
if not self.weighted or self.method == "ITT":
raise ValueError(
f"Per-treatment-level '{name}' models require a weighted, "
"non-ITT analysis."
)
if self.weight_preexpansion:
raise ValueError(
f"Per-treatment-level '{name}' models are only supported for "
"post-expansion weights (weight_preexpansion=False)."
)
if any(f is None for f in spec):
raise ValueError(
f"Per-treatment-level '{name}' formulas contain None; supply "
"one formula per treatment level."
)
if len(spec) != len(self.treatment_level):
raise ValueError(
f"'{name}' must be a single formula or one per treatment "
f"level ({len(self.treatment_level)} expected, in "
f"'treatment_level' order) but {len(spec)} were supplied."
)

if (
self.weighted
and self.method != "ITT"
and self.numerator is not None
and self.denominator is not None
and self.numerator == self.denominator
):
warnings.warn(
f"Numerator and denominator weight models use identical "
f"covariates ('{self.numerator}'); the stabilized weights "
"will all equal 1 (i.e., no weighting). The denominator "
"should typically include the time-varying confounders "
"that the numerator omits — check for a typo in either or "
"both of 'numerator' and 'denominator'.",
UserWarning,
stacklevel=2,
num_list = (
list(self.numerator)
if isinstance(self.numerator, (list, tuple))
else [self.numerator]
)
den_list = (
list(self.denominator)
if isinstance(self.denominator, (list, tuple))
else [self.denominator]
)
# Warn on any arm whose numerator and denominator formulas coincide
# (element-wise when both are per-arm; the shared/shared case reduces to
# comparing the two single formulas).
if len(num_list) == len(den_list):
same = sorted({n for n, d in zip(num_list, den_list) if n == d})
if same:
covs = "', '".join(same)
warnings.warn(
f"Numerator and denominator weight models use identical "
f"covariates ('{covs}'); the stabilized weights "
"will all equal 1 (i.e., no weighting). The denominator "
"should typically include the time-varying confounders "
"that the numerator omits — check for a typo in either or "
"both of 'numerator' and 'denominator'.",
UserWarning,
stacklevel=2,
)

if self.excused:
_, self.excused_colnames = _pad(self.treatment_level, self.excused_colnames)
Expand Down
2 changes: 1 addition & 1 deletion pySEQTarget/expansion/_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _mapper(data, id_col, time_col, min_followup=-math.inf, max_followup=math.in
).alias("period")
]
)
.explode("period")
.explode("period", empty_as_null=True)
.drop(pl.col(time_col))
.with_columns(
[
Expand Down
2 changes: 1 addition & 1 deletion pySEQTarget/helpers/_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _prepare_boot_data(self, data, boot_id):
data.lazy()
.join(counts.lazy(), on=self.id_col, how="inner")
.with_columns(pl.int_ranges(0, pl.col("count")).alias("replicate"))
.explode("replicate")
.explode("replicate", empty_as_null=True)
.with_columns(new_id)
.drop("count", "replicate")
.collect()
Expand Down
10 changes: 8 additions & 2 deletions pySEQTarget/helpers/_col_string.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
def _col_string(expressions):
cols = set()
for expression in expressions:
if expression is not None:
cols.update(expression.replace("+", " ").replace("*", " ").split())
if expression is None:
continue
# numerator/denominator may be a list of per-treatment-level formulas;
# gather the referenced columns across every element.
parts = expression if isinstance(expression, (list, tuple)) else [expression]
for part in parts:
if part is not None:
cols.update(part.replace("+", " ").replace("*", " ").split())
return cols
27 changes: 17 additions & 10 deletions pySEQTarget/weighting/_weight_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@
from ..helpers._glum_fit import _fit_glum


def _formula_rhs_for_level(spec, level_idx):
"""
Resolve the right-hand-side formula for a treatment level. ``spec`` is either
a single formula string (shared across arms) or a list with one formula per
treatment level (in treatment_level order), fitting a separate model per arm.
"""
if isinstance(spec, (list, tuple)):
return spec[level_idx]
return spec


def _get_subset_for_level(
self, WDT, level_idx, level, tx_lag_col, exclude_followup_zero=False
):
Expand Down Expand Up @@ -88,11 +99,6 @@ def _fit_numerator(self, WDT):
if self.method == "ITT":
return
predictor = "switch" if self.excused else self.treatment_col
# Handle intercept-only formula when numerator is "1" or empty
if self.numerator in ("1", ""):
formula = f"{predictor}~1"
else:
formula = f"{predictor}~{self.numerator}"
tx_lag_col = (
f"{self.treatment_col}{self.indicator_baseline}" if self.excused else "tx_lag"
)
Expand All @@ -101,6 +107,9 @@ def _fit_numerator(self, WDT):
# treatment_level=[1,2] or dose-response always uses mnlogit
is_binary = sorted(self.treatment_level) == [0, 1] and self.method == "censoring"
for i, level in enumerate(self.treatment_level):
# numerator may be per-treatment-level; select this arm's formula.
rhs = _formula_rhs_for_level(self.numerator, i)
formula = f"{predictor}~1" if rhs in ("1", "") else f"{predictor}~{rhs}"
DT_subset = _get_subset_for_level(self, WDT, i, level, tx_lag_col)
if len(DT_subset[predictor].unique()) < 2:
fits.append(None)
Expand Down Expand Up @@ -131,17 +140,15 @@ def _fit_denominator(self, WDT):
if self.excused and not self.weight_preexpansion
else self.treatment_col
)
# Handle intercept-only formula when denominator is "1" or empty
if self.denominator in ("1", ""):
formula = f"{predictor}~1"
else:
formula = f"{predictor}~{self.denominator}"
fits = []
# Use logit for binary 0/1 treatment with censoring method only
# treatment_level=[1,2] or dose-response always uses mnlogit
is_binary = sorted(self.treatment_level) == [0, 1] and self.method == "censoring"
exclude_followup_zero = not self.weight_preexpansion
for i, level in enumerate(self.treatment_level):
# denominator may be per-treatment-level; select this arm's formula.
rhs = _formula_rhs_for_level(self.denominator, i)
formula = f"{predictor}~1" if rhs in ("1", "") else f"{predictor}~{rhs}"
DT_subset = _get_subset_for_level(
self, WDT, i, level, "tx_lag", exclude_followup_zero=exclude_followup_zero
)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "pySEQTarget"
version = "0.13.9"
version = "0.14.0"
description = "Sequentially Nested Target Trial Emulation"
readme = "README.md"
license = {text = "MIT"}
Expand Down
Loading
Loading