Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
49d17ff
add timing printouts within cov2d
garrettwrong Feb 26, 2026
822e9f9
add fetch and eval_t times
garrettwrong Feb 26, 2026
66cf925
stashing initial rad ctf port, trying compare to our filter
garrettwrong Mar 19, 2026
9680000
stashing to_radial and freq pt scaling patches
garrettwrong Mar 31, 2026
a88fd9b
cleanup debugging logic a bit
garrettwrong Mar 31, 2026
f621822
continue cleanup
garrettwrong Mar 31, 2026
c84a96c
use existing expand_method, add radial, instead of new flag
garrettwrong Apr 9, 2026
798951b
use existing expand_method, add radial, instead of new flag
garrettwrong Apr 9, 2026
4568e6f
cleanup
garrettwrong Apr 9, 2026
4a562e7
cleanup
garrettwrong Apr 9, 2026
45a6caf
remove warnings
garrettwrong Apr 9, 2026
af522e7
fix logic error
garrettwrong Apr 10, 2026
1713e0f
refix logic error
garrettwrong Apr 10, 2026
7ab9c37
add tqdm to cov2d filter to basis mat
garrettwrong Apr 10, 2026
780a166
optimal fle basis comp
garrettwrong Apr 14, 2026
af7d1bc
stub in bulk ctf code
garrettwrong Apr 16, 2026
74e60bf
tests passing except FFB opts
garrettwrong Apr 21, 2026
0a1f3c5
ffb equiv checkpoint
garrettwrong Apr 21, 2026
0dc76ee
cleanup
garrettwrong Apr 21, 2026
0582bf7
cleanup
garrettwrong Apr 22, 2026
faad4ca
cleanup FFB doc strings
garrettwrong Apr 22, 2026
1cf187a
dtype sensitivity
garrettwrong Apr 22, 2026
042bad7
more cleanup, maybe CI passing
garrettwrong Apr 22, 2026
338ebfa
more cleanup, maybe docs runs
garrettwrong Apr 22, 2026
60ae2bf
cleanup extra loop
garrettwrong Apr 23, 2026
2d9ad39
add force diag option to cov2d
garrettwrong Apr 29, 2026
2c30fb5
ctf stack unit test patches
garrettwrong May 6, 2026
760f013
minimal xform/pipeline patches
garrettwrong May 12, 2026
3a93549
got both filter (ds) and filter stack running
garrettwrong May 14, 2026
90b32a1
initial attempt extending multiplicative filter bcast
garrettwrong May 14, 2026
92c74b5
hacktastic ctf param passthrough
garrettwrong May 18, 2026
3658db3
revert last approach in favor of using evaluate per dev meeting
garrettwrong May 21, 2026
820d9b6
rm unused var
garrettwrong May 21, 2026
c87f753
satisfy tox
garrettwrong May 22, 2026
03eba11
stashing, got filter stack to basis mat eval working for ffb2d
garrettwrong May 22, 2026
6024124
begin filter_basis_mat cleanup
garrettwrong May 26, 2026
f0e1429
continue filter_basis_mat cleanup
garrettwrong May 26, 2026
e9d95fd
initial documentation updates
garrettwrong May 27, 2026
36a39f4
make np.unique call compat with older numpy
garrettwrong May 27, 2026
97e0326
should been better, was not
garrettwrong Jun 2, 2026
756f3e0
first round cleanup
garrettwrong Jun 2, 2026
2428e09
cleanup unused rmat and move reshapes out of loop
garrettwrong Jun 2, 2026
c6c76b2
cleanup additional ffb2d test
garrettwrong Jun 2, 2026
37a71a3
cleanup some filter eval concerns
garrettwrong Jun 2, 2026
3db308e
tox cleanup
garrettwrong Jun 2, 2026
69748b3
tox cleanup
garrettwrong Jun 4, 2026
10ce299
missing xp
garrettwrong Jun 4, 2026
888fc5f
add large covar pytest file and some minor changes for one of the cases
garrettwrong Jun 10, 2026
f262f4e
add a top of file docstring to test
garrettwrong Jun 11, 2026
c3463b9
more cleanup
garrettwrong Jun 11, 2026
7214293
improve test coverage
garrettwrong Jun 22, 2026
6c11b8f
increase testing
garrettwrong Jun 24, 2026
091697d
force covering the batching case
garrettwrong Jun 24, 2026
df12950
remove redundant function
garrettwrong Jun 25, 2026
8b9beee
consolidate redundant functions
garrettwrong Jun 25, 2026
8b8cec9
comment cleanup
garrettwrong Jun 29, 2026
e0f2311
cleanup filter src and test file
garrettwrong Jun 29, 2026
466311a
cleanup
garrettwrong Jun 29, 2026
a7b0282
breakout Image convolve from Image filter
garrettwrong Jun 29, 2026
8d9e1ae
minor cleanup new test file
garrettwrong Jun 30, 2026
82423b2
add ability to force non-radial filter_stack_to_basis_mats given radi…
garrettwrong Jul 1, 2026
84b5339
hide/note force_diag
garrettwrong Jul 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions gallery/experiments/save_simulation_relion_reconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
# that RELION will recover as optics groups.

vol = emdb_2660()
ctf_filters = [RadialCTFFilter(defocus=d) for d in defocus]
ctf_filters = RadialCTFFilter(defocus=defocus)


# %%
Expand All @@ -64,7 +64,7 @@
sim = Simulation(
n=n_particles,
vols=vol,
unique_filters=ctf_filters,
filter_stack=ctf_filters,
noise_adder=WhiteNoiseAdder.from_snr(snr),
)
sim.save(star_path, overwrite=True)
Expand Down
12 changes: 7 additions & 5 deletions gallery/experiments/simulated_abinitio_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,19 @@ def noise_function(x, y):
alpha = 0.1 # Amplitude contrast

# Create filters
ctf_filters = [
RadialCTFFilter(pixel_size, voltage, defocus=d, Cs=2.0, alpha=0.1)
for d in np.linspace(defocus_min, defocus_max, defocus_ct)
]
ctf_filters = RadialCTFFilter(
voltage,
defocus=np.linspace(defocus_min, defocus_max, defocus_ct),
Cs=2.0,
alpha=0.1,
)

# Finally create the Simulation
src = Simulation(
n=num_imgs,
vols=og_v,
noise_adder=custom_noise,
unique_filters=ctf_filters,
filter_stack=ctf_filters,
)

# Downsample
Expand Down
7 changes: 2 additions & 5 deletions gallery/tutorials/aspire_introduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,10 +570,7 @@ def noise_function(x, y):
defocus_ct = 7

# Generate several CTFs.
ctf_filters = [
RadialCTFFilter(defocus=d)
for d in np.linspace(defocus_min, defocus_max, defocus_ct)
]
ctf_filters = RadialCTFFilter(defocus=np.linspace(defocus_min, defocus_max, defocus_ct))

# %%
# Combining into a Simulation
Expand All @@ -586,7 +583,7 @@ def noise_function(x, y):
amplitudes=1,
offsets=0,
noise_adder=white_noise_adder,
unique_filters=ctf_filters,
filter_stack=ctf_filters,
seed=42,
)

Expand Down
8 changes: 3 additions & 5 deletions gallery/tutorials/pipeline_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,8 @@
defocus_max = 25000
defocus_ct = 7

ctf_filters = [
RadialCTFFilter(defocus=d)
for d in np.linspace(defocus_min, defocus_max, defocus_ct)
]
ctf_filters = RadialCTFFilter(defocus=np.linspace(defocus_min, defocus_max, defocus_ct))


# %%
# Initialize Simulation Object
Expand All @@ -96,7 +94,7 @@
n=2500, # number of projections
vols=original_vol, # volume source
offsets=0, # Default: images are randomly shifted
unique_filters=ctf_filters,
filter_stack=ctf_filters,
noise_adder=WhiteNoiseAdder(var=0.0002), # desired noise variance
).cache()

Expand Down
17 changes: 9 additions & 8 deletions gallery/tutorials/tutorials/cov2d_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,13 @@

print("Initialize simulation object and CTF filters.")
# Create filters
ctf_filters = [
RadialCTFFilter(voltage, defocus=d, Cs=2.0, alpha=0.1)
for d in np.linspace(defocus_min, defocus_max, defocus_ct)
]
ctf_filters = RadialCTFFilter(
voltage,
defocus=np.linspace(defocus_min, defocus_max, defocus_ct),
Cs=2.0,
alpha=0.1,
)


# Load the map file of a 70S Ribosome
print(
Expand All @@ -89,7 +92,7 @@
L=img_size,
n=num_imgs,
vols=vols,
unique_filters=ctf_filters,
filter_stack=ctf_filters,
offsets=0.0,
amplitudes=1.0,
dtype=dtype,
Expand All @@ -109,9 +112,7 @@
h_idx = sim.filter_indices

# Evaluate CTF in the 8X8 FB basis
h_ctf_fb = [
ffbbasis.filter_to_basis_mat(filt, pixel_size=pixel_size) for filt in ctf_filters
]
h_ctf_fb = ffbbasis.filter_stack_to_basis_mats(ctf_filters, pixel_size=pixel_size)

# Get clean images from projections of 3D map.
print("Apply CTF filters to clean images.")
Expand Down
2 changes: 1 addition & 1 deletion gallery/tutorials/tutorials/cov3d_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
L=img_size,
n=num_imgs,
vols=vols,
unique_filters=[RadialCTFFilter(defocus=d) for d in np.linspace(1.5e4, 2.5e4, 7)],
filter_stack=RadialCTFFilter(defocus=np.linspace(1.5e4, 2.5e4, 7)),
dtype=dtype,
)

Expand Down
7 changes: 3 additions & 4 deletions gallery/tutorials/tutorials/ctf.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,8 @@ def generate_example_image(L, noise_variance=0.1):

# Construct a range of CTF filters.
defoci = [2500, 5000, 10000, 20000]
ctf_filters = [
RadialCTFFilter(voltage=200, defocus=d, Cs=2.26, alpha=0.07, B=0) for d in defoci
]
ctf_filters = RadialCTFFilter(voltage=200, defocus=defoci, Cs=2.26, alpha=0.07, B=0)


# %%
# Generate CTF corrupted Images
Expand Down Expand Up @@ -334,7 +333,7 @@ def generate_example_image(L, noise_variance=0.1):
from aspire.source import Simulation

# Create the Source. ``ctf_filters`` are re-used from earlier section.
src = Simulation(L=64, n=4, unique_filters=ctf_filters, pixel_size=1)
src = Simulation(L=64, n=4, filter_stack=ctf_filters, pixel_size=1)
src.images[:4].show()

# %%
Expand Down
4 changes: 1 addition & 3 deletions gallery/tutorials/tutorials/micrograph_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,7 @@

# Create our CTF Filter and add it to a list.
# This configuration will apply the same CTF to all particles.
ctfs = [
RadialCTFFilter(voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0),
]
ctfs = RadialCTFFilter(voltage=200, defocus=15000, Cs=2.26, alpha=0.07, B=0)

src = MicrographSimulation(
vol,
Expand Down
13 changes: 8 additions & 5 deletions gallery/tutorials/tutorials/orient3d_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,13 @@

print("Initialize simulation object and CTF filters.")
# Create CTF filters
filters = [
RadialCTFFilter(voltage, defocus=d, Cs=2.0, alpha=0.1)
for d in np.linspace(defocus_min, defocus_max, defocus_ct)
]
filters = RadialCTFFilter(
voltage,
defocus=np.linspace(defocus_min, defocus_max, defocus_ct),
Cs=2.0,
alpha=0.1,
)


# %%
# Downsampling
Expand All @@ -74,7 +77,7 @@
# Create a simulation object with specified filters and the downsampled 3D map
print("Use downsampled map to creat simulation object.")
sim = Simulation(
L=img_size, n=num_imgs, vols=vols, unique_filters=filters, pixel_size=5, dtype=dtype
L=img_size, n=num_imgs, vols=vols, filter_stack=filters, pixel_size=5, dtype=dtype
)

print("Get true rotation angles generated randomly by the simulation object.")
Expand Down
13 changes: 8 additions & 5 deletions gallery/tutorials/tutorials/preprocess_imgs_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,13 @@

print("Initialize simulation object and CTF filters.")
# Create CTF filters
ctf_filters = [
RadialCTFFilter(voltage, defocus=d, Cs=2.0, alpha=0.1)
for d in np.linspace(defocus_min, defocus_max, defocus_ct)
]
ctf_filters = RadialCTFFilter(
voltage,
defocus=np.linspace(defocus_min, defocus_max, defocus_ct),
Cs=2.0,
alpha=0.1,
)


# Load the map file of a 70S ribosome and downsample the 3D map to desired image size.
print("Load 3D map from mrc file")
Expand All @@ -73,7 +76,7 @@
L=img_size,
n=num_imgs,
vols=vols,
unique_filters=ctf_filters,
filter_stack=ctf_filters,
noise_adder=noise_adder,
pixel_size=pixel_size,
)
Expand Down
2 changes: 1 addition & 1 deletion src/aspire/basis/fb_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,6 @@ def calculate_bispectrum(

def filter_to_basis_mat(self, *args, **kwargs):
"""
See `SteerableBasis2D.filter_to_basis_mat`.
See `SteerableBasis2D.filter_stack_to_basis_mat`.
"""
return super().filter_to_basis_mat(*args, **kwargs)
125 changes: 108 additions & 17 deletions src/aspire/basis/ffb_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ def _build(self):
self._precomp["gl_nodes"]
)

# Generate radial filter point set for radial optimized eval
# Weights appear a little sensitive to dtype, otherwise could use self._precomp["gl_nodes"]
k_vals, _ = lgwt(self.n_r, 0, 0.5, dtype=np.float64)
self._filter_pts = np.pad(
2 * np.pi * k_vals.reshape(1, -1), ((0, 1), (0, 0))
).astype(self.dtype)

def _precomp(self):
"""
Precomute the basis functions on a polar Fourier grid
Expand Down Expand Up @@ -236,15 +243,19 @@ def _evaluate_t(self, x):

return xp.asnumpy(v)

def filter_to_basis_mat(self, f, **kwargs):
def _filter_stack_to_basis_mats(self, f, **kwargs):
"""
See `SteerableBasis2D.filter_to_basis_mat`.
"""
# Note 'method' and 'truncate' not relevant for this optimized FFB code.
if kwargs.get("method", None) is not None:
# Note 'truncate' not relevant for this specific FFB code.
# `expand_method=radial` should have already been diverted
# by the wrapping code in `super().filter_stack_to_basis_mats`.
# Permits forcing 2d calc on a radial filter via `expand_method=evaluate_t`
expand_method = kwargs.get("expand_method", None)
if expand_method not in [None, "evaluate_t"]:
raise NotImplementedError(
"`FFBBasis2D.filter_to_basis_mat` method {method} not supported."
" Use `method=None`."
f"`FFBBasis2D.filter_to_basis_mat` expand_method '{expand_method}' not supported."
" Use `expand_method=` `None` or `evaluate_t`."
)

pixel_size = kwargs.get("pixel_size", None)
Expand All @@ -271,28 +282,108 @@ def filter_to_basis_mat(self, f, **kwargs):
omegay = k * np.sin(theta)
omega = 2 * np.pi * np.vstack((omegax.flatten("C"), omegay.flatten("C")))

# This should return either a single 2d array, or stack of 2d arrays
# Reshape singleton to stack of 1.
h_vals2d = (
h_fun(omega, pixel_size=pixel_size).reshape(n_k, n_theta).astype(self.dtype)
h_fun(omega, pixel_size=pixel_size)
.reshape(len(f), n_k, n_theta)
.astype(self.dtype)
)
h_vals = np.sum(h_vals2d, axis=1) / n_theta
h_vals = h_vals2d.sum(axis=-1) / n_theta

# Represent 1D function values in basis
h_basis = BlkDiagMatrix.empty(2 * self.ell_max + 1, dtype=self.dtype)
# Represent each 1D functions values in basis
h_basis = [
BlkDiagMatrix.empty(2 * self.ell_max + 1, dtype=self.dtype) for _ in h_vals
]
ind_ell = 0
# Reshapes for broadcasting
k_vals = k_vals.reshape(n_k, 1)
wts = wts.reshape(n_k, 1)
h_vals = h_vals.reshape(len(f), n_k, 1)
for ell in range(0, self.ell_max + 1):
k_max = self.k_max[ell]
rmat = 2 * k_vals.reshape(n_k, 1) * self.r0[ell][0:k_max].T
basis_vals = np.zeros_like(rmat)
basis_vals = np.zeros((n_k, k_max), dtype=self.dtype)
ind_radial = np.sum(self.k_max[0:ell])
basis_vals[:, 0:k_max] = radial[ind_radial : ind_radial + k_max].T
h_basis_vals = basis_vals * h_vals.reshape(n_k, 1)
h_basis_ell = basis_vals.T @ (
h_basis_vals * k_vals.reshape(n_k, 1) * wts.reshape(n_k, 1)
)
h_basis[ind_ell] = h_basis_ell
h_basis_vals = basis_vals * h_vals
h_basis_ell = basis_vals.T @ (h_basis_vals * k_vals * wts)

# loop over assignment blocks.
for i in range(len(f)):
h_basis[i][ind_ell] = h_basis_ell[i]
ind_ell += 1
if ell > 0:
h_basis[ind_ell] = h_basis[ind_ell - 1]
for i in range(len(f)):
h_basis[i][ind_ell] = h_basis[i][ind_ell - 1]
ind_ell += 1

return h_basis

def filter_to_basis_mat(self, f, **kwargs):
"""
See `SteerableBasis2D.filter_stack_to_basis_mats`.
"""
if len(f) != 1:
raise RuntimeError("Unexpected filter length.")
return self._filter_stack_to_basis_mats(f, **kwargs)[0]

def expand_radial_vec(self, radial_vec, **kwargs):
"""
Expands radial vector or stack of vetors `radial_vec` to basis matrix.

:param radial_vec: Array holding radial vector,
shaped (n_radial_pts) or (n_vectors, n_radial_pts)
:force_diag: Optionally flush off-diagonal elements to zero and return `DiagMatrix`
:return: List of `BlkDiagMatrix`, or list of `DiagMatrix`
"""
force_diag = kwargs.get("force_diag", False)

# Convert vector to (1,...)
if radial_vec.ndim == 1:
radial_vec = radial_vec.reshape(1, *radial_vec.shape)
# Optionally transfer to GPU
radial_vec = xp.asarray(radial_vec)

# Set same dimensions as basis object
n_k = self.n_r
radial = self._precomp["radial"]

k_vals = xp.asarray(self._precomp["gl_nodes"])
wts = xp.asarray(self._precomp["gl_weights"])

# Represent 1D function values in basis
h_basis = [
BlkDiagMatrix.empty(2 * self.ell_max + 1, dtype=self.dtype)
for _ in radial_vec
]

ind_ell = 0
# Reshapes for broadcasting
radial_vec = radial_vec.reshape(len(h_basis), n_k, 1)
k_vals = k_vals.reshape(1, n_k, 1)
wts = wts.reshape(1, n_k, 1)
for ell in range(0, self.ell_max + 1):
k_max = self.k_max[ell]
basis_vals = xp.zeros((n_k, k_max), dtype=self.dtype)
ind_radial = np.sum(self.k_max[0:ell])
basis_vals[:, 0:k_max] = xp.asarray(
radial[ind_radial : ind_radial + k_max]
).T
h_basis_vals = basis_vals * radial_vec
h_basis_ell = basis_vals.T @ (h_basis_vals * k_vals * wts)
h_basis_ell = xp.asnumpy(h_basis_ell)
for _filter in range(len(radial_vec)):
_tmp = h_basis[_filter][ind_ell] = h_basis_ell[_filter]
if ell > 0:
h_basis[_filter][ind_ell + 1] = _tmp
if _filter == len(radial_vec) - 1:
ind_ell += 1
if ell > 0:
ind_ell += 1
if force_diag:
logger.warning(
"Forcing block diagonal to diagonal. Zeroing all off diagonal values."
)
h_basis = [h.diag() for h in h_basis]

return h_basis
Loading