From 5ca99e8dc6110dfdc23194bdd8eaf0924403de26 Mon Sep 17 00:00:00 2001 From: jpweideman Date: Mon, 25 May 2026 18:12:34 +0200 Subject: [PATCH 1/2] fix sglrw for dict params and multi dim tensors --- posteriors/sgmcmc/sglrw.py | 14 ++++++-------- tests/sgmcmc/test_baoa.py | 4 ++-- tests/sgmcmc/test_sghmc.py | 4 ++-- tests/sgmcmc/test_sgld.py | 4 ++-- tests/sgmcmc/test_sglrw.py | 4 ++-- tests/sgmcmc/test_sgnht.py | 4 ++-- 6 files changed, 16 insertions(+), 18 deletions(-) diff --git a/posteriors/sgmcmc/sglrw.py b/posteriors/sgmcmc/sglrw.py index 7555bf5..d7e2ce3 100644 --- a/posteriors/sgmcmc/sglrw.py +++ b/posteriors/sgmcmc/sglrw.py @@ -91,18 +91,13 @@ def update( # Resolve schedules lr_val = lr(state.step) if callable(lr) else lr T_val = temperature(state.step) if callable(temperature) else temperature - lr_val = torch.as_tensor( - lr_val, dtype=state.params.dtype, device=state.params.device - ) - T_val = torch.as_tensor(T_val, dtype=state.params.dtype, device=state.params.device) - # Spatial stepsize to make update binary - diffusion_val = torch.sqrt(2.0 * T_val) - delta_x = torch.sqrt(lr_val) * diffusion_val + diffusion_val = (2.0 * T_val) ** 0.5 + delta_x = lr_val ** 0.5 * diffusion_val # Per-parameter binary LRW transform def transform_params(p, g): - p_plus = ternary_probs(g, diffusion_val, lr_val, delta_x)[:, 2] + p_plus = ternary_probs(g, diffusion_val, lr_val, delta_x)[..., 2] u = torch.rand_like(p_plus) step_sign = torch.where( @@ -139,6 +134,9 @@ def ternary_probs( Returns: Update probabilities as a tensor, with last axis being [p_minus, p_zero, p_plus]. """ + diffusion_val = torch.as_tensor(diffusion_val, dtype=drift_val.dtype, device=drift_val.device) + stepsize = torch.as_tensor(stepsize, dtype=drift_val.dtype, device=drift_val.device) + delta_x = torch.as_tensor(delta_x, dtype=drift_val.dtype, device=drift_val.device) desired_mean = stepsize * drift_val desired_var = stepsize * diffusion_val**2 scaled_mean = desired_mean / delta_x diff --git a/tests/sgmcmc/test_baoa.py b/tests/sgmcmc/test_baoa.py index 468f4f9..892074e 100644 --- a/tests/sgmcmc/test_baoa.py +++ b/tests/sgmcmc/test_baoa.py @@ -40,8 +40,8 @@ def lr(step): # Build transform transform = baoa.build(log_prob, lr) - # Initialise - params = torch.randn(dim) + # Initialise + params = {"w": torch.randn(2, 2), "b": torch.randn(1)} # Verify inplace update verify_inplace_update(transform, params, None) diff --git a/tests/sgmcmc/test_sghmc.py b/tests/sgmcmc/test_sghmc.py index f0bdc24..2f9a208 100644 --- a/tests/sgmcmc/test_sghmc.py +++ b/tests/sgmcmc/test_sghmc.py @@ -48,8 +48,8 @@ def lr(step): # Build transform transform = sghmc.build(log_prob, lr) - # Initialise - params = torch.randn(dim) + # Initialise + params = {"w": torch.randn(2, 2), "b": torch.randn(1)} # Verify inplace update verify_inplace_update(transform, params, None) diff --git a/tests/sgmcmc/test_sgld.py b/tests/sgmcmc/test_sgld.py index 348853e..43b5958 100644 --- a/tests/sgmcmc/test_sgld.py +++ b/tests/sgmcmc/test_sgld.py @@ -33,8 +33,8 @@ def lr(step): # Build transform transform = sgld.build(log_prob, lr) - # Initialise - params = torch.randn(dim) + # Initialise + params = {"w": torch.randn(2, 2), "b": torch.randn(1)} # Verify inplace update verify_inplace_update(transform, params, None) diff --git a/tests/sgmcmc/test_sglrw.py b/tests/sgmcmc/test_sglrw.py index 93c3a0c..380c776 100644 --- a/tests/sgmcmc/test_sglrw.py +++ b/tests/sgmcmc/test_sglrw.py @@ -32,8 +32,8 @@ def lr(step): # Build transform transform = sglrw.build(log_prob, lr) - # Initialise - params = torch.randn(dim) + # Initialise + params = {"w": torch.randn(2, 2), "b": torch.randn(1)} # Verify inplace update verify_inplace_update(transform, params, None) diff --git a/tests/sgmcmc/test_sgnht.py b/tests/sgmcmc/test_sgnht.py index 74b0d9a..3b59d05 100644 --- a/tests/sgmcmc/test_sgnht.py +++ b/tests/sgmcmc/test_sgnht.py @@ -48,8 +48,8 @@ def lr(step): # Build transform transform = sgnht.build(log_prob, lr) - # Initialise - params = torch.randn(dim) + # Initialise + params = {"w": torch.randn(2, 2), "b": torch.randn(1)} # Verify inplace update verify_inplace_update(transform, params, None) From 9e7ef79e9b492bda349065b6a2c89daa25f4934f Mon Sep 17 00:00:00 2001 From: jpweideman Date: Tue, 26 May 2026 11:52:45 +0200 Subject: [PATCH 2/2] ruff formatting --- posteriors/sgmcmc/sglrw.py | 6 ++++-- tests/sgmcmc/test_baoa.py | 2 +- tests/sgmcmc/test_sghmc.py | 2 +- tests/sgmcmc/test_sgld.py | 2 +- tests/sgmcmc/test_sglrw.py | 2 +- tests/sgmcmc/test_sgnht.py | 2 +- 6 files changed, 9 insertions(+), 7 deletions(-) diff --git a/posteriors/sgmcmc/sglrw.py b/posteriors/sgmcmc/sglrw.py index d7e2ce3..52ce96b 100644 --- a/posteriors/sgmcmc/sglrw.py +++ b/posteriors/sgmcmc/sglrw.py @@ -93,7 +93,7 @@ def update( T_val = temperature(state.step) if callable(temperature) else temperature # Spatial stepsize to make update binary diffusion_val = (2.0 * T_val) ** 0.5 - delta_x = lr_val ** 0.5 * diffusion_val + delta_x = lr_val**0.5 * diffusion_val # Per-parameter binary LRW transform def transform_params(p, g): @@ -134,7 +134,9 @@ def ternary_probs( Returns: Update probabilities as a tensor, with last axis being [p_minus, p_zero, p_plus]. """ - diffusion_val = torch.as_tensor(diffusion_val, dtype=drift_val.dtype, device=drift_val.device) + diffusion_val = torch.as_tensor( + diffusion_val, dtype=drift_val.dtype, device=drift_val.device + ) stepsize = torch.as_tensor(stepsize, dtype=drift_val.dtype, device=drift_val.device) delta_x = torch.as_tensor(delta_x, dtype=drift_val.dtype, device=drift_val.device) desired_mean = stepsize * drift_val diff --git a/tests/sgmcmc/test_baoa.py b/tests/sgmcmc/test_baoa.py index 892074e..131c0c8 100644 --- a/tests/sgmcmc/test_baoa.py +++ b/tests/sgmcmc/test_baoa.py @@ -40,7 +40,7 @@ def lr(step): # Build transform transform = baoa.build(log_prob, lr) - # Initialise + # Initialise params = {"w": torch.randn(2, 2), "b": torch.randn(1)} # Verify inplace update diff --git a/tests/sgmcmc/test_sghmc.py b/tests/sgmcmc/test_sghmc.py index 2f9a208..94573e1 100644 --- a/tests/sgmcmc/test_sghmc.py +++ b/tests/sgmcmc/test_sghmc.py @@ -48,7 +48,7 @@ def lr(step): # Build transform transform = sghmc.build(log_prob, lr) - # Initialise + # Initialise params = {"w": torch.randn(2, 2), "b": torch.randn(1)} # Verify inplace update diff --git a/tests/sgmcmc/test_sgld.py b/tests/sgmcmc/test_sgld.py index 43b5958..4dac258 100644 --- a/tests/sgmcmc/test_sgld.py +++ b/tests/sgmcmc/test_sgld.py @@ -33,7 +33,7 @@ def lr(step): # Build transform transform = sgld.build(log_prob, lr) - # Initialise + # Initialise params = {"w": torch.randn(2, 2), "b": torch.randn(1)} # Verify inplace update diff --git a/tests/sgmcmc/test_sglrw.py b/tests/sgmcmc/test_sglrw.py index 380c776..458012f 100644 --- a/tests/sgmcmc/test_sglrw.py +++ b/tests/sgmcmc/test_sglrw.py @@ -32,7 +32,7 @@ def lr(step): # Build transform transform = sglrw.build(log_prob, lr) - # Initialise + # Initialise params = {"w": torch.randn(2, 2), "b": torch.randn(1)} # Verify inplace update diff --git a/tests/sgmcmc/test_sgnht.py b/tests/sgmcmc/test_sgnht.py index 3b59d05..0707cb3 100644 --- a/tests/sgmcmc/test_sgnht.py +++ b/tests/sgmcmc/test_sgnht.py @@ -48,7 +48,7 @@ def lr(step): # Build transform transform = sgnht.build(log_prob, lr) - # Initialise + # Initialise params = {"w": torch.randn(2, 2), "b": torch.randn(1)} # Verify inplace update