From 8ce0baec1df87c128003e4f88ce0af43fd276428 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 3 Jul 2026 13:50:44 +0100 Subject: [PATCH 1/4] Allow arm-specific numerator/denominator weight models SEQopts numerator/denominator now accept either a single formula (shared across treatment arms, unchanged) or a list with one formula per treatment_level, fitting a separate weight model with its own covariates in each arm. Supported for post-expansion weights only (weight_preexpansion=False), mirroring SEQTaRget commit 761d0d6. _fit_numerator/_fit_denominator select each arm's formula inside the per-level loop; _col_string gathers referenced columns across list elements; _param_checker validates the list (weighted non-ITT, post-expansion, no None entries, length equal to treatment_level) and makes the identical-formula warning list-aware. The prediction path is unchanged because it uses the fitted per-level models directly. Adds tests/test_armspecific_weights.py. --- pySEQTarget/SEQopts.py | 36 +++++-- pySEQTarget/error/_param_checker.py | 64 ++++++++++-- pySEQTarget/helpers/_col_string.py | 10 +- pySEQTarget/weighting/_weight_fit.py | 27 +++-- tests/test_armspecific_weights.py | 141 +++++++++++++++++++++++++++ 5 files changed, 249 insertions(+), 29 deletions(-) create mode 100644 tests/test_armspecific_weights.py diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index 19becbd..07d8a26 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -1,7 +1,7 @@ import multiprocessing import os from dataclasses import dataclass, field -from typing import List, Literal, Optional +from typing import List, Literal, Optional, Union @dataclass @@ -20,7 +20,12 @@ class SEQopts: :param cense_eligible_colname: Column name to identify which rows are eligible for censoring model fitting :param compevent_colname: Column name specifying a competing event to the outcome :param covariates: Override to specify the outcome patsy formula for outcome model fitting - :param denominator: Override to specify the outcome patsy formula for denominator model fitting + :param denominator: Override to specify the patsy formula for denominator weight + model fitting. A single string fits the same model in every treatment arm. + A list with one formula per ``treatment_level`` (in ``treatment_level`` + order) fits a separate denominator model, with its own covariates, in each + arm; this is only supported for post-expansion weights + (``weight_preexpansion=False``). :param excused: Boolean to allow excused conditions when method is censoring :param excused_colnames: Column names (at the same length of treatment_level) specifying excused conditions, default ``[]`` :param expand_only: If True, ``SEQuential.expand()`` returns the expanded dataset and skips weighting, @@ -38,8 +43,12 @@ class SEQopts: :param indicator_squared: How to indicate squared columns in models :param km_curves: Boolean to create survival, risk, and incidence (if applicable) estimates :param ncores: Number of cores to use if running in parallel, default ``max(1, cpu_count() - 1)`` - :param numerator: Override to specify the outcome patsy formula for - numerator models; "1" or "" indicate intercept only model + :param numerator: Override to specify the patsy formula for numerator weight + models; "1" or "" indicate intercept only model. A single string fits the + same model in every treatment arm. A list with one formula per + ``treatment_level`` (in ``treatment_level`` order) fits a separate + numerator model, with its own covariates, in each arm; this is only + supported for post-expansion weights (``weight_preexpansion=False``). :param offload: Boolean to offload intermediate model data to disk :param offload_dir: Directory to offload intermediate model data :param parallel: Boolean to run model fitting in parallel @@ -81,7 +90,7 @@ class SEQopts: compevent_colname: Optional[str] = None covariates: Optional[str] = None cox_package: Literal["lifelines", "scikit-survival"] = "lifelines" - denominator: Optional[str] = None + denominator: Optional[Union[str, List[str]]] = None excused: bool = False excused_colnames: List[str] = field(default_factory=lambda: []) expand_only: bool = False @@ -97,7 +106,7 @@ class SEQopts: indicator_squared: str = "_sq" km_curves: bool = False ncores: Optional[int] = None - numerator: Optional[str] = None + numerator: Optional[Union[str, List[str]]] = None offload: bool = False offload_dir: str = "_seq_models" parallel: bool = False @@ -198,7 +207,20 @@ def _normalize_formulas(self): "cense_denominator", ): attr = getattr(self, i) - if attr is not None and not isinstance(attr, list): + if attr is None: + continue + if isinstance(attr, list): + # Per-treatment-level formulas (numerator/denominator): strip + # whitespace from each element, leaving None entries untouched. + setattr( + self, + i, + [ + "".join(a.split()) if isinstance(a, str) else a + for a in attr + ], + ) + else: setattr(self, i, "".join(attr.split())) def __post_init__(self): diff --git a/pySEQTarget/error/_param_checker.py b/pySEQTarget/error/_param_checker.py index 9efa54a..941cd85 100644 --- a/pySEQTarget/error/_param_checker.py +++ b/pySEQTarget/error/_param_checker.py @@ -68,23 +68,67 @@ def _param_checker(self): "For weighted ITT analyses, cense_colname or visit_colname must be provided." ) + # Per-treatment-level weight models: 'numerator'/'denominator' may be a list + # with one formula per treatment_level (in treatment_level order), fitting a + # separate model per arm. Only supported for post-expansion weights. + for name in ("numerator", "denominator"): + spec = getattr(self, name) + if isinstance(spec, (list, tuple)): + if not self.weighted or self.method == "ITT": + raise ValueError( + f"Per-treatment-level '{name}' models require a weighted, " + "non-ITT analysis." + ) + if self.weight_preexpansion: + raise ValueError( + f"Per-treatment-level '{name}' models are only supported for " + "post-expansion weights (weight_preexpansion=False)." + ) + if any(f is None for f in spec): + raise ValueError( + f"Per-treatment-level '{name}' formulas contain None; supply " + "one formula per treatment level." + ) + if len(spec) != len(self.treatment_level): + raise ValueError( + f"'{name}' must be a single formula or one per treatment " + f"level ({len(self.treatment_level)} expected, in " + f"'treatment_level' order) but {len(spec)} were supplied." + ) + if ( self.weighted and self.method != "ITT" and self.numerator is not None and self.denominator is not None - and self.numerator == self.denominator ): - warnings.warn( - f"Numerator and denominator weight models use identical " - f"covariates ('{self.numerator}'); the stabilized weights " - "will all equal 1 (i.e., no weighting). The denominator " - "should typically include the time-varying confounders " - "that the numerator omits — check for a typo in either or " - "both of 'numerator' and 'denominator'.", - UserWarning, - stacklevel=2, + num_list = ( + list(self.numerator) + if isinstance(self.numerator, (list, tuple)) + else [self.numerator] + ) + den_list = ( + list(self.denominator) + if isinstance(self.denominator, (list, tuple)) + else [self.denominator] ) + # Warn on any arm whose numerator and denominator formulas coincide + # (element-wise when both are per-arm; the shared/shared case reduces to + # comparing the two single formulas). + if len(num_list) == len(den_list): + same = sorted({n for n, d in zip(num_list, den_list) if n == d}) + if same: + covs = "', '".join(same) + warnings.warn( + f"Numerator and denominator weight models use identical " + f"covariates ('{covs}'); the stabilized weights " + "will all equal 1 (i.e., no weighting). The denominator " + "should typically include the time-varying confounders " + "that the numerator omits — check for a typo in either or " + "both of 'numerator' and 'denominator'.", + UserWarning, + stacklevel=2, + ) if self.excused: _, self.excused_colnames = _pad(self.treatment_level, self.excused_colnames) diff --git a/pySEQTarget/helpers/_col_string.py b/pySEQTarget/helpers/_col_string.py index 907ef80..8c67709 100644 --- a/pySEQTarget/helpers/_col_string.py +++ b/pySEQTarget/helpers/_col_string.py @@ -1,6 +1,12 @@ def _col_string(expressions): cols = set() for expression in expressions: - if expression is not None: - cols.update(expression.replace("+", " ").replace("*", " ").split()) + if expression is None: + continue + # numerator/denominator may be a list of per-treatment-level formulas; + # gather the referenced columns across every element. + parts = expression if isinstance(expression, (list, tuple)) else [expression] + for part in parts: + if part is not None: + cols.update(part.replace("+", " ").replace("*", " ").split()) return cols diff --git a/pySEQTarget/weighting/_weight_fit.py b/pySEQTarget/weighting/_weight_fit.py index 8fcb0bf..851bff8 100644 --- a/pySEQTarget/weighting/_weight_fit.py +++ b/pySEQTarget/weighting/_weight_fit.py @@ -5,6 +5,17 @@ from ..helpers._glum_fit import _fit_glum +def _formula_rhs_for_level(spec, level_idx): + """ + Resolve the right-hand-side formula for a treatment level. ``spec`` is either + a single formula string (shared across arms) or a list with one formula per + treatment level (in treatment_level order), fitting a separate model per arm. + """ + if isinstance(spec, (list, tuple)): + return spec[level_idx] + return spec + + def _get_subset_for_level( self, WDT, level_idx, level, tx_lag_col, exclude_followup_zero=False ): @@ -88,11 +99,6 @@ def _fit_numerator(self, WDT): if self.method == "ITT": return predictor = "switch" if self.excused else self.treatment_col - # Handle intercept-only formula when numerator is "1" or empty - if self.numerator in ("1", ""): - formula = f"{predictor}~1" - else: - formula = f"{predictor}~{self.numerator}" tx_lag_col = ( f"{self.treatment_col}{self.indicator_baseline}" if self.excused else "tx_lag" ) @@ -101,6 +107,9 @@ def _fit_numerator(self, WDT): # treatment_level=[1,2] or dose-response always uses mnlogit is_binary = sorted(self.treatment_level) == [0, 1] and self.method == "censoring" for i, level in enumerate(self.treatment_level): + # numerator may be per-treatment-level; select this arm's formula. + rhs = _formula_rhs_for_level(self.numerator, i) + formula = f"{predictor}~1" if rhs in ("1", "") else f"{predictor}~{rhs}" DT_subset = _get_subset_for_level(self, WDT, i, level, tx_lag_col) if len(DT_subset[predictor].unique()) < 2: fits.append(None) @@ -131,17 +140,15 @@ def _fit_denominator(self, WDT): if self.excused and not self.weight_preexpansion else self.treatment_col ) - # Handle intercept-only formula when denominator is "1" or empty - if self.denominator in ("1", ""): - formula = f"{predictor}~1" - else: - formula = f"{predictor}~{self.denominator}" fits = [] # Use logit for binary 0/1 treatment with censoring method only # treatment_level=[1,2] or dose-response always uses mnlogit is_binary = sorted(self.treatment_level) == [0, 1] and self.method == "censoring" exclude_followup_zero = not self.weight_preexpansion for i, level in enumerate(self.treatment_level): + # denominator may be per-treatment-level; select this arm's formula. + rhs = _formula_rhs_for_level(self.denominator, i) + formula = f"{predictor}~1" if rhs in ("1", "") else f"{predictor}~{rhs}" DT_subset = _get_subset_for_level( self, WDT, i, level, "tx_lag", exclude_followup_zero=exclude_followup_zero ) diff --git a/tests/test_armspecific_weights.py b/tests/test_armspecific_weights.py new file mode 100644 index 0000000..062e967 --- /dev/null +++ b/tests/test_armspecific_weights.py @@ -0,0 +1,141 @@ +"""Per-treatment-level weight models: 'numerator'/'denominator' given as a list +fit a separate model (with its own covariates) in each treatment arm. Post- +expansion weights only (weight_preexpansion=False). +""" + +import numpy as np +import pytest + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _model(numerator, denominator, **opts): + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts( + weighted=True, + weight_preexpansion=False, + numerator=numerator, + denominator=denominator, + seed=42, + **opts, + ), + ) + s.expand() + s.fit() + return s + + +def _coef_names(model): + return list(model.params.index) + + +def test_per_arm_denominator_fits_different_covariates_in_each_arm(): + s = _model(numerator=["sex", "sex"], denominator=["N+sex", "N+L+P+sex"]) + + den0 = _coef_names(s.denominator_model[0]) + den1 = _coef_names(s.denominator_model[1]) + + # Arm 0's model excludes L and P; arm 1's includes them. + assert not any(n.startswith("L") for n in den0) + assert not any(n.startswith("P") for n in den0) + assert any(n.startswith("N") for n in den0) + assert any(n.startswith("L") for n in den1) + assert any(n.startswith("P") for n in den1) + + +def test_per_arm_formulas_with_identical_elements_match_shared_fit(): + shared = _model(numerator="sex", denominator="N+L+P+sex") + perarm = _model( + numerator=["sex", "sex"], denominator=["N+L+P+sex", "N+L+P+sex"] + ) + + for i in range(2): + np.testing.assert_allclose( + np.asarray(shared.numerator_model[i].params), + np.asarray(perarm.numerator_model[i].params), + ) + np.testing.assert_allclose( + np.asarray(shared.denominator_model[i].params), + np.asarray(perarm.denominator_model[i].params), + ) + + # Same weights feed the outcome model, so its coefficients must match too. + shared_out = np.concatenate( + [np.asarray(m["outcome"].params) for m in shared.outcome_model] + ) + perarm_out = np.concatenate( + [np.asarray(m["outcome"].params) for m in perarm.outcome_model] + ) + np.testing.assert_allclose(shared_out, perarm_out) + + +def test_per_arm_weight_formulas_rejected_for_preexpansion_weights(): + with pytest.raises(ValueError, match="post-expansion weights"): + SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts( + weighted=True, + weight_preexpansion=True, + numerator=["sex", "sex"], + denominator=["N+sex", "N+L+P+sex"], + ), + ) + + +def test_per_arm_weight_formulas_rejected_for_wrong_length(): + with pytest.raises(ValueError, match="one per treatment level"): + SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts( + weighted=True, + weight_preexpansion=False, + denominator=["N+sex", "N+L+P+sex", "N+sex"], + ), + ) + + +def test_per_arm_weight_formulas_rejected_for_itt(): + with pytest.raises(ValueError, match="weighted, non-ITT"): + SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts( + weighted=True, + cense_colname="eligible", + numerator=["sex", "sex"], + denominator=["N+sex", "N+L+P+sex"], + ), + ) From dd8483496d065accf697e86b857066e13c454e75 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 3 Jul 2026 13:51:08 +0100 Subject: [PATCH 2/4] Bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 258e568..d463f86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pySEQTarget" -version = "0.13.9" +version = "0.14.0" description = "Sequentially Nested Target Trial Emulation" readme = "README.md" license = {text = "MIT"} From 66e0a32aaf79c9d49d122363c750df4217623be7 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 3 Jul 2026 14:07:22 +0100 Subject: [PATCH 3/4] Set empty_as_null to True on explode calls for Polars 2.0 Polars 2.0 flips the explode() empty_as_null default from True to False, which would silently turn exploded empty lists from null rows into dropped rows. --- pySEQTarget/analysis/_hazard.py | 4 ++-- pySEQTarget/analysis/_survival_pred.py | 2 +- pySEQTarget/expansion/_mapper.py | 2 +- pySEQTarget/helpers/_bootstrap.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pySEQTarget/analysis/_hazard.py b/pySEQTarget/analysis/_hazard.py index eb16eba..2a0f9c2 100644 --- a/pySEQTarget/analysis/_hazard.py +++ b/pySEQTarget/analysis/_hazard.py @@ -130,7 +130,7 @@ def _one_boot_log_hr(self, data, idx, model_pos, sample_idx): data.lazy() .join(counts.lazy(), on=self.id_col, how="inner") .with_columns(pl.int_ranges(0, pl.col("_count")).alias("_rep")) - .explode("_rep") + .explode("_rep", empty_as_null=True) .drop("_count", "_rep") .collect() ) @@ -209,7 +209,7 @@ def _hazard_handler(self, data, idx, boot_idx, rng): .first() .sort([self.id_col, "trial"]) .with_columns([pl.lit(list(range(self.followup_max + 1))).alias("followup")]) - .explode("followup") + .explode("followup", empty_as_null=True) .with_columns( [(pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}")] ) diff --git a/pySEQTarget/analysis/_survival_pred.py b/pySEQTarget/analysis/_survival_pred.py index c9e17a7..725f412 100644 --- a/pySEQTarget/analysis/_survival_pred.py +++ b/pySEQTarget/analysis/_survival_pred.py @@ -148,7 +148,7 @@ def _calculate_risk(self, data, idx=None, val=None): .first() .drop(["followup", f"followup{self.indicator_squared}"]) .with_columns([pl.lit(followup_range).alias("followup")]) - .explode("followup") + .explode("followup", empty_as_null=True) .with_columns( [(pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}")] ) diff --git a/pySEQTarget/expansion/_mapper.py b/pySEQTarget/expansion/_mapper.py index 37600e5..6918779 100644 --- a/pySEQTarget/expansion/_mapper.py +++ b/pySEQTarget/expansion/_mapper.py @@ -19,7 +19,7 @@ def _mapper(data, id_col, time_col, min_followup=-math.inf, max_followup=math.in ).alias("period") ] ) - .explode("period") + .explode("period", empty_as_null=True) .drop(pl.col(time_col)) .with_columns( [ diff --git a/pySEQTarget/helpers/_bootstrap.py b/pySEQTarget/helpers/_bootstrap.py index a80fecb..ff1fec8 100644 --- a/pySEQTarget/helpers/_bootstrap.py +++ b/pySEQTarget/helpers/_bootstrap.py @@ -60,7 +60,7 @@ def _prepare_boot_data(self, data, boot_id): data.lazy() .join(counts.lazy(), on=self.id_col, how="inner") .with_columns(pl.int_ranges(0, pl.col("count")).alias("replicate")) - .explode("replicate") + .explode("replicate", empty_as_null=True) .with_columns(new_id) .drop("count", "replicate") .collect() From 096490d6912c3a7ab3b9e2e74c0728c0c4ea6b0d Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 3 Jul 2026 13:07:45 +0000 Subject: [PATCH 4/4] Auto-format code --- pySEQTarget/SEQopts.py | 5 +---- tests/test_armspecific_weights.py | 4 +--- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index 07d8a26..cf63e76 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -215,10 +215,7 @@ def _normalize_formulas(self): setattr( self, i, - [ - "".join(a.split()) if isinstance(a, str) else a - for a in attr - ], + ["".join(a.split()) if isinstance(a, str) else a for a in attr], ) else: setattr(self, i, "".join(attr.split())) diff --git a/tests/test_armspecific_weights.py b/tests/test_armspecific_weights.py index 062e967..767a87b 100644 --- a/tests/test_armspecific_weights.py +++ b/tests/test_armspecific_weights.py @@ -55,9 +55,7 @@ def test_per_arm_denominator_fits_different_covariates_in_each_arm(): def test_per_arm_formulas_with_identical_elements_match_shared_fit(): shared = _model(numerator="sex", denominator="N+L+P+sex") - perarm = _model( - numerator=["sex", "sex"], denominator=["N+L+P+sex", "N+L+P+sex"] - ) + perarm = _model(numerator=["sex", "sex"], denominator=["N+L+P+sex", "N+L+P+sex"]) for i in range(2): np.testing.assert_allclose(