From f6fe291d363c408c609ecbfb1f52f187698fd1b0 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Wed, 25 Mar 2026 10:30:25 -0400 Subject: [PATCH 01/76] bump: pydantify NexusValuesWriter --- CHAP/common/writer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/CHAP/common/writer.py b/CHAP/common/writer.py index da33beda..2e09febc 100755 --- a/CHAP/common/writer.py +++ b/CHAP/common/writer.py @@ -536,8 +536,15 @@ def write(self, data): class NexusValuesWriter(Writer): - """Writer for updating values in an existing NeXus file.""" - def write(self, data, filename, path_prefix=''): + """Writer for updating values in an existing NeXus file. + + :ivar path_prefix: Prefix to use for all paths in input `data`, + defaults to `''`. + :vartype path_prefix: str, optional + """ + path_prefix: str = '' + + def write(self, data, filename): """Write new values specified in `data` to the exising NeXus file `filename`. @@ -561,7 +568,7 @@ def write(self, data, filename, path_prefix=''): with NXFile(filename, 'a') as nxroot: self.nxs_writer( nxroot=nxroot, - path=os.path.join(path_prefix, d['path']), + path=os.path.join(self.path_prefix, d['path']), idx=d['idx'], data=d['data'] ) From 2832f19f57750ff1dbd9c4bc5d03dc807b27a50d Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Wed, 25 Mar 2026 12:46:29 -0400 Subject: [PATCH 02/76] add: NexusValuesWriter.resize_axis If `False`, same behavior as before. If an `int`, use it as the axis to resize the NXfield along (only if the target slice shape does not match the data shape). --- CHAP/common/writer.py | 36 +++++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/CHAP/common/writer.py b/CHAP/common/writer.py index 2e09febc..d6c6f1e3 100755 --- a/CHAP/common/writer.py +++ b/CHAP/common/writer.py @@ -7,7 +7,11 @@ # System modules import os -from typing import Optional +from typing import ( + Literal, + Optional, + Union, +) # Third party modules import numpy as np @@ -541,8 +545,14 @@ class NexusValuesWriter(Writer): :ivar path_prefix: Prefix to use for all paths in input `data`, defaults to `''`. :vartype path_prefix: str, optional + :ivar resize_axis: `False` OR the axis along which the dataset + should be resized if the target slice shape does not match the + data shape. If `False`, any mismatching shapes will just raise + an error. Defaults to `False`. + :vartype resize_axis: Union[int, Literal[False]], optional """ path_prefix: str = '' + resize_axis: Union[int, Literal[False]] = False def write(self, data, filename): """Write new values specified in `data` to the exising NeXus @@ -552,7 +562,7 @@ def write(self, data, filename): -- `'path'` identifying the location of the `NXfield` object to which values will be written, `'data'` identifying the data to be written, and `'idx'` identifying the index / indicies of - the `NXfield` to which the data will be written. + the `NXfield` to which the data will be written. :type data: list[PipelineData] :param filename: Name of an existing NeXus file to update. :type filename: str @@ -586,8 +596,8 @@ def nxs_writer(self, nxroot, path, idx, data): :type nxroot: nexusformat.nexus.NXroot :param path: Path to the dataset inside the NeXus file. :type path: str - :param idx: Index or slice where the data should be written. - :type idx: tuple or int + :param idx: Slice where the data should be written. + :type idx: slice :param data: Data to be written to the specified slice in the dataset. :type data: numpy.ndarray or compatible array-like object @@ -607,9 +617,21 @@ def nxs_writer(self, nxroot, path, idx, data): # Check that the slice shape matches the data shape data = np.asarray(data) if dataset[idx].shape != data.shape: - raise ValueError( - f'Data shape {data.shape} does not match the target slice ' - f'shape {dataset[idx].shape}.') + if self.resize_axis is not False: + # Resize along the specified axis + start = idx.start or 0 + stop = start + data.shape[0] + newshape = max(dataset.shape[self.resize_axis], stop) + self.logger.debug( + f'Resizing {path} {dataset.shape} -> {newshape} ' + + f'along axis {self.resize_axis}' + ) + + dataset.resize(newshape, axis=self.resize_axis) + else: + raise ValueError( + f'Data shape {data.shape} does not match the target slice ' + f'shape {dataset[idx].shape}.') # Write the data to the specified slice dataset[idx] = data From 0c1e7143f88aef0881f9598a1f58a96a55e93012 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Wed, 25 Mar 2026 12:52:32 -0400 Subject: [PATCH 03/76] fix: list access typo --- CHAP/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHAP/pipeline.py b/CHAP/pipeline.py index c7e63512..522bc772 100755 --- a/CHAP/pipeline.py +++ b/CHAP/pipeline.py @@ -352,7 +352,7 @@ def get_pipelinedata_item(data, index=-1, remove=False): if isinstance(data, list): if remove: return data.pop(index)['data'] - return data.get(index)['data'] + return data[index]['data'] return data def execute(self, data): From f02dd10d4cfbb45d355864690cb2e5bfbbe28841 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Wed, 25 Mar 2026 12:54:33 -0400 Subject: [PATCH 04/76] add: expandable first dimension for all NXfields in result of MapProcessor --- CHAP/common/processor.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/CHAP/common/processor.py b/CHAP/common/processor.py index 172b120e..ee230b56 100755 --- a/CHAP/common/processor.py +++ b/CHAP/common/processor.py @@ -1272,7 +1272,10 @@ def linkdims(nxgroup, nxdata_source): attrs={'units': dim.units, 'long_name': f'{dim.label} ({dim.units})', 'data_type': dim.data_type, - 'local_name': dim.name}) + 'local_name': dim.name}, + maxshape= (None, *independent_dimensions[i].shape[1:]), + chunks=(1, *independent_dimensions[i].shape[1:]) + ) else: nxentry.independent_dimensions.index = NXfield( np.arange(independent_dimensions[0].size), 'index') @@ -1291,7 +1294,10 @@ def linkdims(nxgroup, nxdata_source): units=dim.units, attrs={'long_name': f'{dim.label} ({dim.units})', 'data_type': dim.data_type, - 'local_name': dim.name})) + 'local_name': dim.name}, + maxshape=(None, *all_scalar_data[i].shape[1:]), + chunks=(1, *all_scalar_data[i].shape[1:]) + )) if (self.config.experiment_type == 'EDD' and not placeholder_data is False): scalar_signals.append('placeholder_data_used') @@ -1308,7 +1314,10 @@ def linkdims(nxgroup, nxdata_source): attrs={'units': dim.units, 'long_name': f'{dim.label} ({dim.units})', 'data_type': dim.data_type, - 'local_name': dim.name})) + 'local_name': dim.name}, + maxshape=(None, *independent_dimensions[i].shape[1:]), + chunks=(1, *independent_dimensions[i].shape[1:]) + )) self.config.all_scalar_data.append( PointByPointScanData(**dim.model_dump())) self.config.independent_dimensions.remove(dim) @@ -1336,7 +1345,10 @@ def linkdims(nxgroup, nxdata_source): for i, detector in enumerate(self.detector_config.detectors): nxdata[detector.get_id()] = NXfield( value=data[i], - attrs={**detector.attrs, 'min': min_[i], 'max': max_[i]}) + attrs={**detector.attrs, 'min': min_[i], 'max': max_[i]}, + maxshape=(None, *data[i].shape[1:]), + chunks=(1, *data[i].shape[1:]) + ) detector_ids.append(detector.get_id()) linkdims(nxdata, nxentry.independent_dimensions) if len(self.detector_config.detectors) == 1: From 2ef093628aa9b31fe9e96806e9067f83c0bc4f1e Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Thu, 26 Mar 2026 12:32:58 -0400 Subject: [PATCH 05/76] add: reminder comment --- CHAP/edd/processor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHAP/edd/processor.py b/CHAP/edd/processor.py index 7aa412bc..7ea2d660 100755 --- a/CHAP/edd/processor.py +++ b/CHAP/edd/processor.py @@ -2666,6 +2666,9 @@ def process(self, data): nxdata = nxentry[nxentry.default] # Load the validated calibration configuration + # FIX make this a class field and add to pipeline_fields too? + # NB: be sure that adding this field does not mess up + # self.detector_config during pydantic validation calibration_config = self.get_config( data, schema='edd.models.MCATthCalibrationConfig', remove=False) From a318c74327b231210902066c8d39095dd7697ee2 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Thu, 26 Mar 2026 13:40:04 -0400 Subject: [PATCH 06/76] fix: slice stop indices in MapSliceProcessor --- CHAP/common/map_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHAP/common/map_utils.py b/CHAP/common/map_utils.py index 7ae1f35e..eb15bebc 100644 --- a/CHAP/common/map_utils.py +++ b/CHAP/common/map_utils.py @@ -106,13 +106,13 @@ def process(self, data, #spec_file, scan_number, # Get spec scan indices to process scan_indices = range(npts_scan)[slice( idx_slice.get('start', 0), - idx_slice.get('stop', npts_scan + 1), + idx_slice.get('stop', npts_scan), idx_slice.get('step', 1) )] # Get map indices to write to map_indices = slice( idx_slice.get('start', 0) + index_offset, - idx_slice.get('stop', npts_scan + 1) + index_offset, + idx_slice.get('stop', npts_scan) + index_offset, idx_slice.get('step', 1) ) From b0c2e105699daf793878e3c60a4e6e316129cd65 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Thu, 26 Mar 2026 13:40:51 -0400 Subject: [PATCH 07/76] fix: support oddball EDD get_detector_data return values --- CHAP/common/map_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/CHAP/common/map_utils.py b/CHAP/common/map_utils.py index eb15bebc..3987207d 100644 --- a/CHAP/common/map_utils.py +++ b/CHAP/common/map_utils.py @@ -129,12 +129,18 @@ def process(self, data, #spec_file, scan_number, } for s_d in self.map_config.all_scalar_data ] + if self.map_config.experiment_type == 'EDD': + def get_detector_data(detector, index): + return scan.get_detector_data(detector.get_id(), index)[0] + else: + def get_detector_data(detector, index): + return scan.get_detector_data(detector.get_id(), index) data_points.extend( [ { 'path': f'{self.map_config.title}/data/{det.get_id()}', 'data': np.asarray([ - scan.get_detector_data(det.get_id(), i) + get_detector_data(det, i) for i in scan_indices ]), 'idx': map_indices From 9fa6985716976146387289c87d49c84841408f8a Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Thu, 26 Mar 2026 16:35:04 -0400 Subject: [PATCH 08/76] add: support for finding SCAN_N datasets in groups other than "data" --- CHAP/edd/reader.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/CHAP/edd/reader.py b/CHAP/edd/reader.py index 8ff854a7..8bad0f6a 100755 --- a/CHAP/edd/reader.py +++ b/CHAP/edd/reader.py @@ -750,7 +750,7 @@ class SliceNXdataReader(Reader): provided slicing parameters. :param scan_number: Number of the SPEC scan. - :type scan_number: int + :vartype scan_number: int """ scan_number: conint(ge=0) @@ -772,14 +772,28 @@ def read(self): reader = NexusReader(**self.model_dump()) nxroot = nxcopy(reader.read()) + nxentry = None nxdata = None for nxname, nxobject in nxroot.items(): if isinstance(nxobject, NXentry): + nxentry = nxobject nxdata = nxobject.data if nxdata is None: msg = 'Could not find NXdata group' self.logger.error(msg) raise ValueError(msg) + if 'SCAN_N' not in nxdata: + self.logger.warning(f'SCAN_N not in {nxdata}') + scan_n_found = False + for k, v in nxentry.items(): + if 'SCAN_N' in v: + nxdata = v + scan_n_found = True + self.logger.warning(f'Using SCAN_N dataset in {nxdata}') + if not scan_n_found: + msg = 'Cannot find SCAN_N dataset' + self.logger.error(msg) + raise ValueError(msg) indices = np.argwhere( nxdata.SCAN_N.nxdata == self.scan_number).flatten() From 4fcdda1235786ecc383d2f620460878bebc48be5 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Wed, 1 Apr 2026 12:18:22 -0400 Subject: [PATCH 09/76] fix: add independent_dimensions to results of MapSliceProcessor --- CHAP/common/map_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHAP/common/map_utils.py b/CHAP/common/map_utils.py index 3987207d..2b71c3d0 100644 --- a/CHAP/common/map_utils.py +++ b/CHAP/common/map_utils.py @@ -59,7 +59,7 @@ class MapSliceProcessor(Processor): spec_file: FilePath scan_number: conint(gt=0) - def process(self, data, #spec_file, scan_number, + def process(self, data, idx_slice={'start': 0, 'step': 1}): """Aggregate partial spec and detector data from one scan in a map, returning results in a format suitable for writing to the @@ -128,6 +128,7 @@ def process(self, data, #spec_file, scan_number, 'idx': map_indices } for s_d in self.map_config.all_scalar_data + + self.map_config.independent_dimensions ] if self.map_config.experiment_type == 'EDD': def get_detector_data(detector, index): From 9ed9e017871e61fc6ed15b0cae94b9b19de8cf01 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Wed, 1 Apr 2026 12:19:21 -0400 Subject: [PATCH 10/76] fix: EDD-specific modifications for MapSliceProcessor --- CHAP/common/map_utils.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/CHAP/common/map_utils.py b/CHAP/common/map_utils.py index 2b71c3d0..7b4c1034 100644 --- a/CHAP/common/map_utils.py +++ b/CHAP/common/map_utils.py @@ -131,8 +131,21 @@ def process(self, data, + self.map_config.independent_dimensions ] if self.map_config.experiment_type == 'EDD': + data_points.extend( + [ + { + 'path': f'{self.map_config.title}/independent_dimensions/index', + 'data': [ + i for i in range( + map_indices.start, map_indices.stop + ) + ], + 'idx': map_indices + } + ] + ) def get_detector_data(detector, index): - return scan.get_detector_data(detector.get_id(), index)[0] + return scan.get_detector_data(detector.get_id(), index)[0][0] else: def get_detector_data(detector, index): return scan.get_detector_data(detector.get_id(), index) From 76a9a611f36aee2297b143ed62d9ad67bd12ef48 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Wed, 1 Apr 2026 12:21:59 -0400 Subject: [PATCH 11/76] fix: some missing resizable NXfields for MapProcessor --- CHAP/common/processor.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/CHAP/common/processor.py b/CHAP/common/processor.py index ee230b56..31f11a3c 100755 --- a/CHAP/common/processor.py +++ b/CHAP/common/processor.py @@ -1273,12 +1273,15 @@ def linkdims(nxgroup, nxdata_source): 'long_name': f'{dim.label} ({dim.units})', 'data_type': dim.data_type, 'local_name': dim.name}, - maxshape= (None, *independent_dimensions[i].shape[1:]), + maxshape=(None, *independent_dimensions[i].shape[1:]), chunks=(1, *independent_dimensions[i].shape[1:]) ) else: nxentry.independent_dimensions.index = NXfield( - np.arange(independent_dimensions[0].size), 'index') + np.arange(independent_dimensions[0].size), 'index', + maxshape=(None,), + chunks=(1,) + ) # Set up scalar data NeXus NXdata group # (add the constant independent dimensions) @@ -1298,6 +1301,7 @@ def linkdims(nxgroup, nxdata_source): maxshape=(None, *all_scalar_data[i].shape[1:]), chunks=(1, *all_scalar_data[i].shape[1:]) )) + if (self.config.experiment_type == 'EDD' and not placeholder_data is False): scalar_signals.append('placeholder_data_used') @@ -1305,7 +1309,10 @@ def linkdims(nxgroup, nxdata_source): value=all_scalar_data[-1], attrs={'description': 'Indicates whether placeholder data may be present for' - 'the corresponding frames of detector data.'})) + 'the corresponding frames of detector data.'}, + maxshape=(None, *all_scalar_data[-1].shape[1:]), + chunks=(1, *all_scalar_data[-1].shape[1:]) + )) for i, dim in enumerate(deepcopy(self.config.independent_dimensions)): if i in constant_dim: scalar_signals.append(dim.label) From a68706cd0b881f07d65caa1854100192e00d519d Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Wed, 1 Apr 2026 12:38:48 -0400 Subject: [PATCH 12/76] fix: independent_dimensions paths in MapSliceProcessor results --- CHAP/common/map_utils.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/CHAP/common/map_utils.py b/CHAP/common/map_utils.py index 7b4c1034..541c0d70 100644 --- a/CHAP/common/map_utils.py +++ b/CHAP/common/map_utils.py @@ -128,22 +128,24 @@ def process(self, data, 'idx': map_indices } for s_d in self.map_config.all_scalar_data - + self.map_config.independent_dimensions ] + data_points.extend( + [ + { + 'path': f'{self.map_config.title}/independent_dimensions/{dim.label}', + 'data': np.asarray([ + dim.get_value( + scans, self.scan_number, i, + scalar_data=self.map_config.scalar_data + ) + for i in scan_indices + ]), + 'idx': map_indices, + } + for dim in self.map_config.independent_dimensions + ] + ) if self.map_config.experiment_type == 'EDD': - data_points.extend( - [ - { - 'path': f'{self.map_config.title}/independent_dimensions/index', - 'data': [ - i for i in range( - map_indices.start, map_indices.stop - ) - ], - 'idx': map_indices - } - ] - ) def get_detector_data(detector, index): return scan.get_detector_data(detector.get_id(), index)[0][0] else: From 4d689979fa7100fa37f4d0109d6f562b153b6cb0 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Mon, 6 Apr 2026 11:35:53 -0400 Subject: [PATCH 13/76] add: common.models.IndexSliceConfig --- CHAP/common/models/__init__.py | 1 + CHAP/common/models/common.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/CHAP/common/models/__init__.py b/CHAP/common/models/__init__.py index d95277f7..062ef9ba 100755 --- a/CHAP/common/models/__init__.py +++ b/CHAP/common/models/__init__.py @@ -5,5 +5,6 @@ from CHAP.common.models.common import ( BinarizeConfig, ImageProcessorConfig, + IndexSliceConfig, UnstructuredToStructuredConfig, ) diff --git a/CHAP/common/models/common.py b/CHAP/common/models/common.py index c24b9178..b2a9b486 100755 --- a/CHAP/common/models/common.py +++ b/CHAP/common/models/common.py @@ -129,6 +129,25 @@ def validate_vrange(cls, vrange): for i in index_range] +class IndexSliceConfig(CHAPBaseModel): + """Configuration for a python `sslice` object. + + :ivar start: A `start` parameter for `slice()`, defaults to 0. + :vartype start: int, optional + :ivar stop: A `stop` parameter for `slice()`, defaults to -1. + :vartype stop: int, optional + :ivar step: A `step` parameter for `slice()`, defaults to 1. + :vartype step: int, optional + """ + start: Optional[int] = 0 + stop: Optional[int] = -1 + step: Optional[int] = 1 + + @property + def _slice(self): + return slice(self.start, self.stop, self.step) + + class UnstructuredToStructuredConfig(CHAPBaseModel): """Configuration class to reshape data in an NXdata from an "unstructured" to a "structured" representation. From 0ec1f60f935f6e0cf41c5ff6da5748ee7a03dc9d Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Mon, 6 Apr 2026 11:44:32 -0400 Subject: [PATCH 14/76] add: NexusValuesWriter.idx_slice --- CHAP/common/writer.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/CHAP/common/writer.py b/CHAP/common/writer.py index d6c6f1e3..b2b6b714 100755 --- a/CHAP/common/writer.py +++ b/CHAP/common/writer.py @@ -25,6 +25,7 @@ from CHAP import Writer from CHAP.pipeline import PipelineItem from CHAP.writer import validate_writer_model +from CHAP.common.models import IndexSliceConfig def validate_model(model): @@ -550,9 +551,16 @@ class NexusValuesWriter(Writer): data shape. If `False`, any mismatching shapes will just raise an error. Defaults to `False`. :vartype resize_axis: Union[int, Literal[False]], optional + :ivar idx_slice: Configuration for a slice object that will be + used to select the slice of the target array to write to. Used + only if the `"idx"` key is not present for an item in the + newest `PipelineData` item in `data`. Defaults to + `IndexSliceConfig()`. + :vartype idx_slice: CHAP.common.models.IndexSliceConfig, optional """ path_prefix: str = '' resize_axis: Union[int, Literal[False]] = False + idx_slice: Optional[IndexSliceConfig] = IndexSliceConfig() def write(self, data, filename): """Write new values specified in `data` to the exising NeXus @@ -576,12 +584,15 @@ def write(self, data, filename): data = self.get_pipelinedata_item(data, remove=self.remove) for d in data: with NXFile(filename, 'a') as nxroot: - self.nxs_writer( - nxroot=nxroot, - path=os.path.join(self.path_prefix, d['path']), - idx=d['idx'], - data=d['data'] - ) + try: + self.nxs_writer( + nxroot=nxroot, + path=os.path.join(self.path_prefix, d['path']), + idx=d.get('idx', self.idx_slice._slice), + data=d['data'] + ) + except Exception as exc: + self.logger.error(exc) def nxs_writer(self, nxroot, path, idx, data): """Write data to a specific NeXus file. @@ -613,9 +624,15 @@ def nxs_writer(self, nxroot, path, idx, data): # Access the specified dataset dataset = nxroot[path] + self.logger.debug( + f'chunks, maxshape = {dataset.chunks}, {dataset.maxshape}' + ) # Check that the slice shape matches the data shape data = np.asarray(data) + self.logger.debug( + f'data shape, target shape = {data.shape}, {dataset[idx].shape}' + ) if dataset[idx].shape != data.shape: if self.resize_axis is not False: # Resize along the specified axis From 8fd846bbb8bf0925478a9cfffd1d06cf911546e9 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Mon, 6 Apr 2026 11:48:28 -0400 Subject: [PATCH 15/76] add: MapProcessor.remove_constant_dims --- CHAP/common/processor.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/CHAP/common/processor.py b/CHAP/common/processor.py index 31f11a3c..70f924ad 100755 --- a/CHAP/common/processor.py +++ b/CHAP/common/processor.py @@ -891,15 +891,20 @@ class MapProcessor(Processor): instance of common.models.map.MapConfig. Any values in `'config'` supplant their corresponding values obtained from the pipeline data configuration. - :type config: Union[dict, common.models.map.MapConfig] + :vartype config: Union[dict, common.models.map.MapConfig] :ivar detector_config: Detector configurations of the detectors to include raw data for in the returned NeXus NXentry object (overruling detector info in the pipeline data, if present). - :type detector_config: Union[ + :vartype detector_config: Union[ dict, common.models.map.DetectorConfig] :ivar num_proc: Number of processors used to read map, defaults to `1`. - :type num_proc: int, optional + :vartype num_proc: int, optional + :ivar remove_constant_dims: Flag to indicate that any + `independent_dimension`s in the map whose values are constant + across the map should be exluded from the output + `NXentry`. Defaults to `True`. + :vartype remove_constant_dims: bool, optional """ pipeline_fields: dict = Field( default = { @@ -909,6 +914,7 @@ class MapProcessor(Processor): config: MapConfig detector_config: DetectorConfig = DetectorConfig(detectors=[]) num_proc: Optional[conint(gt=0)] = 1 + remove_constant_dims: Optional[bool] = True @field_validator('num_proc') @classmethod @@ -1266,7 +1272,7 @@ def linkdims(nxgroup, nxdata_source): nxentry.independent_dimensions = NXdata() if len(constant_dim) < len(self.config.independent_dimensions): for i, dim in enumerate(self.config.independent_dimensions): - if i not in constant_dim: + if i not in constant_dim or not self.remove_constant_dims: nxentry.independent_dimensions[dim.label] = NXfield( independent_dimensions[i], dim.label, attrs={'units': dim.units, From 3205a4c6f6cadbed389ff108f10dd299a32e041f Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Mon, 6 Apr 2026 11:58:00 -0400 Subject: [PATCH 16/76] fix: include independent_dimensions.index in results of MapSliceProcessor just in case all configured ind_dims were constant across the map and removed from the NXentry by MapProcessor --- CHAP/common/map_utils.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/CHAP/common/map_utils.py b/CHAP/common/map_utils.py index 541c0d70..5515d358 100644 --- a/CHAP/common/map_utils.py +++ b/CHAP/common/map_utils.py @@ -118,17 +118,28 @@ def process(self, data, data_points = [ { - 'path': f'{self.map_config.title}/scalar_data/{s_d.label}', - 'data': np.asarray([ - s_d.get_value( - scans, self.scan_number, i, - scalar_data=self.map_config.scalar_data) - for i in scan_indices - ]), + 'path': f'{self.map_config.title}/independent_dimensions/index', + 'data': np.asarray( + [i for i in range(index_offset, index_offset + npts_scan)] + ), 'idx': map_indices } - for s_d in self.map_config.all_scalar_data ] + data_points.extend( + [ + { + 'path': f'{self.map_config.title}/scalar_data/{s_d.label}', + 'data': np.asarray([ + s_d.get_value( + scans, self.scan_number, i, + scalar_data=self.map_config.scalar_data) + for i in scan_indices + ]), + 'idx': map_indices + } + for s_d in self.map_config.all_scalar_data + ] + ) data_points.extend( [ { From 3791322b570f6a6b5a673a61ee1c9be13df85048 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Mon, 6 Apr 2026 12:25:12 -0400 Subject: [PATCH 17/76] fix: bug with MapProcessor.remove_constant_dims --- CHAP/common/processor.py | 89 +++++++++++++++++++++++----------------- 1 file changed, 52 insertions(+), 37 deletions(-) diff --git a/CHAP/common/processor.py b/CHAP/common/processor.py index 70f924ad..b9693364 100755 --- a/CHAP/common/processor.py +++ b/CHAP/common/processor.py @@ -1263,31 +1263,45 @@ def linkdims(nxgroup, nxdata_source): **self.config.sample.model_dump()) # Set up independent dimensions NeXus NXdata group - # (squeeze out constant dimensions) - constant_dim = [] - for i, dim in enumerate(self.config.independent_dimensions): - unique = np.unique(independent_dimensions[i]) - if unique.size == 1: - constant_dim.append(i) nxentry.independent_dimensions = NXdata() - if len(constant_dim) < len(self.config.independent_dimensions): + if self.remove_constant_dims: + # (squeeze out constant dimensions) + constant_dim = [] for i, dim in enumerate(self.config.independent_dimensions): - if i not in constant_dim or not self.remove_constant_dims: - nxentry.independent_dimensions[dim.label] = NXfield( - independent_dimensions[i], dim.label, - attrs={'units': dim.units, - 'long_name': f'{dim.label} ({dim.units})', - 'data_type': dim.data_type, - 'local_name': dim.name}, - maxshape=(None, *independent_dimensions[i].shape[1:]), - chunks=(1, *independent_dimensions[i].shape[1:]) - ) + unique = np.unique(independent_dimensions[i]) + if unique.size == 1: + constant_dim.append(i) + if len(constant_dim) < len(self.config.independent_dimensions): + for i, dim in enumerate(self.config.independent_dimensions): + if i not in constant_dim or not self.remove_constant_dims: + nxentry.independent_dimensions[dim.label] = NXfield( + independent_dimensions[i], dim.label, + attrs={'units': dim.units, + 'long_name': f'{dim.label} ({dim.units})', + 'data_type': dim.data_type, + 'local_name': dim.name}, + maxshape=(None, + *independent_dimensions[i].shape[1:]), + chunks=(1, *independent_dimensions[i].shape[1:]) + ) + else: + nxentry.independent_dimensions.index = NXfield( + np.arange(independent_dimensions[0].size), 'index', + maxshape=(None,), + chunks=(1,) + ) else: - nxentry.independent_dimensions.index = NXfield( - np.arange(independent_dimensions[0].size), 'index', - maxshape=(None,), - chunks=(1,) - ) + for i, dim in enumerate(self.config.independent_dimensions): + nxentry.independent_dimensions[dim.label] = NXfield( + independent_dimensions[i], dim.label, + attrs={'units': dim.units, + 'long_name': f'{dim.label} ({dim.units})', + 'data_type': dim.data_type, + 'local_name': dim.name}, + maxshape=(None, + *independent_dimensions[i].shape[1:]), + chunks=(1, *independent_dimensions[i].shape[1:]) + ) # Set up scalar data NeXus NXdata group # (add the constant independent dimensions) @@ -1319,21 +1333,22 @@ def linkdims(nxgroup, nxdata_source): maxshape=(None, *all_scalar_data[-1].shape[1:]), chunks=(1, *all_scalar_data[-1].shape[1:]) )) - for i, dim in enumerate(deepcopy(self.config.independent_dimensions)): - if i in constant_dim: - scalar_signals.append(dim.label) - scalar_data.append(NXfield( - independent_dimensions[i], dim.label, - attrs={'units': dim.units, - 'long_name': f'{dim.label} ({dim.units})', - 'data_type': dim.data_type, - 'local_name': dim.name}, - maxshape=(None, *independent_dimensions[i].shape[1:]), - chunks=(1, *independent_dimensions[i].shape[1:]) - )) - self.config.all_scalar_data.append( - PointByPointScanData(**dim.model_dump())) - self.config.independent_dimensions.remove(dim) + if self.remove_constant_dims: + for i, dim in enumerate(deepcopy(self.config.independent_dimensions)): + if i in constant_dim: + scalar_signals.append(dim.label) + scalar_data.append(NXfield( + independent_dimensions[i], dim.label, + attrs={'units': dim.units, + 'long_name': f'{dim.label} ({dim.units})', + 'data_type': dim.data_type, + 'local_name': dim.name}, + maxshape=(None, *independent_dimensions[i].shape[1:]), + chunks=(1, *independent_dimensions[i].shape[1:]) + )) + self.config.all_scalar_data.append( + PointByPointScanData(**dim.model_dump())) + self.config.independent_dimensions.remove(dim) if scalar_signals: nxentry.scalar_data = NXdata() for k, v in zip(scalar_signals, scalar_data): From 6ad40220755281ed6dd34e708bd2db7b152ea9c2 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Mon, 6 Apr 2026 13:09:11 -0400 Subject: [PATCH 18/76] add: link to all scalar_data fields from main .data NXdata group --- CHAP/common/processor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/CHAP/common/processor.py b/CHAP/common/processor.py index b9693364..29792425 100755 --- a/CHAP/common/processor.py +++ b/CHAP/common/processor.py @@ -1257,6 +1257,7 @@ def linkdims(nxgroup, nxdata_source): NXfield(value=scans.scan_numbers, dtype='int8', attrs={'spec_file': str(scans.spec_file)}) + nxentry.data = NXdata() # Add sample metadata nxentry[self.config.sample.name] = NXsample( @@ -1353,6 +1354,7 @@ def linkdims(nxgroup, nxdata_source): nxentry.scalar_data = NXdata() for k, v in zip(scalar_signals, scalar_data): nxentry.scalar_data[k] = v + nxentry.data.makelink(nxentry.scalar_data[k]) if 'SCAN_N' in scalar_signals: nxentry.scalar_data.attrs['signal'] = 'SCAN_N' else: @@ -1361,8 +1363,7 @@ def linkdims(nxgroup, nxdata_source): nxentry.scalar_data.attrs['auxiliary_signals'] = scalar_signals # Add detector data - nxdata = NXdata() - nxentry.data = nxdata + nxdata = nxentry.data nxentry.data.set_default() detector_ids = [] for k, v in self.config.attrs.items(): From 59979381cae2d4555bf4c9d7b60a3e00b9795bed Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Mon, 6 Apr 2026 13:46:19 -0400 Subject: [PATCH 19/76] fix: add data type check/conversion to NexusValuesWriter --- CHAP/common/writer.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/CHAP/common/writer.py b/CHAP/common/writer.py index b2b6b714..b426d60e 100755 --- a/CHAP/common/writer.py +++ b/CHAP/common/writer.py @@ -628,8 +628,16 @@ def nxs_writer(self, nxroot, path, idx, data): f'chunks, maxshape = {dataset.chunks}, {dataset.maxshape}' ) - # Check that the slice shape matches the data shape data = np.asarray(data) + + # Check the datatype + if data.dtype != dataset.dtype: + self.logger.warning( + f'Converting new data (type: {data.dtype}) to {dataset.dtype}' + ) + data = data.astype(dataset.dtype) + + # Check the shape self.logger.debug( f'data shape, target shape = {data.shape}, {dataset[idx].shape}' ) From 06ce58fcf0708d651447f447f8f15481cd6cd52e Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Mon, 6 Apr 2026 14:00:14 -0400 Subject: [PATCH 20/76] fix: peak_fit_info in StrainAnalysisProcessor when setup is True --- CHAP/edd/processor.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/CHAP/edd/processor.py b/CHAP/edd/processor.py index 7ea2d660..c1bb5534 100755 --- a/CHAP/edd/processor.py +++ b/CHAP/edd/processor.py @@ -2814,8 +2814,12 @@ def _add_fit_nxcollection(self, nxdetector, fit_type, hkls, peak_fit_info): nxdata.best_fit = NXfield(shape=shape, dtype=np.float64) nxdata.included_peaks = NXfield( shape=[shape[0], len(hkls)], dtype=bool) - nxdata.included_peaks.attrs['hkls'] = dumps(peak_fit_info['hkls']) - nxdata.included_peaks.attrs['use_peaks'] = peak_fit_info['use_peaks'] + nxdata.included_peaks.attrs['hkls'] = dumps( + peak_fit_info.get('hkls', '') + ) + nxdata.included_peaks.attrs['use_peaks'] = peak_fit_info.get( + 'use_peaks' + ) nxdata.residual = NXfield(shape=shape, dtype=np.float64) nxdata.redchi = NXfield(shape=[shape[0]], dtype=np.float64) nxdata.success = NXfield(shape=[shape[0]], dtype='bool') @@ -2862,7 +2866,7 @@ def _add_fit_nxcollection(self, nxdetector, fit_type, hkls, peak_fit_info): nxcollection[hkl_name].sigmas.errors = NXfield( shape=[shape[0]], dtype=np.float64) nxcollection[hkl_name].sigmas.attrs['signal'] = 'values' - if peak_fit_info['peak_models'] == 'pvoigt': + if peak_fit_info.get('peak_models') == 'pvoigt': # Report HKL peak fractions nxcollection[hkl_name].fractions = NXdata() self._linkdims( @@ -3020,6 +3024,12 @@ def _get_nxroot(self, nxentry, calibration_config): nxprocess.strain_analysis_config = \ self.config.model_dump_json() + if len(self._peak_fit_info) == 0: + # FIX this is a temporary fix to be able to run update + # after setup. + self._peak_fit_info = [{}] * len(self.detector_config.detectors) + self.logger.warning('Missing peak_fit_info') + # Loop over the detectors to fill in the nxprocess for energies, mask, nxdata, detector, peak_fit_info in zip( self._energies, self._masks, self._nxdata_detectors, From 3cbcb31022d082600e568a450dd3841a3fb37983 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Mon, 6 Apr 2026 14:12:04 -0400 Subject: [PATCH 21/76] add: make all NXfields resizable in StrainAnalysisProcessor results --- CHAP/edd/processor.py | 109 +++++++++++++++++++++++++++++++----------- 1 file changed, 82 insertions(+), 27 deletions(-) diff --git a/CHAP/edd/processor.py b/CHAP/edd/processor.py index c1bb5534..9d7f8e60 100755 --- a/CHAP/edd/processor.py +++ b/CHAP/edd/processor.py @@ -2811,18 +2811,32 @@ def _add_fit_nxcollection(self, nxdetector, fit_type, hkls, peak_fit_info): nxcollection.results = NXdata() nxdata = nxcollection.results self._linkdims(nxdata, det_nxdata) - nxdata.best_fit = NXfield(shape=shape, dtype=np.float64) + nxdata.best_fit = NXfield( + shape=shape, dtype=np.float64, + maxshape=(None, *shape[1:]), chunks=(1, *shape[1:]) + ) nxdata.included_peaks = NXfield( - shape=[shape[0], len(hkls)], dtype=bool) + shape=[shape[0], len(hkls)], dtype=bool, + maxshape=(None, *shape[1:]), chunks=(1, *shape[1:]) + ) nxdata.included_peaks.attrs['hkls'] = dumps( peak_fit_info.get('hkls', '') ) nxdata.included_peaks.attrs['use_peaks'] = peak_fit_info.get( 'use_peaks' ) - nxdata.residual = NXfield(shape=shape, dtype=np.float64) - nxdata.redchi = NXfield(shape=[shape[0]], dtype=np.float64) - nxdata.success = NXfield(shape=[shape[0]], dtype='bool') + nxdata.residual = NXfield( + shape=shape, dtype=np.float64, + maxshape=(None, *shape[1:]), chunks=(1, *shape[1:]) + ) + nxdata.redchi = NXfield( + shape=[shape[0]], dtype=np.float64, + maxshape=(None,), chunks=(1,) + ) + nxdata.success = NXfield( + shape=[shape[0]], dtype='bool', + maxshape=(None,), chunks=(1,) + ) # Peak-by-peak results for hkl in hkls: @@ -2842,9 +2856,13 @@ def _add_fit_nxcollection(self, nxdetector, fit_type, hkls, peak_fit_info): nxcollection[hkl_name].amplitudes, det_nxdata, skip_field_dims=['energy']) nxcollection[hkl_name].amplitudes.values = NXfield( - shape=[shape[0]], dtype=np.float64, attrs={'units': 'counts'}) + shape=[shape[0]], dtype=np.float64, attrs={'units': 'counts'}, + maxshape=(None,), chunks=(1,) + ) nxcollection[hkl_name].amplitudes.errors = NXfield( - shape=[shape[0]], dtype=np.float64) + shape=[shape[0]], dtype=np.float64, + maxshape=(None,), chunks=(1,) + ) nxcollection[hkl_name].amplitudes.attrs['signal'] = 'values' # Report HKL peak centers nxcollection[hkl_name].centers = NXdata() @@ -2852,9 +2870,13 @@ def _add_fit_nxcollection(self, nxdetector, fit_type, hkls, peak_fit_info): nxcollection[hkl_name].centers, det_nxdata, skip_field_dims=['energy']) nxcollection[hkl_name].centers.values = NXfield( - shape=[shape[0]], dtype=np.float64, attrs={'units': 'keV'}) + shape=[shape[0]], dtype=np.float64, attrs={'units': 'keV'}, + maxshape=(None,), chunks=(1,) + ) nxcollection[hkl_name].centers.errors = NXfield( - shape=[shape[0]], dtype=np.float64) + shape=[shape[0]], dtype=np.float64, + maxshape=(None,), chunks=(1,) + ) nxcollection[hkl_name].centers.attrs['signal'] = 'values' # Report HKL peak FWHMs nxcollection[hkl_name].sigmas = NXdata() @@ -2862,9 +2884,13 @@ def _add_fit_nxcollection(self, nxdetector, fit_type, hkls, peak_fit_info): nxcollection[hkl_name].sigmas, det_nxdata, skip_field_dims=['energy']) nxcollection[hkl_name].sigmas.values = NXfield( - shape=[shape[0]], dtype=np.float64, attrs={'units': 'keV'}) + shape=[shape[0]], dtype=np.float64, attrs={'units': 'keV'}, + maxshape=(None,), chunks=(1,) + ) nxcollection[hkl_name].sigmas.errors = NXfield( - shape=[shape[0]], dtype=np.float64) + shape=[shape[0]], dtype=np.float64, + maxshape=(None,), chunks=(1,) + ) nxcollection[hkl_name].sigmas.attrs['signal'] = 'values' if peak_fit_info.get('peak_models') == 'pvoigt': # Report HKL peak fractions @@ -2873,9 +2899,13 @@ def _add_fit_nxcollection(self, nxdetector, fit_type, hkls, peak_fit_info): nxcollection[hkl_name].fractions, det_nxdata, skip_field_dims=['energy']) nxcollection[hkl_name].fractions.values = NXfield( - shape=[shape[0]], dtype=np.float64) + shape=[shape[0]], dtype=np.float64, + maxshape=(None,), chunks=(1,) + ) nxcollection[hkl_name].fractions.errors = NXfield( - shape=[shape[0]], dtype=np.float64) + shape=[shape[0]], dtype=np.float64, + maxshape=(None,), chunks=(1,) + ) nxcollection[hkl_name].fractions.attrs['signal'] = 'values' # Report HKL peak strains (unconstrained only) if fit_type == 'unconstrained': @@ -2885,11 +2915,17 @@ def _add_fit_nxcollection(self, nxdetector, fit_type, hkls, peak_fit_info): skip_field_dims=['energy']) values = np.full(shape=[shape[0]], fill_value=np.nan) nxcollection[hkl_name].strains.values = NXfield( - value=values, shape=[shape[0]], dtype=np.float64) + value=values, shape=[shape[0]], dtype=np.float64, + maxshape=(None,), chunks=(1,) + ) nxcollection[hkl_name].strains.errors = NXfield( - value=values, shape=[shape[0]], dtype=np.float64) + value=values, shape=[shape[0]], dtype=np.float64, + maxshape=(None,), chunks=(1,) + ) nxcollection[hkl_name].strains.residuals = NXfield( - value=values, shape=[shape[0]], dtype=np.float64) + value=values, shape=[shape[0]], dtype=np.float64, + maxshape=(None,), chunks=(1,) + ) nxcollection[hkl_name].strains.attrs['signal'] = 'values' @@ -3061,35 +3097,54 @@ def _get_nxroot(self, nxentry, calibration_config): value=energies[mask], attrs={'units': 'keV'}) det_nxdata.norm = NXfield( dtype=np.float64, - shape=(num_points,)) + shape=(num_points,), + maxshape=(None,), chunks=(1,) + ) det_nxdata.tth = NXfield( dtype=np.float64, shape=(num_points,), - attrs={'units':'degrees', 'long_name': '2\u03B8 (degrees)'}) + attrs={'units':'degrees', 'long_name': '2\u03B8 (degrees)'}, + maxshape=(None,), chunks=(1,), + ) det_nxdata.uniform_strain = NXfield( dtype=np.float64, shape=(num_points,), - attrs={'long_name': 'Strain from uniform fit (\u03B5)'}) - #attrs={'long_name': 'Strain from uniform fit (\u03BC\u03B5)'}) + attrs={'long_name': 'Strain from uniform fit (\u03B5)'}, + # attrs={'long_name': 'Strain from uniform fit (\u03BC\u03B5)'} + maxshape=(None,), chunks=(1,) + ) + det_nxdata.unconstrained_strain = NXfield( dtype=np.float64, shape=(num_points,), attrs={'long_name': - 'Strain from unconstrained fit (\u03B5)'}) - #'Strain from unconstrained fit (\u03BC\u03B5)'}) + 'Strain from unconstrained fit (\u03B5)'}, + #'Strain from unconstrained fit (\u03BC\u03B5)'}, + maxshape=(None,), chunks=(1,) + ) det_nxdata.unconstrained_strain_stdev = NXfield( dtype=np.float64, shape=(num_points,), attrs={'long_name': 'Standard deviation in strain from unconstrained ' - 'fit (\u03B5)'}) - #'fit (\u03BC\u03B5)'}) + 'fit (\u03B5)'}, + #'fit (\u03BC\u03B5)'} + maxshape=(None,), chunks=(1,) + ) # Add the detector data + _intensity=np.asarray( + [ + data[i].astype(np.float64)[mask] + for i in range(num_points) + ] + ) det_nxdata.intensity = NXfield( - value=np.asarray([data[i].astype(np.float64)[mask] - for i in range(num_points)]), - attrs={'units': 'counts'}) + value=_intensity, + attrs={'units': 'counts'}, + maxshape=(None,*_intensity.shape[1:]), + chunks=(1,*_intensity.shape[1:]) + ) det_nxdata.attrs['signal'] = 'intensity' # Get the unique HKLs and lattice spacings for the strain From f13a81b6be493e4054510057b941254fc7d2dfd5 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Fri, 10 Apr 2026 12:17:38 -0400 Subject: [PATCH 22/76] temporary: add OrnlStrainAnalysisProcessor --- CHAP/edd/processor.py | 439 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 439 insertions(+) diff --git a/CHAP/edd/processor.py b/CHAP/edd/processor.py index 9d7f8e60..cad34d02 100755 --- a/CHAP/edd/processor.py +++ b/CHAP/edd/processor.py @@ -3523,6 +3523,445 @@ def _strain_analysis(self): return points +class OrnlStrainAnalysisProcessor(StrainAnalysisProcessor): + """Processor that handles strain analysis in the middle of + autonomously driven experients.""" + + def process(self, data): + """Setup the strain analysis and/or return the strain analysis + results as a list of updated points or a + `nexusformat.nexus.NXroot` object. + + :param data: Input data containing configurations for a map, + completed energy/tth calibration, and (optionally) + parameters for the strain analysis. + :type data: list[PipelineData] + :raises RuntimeError: Unable to get a valid strain analysis + configuration. + :return: The strain analysis setup or results, a list of + byte stream representions of Matplotlib figures and an + animation of the fit results. + :rtype: Union[list[dict[str, object]], + nexusformat.nexus.NXroot], PipelineData, PipelineData + """ + # Third party modules + from nexusformat.nexus import ( + NXentry, + NXroot, + ) + + # Local modules + from CHAP.utils.general import list_to_string + + if not (self.setup or self.update): + raise RuntimeError('Illegal combination of setup and update') + if not self.update: + if self.interactive: + self.logger.warning('Interactive option disabled during setup') + self.interactive = False + if self.save_figures: + self.logger.warning( + 'Saving figures option disabled during setup') + self.save_figures = False + self._animation = [] + + # Load the pipeline input data + try: + nxobject = self.get_data(data) + if isinstance(nxobject, NXroot): + nxroot = nxobject + elif isinstance(nxobject, NXentry): + nxroot = NXroot() + nxroot[nxobject.nxname] = nxobject + nxobject.set_default() + else: + raise RuntimeError + except Exception as exc: + raise RuntimeError( + 'No valid input in the pipeline data') from exc + + # Load the detector data + # FIX set rel_height_cutoff + nxentry = self.get_default_nxentry(nxroot) + for k, v in nxentry.scalar_data.items(): + self.logger.debug(f'{k}.chunks = {v.chunks}') + self.logger.debug(f'{k}.maxshape = {v.maxshape}') + + nxdata = nxentry[nxentry.default] + + # Load the validated calibration configuration + # FIX make this a class field and add to pipeline_fields too? + # NB: be sure that adding this field does not mess up + # self.detector_config during pydantic validation + calibration_config = self.get_config( + data, schema='edd.models.MCATthCalibrationConfig', remove=False) + + # Load the validated calibration detector configurations + calibration_detector_config = self.get_data( + data, schema='edd.models.MCATthCalibrationConfig') + calibration_detectors = [ + MCADetectorCalibration(**d) + for d in calibration_detector_config.get('detectors', [])] + calibration_detector_ids = [d.get_id() for d in calibration_detectors] + + # Check for available raw detector data and for the available + # calibration data + if not self.detector_config.detectors: + self.detector_config.detectors = [ + MCADetectorStrainAnalysis( + id=id_, processor_type='strainanalysis') + for id_ in nxentry.detector_ids] + self.detector_config.update_detectors() + skipped_detectors = [] + sskipped_detectors = [] + detectors = [] + for detector in self.detector_config.detectors: + detector_id = detector.get_id() + if detector_id not in nxdata: + skipped_detectors.append(detector_id) + elif detector_id not in calibration_detector_ids: + sskipped_detectors.append(detector_id) + else: + raw_detector_data = nxdata[detector_id].nxdata + if raw_detector_data.ndim != 2: + self.logger.warning( + f'Skipping detector {detector_id} (Illegal data shape ' + f'{raw_detector_data.shape})') + elif raw_detector_data.sum(): + for k, v in nxdata[detector_id].attrs.items(): + detector.attrs[k] = v.nxdata + if self.config.rel_height_cutoff is not None: + detector.rel_height_cutoff = \ + self.config.rel_height_cutoff + detector.add_calibration( + calibration_detectors[ + int(calibration_detector_ids.index(detector_id))]) + detectors.append(detector) + else: + self.logger.warning( + f'Skipping detector {detector_id} (zero intensity)') + if len(skipped_detectors) == 1: + self.logger.warning( + f'Skipping detector {skipped_detectors[0]} ' + '(no raw data)') + elif skipped_detectors: + skipped_detectors = [int(d) for d in skipped_detectors] + self.logger.warning( + 'Skipping detectors ' + f'{list_to_string(skipped_detectors)} (no raw data)') + if len(sskipped_detectors) == 1: + self.logger.warning( + f'Skipping detector {sskipped_detectors[0]} ' + '(no raw data)') + elif sskipped_detectors: + skipped_detectors = [int(d) for d in sskipped_detectors] + self.logger.warning( + 'Skipping detectors ' + f'{list_to_string(skipped_detectors)} (no calibration data)') + self.detector_config.detectors = detectors + if not self.detector_config.detectors: + raise ValueError('No valid data or unable to match an available ' + 'calibrated detector for the strain analysis') + + # Load the raw MCA data and compute the detector bin energies + # and the mean spectra + self._setup_detector_data( + nxentry[nxentry.default], + strain_analysis_config=self.config, update=self.update) + + # Apply the energy mask + self._apply_energy_mask() + + # Get the mask and HKLs used in the strain analysis + self._get_mask_hkls() + + # Apply the combined energy ranges mask + self._apply_combined_mask() + + # Setup and/or run the strain analysis + points = [] + if self.update: + points = self._strain_analysis() + values = self._get_values(nxroot, points) + if self.setup: + nxprocess = self._get_nxprocess(nxentry, calibration_config) + if points: + self.logger.info(f'Adding {len(points)} points') + self.add_points(nxprocess, points, logger=self.logger) + self.logger.info(f'... done') + else: + self.logger.warning('Skip adding points') + if not (self._figures or self._animation): + return nxprocess + ret = [nxprocess] + else: + if not (self._figures or self._animation): + return values + ret = [values] + if self._figures: + ret.append( + PipelineData( + name=self.__name__, data=self._figures, + schema='common.write.ImageWriter')) + if self._animation: + ret.append( + PipelineData( + name=self.__name__, data=self._animation, + schema='common.write.ImageWriter')) + return tuple(ret) + + def _get_nxprocess(self, nxentry, calibration_config): + """Return a standalone NXprocess for the strain analysis + results & metadata. + + :param nxentry: Strain analysis map, including the raw + MCA data. + :type nxentry: nexusformat.nexus.NXentry + :param calibration_config: 2&theta calibration configuration. + :type calibration_config: + CHAP.edd.models.MCATthCalibrationConfig + :return: Strain analysis results & associated metadata. + :rtype: nexusformat.nexus.NXprocess + """ + # Third party modules + # pylint: disable=no-name-in-module + from nexusformat.nexus import ( + NXdetector, + NXfield, + NXprocess, + NXroot, + ) + # pylint: enable=no-name-in-module + + # Third party modules + from json import dumps + + # Local modules + from CHAP.edd.utils import get_unique_hkls_ds + from CHAP.utils.general import nxcopy + + if not self.interactive and not self.config.materials: + raise ValueError( + 'No material provided. Provide a material in the ' + 'StrainAnalysis Configuration, or re-run the pipeline with ' + 'the --interactive flag.') + + nxprocess = NXprocess() + nxprocess.calibration_config = \ + calibration_config.model_dump_json() + nxprocess.strain_analysis_config = \ + self.config.model_dump_json() + + if len(self._peak_fit_info) == 0: + # FIX this is a temporary fix to be able to run update + # after setup. + self._peak_fit_info = [{}] * len(self.detector_config.detectors) + self.logger.warning('Missing peak_fit_info') + + # Loop over the detectors to fill in the nxprocess + for energies, mask, nxdata, detector, peak_fit_info in zip( + self._energies, self._masks, self._nxdata_detectors, + self.detector_config.detectors, self._peak_fit_info): + + # Get the current data object + data = nxdata.nxsignal + num_points = data.shape[0] + + # Setup the NXdetector object for the current detector + self.logger.debug( + f'Setting up NXdetector group for {detector.get_id()}') + nxdetector = NXdetector() + nxprocess[detector.get_id()] = nxdetector + nxdetector.local_name = detector.get_id() + nxdetector.detector_config = detector.model_dump_json() + nxdetector.peak_fit_info = dumps(peak_fit_info) + nxdetector.data = nxcopy(nxdata, exclude_nxpaths='detector_data') + det_nxdata = nxdetector.data + if 'axes' in det_nxdata.attrs: + if isinstance(det_nxdata.attrs['axes'], str): + det_nxdata.attrs['axes'] = [ + det_nxdata.attrs['axes'], 'energy'] + else: + det_nxdata.attrs['axes'].append('energy') + else: + det_nxdata.attrs['axes'] = ['energy'] + det_nxdata.energy = NXfield( + value=energies[mask], attrs={'units': 'keV'}) + det_nxdata.norm = NXfield( + dtype=np.float64, + shape=(num_points,), + maxshape=(None,), chunks=(1,) + ) + det_nxdata.tth = NXfield( + dtype=np.float64, + shape=(num_points,), + attrs={'units':'degrees', 'long_name': '2\u03B8 (degrees)'}, + maxshape=(None,), chunks=(1,), + ) + det_nxdata.uniform_strain = NXfield( + dtype=np.float64, + shape=(num_points,), + attrs={'long_name': 'Strain from uniform fit (\u03B5)'}, + # attrs={'long_name': 'Strain from uniform fit (\u03BC\u03B5)'} + maxshape=(None,), chunks=(1,) + ) + + det_nxdata.unconstrained_strain = NXfield( + dtype=np.float64, + shape=(num_points,), + attrs={'long_name': + 'Strain from unconstrained fit (\u03B5)'}, + #'Strain from unconstrained fit (\u03BC\u03B5)'}, + maxshape=(None,), chunks=(1,) + ) + det_nxdata.unconstrained_strain_stdev = NXfield( + dtype=np.float64, + shape=(num_points,), + attrs={'long_name': + 'Standard deviation in strain from unconstrained ' + 'fit (\u03B5)'}, + #'fit (\u03BC\u03B5)'} + maxshape=(None,), chunks=(1,) + ) + + # Add the detector data + _intensity=np.asarray( + [ + data[i].astype(np.float64)[mask] + for i in range(num_points) + ] + ) + det_nxdata.intensity = NXfield( + value=_intensity, + attrs={'units': 'counts'}, + maxshape=(None,*_intensity.shape[1:]), + chunks=(1,*_intensity.shape[1:]) + ) + det_nxdata.attrs['signal'] = 'intensity' + + # Get the unique HKLs and lattice spacings for the strain + # analysis materials + hkls, _ = get_unique_hkls_ds( + self.config.materials, tth_max=detector.tth_max, + tth_tol=detector.tth_tol) + + # Get the HKLs and lattice spacings that will be used for + # fitting + hkls_fit = np.asarray([hkls[i] for i in detector.hkl_indices]) + + # Add the uniform fit nxcollection + self._add_fit_nxcollection( + nxdetector, 'uniform', hkls_fit, peak_fit_info) + + # Add the unconstrained fit nxcollection + self._add_fit_nxcollection( + nxdetector, 'unconstrained', hkls_fit, peak_fit_info) + + # Add the strain fields + tth_map = detector.get_tth_map((num_points,)) + det_nxdata.tth.nxdata = tth_map + + return nxprocess + + def _linkdims( + self, nxgroup, nxdata_source, add_field_dims=None, + skip_field_dims=None, oversampling_axis=None): + """Link the dimensions for a 'nexusformat.nexus.NXgroup` + object. + """ + # Third party modules + from nexusformat.nexus import NXfield, NXroot + from nexusformat.nexus.tree import NXlinkfield + + if not isinstance(nxgroup.nxroot, NXroot): + self.logger.warning( + 'Skipping linkdims -- type(nxgroup.nxroot) = ' + + f'{type(nxgroup.nxroot)}' + ) + return + super()._linkdims( + nxgroup, nxdata_source, add_field_dims=add_field_dims, + skip_field_dims=skip_field_dims, oversampling_axis=skip_field_dims + ) + + def _get_values(self, nxroot, points): + """Return list of dictionaries with new srtain results values + suitable for writing with `common.NexusValuesWriter`. + + :param nxroot: + :type nxroot: + :param points: + :type points: + :returns: + :rtype: + """ + # pylint: disable=no-name-in-module + from nexusformat.nexus import ( + NXdetector, + NXprocess + ) + # pylint: enable=no-name-in-module + + nxprocess = None + for nxobject in nxroot.values(): + if isinstance(nxobject, NXprocess): + nxprocess = nxobject + break + if nxprocess is None: + raise RuntimeError('Unable to find the strainanalysis object') + + nxdata_detectors = [] + for nxobject in nxprocess.values(): + if isinstance(nxobject, NXdetector): + nxdata_detectors.append(nxobject.data) + if not nxdata_detectors: + raise RuntimeError( + 'Unable to find detector data in strainanalysis object') + axes = get_axes(nxdata_detectors[0], skip_axes=['energy']) + + values = [] + if axes: + coords = np.asarray( + [nxdata_detectors[0][a].nxdata for a in axes]).T + + def get_matching_indices(all_coords, point_coords, decimals=None): + if isinstance(decimals, int): + all_coords = np.round(all_coords, decimals=decimals) + point_coords = np.round(point_coords, decimals=decimals) + coords_match = np.all(all_coords == point_coords, axis=1) + index = np.where(coords_match)[0] + return index + + # FIX: can we round to 3 decimals right away in general? + # FIX: assumes points contains a sorted and continous + # slice of updates + i_0 = get_matching_indices( + coords, + np.asarray([points[0][a] for a in axes]), decimals=3)[0] + i_f = get_matching_indices( + coords, + np.asarray([points[-1][a] for a in axes]), decimals=3)[0] + slices = {k: np.asarray([p[k] for p in points]) for k in points[0]} + for k, v in slices.items(): + values.append( + { + 'data': v, + 'path': k, + } + ) + else: + values.extend( + [ + { + 'data': v, + 'path': k, + } + for k, v in points[0].items() + ] + ) + + return values + if __name__ == '__main__': # Local modules From 00f8b587c78b0b185564363696ca0153f1d17593 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Fri, 10 Apr 2026 14:54:37 -0400 Subject: [PATCH 23/76] optimize: NexusValuesWriter loop order --- CHAP/common/writer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHAP/common/writer.py b/CHAP/common/writer.py index b426d60e..aad605c1 100755 --- a/CHAP/common/writer.py +++ b/CHAP/common/writer.py @@ -582,8 +582,8 @@ def write(self, data, filename): from nexusformat.nexus import NXFile data = self.get_pipelinedata_item(data, remove=self.remove) - for d in data: - with NXFile(filename, 'a') as nxroot: + with NXFile(filename, 'a') as nxroot: + for d in data: try: self.nxs_writer( nxroot=nxroot, From a2ad853a2f20e94fc4dee60235cb81d8681f0b14 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Tue, 21 Apr 2026 11:08:57 -0400 Subject: [PATCH 24/76] add: StrainAnalysisProcessor.standalone --- CHAP/edd/processor.py | 797 ++++++++++++------------------------------ 1 file changed, 216 insertions(+), 581 deletions(-) diff --git a/CHAP/edd/processor.py b/CHAP/edd/processor.py index cad34d02..230d0a25 100755 --- a/CHAP/edd/processor.py +++ b/CHAP/edd/processor.py @@ -5,7 +5,6 @@ Author : Keara Soloway, Rolf Verberg Description: Module for Processors used only by EDD experiments """ - # System modules from copy import deepcopy import os @@ -2514,6 +2513,16 @@ class StrainAnalysisProcessor(BaseStrainProcessor): results as a list of updated points or update the result from the `setup` stage, defaults to `True`. :type update: bool, optional + :ivar standalone: Return results in standalone format suitable + for autonomous/streaming experiments, defaults to `False`. + When `True`, setup returns a standalone + `nexusformat.nexus.NXprocess` (not wrapped in an + `nexusformat.nexus.NXroot`), and update returns a list of + ``{'data': array, 'path': str}`` dicts suitable for writing + with ``common.NexusValuesWriter``. When both `setup` and + `update` are `True` the standalone `NXprocess` and the + values list are returned together as a tuple. + :type standalone: bool, optional """ pipeline_fields: dict = Field( default = { @@ -2527,6 +2536,7 @@ class StrainAnalysisProcessor(BaseStrainProcessor): detector_config: MCADetectorConfig setup: Optional[bool] = True update: Optional[bool] = True + standalone: Optional[bool] = False @model_validator(mode='before') @classmethod @@ -2755,24 +2765,36 @@ def process(self, data): self._apply_combined_mask() # Setup and/or run the strain analysis - points = [] + results = {} if self.update: - points = self._strain_analysis() + results = self._strain_analysis() if self.setup: - nxroot = self._get_nxroot(nxentry, calibration_config) - if points: - self.logger.info(f'Adding {len(points)} points') - self.add_points(nxroot, points, logger=self.logger) - self.logger.info(f'... done') + if self.standalone: + nxsetup = self._get_nxprocess(nxentry, calibration_config) else: - self.logger.warning('Skip adding points') - if not (self._figures or self._animation): - return nxroot - ret = [nxroot] + nxsetup = self._get_nxroot(nxentry, calibration_config) + if not self.standalone: + points = self._get_points(results) + if points: + self.logger.info(f'Adding {len(points)} points') + self.add_points(nxsetup, points, logger=self.logger) + self.logger.info(f'... done') + else: + self.logger.warning('Skip adding points') + if self.standalone and self.update: + # Return nxprocess structure and values separately for writer + values = self._get_values(results) + ret = [nxsetup, values] + else: + if not (self._figures or self._animation): + return nxsetup + ret = [nxsetup] else: + result = self._get_values(results) if self.standalone \ + else self._get_points(results) if not (self._figures or self._animation): - return points - ret = [points] + return result + ret = [result] if self._figures: ret.append( PipelineData( @@ -3014,9 +3036,38 @@ def animate(i): f'{detector_id}_strainanalysis_unconstrained_fits')) plt.close() - def _get_nxroot(self, nxentry, calibration_config): - """Return a `nexusformat.nexus.NXroot` object initialized for - the stress analysis. + def _get_points(self, results): + """Convert strain analysis results to list-of-dicts format + expected by `add_points`. + + :param results: Strain analysis results, mapping NeXus path + strings to arrays with leading axis = num_points. + :type results: dict[str, numpy.ndarray] + :return: One dict per map coordinate, each containing all + result values at that coordinate. + :rtype: list[dict[str, object]] + """ + if not results: + return [] + num_points = next(iter(results.values())).shape[0] + return [{k: v[i] for k, v in results.items()} + for i in range(num_points)] + + def _get_values(self, results): + """Convert strain analysis results to a list of value dicts + suitable for writing with ``common.NexusValuesWriter``. + + :param results: Strain analysis results, mapping NeXus path + strings to arrays with leading axis = num_points. + :type results: dict[str, numpy.ndarray] + :return: List of ``{'data': array, 'path': str}`` dicts. + :rtype: list[dict[str, object]] + """ + return [{'data': v, 'path': k} for k, v in results.items()] + + def _get_nxprocess(self, nxentry, calibration_config): + """Return a standalone NXprocess for the strain analysis + results & metadata. :param nxentry: Strain analysis map, including the raw MCA data. @@ -3024,8 +3075,8 @@ def _get_nxroot(self, nxentry, calibration_config): :param calibration_config: 2&theta calibration configuration. :type calibration_config: CHAP.edd.models.MCATthCalibrationConfig - :return: Strain analysis results & associated metadata.. - :rtype: nexusformat.nexus.NXroot + :return: Strain analysis results & associated metadata. + :rtype: nexusformat.nexus.NXprocess """ # Third party modules # pylint: disable=no-name-in-module @@ -3033,7 +3084,6 @@ def _get_nxroot(self, nxentry, calibration_config): NXdetector, NXfield, NXprocess, - NXroot, ) # pylint: enable=no-name-in-module @@ -3050,11 +3100,7 @@ def _get_nxroot(self, nxentry, calibration_config): 'StrainAnalysis Configuration, or re-run the pipeline with ' 'the --interactive flag.') - # Create the NXroot object - nxroot = NXroot() - nxroot[nxentry.nxname] = nxentry - nxroot[f'{nxentry.nxname}_strainanalysis'] = NXprocess() - nxprocess = nxroot[f'{nxentry.nxname}_strainanalysis'] + nxprocess = NXprocess() nxprocess.calibration_config = \ calibration_config.model_dump_json() nxprocess.strain_analysis_config = \ @@ -3169,6 +3215,30 @@ def _get_nxroot(self, nxentry, calibration_config): tth_map = detector.get_tth_map((num_points,)) det_nxdata.tth.nxdata = tth_map + return nxprocess + + def _get_nxroot(self, nxentry, calibration_config): + """Return a `nexusformat.nexus.NXroot` object initialized for + the strain analysis. + + :param nxentry: Strain analysis map, including the raw + MCA data. + :type nxentry: nexusformat.nexus.NXentry + :param calibration_config: 2&theta calibration configuration. + :type calibration_config: + CHAP.edd.models.MCATthCalibrationConfig + :return: Strain analysis results & associated metadata. + :rtype: nexusformat.nexus.NXroot + """ + # Third party modules + # pylint: disable=no-name-in-module + from nexusformat.nexus import NXroot + # pylint: enable=no-name-in-module + + nxroot = NXroot() + nxroot[nxentry.nxname] = nxentry + nxprocess = self._get_nxprocess(nxentry, calibration_config) + nxroot[f'{nxentry.nxname}_strainanalysis'] = nxprocess return nxroot def _linkdims( @@ -3178,9 +3248,16 @@ def _linkdims( object. """ # Third party modules - from nexusformat.nexus import NXfield + from nexusformat.nexus import NXfield, NXroot from nexusformat.nexus.tree import NXlinkfield + if not isinstance(nxgroup.nxroot, NXroot): + self.logger.warning( + 'Skipping linkdims -- type(nxgroup.nxroot) = ' + + f'{type(nxgroup.nxroot)}' + ) + return + if skip_field_dims is None: skip_field_dims = [] if oversampling_axis is None: @@ -3232,7 +3309,13 @@ def _linkdims( nxgroup.attrs['unstructured_axes'] = unstructured_axes def _strain_analysis(self): - """Perform the strain analysis on the full or partial map.""" + """Perform the strain analysis on the full or partial map. + + :return: Strain analysis results mapping NeXus path strings to + arrays with leading axis = num_points. Returns an empty + dict when no valid peaks are found. + :rtype: dict[str, numpy.ndarray] + """ # Third party modules from nexusformat.nexus import NXfield @@ -3254,25 +3337,23 @@ def _strain_analysis(self): f'{self.detector_config.detectors[0].get_id()}_' 'strainanalysis_material_config')) - # Setup the points list with the map axes values + # Setup the results dict with the map axes values nxdata_ref = self._nxdata_detectors[0] axes = get_axes(nxdata_ref) if axes: - points = [ - {a: nxdata_ref[a].nxdata[i] for a in axes} - for i in range(nxdata_ref[axes[0]].size)] + num_points = nxdata_ref[axes[0]].size + results = {a: nxdata_ref[a].nxdata for a in axes} else: axes = ['index'] - points = [ - {'index': i} - for i in range(np.prod(nxdata_ref.nxsignal.shape[:-1]))] + num_points = int(np.prod(nxdata_ref.nxsignal.shape[:-1])) + results = {'index': np.arange(num_points)} for nxdata in self._nxdata_detectors: nxdata.attrs['axes'] = axes nxdata.index = NXfield( - np.arange(np.prod(nxdata_ref.nxsignal.shape[:-1])), + np.arange(num_points), 'index') - # Loop over the detectors to fill in the nxprocess + # Loop over the detectors to fill in the results dict for energies, mask, mean_data, nxdata, detector in zip( self._energies, self._masks, self._mean_data, self._nxdata_detectors, self.detector_config.detectors): @@ -3341,7 +3422,7 @@ def _strain_analysis(self): self.logger.warning( 'No matching peaks with heights above the threshold, ' f'skipping the fit for detector {detector.get_id()}') - return [] + return {} self._peak_fit_info.append({ 'hkls': ["".join(map(str, hkl)) for hkl in hkls_fit], 'nominal_peak_centers': peak_locations.tolist(), @@ -3354,7 +3435,7 @@ def _strain_analysis(self): np.squeeze(intensities), energies[mask], peak_locations[use_peaks], detector, num_proc=self.config.num_proc, **self.run_config) - if intensities.shape[0] == 1: + if num_points == 1: uniform_results = {k: [v] for k, v in uniform_results.items()} unconstrained_results = { k: [v] for k, v in unconstrained_results.items()} @@ -3373,8 +3454,7 @@ def _strain_analysis(self): self.logger.info('... done') - # Add the fit results to the list of points - num_points = len(points) + # Compute the strain analysis results for all map points tth_map = detector.get_tth_map((nxdata.shape[0],)) nominal_centers = np.asarray( [get_peak_locations(d0, tth_map) @@ -3410,109 +3490,103 @@ def _strain_analysis(self): unconstrained_amplitudes_vary, insert_peak_indices, [False], axis=-1) - # Add points - for i, point in enumerate(points): - point.update({ - f'{detector.get_id()}/data/intensity': intensities[i], - f'{detector.get_id()}/data/norm': intensity_norms[i], - f'{detector.get_id()}/data/uniform_strain': - uniform_strain[i], - f'{detector.get_id()}/data/unconstrained_strain': - unconstrained_strain[i], - f'{detector.get_id()}/data/unconstrained_strain_stdev': - unconstrained_strain_stdev[i], - f'{detector.get_id()}/uniform_fit/results/best_fit': - uniform_results['best_fits'][i], - f'{detector.get_id()}/uniform_fit/results/included_peaks': - uniform_amplitudes_vary[i], - f'{detector.get_id()}/uniform_fit/results/residual': - uniform_results['residuals'][i], - f'{detector.get_id()}/uniform_fit/results/redchi': - uniform_results['redchis'][i], - f'{detector.get_id()}/uniform_fit/results/success': - uniform_results['success'][i], - f'{detector.get_id()}/unconstrained_fit/results/best_fit': - unconstrained_results['best_fits'][i], - f'{detector.get_id()}/unconstrained_fit/results/' - 'included_peaks': unconstrained_amplitudes_vary[i], - f'{detector.get_id()}/unconstrained_fit/results/residual': - unconstrained_results['residuals'][i], - f'{detector.get_id()}/unconstrained_fit/results/redchi': - unconstrained_results['redchis'][i], - f'{detector.get_id()}/unconstrained_fit/results/success': - unconstrained_results['success'][i], + # Store results as full arrays (leading axis = num_points) + det_id = detector.get_id() + results.update({ + f'{det_id}/data/intensity': intensities, + f'{det_id}/data/norm': intensity_norms, + f'{det_id}/data/uniform_strain': uniform_strain, + f'{det_id}/data/unconstrained_strain': unconstrained_strain, + f'{det_id}/data/unconstrained_strain_stdev': + unconstrained_strain_stdev, + f'{det_id}/uniform_fit/results/best_fit': + np.asarray(uniform_results['best_fits']), + f'{det_id}/uniform_fit/results/included_peaks': + uniform_amplitudes_vary, + f'{det_id}/uniform_fit/results/residual': + np.asarray(uniform_results['residuals']), + f'{det_id}/uniform_fit/results/redchi': + np.asarray(uniform_results['redchis']), + f'{det_id}/uniform_fit/results/success': + np.asarray(uniform_results['success']), + f'{det_id}/unconstrained_fit/results/best_fit': + np.asarray(unconstrained_results['best_fits']), + f'{det_id}/unconstrained_fit/results/included_peaks': + unconstrained_amplitudes_vary, + f'{det_id}/unconstrained_fit/results/residual': + np.asarray(unconstrained_results['residuals']), + f'{det_id}/unconstrained_fit/results/redchi': + np.asarray(unconstrained_results['redchis']), + f'{det_id}/unconstrained_fit/results/success': + np.asarray(unconstrained_results['success']), + }) + for j, hkl in enumerate(hkls_fit[use_peaks]): + hkl_name = '_'.join(str(hkl)[1:-1].split(' ')) + uniform_fit_path = f'{det_id}/uniform_fit/{hkl_name}' + unconstrained_fit_path = \ + f'{det_id}/unconstrained_fit/{hkl_name}' + results.update({ + f'{uniform_fit_path}/amplitudes/values': + np.asarray(uniform_results['amplitudes'][j]), + f'{uniform_fit_path}/amplitudes/errors': + np.asarray(uniform_results['amplitudes_errors'][j]), + f'{uniform_fit_path}/centers/values': + uniform_centers[j], + f'{uniform_fit_path}/centers/errors': + np.asarray(uniform_results['centers_errors'][j]), + f'{uniform_fit_path}/sigmas/values': + np.asarray(uniform_results['sigmas'][j]), + f'{uniform_fit_path}/sigmas/errors': + np.asarray(uniform_results['sigmas_errors'][j]), + f'{unconstrained_fit_path}/amplitudes/values': + np.asarray(unconstrained_results['amplitudes'][j]), + f'{unconstrained_fit_path}/amplitudes/errors': + np.asarray( + unconstrained_results['amplitudes_errors'][j]), + f'{unconstrained_fit_path}/centers/values': + unconstrained_centers[j], + f'{unconstrained_fit_path}/centers/errors': + np.asarray( + unconstrained_results['centers_errors'][j]), + f'{unconstrained_fit_path}/sigmas/values': + np.asarray(unconstrained_results['sigmas'][j]), + f'{unconstrained_fit_path}/sigmas/errors': + np.asarray(unconstrained_results['sigmas_errors'][j]), }) - for j, hkl in enumerate(hkls_fit[use_peaks]): - hkl_name = '_'.join(str(hkl)[1:-1].split(' ')) - uniform_fit_path = \ - f'{detector.get_id()}/uniform_fit/{hkl_name}' - unconstrained_fit_path = \ - f'{detector.get_id()}/unconstrained_fit/{hkl_name}' - point.update({ - f'{uniform_fit_path}/amplitudes/values': - uniform_results['amplitudes'][j][i], - f'{uniform_fit_path}/amplitudes/errors': - uniform_results['amplitudes_errors'][j][i], - f'{uniform_fit_path}/centers/values': - uniform_centers[j][i], - f'{uniform_fit_path}/centers/errors': - uniform_results['centers_errors'][j][i], - f'{uniform_fit_path}/sigmas/values': - uniform_results['sigmas'][j][i], - f'{uniform_fit_path}/sigmas/errors': - uniform_results['sigmas_errors'][j][i], - f'{unconstrained_fit_path}/amplitudes/values': - unconstrained_results['amplitudes'][j][i], - f'{unconstrained_fit_path}/amplitudes/errors': - unconstrained_results['amplitudes_errors'][j][i], - f'{unconstrained_fit_path}/centers/values': - unconstrained_centers[j][i], - f'{unconstrained_fit_path}/centers/errors': - unconstrained_results['centers_errors'][j][i], - f'{unconstrained_fit_path}/sigmas/values': - unconstrained_results['sigmas'][j][i], - f'{unconstrained_fit_path}/sigmas/errors': - unconstrained_results['sigmas_errors'][j][i], + if detector.peak_models == 'pvoigt': + results.update({ + f'{uniform_fit_path}/fractions/values': + np.asarray(uniform_results['fractions'][j]), + f'{uniform_fit_path}/fractions/errors': + np.asarray( + uniform_results['fractions_errors'][j]), + f'{unconstrained_fit_path}/fractions/values': + np.asarray( + unconstrained_results['fractions'][j]), + f'{unconstrained_fit_path}/fractions/errors': + np.asarray( + unconstrained_results['fractions_errors'][j]), }) - if detector.peak_models == 'pvoigt': - point.update({ - f'{uniform_fit_path}/fractions/values': - uniform_results['fractions'][j][i], - f'{uniform_fit_path}/fractions/errors': - uniform_results['fractions_errors'][j][i], - f'{unconstrained_fit_path}/fractions/values': - unconstrained_results['fractions'][j][i], - f'{unconstrained_fit_path}/fractions/errors': - unconstrained_results['fractions_errors'][j][i], - }) - if unconstrained_centers[j][i]: - point.update({ - f'{unconstrained_fit_path}/strains/values': - unconstrained_strains[j][i], - f'{unconstrained_fit_path}/strains/residuals': - unconstrained_strain[i] - - unconstrained_strains[j][i], - }) - if (unconstrained_results['centers_errors'][j][i] - is None): - point.update({ - f'{unconstrained_fit_path}/strains/errors': - None, - }) - else: - point.update({ - f'{unconstrained_fit_path}/strains/errors': - unconstrained_results[ - 'centers_errors'][j][i] / - unconstrained_centers[j][i], - }) - else: - point.update({ - f'{unconstrained_fit_path}/strains/values': None, - f'{unconstrained_fit_path}/strains/errors': None, - f'{unconstrained_fit_path}/strains/residuals': - None, - }) + # Strain values: NaN where unconstrained center is zero + has_center = unconstrained_centers[j].astype(bool) + centers_errors_j = unconstrained_results['centers_errors'][j] + strain_errors_raw = np.asarray([ + (e / c if e is not None else np.nan) + for e, c in zip(centers_errors_j, + unconstrained_centers[j]) + ]) + results.update({ + f'{unconstrained_fit_path}/strains/values': + np.where(has_center, unconstrained_strains[j], + np.nan), + f'{unconstrained_fit_path}/strains/errors': + np.where(has_center, strain_errors_raw, np.nan), + f'{unconstrained_fit_path}/strains/residuals': + np.where( + has_center, + unconstrained_strain - unconstrained_strains[j], + np.nan), + }) # Create an animation of the fit points if (not self.config.skip_animation @@ -3521,446 +3595,7 @@ def _strain_analysis(self): nxdata, energies[mask], intensities, intensity_norms, unconstrained_results['best_fits'], detector.get_id()) - return points - -class OrnlStrainAnalysisProcessor(StrainAnalysisProcessor): - """Processor that handles strain analysis in the middle of - autonomously driven experients.""" - - def process(self, data): - """Setup the strain analysis and/or return the strain analysis - results as a list of updated points or a - `nexusformat.nexus.NXroot` object. - - :param data: Input data containing configurations for a map, - completed energy/tth calibration, and (optionally) - parameters for the strain analysis. - :type data: list[PipelineData] - :raises RuntimeError: Unable to get a valid strain analysis - configuration. - :return: The strain analysis setup or results, a list of - byte stream representions of Matplotlib figures and an - animation of the fit results. - :rtype: Union[list[dict[str, object]], - nexusformat.nexus.NXroot], PipelineData, PipelineData - """ - # Third party modules - from nexusformat.nexus import ( - NXentry, - NXroot, - ) - - # Local modules - from CHAP.utils.general import list_to_string - - if not (self.setup or self.update): - raise RuntimeError('Illegal combination of setup and update') - if not self.update: - if self.interactive: - self.logger.warning('Interactive option disabled during setup') - self.interactive = False - if self.save_figures: - self.logger.warning( - 'Saving figures option disabled during setup') - self.save_figures = False - self._animation = [] - - # Load the pipeline input data - try: - nxobject = self.get_data(data) - if isinstance(nxobject, NXroot): - nxroot = nxobject - elif isinstance(nxobject, NXentry): - nxroot = NXroot() - nxroot[nxobject.nxname] = nxobject - nxobject.set_default() - else: - raise RuntimeError - except Exception as exc: - raise RuntimeError( - 'No valid input in the pipeline data') from exc - - # Load the detector data - # FIX set rel_height_cutoff - nxentry = self.get_default_nxentry(nxroot) - for k, v in nxentry.scalar_data.items(): - self.logger.debug(f'{k}.chunks = {v.chunks}') - self.logger.debug(f'{k}.maxshape = {v.maxshape}') - - nxdata = nxentry[nxentry.default] - - # Load the validated calibration configuration - # FIX make this a class field and add to pipeline_fields too? - # NB: be sure that adding this field does not mess up - # self.detector_config during pydantic validation - calibration_config = self.get_config( - data, schema='edd.models.MCATthCalibrationConfig', remove=False) - - # Load the validated calibration detector configurations - calibration_detector_config = self.get_data( - data, schema='edd.models.MCATthCalibrationConfig') - calibration_detectors = [ - MCADetectorCalibration(**d) - for d in calibration_detector_config.get('detectors', [])] - calibration_detector_ids = [d.get_id() for d in calibration_detectors] - - # Check for available raw detector data and for the available - # calibration data - if not self.detector_config.detectors: - self.detector_config.detectors = [ - MCADetectorStrainAnalysis( - id=id_, processor_type='strainanalysis') - for id_ in nxentry.detector_ids] - self.detector_config.update_detectors() - skipped_detectors = [] - sskipped_detectors = [] - detectors = [] - for detector in self.detector_config.detectors: - detector_id = detector.get_id() - if detector_id not in nxdata: - skipped_detectors.append(detector_id) - elif detector_id not in calibration_detector_ids: - sskipped_detectors.append(detector_id) - else: - raw_detector_data = nxdata[detector_id].nxdata - if raw_detector_data.ndim != 2: - self.logger.warning( - f'Skipping detector {detector_id} (Illegal data shape ' - f'{raw_detector_data.shape})') - elif raw_detector_data.sum(): - for k, v in nxdata[detector_id].attrs.items(): - detector.attrs[k] = v.nxdata - if self.config.rel_height_cutoff is not None: - detector.rel_height_cutoff = \ - self.config.rel_height_cutoff - detector.add_calibration( - calibration_detectors[ - int(calibration_detector_ids.index(detector_id))]) - detectors.append(detector) - else: - self.logger.warning( - f'Skipping detector {detector_id} (zero intensity)') - if len(skipped_detectors) == 1: - self.logger.warning( - f'Skipping detector {skipped_detectors[0]} ' - '(no raw data)') - elif skipped_detectors: - skipped_detectors = [int(d) for d in skipped_detectors] - self.logger.warning( - 'Skipping detectors ' - f'{list_to_string(skipped_detectors)} (no raw data)') - if len(sskipped_detectors) == 1: - self.logger.warning( - f'Skipping detector {sskipped_detectors[0]} ' - '(no raw data)') - elif sskipped_detectors: - skipped_detectors = [int(d) for d in sskipped_detectors] - self.logger.warning( - 'Skipping detectors ' - f'{list_to_string(skipped_detectors)} (no calibration data)') - self.detector_config.detectors = detectors - if not self.detector_config.detectors: - raise ValueError('No valid data or unable to match an available ' - 'calibrated detector for the strain analysis') - - # Load the raw MCA data and compute the detector bin energies - # and the mean spectra - self._setup_detector_data( - nxentry[nxentry.default], - strain_analysis_config=self.config, update=self.update) - - # Apply the energy mask - self._apply_energy_mask() - - # Get the mask and HKLs used in the strain analysis - self._get_mask_hkls() - - # Apply the combined energy ranges mask - self._apply_combined_mask() - - # Setup and/or run the strain analysis - points = [] - if self.update: - points = self._strain_analysis() - values = self._get_values(nxroot, points) - if self.setup: - nxprocess = self._get_nxprocess(nxentry, calibration_config) - if points: - self.logger.info(f'Adding {len(points)} points') - self.add_points(nxprocess, points, logger=self.logger) - self.logger.info(f'... done') - else: - self.logger.warning('Skip adding points') - if not (self._figures or self._animation): - return nxprocess - ret = [nxprocess] - else: - if not (self._figures or self._animation): - return values - ret = [values] - if self._figures: - ret.append( - PipelineData( - name=self.__name__, data=self._figures, - schema='common.write.ImageWriter')) - if self._animation: - ret.append( - PipelineData( - name=self.__name__, data=self._animation, - schema='common.write.ImageWriter')) - return tuple(ret) - - def _get_nxprocess(self, nxentry, calibration_config): - """Return a standalone NXprocess for the strain analysis - results & metadata. - - :param nxentry: Strain analysis map, including the raw - MCA data. - :type nxentry: nexusformat.nexus.NXentry - :param calibration_config: 2&theta calibration configuration. - :type calibration_config: - CHAP.edd.models.MCATthCalibrationConfig - :return: Strain analysis results & associated metadata. - :rtype: nexusformat.nexus.NXprocess - """ - # Third party modules - # pylint: disable=no-name-in-module - from nexusformat.nexus import ( - NXdetector, - NXfield, - NXprocess, - NXroot, - ) - # pylint: enable=no-name-in-module - - # Third party modules - from json import dumps - - # Local modules - from CHAP.edd.utils import get_unique_hkls_ds - from CHAP.utils.general import nxcopy - - if not self.interactive and not self.config.materials: - raise ValueError( - 'No material provided. Provide a material in the ' - 'StrainAnalysis Configuration, or re-run the pipeline with ' - 'the --interactive flag.') - - nxprocess = NXprocess() - nxprocess.calibration_config = \ - calibration_config.model_dump_json() - nxprocess.strain_analysis_config = \ - self.config.model_dump_json() - - if len(self._peak_fit_info) == 0: - # FIX this is a temporary fix to be able to run update - # after setup. - self._peak_fit_info = [{}] * len(self.detector_config.detectors) - self.logger.warning('Missing peak_fit_info') - - # Loop over the detectors to fill in the nxprocess - for energies, mask, nxdata, detector, peak_fit_info in zip( - self._energies, self._masks, self._nxdata_detectors, - self.detector_config.detectors, self._peak_fit_info): - - # Get the current data object - data = nxdata.nxsignal - num_points = data.shape[0] - - # Setup the NXdetector object for the current detector - self.logger.debug( - f'Setting up NXdetector group for {detector.get_id()}') - nxdetector = NXdetector() - nxprocess[detector.get_id()] = nxdetector - nxdetector.local_name = detector.get_id() - nxdetector.detector_config = detector.model_dump_json() - nxdetector.peak_fit_info = dumps(peak_fit_info) - nxdetector.data = nxcopy(nxdata, exclude_nxpaths='detector_data') - det_nxdata = nxdetector.data - if 'axes' in det_nxdata.attrs: - if isinstance(det_nxdata.attrs['axes'], str): - det_nxdata.attrs['axes'] = [ - det_nxdata.attrs['axes'], 'energy'] - else: - det_nxdata.attrs['axes'].append('energy') - else: - det_nxdata.attrs['axes'] = ['energy'] - det_nxdata.energy = NXfield( - value=energies[mask], attrs={'units': 'keV'}) - det_nxdata.norm = NXfield( - dtype=np.float64, - shape=(num_points,), - maxshape=(None,), chunks=(1,) - ) - det_nxdata.tth = NXfield( - dtype=np.float64, - shape=(num_points,), - attrs={'units':'degrees', 'long_name': '2\u03B8 (degrees)'}, - maxshape=(None,), chunks=(1,), - ) - det_nxdata.uniform_strain = NXfield( - dtype=np.float64, - shape=(num_points,), - attrs={'long_name': 'Strain from uniform fit (\u03B5)'}, - # attrs={'long_name': 'Strain from uniform fit (\u03BC\u03B5)'} - maxshape=(None,), chunks=(1,) - ) - - det_nxdata.unconstrained_strain = NXfield( - dtype=np.float64, - shape=(num_points,), - attrs={'long_name': - 'Strain from unconstrained fit (\u03B5)'}, - #'Strain from unconstrained fit (\u03BC\u03B5)'}, - maxshape=(None,), chunks=(1,) - ) - det_nxdata.unconstrained_strain_stdev = NXfield( - dtype=np.float64, - shape=(num_points,), - attrs={'long_name': - 'Standard deviation in strain from unconstrained ' - 'fit (\u03B5)'}, - #'fit (\u03BC\u03B5)'} - maxshape=(None,), chunks=(1,) - ) - - # Add the detector data - _intensity=np.asarray( - [ - data[i].astype(np.float64)[mask] - for i in range(num_points) - ] - ) - det_nxdata.intensity = NXfield( - value=_intensity, - attrs={'units': 'counts'}, - maxshape=(None,*_intensity.shape[1:]), - chunks=(1,*_intensity.shape[1:]) - ) - det_nxdata.attrs['signal'] = 'intensity' - - # Get the unique HKLs and lattice spacings for the strain - # analysis materials - hkls, _ = get_unique_hkls_ds( - self.config.materials, tth_max=detector.tth_max, - tth_tol=detector.tth_tol) - - # Get the HKLs and lattice spacings that will be used for - # fitting - hkls_fit = np.asarray([hkls[i] for i in detector.hkl_indices]) - - # Add the uniform fit nxcollection - self._add_fit_nxcollection( - nxdetector, 'uniform', hkls_fit, peak_fit_info) - - # Add the unconstrained fit nxcollection - self._add_fit_nxcollection( - nxdetector, 'unconstrained', hkls_fit, peak_fit_info) - - # Add the strain fields - tth_map = detector.get_tth_map((num_points,)) - det_nxdata.tth.nxdata = tth_map - - return nxprocess - - def _linkdims( - self, nxgroup, nxdata_source, add_field_dims=None, - skip_field_dims=None, oversampling_axis=None): - """Link the dimensions for a 'nexusformat.nexus.NXgroup` - object. - """ - # Third party modules - from nexusformat.nexus import NXfield, NXroot - from nexusformat.nexus.tree import NXlinkfield - - if not isinstance(nxgroup.nxroot, NXroot): - self.logger.warning( - 'Skipping linkdims -- type(nxgroup.nxroot) = ' - + f'{type(nxgroup.nxroot)}' - ) - return - super()._linkdims( - nxgroup, nxdata_source, add_field_dims=add_field_dims, - skip_field_dims=skip_field_dims, oversampling_axis=skip_field_dims - ) - - def _get_values(self, nxroot, points): - """Return list of dictionaries with new srtain results values - suitable for writing with `common.NexusValuesWriter`. - - :param nxroot: - :type nxroot: - :param points: - :type points: - :returns: - :rtype: - """ - # pylint: disable=no-name-in-module - from nexusformat.nexus import ( - NXdetector, - NXprocess - ) - # pylint: enable=no-name-in-module - - nxprocess = None - for nxobject in nxroot.values(): - if isinstance(nxobject, NXprocess): - nxprocess = nxobject - break - if nxprocess is None: - raise RuntimeError('Unable to find the strainanalysis object') - - nxdata_detectors = [] - for nxobject in nxprocess.values(): - if isinstance(nxobject, NXdetector): - nxdata_detectors.append(nxobject.data) - if not nxdata_detectors: - raise RuntimeError( - 'Unable to find detector data in strainanalysis object') - axes = get_axes(nxdata_detectors[0], skip_axes=['energy']) - - values = [] - if axes: - coords = np.asarray( - [nxdata_detectors[0][a].nxdata for a in axes]).T - - def get_matching_indices(all_coords, point_coords, decimals=None): - if isinstance(decimals, int): - all_coords = np.round(all_coords, decimals=decimals) - point_coords = np.round(point_coords, decimals=decimals) - coords_match = np.all(all_coords == point_coords, axis=1) - index = np.where(coords_match)[0] - return index - - # FIX: can we round to 3 decimals right away in general? - # FIX: assumes points contains a sorted and continous - # slice of updates - i_0 = get_matching_indices( - coords, - np.asarray([points[0][a] for a in axes]), decimals=3)[0] - i_f = get_matching_indices( - coords, - np.asarray([points[-1][a] for a in axes]), decimals=3)[0] - slices = {k: np.asarray([p[k] for p in points]) for k in points[0]} - for k, v in slices.items(): - values.append( - { - 'data': v, - 'path': k, - } - ) - else: - values.extend( - [ - { - 'data': v, - 'path': k, - } - for k, v in points[0].items() - ] - ) - - return values + return results if __name__ == '__main__': From b0c9ae5ffb126881eff899d97df0d9279f600133 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Tue, 21 Apr 2026 11:26:17 -0400 Subject: [PATCH 25/76] optimize: edd.SliceNXdataReader --- CHAP/edd/reader.py | 83 +++++++++++++++++++++++++++++----------------- 1 file changed, 52 insertions(+), 31 deletions(-) diff --git a/CHAP/edd/reader.py b/CHAP/edd/reader.py index 8bad0f6a..ad4eb757 100755 --- a/CHAP/edd/reader.py +++ b/CHAP/edd/reader.py @@ -770,43 +770,64 @@ def read(self): from CHAP.common import NexusReader from CHAP.utils.general import nxcopy + # Read NXroot reader = NexusReader(**self.model_dump()) - nxroot = nxcopy(reader.read()) - nxentry = None - nxdata = None - for nxname, nxobject in nxroot.items(): - if isinstance(nxobject, NXentry): - nxentry = nxobject - nxdata = nxobject.data + nxroot = reader.read() + + # Locate NXentry + nxentry = next( + (obj for obj in nxroot.values() if isinstance(obj, NXentry)), + None, + ) + if nxentry is None: + raise ValueError('Could not find NXentry group') + nxentry_nxpath = nxentry.nxpath + self.logger.info(f'Using NXentry at: {nxentry_nxpath}') + + # Make a copy of the NXroot, excluding everything but the + # NXentry of interest. Do this so we can just slice the + # NXfields in place in the copy (because copy is not tied to + # the original input file). + exclude_nxpaths = [] + for v in nxroot.values(): + if v.nxpath != nxentry_nxpath: + exclude_nxpaths.append(v.nxpath) + nxroot = nxcopy(nxroot, exclude_nxpaths=exclude_nxpaths) + nxentry = nxroot[nxentry_nxpath] + + # Locate NXdata containining the "SCAN_N" NXfield + nxdata = getattr(nxentry, 'data', None) if nxdata is None: - msg = 'Could not find NXdata group' - self.logger.error(msg) - raise ValueError(msg) - if 'SCAN_N' not in nxdata: - self.logger.warning(f'SCAN_N not in {nxdata}') - scan_n_found = False - for k, v in nxentry.items(): - if 'SCAN_N' in v: + self.logger.warning('NXdata group missing — searching fallback') + + for v in nxentry.values(): + if hasattr(v, 'SCAN_N'): nxdata = v - scan_n_found = True - self.logger.warning(f'Using SCAN_N dataset in {nxdata}') - if not scan_n_found: - msg = 'Cannot find SCAN_N dataset' - self.logger.error(msg) - raise ValueError(msg) - - indices = np.argwhere( - nxdata.SCAN_N.nxdata == self.scan_number).flatten() - for nxname, nxobject in nxdata.items(): - if isinstance(nxobject, NXfield): - nxdata[nxname] = NXfield( - value=nxobject.nxdata[indices], - dtype=nxdata[nxname].dtype, - attrs=nxdata[nxname].attrs, - ) + break + + if nxdata is None: + raise ValueError('Cannot find SCAN_N dataset') + self.logger.info(f'Using NXdata at: {nxdata.nxpath}') + + # Get indicies of SCAN_N that match self.scan_number + scan_field = nxdata['SCAN_N'].nxdata + indices = np.flatnonzero(scan_field == self.scan_number) + + if indices.size == 0: + self.logger.warning( + f'scan_number {self.scan_number} not found in SCAN_N' + ) + self.logger.info(f'Slicing NXfields with: {indices}') + + # Slice only NXfields + for name, obj in list(nxdata.items()): + if isinstance(obj, NXfield): + self.logger.info(f'Slicing NXfield at: {obj.nxpath}') + nxdata[name] = obj[indices] return nxroot + class UpdateNXdataReader(Reader): """Companion to `edd.SetupNXdataReader` and `common.UpdateNXDataProcessor`. Constructs a list of data points From 23ba5d3fa7d53a977d363c3ae5bcfdf56b8b4a6b Mon Sep 17 00:00:00 2001 From: Rolf Verberg Date: Tue, 28 Apr 2026 12:50:53 -0400 Subject: [PATCH 26/76] add: added peak fit results to the EDD calibration output yamls --- CHAP/edd/processor.py | 72 ++++++++++++++++++++++++++++++++----------- 1 file changed, 54 insertions(+), 18 deletions(-) diff --git a/CHAP/edd/processor.py b/CHAP/edd/processor.py index 2309ff78..733ec448 100755 --- a/CHAP/edd/processor.py +++ b/CHAP/edd/processor.py @@ -1185,6 +1185,8 @@ class MCAEnergyCalibrationProcessor(_BaseEddProcessor): config: MCAEnergyCalibrationConfig detector_config: MCADetectorConfig + _peak_fit_results: dict = PrivateAttr(default={}) + @model_validator(mode='before') @classmethod def validate_mcaenergycalibrationprocessor_before(cls, data): @@ -1309,6 +1311,10 @@ def process(self, data): **self.config.model_dump(), 'detectors': [d.model_dump(exclude_defaults=True) for d in self.detector_config.detectors]} + # Add the fit results to the detector fields + for i, d in enumerate(self.detector_config.detectors): + configs['detectors'][i].update( + {'peak_fit_results': self._peak_fit_results[d.get_id()]}) return configs, PipelineData( name=self.__name__, data=self._figures, schema='common.write.ImageWriter') @@ -1369,7 +1375,9 @@ def _calibrate(self): self._energies, self._masks, self._mean_data, self._mask_index_ranges, self.detector_config.detectors): - self.logger.info(f'Calibrating detector {detector.get_id()}') + id_ = detector.get_id() + + self.logger.info(f'Calibrating detector {id_}') bins = low + np.arange(energies.size, dtype=np.int16) @@ -1378,13 +1386,11 @@ def _calibrate(self): for energy in peak_energies] buf, initial_peak_indices = self._get_initial_peak_positions( mean_data*np.asarray(mask).astype(np.int32), low, - detector.mask_ranges, input_indices, max_peak_index, - detector.get_id(), return_buf=self.save_figures) + detector.mask_ranges, input_indices, max_peak_index, id_, + return_buf=self.save_figures) if self.save_figures: self._figures.append( - (buf, - f'{detector.get_id()}' - '_energy_calibration_initial_peak_positions')) + (buf, f'{id_}_energy_calibration_initial_peak_positions')) # Construct the fit model and perform the fit models = [] @@ -1409,17 +1415,24 @@ def _calibrate(self): # Extract the fit results for the peaks fit_peak_amplitudes = sorted([ - mean_data_fit.best_values[f'peak{i+1}_amplitude'] + mean_data_fit.best_values[f'peak{i+1}_amplitude'].tolist() for i in range(len(initial_peak_indices))]) self.logger.debug(f'Fit peak amplitudes: {fit_peak_amplitudes}') fit_peak_indices = sorted([ - mean_data_fit.best_values[f'peak{i+1}_center'] + mean_data_fit.best_values[f'peak{i+1}_center'].tolist() for i in range(len(initial_peak_indices))]) self.logger.debug(f'Fit peak center indices: {fit_peak_indices}') - fit_peak_fwhms = sorted([ - 2.35482*mean_data_fit.best_values[f'peak{i+1}_sigma'] + fit_peak_sigmas = sorted([ + mean_data_fit.best_values[f'peak{i+1}_sigma'].tolist() for i in range(len(initial_peak_indices))]) - self.logger.debug(f'Fit peak fwhms: {fit_peak_fwhms}') + self.logger.debug(f'Fit peak sigmas: {fit_peak_sigmas}') + self._peak_fit_results[id_] = { + 'amplitudes': fit_peak_amplitudes, + 'centers': fit_peak_indices, + 'sigmas': fit_peak_sigmas, + 'independent_dimension': { + 'name': 'Detector Channel', 'unit': '-'} + } # FIX for now stick with a linear energy correction fit = FitProcessor(**self.run_config) @@ -1446,8 +1459,7 @@ def _calibrate(self): bins_masked = bins[mask] fig, axs = plt.subplots(1, 2, figsize=(11, 4.25)) - fig.suptitle( - f'Detector {detector.get_id()} energy calibration') + fig.suptitle(f'Detector {id_} energy calibration') # Left plot: raw MCA data and best fit of peaks axs[0].set_title('MCA spectrum peak fit') axs[0].set_xlabel('Detector Channel (-)') @@ -1488,8 +1500,7 @@ def _calibrate(self): if self.save_figures: self._figures.append( - (fig_to_iobuf(fig), - f'{detector.get_id()}_energy_calibration_fit')) + (fig_to_iobuf(fig), f'{id_}_energy_calibration_fit')) if self.interactive: plt.show() plt.close() @@ -1740,6 +1751,8 @@ class MCATthCalibrationProcessor(_BaseEddProcessor): config: MCATthCalibrationConfig detector_config: MCADetectorConfig + _peak_fit_results: dict = PrivateAttr(default={}) + @model_validator(mode='before') @classmethod def validate_mcatthcalibrationprocessor_before(cls, data): @@ -1875,6 +1888,10 @@ def process(self, data): **self.config.model_dump(), 'detectors': [d.model_dump(exclude_defaults=True) for d in self.detector_config.detectors]} + # Add the fit results to the detector fields + for i, d in enumerate(self.detector_config.detectors): + configs['detectors'][i].update( + {'peak_fit_results': self._peak_fit_results[d.get_id()]}) return configs, PipelineData( name=self.__name__, data=self._figures, schema='common.write.ImageWriter') @@ -2128,7 +2145,26 @@ def _direct_bragg_peak_fit( residual = result.residual # Extract the Bragg peak indices from the fit - i_bragg_fit = np.asarray( + fit_peak_amplitudes = [ + result.best_values[f'peak{i+1}_amplitude'].tolist() + for i in range(len(e_bragg))] + self.logger.debug(f'Fit peak amplitudes: {fit_peak_amplitudes}') + fit_peak_indices = [ + result.best_values[f'peak{i+1}_center'].tolist() + for i in range(len(e_bragg))] + self.logger.debug(f'Fit peak center indices: {fit_peak_indices}') + fit_peak_sigmas = [ + result.best_values[f'peak{i+1}_sigma'].tolist() + for i in range(len(e_bragg))] + self.logger.debug(f'Fit peak sigmas: {fit_peak_sigmas}') + self._peak_fit_results[detector.get_id()] = { + 'amplitudes': fit_peak_amplitudes, + 'centers': fit_peak_indices, + 'sigmas': fit_peak_sigmas, + 'independent_dimension': { + 'name': 'Detector Channel', 'unit': '-'} + } + fit_peak_indices = np.asarray( [result.best_values[f'peak{i+1}_center'] for i in range(len(e_bragg))]) @@ -2140,7 +2176,7 @@ def _direct_bragg_peak_fit( model = 'linear' fit = FitProcessor(**self.run_config) result = fit.process( - NXdata(NXfield(e_bragg, 'y'), NXfield(i_bragg_fit, 'x')), + NXdata(NXfield(e_bragg, 'y'), NXfield(fit_peak_indices, 'x')), {'models': [{'model': model}]}) if quadratic_energy_calibration: a_fit = result.best_values['a'] @@ -2151,7 +2187,7 @@ def _direct_bragg_peak_fit( b_fit = result.best_values['slope'] c_fit = result.best_values['intercept'] e_bragg_unconstrained = ( - (a_fit*i_bragg_fit + b_fit) * i_bragg_fit + c_fit) + (a_fit*fit_peak_indices + b_fit) * fit_peak_indices + c_fit) return { 'best_fit_unconstrained': best_fit, From 5d229a8ae53be23d3d98bafe81c32f9b9c86a733 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Tue, 28 Apr 2026 14:13:11 -0400 Subject: [PATCH 27/76] fix: YAMLWriter can write lists as well as dicts --- CHAP/common/writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHAP/common/writer.py b/CHAP/common/writer.py index aad605c1..89ddea98 100755 --- a/CHAP/common/writer.py +++ b/CHAP/common/writer.py @@ -772,7 +772,7 @@ def write(self, data): from CHAP.models import CHAPBaseModel def get_dict(data): - if isinstance(data, dict): + if isinstance(data, dict) or isinstance(data, list): return data if isinstance(data, (BaseModel, CHAPBaseModel)): try: From 4f85f99c599e794a225f4a6e302a145fc0d394fb Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Tue, 28 Apr 2026 14:14:13 -0400 Subject: [PATCH 28/76] add: JSONWriter --- CHAP/common/writer.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/CHAP/common/writer.py b/CHAP/common/writer.py index 89ddea98..62c983ec 100755 --- a/CHAP/common/writer.py +++ b/CHAP/common/writer.py @@ -418,6 +418,22 @@ def write(self, data): raise ValueError(f'Invalid image input type {type(image_data)}') +class JSONWriter(Writer): + """Writer for JSON data. + + :ivar index: Index of ``PipelineData`` item in input ``data`` list + that should be written to the JSON file. Defaults to ``-1``. + :vartpe index: int, Optional. + """ + index: int = -1 + def write(self, data): + """Write the last """ + import json + _data = self.get_pipelinedata_item(data, remove=self.remove) + with open(self.filename, 'w') as outf: + json.dump(_data, outf) + + class MatplotlibAnimationWriter(Writer): """Writer for saving matplotlib animations. From b81e6c0f0d967286ca6b90e94eddc94f215a8b34 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Tue, 28 Apr 2026 14:17:48 -0400 Subject: [PATCH 29/76] add: StrainAnalysisProcessor.json_results --- CHAP/edd/processor.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/CHAP/edd/processor.py b/CHAP/edd/processor.py index 230d0a25..be95381f 100755 --- a/CHAP/edd/processor.py +++ b/CHAP/edd/processor.py @@ -2523,6 +2523,11 @@ class StrainAnalysisProcessor(BaseStrainProcessor): `update` are `True` the standalone `NXprocess` and the values list are returned together as a tuple. :type standalone: bool, optional + :ivar json_results: If updating, return an additional minimal + dictionary of results that can be written to .json file for + easier access by autonomous feedback experiment + drivers. Defaults to `False`. + :vartype json_results: bool, optional """ pipeline_fields: dict = Field( default = { @@ -2537,6 +2542,7 @@ class StrainAnalysisProcessor(BaseStrainProcessor): setup: Optional[bool] = True update: Optional[bool] = True standalone: Optional[bool] = False + json_results: Optional[bool] = False @model_validator(mode='before') @classmethod @@ -2792,9 +2798,12 @@ def process(self, data): else: result = self._get_values(results) if self.standalone \ else self._get_points(results) - if not (self._figures or self._animation): - return result ret = [result] + if self.json_results: + json_results = {k: v.tolist() for k, v in results.items()} + ret.append(json_results) + if not (self._figures or self._animation): + return tuple(ret) if self._figures: ret.append( PipelineData( From 925f136a5fa53f4753b183f026f1e3e079b46328 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Wed, 29 Apr 2026 10:29:06 -0400 Subject: [PATCH 30/76] add: scan_number option for PointByPointScanData.data_type --- CHAP/common/models/map.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/CHAP/common/models/map.py b/CHAP/common/models/map.py index 52afbf40..65bbfb98 100755 --- a/CHAP/common/models/map.py +++ b/CHAP/common/models/map.py @@ -370,8 +370,8 @@ class PointByPointScanData(CHAPBaseModel): units: constr(strip_whitespace=True, min_length=1) data_type: Literal[ 'spec_motor', 'spec_motor_absolute', 'spec_motor_static', - 'scan_column', 'scan_start_time', 'smb_par', 'expression', - 'detector_log_timestamps', 'scan_step_index' + 'scan_column', 'scan_number', 'scan_start_time', 'smb_par', + 'expression', 'detector_log_timestamps', 'scan_step_index' ] name: constr(strip_whitespace=True, min_length=1) ndigits: Optional[conint(ge=0)] = None @@ -557,6 +557,8 @@ def get_value( return scan_step_index scanparser = get_scanparser(spec_scans.spec_file, scan_number) return [i for i in range(scanparser.spec_scan_npts)] + if self.data_type == 'scan_number': + return scan_number return None @cache From 58b40a4f1d86441f3077b7b6ec9eda46129e2bd8 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Wed, 29 Apr 2026 11:02:45 -0400 Subject: [PATCH 31/76] fix: allow MapConfig.scalar_data to contain items with data_type==expression --- CHAP/common/models/map.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/CHAP/common/models/map.py b/CHAP/common/models/map.py index 65bbfb98..97ee496f 100755 --- a/CHAP/common/models/map.py +++ b/CHAP/common/models/map.py @@ -766,7 +766,9 @@ def _validate_data_source_for_map_config(data_source, info): if data_source is not None: values = info.data if data_source.data_type == 'expression': - data_source.validate_for_scalar_data(values['scalar_data']) + if 'scalar_data' in values: + data_source.validate_for_scalar_data( + values['scalar_data']) else: import_scanparser( values['station'], values['experiment_type']) @@ -1126,6 +1128,18 @@ def validate_before(cls, data): data['attrs'] = {} return data + @model_validator(mode='after') + def validate_scalar_data_expressions(self): + """Validate any expression-type items in scalar_data against + the complete scalar_data list. This deferred check is necessary + because field validation of scalar_data runs before the list is + fully constructed. + """ + for data_source in self.scalar_data: + if data_source.data_type == 'expression': + data_source.validate_for_scalar_data(self.scalar_data) + return self + #RV maybe better to use model_validator, see v2 docs? @field_validator('attrs') @classmethod From f868c510652dc264056b089f4b0c882856a132ab Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Wed, 29 Apr 2026 14:55:18 -0400 Subject: [PATCH 32/76] fix: allow 1d detector shapes in DetectorConfig --- CHAP/common/models/map.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHAP/common/models/map.py b/CHAP/common/models/map.py index 97ee496f..5d159fe8 100755 --- a/CHAP/common/models/map.py +++ b/CHAP/common/models/map.py @@ -65,7 +65,7 @@ class Detector(CHAPBaseModel): :type attrs: dict, optional """ id_: constr(min_length=1) = Field(alias='id') - shape: Optional[tuple[int, int]] = None + shape: Optional[Union[tuple[int, int], tuple[int]]] = None attrs: Optional[Annotated[dict, Field(validate_default=True)]] = {} @field_validator('id_', mode='before') From 519f989979f2925e2dd054f816f186730cb8333b Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Thu, 30 Apr 2026 08:47:05 -0400 Subject: [PATCH 33/76] feat: support 0-scan MapConfigs --- CHAP/common/models/map.py | 10 ++++++++-- CHAP/common/processor.py | 38 ++++++++++++++++++++++++++++++++------ 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/CHAP/common/models/map.py b/CHAP/common/models/map.py index 5d159fe8..597385a6 100755 --- a/CHAP/common/models/map.py +++ b/CHAP/common/models/map.py @@ -159,7 +159,7 @@ class SpecScans(CHAPBaseModel): """ spec_file: FilePath scan_numbers: Union[ - constr(min_length=1), conlist(item_type=conint(gt=0), min_length=1)] + constr(min_length=1), conlist(item_type=conint(gt=0))] par_file: Optional[FilePath] = None @field_validator('spec_file') @@ -1157,13 +1157,17 @@ def validate_attrs(cls, attrs, info): """ # Get the map's scan_type for EDD experiments values = info.data + if not values['validate_data_present']: + return attrs station = values['station'] experiment_type = values['experiment_type'] if station in ['id1a3', 'id3a'] and experiment_type == 'EDD': scan_type = cls.get_smb_par_attr(values, 'scan_type') if scan_type is not None: attrs['scan_type'] = scan_type - attrs['config_id'] = cls.get_smb_par_attr(values, 'config_id') + config_id = cls.get_smb_par_attr(values, 'config_id') + if config_id is not None: + attrs['config_id'] = config_id dataset_id = cls.get_smb_par_attr( values, 'dataset_id', unique=False) if dataset_id is not None: @@ -1215,6 +1219,8 @@ def get_smb_par_attr( # f'{scans.spec_file}.') values.append(None) values = list(set(values)) + if not values: + return None if len(values) == 1: return values[0] if unique: diff --git a/CHAP/common/processor.py b/CHAP/common/processor.py index 29792425..2d2f591c 100755 --- a/CHAP/common/processor.py +++ b/CHAP/common/processor.py @@ -1001,7 +1001,7 @@ def _process( spec_scans = self.config.spec_scans[0] scan_numbers = spec_scans.scan_numbers num_scan = len(scan_numbers) - if num_scan < self.num_proc: + if 0 < num_scan < self.num_proc: self.logger.warning( f'Requested number of processors ({self.num_proc}) exceeds ' f'the number of scans ({num_scan}): reset it to {num_scan}') @@ -1085,7 +1085,16 @@ def _process( offset = common_comm.scatter(offsets, root=0) # Read the raw data - if self.config.experiment_type == 'EDD': + if num_scan == 0: + num_id = len(self.config.independent_dimensions) + num_sd = len(self.config.all_scalar_data) + num_det = len(self.detector_config.detectors) + if placeholder_data is not False: + num_sd += 1 + data = np.empty((num_det, 0)) + independent_dimensions = np.empty((num_id, 0)) + all_scalar_data = np.empty((num_sd, 0)) + elif self.config.experiment_type == 'EDD': data, independent_dimensions, all_scalar_data = \ self._read_raw_data_edd( common_comm, num_scan, offset, placeholder_data) @@ -1093,7 +1102,8 @@ def _process( data, independent_dimensions, all_scalar_data = \ self._read_raw_data(common_comm, num_scan, offset) if not rank: - self.logger.debug(f'Data shape: {data.shape}') + self.logger.debug( + f'Data shape: {data.shape if data is not None else None}') if independent_dimensions is not None: self.logger.debug('Independent dimensions shape: ' f'{independent_dimensions.shape}') @@ -1142,6 +1152,11 @@ def _process( all_scalar_data = np.empty( (len(self.config.all_scalar_data), map_len)) if len(self.detector_config.detectors) > 0: + if det_shapes is False: + det_shapes = {} + for det in self.detector_config.detectors: + if det.shape is not None: + det_shapes[det.get_id()] = det.shape data = np.empty( (len(self.detector_config.detectors), map_len, @@ -1253,7 +1268,14 @@ def linkdims(nxgroup, nxdata_source): nxentry.attrs[k] = v nxentry.spec_scans = NXcollection() for scans in self.config.spec_scans: - nxentry.spec_scans[scans.scanparsers[0].scan_name] = \ + if len(scans.scanparsers) > 0: + key = scans.scanparsers[0].scan_name + else: + if str(scans.spec_file).endswith('spec.log'): + key = str(scans.spec_file).split('/')[-2] + else: + key = str(scans.spec_file).split('/')[-1] + nxentry.spec_scans[key] = \ NXfield(value=scans.scan_numbers, dtype='int8', attrs={'spec_file': str(scans.spec_file)}) @@ -1369,8 +1391,12 @@ def linkdims(nxgroup, nxdata_source): for k, v in self.config.attrs.items(): nxdata.attrs[k] = v if data is not None: - min_ = np.min(data, axis=tuple(range(1, data.ndim))) - max_ = np.max(data, axis=tuple(range(1, data.ndim))) + if data.size > 0: + min_ = np.min(data, axis=tuple(range(1, data.ndim))) + max_ = np.max(data, axis=tuple(range(1, data.ndim))) + else: + min_ = np.full(len(self.detector_config.detectors), np.nan) + max_ = np.full(len(self.detector_config.detectors), np.nan) for i, detector in enumerate(self.detector_config.detectors): nxdata[detector.get_id()] = NXfield( value=data[i], From 168f0bbb406d189b4082e96632f86b6bf82339c8 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Thu, 30 Apr 2026 09:33:50 -0400 Subject: [PATCH 34/76] add: JSONWriter.update and JSONWriter.extend --- CHAP/common/writer.py | 44 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/CHAP/common/writer.py b/CHAP/common/writer.py index 62c983ec..4e775ff9 100755 --- a/CHAP/common/writer.py +++ b/CHAP/common/writer.py @@ -424,14 +424,54 @@ class JSONWriter(Writer): :ivar index: Index of ``PipelineData`` item in input ``data`` list that should be written to the JSON file. Defaults to ``-1``. :vartpe index: int, Optional. + :ivar update: Update an existing file with new values, overwriting + the existing file's values. Defaults to ``False``. + :vartype update: bool, optional + :ivar extend: Extend the values in an existing file, adding to the + existing file's value lists. Defaults to ``False``. + :vartype extend: bool, optional """ index: int = -1 + update: Optional[bool] = False + extend: Optional[bool] = False + def write(self, data): """Write the last """ import json - _data = self.get_pipelinedata_item(data, remove=self.remove) + + _data = self.get_pipelinedata_item( + data, + index=self.index, + remove=self.remove, + ) + + write_data = _data + if self.update: + try: + with open(self.filename, 'r') as inf: + write_data = json.load(inf) + except: + self.logger.warning(f'Could not load JSON from {self.filename}') + write_data = {} + if self.extend: + for k, v in _data.items(): + if k in write_data: + if not isinstance(write_data[k], list): + write_data[k] = [write_data[k]] + write_data[k].extend(v) + else: + write_data[k] = v + else: + write_data.update(_data) + with open(self.filename, 'w') as outf: - json.dump(_data, outf) + json.dump(write_data, outf) + + @model_validator(mode='after') + def validate_modes(self): + if self.extend and not self.update: + self.update = True + return self class MatplotlibAnimationWriter(Writer): From e8ff4e3958c48b476aeb052b9fcea17b875cb773 Mon Sep 17 00:00:00 2001 From: Rolf Verberg Date: Thu, 30 Apr 2026 11:52:31 -0400 Subject: [PATCH 35/76] test: Updated tomo_one_plane.py to use the hollow_cube example --- examples/tomo/tomo_one_plane.py | 182 ++++++++++++++++++++++---------- 1 file changed, 124 insertions(+), 58 deletions(-) diff --git a/examples/tomo/tomo_one_plane.py b/examples/tomo/tomo_one_plane.py index 353d760b..90770b31 100644 --- a/examples/tomo/tomo_one_plane.py +++ b/examples/tomo/tomo_one_plane.py @@ -1,11 +1,19 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +"""Choose the path to your CHESS repo.""" import sys #sys.path.append( # '/home/rv43/Documents/Programs/repos/CHESSComputing/ChessAnalysisPipeline') sys.path.append( - '/nfs/chess/sw/CHESS-software-releases/repos/prod/ChessAnalysisPipeline') + '/home/rv43/Documents/Programs/repos/CHESSComputing/ChessAnalysisPipeline_main') +#sys.path.append( +# '/nfs/chess/sw/CHESS-software-releases/repos/prod/ChessAnalysisPipeline') + +# System modules +import os +from pprint import pprint +from tempfile import NamedTemporaryFile # Third party modules import matplotlib.pyplot as plt @@ -29,74 +37,80 @@ # Start user input #------------------------------------------------------------------------------# +"""Select run_type from +'normal': Full regular reconstruction + (like running the default hollow_cube in CHAP) +'one_plane': Reconstruction of two planes only +'two_planes': Reconstruction of one plane only +'roi': Reconstruction for a subset or rows + (find center at first and last row) +""" +#run_type = 'normal' +#run_type = 'one_plane' +#run_type = 'two_planes' +run_type = 'roi' + interactive = True +outputdir = 'output' construct_chess_map = True -map_name = 'output/map.nxs' -chess_map_name = 'output/chess_map.nxs' -load_converted_map = True -img_row_bounds = [600, 606] # None, [80, 2080] or [1050, 1251] -recon_layer_index = [600, 1400] # One or two values - # ignored when using img_row_bounds reduce_data = True -reduced_data_name = 'output/reduced_data.nxs' +find_center = True +reconstruct_data = True + +map_name = f'{outputdir}/map.nxs' +chess_map_name = f'{outputdir}/chess_map.nxs' +reduced_data_name = f'{outputdir}/reduced_data.nxs' reduce_config = { - 'remove_stripe': {'remove_all_stripe': {}}, - 'img_row_bounds': img_row_bounds, +# 'remove_stripe': {'remove_all_stripe': {}}, } -find_center = True -find_center_name = 'output/find_center.yaml' +find_center_name = f'{outputdir}/find_center.yaml' find_center_config = { - 'gaussian_sigma': 0.75, + 'gaussian_sigma': 0.05, # 'remove_stripe_sigma': None, - 'ring_width': 5, + 'ring_width': 1, } -reconstruct_data = True -reconstructed_data_name = 'output/reconstructed_data.nxs' +reconstructed_data_name = f'{outputdir}/reconstructed_data.nxs' reconstruct_config = { - 'x_bounds': [90, 2470], - 'y_bounds': [600, 2000], - 'secondary_iters:': 25, - 'gaussian_sigma': 0.75, - 'remove_stripe_sigma': None, - 'ring_width': 5, + 'x_bounds': [15, 390], + 'y_bounds': [25, 380], + 'secondary_iters': 10, +# 'gaussian_sigma': 0.75, +# 'remove_stripe_sigma': None, + 'ring_width': 1, } tdf_scan_numbers = 1 tbf_scan_numbers = 2 map_config = { - 'title': 'badran-4510-b', + 'title': 'hollow_cube', 'station': 'id3b', 'experiment_type': 'TOMO', - 'sample': {'name': 'sample_190-1-Edge'}, + 'sample': {'name': 'hollow_cube'}, 'spec_scans': [ - {'spec_file': 'data/sample_190-1-Edge/refine_1', + {'spec_file': 'raw/hollow_cube/hollow_cube', 'scan_numbers': 3} ], 'independent_dimensions': [ {'label': 'rotation_angles', 'units': 'degrees', 'data_type': 'scan_column', - 'name': 'GI_samphi'}, + 'name': 'theta'}, {'label': 'x_translation', 'units': 'mm', 'data_type': 'spec_motor', - 'name': 'saxx'}, + 'name': 'GI_samx'}, {'label': 'z_translation', 'units': 'mm', 'data_type': 'spec_motor', - 'name': 'saxz'} + 'name': 'GI_samz'} ], - 'presample_intensity': { - 'data_type': 'scan_column', - 'name': 'ic1' - } } -detector_config = {'detectors': [{'id': 'andor2'}],} +detector_config = {'detectors': [{'id': 'sim'}],} detector_setup = { - 'prefix': 'andor2', - 'rows': 2160, - 'columns': 2560, - 'pixel_size': [0.0065, 0.0065], - 'lens_magnification': 5.0, + 'prefix': 'sim', + 'rows': 40, + 'columns': 400, + 'pixel_size': [0.05, 0.005], + 'lens_magnification': 1.0, } remove_stripe_corrections = [ # 'remove_all_stripe', @@ -111,34 +125,86 @@ # 'remove_stripe_ti', ] +#------------------------------------------------------------------------------# +# Optional user input +#------------------------------------------------------------------------------# + +if run_type == 'normal': + img_row_bounds = [3, 35] + recon_layer_indices = [11, 28] +elif run_type == 'one_plane': + img_row_bounds = None + recon_layer_indices = 15 +elif run_type == 'two_planes': + img_row_bounds = None + recon_layer_indices = [11, 20] +elif run_type == 'roi': + img_row_bounds = None + recon_layer_indices = [11, 20] +else: + print('Pick valid values for img_row_bounds and recon_layer_indices') + img_row_bounds = None # list[int, int], optional + recon_layer_indices = None # int | list[int, int], optional + # ignored when using img_row_bounds + #------------------------------------------------------------------------------# # End user input #------------------------------------------------------------------------------# +if not os.path.isdir(outputdir): + os.makedirs(outputdir) +try: + NamedTemporaryFile(dir=outputdir) +except Exception as exc: + raise OSError( + 'Output directory not accessible for writing ({outputdir})') from exc + +if isinstance(recon_layer_indices, int): + recon_layer_indices = [recon_layer_indices] if img_row_bounds is None: - if len(recon_layer_index)== 1: + center_rows = None + if len(recon_layer_indices)== 1: detector_config['roi'] = [ - {'start': recon_layer_index, - 'end': recon_layer_index+1, + {'start': recon_layer_indices[0], + 'end': recon_layer_indices[0]+1, 'step': None}, None] - elif len(recon_layer_index) == 2: + elif len(recon_layer_indices) == 2: detector_config['roi'] = [ - {'start': recon_layer_index[0], - 'end': recon_layer_index[1], - 'step': recon_layer_index[1]-recon_layer_index[0]-1}, + {'start': recon_layer_indices[0], + 'end': recon_layer_indices[1]}, None] + if run_type == 'two_planes': + detector_config['roi'][0]['step'] = \ + recon_layer_indices[1]-recon_layer_indices[0]-1 + else: + detector_config['roi'][0]['step'] = None else: - print('Invalid value for recon_layer_index ({recon_layer_index})') - exit() + raise RuntimeError('Invalid value for recon_layer_indices ({recon_layer_indices})') else: assert len(img_row_bounds) == 2 - recon_layer_index = [] - detector_config['roi'] = [ - {'start': img_row_bounds[0], - 'end': img_row_bounds[1], - 'step': 1}, - None] + center_rows = recon_layer_indices + recon_layer_indices = [] + +reduce_config['img_row_bounds'] = img_row_bounds +find_center_config['center_rows'] = center_rows + +#------------------------------------------------------------------------------# + +if construct_chess_map: + print(f'\nmap_config:') + pprint(map_config) +if reduce_data: + print(f'\nreduce_config:') + pprint(reduce_config) +if find_center: + print(f'\nfind_center_config:') + pprint(find_center_config) +if reconstruct_data: + print(f'\nreconstruct_config:') + pprint(reconstruct_config) +print(f'\ndetector_config:') +pprint(detector_config) # Construct the CHESS style tomo map if not construct_chess_map: @@ -197,9 +263,7 @@ if not reduce_data: nxroot = nxload(reduced_data_name) else: - # Load as needed - if load_converted_map: - nxroot = nxload(chess_map_name) + nxroot = nxload(chess_map_name) # Reduce the data with remove_all_stripe data = [ PipelineData(name='TomoCHESSMapConverter', data=nxroot, schema=None)] @@ -222,7 +286,7 @@ image_slice = nxdata.nxsignal[0,:,0,:] vmin = image_slice.min() vmax = image_slice.max() - if len(recon_layer_index): + if len(recon_layer_indices): quick_imshow( image_slice, title=f'Slice {detector_config["roi"][0]["start"]}', @@ -231,7 +295,7 @@ vmin=vmin, vmax=vmax, save_fig=True, block=True) - if len(recon_layer_index) == 2: + if len(recon_layer_indices) == 2: quick_imshow( image_slice, title=f'Slice {detector_config["roi"][0]["end"]}', @@ -300,6 +364,8 @@ # name='TomoFindCenterProcessor', data=center_config, # schema='tomo.models.TomoFindCenterConfig') ] + from pprint import pprint + pprint(reconstruct_config) tomo = TomoReconstructProcessor( config=reconstruct_config, center_config=center_config, From 7610912b5fc4142810ebdc3071a1cd0bd6d2d0ed Mon Sep 17 00:00:00 2001 From: Rolf Verberg Date: Thu, 30 Apr 2026 11:56:10 -0400 Subject: [PATCH 36/76] test: renamed tomo script consistent with CHAP pipeline version --- examples/tomo/{tomo_one_plane.py => tomo_script_id3b.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/tomo/{tomo_one_plane.py => tomo_script_id3b.py} (100%) diff --git a/examples/tomo/tomo_one_plane.py b/examples/tomo/tomo_script_id3b.py similarity index 100% rename from examples/tomo/tomo_one_plane.py rename to examples/tomo/tomo_script_id3b.py From a27102a0ac17cc6a716c10ef2b132405ef1b67ea Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Thu, 30 Apr 2026 12:06:24 -0400 Subject: [PATCH 37/76] feat: support processing multiple scans with MapSliceProcessor --- CHAP/common/map_utils.py | 287 +++++++++++++++++++++++++++------------ 1 file changed, 200 insertions(+), 87 deletions(-) diff --git a/CHAP/common/map_utils.py b/CHAP/common/map_utils.py index 5515d358..531a2157 100644 --- a/CHAP/common/map_utils.py +++ b/CHAP/common/map_utils.py @@ -3,9 +3,11 @@ from pydantic import ( conint, conlist, + model_validator, Field, FilePath, ) +from typing import Optional # Local modules from CHAP import Processor @@ -45,9 +47,9 @@ class MapSliceProcessor(Processor): :ivar spec_file: SPEC file containing scan from which to read a slice of raw data. :type spec_file: pydantic.FilePath - :ivar scan_number: Number of scan from which to read a slice of + :ivar scan_numbers: Numbers of scans from which to read slices of raw data. - :type scan_number: int + :type scan_numbers: list[int] """ pipeline_fields: dict = Field( default={ @@ -57,23 +59,29 @@ class MapSliceProcessor(Processor): map_config: MapConfig detectors: conlist(item_type=Detector, min_length=1) spec_file: FilePath - scan_number: conint(gt=0) + scan_number: Optional[conint(gt=0)] = None + scan_numbers: Optional[conlist(item_type=conint(gt=0))] = None def process(self, data, idx_slice={'start': 0, 'step': 1}): - """Aggregate partial spec and detector data from one scan in a - map, returning results in a format suitable for writing to the - full map container with `common.NexusValuesWriter` or - `common.ZarrValuesWriter`. + """Aggregate partial spec and detector data from one or more + scans in a map, returning results in a format suitable for + writing to the full map container with + `common.NexusValuesWriter` or `common.ZarrValuesWriter`. + + When all scans are adjacent in the map and `idx_slice` covers + each scan in full, data_points entries for the same path are + consolidated into a single array + slice, avoiding redundant + write calls. :param data: Result of `Reader.read` where at least one item has the value `'common.models.map.MapConfig'` for the `'schema'` key. :type data: list[PipelineData] - :type idx_slice: Parameters for the slice of the scan to + :param idx_slice: Parameters for the slice of each scan to process (slice parameters are the usual for the python - `slice` object: `'start'`, `'stop'`, and - `'step'`). Defaults to `{'start': 0, 'step': '1'}`. + `slice` object: `'start'`, `'stop'`, and `'step'`). + Defaults to `{'start': 0, 'step': 1}`. :type idx_slice: dict[str, int], optional :return: Slice of map data, ready to be written to a map container. @@ -81,102 +89,207 @@ def process(self, data, """ import numpy as np import os - from chess_scanparsers import choose_scanparser - from CHAP.common.models.map import SpecScans - ScanParser = choose_scanparser( - self.map_config.station, self.map_config.experiment_type) - scans = SpecScans( - spec_file=self.spec_file, scan_numbers=[self.scan_number]) - scan = scans.get_scanparser(self.scan_number) + scans_obj = SpecScans( + spec_file=self.spec_file, scan_numbers=self.scan_numbers) + self_spec_file_abs = os.path.abspath(str(self.spec_file)) + + if self.map_config.experiment_type == 'EDD': + def get_detector_data(scan, detector, index): + return scan.get_detector_data(detector.get_id(), index)[0][0] + else: + def get_detector_data(scan, detector, index): + return scan.get_detector_data(detector.get_id(), index) - # Get index offset for this data slice within the map - npts_scan = int(scan.spec_scan_npts) - nscans_prev = 0 - for scans in self.map_config.spec_scans: - for scan_n in scans.scan_numbers: - if (os.path.abspath(self.spec_file) == \ - os.path.abspath(self.spec_file) - and scan_n == self.scan_number): + # Build flat ordered list of (abs_spec_file, scan_number) for + # all scans in the map to determine each scan's map position + map_scan_order = [] + for spec_scans_item in self.map_config.spec_scans: + sf_abs = os.path.abspath(str(spec_scans_item.spec_file)) + for sn in spec_scans_item.scan_numbers: + map_scan_order.append((sf_abs, sn)) + scan_positions = {} + for sn in self.scan_numbers: + for pos, (sf, n) in enumerate(map_scan_order): + if sf == self_spec_file_abs and n == sn: + scan_positions[sn] = pos break - nscans_prev += 1 - index_offset = nscans_prev * npts_scan - # Get spec scan indices to process - scan_indices = range(npts_scan)[slice( - idx_slice.get('start', 0), - idx_slice.get('stop', npts_scan), - idx_slice.get('step', 1) - )] - # Get map indices to write to - map_indices = slice( - idx_slice.get('start', 0) + index_offset, - idx_slice.get('stop', npts_scan) + index_offset, - idx_slice.get('step', 1) + # Process scans in map order + sorted_scan_numbers = sorted( + self.scan_numbers, key=lambda sn: scan_positions[sn]) + + slice_start = idx_slice.get('start', 0) + slice_step = idx_slice.get('step', 1) + + # Collect per-scan metadata; assumes uniform npts across scans + # for index_offset calculation (index_offset = map_pos * npts) + per_scan = [] + for sn in sorted_scan_numbers: + scan = scans_obj.get_scanparser(sn) + npts_scan = int(scan.spec_scan_npts) + index_offset = scan_positions[sn] * npts_scan + # Cap stop at npts_scan so map_indices and data stay in sync + slice_stop = min(idx_slice.get('stop', npts_scan), npts_scan) + scan_indices = range(npts_scan)[ + slice(slice_start, slice_stop, slice_step)] + map_indices = slice( + slice_start + index_offset, + slice_stop + index_offset, + slice_step, + ) + per_scan.append({ + 'sn': sn, 'scan': scan, 'npts_scan': npts_scan, + 'index_offset': index_offset, + 'scan_indices': scan_indices, + 'map_indices': map_indices, + 'full': (slice_start == 0 + and slice_step == 1 + and slice_stop == npts_scan), + }) + + # Consolidate into single data_points entries when all scans + # are adjacent in the map and idx_slice covers each scan fully + sorted_positions = [scan_positions[sn] for sn in sorted_scan_numbers] + scans_are_adjacent = all( + sorted_positions[i + 1] == sorted_positions[i] + 1 + for i in range(len(sorted_positions) - 1) + ) + can_consolidate = ( + len(per_scan) > 1 + and scans_are_adjacent + and all(ps['full'] for ps in per_scan) ) - data_points = [ - { - 'path': f'{self.map_config.title}/independent_dimensions/index', - 'data': np.asarray( - [i for i in range(index_offset, index_offset + npts_scan)] + if can_consolidate: + first, last = per_scan[0], per_scan[-1] + merged_idx = slice( + first['index_offset'], + last['index_offset'] + last['npts_scan'], + 1, + ) + data_points = [{ + 'path': (f'{self.map_config.title}' + f'/independent_dimensions/index'), + 'data': np.arange( + first['index_offset'], + last['index_offset'] + last['npts_scan'], ), - 'idx': map_indices - } - ] - data_points.extend( - [ - { - 'path': f'{self.map_config.title}/scalar_data/{s_d.label}', + 'idx': merged_idx, + }] + for s_d in self.map_config.all_scalar_data: + data_points.append({ + 'path': (f'{self.map_config.title}' + f'/scalar_data/{s_d.label}'), + 'data': np.concatenate([ + np.asarray([ + s_d.get_value( + scans_obj, ps['sn'], i, + scalar_data=self.map_config.scalar_data) + for i in ps['scan_indices'] + ]) + for ps in per_scan + ]), + 'idx': merged_idx, + }) + for dim in self.map_config.independent_dimensions: + data_points.append({ + 'path': (f'{self.map_config.title}' + f'/independent_dimensions/{dim.label}'), + 'data': np.concatenate([ + np.asarray([ + dim.get_value( + scans_obj, ps['sn'], i, + scalar_data=self.map_config.scalar_data) + for i in ps['scan_indices'] + ]) + for ps in per_scan + ]), + 'idx': merged_idx, + }) + for det in self.detectors: + data_points.append({ + 'path': (f'{self.map_config.title}' + f'/data/{det.get_id()}'), + 'data': np.concatenate([ + np.asarray([ + get_detector_data(ps['scan'], det, i) + for i in ps['scan_indices'] + ]) + for ps in per_scan + ]), + 'idx': merged_idx, + }) + else: + data_points = [] + for ps in per_scan: + data_points.append({ + 'path': (f'{self.map_config.title}' + f'/independent_dimensions/index'), + 'data': np.asarray( + [ps['index_offset'] + i for i in ps['scan_indices']] + ), + 'idx': ps['map_indices'], + }) + data_points.extend([{ + 'path': (f'{self.map_config.title}' + f'/scalar_data/{s_d.label}'), 'data': np.asarray([ s_d.get_value( - scans, self.scan_number, i, + scans_obj, ps['sn'], i, scalar_data=self.map_config.scalar_data) - for i in scan_indices + for i in ps['scan_indices'] ]), - 'idx': map_indices - } - for s_d in self.map_config.all_scalar_data - ] - ) - data_points.extend( - [ - { - 'path': f'{self.map_config.title}/independent_dimensions/{dim.label}', + 'idx': ps['map_indices'], + } for s_d in self.map_config.all_scalar_data]) + data_points.extend([{ + 'path': (f'{self.map_config.title}' + f'/independent_dimensions/{dim.label}'), 'data': np.asarray([ dim.get_value( - scans, self.scan_number, i, - scalar_data=self.map_config.scalar_data - ) - for i in scan_indices + scans_obj, ps['sn'], i, + scalar_data=self.map_config.scalar_data) + for i in ps['scan_indices'] ]), - 'idx': map_indices, - } - for dim in self.map_config.independent_dimensions - ] - ) - if self.map_config.experiment_type == 'EDD': - def get_detector_data(detector, index): - return scan.get_detector_data(detector.get_id(), index)[0][0] - else: - def get_detector_data(detector, index): - return scan.get_detector_data(detector.get_id(), index) - data_points.extend( - [ - { - 'path': f'{self.map_config.title}/data/{det.get_id()}', + 'idx': ps['map_indices'], + } for dim in self.map_config.independent_dimensions]) + data_points.extend([{ + 'path': (f'{self.map_config.title}' + f'/data/{det.get_id()}'), 'data': np.asarray([ - get_detector_data(det, i) - for i in scan_indices + get_detector_data(ps['scan'], det, i) + for i in ps['scan_indices'] ]), - 'idx': map_indices - } - for det in self.detectors - ] - ) + 'idx': ps['map_indices'], + } for det in self.detectors]) return data_points + @model_validator(mode='before') + @classmethod + def fill_scan_numbers(cls, data): + if not isinstance(data, dict): + return data + if 'scan_numbers' not in data or data['scan_numbers'] is None: + if data.get('scan_number') is not None: + data['scan_numbers'] = [data['scan_number']] + elif isinstance(data['scan_numbers'], int): + data['scan_numbers'] = [data['scan_numbers']] + elif isinstance(data['scan_numbers'], str): + from CHAP.utils.general import string_to_list + data['scan_numbers'] = string_to_list(data['scan_numbers']) + return data + + @model_validator(mode='after') + def validate_scan_numbers(self): + if self.scan_numbers is None: + raise ValueError( + 'scan_numbers is required; alternatively, provide scan_number') + if self.scan_number is not None \ + and self.scan_number not in self.scan_numbers: + self.scan_numbers.append(self.scan_number) + return self + class SpecScanToMapConfigProcessor(Processor): """Processor to get the `CHAP.common.models.map.MapConfig` From 9cf6b798fb8dec8c4e734b4a22d47b2656556155 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Thu, 30 Apr 2026 12:07:52 -0400 Subject: [PATCH 38/76] add: support reading data from multiple scans with edd.SliceNXdataReader --- CHAP/edd/reader.py | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/CHAP/edd/reader.py b/CHAP/edd/reader.py index ad4eb757..db3cf934 100755 --- a/CHAP/edd/reader.py +++ b/CHAP/edd/reader.py @@ -15,6 +15,7 @@ conlist, constr, field_validator, + model_validator, ) # Local modules @@ -749,10 +750,12 @@ class SliceNXdataReader(Reader): from an NXdata group and slices all fields according to the provided slicing parameters. - :param scan_number: Number of the SPEC scan. - :vartype scan_number: int + :ivar scan_numbers: Numbers of scans from which to read slices of + raw data. + :type scan_numbers: list[int] """ - scan_number: conint(ge=0) + scan_number: Optional[conint(gt=0)] = None + scan_numbers: Optional[conlist(item_type=conint(gt=0))] = None def read(self): """Reads an NXdata group from a NeXus file and slices the @@ -811,7 +814,7 @@ def read(self): # Get indicies of SCAN_N that match self.scan_number scan_field = nxdata['SCAN_N'].nxdata - indices = np.flatnonzero(scan_field == self.scan_number) + indices = np.flatnonzero(np.isin(scan_field, self.scan_numbers)) if indices.size == 0: self.logger.warning( @@ -827,6 +830,30 @@ def read(self): return nxroot + @model_validator(mode='before') + @classmethod + def fill_scan_numbers(cls, data): + if not isinstance(data, dict): + return data + if 'scan_numbers' not in data or data['scan_numbers'] is None: + if data.get('scan_number') is not None: + data['scan_numbers'] = [data['scan_number']] + elif isinstance(data['scan_numbers'], int): + data['scan_numbers'] = [data['scan_numbers']] + elif isinstance(data['scan_numbers'], str): + from CHAP.utils.general import string_to_list + data['scan_numbers'] = string_to_list(data['scan_numbers']) + return data + + @model_validator(mode='after') + def validate_scan_numbers(self): + if self.scan_numbers is None: + raise ValueError( + 'scan_numbers is required; alternatively, provide scan_number') + if self.scan_number is not None \ + and self.scan_number not in self.scan_numbers: + self.scan_numbers.append(self.scan_number) + return self class UpdateNXdataReader(Reader): """Companion to `edd.SetupNXdataReader` and From 9fba14e627b999f7014bc1c0bc412252bea06d43 Mon Sep 17 00:00:00 2001 From: Rolf Verberg Date: Thu, 30 Apr 2026 20:34:01 -0400 Subject: [PATCH 39/76] add: simplified script based calling and added an EDD calib script Also update the tomo script to the new CHAP calling style --- CHAP/common/processor.py | 2 +- CHAP/common/reader.py | 9 +- CHAP/models.py | 6 +- CHAP/pipeline.py | 71 ++++++++++- CHAP/utils/fit.py | 24 ++-- CHAP/utils/general.py | 8 +- CHAP/utils/models.py | 4 +- examples/edd/edd_calibration_script.py | 158 +++++++++++++++++++++++++ examples/tomo/tomo_script_id3b.py | 118 +++++++++--------- 9 files changed, 306 insertions(+), 94 deletions(-) create mode 100644 examples/edd/edd_calibration_script.py diff --git a/CHAP/common/processor.py b/CHAP/common/processor.py index 0d198c11..0665b101 100755 --- a/CHAP/common/processor.py +++ b/CHAP/common/processor.py @@ -2422,7 +2422,7 @@ def process( :type npt: int :param mask_file: File to use for masking the input data. :type mask_file: str, optional - :param integrate1d_kwargs: Optional dictionary of keywords + :param integrate1d_kwargs: Optional dictionary of keywords. :type integrate1d_kwargs: Optional[dict] :returns: Azimuthal integration results as a dictionary of numpy arrays. diff --git a/CHAP/common/reader.py b/CHAP/common/reader.py index 398fe189..8d6042bc 100755 --- a/CHAP/common/reader.py +++ b/CHAP/common/reader.py @@ -550,7 +550,7 @@ class PandasReader(Reader): `pandas `__ """ - def read(self, filename, method='read_csv', comment='#', kwargs=None): + def read(self, filename, method='read_csv', comment='#', **kwargs): """Return a `pandas.DataFrame` read from the given file. :param filename: Name of file to read from. @@ -561,9 +561,9 @@ def read(self, filename, method='read_csv', comment='#', kwargs=None): :param comment: Character to identify comment lines in the input file, defaults to `'#'`. :type comment: str, optional - :param kwargs: Additional keyword arguments to supply to the + :param \*\*kwargs: Additional keyword arguments to supply to the `pandas` reader. - :param kwargs: dict, optional. + :type: dict, optional. :rtype: `pandas.DataFrame` """ # Third party modules @@ -575,8 +575,6 @@ def read(self, filename, method='read_csv', comment='#', kwargs=None): raise ValueError( f'{method} is not a callable pandas reader method') - if kwargs is None: - kwargs = {} if not isinstance(kwargs, dict): raise TypeError( f'Invalid kwargs type ({type(kwargs)}, should be dict)') @@ -921,7 +919,6 @@ def read(self): detector_roi=[ self.detector_config.roi[0].toslice(), self.detector_config.roi[1].toslice()] - print(f'\n\ndetector_roi: {detector_roi}\n\n') nxdata[detector.get_id()] = NXfield( value=scanparser.get_detector_data( detector.get_id(), diff --git a/CHAP/models.py b/CHAP/models.py index 7a7df819..9b4ad36c 100755 --- a/CHAP/models.py +++ b/CHAP/models.py @@ -34,7 +34,7 @@ def dict(self, *args, **kwargs): """Dump the class implemention to a dictionary. :param \*\*kwargs: Arbitrary keyword arguments. - :type: dict + :type: dict, optional :keyword exclude: Class variable(s) to omit from the output dictionary. :type exclude: dict or set, optional @@ -50,7 +50,7 @@ def model_dump(self, *args, **kwargs): """Dump the class implemention to a dictionary. :param \*\*kwargs: Arbitrary keyword arguments. - :type: dict + :type: dict, optional :keyword exclude: Class variable(s) to omit from the output dictionary. :type exclude: dict or set, optional @@ -71,7 +71,7 @@ def model_dump_json(self, *args, **kwargs): """Dump the class implemention to a JSON string. :param \*\*kwargs: Arbitrary keyword arguments. - :type: dict + :type: dict, optional :keyword exclude: Class variable(s) to omit from the output dictionary. :type exclude: dict or set, optional diff --git a/CHAP/pipeline.py b/CHAP/pipeline.py index ce854e1f..47df472d 100755 --- a/CHAP/pipeline.py +++ b/CHAP/pipeline.py @@ -444,18 +444,17 @@ def unwrap_pipelinedata(data): unwrapped_data = [data] return unwrapped_data - def execute(self, data):#, metadata, provenance): + def execute(self, data=None):#, metadata, provenance): """Execute the appropriate method of the object and return the result. :param data: Input data. - :type data: list[PipelineData] + :type data: list[PipelineData], optional :return: Wrapped result of executing read, process, or write. :rtype: PipelineData | tuple[PipelineData] """ # self._metadata = metadata # self._provenance = provenance - if 'data' in self._allowed_args: self._args['data'] = data t0 = time() @@ -467,6 +466,72 @@ def execute(self, data):#, metadata, provenance): f'Finished "{self._method}" in {time()-t0:.0f} seconds\n') return data + @classmethod + def run(cls, **kwargs): + """Execute the appropriate method of the object and return the + result. + + This class method gets and executes the appropriate method + (process, read or write) from the pipeline item it's called + from. It is intended to be called from a script or notebook + only and should not be called from other CHAP Processors, + Readers or Writers. + + The method expects the same parameters as those used to + instantiate its class object and run the process, read or + write method, in addition to any run time parameter in the + pipeline file config dictionary (see: + :class:`~CHAP.models.RunConfig)`. + + :param \*\*kwargs: Arbitrary keyword arguments, including: + :type: dict, optional + :keyword config: Initialization parameters for an instance + of the pipeline item this method is called from (often + used by Readers and Processors). + :type config: dict, optional + :keyword data: Input data (required for any Processor, but + is allowed to be `None` or `[]`). + :type data: list[PipelineData], optional + :keyword filename: Name of file to read (required for most + Readers and Writers). + :type filename: str, optional + :keyword force_overwrite: Flag to allow data in `filename` + to be overwritten if it already exists, defaults to + `False` (optional for Writers). + :type force_overwrite: bool, optional + :keyword remove: Flag to remove the dictionary from `data`, + defaults to `False` (optional for Writers). + :type remove: bool, optional + :return: Returned result from executing the underlying read, + process, or write method. + :rtype: Any + """ + # System modules + from importlib import import_module + from inspect import isclass + from pkgutil import walk_packages + + def _find_class_in_package(package_name, class_name): + package = import_module(package_name) + # Recursively walk through all submodules + for _, module_name, _ in walk_packages( + package.__path__, package.__name__ + "."): + try: + module = import_module(module_name) + # Check if the class exists in this module + if hasattr(module, class_name): + cls = getattr(module, class_name) + if isclass(cls) and cls.__module__ == module_name: + return cls + except ImportError: + continue + return None + + cls_name = cls.__name__ + mmc = _find_class_in_package('CHAP', cls_name) + item = mmc(modelmetaclass=mmc, **kwargs) + return item.execute(kwargs.get('data')) + class Pipeline(CHAPBaseModel): """Class representing a full `Pipeline` object. diff --git a/CHAP/utils/fit.py b/CHAP/utils/fit.py index c626b74b..6f65c6ed 100755 --- a/CHAP/utils/fit.py +++ b/CHAP/utils/fit.py @@ -1258,9 +1258,9 @@ def fit(self, config=None, **kwargs): :param config: Fit configuration. :type config: CHAP.utils.models.FitConfig, optional - :param kwargs: Additional key, value pairs to pass on directly - to the core fit routine. - :type kwargs: dict + :param \*\*kwargs: Additional key, value pairs to pass on + directly to the core fit routine. + :type: dict, optional """ # Check input parameters if self._model is None: @@ -1359,9 +1359,9 @@ class attribute. :type plot_residual: bool, optional :param plot_masked_data: :type plot_masked_data: bool, optional - :param kwargs: Additional key, value pairs to pass on directly - to the Matplotlib plot function. - :type kwargs: dict + :param \*\*kwargs: Additional key, value pairs to pass on + directly to the Matplotlib plot function. + :type: dict, optional """ if result is None: result = self._result @@ -2645,9 +2645,9 @@ def plot( :type plot_residual: bool, optional :param plot_masked_data: :type plot_masked_data: bool, optional - :param kwargs: Additional key, value pairs to pass on directly - to the Matplotlib plot function. - :type kwargs: dict + :param \*\*kwargs: Additional key, value pairs to pass on + directly to the Matplotlib plot function. + :type: dict, optional """ # Third party modules from lmfit.models import ExpressionModel @@ -2720,9 +2720,9 @@ def fit(self, config=None, **kwargs): :param config: Fit configuration. :type config: CHAP.utils.models.FitConfig, optional - :param kwargs: Additional key, value pairs to pass on directly - to the core fit routine. - :type kwargs: dict + :param \*\*kwargs: Additional key, value pairs to pass on + directly to the core fit routine. + :type: dict, optional """ # Check input parameters if self._model is None: diff --git a/CHAP/utils/general.py b/CHAP/utils/general.py index 6a717336..4c9f8c02 100755 --- a/CHAP/utils/general.py +++ b/CHAP/utils/general.py @@ -2735,9 +2735,9 @@ def quick_imshow( :type grid_linewidth: int, optional :param colorbar: Include a colorbar, defaults to `False`. :type colorbar: bool, optional - :param kwargs: Any additional keyword parameters to pass on to + :param \*\*kwargs: Any additional keyword parameters to pass on to `matplotlib.pyplot.imshow `__. - :type kwargs: dict, optional + :type: dict, optional :raise: ValueError for invalid input data or parameters. :return: In-memory object as a byte stream represention if `return_fig` is set. @@ -2853,9 +2853,9 @@ def quick_plot( :type save_only: bool, optional :param block: Wait for the image to be closed before returning. :type block: bool, optional - :param kwargs: Any additional keyword parameters to pass on to + :param \*\*kwargs: Any additional keyword parameters to pass on to `matplotlib.pyplot.plot `__ - :type kwargs: dict, optional + :type: dict, optional :raise: ValueError for invalid input data or parameters. """ #FIX: Update with return_buf diff --git a/CHAP/utils/models.py b/CHAP/utils/models.py index 99d8a714..ebdeb88a 100755 --- a/CHAP/utils/models.py +++ b/CHAP/utils/models.py @@ -300,7 +300,7 @@ def validate_parameters(parameters, info): :rtype: list[FitParameter] """ # System imports - import inspect + from inspect import signature if 'model' in info.data: model = info.data['model'] @@ -308,7 +308,7 @@ def validate_parameters(parameters, info): model = None if model is None or model == 'expression': return parameters - sig = dict(inspect.signature(models[model]['name']).parameters.items()) + sig = dict(signature(models[model]['name']).parameters.items()) sig.pop('x') # Check input model parameter validity diff --git a/examples/edd/edd_calibration_script.py b/examples/edd/edd_calibration_script.py new file mode 100644 index 00000000..fc3e25ff --- /dev/null +++ b/examples/edd/edd_calibration_script.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +"""A python script version of the CHAP pipeline version of the +calibration example. + +Modify any user choices in the "Start user input" and "Optional user +input" blocks below. + +In particular: + +* Choose the path to your CHAP repo by setting sys.path. +""" + +# System modules +import os +from pprint import pprint +import sys +from tempfile import NamedTemporaryFile + +# Third party modules +import yaml + +# Local modules +from CHAP.pipeline import PipelineData +from CHAP.common.reader import SpecReader +from CHAP.edd.processor import ( + MCAEnergyCalibrationProcessor, + MCATthCalibrationProcessor, +) + +#------------------------------------------------------------------------------# +# Start user input +#------------------------------------------------------------------------------# + +# Choose the path to your CHAP repo +sys.path.append( + '/home/rv43/Documents/Programs/repos/CHESSComputing/ChessAnalysisPipeline_main') +#sys.path.append( +# '/nfs/chess/sw/CHESS-software-releases/repos/prod/ChessAnalysisPipeline') + +root = None #'examples/edd' +interactive = True +outputdir = 'output' +energy_calib = True +tth_calib = True + +if root is None: + root = os.getcwd() +if outputdir is None: + outputdir = '.' +if not os.path.isabs(outputdir): + outputdir = os.path.normpath( + os.path.realpath(os.path.join(root, outputdir))) + +energy_name = f'{outputdir}/energy_calibration_result.yaml' +tth_name = f'{outputdir}/tth_calibration_result.yaml' +spec_file = f'{root}/data/ceo2-5deg-80um-calib/spec.log' +scan_numbers = 1 +detector_ids = [0, 21] + +#------------------------------------------------------------------------------# +# Optional user input +#------------------------------------------------------------------------------# + +spec_config = { + 'station': 'id3a', + 'experiment_type': 'EDD', + 'spec_scans': [ + {'spec_file': spec_file, + 'scan_numbers': scan_numbers} + ], +} +energy_config = { + 'peak_energies': [34.276, 34.717, 39.255, 40.231], + 'max_peak_index': 1, + 'materials': [ + {'material_name': 'CeO2', + 'lattice_parameters': 5.41153, + 'sgnum': 225} + ], +} +energy_detector_config = { + 'baseline': True, + 'mask_ranges': [[650, 850]], +} +if detector_ids is not None: + energy_detector_config['detectors'] = [{'id':id_} for id_ in detector_ids] +tth_config = {'tth_initial_guess': 5.2} +tth_detector_config = { + 'energy_mask_ranges': [[65, 155]], +} + +#------------------------------------------------------------------------------# +# End user input +#------------------------------------------------------------------------------# + +if not os.path.isdir(outputdir): + os.makedirs(outputdir) +try: + NamedTemporaryFile(dir=outputdir) +except Exception as exc: + raise OSError( + 'Output directory not accessible for writing ({outputdir})') from exc + +#------------------------------------------------------------------------------# + +print(f'\nspec_config:') +pprint(spec_config) +if energy_calib: + print(f'\nenergy_config:') + pprint(energy_config) + print(f'\nenergy_detector_config:') + pprint(energy_detector_config) +if tth_calib: + print(f'\ntth_config:') + pprint(tth_config) + print(f'\ntth_detector_config:') + pprint(tth_detector_config) + +# Perform the energy calibration +if energy_calib: + # Read the calibration data + spec_reader = SpecReader(config=spec_config) + nxroot = spec_reader.read() + + # Perform the energy calibration + data = [PipelineData(name='SpecReader', data=nxroot)] + energy_calib_config, images = MCAEnergyCalibrationProcessor.run( + data=data, config=energy_config, detector_config=energy_detector_config, + interactive=interactive) + + # Write the energy calibration results + with open(energy_name, 'w', encoding='utf-8') as f: + yaml.dump(energy_calib_config, f, sort_keys=False) + +# Perform the tth calibration +if tth_calib: + # Read the energy calibration results + with open(energy_name, encoding='utf-8') as f: + energy_calib_config = yaml.safe_load(f) + + # Read the calibration data + spec_reader = SpecReader(config=spec_config) + nxroot = spec_reader.read() + + # Perform the tth calibration + data = [ + PipelineData( + data=energy_calib_config, + schema='edd.models.MCAEnergyCalibrationConfig'), + PipelineData(name='SpecReader', data=nxroot), + ] + tth_calib_config, images = MCATthCalibrationProcessor.run( + data=data, config=tth_config, detector_config=tth_detector_config, + interactive=interactive) + with open(tth_name, 'w', encoding='utf-8') as f: + yaml.dump(tth_calib_config, f, sort_keys=False) + diff --git a/examples/tomo/tomo_script_id3b.py b/examples/tomo/tomo_script_id3b.py index 90770b31..c6eed623 100644 --- a/examples/tomo/tomo_script_id3b.py +++ b/examples/tomo/tomo_script_id3b.py @@ -1,24 +1,36 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +"""A python script version of the CHAP pipeline version of the id3b +example. -"""Choose the path to your CHESS repo.""" -import sys -#sys.path.append( -# '/home/rv43/Documents/Programs/repos/CHESSComputing/ChessAnalysisPipeline') -sys.path.append( - '/home/rv43/Documents/Programs/repos/CHESSComputing/ChessAnalysisPipeline_main') -#sys.path.append( -# '/nfs/chess/sw/CHESS-software-releases/repos/prod/ChessAnalysisPipeline') +Modify any user choices in the "Start user input" and "Optional user +input" blocks below. + +In particular: + +* Choose the path to your CHAP repo by setting sys.path. + +* Select the run_type below from: + =========== ============================================== + run_type Description + =========== ============================================== + 'normal' Full regular reconstruction + (like running the default hollow_cube in CHAP) + 'one_plane' Reconstruction of two planes only + 'two_planes' Reconstruction of one plane only + 'roi' Reconstruction for a subset or rows + (find center at first and last row) + =========== ============================================== +""" # System modules import os from pprint import pprint +import sys from tempfile import NamedTemporaryFile # Third party modules -import matplotlib.pyplot as plt from nexusformat.nexus import nxload -from PIL import Image import yaml # Local modules @@ -37,14 +49,13 @@ # Start user input #------------------------------------------------------------------------------# -"""Select run_type from -'normal': Full regular reconstruction - (like running the default hollow_cube in CHAP) -'one_plane': Reconstruction of two planes only -'two_planes': Reconstruction of one plane only -'roi': Reconstruction for a subset or rows - (find center at first and last row) -""" +# Choose the path to your CHAP repo +#sys.path.append( +# '/home/rv43/Documents/Programs/repos/CHESSComputing/ChessAnalysisPipeline_main') +sys.path.append( + '/nfs/chess/sw/CHESS-software-releases/repos/prod/ChessAnalysisPipeline') + +# Select run_type #run_type = 'normal' #run_type = 'one_plane' #run_type = 'two_planes' @@ -206,44 +217,40 @@ print(f'\ndetector_config:') pprint(detector_config) -# Construct the CHESS style tomo map +# Construct the CHAP style tomo map if not construct_chess_map: nxroot = nxload(chess_map_name) else: # Create the map for the tomo stack - tomo_map = MapProcessor(config=map_config, detector_config=detector_config) - - # Read the map for the tomo stack - tomofields = tomo_map.process(data=None) + tomofields = MapProcessor.run( + config=map_config, detector_config=detector_config) tomofields.save(map_name, mode='w') # Read the dark field - tdf_spec_reader = SpecReader( + darkfield = SpecReader.run( config={ - 'station': tomo_map.config.station, - 'experiment_type': tomo_map.config.experiment_type, - 'sample': tomo_map.config.sample, + 'station': map_config['station'], + 'experiment_type': map_config['experiment_type'], + 'sample': map_config['sample'], 'spec_scans': [ - {'spec_file': tomo_map.config.spec_scans[0].spec_file, + {'spec_file': map_config['spec_scans'][0].spec_file, 'scan_numbers': tdf_scan_numbers}], }, detector_config=detector_config) - darkfield = tdf_spec_reader.read() # Read the bright field - tbf_spec_reader = SpecReader( + brightfield = SpecReader.run( config={ - 'station': tomo_map.config.station, - 'experiment_type': tomo_map.config.experiment_type, - 'sample': tomo_map.config.sample, + 'station': map_config['station'], + 'experiment_type': map_config['experiment_type'], + 'sample': map_config['sample'], 'spec_scans': [ - {'spec_file': tomo_map.config.spec_scans[0].spec_file, + {'spec_file': map_config['spec_scans'][0].spec_file, 'scan_numbers': tbf_scan_numbers}], }, detector_config=detector_config) - brightfield = tbf_spec_reader.read() - # Convert to CHESS style tomography map + # Convert to CHAP style tomography map data = [ PipelineData( name='MapProcessor', data=tomofields, schema='tomofields'), @@ -254,8 +261,7 @@ name='YAMLReader', data=detector_setup, schema='tomo.models.Detector') ] - chess_map = TomoCHESSMapConverter() - _, _, nxroot = chess_map.process(data=data) + _, _, nxroot = TomoCHESSMapConverter.run(data=data) nxroot = nxroot['data'] nxroot.save(chess_map_name, mode='w') @@ -267,26 +273,17 @@ # Reduce the data with remove_all_stripe data = [ PipelineData(name='TomoCHESSMapConverter', data=nxroot, schema=None)] - tomo = TomoReduceProcessor( - config=reduce_config, interactive=interactive) - (metadata, provenance, images, reduced_data) = tomo.process(data) + _, _, _, reduced_data = TomoReduceProcessor.run( + data=data, config=reduce_config, interactive=interactive) reduced_data = reduced_data['data'] reduced_data.save(reduced_data_name, mode='w') -# for (buf, ext), name in images.get('data', []): -# buf.seek(0) -# plt.imshow(Image.open(buf)) -# plt.axis('off') -# plt.tight_layout() -# plt.show(block=True) -# buf.close() -# plt.close() nxentry = reduced_data[reduced_data.default] nxdata = nxentry[nxentry.default] image_slice = nxdata.nxsignal[0,:,0,:] vmin = image_slice.min() vmax = image_slice.max() - if len(recon_layer_indices): + if recon_layer_indices: quick_imshow( image_slice, title=f'Slice {detector_config["roi"][0]["start"]}', @@ -310,10 +307,10 @@ data = [PipelineData( name='TomoCHESSMapConverter', data=nxroot, schema=None)] - tomo = TomoReduceProcessor( + _, _, _, reduced_data = TomoReduceProcessor.run( + data=data, config=reduce_config.update({'remove_stripe': {method: {}}}), interactive=interactive) - (metadata, provenance, images, reduced_data) = tomo.process(data) reduced_data = reduced_data['data'] reduced_data.save(f'reduced_data_{method}.nxs', mode='w') @@ -338,11 +335,10 @@ data = [PipelineData( name='TomoCHESSMapConverter', data=reduced_data, schema=None)] - tomo = TomoFindCenterProcessor( - interactive=interactive, config=find_center_config) - (metadata, provenance, images, center_config) = tomo.process(data) + _, _, _, center_config = TomoFindCenterProcessor.run( + data=data, config=find_center_config, interactive=interactive) center_config = center_config['data'] - with open(find_center_name, 'w') as f: + with open(find_center_name, 'w', encoding='utf-8') as f: yaml.dump(center_config, f, sort_keys=False) # Reconstruct @@ -351,7 +347,7 @@ pass else: reduced_data = nxload(reduced_data_name) - with open(find_center_name) as f: + with open(find_center_name, encoding='utf-8') as f: center_config = yaml.safe_load(f) data = [ @@ -364,12 +360,8 @@ # name='TomoFindCenterProcessor', data=center_config, # schema='tomo.models.TomoFindCenterConfig') ] - from pprint import pprint - pprint(reconstruct_config) - tomo = TomoReconstructProcessor( - config=reconstruct_config, - center_config=center_config, + _, _, _, reconstructed_data = TomoReconstructProcessor.run( + data=data, config=reconstruct_config, center_config=center_config, interactive=interactive) - (metadata, provenance, images, reconstructed_data) = tomo.process(data) reconstructed_data = reconstructed_data['data'] reconstructed_data.save(reconstructed_data_name, mode='w') From 4f4ce623dc01648644a6b4cb06ad068c44f34fea Mon Sep 17 00:00:00 2001 From: Rolf Verberg Date: Thu, 30 Apr 2026 21:02:55 -0400 Subject: [PATCH 40/76] docs: cleaned up kwargs docstrings --- CHAP/common/reader.py | 1 - CHAP/models.py | 9 +++------ CHAP/pipeline.py | 3 +-- CHAP/utils/fit.py | 4 ---- CHAP/utils/general.py | 2 -- 5 files changed, 4 insertions(+), 15 deletions(-) diff --git a/CHAP/common/reader.py b/CHAP/common/reader.py index 8d6042bc..de9b78a9 100755 --- a/CHAP/common/reader.py +++ b/CHAP/common/reader.py @@ -563,7 +563,6 @@ def read(self, filename, method='read_csv', comment='#', **kwargs): :type comment: str, optional :param \*\*kwargs: Additional keyword arguments to supply to the `pandas` reader. - :type: dict, optional. :rtype: `pandas.DataFrame` """ # Third party modules diff --git a/CHAP/models.py b/CHAP/models.py index 9b4ad36c..435c2b57 100755 --- a/CHAP/models.py +++ b/CHAP/models.py @@ -33,8 +33,7 @@ class CHAPBaseModel(BaseModel): def dict(self, *args, **kwargs): """Dump the class implemention to a dictionary. - :param \*\*kwargs: Arbitrary keyword arguments. - :type: dict, optional + :param \*\*kwargs: Optional keyword arguments, including: :keyword exclude: Class variable(s) to omit from the output dictionary. :type exclude: dict or set, optional @@ -49,8 +48,7 @@ class variables that have an alias., defaults to `True`. def model_dump(self, *args, **kwargs): """Dump the class implemention to a dictionary. - :param \*\*kwargs: Arbitrary keyword arguments. - :type: dict, optional + :param \*\*kwargs: Optional keyword arguments, including: :keyword exclude: Class variable(s) to omit from the output dictionary. :type exclude: dict or set, optional @@ -70,8 +68,7 @@ class variables that have an alias., defaults to `True`. def model_dump_json(self, *args, **kwargs): """Dump the class implemention to a JSON string. - :param \*\*kwargs: Arbitrary keyword arguments. - :type: dict, optional + :param \*\*kwargs: Optional keyword arguments, including: :keyword exclude: Class variable(s) to omit from the output dictionary. :type exclude: dict or set, optional diff --git a/CHAP/pipeline.py b/CHAP/pipeline.py index 47df472d..802a9b38 100755 --- a/CHAP/pipeline.py +++ b/CHAP/pipeline.py @@ -483,8 +483,7 @@ def run(cls, **kwargs): pipeline file config dictionary (see: :class:`~CHAP.models.RunConfig)`. - :param \*\*kwargs: Arbitrary keyword arguments, including: - :type: dict, optional + :param \*\*kwargs: Optional keyword arguments, including: :keyword config: Initialization parameters for an instance of the pipeline item this method is called from (often used by Readers and Processors). diff --git a/CHAP/utils/fit.py b/CHAP/utils/fit.py index 6f65c6ed..56489ab8 100755 --- a/CHAP/utils/fit.py +++ b/CHAP/utils/fit.py @@ -1260,7 +1260,6 @@ def fit(self, config=None, **kwargs): :type config: CHAP.utils.models.FitConfig, optional :param \*\*kwargs: Additional key, value pairs to pass on directly to the core fit routine. - :type: dict, optional """ # Check input parameters if self._model is None: @@ -1361,7 +1360,6 @@ class attribute. :type plot_masked_data: bool, optional :param \*\*kwargs: Additional key, value pairs to pass on directly to the Matplotlib plot function. - :type: dict, optional """ if result is None: result = self._result @@ -2647,7 +2645,6 @@ def plot( :type plot_masked_data: bool, optional :param \*\*kwargs: Additional key, value pairs to pass on directly to the Matplotlib plot function. - :type: dict, optional """ # Third party modules from lmfit.models import ExpressionModel @@ -2722,7 +2719,6 @@ def fit(self, config=None, **kwargs): :type config: CHAP.utils.models.FitConfig, optional :param \*\*kwargs: Additional key, value pairs to pass on directly to the core fit routine. - :type: dict, optional """ # Check input parameters if self._model is None: diff --git a/CHAP/utils/general.py b/CHAP/utils/general.py index 4c9f8c02..5b9cfb5a 100755 --- a/CHAP/utils/general.py +++ b/CHAP/utils/general.py @@ -2737,7 +2737,6 @@ def quick_imshow( :type colorbar: bool, optional :param \*\*kwargs: Any additional keyword parameters to pass on to `matplotlib.pyplot.imshow `__. - :type: dict, optional :raise: ValueError for invalid input data or parameters. :return: In-memory object as a byte stream represention if `return_fig` is set. @@ -2855,7 +2854,6 @@ def quick_plot( :type block: bool, optional :param \*\*kwargs: Any additional keyword parameters to pass on to `matplotlib.pyplot.plot `__ - :type: dict, optional :raise: ValueError for invalid input data or parameters. """ #FIX: Update with return_buf From d46b27d7f1fc1bc08f195b86260ad6faea4ac6cd Mon Sep 17 00:00:00 2001 From: Rolf Verberg Date: Fri, 1 May 2026 11:01:19 -0400 Subject: [PATCH 41/76] fix: fixed pipelineitem.run for duplicate class names --- CHAP/pipeline.py | 19 ++++++++++++++++--- examples/edd/edd_calibration_script.py | 6 ++---- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/CHAP/pipeline.py b/CHAP/pipeline.py index 802a9b38..1b206945 100755 --- a/CHAP/pipeline.py +++ b/CHAP/pipeline.py @@ -510,9 +510,13 @@ def run(cls, **kwargs): from inspect import isclass from pkgutil import walk_packages + # Local modules + from CHAP.utils.general import input_menu + def _find_class_in_package(package_name, class_name): package = import_module(package_name) # Recursively walk through all submodules + found_classes = [] for _, module_name, _ in walk_packages( package.__path__, package.__name__ + "."): try: @@ -520,11 +524,20 @@ def _find_class_in_package(package_name, class_name): # Check if the class exists in this module if hasattr(module, class_name): cls = getattr(module, class_name) - if isclass(cls) and cls.__module__ == module_name: - return cls + if (isclass(cls) and cls.__module__ == module_name + and cls not in found_classes): + found_classes.append(cls) except ImportError: continue - return None + if not found_classes: + raise ImportError(f'Unable to find {class_name} in CHAP') + if len(found_classes) == 1: + return found_classes[0] + index = input_menu( + [v.__module__ for v in found_classes], + header=f'\nFound multiple classes named {class_name} in ' + f'CHAP\nUse {class_name} from:') + return found_classes[index] cls_name = cls.__name__ mmc = _find_class_in_package('CHAP', cls_name) diff --git a/examples/edd/edd_calibration_script.py b/examples/edd/edd_calibration_script.py index fc3e25ff..b8ed5b9d 100644 --- a/examples/edd/edd_calibration_script.py +++ b/examples/edd/edd_calibration_script.py @@ -120,8 +120,7 @@ # Perform the energy calibration if energy_calib: # Read the calibration data - spec_reader = SpecReader(config=spec_config) - nxroot = spec_reader.read() + nxroot = SpecReader.run(config=spec_config) # Perform the energy calibration data = [PipelineData(name='SpecReader', data=nxroot)] @@ -140,8 +139,7 @@ energy_calib_config = yaml.safe_load(f) # Read the calibration data - spec_reader = SpecReader(config=spec_config) - nxroot = spec_reader.read() + nxroot = SpecReader.run(config=spec_config) # Perform the tth calibration data = [ From 661e81d8e62e97183c0217f8956e937d5b51cc72 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Fri, 1 May 2026 13:55:29 -0400 Subject: [PATCH 42/76] feat: support 0-scan maps in edd.StrainAnalysisProcessor with setup=True and update=False --- CHAP/edd/processor.py | 59 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 50 insertions(+), 9 deletions(-) diff --git a/CHAP/edd/processor.py b/CHAP/edd/processor.py index be95381f..146a28ae 100755 --- a/CHAP/edd/processor.py +++ b/CHAP/edd/processor.py @@ -2719,6 +2719,17 @@ def process(self, data): self.logger.warning( f'Skipping detector {detector_id} (Illegal data shape ' f'{raw_detector_data.shape})') + elif raw_detector_data.size == 0 and self.setup: + # 0-scan map: no spectra yet, include for setup + for k, v in nxdata[detector_id].attrs.items(): + detector.attrs[k] = v.nxdata + if self.config.rel_height_cutoff is not None: + detector.rel_height_cutoff = \ + self.config.rel_height_cutoff + detector.add_calibration( + calibration_detectors[ + int(calibration_detector_ids.index(detector_id))]) + detectors.append(detector) elif raw_detector_data.sum(): for k, v in nxdata[detector_id].attrs.items(): detector.attrs[k] = v.nxdata @@ -2770,9 +2781,16 @@ def process(self, data): # Apply the combined energy ranges mask self._apply_combined_mask() + # Populate _peak_fit_info when there are no spectra (0-scan setup) + no_raw_data = ( + bool(self._nxdata_detectors) + and self._nxdata_detectors[0].nxsignal.shape[0] == 0) + if no_raw_data: + self._populate_peak_fit_info() + # Setup and/or run the strain analysis results = {} - if self.update: + if self.update and not no_raw_data: results = self._strain_analysis() if self.setup: if self.standalone: @@ -2959,6 +2977,30 @@ def _add_fit_nxcollection(self, nxdetector, fit_type, hkls, peak_fit_info): ) nxcollection[hkl_name].strains.attrs['signal'] = 'values' + def _populate_peak_fit_info(self): + """Populate _peak_fit_info with all configured HKLs for each + detector, with all peaks marked as used. + + Called when no raw detector spectra are present (e.g. a 0-scan + setup run), so that the resulting NXprocess contains entries for + every HKL that could be encountered in a subsequent update run. + """ + # Local modules + from CHAP.edd.utils import get_peak_locations, get_unique_hkls_ds + + for detector in self.detector_config.detectors: + hkls, ds = get_unique_hkls_ds( + self.config.materials, tth_max=detector.tth_max, + tth_tol=detector.tth_tol) + hkls_fit = np.asarray([hkls[i] for i in detector.hkl_indices]) + ds_fit = np.asarray([ds[i] for i in detector.hkl_indices]) + peak_locations = get_peak_locations(ds_fit, detector.tth_calibrated) + self._peak_fit_info.append({ + 'hkls': [''.join(map(str, hkl)) for hkl in hkls_fit], + 'nominal_peak_centers': peak_locations.tolist(), + 'peak_models': detector.peak_models, + 'use_peaks': np.ones(len(hkls_fit), dtype=bool).tolist(), + }) def _create_animation( self, nxdata, energies, intensities, intensity_norms, best_fits, @@ -3188,17 +3230,16 @@ def _get_nxprocess(self, nxentry, calibration_config): ) # Add the detector data - _intensity=np.asarray( - [ - data[i].astype(np.float64)[mask] - for i in range(num_points) - ] - ) + num_energy_bins = mask.sum() + _intensity = np.empty( + (num_points, num_energy_bins), dtype=np.float64) + for i in range(num_points): + _intensity[i] = data[i].astype(np.float64)[mask] det_nxdata.intensity = NXfield( value=_intensity, attrs={'units': 'counts'}, - maxshape=(None,*_intensity.shape[1:]), - chunks=(1,*_intensity.shape[1:]) + maxshape=(None, num_energy_bins), + chunks=(1, num_energy_bins) ) det_nxdata.attrs['signal'] = 'intensity' From fc796f0db681ba57da382c6fff1faf0246e00eee Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Fri, 1 May 2026 14:29:58 -0400 Subject: [PATCH 43/76] cut: exclude spectra-like data from JSON results of StrainAnalysisProcessor --- CHAP/edd/processor.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/CHAP/edd/processor.py b/CHAP/edd/processor.py index 146a28ae..ca0d9f94 100755 --- a/CHAP/edd/processor.py +++ b/CHAP/edd/processor.py @@ -2818,7 +2818,13 @@ def process(self, data): else self._get_points(results) ret = [result] if self.json_results: - json_results = {k: v.tolist() for k, v in results.items()} + json_results = { + k: v.tolist() + for k, v in results.items() + if not k.endswith('intensity') # exclude raw detector data + and not k.endswith('best_fit') # exlude best_fit spectra + and not k.endswith('residual') + } ret.append(json_results) if not (self._figures or self._animation): return tuple(ret) From 02eae4190bc78cc0f60a3ed504a5e5f62a17ee47 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Fri, 1 May 2026 14:30:55 -0400 Subject: [PATCH 44/76] add: include placeholder values for unused peaks in JSON results of StrainAnalysisProcessor --- CHAP/edd/processor.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/CHAP/edd/processor.py b/CHAP/edd/processor.py index ca0d9f94..620c0b2c 100755 --- a/CHAP/edd/processor.py +++ b/CHAP/edd/processor.py @@ -3644,6 +3644,29 @@ def _strain_analysis(self): np.nan), }) + if self.json_results: + # Include placeholder values for unused peaks in results + placeholder = np.full(num_points, np.nan) + for j, hkl in enumerate(hkls_fit[~use_peaks]): + hkl_name = '_'.join(str(hkl)[1:-1].split(' ')) + for fitmode in ('unconstrained', 'uniform'): + fitparams = ('amplitudes', 'centers', 'sigmas') + if detector.peak_models == 'pvoigt': + fitparams = (*fitparams, 'fractions') + if fitmode == 'unconstrained': + fitparams = (*fitparams, 'strains') + for fitparam in fitparams: + fitquants = ('values', 'errors') + if fitmode == 'unconstrained' \ + and fitparam == 'strains': + fitquants = (*fitquants, 'residuals') + for fitquant in fitquants: + key = str( + f'{det_id}/{fitmode}_fit/{hkl_name}/' + f'{fitparam}/{fitquant}' + ) + results.update({key: placeholder}) + # Create an animation of the fit points if (not self.config.skip_animation and (self.interactive or self.save_figures)): From 266086e4d69ecf766e2ef4d897b2f7c41f76438c Mon Sep 17 00:00:00 2001 From: Rolf Verberg Date: Tue, 5 May 2026 13:50:19 -0400 Subject: [PATCH 45/76] fix: peak fit results for the EDD calibration in channel energies --- CHAP/edd/processor.py | 71 ++++++++++++++++++++++++------------------- 1 file changed, 39 insertions(+), 32 deletions(-) diff --git a/CHAP/edd/processor.py b/CHAP/edd/processor.py index 733ec448..102c554c 100755 --- a/CHAP/edd/processor.py +++ b/CHAP/edd/processor.py @@ -1413,26 +1413,20 @@ def _calibrate(self): NXfield(mean_data[mask], 'y'), NXfield(bins[mask], 'x')), {'models': models, 'method': 'trf'}) + # Extract the fit results for the peaks - fit_peak_amplitudes = sorted([ - mean_data_fit.best_values[f'peak{i+1}_amplitude'].tolist() + fit_peak_amplitudes = np.asarray([ + mean_data_fit.best_values[f'peak{i+1}_amplitude'] for i in range(len(initial_peak_indices))]) self.logger.debug(f'Fit peak amplitudes: {fit_peak_amplitudes}') - fit_peak_indices = sorted([ - mean_data_fit.best_values[f'peak{i+1}_center'].tolist() + fit_peak_indices = np.asarray([ + mean_data_fit.best_values[f'peak{i+1}_center'] for i in range(len(initial_peak_indices))]) self.logger.debug(f'Fit peak center indices: {fit_peak_indices}') - fit_peak_sigmas = sorted([ - mean_data_fit.best_values[f'peak{i+1}_sigma'].tolist() + fit_peak_sigmas = np.asarray([ + mean_data_fit.best_values[f'peak{i+1}_sigma'] for i in range(len(initial_peak_indices))]) self.logger.debug(f'Fit peak sigmas: {fit_peak_sigmas}') - self._peak_fit_results[id_] = { - 'amplitudes': fit_peak_amplitudes, - 'centers': fit_peak_indices, - 'sigmas': fit_peak_sigmas, - 'independent_dimension': { - 'name': 'Detector Channel', 'unit': '-'} - } # FIX for now stick with a linear energy correction fit = FitProcessor(**self.run_config) @@ -1505,6 +1499,17 @@ def _calibrate(self): plt.show() plt.close() + # Collect the fit results for the peaks + # FIX assume Gaussian peaks for now and a=0 + self._peak_fit_results[id_] = { + 'heights': (fit_peak_amplitudes / + (np.sqrt(2*np.pi)*fit_peak_sigmas)).tolist(), + 'centers': (b*fit_peak_indices + c).tolist(), + 'fwhm': (b*fit_peak_sigmas*np.sqrt(8*np.log(2))).tolist(), + 'independent_dimension': { + 'name': 'Energy', 'unit': '-'}, + } + def _get_initial_peak_positions( self, y, low, index_ranges, input_indices, input_max_peak_index, detector_id, reset_flag=0, return_buf=False): @@ -2145,28 +2150,18 @@ def _direct_bragg_peak_fit( residual = result.residual # Extract the Bragg peak indices from the fit - fit_peak_amplitudes = [ - result.best_values[f'peak{i+1}_amplitude'].tolist() - for i in range(len(e_bragg))] + fit_peak_amplitudes = np.asarray([ + result.best_values[f'peak{i+1}_amplitude'] + for i in range(len(e_bragg))]) self.logger.debug(f'Fit peak amplitudes: {fit_peak_amplitudes}') - fit_peak_indices = [ - result.best_values[f'peak{i+1}_center'].tolist() - for i in range(len(e_bragg))] + fit_peak_indices = np.asarray([ + result.best_values[f'peak{i+1}_center'] + for i in range(len(e_bragg))]) self.logger.debug(f'Fit peak center indices: {fit_peak_indices}') - fit_peak_sigmas = [ - result.best_values[f'peak{i+1}_sigma'].tolist() - for i in range(len(e_bragg))] + fit_peak_sigmas = np.asarray([ + result.best_values[f'peak{i+1}_sigma'] + for i in range(len(e_bragg))]) self.logger.debug(f'Fit peak sigmas: {fit_peak_sigmas}') - self._peak_fit_results[detector.get_id()] = { - 'amplitudes': fit_peak_amplitudes, - 'centers': fit_peak_indices, - 'sigmas': fit_peak_sigmas, - 'independent_dimension': { - 'name': 'Detector Channel', 'unit': '-'} - } - fit_peak_indices = np.asarray( - [result.best_values[f'peak{i+1}_center'] - for i in range(len(e_bragg))]) # Fit a line through zero strain peak energies vs detector # energy bins @@ -2189,6 +2184,18 @@ def _direct_bragg_peak_fit( e_bragg_unconstrained = ( (a_fit*fit_peak_indices + b_fit) * fit_peak_indices + c_fit) + # Collect the fit results for the peaks + # FWHM to first order (i.e. ignore non-zero a) + # FIX assume Gaussian peaks for now + self._peak_fit_results[detector.get_id()] = { + 'heights': (fit_peak_amplitudes / + (np.sqrt(2*np.pi)*fit_peak_sigmas)).tolist(), + 'centers': e_bragg_unconstrained.tolist(), + 'fwhm': (b_fit*fit_peak_sigmas*np.sqrt(8*np.log(2))).tolist(), + 'independent_dimension': { + 'name': 'Energy', 'unit': 'keV'}, + } + return { 'best_fit_unconstrained': best_fit, 'residual_unconstrained': residual, From d248941db7b6360e34eb2958f142cb225b62856a Mon Sep 17 00:00:00 2001 From: CHESS MSNC Date: Tue, 12 May 2026 22:11:54 -0400 Subject: [PATCH 46/76] fix: missing call to import_scanparser in MapConfig validation --- CHAP/common/models/map.py | 1 + 1 file changed, 1 insertion(+) diff --git a/CHAP/common/models/map.py b/CHAP/common/models/map.py index fa463f6b..2d4b0ca5 100755 --- a/CHAP/common/models/map.py +++ b/CHAP/common/models/map.py @@ -1201,6 +1201,7 @@ def validate_experiment_type(cls, experiment_type, info): f'For station {station}, allowed experiment types are ' f'{", ".join(allowed_experiment_types)}. ' f'Supplied experiment type {experiment_type} is not allowed.') + import_scanparser(station, experiment_type) return experiment_type From 3b760679e0c0e021894cbd86e16c78970bb6abec Mon Sep 17 00:00:00 2001 From: Rolf Verberg Date: Tue, 12 May 2026 22:19:32 -0400 Subject: [PATCH 47/76] fix: missing call to import_scanparser in MapConfig validation --- CHAP/common/models/map.py | 1 + 1 file changed, 1 insertion(+) diff --git a/CHAP/common/models/map.py b/CHAP/common/models/map.py index fa463f6b..2d4b0ca5 100755 --- a/CHAP/common/models/map.py +++ b/CHAP/common/models/map.py @@ -1201,6 +1201,7 @@ def validate_experiment_type(cls, experiment_type, info): f'For station {station}, allowed experiment types are ' f'{", ".join(allowed_experiment_types)}. ' f'Supplied experiment type {experiment_type} is not allowed.') + import_scanparser(station, experiment_type) return experiment_type From 278eb49ebd46ccf8322506516dd955ae34c10664 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Thu, 14 May 2026 10:19:07 -0400 Subject: [PATCH 48/76] fix: saxswaxs processor imports & docstring typo --- CHAP/saxswaxs/processor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CHAP/saxswaxs/processor.py b/CHAP/saxswaxs/processor.py index f2933d86..5c8621ef 100755 --- a/CHAP/saxswaxs/processor.py +++ b/CHAP/saxswaxs/processor.py @@ -745,7 +745,7 @@ def process(self, data): # pylint: enable=import-error # Local modules - from CHAP.common import ( + from CHAP.common.processor import ( MapProcessor, NexusToZarrProcessor, ) @@ -1238,7 +1238,7 @@ def process(self, data, idx_slice=None): import os # Local modules - from CHAP.common import MapSliceProcessor + from CHAP.common.map_utils import MapSliceProcessor from CHAP.pipeline import PipelineData #from CHAP.saxswaxs.processor import PyfaiIntegrationProcessor @@ -1278,7 +1278,7 @@ def set_logger(pipeline_item): ).process(None, idx_slice=idx_slice) def _get_detector_data(values, name): - "Get the detector data.""" + """Get the detector data.""" for v in values: if os.path.basename(v['path']) == name: return v['data'] From 8d4d5a4c27b4c3d19616d9c56a6c7ee5b929c774 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Thu, 14 May 2026 10:33:09 -0400 Subject: [PATCH 49/76] fix: erroneous in-place modification of pipeline data in UpdateValuesProcessor --- CHAP/saxswaxs/processor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/CHAP/saxswaxs/processor.py b/CHAP/saxswaxs/processor.py index 5c8621ef..88202735 100755 --- a/CHAP/saxswaxs/processor.py +++ b/CHAP/saxswaxs/processor.py @@ -1234,6 +1234,7 @@ def process(self, data, idx_slice=None): # Pass detector data to PyfaiIntegration processor # Concatenate & return results # System modules + from copy import deepcopy import logging import os @@ -1242,6 +1243,7 @@ def process(self, data, idx_slice=None): from CHAP.pipeline import PipelineData #from CHAP.saxswaxs.processor import PyfaiIntegrationProcessor + _data = deepcopy(data) def set_logger(pipeline_item): """Set the logger and logging handler for given pipeline item. @@ -1286,7 +1288,7 @@ def _get_detector_data(values, name): # Use raw detector data as input to integration for d in self.detectors: - data.append( + _data.append( PipelineData( name=d.get_id(), data=_get_detector_data(raw_values, d.get_id()), @@ -1295,7 +1297,7 @@ def _get_detector_data(values, name): # Get integrated data processed_values = set_logger( PyfaiIntegrationProcessor(config=self.pyfai_config) - ).process(data, idx_slices=[idx_slice]) + ).process(_data, idx_slices=[idx_slice]) if self.raw_data: return raw_values + processed_values From c75303e47c1a347437f355e9f416f4f58e865491 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Thu, 14 May 2026 11:37:23 -0400 Subject: [PATCH 50/76] fix: typo --- CHAP/edd/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHAP/edd/utils.py b/CHAP/edd/utils.py index f4d7a836..12c0e49f 100755 --- a/CHAP/edd/utils.py +++ b/CHAP/edd/utils.py @@ -1523,8 +1523,8 @@ def get_spectra_fits( 'amplitudes': uniform_fit_amplitudes, 'amplitudes_errors': uniform_fit_amplitudes_errors, 'amplitudes_vary': uniform_fit_amplitudes_vary, - 'sigmas': uniform_fit_fractions, - 'sigmas_errors': uniform_fit_fractions_errors, + 'sigmas': uniform_fit_sigmas, + 'sigmas_errors': uniform_fit_sigmas_errors, 'best_fits': uniform_fit.best_fit, 'residuals': uniform_fit.residual, 'redchis': uniform_fit.redchi, From acf8b1dbafd5972bae31f80c9fab2f17171a6c20 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Fri, 15 May 2026 14:51:40 -0400 Subject: [PATCH 51/76] fix: module path to IndexSliceConfig --- CHAP/common/writer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHAP/common/writer.py b/CHAP/common/writer.py index 04d6e84b..64b35be5 100755 --- a/CHAP/common/writer.py +++ b/CHAP/common/writer.py @@ -26,7 +26,7 @@ Writer, validate_writer_model, ) -from CHAP.common.models import IndexSliceConfig +from CHAP.common.models.common import IndexSliceConfig @@ -640,7 +640,7 @@ class NexusValuesWriter(Writer): only if the `"idx"` key is not present for an item in the newest `PipelineData` item in `data`. Defaults to `IndexSliceConfig()`. - :vartype idx_slice: CHAP.common.models.IndexSliceConfig, optional + :vartype idx_slice: CHAP.common.models.common.IndexSliceConfig, optional """ path_prefix: str = '' resize_axis: Union[int, Literal[False]] = False From f9e9cfb2df32799ab42871442cc71f3e80834431 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Fri, 15 May 2026 14:55:05 -0400 Subject: [PATCH 52/76] add: ZarrValuesWriter.resize_axis and ZarrValuesWriter.idx_slice --- CHAP/common/writer.py | 70 +++++++++++++++++++++++++++++++++---------- 1 file changed, 55 insertions(+), 15 deletions(-) diff --git a/CHAP/common/writer.py b/CHAP/common/writer.py index 64b35be5..d951e533 100755 --- a/CHAP/common/writer.py +++ b/CHAP/common/writer.py @@ -906,12 +906,24 @@ class ZarrValuesWriter(Writer): """Writer for updating values in arrays of an existing `Zarr `__ file. - :ivar path_prefix: Prefix to prepend to all "path" fields in - `data` before writing. Defaults to `""`. + :ivar path_prefix: Prefix to use for all paths in input `data`, + defaults to `''`. :vartype path_prefix: str, optional + :ivar resize_axis: `False` OR the axis along which the dataset + should be resized if the target slice shape does not match the + data shape. If `False`, any mismatching shapes will just raise + an error. Defaults to `False`. + :vartype resize_axis: Union[int, Literal[False]], optional + :ivar idx_slice: Configuration for a slice object that will be + used to select the slice of the target array to write to. Used + only if the `"idx"` key is not present for an item in the + newest `PipelineData` item in `data`. Defaults to + `IndexSliceConfig()`. + :vartype idx_slice: CHAP.common.models.common.IndexSliceConfig, optional """ - path_prefix: Optional[str] = '' + resize_axis: Union[int, Literal[False]] = False + idx_slice: Optional[IndexSliceConfig] = IndexSliceConfig() def write(self, data): """Write values to specific paths and slices in an existing @@ -931,11 +943,16 @@ def write(self, data): # Get list of PyfaiIntegrationProcessor results to write for d in self.get_pipelinedata_item(data, remove=self.remove): - self.zarr_writer( - zarrfile=zarrfile, - path=os.path.join(self.path_prefix, d['path']), - idx=d['idx'], - data=d['data']) + try: + self.logger.info(f'd = {d}') + self.zarr_writer( + zarrfile=zarrfile, + path=os.path.join(self.path_prefix, d['path']), + idx=d.get('idx', self.idx_slice._slice), + data=d['data'] + ) + except Exception as exc: + self.logger.error(exc) def zarr_writer(self, zarrfile, path, idx, data): """Write data to a specific dataset. @@ -963,18 +980,41 @@ def zarr_writer(self, zarrfile, path, idx, data): # Check if the dataset exists if path not in zarrfile: raise ValueError( - f'Dataset "{path}" does not exist in the Zarr file.') + f'Dataset "{path}" does not exist in the Zarr file.' + ) # Access the specified dataset dataset = zarrfile[path] + self.logger.debug( + f'chunks = {dataset.chunks}' + ) - # Check that the slice shape matches the data shape - if dataset[idx].shape != data.shape and data.shape[0] == 1: - data = np.squeeze(data, axis=0) + data = np.asarray(data) + + # Check the shape + self.logger.info( + f'data shape, target shape = {data.shape}, {dataset[idx].shape}' + ) if dataset[idx].shape != data.shape: - raise ValueError( - f'Data shape {data.shape} does not match the target slice ' - f'shape {dataset[idx].shape}.') + if self.resize_axis is not False: + # Resize along the specified axis + start = idx.start or 0 + stop = start + data.shape[0] + newshape_axis = max(dataset.shape[self.resize_axis], stop) + oldshape = dataset.shape + newshape = ( + *oldshape[0:self.resize_axis], + newshape_axis, + *oldshape[self.resize_axis + 1:] + ) + self.logger.info( + f'Resizing {path} {oldshape} -> {newshape}' + ) + dataset.resize(newshape) + else: + raise ValueError( + f'Data shape {data.shape} does not match the target slice ' + f'shape {dataset[idx].shape}.') # Write the data to the specified slice dataset[idx] = data From d665f2f28bde26296aeed31c0b240b26bae84728 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Fri, 15 May 2026 16:29:25 -0400 Subject: [PATCH 53/76] fix: preserve NXlinks when passing through NexusToZarrProcessor and ZarrToNexusProcessor --- CHAP/common/processor.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/CHAP/common/processor.py b/CHAP/common/processor.py index cd5d48f7..1a1ff447 100755 --- a/CHAP/common/processor.py +++ b/CHAP/common/processor.py @@ -2209,6 +2209,7 @@ def process(self, data, chunks='auto'): from nexusformat.nexus import ( NXfield, NXgroup, + NXlink, ) # pylint: disable=import-error import zarr @@ -2241,9 +2242,19 @@ def copy_group(nexus_group, zarr_group): else: zarr_group.attrs[attr_key] = attr_value.nxvalue - # Copy datasets and sub-groups + # Copy datasets, sub-groups, and links + nxlinks = {} for key, item in nexus_group.items(): - if isinstance(item, NXfield): + if isinstance(item, NXlink): + self.logger.info(f'Recording link {item.nxpath}') + if item.nxfilename is not None: + nxlinks[key] = { + 'target': item.nxtarget, + 'file': item.nxfilename, + } + else: + nxlinks[key] = item.nxtarget + elif isinstance(item, NXfield): if isinstance(item.nxdata, np.ndarray): try: # Determine chunks @@ -2278,6 +2289,8 @@ def copy_group(nexus_group, zarr_group): # Recursively copy subgroup zarr_subgroup = zarr_group.create_group(key) copy_group(item, zarr_subgroup) + if nxlinks: + zarr_group.attrs['__nxlinks__'] = nxlinks copy_group(nexus_group, zarr_group) return zarr_group @@ -3514,9 +3527,10 @@ def copy_group(zarr_group, nexus_group): :type: nexusformat.nexus.NXgroup """ self.logger.info(f'Copying {zarr_group.path}') - # Copy attributes + # Copy attributes (skip the internal link-metadata key) for attr_key, attr_value in zarr_group.attrs.items(): - nexus_group.attrs[attr_key] = attr_value + if attr_key != '__nxlinks__': + nexus_group.attrs[attr_key] = attr_value # Copy datasets and sub-groups for key, item in zarr_group.members(): @@ -3538,6 +3552,17 @@ def copy_group(zarr_group, nexus_group): nexus_subgroup = nexus_group.create_group(key) copy_group(item, nexus_subgroup) + # Recreate NXlinks + for link_name, link_target in \ + zarr_group.attrs.get('__nxlinks__', {}).items(): + self.logger.info( + f'Creating link {zarr_group.path}/{link_name}') + if isinstance(link_target, dict): + nexus_group[link_name] = h5py.ExternalLink( + link_target['file'], link_target['target']) + else: + nexus_group[link_name] = h5py.SoftLink(link_target) + # Start copying from the root group copy_group(zarr_file, nexus_file) From 72cc37190844cb4364e2e25bee2cc52aecc01009 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Fri, 15 May 2026 16:31:49 -0400 Subject: [PATCH 54/76] bump: pydantify idx_slice parameter in common.map_utils.MapSliceProcessor and saxswaxs.UpdateValuesProcessor --- CHAP/common/map_utils.py | 20 ++++++++++---------- CHAP/saxswaxs/processor.py | 12 ++++++------ 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/CHAP/common/map_utils.py b/CHAP/common/map_utils.py index 08551fa0..b29139f5 100755 --- a/CHAP/common/map_utils.py +++ b/CHAP/common/map_utils.py @@ -13,6 +13,7 @@ from typing import Optional # Local modules +from CHAP.common.models.common import IndexSliceConfig from CHAP.common.models.map import ( Detector, MapConfig, @@ -65,6 +66,9 @@ class MapSliceProcessor(Processor): :ivar scan_numbers: Numbers of scans from which to read slices of raw data. :vartype scan_numbers: list[int] + :ivar idx_slice: Parameters for the slice of each scan to process + Defaults to `IndexSliceConfig()`. + :vartype idx_slice: CHAP.common.models.common.IndexSliceConfig, optional """ pipeline_fields: dict = Field( @@ -77,9 +81,9 @@ class MapSliceProcessor(Processor): spec_file: FilePath scan_number: Optional[conint(gt=0)] = None scan_numbers: Optional[conlist(item_type=conint(gt=0))] = None + idx_slice: Optional[IndexSliceConfig] = IndexSliceConfig() - def process(self, data, - idx_slice={'start': 0, 'step': 1}): + def process(self, data): """Aggregate partial spec and detector data from one or more scans in a map, returning results in a format suitable for @@ -96,11 +100,6 @@ def process(self, data, has the value `'common.models.map.MapConfig'` for the `'schema'` key. :type data: list[PipelineData] - :param idx_slice: Parameters for the slice of each scan to - process (slice parameters are the usual for the python - `slice` object: `'start'`, `'stop'`, and `'step'`). - Defaults to `{'start': 0, 'step': 1}`. - :type idx_slice: dict[str, int], optional :return: Slice of map data, ready to be written to a map container. :rtype: list[dict[str, Any]] @@ -143,8 +142,8 @@ def get_detector_data(scan, detector, index): sorted_scan_numbers = sorted( self.scan_numbers, key=lambda sn: scan_positions[sn]) - slice_start = idx_slice.get('start', 0) - slice_step = idx_slice.get('step', 1) + slice_start = self.idx_slice._slice.start + slice_step = self.idx_slice._slice.step # Collect per-scan metadata; assumes uniform npts across scans # for index_offset calculation (index_offset = map_pos * npts) @@ -154,7 +153,8 @@ def get_detector_data(scan, detector, index): npts_scan = int(scan.spec_scan_npts) index_offset = scan_positions[sn] * npts_scan # Cap stop at npts_scan so map_indices and data stay in sync - slice_stop = min(idx_slice.get('stop', npts_scan), npts_scan) + slice_stop = min(self.idx_slice._slice.stop, npts_scan) \ + if self.idx_slice._slice.stop > 0 else npts_scan scan_indices = range(npts_scan)[ slice(slice_start, slice_stop, slice_step)] map_indices = slice( diff --git a/CHAP/saxswaxs/processor.py b/CHAP/saxswaxs/processor.py index 88202735..232db427 100755 --- a/CHAP/saxswaxs/processor.py +++ b/CHAP/saxswaxs/processor.py @@ -22,6 +22,7 @@ import numpy as np # Local modules +from CHAP.common.models.common import IndexSliceConfig from CHAP.common.models.map import ( Detector, MapConfig, @@ -1211,8 +1212,9 @@ class UpdateValuesProcessor(Processor): scan_number: conint(gt=0) detectors: conlist(item_type=Detector, min_length=1) raw_data: Optional[bool] = True + idx_slice: Optional[IndexSliceConfig] = IndexSliceConfig() - def process(self, data, idx_slice=None): + def process(self, data): """Processes a slice of data for updating values in an existing container for a SAXS/WAXS experiment. @@ -1266,18 +1268,16 @@ def set_logger(pipeline_item): pipeline_item.logger.addHandler(handler) return pipeline_item - if idx_slice is None: - idx_slice = {'start':0, 'step': 1} - # Read in slice of raw data raw_values = set_logger( MapSliceProcessor( map_config=self.map_config, detectors=self.detectors, spec_file=str(self.spec_file), - scan_number=self.scan_number, + scan_numbers=[self.scan_number], + idx_slice=self.idx_slice ) - ).process(None, idx_slice=idx_slice) + ).process(None) def _get_detector_data(values, name): """Get the detector data.""" From a5d616ae58fb7aae0bdd420b2620fb2eeb279c8b Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Fri, 15 May 2026 16:35:55 -0400 Subject: [PATCH 55/76] rm: idx_slices parameter for saxswaxs.PyfaiIntegrationProcessor --- CHAP/saxswaxs/processor.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/CHAP/saxswaxs/processor.py b/CHAP/saxswaxs/processor.py index 232db427..ac92125e 100755 --- a/CHAP/saxswaxs/processor.py +++ b/CHAP/saxswaxs/processor.py @@ -499,7 +499,7 @@ class PyfaiIntegrationProcessor(Processor): init_var=True) config: PyfaiIntegrationConfig - def process(self, data, idx_slices=None): + def process(self, data): """Perform a set of integrations on 2D detector data. :param data: Input data. @@ -516,19 +516,11 @@ def process(self, data, idx_slices=None): # System modules import time - if idx_slices is None: - idx_slices = [{'start':0, 'step': 1}] - # Organize input for integrations input_data = {d['name']: d['data'] for d in [d for d in data if isinstance(d['data'], np.ndarray)]} ais = {ai.get_id(): ai for ai in self.config.azimuthal_integrators} - # Finalize idx slice for results - idx = tuple(slice(idx_slice.get('start'), - idx_slice.get('stop'), - idx_slice.get('step')) for idx_slice in idx_slices) - # Perform integration(s), package results for ZarrResultsWriter results = [] nframes = len(input_data[list(input_data.keys())[0]]) @@ -544,7 +536,6 @@ def process(self, data, idx_slices=None): [ { 'path': f'{integration.name}/data/I', - 'idx': idx, 'data': np.asarray(result['intensities']), }, ] @@ -1297,7 +1288,7 @@ def _get_detector_data(values, name): # Get integrated data processed_values = set_logger( PyfaiIntegrationProcessor(config=self.pyfai_config) - ).process(_data, idx_slices=[idx_slice]) + ).process(_data) if self.raw_data: return raw_values + processed_values From 140104452f3e340459a865cef73f47d8df2a0033 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Fri, 15 May 2026 16:37:50 -0400 Subject: [PATCH 56/76] feat: support 0-sized setup arrays in saxswaxs.SetupProcessor --- CHAP/saxswaxs/processor.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/CHAP/saxswaxs/processor.py b/CHAP/saxswaxs/processor.py index ac92125e..deaf1dcf 100755 --- a/CHAP/saxswaxs/processor.py +++ b/CHAP/saxswaxs/processor.py @@ -575,7 +575,7 @@ class SetupResultsProcessor(Processor): }, init_var=True) config: PyfaiIntegrationConfig - dataset_shape: conlist(item_type=conint(gt=0), min_length=1) + dataset_shape: conlist(item_type=conint(ge=0), min_length=1) dataset_chunks: Optional[ Union[ Literal['auto'], @@ -704,9 +704,9 @@ class SetupProcessor(Processor): # Pipeline to pass validation. map_config: MapConfig = None pyfai_config: PyfaiIntegrationConfig - detectors: conlist(item_type=Detector, min_length=1) + detectors: conlist(item_type=Detector, min_length=0) dataset_shape: Optional[ - conlist(item_type=conint(gt=0), min_length=1)] = None + conlist(item_type=conint(ge=0), min_length=1)] = None dataset_chunks: Optional[ Union[ Literal['auto'], @@ -768,7 +768,8 @@ def set_logger(pipeline_item): # Get NXroot container for raw data map map_processor_kwargs = { - 'config': self.map_config + 'config': self.map_config, + 'remove_constant_dims': False, } if self.raw_data: map_processor_kwargs['detector_config'] = { @@ -778,13 +779,9 @@ def set_logger(pipeline_item): map_processor_kwargs['detector_config'] = { 'detectors': [] } - setup_map_processor = set_logger( - MapProcessor( - **map_processor_kwargs - # config=self.map_config, - # detector_config={'detectors': self.detectors}, - ) - ) + self.map_config.spec_scans[0].scan_numbers = [] + + setup_map_processor = set_logger(MapProcessor(**map_processor_kwargs)) ddata = [ PipelineData( data=setup_map_processor.process( From eb06d7b335ada7d8c8b30f25787fe1f7a60570ec Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Mon, 18 May 2026 08:26:23 -0400 Subject: [PATCH 57/76] fix: do not duplicate log handlers for PipelineItems with the same name --- CHAP/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHAP/pipeline.py b/CHAP/pipeline.py index 1b206945..2f606d37 100755 --- a/CHAP/pipeline.py +++ b/CHAP/pipeline.py @@ -88,7 +88,7 @@ def validate_pipelineitem_after(self): log_handler.setFormatter(logging.Formatter( '{asctime}: {name:20}: {levelname}: {message}', datefmt='%Y-%m-%d %H:%M:%S', style='{')) - self.logger.addHandler(log_handler) + self.logger.handlers = [log_handler] self.logger.setLevel(self.log_level) # Optinal, but it's already available in the 'name' field #if self.get_schema() is None: From d6587dd6e0ea5414f053757c032f74928a263431 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Mon, 18 May 2026 08:26:58 -0400 Subject: [PATCH 58/76] style: include lineno in log formatters --- CHAP/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHAP/pipeline.py b/CHAP/pipeline.py index 2f606d37..ba843ca3 100755 --- a/CHAP/pipeline.py +++ b/CHAP/pipeline.py @@ -86,7 +86,7 @@ def validate_pipelineitem_after(self): self.logger.propagate = False log_handler = logging.StreamHandler() log_handler.setFormatter(logging.Formatter( - '{asctime}: {name:20}: {levelname}: {message}', + '{asctime}: {name:20} (L{lineno}): {levelname}: {message}', datefmt='%Y-%m-%d %H:%M:%S', style='{')) self.logger.handlers = [log_handler] self.logger.setLevel(self.log_level) From 47b24e0a420a426ee872f09bc106357f3175e825 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Thu, 21 May 2026 15:46:53 -0400 Subject: [PATCH 59/76] add: support detector_config in addition to detectors field in MapSliceProcessor --- CHAP/common/map_utils.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/CHAP/common/map_utils.py b/CHAP/common/map_utils.py index b29139f5..91c57e85 100755 --- a/CHAP/common/map_utils.py +++ b/CHAP/common/map_utils.py @@ -15,6 +15,7 @@ # Local modules from CHAP.common.models.common import IndexSliceConfig from CHAP.common.models.map import ( + DetectorConfig, Detector, MapConfig, ) @@ -74,10 +75,12 @@ class MapSliceProcessor(Processor): pipeline_fields: dict = Field( default={ 'map_config': 'common.models.map.MapConfig', + 'detector_config': 'common.models.map.DetectorConfig' }, init_var=True) map_config: MapConfig - detectors: conlist(item_type=Detector, min_length=1) + detector_config: DetectorConfig + detectors: Optional[conlist(item_type=Detector)] = None spec_file: FilePath scan_number: Optional[conint(gt=0)] = None scan_numbers: Optional[conlist(item_type=conint(gt=0))] = None @@ -231,7 +234,7 @@ def get_detector_data(scan, detector, index): ]), 'idx': merged_idx, }) - for det in self.detectors: + for det in self.detector_config.detectors: data_points.append({ 'path': (f'{self.map_config.title}' f'/data/{det.get_id()}'), @@ -285,7 +288,7 @@ def get_detector_data(scan, detector, index): for i in ps['scan_indices'] ]), 'idx': ps['map_indices'], - } for det in self.detectors]) + } for det in self.detector_config.detectors]) return data_points @model_validator(mode='before') @@ -313,6 +316,21 @@ def validate_scan_numbers(self): self.scan_numbers.append(self.scan_number) return self + @model_validator(mode='before') + def fill_detector_config(cls, data): + if not isinstance(data, dict): + return data + if 'detector_config' not in data or data['detector_config'] is None: + if data.get('detectors') is not None: + data['detector_config'] = DetectorConfig( + detectors=data['detectors'] + ) + else: + raise ValueError( + 'detector_config is required; alternatively, provide detectors' + ) + return data + class SpecScanToMapConfigProcessor(Processor): """Processor to get the From ea2510cea9dff44ef7007a994db5a9301d35cbba Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Thu, 21 May 2026 15:49:11 -0400 Subject: [PATCH 60/76] feat: in zarr_tree methods of pyFAI integration configs, support adding attributes that will result in NXlinks when converted to NeXus format --- CHAP/common/models/integration.py | 69 ++++++++++++++++++++++++------- 1 file changed, 54 insertions(+), 15 deletions(-) diff --git a/CHAP/common/models/integration.py b/CHAP/common/models/integration.py index fd88bb48..1bce0d24 100755 --- a/CHAP/common/models/integration.py +++ b/CHAP/common/models/integration.py @@ -634,17 +634,37 @@ def integrate(self, azimuthal_integrators, data): return results - def zarr_tree(self, dataset_shape, dataset_chunks='auto'): + def zarr_tree(self, dataset_shape, dataset_chunks='auto', nxlinks=None): """Return a dictionary representing a `zarr.group` that can be used to contain results from this integration. - :return: A `zarr.group` that can be used to contain the - integration results. + :param dataset_shape: Shape of the measurement (scan) dimensions + of the output dataset, excluding the integration dimensions. + :type dataset_shape: tuple[int, ...] + :param dataset_chunks: Chunk shape along the scan dimensions, or + ``'auto'``, defaults to ``'auto'``. + :type dataset_chunks: list[int] or str, optional + :param nxlinks: NeXus path(s) to link into the ``data`` group. + When the zarr tree is written to a ``.zarr`` file and + converted to ``.nxs`` with + :class:`~CHAP.common.processor.ZarrToNexusProcessor`, each + path produces an ``NXlink`` whose name is + ``os.path.basename(path)``. Accepts a single path string or + a list of path strings. + :type nxlinks: str or list[str], optional + :returns: Nested dict representing the zarr group tree for this + integration. :rtype: dict """ # Third party modules #import json + if isinstance(nxlinks, str): + nxlinks = [nxlinks] + data_attrs = {**self.get_axes_indices(len(dataset_shape))} + if nxlinks: + data_attrs['__nxlinks__'] = { + os.path.basename(p): p for p in nxlinks} tree = { # NXprocess 'attributes': { @@ -654,10 +674,7 @@ def zarr_tree(self, dataset_shape, dataset_chunks='auto'): 'children': { 'data': { # NXdata - 'attributes': { - # 'axes': self.result_axes(), - **self.get_axes_indices(len(dataset_shape)) - }, + 'attributes': data_attrs, 'children': { 'I': { # NXfield @@ -732,28 +749,50 @@ def validate_config(cls, data): data['azimuthal_integrators'] = ais return data - def zarr_tree(self, dataset_shape, dataset_chunks='auto'): + def zarr_tree(self, dataset_shape, dataset_chunks='auto', nxlinks=None): """Return a dictionary representing a `zarr.group` that can be used to contain results from :class:`~CHAP.saxswaxs.PyfaiIntegrationProcessor`. - :return: A `zarr.group` that can be used to contain the - integration results. + Each integration defined in :attr:`integrations` gets its own + sub-group keyed by the integration's ``name``. See + :meth:`PyfaiIntegratorConfig.zarr_tree` for the structure of + each sub-group. + + :param dataset_shape: Shape of the measurement (scan) dimensions + of the output dataset, excluding the integration dimensions. + :type dataset_shape: tuple[int, ...] + :param dataset_chunks: Chunk shape along the scan dimensions, or + ``'auto'``, defaults to ``'auto'``. + :type dataset_chunks: list[int] or str, optional + :param nxlinks: NeXus links to inject into each integration's + ``data`` group. May be a single path string or list of path + strings (forwarded to every integration), or a dict keyed by + integration name mapping each integration to its own path(s). + See :meth:`PyfaiIntegratorConfig.zarr_tree` for details on + how individual paths are handled. + :type nxlinks: str or list[str] or dict[str, str or list[str]], + optional + :returns: Nested dict representing the zarr group tree for all + integrations. :rtype: dict """ ais = {ai.get_id(): ai for ai in self.azimuthal_integrators} for integration in self.integrations: integration.init_placeholder_results(ais) + if not isinstance(nxlinks, dict): + nxlinks = {intg.name: nxlinks for intg in self.integrations} tree = { 'root': { 'attributes': { 'description': 'Container for processed SAXS/WAXS data' }, - 'children': { - integration.name: integration.zarr_tree( - dataset_shape, dataset_chunks) - for integration in self.integrations - } + }, + 'children': { + integration.name: integration.zarr_tree( + dataset_shape, dataset_chunks, + nxlinks=nxlinks.get(integration.name)) + for integration in self.integrations } } return tree From 740072ba2f6549c449653f7f617515d7368623cc Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Thu, 21 May 2026 15:54:21 -0400 Subject: [PATCH 61/76] doc: update MapSliceProcessor docstring --- CHAP/common/map_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/CHAP/common/map_utils.py b/CHAP/common/map_utils.py index 91c57e85..384aabd4 100755 --- a/CHAP/common/map_utils.py +++ b/CHAP/common/map_utils.py @@ -58,9 +58,8 @@ class MapSliceProcessor(Processor): :ivar map_config: Map configuration. :vartype map_config: MapConfig - :ivar detectors: Detector configurations. - :vartype detectors: - list[:class:`~CHAP.common.models.map.Detector`] + :ivar detector_config: Detector configurations. + :vartype detector_config: :class:`~CHAP.common.models.map.DetectorConfig` :ivar spec_file: SPEC file containing scan from which to read a slice of raw data. :vartype spec_file: pydantic.FilePath From 1fda1d39e91fd72ce9a07865bdc55c5fbdee04eb Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Thu, 21 May 2026 15:57:53 -0400 Subject: [PATCH 62/76] feat: support slicewise processing for corrected data --- CHAP/saxswaxs/models.py | 356 ++++++++++++++++++++++++++ CHAP/saxswaxs/processor.py | 496 ++++++++++++++++++++++--------------- CHAP/saxswaxs/utils.py | 60 +++++ 3 files changed, 709 insertions(+), 203 deletions(-) create mode 100644 CHAP/saxswaxs/models.py create mode 100644 CHAP/saxswaxs/utils.py diff --git a/CHAP/saxswaxs/models.py b/CHAP/saxswaxs/models.py new file mode 100644 index 00000000..86df7a9a --- /dev/null +++ b/CHAP/saxswaxs/models.py @@ -0,0 +1,356 @@ +"""Models to help construct containers for results of +``saxswaxs.*CorrectionProcessor`` tools.""" + +# System modules +from functools import cached_property +import os +import re +from typing import Literal, Optional, Union + +# Third party modules +from pydantic import ( + confloat, + conlist, + model_validator, +) + +# Local modules +from CHAP import version as chap_version +from CHAP.models import CHAPBaseModel +from CHAP.common.models.common import IndexSliceConfig +from CHAP.common.models.map import SpecScans + + +class Background(SpecScans): + """Configuration for background scan data associated with a + correction. + + Extends :class:`~CHAP.common.models.map.SpecScans` with an + optional index slice so that only a subset of scan steps is used + when reading background images. + + :ivar idx_slice: Index slice selecting which scan steps of the + background scan(s) to read and average. Defaults to all steps. + :vartype idx_slice: IndexSliceConfig + """ + + idx_slice: IndexSliceConfig = IndexSliceConfig() + + def zarr_arrays(self, integration_shape): + """Return a dictionary describing the zarr array that will hold + the averaged integrated background intensity. + + :param integration_shape: Shape of one frame of integration + results, as returned by + :attr:`~CHAP.common.models.integration.PyfaiIntegratorConfig.result_shape`. + :type integration_shape: tuple[int, ...] + :returns: Dict mapping the array name ``'I_background'`` to a + zarr array specification (``dtype``, ``shape``, and + ``attributes`` keys) compatible with + :func:`~CHAP.saxswaxs.utils.dict_to_zarr`. + :rtype: dict + """ + return { + 'I_background': { + 'attributes': { + 'long_name': 'Intensity (a.u)', + 'units': 'a.u,' + }, + 'dtype': 'float64', + 'shape': integration_shape, + } + } + + +class CorrectionConfig(CHAPBaseModel): + """Base configuration for a single SAXS/WAXS correction step. + + Describes one correction (flux, flux+absorption, or + flux+absorption+background) to be applied to azimuthally integrated + detector data. Subclasses pin ``correction_type`` and may require + additional fields such as ``background``. + + :ivar correction_type: Identifies the correction algorithm. + One of ``'flux'``, ``'flux_absorption'``, or + ``'flux_absorption_background'``. + :vartype correction_type: str + :ivar name: Human-readable name used as the key for this + correction's group in the output zarr / NeXus tree. + :vartype name: str + :ivar uncorrected_data_name: Name of the + :class:`~CHAP.common.models.integration.PyfaiIntegratorConfig` + integration whose output serves as the uncorrected input for + this correction. Must match the ``name`` field of one of the + integrations in the associated + :class:`~CHAP.common.models.integration.PyfaiIntegrationConfig`. + :vartype uncorrected_data_name: str + :ivar presample_intensity_reference_rate: Fixed reference counting + rate for the pre-sample beam intensity monitor. When ``None`` + the rate is computed from the scan data as + ``nanmean(presample_intensity / dwell_time_actual)``. + :vartype presample_intensity_reference_rate: float, optional + :ivar background: Background scan configuration. Required for + ``'flux_absorption'`` and ``'flux_absorption_background'`` + correction types. + :vartype background: Background, optional + """ + + correction_type: Literal['flux', 'flux_absorption', + 'flux_absorption_background'] + name: str + uncorrected_data_name: str + presample_intensity_reference_rate: Optional[float] = None + background: Optional[Background] = None + + def zarr_tree(self, dataset_shape, dataset_chunks, integration_shape, + nxlinks=None): + """Return a dictionary representing the zarr tree for this + correction's output container. + + The returned tree is compatible with + :func:`~CHAP.saxswaxs.utils.dict_to_zarr` and, after conversion + with :class:`~CHAP.common.processor.ZarrToNexusProcessor`, + produces an ``NXprocess`` group containing a ``data`` sub-group + with an ``I_corrected`` dataset and, when a ``background`` is + configured, an ``I_background`` dataset. + + :param dataset_shape: Shape of the measurement (scan) dimensions + of the output dataset, excluding the integration dimensions. + :type dataset_shape: tuple[int, ...] + :param dataset_chunks: Chunk shape along the scan dimensions, or + ``'auto'``. + :type dataset_chunks: list[int] or str + :param integration_shape: Shape of one frame of integration + results for the integration named by + ``uncorrected_data_name``. + :type integration_shape: tuple[int, ...] + :param nxlinks: NeXus path(s) to link into the ``data`` group. + When the zarr tree is written to a ``.zarr`` file and + converted to ``.nxs`` with + :class:`~CHAP.common.processor.ZarrToNexusProcessor`, each + path produces an ``NXlink`` whose name is + ``os.path.basename(path)``. Accepts a single path string or + a list of path strings. All links must be explicit; none are + auto-generated. + :type nxlinks: str or list[str], optional + :returns: Nested dict representing the zarr group tree for this + correction. + :rtype: dict + """ + if isinstance(nxlinks, str): + nxlinks = [nxlinks] + data_attrs = {} + if nxlinks: + data_attrs['__nxlinks__'] = { + os.path.basename(p): p for p in nxlinks + } + if self.background is None: + background_arrays = {} + else: + background_arrays = self.background.zarr_arrays(integration_shape) + data_attrs['background'] = str(self.background.model_dump()) + return { + # NXprocess + 'attributes': { + 'correction_type': self.correction_type, + 'default': 'data', + }, + 'children': { + 'program': 'CHAP.saxswaxs', + 'version': chap_version, + 'data': { + # NXdata + 'attributes': data_attrs, + 'children': { + 'I_corrected': { + 'attributes': { + 'long_name': 'Intensity (a.u)', + 'units': 'a.u,' + }, + 'dtype': 'float64', + 'shape': (*dataset_shape, *integration_shape), + }, + **background_arrays, + } + } + } + } + + @cached_property + def processor_name(self): + """Name of the processor class that implements this correction. + + Derived from ``correction_type`` by capitalising each + ``'_'``-separated word and appending ``'CorrectionProcessor'`` + (e.g. ``'flux_absorption'`` → + ``'FluxAbsorptionCorrectionProcessor'``). + + :type: str + """ + return ''.join( + [x.capitalize() for x in self.correction_type.split('_')] + + ['CorrectionProcessor'] + ) + + @cached_property + def processor_module(self): + """Module object containing the processor class for this + correction. + + :type: module + """ + return __import__('CHAP.saxswaxs.processor', + fromlist=[self.processor_name]) + + @cached_property + def processor_class(self): + """Processor class that implements this correction. + + :type: type + """ + return getattr(self.processor_module, self.processor_name) + + @property + def processor(self): + """A new instance of the processor class for this correction, + initialised with the current configuration. + + :type: Processor + """ + return self.processor_class(config=self.model_dump()) + + +class CorrectionsConfig(CHAPBaseModel): + """Configuration container for an ordered list of SAXS/WAXS + correction steps to apply to integrated detector data. + + :ivar corrections: Ordered list of correction configurations. + :vartype corrections: list[CorrectionConfig] + """ + + corrections: conlist(item_type=CorrectionConfig) + + def zarr_tree(self, dataset_shape, dataset_chunks, + integration_shapes, nxlinks=None): + """Return a dictionary representing the zarr tree for all + corrections in this configuration. + + Each correction gets its own sub-group keyed by + :attr:`CorrectionConfig.name`. See + :meth:`CorrectionConfig.zarr_tree` for the structure of each + sub-group. + + :param dataset_shape: Shape of the measurement (scan) dimensions, + excluding integration dimensions. + :type dataset_shape: tuple[int, ...] + :param dataset_chunks: Chunk shape along the scan dimensions, or + ``'auto'``. + :type dataset_chunks: list[int] or str + :param integration_shapes: Mapping from integration name to the + shape of one integration result frame. Used to look up the + ``integration_shape`` for each correction via + :attr:`CorrectionConfig.uncorrected_data_name`. + :type integration_shapes: dict[str, tuple[int, ...]] + :param nxlinks: NeXus links to inject into each correction's + ``data`` group. May be a single path string or list of path + strings (forwarded to every correction), or a dict keyed by + correction name mapping each correction to its own path(s). + See :meth:`CorrectionConfig.zarr_tree` for details on how + individual paths are handled. + :type nxlinks: str or list[str] or dict[str, str or list[str]], + optional + :returns: Nested dict representing the zarr group tree for all + corrections. + :rtype: dict + """ + if not isinstance(nxlinks, dict): + nxlinks = {corr.name: nxlinks for corr in self.corrections} + return { + 'root': { + 'attributes': {}, + }, + 'children': { + corr.name: corr.zarr_tree( + dataset_shape, dataset_chunks, + integration_shapes.get( + corr.uncorrected_data_name, None + ), + nxlinks=nxlinks.get(corr.name), + ) + for corr in self.corrections + } + } + + +class FluxCorrectionConfig(CorrectionConfig): + """Correction configuration for flux-only correction. + + Applies a flux correction that normalises the measured intensity by + the pre-sample beam monitor counts, referenced to a fixed counting + rate. No background scan is required. + """ + + correction_type: Literal['flux'] = 'flux' + + +class FluxAbsorptionCorrectionConfig(FluxCorrectionConfig): + """Correction configuration for combined flux and absorption + correction. + + Extends :class:`FluxCorrectionConfig` with a required background + scan used to determine the sample transmission. + + :ivar background: Background scan configuration used to compute + the sample transmission term. + :vartype background: Background + """ + + correction_type: Literal['flux_absorption'] = 'flux_absorption' + background: Background + + +class FluxAbsorptionBackgroundCorrectionConfig( + FluxAbsorptionCorrectionConfig): + """Correction configuration for combined flux, absorption, and + background-subtraction correction with optional thickness + normalisation. + + Extends :class:`FluxAbsorptionCorrectionConfig` with an integrated + background subtraction step and an optional sample thickness or + linear attenuation coefficient for thickness normalisation. At + most one of ``sample_thickness_cm`` and ``sample_mu_inv_cm`` may be + provided. + + :ivar background: Background scan configuration. + :vartype background: Background + :ivar sample_thickness_cm: Sample thickness in centimetres. When + provided, corrected intensities are divided by this value. + Mutually exclusive with ``sample_mu_inv_cm``. + :vartype sample_thickness_cm: float, optional + :ivar sample_mu_inv_cm: Sample linear attenuation coefficient in + inverse centimetres. When provided, the effective thickness is + derived from the measured transmission. Mutually exclusive with + ``sample_thickness_cm``. + :vartype sample_mu_inv_cm: float, optional + """ + + correction_type: Literal[ + 'flux_absorption_background'] = 'flux_absorption_background' + background: Background + sample_thickness_cm: Optional[confloat(gt=0)] + sample_mu_inv_cm: Optional[confloat(gt=0)] + + @model_validator(mode='after') + def validate_thickness(self): + """Ensure ``sample_thickness_cm`` and ``sample_mu_inv_cm`` are + not both specified. + + :raises ValueError: If both fields are set. + :returns: The validated model instance. + :rtype: FluxAbsorptionBackgroundCorrectionConfig + """ + if self.sample_thickness_cm and self.sample_mu_inv_cm: + raise ValueError( + 'Use sample_thickness_cm OR sample_mu_inv_cm, not both.' + ) + return self diff --git a/CHAP/saxswaxs/processor.py b/CHAP/saxswaxs/processor.py index deaf1dcf..4f8872e1 100755 --- a/CHAP/saxswaxs/processor.py +++ b/CHAP/saxswaxs/processor.py @@ -24,12 +24,21 @@ # Local modules from CHAP.common.models.common import IndexSliceConfig from CHAP.common.models.map import ( + DetectorConfig, Detector, MapConfig, ) from CHAP.common.models.integration import PyfaiIntegrationConfig from CHAP.common.processor import ExpressionProcessor from CHAP.processor import Processor +from CHAP.pipeline import PipelineData +from CHAP.saxswaxs.models import ( + CorrectionsConfig, + FluxCorrectionConfig, + FluxAbsorptionCorrectionConfig, + FluxAbsorptionBackgroundCorrectionConfig, +) +from CHAP.saxswaxs.utils import dict_to_zarr class CfProcessor(Processor): @@ -216,7 +225,7 @@ def process( class FluxCorrectionProcessor(ExpressionProcessor): """Processor for flux correction.""" - + config: FluxCorrectionConfig def process( self, data, presample_intensity_reference_rate=None, nxprocess=False): @@ -252,7 +261,7 @@ def process( data, name='presample_intensity', ) intensity = self.get_data( - data, name='intensity', + data, name=self.config.uncorrected_data_name, ) # nxfieldtable = { # 'intensity': intensity, @@ -287,7 +296,7 @@ def process( class FluxAbsorptionCorrectionProcessor(ExpressionProcessor): """Processor for flux and absorption correction.""" - + config: FluxAbsorptionCorrectionConfig def process( self, data, presample_intensity_reference_rate=None, nxprocess=False): @@ -316,7 +325,7 @@ def process( :rtype: Any """ intensity = self.get_data( - data, name='intensity', + data, name=self.config.uncorrected_data_name, #'intensity', ) if presample_intensity_reference_rate is None: @@ -366,7 +375,7 @@ def process( class FluxAbsorptionBackgroundCorrectionProcessor(ExpressionProcessor): """Processor for flux, absorption, and background correction as well as optional thickness correction.""" - + config: FluxAbsorptionBackgroundCorrectionConfig def process( self, data, presample_intensity_reference_rate=None, sample_thickness_cm=None, sample_mu_inv_cm=None, nxprocess=False): @@ -414,7 +423,7 @@ def process( )) intensity = self.get_data( - data, name='intensity', + data, name=self.config.uncorrected_data_name, ) if presample_intensity_reference_rate is None: @@ -543,119 +552,6 @@ def process(self, data): return results -class SetupResultsProcessor(Processor): - """Processor for creating an intital - `Zarr group `__ - object with empty datasets for filling in by - :class:`~CHAP.saxswaxs.PyfaiIntegrationProcessor` and - :class:`~CHAP.common.ZarrValuesWriter`. - - :ivar config: Initialization parameters for an instance of - :class:`~CHAP.common.models.integration.PyfaiIntegrationConfig`. - :vartype config: dict - :ivar dataset_shape: Shape of the completed dataset that will be - processed later on (shape of the measurement itself, _not_ - including the dimensions of any signals collected at each point - in that measurement). - :vartype dataset_shape: int or list[int] - :ivar dataset_chunks: Extent of chunks along each dimension of the - completed dataset / measurement. Choose this according to how - you will process your data -- for example, if your - `dataset_shape` is `[m, n]`, and you are planning to process - each of the `m` rows as chunks, `dataset_chunks` should be - `[1, n]`. But if you plan to process each of the `n` columns as - chunks, `dataset_chunks` should be `[m, 1]`, - defaults to `"auto"`. - :vartype dataset_chunks: list[int] or Literal["auto"], optional - """ - - pipeline_fields: dict = Field( - default={ - 'config': 'common.models.integration.PyfaiIntegrationConfig' - }, - init_var=True) - config: PyfaiIntegrationConfig - dataset_shape: conlist(item_type=conint(ge=0), min_length=1) - dataset_chunks: Optional[ - Union[ - Literal['auto'], - conlist(item_type=conint(gt=0), min_length=1) - ]] = 'auto' - - def process(self, data): - """Create and return a - `Zarr group `__ - object to hold processed SAXS/WAXS data processed - by :class:`~CHAP.saxswaxs.PyfaiIntegrationProcessor`. - - :param data: Input data. - :type data: list[PipelineData] - :return: Empty structure for filling in SAXS/WAXS data. - :rtype: zarr.Group - """ - - # Get Zarr tree as dict from the PyfaiIntegrationConfig - tree = self.config.zarr_tree(self.dataset_shape, self.dataset_chunks) - - # Construct & return the root Zarr group - return self.zarr_setup(tree) - - def zarr_setup(self, tree): - """Create a - `Zarr group `__ - object based on a dictionary representing a Zarr tree of groups - and arrays. - - :param tree: Nested dictionary representing a Zarr tree of - groups and arrays. - :type tree: dict[str, Any] - :return: Zarr group corresponding to the contents of `tree`. - :rtype: zarr.Group - """ - # Third party modules - # pylint: disable=import-error - import zarr - from zarr.storage import MemoryStore - - def create_group_or_dataset(node, zarr_parent, indent=0): - """Create and return a - `Zarr group `__ - `Zarr dataset `__. - - :param node: Child Zarr tree group. - :type node: zarr.Group or zarr.Array - :param zarr_parent: Parent Zarr tree group. - :type zarr_parent: zarr.Group - :param indent: Indentation level, defaults to 0. - :type indent: int, optional - """ - # Set attributes if present - if 'attributes' in node: - for key, value in node['attributes'].items(): - zarr_parent.attrs[key] = value - # Create children (groups or datasets) - if 'children' in node: - for name, child in node['children'].items(): - if 'shape' in child or 'data' in child: - # It's a dataset - self.logger.debug(f'Adding dset: {name}') - zarr_parent.create_dataset( - name, - **child, - ) - # Set dataset attributes - if 'attributes' in child: - for key, value in child['attributes'].items(): - zarr_parent[name].attrs[key] = value - else: - # It's a group - group = zarr_parent.create_group(name) - create_group_or_dataset(child, group, indent=indent+2) - results = zarr.create_group(store=MemoryStore({})) - create_group_or_dataset(tree['root'], results) - return results - - class SetupProcessor(Processor): """Convenience Processor for setting up a container for SAXS/WAXS experiments. @@ -666,13 +562,13 @@ class SetupProcessor(Processor): :ivar pyfai_config: Initialization parameters for an instance of :class:`~CHAP.common.models.integration.PyfaiIntegrationConfig`. :vartype pyfai_config: dict, optional - :ivar detectors: Detector configurations. - :vartype detectors: DetectorConfig + :ivar detector_config: Detector configurations. + :vartype detector_config: DetectorConfig :ivar dataset_shape: Shape of the completed dataset that will be processed later on (shape of the measurement itself, _not_ including the dimensions of any signals collected at each point - in that measurement). - :vartype dataset_shape: int or list[int] + in that measurement). Defaults to `[0]`. + :vartype dataset_shape: int or list[int], optional :ivar dataset_chunks: Extent of chunks along each dimension of the completed dataset / measurement. Choose this according to how you will process your data -- for example, if your @@ -695,7 +591,9 @@ class SetupProcessor(Processor): pipeline_fields: dict = Field( default={ 'map_config': 'common.models.map.MapConfig', - 'pyfai_config': 'common.models.integration.PyfaiIntegrationConfig' + 'detector_config': 'common.models.map.DetectorConfig', + 'pyfai_config': 'common.models.integration.PyfaiIntegrationConfig', + 'correction_config': 'saxswaxs.models.CorrectionsConfig', }, init_var=True) # map_config needs a default value because the map configuration @@ -704,9 +602,10 @@ class SetupProcessor(Processor): # Pipeline to pass validation. map_config: MapConfig = None pyfai_config: PyfaiIntegrationConfig - detectors: conlist(item_type=Detector, min_length=0) + detector_config: DetectorConfig = DetectorConfig(detectors=[]) + correction_config: CorrectionsConfig dataset_shape: Optional[ - conlist(item_type=conint(ge=0), min_length=1)] = None + conlist(item_type=conint(ge=0), min_length=1)] = [0] dataset_chunks: Optional[ Union[ Literal['auto'], @@ -742,46 +641,19 @@ def process(self, data): NexusToZarrProcessor, ) from CHAP.pipeline import PipelineData - #from CHAP.saxswaxs.processor import SetupResultsProcessor - - def set_logger(pipeline_item): - """Set the logger and logging handler for given pipeline - item. - - :param pipeline_item: Pipeline item. - :type pipeline_item: PipelineItem - :return: Input Pipeline item, with updated logger and - logging handler. - :rtype: PipelineItem - """ - pipeline_item.logger = self.logger - pipeline_item.logger.name = pipeline_item.__class__.__name__ - handler = logging.StreamHandler() - handler.setFormatter(logging.Formatter( - '{asctime}: {name:20} (from '+ self.__class__.__name__ - + '): {levelname}: {message}', - datefmt='%Y-%m-%d %H:%M:%S', style='{')) - pipeline_item.logger.removeHandler( - pipeline_item.logger.handlers[0]) - pipeline_item.logger.addHandler(handler) - return pipeline_item # Get NXroot container for raw data map map_processor_kwargs = { 'config': self.map_config, + 'detector_config': self.detector_config, 'remove_constant_dims': False, } - if self.raw_data: - map_processor_kwargs['detector_config'] = { - 'detectors': self.detectors - } - else: - map_processor_kwargs['detector_config'] = { - 'detectors': [] - } + if not self.raw_data: self.map_config.spec_scans[0].scan_numbers = [] - setup_map_processor = set_logger(MapProcessor(**map_processor_kwargs)) + setup_map_processor = self.setup_pipelineitem( + MapProcessor(**map_processor_kwargs) + ) ddata = [ PipelineData( data=setup_map_processor.process( @@ -813,18 +685,124 @@ def set_logger(pipeline_item): self.dataset_chunks = [self.dataset_chunks] # Convert raw data map container to Zarr format - ddata_converter = set_logger(NexusToZarrProcessor()) + ddata_converter = self.setup_pipelineitem(NexusToZarrProcessor()) zarr_map = ddata_converter.process(ddata, chunks=self.dataset_chunks) - # Get Zarr container for integration results - setup_results_processor = set_logger( - SetupResultsProcessor( - config=self.pyfai_config, - dataset_shape=self.dataset_shape, - dataset_chunks=self.dataset_chunks, + # Get paths to independent_dimension arrays + dim_paths = [ + f'{self.map_config.title}/independent_dimensions/{dim.label}' + for dim in self.map_config.independent_dimensions + ] + + # Get Zarr container for pyfai integrations + zarr_pyfai_tree = self.pyfai_config.zarr_tree( + self.dataset_shape, self.dataset_chunks, + nxlinks=dim_paths + ) + zarr_pyfai = dict_to_zarr(zarr_pyfai_tree, logger=self.logger) + + + # Get zarr container for corrected datasets + integration_shapes = { + integration.name: integration.result_shape + for integration in self.pyfai_config.integrations + } + intg_by_name = { + intg.name: intg for intg in self.pyfai_config.integrations + } + corr_nxlinks = { + corr.name: ( + dim_paths + + [f'{corr.uncorrected_data_name}/data/I'] + + [ + f'{corr.uncorrected_data_name}/data/{coord}' + for coord in intg_by_name[ + corr.uncorrected_data_name].result_coords + ] ) + for corr in self.correction_config.corrections + } + zarr_corr = dict_to_zarr( + self.correction_config.zarr_tree( + self.dataset_shape, self.dataset_chunks, + integration_shapes, + nxlinks=corr_nxlinks, + ), + logger=self.logger, ) - zarr_results = setup_results_processor.process(data) + + # For corrections that include a background, read and integrate + # that background data and store it in zarr_corr + for corr_cfg in self.correction_config.corrections: + if corr_cfg.background is None: + continue + if not self.detector_config.detectors: + self.logger.warning( + f'No detectors configured; skipping background ' + f'processing for correction "{corr_cfg.name}"') + continue + # Read in background raw detector data + self.logger.info( + f'Reading background data for correction "{corr_cfg.name}"') + idx_slice = corr_cfg.background.idx_slice + bg_det_images = { + d.get_id(): [] for d in self.detector_config.detectors} + for scan_number in corr_cfg.background.scan_numbers: + scanparser = corr_cfg.background.get_scanparser(scan_number) + npts = int(scanparser.spec_scan_npts) + slice_stop = ( + min(idx_slice._slice.stop, npts) + if idx_slice._slice.stop > 0 else npts + ) + scan_indices = range(npts)[ + slice(idx_slice._slice.start, slice_stop, + idx_slice._slice.step)] + for i in scan_indices: + for det in self.detector_config.detectors: + bg_det_images[det.get_id()].append( + scanparser.get_detector_data(det.get_id(), i)) + # Average all background data to a single frame before + # processing with appropriate integration + for det_id in bg_det_images: + bg_det_images[det_id] = np.mean( + bg_det_images[det_id], axis=0, keepdims=True) + self.logger.info( + f'Integrating background data for correction ' + f'"{corr_cfg.name}"') + bg_pyfai_input = [ + PipelineData(name=det_id, data=imgs) + for det_id, imgs in bg_det_images.items() + ] + bg_pyfai_config = self.pyfai_config.model_copy( + update={'integrations': [ + intg for intg in self.pyfai_config.integrations + if intg.name == corr_cfg.uncorrected_data_name + ]} + ) + bg_integrated = self.setup_pipelineitem( + PyfaiIntegrationProcessor(config=bg_pyfai_config) + ).process(bg_pyfai_input)[0] + # Read background scalar data + bg_presample = self.map_config.presample_intensity.get_value( + corr_cfg.background, scan_number, -1, + self.map_config.scalar_data + )[idx_slice._slice] + bg_postsample = self.map_config.postsample_intensity.get_value( + corr_cfg.background, scan_number, -1, + self.map_config.scalar_data + )[idx_slice._slice] + # Fill in placeholder zarr arrays with real background + # data + data_group = zarr_corr[corr_cfg.name]['data'] + data_group['I_background'][:] = bg_integrated['data'] + bg_presample_arr = data_group.create_array( + 'background_presample_intensity', + shape=bg_presample.shape, dtype=bg_presample.dtype) + bg_presample_arr[:] = bg_presample + bg_postsample_arr = data_group.create_array( + 'background_postsample_intensity', + shape=bg_postsample.shape, dtype=bg_postsample.dtype) + bg_postsample_arr[:] = bg_postsample # Assemble containers for raw & processed data zarr_root = zarr.create_group(store=MemoryStore({})) @@ -842,10 +820,34 @@ async def copy_zarr_store(source_store, dest_store): buf = await source_store.get( k, prototype=default_buffer_prototype()) await dest_store.set(k, buf) - asyncio.run(copy_zarr_store(zarr_map.store, zarr_root.store)) - asyncio.run(copy_zarr_store(zarr_results.store, zarr_root.store)) + for zarr_group in (zarr_map, zarr_pyfai, zarr_corr): + asyncio.run(copy_zarr_store(zarr_group.store, zarr_root.store)) return zarr_root + def setup_pipelineitem(self, pipeline_item): + """Convenience method to use a nice logger when this + ``Processor`` calls on another ``PipelineItem``. Set the + logger and logging handler for given pipeline item. + + :param pipeline_item: Pipeline item. + :type pipeline_item: PipelineItem + :return: Input Pipeline item, with updated logger and + logging handler. + :rtype: PipelineItem + """ + import logging + + pipeline_item.logger = logging.getLogger( + pipeline_item.__class__.__name__, + ) + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter( + '{asctime}: {name:20} (from '+ self.__class__.__name__ + + ') (L{lineno}): {levelname}: {message}', + datefmt='%Y-%m-%d %H:%M:%S', style='{')) + pipeline_item.logger.handlers = [handler] + return pipeline_item + class UnstructuredToStructuredProcessor(Processor): """Processor to aggregate "unstructured" data into a single NeXus @@ -1176,18 +1178,19 @@ class UpdateValuesProcessor(Processor): :ivar scan_number: Scan number from which to read and process a slice of raw data. :vartype scan_number: int - :ivar detectors: Detector configurations. - :vartype detectors: list[CHAP.common.models.map.Detector] + :ivar detector_config: Detector configurations. + :vartype detector_config: :class:`~CHAP.common.models.map.DetectorConfig` :ivar raw_data: Flag to indicate wether or not space for raw detector data should be included in the values returned, defaults to `True`. :vartype raw_data: bool, optional """ - pipeline_fields: dict = Field( - default={ + { 'map_config': 'common.models.map.MapConfig', - 'pyfai_config': 'common.models.integration.PyfaiIntegrationConfig' + 'detector_config': 'common.models.map.DetectorConfig', + 'pyfai_config': 'common.models.integration.PyfaiIntegrationConfig', + 'correction_config': 'saxswaxs.models.CorrectionsConfig', }, init_var=True) # map_config needs a default value because the map configuration @@ -1196,9 +1199,11 @@ class UpdateValuesProcessor(Processor): # Pipeline to pass validation. map_config: MapConfig = None pyfai_config: PyfaiIntegrationConfig + detector_config: DetectorConfig = DetectorConfig(detectors=[]) + correction_config: CorrectionsConfig spec_file: FilePath scan_number: conint(gt=0) - detectors: conlist(item_type=Detector, min_length=1) + filename: Optional[str] = None raw_data: Optional[bool] = True idx_slice: Optional[IndexSliceConfig] = IndexSliceConfig() @@ -1231,36 +1236,17 @@ def process(self, data): # Local modules from CHAP.common.map_utils import MapSliceProcessor from CHAP.pipeline import PipelineData - #from CHAP.saxswaxs.processor import PyfaiIntegrationProcessor + # Use a copy of input data so we can append to it inside this + # Processor without modifying the actual Pipeline's data + # unecessarily. _data = deepcopy(data) - def set_logger(pipeline_item): - """Set the logger and logging handler for given pipeline - item. - - :param pipeline_item: Pipeline item. - :type pipeline_item: PipelineItem - :return: Input Pipeline item, with updated logger and - logging handler. - :rtype: PipelineItem - """ - pipeline_item.logger = self.logger - pipeline_item.logger.name = pipeline_item.__class__.__name__ - handler = logging.StreamHandler() - handler.setFormatter(logging.Formatter( - '{asctime}: {name:20} (from '+ self.__class__.__name__ - + '): {levelname}: {message}', - datefmt='%Y-%m-%d %H:%M:%S', style='{')) - pipeline_item.logger.removeHandler( - pipeline_item.logger.handlers[0]) - pipeline_item.logger.addHandler(handler) - return pipeline_item # Read in slice of raw data - raw_values = set_logger( + raw_values = self.setup_pipelineitem( MapSliceProcessor( map_config=self.map_config, - detectors=self.detectors, + detector_config=self.detector_config, spec_file=str(self.spec_file), scan_numbers=[self.scan_number], idx_slice=self.idx_slice @@ -1275,7 +1261,7 @@ def _get_detector_data(values, name): return None # Use raw detector data as input to integration - for d in self.detectors: + for d in self.detector_config.detectors: _data.append( PipelineData( name=d.get_id(), @@ -1283,19 +1269,123 @@ def _get_detector_data(values, name): ) ) # Get integrated data - processed_values = set_logger( + processed_values = self.setup_pipelineitem( PyfaiIntegrationProcessor(config=self.pyfai_config) ).process(_data) + # Get corrected data + corrected_values = [ + { + 'data': self.setup_pipelineitem( + corr_cfg.processor + ).process( + self.get_corrections_input_data( + raw_values, processed_values, corr_cfg + ), + nxprocess=False, + ), + 'path': f'{corr_cfg.name}/data/I_corrected' + } + for corr_cfg in self.correction_config.corrections + ] + if self.raw_data: - return raw_values + processed_values + return raw_values + processed_values + corrected_values - detector_ids = [d.get_id() for d in self.detectors] + detector_ids = [d.get_id() for d in self.detector_config.detectors] scalar_values = [ d for d in raw_values if not os.path.basename(d['path']) in detector_ids ] - return scalar_values + processed_values + return scalar_values + processed_values + corrected_values + + def get_corrections_input_data(self, raw_values, processed_values, + corr_cfg): + corr_data = [] + for x in ('dwell_time_actual', 'presample_intensity', + 'postsample_intensity'): + for d in raw_values: + if d['path'].endswith(x): + corr_data.append( + PipelineData(data=d['data'], name=x) + ) + break + for d in processed_values: + if d['path'].startswith(f'{corr_cfg.uncorrected_data_name}/'): + corr_data.append( + PipelineData( + data=d['data'], + name=corr_cfg.uncorrected_data_name, + ) + ) + if corr_cfg.background is not None: + if self.filename is None: + self.logger.warning( + f'No filename configured; cannot read background ' + f'intensities for correction "{corr_cfg.name}"') + else: + # System modules + import os + pre_path = (f'{corr_cfg.name}/data/' + 'background_presample_intensity') + post_path = (f'{corr_cfg.name}/data/' + 'background_postsample_intensity') + if os.path.splitext(self.filename)[1] in ('.nxs', '.h5', + '.hdf5'): + import h5py + with h5py.File(self.filename, 'r') as f: + for path, name in ( + (pre_path, + 'background_presample_intensity'), + (post_path, + 'background_postsample_intensity')): + if path in f: + corr_data.append(PipelineData( + data=np.asarray(f[path]), + name=name, + )) + else: + self.logger.warning( + f'{path} not found in {self.filename}') + else: + import zarr + zarrfile = zarr.open(self.filename, mode='r') + for path, name in ( + (pre_path, 'background_presample_intensity'), + (post_path, 'background_postsample_intensity')): + try: + corr_data.append(PipelineData( + data=np.asarray(zarrfile[path]), + name=name, + )) + except KeyError: + self.logger.warning( + f'{path} not found in {self.filename}') + return corr_data + + def setup_pipelineitem(self, pipeline_item): + """Convenience method to use a nice logger when this + ``Processor`` calls on another ``PipelineItem``. Set the + logger and logging handler for given pipeline item. + + :param pipeline_item: Pipeline item. + :type pipeline_item: PipelineItem + :return: Input Pipeline item, with updated logger and + logging handler. + :rtype: PipelineItem + """ + import logging + + pipeline_item.logger = logging.getLogger( + pipeline_item.__class__.__name__, + ) + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter( + '{asctime}: {name:20} (from '+ self.__class__.__name__ + + ') (L{lineno}): {levelname}: {message}', + datefmt='%Y-%m-%d %H:%M:%S', style='{')) + pipeline_item.logger.handlers = [handler] + return pipeline_item if __name__ == '__main__': diff --git a/CHAP/saxswaxs/utils.py b/CHAP/saxswaxs/utils.py new file mode 100644 index 00000000..4a9cdfb7 --- /dev/null +++ b/CHAP/saxswaxs/utils.py @@ -0,0 +1,60 @@ +import zarr + +from CHAP.runner import set_logger + +logger, _ = set_logger(log_level='DEBUG') + +def dict_to_zarr(tree, logger=logger): + """Create a + `Zarr group `__ + object based on a dictionary representing a Zarr tree of groups + and arrays. + + :param tree: Nested dictionary representing a Zarr tree of + groups and arrays. + :type tree: dict[str, Any] + :return: Zarr group corresponding to the contents of `tree`. + :rtype: zarr.Group + """ + # Third party modules + # pylint: disable=import-error + import zarr + from zarr.storage import MemoryStore + + def create_group_or_dataset(node, zarr_parent, indent=0): + """Create and return a + `Zarr group `__ + `Zarr dataset `__. + + :param node: Child Zarr tree group. + :type node: zarr.Group or zarr.Array + :param zarr_parent: Parent Zarr tree group. + :type zarr_parent: zarr.Group + :param indent: Indentation level, defaults to 0. + :type indent: int, optional + """ + # Set attributes if present + if 'attributes' in node: + for key, value in node['attributes'].items(): + zarr_parent.attrs[key] = value + # Create children (groups or datasets) + if 'children' in node: + for name, child in node['children'].items(): + if 'shape' in child or 'data' in child: + # It's a dataset + logger.debug(f'Adding dset: {name}') + zarr_parent.create_dataset( + name, + **child, + ) + # Set dataset attributes + if 'attributes' in child: + for key, value in child['attributes'].items(): + zarr_parent[name].attrs[key] = value + else: + # It's a group + group = zarr_parent.create_group(name) + create_group_or_dataset(child, group, indent=indent+2) + results = zarr.create_group(store=MemoryStore({})) + create_group_or_dataset(tree, results) + return results From 8fc3d6efda37e961c995ced3b5846eacb988f535 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Thu, 21 May 2026 15:59:14 -0400 Subject: [PATCH 63/76] fix: set logger handlers instead of using addHandler (avoids duplicate logging messages when scripting) --- CHAP/runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHAP/runner.py b/CHAP/runner.py index 9e191d16..b7045658 100755 --- a/CHAP/runner.py +++ b/CHAP/runner.py @@ -227,9 +227,9 @@ def set_logger(log_level='INFO'): logger.setLevel(log_level) log_handler = logging.StreamHandler() log_handler.setFormatter(logging.Formatter( - '{asctime}: {name:20}: {levelname}: {message}', + '{asctime}: {name:20} (L{lineno}): {levelname}: {message}', datefmt='%Y-%m-%d %H:%M:%S', style='{')) - logger.addHandler(log_handler) + logger.handlers = [log_handler] return logger, log_handler def run( From f899e7c3bbea03dcadcce8a0e989c9ae5334c998 Mon Sep 17 00:00:00 2001 From: Rolf Verberg Date: Fri, 22 May 2026 13:39:31 -0400 Subject: [PATCH 64/76] feat: added create_copy to NexusReader to avoid linked data issues --- CHAP/common/reader.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/CHAP/common/reader.py b/CHAP/common/reader.py index de9b78a9..5fc45067 100755 --- a/CHAP/common/reader.py +++ b/CHAP/common/reader.py @@ -584,6 +584,9 @@ def read(self, filename, method='read_csv', comment='#', **kwargs): class NexusReader(Reader): """Reader for `NeXus `__ files. + :ivar create_copy: Return a copy of the selected data, resolving + any and all linked NeXus objects, defaults to `False`.. + :vartype create_copy: bool, optional :ivar nxpath: Path to a specific location in the NeXus file tree to read from, defaults to `'/'`. :vartype nxpath: str, optional @@ -595,10 +598,11 @@ class NexusReader(Reader): :vartype nxmemory: int, optional """ - nxpath: Optional[constr(strip_whitespace=True, min_length=1)] = '/' + create_copy: Optional[bool] = False idx: Optional[conint(ge=0)] = None mode: Literal['r', 'rw', 'r+', 'w', 'a'] = 'r' nxmemory: Optional[conint(gt=0)] = None + nxpath: Optional[constr(strip_whitespace=True, min_length=1)] = '/' def read(self): """Return the NeXus Style @@ -617,11 +621,22 @@ def read(self): nxsetconfig, ) + # Local modules + from CHAP.utils.general import nxcopy + if self.nxmemory is not None: nxsetconfig(memory=self.nxmemory) if self.idx is not None: - return nxload(self.filename, mode=self.mode)[self.nxpath][self.idx] - return nxload(self.filename, mode=self.mode)[self.nxpath] + nxobject = nxload( + self.filename, mode=self.mode)[self.nxpath][self.idx] + if self.create_copy: + return nxcopy(nxobject) + return nxobject + #return nxload(self.filename, mode=self.mode)[self.nxpath][self.idx] + nxobject = nxload(self.filename, mode=self.mode)[self.nxpath] + if self.create_copy: + nxobject = nxcopy(nxobject) + return nxobject class NXdataReader(Reader): From 0b0f2f43bc667a78d3488a1c9349644eca9498e3 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Fri, 22 May 2026 13:52:58 -0400 Subject: [PATCH 65/76] fix: squeeze empty first dim out of I_background before putting in zarr source --- CHAP/saxswaxs/processor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHAP/saxswaxs/processor.py b/CHAP/saxswaxs/processor.py index 4f8872e1..b9f486b0 100755 --- a/CHAP/saxswaxs/processor.py +++ b/CHAP/saxswaxs/processor.py @@ -794,7 +794,9 @@ def process(self, data): # Fill in placeholder zarr arrays with real background # data data_group = zarr_corr[corr_cfg.name]['data'] - data_group['I_background'][:] = bg_integrated['data'] + data_group['I_background'][:] = np.squeeze( + bg_integrated['data'], axis=0 + ) bg_presample_arr = data_group.create_array( 'background_presample_intensity', shape=bg_presample.shape, dtype=bg_presample.dtype) From afc02798e948a05510431ce2159966f40ca6d677 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Fri, 22 May 2026 13:54:24 -0400 Subject: [PATCH 66/76] feat: alias old CorrectionConfig field names --- CHAP/saxswaxs/models.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/CHAP/saxswaxs/models.py b/CHAP/saxswaxs/models.py index 86df7a9a..79bfd101 100644 --- a/CHAP/saxswaxs/models.py +++ b/CHAP/saxswaxs/models.py @@ -12,6 +12,8 @@ confloat, conlist, model_validator, + Field, + AliasChoices, ) # Local modules @@ -97,8 +99,9 @@ class CorrectionConfig(CHAPBaseModel): correction_type: Literal['flux', 'flux_absorption', 'flux_absorption_background'] - name: str - uncorrected_data_name: str + name: str = Field(validation_alias=AliasChoices('name', 'title')) + uncorrected_data_name: str = Field(validation_alias=AliasChoices( + 'uncorrected_data_name', 'uncorrected_data_title')) presample_intensity_reference_rate: Optional[float] = None background: Optional[Background] = None From 6fe7c82d5d2d4f97755fe2d4fd263c51d2a1599d Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Fri, 22 May 2026 14:10:47 -0400 Subject: [PATCH 67/76] fix: add default values for optional fields sample_thickness_cm and sample_mu_inv_cm --- CHAP/saxswaxs/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHAP/saxswaxs/models.py b/CHAP/saxswaxs/models.py index 79bfd101..af0c5a3b 100644 --- a/CHAP/saxswaxs/models.py +++ b/CHAP/saxswaxs/models.py @@ -340,8 +340,8 @@ class FluxAbsorptionBackgroundCorrectionConfig( correction_type: Literal[ 'flux_absorption_background'] = 'flux_absorption_background' background: Background - sample_thickness_cm: Optional[confloat(gt=0)] - sample_mu_inv_cm: Optional[confloat(gt=0)] + sample_thickness_cm: Optional[confloat(gt=0)] = None + sample_mu_inv_cm: Optional[confloat(gt=0)] = None @model_validator(mode='after') def validate_thickness(self): From dc3d5f94833e4de10b8d8bf31f705addaa984257 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Fri, 22 May 2026 14:22:59 -0400 Subject: [PATCH 68/76] fix: swap old method args for new config fields in saxswaxs.*CorrectionProcessors --- CHAP/saxswaxs/models.py | 2 +- CHAP/saxswaxs/processor.py | 218 +++++++++++++++++++------------------ 2 files changed, 115 insertions(+), 105 deletions(-) diff --git a/CHAP/saxswaxs/models.py b/CHAP/saxswaxs/models.py index af0c5a3b..48667db9 100644 --- a/CHAP/saxswaxs/models.py +++ b/CHAP/saxswaxs/models.py @@ -80,7 +80,7 @@ class CorrectionConfig(CHAPBaseModel): correction's group in the output zarr / NeXus tree. :vartype name: str :ivar uncorrected_data_name: Name of the - :class:`~CHAP.common.models.integration.PyfaiIntegratorConfig` + :class:`~CHAP.common.models.integration.PyfaiIntegrationConfig` integration whose output serves as the uncorrected input for this correction. Must match the ``name`` field of one of the integrations in the associated diff --git a/CHAP/saxswaxs/processor.py b/CHAP/saxswaxs/processor.py index b9f486b0..03c731fe 100755 --- a/CHAP/saxswaxs/processor.py +++ b/CHAP/saxswaxs/processor.py @@ -224,36 +224,41 @@ def process( class FluxCorrectionProcessor(ExpressionProcessor): - """Processor for flux correction.""" + """Processor for applying a flux correction to azimuthally + integrated SAXS/WAXS intensity data. + + Normalises the measured intensity by the pre-sample beam monitor + counts, referenced to a fixed counting rate configured via + :attr:`config`. + + :ivar config: Correction configuration. + :vartype config: FluxCorrectionConfig + """ config: FluxCorrectionConfig - def process( - self, data, presample_intensity_reference_rate=None, - nxprocess=False): - """Given input data for `'intensity'`, `'presample_intensity'`, - and `'dwell_time_actual'`, compute the flux corrected intensity - signal. - - :param data: Input data list containing items with names - `'intensity'`, `'presample_intensity'`, and - `'dwell_time_actual'` (if - `presample_intensity_reference_rate` is not specified). + def process(self, data, nxprocess=False): + """Compute the flux-corrected intensity. + + Requires input data items named ``'intensity'`` (the raw + integrated signal) and ``'presample_intensity'`` (beam monitor + counts). When + :attr:`~saxswaxs.models.FluxCorrectionConfig.presample_intensity_reference_rate` + is not pre-configured in :attr:`config`, an additional item + named ``'dwell_time_actual'`` is also required and the + reference rate is computed on the fly as + ``np.nanmean(presample_intensity / dwell_time_actual)``. + + :param data: Input data list. :type data: list[PipelineData] - :param presample_intensity_reference_rate: Reference counting - rate for the `'presample_intensity'` signal. If not - specified, it will set to - `'numpy.nanmean(presample_intensity / - dwell_time_actual)'`. - :type presample_intensity_reference_rate: float, optional - :param nxprocess: Flag to indicate the flux corrected data - should be returned as a NeXus style + :param nxprocess: Return the result as a NeXus `NXobject `__ - object. Defaults to `False`. + object. Defaults to ``False``. :type nxprocess: bool, optional - :returns: Flux corrected version of input `'intensity'` data. - :rtype: Any + :returns: Flux-corrected intensity array (or NXobject when + ``nxprocess`` is ``True``). + :rtype: numpy.ndarray or nexusformat.nexus.NXobject """ - if presample_intensity_reference_rate is None: - presample_intensity_reference_rate = super().process( + if self.config.presample_intensity_reference_rate is None: + self.config.presample_intensity_reference_rate = super().process( data, 'np.nanmean(presample_intensity / dwell_time_actual)' ) @@ -280,7 +285,7 @@ def process( ) symtable = { 'presample_intensity_reference_rate': - presample_intensity_reference_rate, + self.config.presample_intensity_reference_rate, 'intensity': intensity, 'presample_intensity': presample_intensity } @@ -295,41 +300,47 @@ def process( class FluxAbsorptionCorrectionProcessor(ExpressionProcessor): - """Processor for flux and absorption correction.""" + """Processor for applying combined flux and absorption correction + to azimuthally integrated SAXS/WAXS intensity data. + + In addition to flux normalisation (see + :class:`FluxCorrectionProcessor`), divides by the sample + transmission, which is estimated from the ratio of post-sample to + pre-sample intensities for both the sample and a background scan. + + :ivar config: Correction configuration. + :vartype config: FluxAbsorptionCorrectionConfig + """ config: FluxAbsorptionCorrectionConfig - def process( - self, data, presample_intensity_reference_rate=None, - nxprocess=False): - """Given input data for `'intensity'`, `'presample_intensity'`, - `'postsample_intensity'`, `'background_presample_intensity'`, - `'background_postsample_intensity'`, and - `'dwell_time_actual'`, compute the flux and absorption - corrected intensity signal. - - :param data: Input data list containing all necessary data - labelled with their proper names. + def process(self, data, nxprocess=False): + """Compute the flux- and absorption-corrected intensity. + + Requires input data items named ``'intensity'``, + ``'presample_intensity'``, ``'postsample_intensity'``, + ``'background_presample_intensity'``, and + ``'background_postsample_intensity'``. When + :attr:`~saxswaxs.models.FluxAbsorptionCorrectionConfig.presample_intensity_reference_rate` + is not pre-configured in :attr:`config`, an additional item + named ``'dwell_time_actual'`` is also required and the + reference rate is computed on the fly as + ``np.nanmean(presample_intensity / dwell_time_actual)``. + + :param data: Input data list. :type data: list[PipelineData] - :param presample_intensity_reference_rate: Reference counting - rate for the `'presample_intensity'` signal. If not - specified, it will be calculated with - `'numpy.nanmean(presample_intensity / - dwell_time_actual)'`. - :type presample_intensity_reference_rate: float, optional - :param nxprocess: Flag to indicate the flux and absorption - corrected data should be returned as a NeXus style + :param nxprocess: Return the result as a NeXus `NXobject `__ - object. Defaults to `False`. + object. Defaults to ``False``. :type nxprocess: bool, optional - :returns: Flux and absorption corrected version of input - `'intensity'` data. - :rtype: Any + :returns: Flux- and absorption-corrected intensity array (or + NXobject when ``nxprocess`` is ``True``). + :rtype: numpy.ndarray or nexusformat.nexus.NXobject """ intensity = self.get_data( data, name=self.config.uncorrected_data_name, #'intensity', ) - if presample_intensity_reference_rate is None: - presample_intensity_reference_rate = super().process( + if self.config.presample_intensity_reference_rate is None: + self.config.presample_intensity_reference_rate = super().process( data, 'np.nanmean(presample_intensity / dwell_time_actual)' ) @@ -356,7 +367,7 @@ def process( symtable = { 'presample_intensity_reference_rate': - presample_intensity_reference_rate, + self.config.presample_intensity_reference_rate, 'intensity': intensity, 'presample_intensity': presample_intensity, 'tt': tt @@ -373,61 +384,60 @@ def process( class FluxAbsorptionBackgroundCorrectionProcessor(ExpressionProcessor): - """Processor for flux, absorption, and background correction as - well as optional thickness correction.""" + """Processor for applying combined flux, absorption, and + background-subtraction correction to azimuthally integrated + SAXS/WAXS intensity data, with optional thickness normalisation. + + Extends :class:`FluxAbsorptionCorrectionProcessor` by subtracting + an integrated background signal after flux and absorption + correction. Thickness normalisation is applied when + :attr:`~saxswaxs.models.FluxAbsorptionBackgroundCorrectionConfig.sample_thickness_cm` + or + :attr:`~saxswaxs.models.FluxAbsorptionBackgroundCorrectionConfig.sample_mu_inv_cm` + is set in :attr:`config`. + + :ivar config: Correction configuration. + :vartype config: FluxAbsorptionBackgroundCorrectionConfig + """ config: FluxAbsorptionBackgroundCorrectionConfig - def process( - self, data, presample_intensity_reference_rate=None, - sample_thickness_cm=None, sample_mu_inv_cm=None, nxprocess=False): - """Given input data for `'intensity'`, `'presample_intensity'`, - `'postsample_intensity'`, `'background_presample_intensity'`, - `'background_postsample_intensity'`, `'background_intensity'`, - and `'dwell_time_actual'`, return flux, absorption and - background corrected intensity signal. - - :param data: Input data list containing all necessary data - labelled with their proper names. + def process(self, data, nxprocess=False): + """Compute the flux-, absorption-, and background-corrected + intensity, with optional thickness normalisation. + + Requires input data items named ``'intensity'``, + ``'presample_intensity'``, ``'postsample_intensity'``, + ``'background_presample_intensity'``, + ``'background_postsample_intensity'``, and + ``'background_intensity'``. When + :attr:`~saxswaxs.models.FluxAbsorptionBackgroundCorrectionConfig.presample_intensity_reference_rate` + is not pre-configured in :attr:`config`, an additional item + named ``'dwell_time_actual'`` is also required and the + reference rate is computed on the fly as + ``np.nanmean(presample_intensity / dwell_time_actual)``. + + Thickness normalisation is controlled by :attr:`config`: if + ``config.sample_thickness_cm`` is set the corrected signal is + divided by that value; if ``config.sample_mu_inv_cm`` is set + the effective thickness is derived from the measured + transmission; otherwise no thickness normalisation is applied. + + :param data: Input data list. :type data: list[PipelineData] - :param presample_intensity_reference_rate: Reference counting - rate for the `'presample_intensity'` signal. If not - specified, it will be calculated with - `'numpy.nanmean(presample_intensity / - dwell_time_actual)'`. - :type presample_intensity_reference_rate: float, optional - :param sample_thickness_cm: Sample thickness in - centimeters. If specified, this processor will - additionally perform thickness correction. Use of this - parameter is mutualy exclusive with - use of `sample_mu_inv_cm`. - :type sample_thickness_cm: float, optional - :param sample_mu_inv_cm: Sample linear attenuation coefficient - in inverse centimeters. If specified, this processor will - additionally perform thickness correction. Use of this - parameter is mutualy exclusive with use of - `sample_thickness_cm`. - :type sample_mu_inv_cm: float, optional - :param nxprocess: Flag to indicate the flux, absorption, and - background corrected data should be returned as a Nexus - style + :param nxprocess: Return the result as a NeXus `NXobject `__ - object. Defaults to `False`. + object. Defaults to ``False``. :type nxprocess: bool, optional - :returns: Flux, absorption and background corrected version of - input `'intensity'` data. - :rtype: Any + :returns: Flux-, absorption-, and background-corrected + intensity array (or NXobject when ``nxprocess`` is + ``True``). + :rtype: numpy.ndarray or nexusformat.nexus.NXobject """ - if sample_thickness_cm is not None and sample_mu_inv_cm is not None: - raise ValueError(( - 'Cannot use sample_thickness_cm and sample_mu_inv_cm' - ' at the same time' - )) - intensity = self.get_data( data, name=self.config.uncorrected_data_name, ) - if presample_intensity_reference_rate is None: - presample_intensity_reference_rate = super().process( + if self.config.presample_intensity_reference_rate is None: + self.config.presample_intensity_reference_rate = super().process( data, 'np.nanmean(presample_intensity / dwell_time_actual)' ) @@ -460,17 +470,17 @@ def process( background_intensity = np.broadcast_to( background_intensity, intensity.shape) - if sample_thickness_cm is not None: - t = sample_thickness_cm - elif sample_mu_inv_cm is not None: - t = -np.log(tt / sample_mu_inv_cm) + if self.config.sample_thickness_cm is not None: + t = self.config.sample_thickness_cm + elif self.config.sample_mu_inv_cm is not None: + t = -np.log(tt / self.config.sample_mu_inv_cm) else: t = 1 symtable = { 't': t, 'presample_intensity_reference_rate': - presample_intensity_reference_rate, + self.config.presample_intensity_reference_rate, 'intensity': intensity, 'presample_intensity': presample_intensity, 'tt': tt, From 1fb4b1f506f5290664983937063fa4c726fff77b Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Fri, 22 May 2026 14:23:49 -0400 Subject: [PATCH 69/76] fix: include background intensities in input data for correction Processors that need them --- CHAP/saxswaxs/processor.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/CHAP/saxswaxs/processor.py b/CHAP/saxswaxs/processor.py index 03c731fe..aa4be885 100755 --- a/CHAP/saxswaxs/processor.py +++ b/CHAP/saxswaxs/processor.py @@ -611,9 +611,9 @@ class SetupProcessor(Processor): # the case, then map_config needs SOME value in order for the # Pipeline to pass validation. map_config: MapConfig = None - pyfai_config: PyfaiIntegrationConfig + pyfai_config: PyfaiIntegrationConfig = None detector_config: DetectorConfig = DetectorConfig(detectors=[]) - correction_config: CorrectionsConfig + correction_config: CorrectionsConfig = CorrectionsConfig(corrections=[]) dataset_shape: Optional[ conlist(item_type=conint(ge=0), min_length=1)] = [0] dataset_chunks: Optional[ @@ -1342,6 +1342,7 @@ def get_corrections_input_data(self, raw_values, processed_values, 'background_presample_intensity') post_path = (f'{corr_cfg.name}/data/' 'background_postsample_intensity') + intens_path = f'{corr_cfg.name}/data/I_background' if os.path.splitext(self.filename)[1] in ('.nxs', '.h5', '.hdf5'): import h5py @@ -1350,7 +1351,8 @@ def get_corrections_input_data(self, raw_values, processed_values, (pre_path, 'background_presample_intensity'), (post_path, - 'background_postsample_intensity')): + 'background_postsample_intensity'), + (intens_path, 'background_intensity')): if path in f: corr_data.append(PipelineData( data=np.asarray(f[path]), @@ -1364,7 +1366,8 @@ def get_corrections_input_data(self, raw_values, processed_values, zarrfile = zarr.open(self.filename, mode='r') for path, name in ( (pre_path, 'background_presample_intensity'), - (post_path, 'background_postsample_intensity')): + (post_path, 'background_postsample_intensity'), + (intens_path, 'background_intensity')): try: corr_data.append(PipelineData( data=np.asarray(zarrfile[path]), From e8f16c9701766d406e8eda90931b7cf548518b25 Mon Sep 17 00:00:00 2001 From: Rolf Verberg Date: Thu, 28 May 2026 12:54:12 -0400 Subject: [PATCH 70/76] add: Added a max # func evals for option for the fit processor Applied its use to the EDD workflow. --- CHAP/common/writer.py | 3 +-- CHAP/edd/models.py | 7 +++++-- CHAP/edd/processor.py | 5 +++-- CHAP/edd/reader.py | 2 +- CHAP/edd/utils.py | 4 +++- CHAP/utils/fit.py | 21 ++++++++++++++------- 6 files changed, 27 insertions(+), 15 deletions(-) diff --git a/CHAP/common/writer.py b/CHAP/common/writer.py index 04d6e84b..a118ba22 100755 --- a/CHAP/common/writer.py +++ b/CHAP/common/writer.py @@ -26,8 +26,7 @@ Writer, validate_writer_model, ) -from CHAP.common.models import IndexSliceConfig - +from CHAP.common.models.common import IndexSliceConfig def validate_model(model): diff --git a/CHAP/edd/models.py b/CHAP/edd/models.py index d224a69e..0ad7dabe 100755 --- a/CHAP/edd/models.py +++ b/CHAP/edd/models.py @@ -883,15 +883,17 @@ class StrainAnalysisConfig(MCACalibrationConfig): the maximum mean intensity for that detector. Defaults to `0` in which case this step is ignored. :vartype find_peak_cutoff: float, optional + :ivar max_nfev: Maximum number of function evaluations in the + the strain analysis peak fitting routine. + :vartype max_nfev: int, optional :ivar num_proc: Number of processors to be used by the strain analysis peak fitting routine. - :vartype num_proc: int + :vartype num_proc: int, optional :ivar rel_height_cutoff: Used to excluded peaks based on the `find_peak` parameter as well as for peak fitting exclusion of the individual detector spectra (see the strain detector configuration :class:`~CHAP.edd.models.MCADetectorStrainAnalysis`). - Defaults to `None`. :vartype rel_height_cutoff: float, optional :ivar skip_animation: Skip the animation and plotting of the strain analysis fits, defaults to `False`. @@ -904,6 +906,7 @@ class StrainAnalysisConfig(MCACalibrationConfig): #:vartype oversampling: FIX find_peak_cutoff: Optional[confloat(ge=0.0, allow_inf_nan=False)] = 0.0 + max_nfev: Optional[conint(gt=0)] = None num_proc: Optional[conint(gt=0)] = max(1, os.cpu_count()//4) #oversampling: dict = {'num': 10} rel_height_cutoff: Optional[ diff --git a/CHAP/edd/processor.py b/CHAP/edd/processor.py index db685bbe..971aef78 100755 --- a/CHAP/edd/processor.py +++ b/CHAP/edd/processor.py @@ -3377,7 +3377,7 @@ def _linkdims( from nexusformat.nexus.tree import NXlinkfield if not isinstance(nxgroup.nxroot, NXroot): - self.logger.warning( + self.logger.debug( 'Skipping linkdims -- type(nxgroup.nxroot) = ' + f'{type(nxgroup.nxroot)}' ) @@ -3561,7 +3561,8 @@ def _strain_analysis(self): uniform_results, unconstrained_results = get_spectra_fits( np.squeeze(intensities), energies[mask], peak_locations[use_peaks], detector, - num_proc=self.config.num_proc, **self.run_config) + num_proc=self.config.num_proc, + max_nfev=self.config.max_nfev, **self.run_config) if num_points == 1: uniform_results = {k: [v] for k, v in uniform_results.items()} unconstrained_results = { diff --git a/CHAP/edd/reader.py b/CHAP/edd/reader.py index e813ee03..4a89b18d 100755 --- a/CHAP/edd/reader.py +++ b/CHAP/edd/reader.py @@ -785,7 +785,7 @@ def read(self): from nexusformat.nexus import NXentry, NXfield # Local modules - from CHAP.common import NexusReader + from CHAP.common.reader import NexusReader from CHAP.utils.general import nxcopy # Read NXroot diff --git a/CHAP/edd/utils.py b/CHAP/edd/utils.py index 8079d4b5..82f8b625 100755 --- a/CHAP/edd/utils.py +++ b/CHAP/edd/utils.py @@ -1311,7 +1311,8 @@ def get_spectra_fits( # Local modules from CHAP.utils.fit import FitProcessor - num_proc = kwargs.get('num_proc', 1) + num_proc = kwargs.pop('num_proc', 1) + max_nfev = kwargs.pop('max_nfev', 64000) rel_height_cutoff = detector.rel_height_cutoff num_peak = len(peak_locations) nxdata = NXdata(NXfield(spectra, 'y'), NXfield(energies, 'x')) @@ -1341,6 +1342,7 @@ def get_spectra_fits( 'models': models, # 'plot': True, 'num_proc': num_proc, + 'max_nfev': max_nfev, 'rel_height_cutoff': rel_height_cutoff, # 'method': 'trf', 'method': 'leastsq', diff --git a/CHAP/utils/fit.py b/CHAP/utils/fit.py index 56489ab8..af4fc1a4 100755 --- a/CHAP/utils/fit.py +++ b/CHAP/utils/fit.py @@ -106,9 +106,9 @@ def process(self, data, config=None): # Refit/continue the fit with possibly updated parameters fit = data if isinstance(data, FitMap): - fit.fit(config=fit_config) + fit.fit(config=fit_config, max_nfev=config.get('max_nfev')) else: - fit.fit(config=fit_config) + fit.fit(config=fit_config, max_nfev=config.get('max_nfev')) if fit_config is not None: if fit_config.print_report: fit.print_fit_report() @@ -145,7 +145,7 @@ def process(self, data, config=None): # Instantiate the Fit or FitMap object and fit the data if np.squeeze(nxdata.nxsignal).ndim == 1: fit = Fit(nxdata, fit_config, self.logger) - fit.fit() + fit.fit(max_nfev=config.get('max_nfev')) if fit_config.print_report: fit.print_fit_report() if fit_config.plot: @@ -154,8 +154,9 @@ def process(self, data, config=None): fit = FitMap(nxdata, fit_config, self.logger) fit.fit( rel_height_cutoff=fit_config.rel_height_cutoff, - num_proc=fit_config.num_proc, plot=fit_config.plot, - print_report=fit_config.print_report) + max_nfev=config.get('max_nfev'), + num_proc=fit_config.num_proc, + plot=fit_config.plot, print_report=fit_config.print_report) else: raise ValueError(f'Invalid input data ({type(data)}: {data})') @@ -2087,12 +2088,12 @@ def _fit_nonlinear_model(self, x, y, **kwargs): 'gtol': 10*FLOAT_EPS, } if self._method == 'leastsq': - lskws['maxfev'] = 64000 + lskws['maxfev'] = kwarg.get('max_nfev', 64000) result = leastsq( self._residual, pars_init, args=(x, y), full_output=True, **lskws) else: - lskws['max_nfev'] = 64000 + lskws['max_nfev'] = kwarg.get('max_nfev', 64000) result = least_squares( self._residual, pars_init, bounds=bounds, method=self._method, args=(x, y), **lskws) @@ -3089,6 +3090,9 @@ def _fit(self, n, current_best_values, return_result=False, **kwargs): # Regular full fit result = self._fit_with_bounds_check(n, current_best_values, **kwargs) + if result.nfev == kwargs.get('max_nfev'): + self._logger.info( + f'Hit max_nfev limit for n={n}\n\tnfev: {result.nfev}') if self._rel_height_cutoff is not None: # Check for low heights peaks and refit without them @@ -3123,6 +3127,9 @@ def _fit(self, n, current_best_values, return_result=False, **kwargs): # Reset fixed amplitudes back to default self._parameters = deepcopy(parameters_save) self._parameter_bounds = deepcopy(parameters_bounds_save) + if result.nfev == kwargs.get('max_nfev'): + self._logger.info( + f'\nHit max_nfev limit again after refit for n={n}') if result.redchi >= self._redchi_cutoff: result.success = False From 580d5964e4a8aff0a838280ee53931b7ed98e9bc Mon Sep 17 00:00:00 2001 From: Rolf Verberg Date: Fri, 29 May 2026 13:38:10 -0400 Subject: [PATCH 71/76] fix: bug in max_nfev kwargs for scipy fitting --- CHAP/common/writer.py | 2 +- CHAP/utils/fit.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/CHAP/common/writer.py b/CHAP/common/writer.py index a118ba22..ceb3c661 100755 --- a/CHAP/common/writer.py +++ b/CHAP/common/writer.py @@ -897,7 +897,7 @@ def get_dict(data): self.status = 'written' # Right now does nothing yet, but could # add a sort of modification flag later - # Return provenance with the output file name added + # Return provenance with the output file name added return self._update_provenance(data) diff --git a/CHAP/utils/fit.py b/CHAP/utils/fit.py index af4fc1a4..4b15d05e 100755 --- a/CHAP/utils/fit.py +++ b/CHAP/utils/fit.py @@ -2087,13 +2087,16 @@ def _fit_nonlinear_model(self, x, y, **kwargs): 'xtol': 1.49012e-08, 'gtol': 10*FLOAT_EPS, } + max_nfev = kwargs.get('max_nfev') if self._method == 'leastsq': - lskws['maxfev'] = kwarg.get('max_nfev', 64000) + if max_nfev is not None: + lskws['maxfev'] = max_nfev result = leastsq( self._residual, pars_init, args=(x, y), full_output=True, **lskws) else: - lskws['max_nfev'] = kwarg.get('max_nfev', 64000) + if max_nfev is not None: + lskws['max_nfev'] = max_nfev result = least_squares( self._residual, pars_init, bounds=bounds, method=self._method, args=(x, y), **lskws) From 6e0bc658c5da85db22cbb618b1e44e357f99367d Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Mon, 1 Jun 2026 09:23:12 -0400 Subject: [PATCH 72/76] fix: add CHAP.saxswaxs.models.Background.idx_slice --- CHAP/saxswaxs/models.py | 47 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/CHAP/saxswaxs/models.py b/CHAP/saxswaxs/models.py index 48667db9..25d30d22 100644 --- a/CHAP/saxswaxs/models.py +++ b/CHAP/saxswaxs/models.py @@ -31,12 +31,59 @@ class Background(SpecScans): optional index slice so that only a subset of scan steps is used when reading background images. + Accepts either ``idx_slice`` (a :class:`~CHAP.common.models.common.IndexSliceConfig` + dict) or the convenience field ``scan_step_indices`` (a list of + integers or a compact string such as ``"0-4, 6"``), but not both. + When ``scan_step_indices`` is given it is converted to the + equivalent ``idx_slice``; the indices must form a uniformly-spaced + sequence expressible as a Python :class:`slice`. + :ivar idx_slice: Index slice selecting which scan steps of the background scan(s) to read and average. Defaults to all steps. :vartype idx_slice: IndexSliceConfig + :ivar scan_step_indices: Convenience alternative to ``idx_slice``. + A list of integer step indices (or a compact string such as + ``"0-4, 6"``) that are converted to an ``idx_slice`` during + validation. The indices must be uniformly spaced. Mutually + exclusive with ``idx_slice``. + :vartype scan_step_indices: list[int] or str, optional """ idx_slice: IndexSliceConfig = IndexSliceConfig() + scan_step_indices: Optional[Union[list[int], str]] = None + + @model_validator(mode='before') + @classmethod + def fill_idx_slice(cls, data): + scan_step_indices = data.get('scan_step_indices') + idx_slice = data.get('idx_slice') + if scan_step_indices is not None and idx_slice is not None: + raise ValueError( + 'Specify idx_slice or scan_step_indices, not both.') + if scan_step_indices is not None: + if isinstance(scan_step_indices, str): + from CHAP.utils.general import string_to_list + scan_step_indices = string_to_list(scan_step_indices) + # scan_step_indices is now list[int]; derive a uniform slice + indices = sorted(scan_step_indices) + if len(indices) == 1: + start, step = indices[0], 1 + else: + step = indices[1] - indices[0] + if step <= 0: + raise ValueError( + 'scan_step_indices must contain distinct, ' + 'positive-step values.') + diffs = [indices[i+1] - indices[i] + for i in range(len(indices) - 1)] + if len(set(diffs)) != 1: + raise ValueError( + 'scan_step_indices must be uniformly spaced so ' + 'they can be expressed as a slice.') + start = indices[0] + stop = indices[-1] + step + data['idx_slice'] = {'start': start, 'stop': stop, 'step': step} + return data def zarr_arrays(self, integration_shape): """Return a dictionary describing the zarr array that will hold From ad7de50c0b51def6bdc9785932e6ed1e328931f1 Mon Sep 17 00:00:00 2001 From: Rolf Verberg Date: Tue, 2 Jun 2026 09:21:05 -0400 Subject: [PATCH 73/76] fix: changed config to FitProcessor class field, change input data type for fit --- CHAP/edd/processor.py | 66 +++++------ CHAP/edd/utils.py | 8 +- CHAP/pipeline.py | 2 +- CHAP/tomo/processor.py | 10 +- CHAP/utils/fit.py | 261 +++++++++++++++++++++-------------------- 5 files changed, 175 insertions(+), 172 deletions(-) diff --git a/CHAP/edd/processor.py b/CHAP/edd/processor.py index db685bbe..c2258031 100755 --- a/CHAP/edd/processor.py +++ b/CHAP/edd/processor.py @@ -735,11 +735,10 @@ def _measure_dvl(self, scanned_vals): models.append({'model': model, 'prefix': f'{model}_'}) models.append({'model': 'gaussian'}) self.logger.debug('Fitting mean spectrum') - fit = FitProcessor(**self.run_config) - result = fit.process( - NXdata( - NXfield(masked_sum, 'y'), NXfield(x, 'x')), - {'models': models, 'method': 'trf'}) + result = FitProcessor.run( + data={'x': x, 'y': masked_sum}, + config={'models': models, 'method': 'trf'}, + **self.run_config) # Calculate / manually select diffraction volume length detector.dvl = float( @@ -1406,12 +1405,10 @@ def _calibrate(self): 'fwhm_min': detector.fwhm_min, 'fwhm_max': detector.fwhm_max}) self.logger.debug('Fitting spectrum') - fit = FitProcessor(**self.run_config) - mean_data_fit = fit.process( - NXdata( - NXfield(mean_data[mask], 'y'), NXfield(bins[mask], 'x')), - {'models': models, 'method': 'trf'}) - + mean_data_fit = FitProcessor.run( + data={'x': bins[mask], 'y': mean_data[mask]}, + config={'models': models, 'method': 'trf'}, + **self.run_config) # Extract the fit results for the peaks fit_peak_amplitudes = np.asarray([ @@ -1428,12 +1425,10 @@ def _calibrate(self): self.logger.debug(f'Fit peak sigmas: {fit_peak_sigmas}') # FIX for now stick with a linear energy correction - fit = FitProcessor(**self.run_config) - energy_fit = fit.process( - NXdata( - NXfield(peak_energies, 'y'), - NXfield(fit_peak_indices, 'x')), - {'models': [{'model': 'linear'}]}) + energy_fit = FitProcessor.run( + data={'x': fit_peak_indices, 'y': peak_energies}, + config={'models': [{'model': 'linear'}]}, + **self.run_config) a = 0.0 b = float(energy_fit.best_values['slope']) c = float(energy_fit.best_values['intercept']) @@ -2141,10 +2136,10 @@ def _direct_bragg_peak_fit( 'fwhm_max': detector.fwhm_max}) # Perform an unconstrained fit in terms of MCA bin index - fit = FitProcessor(**self.run_config) - result = fit.process( - NXdata(NXfield(mean_data[mask], 'y'), NXfield(bins[mask], 'x')), - {'models': models, 'method': 'trf'}) + result = FitProcessor.run( + data={'x': bins[mask], 'y': mean_data[mask]}, + config={'models': models, 'method': 'trf'}, + **self.run_config) best_fit = result.best_fit residual = result.residual @@ -2168,10 +2163,10 @@ def _direct_bragg_peak_fit( model = 'quadratic' else: model = 'linear' - fit = FitProcessor(**self.run_config) - result = fit.process( - NXdata(NXfield(e_bragg, 'y'), NXfield(fit_peak_indices, 'x')), - {'models': [{'model': model}]}) + result = FitProcessor.run( + data={'x': fit_peak_indices, 'y': e_bragg}, + config={'models': [{'model': model}]}, + **self.run_config) if quadratic_energy_calibration: a_fit = result.best_values['a'] b_fit = result.best_values['b'] @@ -2298,10 +2293,11 @@ def _direct_fit_tth_ecc( {'name': 'sigma', 'min': sig_min, 'max': sig_max}]}) # Perform the fit - fit = FitProcessor(**self.run_config) - result = fit.process( - NXdata(NXfield(mean_data[mask], 'y'), NXfield(bins[mask], 'x')), - {'parameters': parameters, 'models': models, 'method': 'trf'}) + result = FitProcessor.run( + data={'x': bins[mask], 'y': mean_data[mask]}, + config={ + 'parameters': parameters, 'models': models, 'method': 'trf'}, + **self.run_config) # Extract values of interest from the best values tth_fit = np.degrees(result.best_values['tth']) @@ -2333,14 +2329,10 @@ def _direct_fit_tth_ecc( 'centers_range': b_fit * detector.centers_range, 'fwhm_min': b_fit * detector.fwhm_min, 'fwhm_max': b_fit * detector.fwhm_max}] - result = fit.process( - NXdata(NXfield(mean_data[mask], 'y'), NXfield(bins[mask], 'x')), - {'parameters': parameters, 'models': models, 'method': 'trf'}) - fit = FitProcessor(**self.run_config) - result = fit.process( - NXdata(NXfield(mean_data[mask], 'y'), - NXfield(energies[mask], 'x')), - {'models': models, 'method': 'trf'}) + result = FitProcessor.run( + data={'x': energies[mask], 'y': mean_data[mask]}, + config={'models': models, 'method': 'trf'}, + **self.run_config) e_xrf_unconstrained = np.sort( [result.best_values[f'peak{i+1}_center'] for i in range(num_xrf)]) diff --git a/CHAP/edd/utils.py b/CHAP/edd/utils.py index 8079d4b5..7a1537ef 100755 --- a/CHAP/edd/utils.py +++ b/CHAP/edd/utils.py @@ -1314,7 +1314,6 @@ def get_spectra_fits( num_proc = kwargs.get('num_proc', 1) rel_height_cutoff = detector.rel_height_cutoff num_peak = len(peak_locations) - nxdata = NXdata(NXfield(spectra, 'y'), NXfield(energies, 'x')) # Construct the fit model models = [] @@ -1350,8 +1349,8 @@ def get_spectra_fits( # Perform uniform fit # FIX make more generic for fit parameters - fit = FitProcessor(**kwargs) - uniform_fit = fit.process(nxdata, config) + uniform_fit = FitProcessor.fit( + data={'x': energies, 'y': spectra}, config=config, **kwargs) uniform_success = uniform_fit.success if spectra.ndim == 1: if uniform_success: @@ -1518,7 +1517,8 @@ def get_spectra_fits( # Perform unconstrained fit config['models'][-1]['fit_type'] = 'unconstrained' - unconstrained_fit = fit.process(uniform_fit, config) + fit.config = config + unconstrained_fit = fit.process(uniform_fit) unconstrained_success = unconstrained_fit.success if spectra.ndim == 1: if unconstrained_success: diff --git a/CHAP/pipeline.py b/CHAP/pipeline.py index ba843ca3..91800b24 100755 --- a/CHAP/pipeline.py +++ b/CHAP/pipeline.py @@ -416,7 +416,7 @@ def get_pipelinedata_item(data, index=-1, remove=False): :return: Matching data item. :rtype: Any """ - if isinstance(data, list): + if isinstance(data, list) and isinstance(data[index], PipelineData): if remove: return data.pop(index)['data'] return data[index]['data'] diff --git a/CHAP/tomo/processor.py b/CHAP/tomo/processor.py index 18993aa5..9d794348 100755 --- a/CHAP/tomo/processor.py +++ b/CHAP/tomo/processor.py @@ -1627,7 +1627,6 @@ def _set_detector_bounds( # Try to get a fit from the bright field row_sum = np.sum(tbf, 1) num = len(row_sum) - fit = FitProcessor(**self.run_config) model = {'model': 'rectangle', 'parameters': [ {'name': 'amplitude', @@ -1641,11 +1640,10 @@ def _set_detector_bounds( 'min': 0.0, 'max': num}, {'name': 'sigma2', 'value': num/7.0, 'min': sys.float_info.min}]} - bounds_fit = fit.process( - data=NXdata( - NXfield(row_sum, 'y'), - NXfield(np.array(range(num)), 'x')), - config={'models': [model], 'method': 'trf'}) + bounds_fit = FitProcessor.run( + data={'x': np.array(range(num)), 'y':row_sum}, + config={'models': [model], 'method': 'trf'}, + **self.run_config) parameters = bounds_fit.best_values row_low_fit = parameters.get('center1', None) row_upp_fit = parameters.get('center2', None) diff --git a/CHAP/utils/fit.py b/CHAP/utils/fit.py index 56489ab8..86f810c4 100755 --- a/CHAP/utils/fit.py +++ b/CHAP/utils/fit.py @@ -14,6 +14,7 @@ from shutil import rmtree from sys import float_info #from time import time +from typing import Optional # Third party modules try: @@ -24,8 +25,8 @@ HAVE_JOBLIB = True except ImportError: HAVE_JOBLIB = False -from nexusformat.nexus import NXdata import numpy as np +from pydantic import Field # Local modules from CHAP.processor import Processor @@ -35,6 +36,7 @@ index_nearest, quick_plot, ) +from CHAP.utils.models import FitConfig FLOAT_MIN = float_info.min FLOAT_MAX = float_info.max @@ -60,104 +62,93 @@ class FitProcessor(Processor): - """A processor to perform a fit on a data set or data map. """ + """A processor to perform a fit on a data set or data map. - def process(self, data, config=None): + :ivar config: Initialization parameters for an instance of + :class:`~CHAP.utils.models.FitConfig`. + :vartype config: dict, optional + """ + + pipeline_fields: dict = Field( + default = { + 'config': 'CHAP.utils.models.FitConfig'}, init_var=True) + config: Optional[FitConfig] = None + + def process(self, data): """Fit the data and return a :class:`~CHAP.utils.fit.Fit` or :class:`~CHAP.utils.fit.FitMap` object depending on the - dimensionality of the input data. The input data should be or - contain a NeXus style - `NXdata `__ - object, with properly defined signal and axis, or a - :class:`~CHAP.utils.fit.Fit` or :class:`~CHAP.utils.fit.FitMap` - object from a previous fit. - - :param data: Input data containing the - nexusformat.nexus.NXdata object to fit. - :type data: list[PipelineData] or Fit or FitMap or - nexusformat.nexus.NXdata - :param config: Fit configuration. - :type config: dict, optional - :raises ValueError: Invalid input or configuration parameter. + dimensionality of the input data. The input data should be an + array-like object (tuple, list or numpy.ndarray) or a + dictionary with at least a `"y"` field with a value equal to + one or a :class:`~CHAP.utils.fit.Fit` or + :class:`~CHAP.utils.fit.FitMap` object from a previous fit. + + :param data: Input data containing the data object to fit. + :type data: list[PipelineData] or Fit or FitMap or array-like :return: The fitted data object. :rtype: Fit or FitMap """ # Local modules - from CHAP.utils.models import ( - FitConfig, - Multipeak, - ) + from CHAP.utils.models import Multipeak - # Unwrap the PipelineData if called as a Pipeline Processor - if (not isinstance(data, (Fit, FitMap)) - and not isinstance(data, NXdata)): - data = self.get_pipelinedata_item(data) - - # Get the validated fit configuration - fit_config = None - if config is not None: - try: - fit_config = FitConfig(**config) - except Exception as exc: - raise RuntimeError from exc +# # Unwrap the PipelineData if called as a Pipeline Processor +# data = self.get_pipelinedata_item(data) if isinstance(data, (Fit, FitMap)): # Refit/continue the fit with possibly updated parameters fit = data if isinstance(data, FitMap): - fit.fit(config=fit_config) + fit.fit(config=self.config) else: - fit.fit(config=fit_config) - if fit_config is not None: - if fit_config.print_report: + fit.fit(config=self.config) + if self.config is not None: + if self.config.print_report: fit.print_fit_report() - if fit_config.plot: + if self.config.plot: fit.plot(skip_init=True) - elif isinstance(data, NXdata): + else: - # Get the default NXdata object + # Test for the correct input data type try: - nxdata = data.get_default() - assert nxdata is not None - except (AssertionError, ValueError) as exc: - if nxdata is None or nxdata.nxclass != 'NXdata': - raise ValueError( - 'Invalid default pathway to an NXdata ' - f'object in ({data})') from exc + if isinstance(data, dict): + x = np.asarray(data.get('x')) + y = np.asarray(data.get('y')) + else: + y = np.asarray(data) + except (ValueError, TypeError) as exc: + raise ValueError(f'Invalid input data ({type(data)}\n{data})') # Expand multipeak model if present found_multipeak = False - for i, model in enumerate(deepcopy(fit_config.models)): + for i, model in enumerate(deepcopy(self.config.models)): if isinstance(model, Multipeak): if found_multipeak: raise ValueError( - f'Invalid parameter models ({fit_config.models}) ' + f'Invalid parameter models ({self.config.models}) ' '(multiple instances of multipeak not allowed)') parameters, models = self.create_multipeak_model(model) if parameters: - fit_config.parameters += parameters - fit_config.models += models - fit_config.models.pop(i) + self.config.parameters += parameters + self.config.models += models + self.config.models.pop(i) found_multipeak = True # Instantiate the Fit or FitMap object and fit the data - if np.squeeze(nxdata.nxsignal).ndim == 1: - fit = Fit(nxdata, fit_config, self.logger) + if np.squeeze(y).ndim == 1: + fit = Fit(data, self.config, self.logger) fit.fit() - if fit_config.print_report: + if self.config.print_report: fit.print_fit_report() - if fit_config.plot: + if self.config.plot: fit.plot(skip_init=True) else: - fit = FitMap(nxdata, fit_config, self.logger) + fit = FitMap(data, self.config, self.logger) fit.fit( - rel_height_cutoff=fit_config.rel_height_cutoff, - num_proc=fit_config.num_proc, plot=fit_config.plot, - print_report=fit_config.print_report) - else: - raise ValueError(f'Invalid input data ({type(data)}: {data})') + rel_height_cutoff=self.config.rel_height_cutoff, + num_proc=self.config.num_proc, plot=self.config.plot, + print_report=self.config.print_report) return fit @@ -577,11 +568,11 @@ def fit_report(self, show_correl=False): class Fit: """Wrapper class for scipy/lmfit.""" - def __init__(self, nxdata, config, logger): + def __init__(self, data, config, logger): """Initialize Fit. - :param nxdata: The input data. - :type nxdata: nexusformat.nexus.NXdata + :param data: The input data. + :type data: array-like, or dict :param config: Fit configuration. :type config: CHAP.utils.models.FitConfig, optional :param logger: A python Logger object. @@ -633,20 +624,20 @@ def __init__(self, nxdata, config, logger): # raise ValueError( # 'Invalid value of keyword argument try_linear_fit ' # f'({self._try_linear_fit})') - if nxdata is not None: - if isinstance(nxdata.attrs['axes'], str): - dim_x = nxdata.attrs['axes'] - else: - dim_x = nxdata.attrs['axes'][-1] - self._x = np.asarray(nxdata[dim_x]) - self._y = np.squeeze(nxdata.nxsignal) - if self._x.ndim != 1: - raise ValueError( - f'Invalid x dimension ({self._x.ndim})') - if self._x.size != self._y.size: + if data is not None: + try: + if isinstance(data, dict): + self._x = np.asarray(data.get('x')) + self._y = np.squeeze(data.get('y')) + assert self._x.size == self._y.size + else: + self._y = np.squeeze(data) + self._x = np.arange(self._y.size) + except (ValueError, TypeError) as exc: + raise ValueError(f'Invalid input data ({type(data)}\n{data})') + if self._y.ndim != 1: raise ValueError( - f'Inconsistent x and y dimensions ({self._x.size} vs ' - f'{self._y.size})') + f'Invalid input data dimension ({self._y.ndim})') # if 'mask' in kwargs: # self._mask = kwargs.pop('mask') if True: #self._mask is None: @@ -928,14 +919,6 @@ def var_names(self): return None return getattr(self._result, 'var_names', None) - @property - def x(self): - """Return the input x-coordinates. - - :type: numpy.ndarray - """ - return self._x - @property def y(self): """Return the input y-coordinates. @@ -1216,8 +1199,7 @@ def add_model(self, model, prefix): if name not in new_parameters: name = pprefix+name if name not in new_parameters: - raise ValueError( - f'Unable to match parameter {name}') + raise ValueError(f'Unable to match parameter {name}') if parameter.expr is None: self._parameters[name].set( value=parameter.value, min=parameter.min, @@ -1331,11 +1313,13 @@ def fit(self, config=None, **kwargs): return None def plot( - self, y=None, *, y_title=None, title=None, result=None, + self, x=None, y=None, *, y_title=None, title=None, result=None, skip_init=False, plot_comp=True, plot_comp_legends=False, plot_residual=False, plot_masked_data=True, **kwargs): """Plot the best fit. + :param x: x-coordinates. + :type x: array-like, optional :param y: y-coordinates. :type y: array-like, optional :param y_title: y-axis label. @@ -1372,34 +1356,45 @@ class attribute. plot_masked_data = False else: mask = self._mask + if x is not None: + if not isinstance(x, (tuple, list, np.ndarray)): + self._logger.warning( + 'Ignoring invalid parameter x ({type(x)})') + if len(x) != len(self._x): + self._logger.warning( + 'Ignoring parameter x in plot (wrong dimension)') + x = None + if x is None: + x = self._x if y is not None: if not isinstance(y, (tuple, list, np.ndarray)): - self._logger.warning('Ignorint invalid parameter y ({y}') - if len(y) != len(self._x): + self._logger.warning( + 'Ignoring invalid parameter y ({type(y)})') + if len(y) != len(x): self._logger.warning( 'Ignoring parameter y in plot (wrong dimension)') y = None if y is not None: if y_title is None or not isinstance(y_title, str): y_title = 'data' - plots += [(self._x, y, '.')] + plots += [(x, y, '.')] legend += [y_title] if self._y is not None: - plots += [(self._x, np.asarray(self._y), 'b.')] + plots += [(x, np.asarray(self._y), 'b.')] legend += ['data'] if plot_masked_data: - plots += [(self._x[mask], np.asarray(self._y)[mask], 'bx')] + plots += [(x[mask], np.asarray(self._y)[mask], 'bx')] legend += ['masked data'] if isinstance(plot_residual, bool) and plot_residual: - plots += [(self._x[~mask], result.residual, 'r-')] + plots += [(x[~mask], result.residual, 'r-')] legend += ['residual'] - plots += [(self._x[~mask], result.best_fit, 'k-')] + plots += [(x[~mask], result.best_fit, 'k-')] legend += ['best fit'] if not skip_init and hasattr(result, 'init_fit'): - plots += [(self._x[~mask], result.init_fit, 'g-')] + plots += [(x[~mask], result.init_fit, 'g-')] legend += ['init'] if plot_comp: - components = result.eval_components(x=self._x[~mask]) + components = result.eval_components(x=x[~mask]) num_components = len(components) if 'tmp_normalization_offset_' in components: num_components -= 1 @@ -1413,8 +1408,8 @@ class attribute. if len(modelname) > 20: modelname = f'{modelname[0:16]} ...' if isinstance(y_comp, (int, float)): - y_comp *= np.ones(self._x[~mask].size) - plots += [(self._x[~mask], y_comp, '--')] + y_comp *= np.ones(x[~mask].size) + plots += [(x[~mask], y_comp, '--')] if plot_comp_legends: if modelname[-1] == '_': legend.append(modelname[:-1]) @@ -2311,11 +2306,11 @@ def _residual(self, pars, x, y): class FitMap(Fit): """Wrapper to the Fit class to fit data on a N-dimensional map.""" - def __init__(self, nxdata, config, logger): + def __init__(self, data, config, logger): """Initialize FitMap. - :param nxdata: The input data. - :type nxdata: nexusformat.nexus.NXdata + :param data: The input data. + :type data: array-like, or dict :param config: Fit configuration. :type config: CHAP.utils.models.FitConfig, optional :param logger: A python Logger object. @@ -2345,16 +2340,16 @@ def __init__(self, nxdata, config, logger): # At this point the fastest index should always be the signal # dimension so that the slowest ndim-1 dimensions are the # map dimensions - self._x = np.asarray(nxdata[nxdata.attrs['axes'][-1]]) - self._ymap = np.asarray(nxdata.nxsignal) - - # Check input parameters - if self._x.ndim != 1: - raise ValueError(f'Invalid x dimension ({self._x.ndim})') - if self._x.size != self._ymap.shape[-1]: - raise ValueError( - f'Inconsistent x and y dimensions ({self._x.size} vs ' - f'{self._ymap.shape[-1]})') + try: + if isinstance(data, dict): + self._x = np.asarray(data.get('x')) + self._ymap = np.asarray(data.get('y')) + assert self._x.size == self._ymap.shape[-1] + else: + self._ymap = np.asarray(data) + self._x = np.arange(self._ymap.shape[-1]) + except (ValueError, TypeError) as exc: + raise ValueError(f'Invalid input data ({type(data)}\n{data})') # Flatten the map # Store the flattened map in self._ymap_norm @@ -2385,7 +2380,7 @@ def __init__(self, nxdata, config, logger): self._y_range = ymap_max-ymap_min if self._y_range > 0.0: self._norm = (ymap_min, self._y_range) - self._ymap_norm = (self._ymap_norm-self._norm[0]) / self._norm[1] + self._ymap_norm = (self._ymap_norm-self._norm[0])/self._norm[1] else: self._redchi_cutoff *= self._y_range**2 @@ -2588,7 +2583,12 @@ def best_parameters(self, dims=None): :rtype: list[str] or dict """ if dims is None: - return self._best_parameters + parameters_dict = {} + for i, name in enumerate(self._best_parameters): + parameters_dict[name] = { + 'values': self._best_values[i], + 'errors': self._best_errors[i]} + return parameters_dict if (not isinstance(dims, (list, tuple)) or len(dims) != len(self._map_shape)): raise ValueError('Invalid parameter dims ({dims})') @@ -2627,10 +2627,12 @@ def freemem(self): self._logger.warning('Could not clean-up automatically.') def plot( - self, dims=None, *, y_title=None, plot_comp_legends=False, + self, x=None, dims=None, *, y_title=None, plot_comp_legends=False, plot_residual=False, plot_masked_data=True, **kwargs): """Plot the best fits. + :param x: x-coordinates. + :type x: array-like, optional :param dims: Map indices of the data point to plot, defaults to `None` which will plot the first data point. :type dims: list or tuple, optional @@ -2649,6 +2651,16 @@ def plot( # Third party modules from lmfit.models import ExpressionModel + if x is not None: + if not isinstance(x, (tuple, list, np.ndarray)): + self._logger.warning( + 'Ignoring invalid parameter x ({type(x)})') + if len(x) != len(self._x): + self._logger.warning( + 'Ignoring parameter x in plot (wrong dimension)') + x = None + if x is None: + x = self._x if dims is None: dims = [0]*len(self._map_shape) if (not isinstance(dims, (list, tuple)) @@ -2663,20 +2675,20 @@ def plot( if y_title is None or not isinstance(y_title, str): y_title = 'data' if self._mask is None: - mask = np.zeros(self._x.size).astype(bool) + mask = np.zeros(x.size).astype(bool) plot_masked_data = False else: mask = self._mask - plots = [(self._x, np.asarray(self._ymap[dims]), 'b.')] + plots = [(x, np.asarray(self._ymap[dims]), 'b.')] legend = [y_title] if plot_masked_data: plots += \ - [(self._x[mask], np.asarray(self._ymap)[(*dims,mask)], 'bx')] + [(x[mask], np.asarray(self._ymap)[(*dims,mask)], 'bx')] legend += ['masked data'] - plots += [(self._x[~mask], self.best_fit[dims], 'k-')] + plots += [(x[~mask], self.best_fit[dims], 'k-')] legend += ['best fit'] if plot_residual: - plots += [(self._x[~mask], self.residual[dims], 'r--')] + plots += [(x[~mask], self.residual[dims], 'r--')] legend += ['residual'] # Create current parameters parameters = deepcopy(self._parameters) @@ -2703,10 +2715,10 @@ def plot( modelname = f'{component._name}' if len(modelname) > 20: modelname = f'{modelname[0:16]} ...' - y = component.eval(params=parameters, x=self._x[~mask]) + y = component.eval(params=parameters, x=x[~mask]) if isinstance(y, (int, float)): - y *= np.ones(self._x[~mask].size) - plots += [(self._x[~mask], y, '--')] + y *= np.ones(x[~mask].size) + plots += [(x[~mask], y, '--')] if plot_comp_legends: legend.append(modelname) quick_plot( @@ -3092,6 +3104,7 @@ def _fit(self, n, current_best_values, return_result=False, **kwargs): if self._rel_height_cutoff is not None: # Check for low heights peaks and refit without them + # FIX make sure to add "height" and "fwhm" to peak-like models heights = [] names = [] for component in result.components: From 372c6152e777f1a3d426a8a1ddb667bb6111ccb9 Mon Sep 17 00:00:00 2001 From: Keara Soloway Date: Tue, 2 Jun 2026 10:03:56 -0400 Subject: [PATCH 74/76] fix: merge commit bug --- CHAP/utils/fit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHAP/utils/fit.py b/CHAP/utils/fit.py index d09599f7..a834d46d 100755 --- a/CHAP/utils/fit.py +++ b/CHAP/utils/fit.py @@ -99,9 +99,9 @@ def process(self, data): # Refit/continue the fit with possibly updated parameters fit = data if isinstance(data, FitMap): - fit.fit(config=self.config, max_nfev=config.get('max_nfev')) + fit.fit(config=self.config, max_nfev=self.config.max_nfev) else: - fit.fit(config=self.config, max_nfev=config.get('max_nfev')) + fit.fit(config=self.config, max_nfev=self.config.max_nfev) if self.config is not None: if self.config.print_report: fit.print_fit_report() From 3de02faaab39788cc1f4d55cc20971a1cc1fea93 Mon Sep 17 00:00:00 2001 From: Rolf Verberg Date: Tue, 2 Jun 2026 11:26:06 -0400 Subject: [PATCH 75/76] fix: merge more fit processor commit bug --- CHAP/edd/utils.py | 6 +++--- CHAP/utils/fit.py | 23 ++++++++++++---------- CHAP/utils/models.py | 46 ++++++++++++++++++++++++++------------------ 3 files changed, 43 insertions(+), 32 deletions(-) diff --git a/CHAP/edd/utils.py b/CHAP/edd/utils.py index 837bc4d5..dceb4a78 100755 --- a/CHAP/edd/utils.py +++ b/CHAP/edd/utils.py @@ -1351,7 +1351,7 @@ def get_spectra_fits( # Perform uniform fit # FIX make more generic for fit parameters - uniform_fit = FitProcessor.fit( + uniform_fit = FitProcessor.run( data={'x': energies, 'y': spectra}, config=config, **kwargs) uniform_success = uniform_fit.success if spectra.ndim == 1: @@ -1519,8 +1519,8 @@ def get_spectra_fits( # Perform unconstrained fit config['models'][-1]['fit_type'] = 'unconstrained' - fit.config = config - unconstrained_fit = fit.process(uniform_fit) + unconstrained_fit = FitProcessor.run( + data=uniform_fit, config=config, **kwargs) unconstrained_success = unconstrained_fit.success if spectra.ndim == 1: if unconstrained_success: diff --git a/CHAP/utils/fit.py b/CHAP/utils/fit.py index a834d46d..c4671cf8 100755 --- a/CHAP/utils/fit.py +++ b/CHAP/utils/fit.py @@ -136,7 +136,7 @@ def process(self, data): found_multipeak = True # Instantiate the Fit or FitMap object and fit the data - if np.squeeze(data.nxsignal).ndim == 1: + if np.squeeze(y).ndim == 1: fit = Fit(data, self.config, self.logger) fit.fit(max_nfev=self.config.max_nfev) if self.config.print_report: @@ -149,9 +149,8 @@ def process(self, data): rel_height_cutoff=self.config.rel_height_cutoff, max_nfev=self.config.max_nfev, num_proc=self.config.num_proc, - plot=self.config.plot, print_report=self.config.print_report) - else: - raise ValueError(f'Invalid input data ({type(data)}: {data})') + plot=self.config.plot, + print_report=self.config.print_report) return fit @@ -2589,12 +2588,16 @@ def best_parameters(self, dims=None): :rtype: list[str] or dict """ if dims is None: - parameters_dict = {} - for i, name in enumerate(self._best_parameters): - parameters_dict[name] = { - 'values': self._best_values[i], - 'errors': self._best_errors[i]} - return parameters_dict + return self._best_parameters +# FIX use something else, self._best_parameters is "reserved" to get the +# parameters in the EDD strain analysis and must return the order of the +# parameters in self.best_values and self.best_errors +# parameters_dict = {} +# for i, name in enumerate(self._best_parameters): +# parameters_dict[name] = { +# 'values': self._best_values[i], +# 'errors': self._best_errors[i]} +# return parameters_dict if (not isinstance(dims, (list, tuple)) or len(dims) != len(self._map_shape)): raise ValueError('Invalid parameter dims ({dims})') diff --git a/CHAP/utils/models.py b/CHAP/utils/models.py index ebdeb88a..e63b7e30 100755 --- a/CHAP/utils/models.py +++ b/CHAP/utils/models.py @@ -768,46 +768,54 @@ class FitConfig(CHAPBaseModel): :ivar code: Specifies is lmfit is used to perform the fit or if the scipy fit method is called directly, default to `'lmfit'`. :vartype code: Literal['lmfit', 'scipy'], optional - :ivar parameters: Fit model parameters in addition to those - implicitly defined through the build-in model functions, - defaults to `[]`' - :vartype parameters: - list[:class:`~CHAP.utils.models.FitParameter`], optional + :ivar max_nfev: Maximum number of function evaluations in the + the strain analysis peak fitting routine. + :vartype max_nfev: int, optional + :ivar memfolder: Folder name for the temporary memory map if + multiple processors are used, defaults to `'joblib_memmap'`. + :vartype memfolder: str, optional + :ivar method: SciPy non-linear fit method, defaults to + `"leastsq"`. + :vartype method: Literal[ + 'leastsq', 'trf', 'dogbox', 'lm', 'least_squares'] :ivar models: The component(s) of the (composite) fit model. :vartype models: list[:attr:`~CHAP.utils.models.FitConfig.models`] - :ivar rel_height_cutoff: Relative peak height cutoff for - peak fitting (any peak with a height smaller than - `rel_height_cutoff` times the maximum height of all peaks - gets removed from the fit model). - :vartype rel_height_cutoff: float, optional :ivar num_proc: The number of processors used in fitting a map of data, defaults to `1`. :vartype num_proc: int, optional + :ivar parameters: Fit model parameters in addition to those + implicitly defined through the build-in model functions, + defaults to `[]`' + :vartype parameters: + list[:class:`~CHAP.utils.models.FitParameter`], optional :ivar plot: Whether a plot of the fit result is generated, defaults to `False`. :vartype plot: bool, optional. :ivar print_report: Whether to generate a fit result printout, defaults to `False`. :vartype print_report: bool, optional. - :ivar memfolder: Folder name for the temporary memory map if - multiple processors are used, defaults to `'joblib_memmap'`. - :vartype memfolder: str, optional + :ivar rel_height_cutoff: Relative peak height cutoff for + peak fitting (any peak with a height smaller than + `rel_height_cutoff` times the maximum height of all peaks + gets removed from the fit model). + :vartype rel_height_cutoff: float, optional """ code: Literal['lmfit', 'scipy'] = 'scipy' - parameters: conlist(item_type=FitParameter) = [] + max_nfev: Optional[conint(gt=0)] = None + memfolder: str = 'joblib_memmap' + method: Literal[ + 'leastsq', 'trf', 'dogbox', 'lm', 'least_squares'] = 'leastsq' models: conlist(item_type=Union[ Constant, Linear, Quadratic, Exponential, Gaussian, Lorentzian, PseudoVoigt, Rectangle, Expression, Multipeak], min_length=1) - method: Literal[ - 'leastsq', 'trf', 'dogbox', 'lm', 'least_squares'] = 'leastsq' - rel_height_cutoff: Optional[ - confloat(gt=0, lt=1.0, allow_inf_nan=False)] = None num_proc: conint(gt=0) = 1 + parameters: conlist(item_type=FitParameter) = [] plot: StrictBool = False print_report: StrictBool = False - memfolder: str = 'joblib_memmap' + rel_height_cutoff: Optional[ + confloat(gt=0, lt=1.0, allow_inf_nan=False)] = None @field_validator('method') @classmethod From cf1e0887935b706eb1c11425dc09d85cb919c7e4 Mon Sep 17 00:00:00 2001 From: Rolf Verberg Date: Tue, 9 Jun 2026 15:26:04 -0400 Subject: [PATCH 76/76] refactor: change valid input data types of FitProcessor and Fit/FitMap --- CHAP/pipeline.py | 2 +- CHAP/utils/fit.py | 542 ++++++++++++++++++++++--------------------- CHAP/utils/models.py | 447 ++++++++++++++++------------------- 3 files changed, 486 insertions(+), 505 deletions(-) diff --git a/CHAP/pipeline.py b/CHAP/pipeline.py index 91800b24..ba046a68 100755 --- a/CHAP/pipeline.py +++ b/CHAP/pipeline.py @@ -269,7 +269,7 @@ def get_data(data, name=None, schema=None, remove=True): `NXobject `__ object or matches a given name or schema. Pick the last item for which the `'name'` key matches `name` if set or the `'schema'` key matches - `schema` if set, pick the last match for a `NXobjecta` object + `schema` if set, pick the last match for a `NXobject` object otherwise. Return the data object. :param data: Input data. diff --git a/CHAP/utils/fit.py b/CHAP/utils/fit.py index c4671cf8..019608af 100755 --- a/CHAP/utils/fit.py +++ b/CHAP/utils/fit.py @@ -74,56 +74,91 @@ class FitProcessor(Processor): 'config': 'CHAP.utils.models.FitConfig'}, init_var=True) config: Optional[FitConfig] = None + def _get_pipelinedata_item(self, data, remove=True): + """Retrieve the input data to :meth:`process` from the list of + PipelineData items. + + :param data: Input data. + :type data: list[PipelineData] + :param remove: If there is a matching entry in `data`, remove + it from the list, defaults to `True`. + :type remove: bool, optional + :return: Matching data item(s). + :rtype: Any + """ + # Retrieve the data to (re)fit from the pipeline + for i, d in reversed(list(enumerate(data))): + ddata = d.get('data') + if isinstance(ddata, (Fit, FitMap)): + if remove: + data.pop(i) + return ddata + if d.get('name') == 'signal': + if remove: + data.pop(i) + break + else: + raise ValueError( + f'Unable to extract suitable fit input data from {data}') + + # Retrieve the optional coordinates from the pipeline + try: + y = np.asarray(ddata) + for i, d in reversed(list(enumerate(data))): + if d.get('name') == 'coordinates': + x = np.asarray(d.get('data')) + if remove: + data.pop(i) + assert x.size == y.shape[-1] + break + else: + x = None + except (ValueError, TypeError) as exc: + raise ValueError( + f'Unable to extract suitable fit input data from {data}') + return x, y + def process(self, data): """Fit the data and return a :class:`~CHAP.utils.fit.Fit` or :class:`~CHAP.utils.fit.FitMap` object depending on the - dimensionality of the input data. The input data should be an - array-like object (tuple, list or numpy.ndarray) or a - dictionary with at least a `"y"` field with a value equal to - one or a :class:`~CHAP.utils.fit.Fit` or - :class:`~CHAP.utils.fit.FitMap` object from a previous fit. - - :param data: Input data containing the data object to fit. - :type data: list[PipelineData] or Fit or FitMap or array-like + dimensionality of the input data. The input data should be a + list of PipelineData items containing either one with a `data` + field of type :class:`~CHAP.utils.fit.Fit` or + :class:`~CHAP.utils.fit.FitMap` (to refit or continue a + previous fit), or one with the `name` of `signal` and an + array-like `data` field. In the latter case, optional + x-coordinates can be supplied by a second PipelineData item + with the `name` of `coordinates` and again an array-like + `data` field. + + :param data: Input data. + :type data: list[PipelineData] :return: The fitted data object. :rtype: Fit or FitMap """ # Local modules - from CHAP.utils.models import Multipeak + from CHAP.utils.models import MultipeakModel -# # Unwrap the PipelineData if called as a Pipeline Processor -# data = self.get_pipelinedata_item(data) + # Unwrap the PipelineData + data = self._get_pipelinedata_item(data) if isinstance(data, (Fit, FitMap)): # Refit/continue the fit with possibly updated parameters fit = data - if isinstance(data, FitMap): - fit.fit(config=self.config, max_nfev=self.config.max_nfev) - else: - fit.fit(config=self.config, max_nfev=self.config.max_nfev) - if self.config is not None: - if self.config.print_report: - fit.print_fit_report() - if self.config.plot: - fit.plot(skip_init=True) + fit.fit(config=self.config, max_nfev=self.config.max_nfev) + if self.config is not None and not isinstance(data, FitMap): + if self.config.print_report: + fit.print_fit_report() + if self.config.plot: + fit.plot(skip_init=True) else: - # Test for the correct input data type - try: - if isinstance(data, dict): - x = np.asarray(data.get('x')) - y = np.asarray(data.get('y')) - else: - y = np.asarray(data) - except (ValueError, TypeError) as exc: - raise ValueError(f'Invalid input data ({type(data)}\n{data})') - # Expand multipeak model if present found_multipeak = False for i, model in enumerate(deepcopy(self.config.models)): - if isinstance(model, Multipeak): + if isinstance(model, MultipeakModel): if found_multipeak: raise ValueError( f'Invalid parameter models ({self.config.models}) ' @@ -136,15 +171,15 @@ def process(self, data): found_multipeak = True # Instantiate the Fit or FitMap object and fit the data - if np.squeeze(y).ndim == 1: - fit = Fit(data, self.config, self.logger) + if np.squeeze(data[1]).ndim == 1: + fit = Fit(data[1], self.config, self.logger, x=data[0]) fit.fit(max_nfev=self.config.max_nfev) if self.config.print_report: fit.print_fit_report() if self.config.plot: fit.plot(skip_init=True) else: - fit = FitMap(data, self.config, self.logger) + fit = FitMap(data[1], self.config, self.logger, x=data[0]) fit.fit( rel_height_cutoff=self.config.rel_height_cutoff, max_nfev=self.config.max_nfev, @@ -155,66 +190,65 @@ def process(self, data): return fit @staticmethod - def create_multipeak_model(model_config): + def create_multipeak_model(model): """Create a multipeak model. - :param model_config: A Multipeak fit model class. - :type model_config: :class:`~CHAP.utils.models.Multipeak` + :param model: A Multipeak fit model class. + :type model: :class:`~CHAP.utils.models.MultipeakModel` :return: The fit parameters and fit model classes. :rtype: list[:attr:`~CHAP.utils.models.FitParameter`], list[:attr:`~CHAP.utils.models.FitConfig.models`] """ # Local modules from CHAP.utils.models import ( + PEAK_LIKE_MODELS, FitParameter, - models, ) - peak_model_name = model_config.peak_models - peak_model_class = models[peak_model_name]['class'] + peak_model_class = PEAK_LIKE_MODELS[model.peak_models] parameters = [] peak_models = [] - num_peak = len(model_config.centers) - if num_peak == 1 and model_config.fit_type == 'uniform': - model_config.fit_type = 'unconstrained' + num_peak = len(model.centers) + if num_peak == 1 and model.fit_type == 'uniform': + model.fit_type = 'unconstrained' sig_min = FLOAT_MIN sig_max = np.inf - if (model_config.fwhm_min is not None - or model_config.fwhm_max is not None): + if (model.fwhm_min is not None + or model.fwhm_max is not None): # Third party modules from asteval import Interpreter ast = Interpreter() - if model_config.fwhm_min is not None: - ast(f'fwhm = {model_config.fwhm_min}') - sig_min = ast(fwhm_factor[model_config.peak_models]) - if model_config.fwhm_max is not None: - ast(f'fwhm = {model_config.fwhm_max}') - sig_max = ast(fwhm_factor[model_config.peak_models]) + if model.fwhm_min is not None: + ast(f'fwhm = {model.fwhm_min}') + sig_min = ast(fwhm_factor[model.peak_models]) + if model.fwhm_max is not None: + ast(f'fwhm = {model.fwhm_max}') + sig_max = ast(fwhm_factor[model.peak_models]) prefix = '' - if model_config.fit_type == 'uniform': + if model.fit_type == 'uniform': parameters.append(FitParameter( name='scale_factor', value=1.0, min=FLOAT_MIN)) - for i, cen in enumerate(model_config.centers): + for i, cen in enumerate(model.centers): if num_peak > 1: prefix = f'peak{i+1}_' peak_models.append(peak_model_class( - model=peak_model_name, + model_type=model.peak_models, prefix=prefix, parameters=[ {'name': 'amplitude', 'min': FLOAT_MIN}, {'name': 'center', 'expr': f'scale_factor*{cen}'}, {'name': 'sigma', 'min': sig_min, 'max': sig_max}])) else: - for i, cen in enumerate(model_config.centers): + for i, cen in enumerate(model.centers): if num_peak > 1: prefix = f'peak{i+1}_' - if model_config.centers_range == 0: + if model.centers_range == 0: peak_models.append(peak_model_class( - model=peak_model_name, + model_type=model.peak_models, prefix=prefix, parameters=[ {'name': 'amplitude', 'min': FLOAT_MIN}, @@ -222,14 +256,14 @@ def create_multipeak_model(model_config): {'name': 'sigma', 'min': sig_min, 'max': sig_max} ])) else: - if model_config.centers_range is None: + if model.centers_range is None: cen_min = None cen_max = None else: - cen_min = cen - model_config.centers_range - cen_max = cen + model_config.centers_range + cen_min = cen - model.centers_range + cen_max = cen + model.centers_range peak_models.append(peak_model_class( - model=peak_model_name, + model_type=model.peak_models, prefix=prefix, parameters=[ {'name': 'amplitude', 'min': FLOAT_MIN}, @@ -253,12 +287,15 @@ def __init__(self, model, prefix=''): :type prefix: str, optional """ # Local modules - from CHAP.utils.models import models + #from CHAP.utils.models import MODEL_TYPE_TO_CLASS - self.func = models[model.model]['name'] + self.func = model.func #MODEL_TYPE_TO_CLASS[model.model_type] + self.func_args = model._func_args #MODEL_TYPE_TO_CLASS[model.model_type] self.param_names = [f'{prefix}{par.name}' for par in model.parameters] self.prefix = prefix - self._name = model.model + self._name = model.model_type + names = [f'{par.name}' for par in model.parameters] + self.func_args_indices = [names.index(arg) for arg in self.func_args] class Components(dict): @@ -284,24 +321,6 @@ def components(self): """ return self.values() - def add(self, model, prefix=''): - """Add a model to the model fit components dictionary. - - :param model: A fit model class. - :type model: :attr:`~CHAP.utils.models.FitConfig.models` - :param prefix: Model prefix. - :type prefix: str, optional - """ - # Local modules - from CHAP.utils.models import model_classes - - if not isinstance(model, model_classes): - raise ValueError(f'Invalid parameter model ({model})') - if not isinstance(prefix, str): - raise ValueError(f'Invalid parameter prefix ({prefix})') - name = f'{prefix}{model.model}' - self.__setitem__(name, Component(model, prefix)) - class Parameters(dict): """A dictionary of FitParameter objects, mimicking the @@ -357,7 +376,7 @@ class ModelResult(): """ def __init__( - self, model, parameters, x=None, y=None, method=None, ast=None, + self, model, parameters, *, x=None, y=None, method=None, ast=None, res_par_exprs=None, res_par_indices=None, res_par_names=None, result=None): """Initialize SetNumexprThreads. @@ -498,7 +517,9 @@ def eval_components(self, x=None, parameters=None): name = component._name else: name = component.prefix - result[name] = component.func(x, *par_values) + ppar_values = tuple( + par_values[i] for i in component.func_args_indices) + result[name] = component.func(x, *ppar_values) return result def fit_report(self, show_correl=False): @@ -570,22 +591,24 @@ def fit_report(self, show_correl=False): class Fit: """Wrapper class for scipy/lmfit.""" - def __init__(self, data, config, logger): + def __init__(self, y, config, logger, x=None): """Initialize Fit. - :param data: The input data. - :type data: array-like, or dict + :param y: Input signal data. + :type y: array-like :param config: Fit configuration. - :type config: CHAP.utils.models.FitConfig, optional + :type config: CHAP.utils.models.FitConfig :param logger: A python Logger object. :type logger: logging.Logger + :param x: Input coordinate data. + :type x: array-like, optional """ self._code = config.code for model in config.models: - if model.model == 'expression' and self._code != 'lmfit': + if model.model_type == 'expression' and self._code != 'lmfit': self._code = 'lmfit' - logger.warning('Using lmfit instead of scipy with ' - 'an expression model') + logger.warning('Using lmfit instead of scipy with an ' + 'expression model') if self._code == 'scipy': # Local modules from CHAP.utils.fit import Parameters @@ -626,20 +649,16 @@ def __init__(self, data, config, logger): # raise ValueError( # 'Invalid value of keyword argument try_linear_fit ' # f'({self._try_linear_fit})') - if data is not None: - try: - if isinstance(data, dict): - self._x = np.asarray(data.get('x')) - self._y = np.squeeze(data.get('y')) - assert self._x.size == self._y.size - else: - self._y = np.squeeze(data) - self._x = np.arange(self._y.size) - except (ValueError, TypeError) as exc: - raise ValueError(f'Invalid input data ({type(data)}\n{data})') + if y is not None: + self._y = np.squeeze(y) if self._y.ndim != 1: raise ValueError( - f'Invalid input data dimension ({self._y.ndim})') + f'Invalid input signal dimension ({self._y.ndim})') + if x is None: + self._x = np.arange(self._y.size) + else: + self._x = x + assert self._x.size == self._y.size # if 'mask' in kwargs: # self._mask = kwargs.pop('mask') if True: #self._mask is None: @@ -995,12 +1014,10 @@ def add_model(self, model, prefix): RectangleModel, ) - if model.model == 'expression': + if model.model_type == 'expression': expr = model.expr else: expr = None - parameters = model.parameters - model_name = model.model if prefix is None: pprefix = '' @@ -1008,149 +1025,147 @@ def add_model(self, model, prefix): pprefix = prefix if self._code == 'scipy': new_parameters = [] - for par in deepcopy(parameters): + for par in deepcopy(model.parameters): + name = par.name self._parameters.add(par, pprefix) if self._parameters[par.name].expr is None: self._parameters[par.name].set(value=par.default) new_parameters.append(par.name) - self._res_num_pars += [len(parameters)] + if name in model.LINEAR_PARAMETERS: + self._linear_parameters.append(par.name) + elif name not in model.MODEL_PARAMETERS: + self._nonlinear_parameters.append(par.name) + self._res_num_pars += [len(model.parameters)] - if model_name == 'constant': - # Par: c - if self._code == 'lmfit': + if self._code == 'lmfit': + if model.model_type == 'constant': + # Par: c newmodel = ConstantModel(prefix=prefix) - self._linear_parameters.append(f'{pprefix}c') - elif model_name == 'linear': - # Par: slope, intercept - if self._code == 'lmfit': + self._linear_parameters.append(f'{pprefix}c') + elif model.model_type == 'linear': + # Par: slope, intercept newmodel = LinearModel(prefix=prefix) - self._linear_parameters.append(f'{pprefix}slope') - self._linear_parameters.append(f'{pprefix}intercept') - elif model_name == 'quadratic': - # Par: a, b, c - if self._code == 'lmfit': + self._linear_parameters.append(f'{pprefix}slope') + self._linear_parameters.append(f'{pprefix}intercept') + elif model.model_type == 'parabolic': + # Par: a, b, c newmodel = QuadraticModel(prefix=prefix) - self._linear_parameters.append(f'{pprefix}a') - self._linear_parameters.append(f'{pprefix}b') - self._linear_parameters.append(f'{pprefix}c') -# elif model_name == 'polynomial': -# # Par: c0, c1,..., c7 -# degree = kwargs.get('degree') -# if degree is not None: -# kwargs.pop('degree') -# if degree is None or not is_int(degree, ge=0, le=7): -# raise ValueError( -# 'Invalid parameter degree for build-in step model ' -# f'({degree})') -# if self._code == 'lmfit': + self._linear_parameters.append(f'{pprefix}a') + self._linear_parameters.append(f'{pprefix}b') + self._linear_parameters.append(f'{pprefix}c') +# elif model.model_type == 'polynomial': +# # Par: c0, c1,..., c7 +# degree = kwargs.get('degree') +# if degree is not None: +# kwargs.pop('degree') +# if degree is None or not is_int(degree, ge=0, le=7): +# raise ValueError( +# 'Invalid parameter degree for build-in step model ' +# f'({degree})') # newmodel = PolynomialModel(degree=degree, prefix=prefix) -# for i in range(degree+1): -# self._linear_parameters.append(f'{pprefix}c{i}') - elif model_name == 'exponential': - # Par: amplitude, decay - if self._code == 'lmfit': +# for i in range(degree+1): +# self._linear_parameters.append(f'{pprefix}c{i}') + elif model.model_type == 'exponential': + # Par: amplitude, decay newmodel = ExponentialModel(prefix=prefix) - self._linear_parameters.append(f'{pprefix}amplitude') - self._nonlinear_parameters.append(f'{pprefix}decay') - elif model_name == 'gaussian': - # Par: amplitude, center, sigma (fwhm, height) - if self._code == 'lmfit': + self._linear_parameters.append(f'{pprefix}amplitude') + self._nonlinear_parameters.append(f'{pprefix}decay') + elif model.model_type == 'gaussian': + # Par: amplitude, center, sigma (fwhm, height) newmodel = GaussianModel(prefix=prefix) # parameter norms for height and fwhm are needed to # get correct errors - self._linear_parameters.append(f'{pprefix}amplitude') - self._nonlinear_parameters.append(f'{pprefix}center') - self._nonlinear_parameters.append(f'{pprefix}sigma') - elif model_name == 'lorentzian': - # Par: amplitude, center, sigma (fwhm, height) - if self._code == 'lmfit': + self._linear_parameters.append(f'{pprefix}amplitude') + self._nonlinear_parameters.append(f'{pprefix}center') + self._nonlinear_parameters.append(f'{pprefix}sigma') + elif model.model_type == 'lorentzian': + # Par: amplitude, center, sigma (fwhm, height) newmodel = LorentzianModel(prefix=prefix) # parameter norms for height and fwhm are needed to # get correct errors - self._linear_parameters.append(f'{pprefix}amplitude') - self._nonlinear_parameters.append(f'{pprefix}center') - self._nonlinear_parameters.append(f'{pprefix}sigma') - elif model_name == 'pvoigt': - # Par: amplitude, center, sigma (fwhm, height), fraction - if self._code == 'lmfit': + self._linear_parameters.append(f'{pprefix}amplitude') + self._nonlinear_parameters.append(f'{pprefix}center') + self._nonlinear_parameters.append(f'{pprefix}sigma') + elif model.model_type == 'pvoigt': + # Par: amplitude, center, sigma (fwhm, height), fraction newmodel = PseudoVoigtModel(prefix=prefix) # parameter norms for height and fwhm are needed to # get correct errors - self._linear_parameters.append(f'{pprefix}amplitude') - self._linear_parameters.append(f'{pprefix}fraction') - self._nonlinear_parameters.append(f'{pprefix}center') - self._nonlinear_parameters.append(f'{pprefix}sigma') -# elif model_name == 'step': -# # Par: amplitude, center, sigma -# form = kwargs.get('form') -# if form is not None: -# kwargs.pop('form') -# if (form is None or form not in -# ('linear', 'atan', 'arctan', 'erf', 'logistic')): -# raise ValueError( -# 'Invalid parameter form for build-in step model ' -# f'({form})') -# if self._code == 'lmfit': + self._linear_parameters.append(f'{pprefix}amplitude') + self._linear_parameters.append(f'{pprefix}fraction') + self._nonlinear_parameters.append(f'{pprefix}center') + self._nonlinear_parameters.append(f'{pprefix}sigma') +# elif model.model_type == 'step': +# # Par: amplitude, center, sigma +# form = kwargs.get('form') +# if form is not None: +# kwargs.pop('form') +# if (form is None or form not in +# ('linear', 'atan', 'arctan', 'erf', 'logistic')): +# raise ValueError( +# 'Invalid parameter form for build-in step model ' +# f'({form})') # newmodel = StepModel(prefix=prefix, form=form) -# self._linear_parameters.append(f'{pprefix}amplitude') -# self._nonlinear_parameters.append(f'{pprefix}center') -# self._nonlinear_parameters.append(f'{pprefix}sigma') - elif model_name == 'rectangle': - # Par: amplitude, center1, center2, sigma1, sigma2 - form = 'atan' #kwargs.get('form') - #if form is not None: - # kwargs.pop('form') - # RV: Implement and test other forms when needed - if (form is None or form not in - ('linear', 'atan', 'arctan', 'erf', 'logistic')): - raise ValueError( - 'Invalid parameter form for build-in rectangle model ' - f'({form})') - if self._code == 'lmfit': +# self._linear_parameters.append(f'{pprefix}amplitude') +# self._nonlinear_parameters.append(f'{pprefix}center') +# self._nonlinear_parameters.append(f'{pprefix}sigma') + elif model.model_type == 'rectangle': + # Par: amplitude, center1, center2, sigma1, sigma2 + form = 'atan' #kwargs.get('form') + #if form is not None: + # kwargs.pop('form') + # RV: Implement and test other forms when needed + if (form is None or form not in + ('linear', 'atan', 'arctan', 'erf', 'logistic')): + raise ValueError( + 'Invalid parameter form for build-in rectangle model ' + f'({form})') newmodel = RectangleModel(prefix=prefix, form=form) - self._linear_parameters.append(f'{pprefix}amplitude') - self._nonlinear_parameters.append(f'{pprefix}center1') - self._nonlinear_parameters.append(f'{pprefix}center2') - self._nonlinear_parameters.append(f'{pprefix}sigma1') - self._nonlinear_parameters.append(f'{pprefix}sigma2') - elif model_name == 'expression' and self._code == 'lmfit': - # Third party modules - from asteval import ( - Interpreter, - get_ast_names, - ) + self._linear_parameters.append(f'{pprefix}amplitude') + self._nonlinear_parameters.append(f'{pprefix}center1') + self._nonlinear_parameters.append(f'{pprefix}center2') + self._nonlinear_parameters.append(f'{pprefix}sigma1') + self._nonlinear_parameters.append(f'{pprefix}sigma2') + elif model.model_type == 'expression': + # FIX move to a validator + # Third party modules + from asteval import ( + Interpreter, + get_ast_names, + ) - for par in parameters: - if par.expr is not None: - raise KeyError( - f'Invalid "expr" key ({par.expr}) in ' - f'parameter ({par}) for an expression model') - ast = Interpreter() - expr_parameters = [ - name for name in get_ast_names(ast.parse(expr)) - if (name != 'x' and name not in self._parameters - and name not in ast.symtable)] - if prefix is None: - newmodel = ExpressionModel(expr=expr) - else: - for name in expr_parameters: - expr = sub(rf'\b{name}\b', f'{prefix}{name}', expr) + for par in model.parameters: + if par.expr is not None: + raise KeyError( + f'Invalid "expr" key ({par.expr}) in ' + f'parameter ({par}) for an expression model') + ast = Interpreter() expr_parameters = [ - f'{prefix}{name}' for name in expr_parameters] - newmodel = ExpressionModel(expr=expr, name=model_name) - # Remove already existing names - for name in newmodel.param_names.copy(): - if name not in expr_parameters: - newmodel._func_allargs.remove(name) - newmodel._param_names.remove(name) - else: - raise ValueError(f'Unknown fit model ({model_name})') + name for name in get_ast_names(ast.parse(expr)) + if (name != 'x' and name not in self._parameters + and name not in ast.symtable)] + if prefix is None: + newmodel = ExpressionModel(expr=expr) + else: + for name in expr_parameters: + expr = sub(rf'\b{name}\b', f'{prefix}{name}', expr) + expr_parameters = [ + f'{prefix}{name}' for name in expr_parameters] + newmodel = ExpressionModel(expr=expr, name=model.model_type) + # Remove already existing names + for name in newmodel.param_names.copy(): + if name not in expr_parameters: + newmodel._func_allargs.remove(name) + newmodel._param_names.remove(name) + else: + raise ValueError(f'Unknown fit model ({model.model_type})') # Add the new model to the current one if self._code == 'scipy': if self._model is None: self._model = Components() - self._model.add(model, prefix) + self._model |= { + f'{prefix}{model.model_type}': Component(model, prefix)} else: if self._model is None: self._model = newmodel @@ -1186,17 +1201,23 @@ def add_model(self, model, prefix): value = par.value _min = par.min _max = par.max - if not (model_name == 'pvoigt' and 'fraction' in name): - if value is not None: - value *= self._norm[1] - if not np.isinf(_min) and abs(_min) != FLOAT_MIN: - _min *= self._norm[1] - if not np.isinf(_max) and abs(_max) != FLOAT_MIN: - _max *= self._norm[1] +# if not (model.model_type == 'pvoigt' and 'fraction' in name): +# if value is not None: +# value *= self._norm[1] +# if not np.isinf(_min) and abs(_min) != FLOAT_MIN: +# _min *= self._norm[1] +# if not np.isinf(_max) and abs(_max) != FLOAT_MIN: +# _max *= self._norm[1] + if value is not None: + value *= self._norm[1] + if not np.isinf(_min) and abs(_min) != FLOAT_MIN: + _min *= self._norm[1] + if not np.isinf(_max) and abs(_max) != FLOAT_MIN: + _max *= self._norm[1] par.set(value=value, min=_min, max=_max) # Initialize the model parameters - for parameter in deepcopy(parameters): + for parameter in deepcopy(model.parameters): name = parameter.name if name not in new_parameters: name = pprefix+name @@ -1538,7 +1559,7 @@ def _create_prefixes(self, models): names = [] prefixes = [] for model in models: - names.append(f'{model.prefix}{model.model}') + names.append(f'{model.prefix}{model.model_type}') prefixes.append(model.prefix) counts = Counter(names) for model, count in counts.items(): @@ -1651,10 +1672,10 @@ def _setup_fit(self, config, guess=False): # Add constant offset for a normalized model if self._result is None and self._norm is not None and self._norm[0]: # Local modules - from CHAP.utils.models import Constant + from CHAP.utils.models import ConstantModel - model = Constant( - model='constant', + model = ConstantModel( + model_type='constant', parameters=[{ 'name': 'c', 'value': -self._norm[0], @@ -1667,14 +1688,14 @@ def _setup_fit(self, config, guess=False): # Local modules from CHAP.utils.models import ( FitConfig, - Multipeak, + MultipeakModel, ) # Expand multipeak model if present scale_factor = None for i, model in enumerate(deepcopy(config.models)): found_multipeak = False - if isinstance(model, Multipeak): + if isinstance(model, MultipeakModel): if found_multipeak: raise ValueError( f'Invalid parameter models ({config.models}) ' @@ -1758,10 +1779,9 @@ def _setup_fit(self, config, guess=False): if par.expr: self._res_par_exprs.append( {'expr': par.expr, 'index': i}) - else: - if par.vary: - self._res_par_indices.append(i) - self._res_par_names.append(name) + elif par.vary: + self._res_par_indices.append(i) + self._res_par_names.append(name) # Check for uninitialized parameters for name, par in self._parameters.items(): @@ -2077,7 +2097,6 @@ def _fit_nonlinear_model(self, x, y, **kwargs): else: bounds = (-np.inf, np.inf) init_params = deepcopy(self._parameters) -# t0 = time() lskws = { 'ftol': 1.49012e-08, 'xtol': 1.49012e-08, @@ -2096,12 +2115,11 @@ def _fit_nonlinear_model(self, x, y, **kwargs): result = least_squares( self._residual, pars_init, bounds=bounds, method=self._method, args=(x, y), **lskws) -# t1 = time() -# print(f'\n\nFitting took {1000*(t1-t0):.3f} ms\n\n') model_result = ModelResult( - self._model, self._parameters, x, y, self._method, self._ast, - self._res_par_exprs, self._res_par_indices, - self._res_par_names, result) + self._model, self._parameters, x=x, y=y, method=self._method, + ast=self._ast, res_par_exprs=self._res_par_exprs, + res_par_indices=self._res_par_indices, + res_par_names=self._res_par_names, result=result) model_result.init_params = init_params model_result.init_values = {} for name, par in init_params.items(): @@ -2111,12 +2129,9 @@ def _fit_nonlinear_model(self, x, y, **kwargs): fit_kws = {} # if 'Dfun' in kwargs: # fit_kws['Dfun'] = kwargs.pop('Dfun') -# t0 = time() model_result = self._model.fit( y, self._parameters, x=x, method=self._method, fit_kws=fit_kws, **kwargs) -# t1 = time() -# print(f'\n\nFitting took {1000*(t1-t0):.3f} ms\n\n') return model_result @@ -2302,8 +2317,11 @@ def _residual(self, pars, x, y): self._ast.eval(expr['expr']) for component, num_par in zip( self._model.components, self._res_num_pars): + values = self._res_par_values[n_par:n_par+num_par] + vvalues = [values[i] for i in component.func_args_indices] res += component.func( - x, *tuple(self._res_par_values[n_par:n_par+num_par])) +# x, *tuple(self._res_par_values[n_par:n_par+num_par])) + x, *tuple(vvalues)) n_par += num_par return res - y @@ -2311,15 +2329,17 @@ def _residual(self, pars, x, y): class FitMap(Fit): """Wrapper to the Fit class to fit data on a N-dimensional map.""" - def __init__(self, data, config, logger): + def __init__(self, y, config, logger, x=None): """Initialize FitMap. - :param data: The input data. - :type data: array-like, or dict + :param y: Input signal data. + :type y: array-like :param config: Fit configuration. - :type config: CHAP.utils.models.FitConfig, optional + :type config: CHAP.utils.models.FitConfig :param logger: A python Logger object. :type logger: logging.Logger + :param x: Input coordinate data. + :type x: array-like, optional """ super().__init__(None, config, logger) self._best_errors = None @@ -2345,16 +2365,12 @@ def __init__(self, data, config, logger): # At this point the fastest index should always be the signal # dimension so that the slowest ndim-1 dimensions are the # map dimensions - try: - if isinstance(data, dict): - self._x = np.asarray(data.get('x')) - self._ymap = np.asarray(data.get('y')) - assert self._x.size == self._ymap.shape[-1] - else: - self._ymap = np.asarray(data) - self._x = np.arange(self._ymap.shape[-1]) - except (ValueError, TypeError) as exc: - raise ValueError(f'Invalid input data ({type(data)}\n{data})') + self._ymap = y + if x is None: + self._x = np.arange(self._ymap.shape[-1]) + else: + self._x = x + assert self._x.size == self._ymap.shape[-1] # Flatten the map # Store the flattened map in self._ymap_norm diff --git a/CHAP/utils/models.py b/CHAP/utils/models.py index e63b7e30..657265a3 100755 --- a/CHAP/utils/models.py +++ b/CHAP/utils/models.py @@ -4,6 +4,7 @@ # System modules from typing import ( + ClassVar, Literal, Optional, Union, @@ -19,6 +20,7 @@ confloat, constr, field_validator, + model_validator, ) from typing_extensions import Annotated import numpy as np @@ -27,9 +29,6 @@ from CHAP.models import CHAPBaseModel from CHAP.utils.general import not_zero, tiny -# pylint: disable=no-member -tiny = np.finfo(np.float64).resolution -# pylint: enable=no-member s2pi = np.sqrt(2*np.pi) s2ln2 = np.sqrt(2*np.log(2)) @@ -39,7 +38,7 @@ def constant(x, c=0.0): :param c: Constant, defaults to `0`. :type c: float, optional - :returns: The function evaluations. + :returns: Function evaluations. :rtype: numpy.ndarray .. math:: @@ -58,7 +57,7 @@ def linear(x, slope=1.0, intercept=0.0): :type slope: float, optional :param intercept: Intercept, defaults to `0`. :type intercept: float, optional - :returns: The function evaluations. + :returns: Function evaluations. :rtype: numpy.ndarray .. math:: @@ -71,8 +70,8 @@ def linear(x, slope=1.0, intercept=0.0): return slope * x + intercept -#def quadratic(x, a=0.5, b=0.4, c=0.1): -def quadratic(x, a=0.0, b=0.0, c=0.0): +#def parabolic(x, a=0.5, b=0.4, c=0.1): +def parabolic(x, a=0.0, b=0.0, c=0.0): r"""Return a parabolic function. :param a: Quadratic polynomial coefficient, defaults to an @@ -84,7 +83,7 @@ def quadratic(x, a=0.0, b=0.0, c=0.0): :param c: Constant polynomial coefficient, defaults to an initial value of `0`. :type c: float, optional - :returns: The function evaluations. + :returns: Function evaluations. :rtype: numpy.ndarray .. math:: @@ -104,7 +103,7 @@ def exponential(x, amplitude=1.0, decay=1.0): :type amplitude: float, optional :param decay: Exponential decay, defaults to `1`. :type decay: float, optional - :returns: The function evaluations. + :returns: Function evaluations. :rtype: numpy.ndarray .. math:: @@ -128,7 +127,7 @@ def gaussian(x, amplitude=1.0, center=0.0, sigma=1.0): :type center: float, optional :param sigma: Standard deviation, defaults to `1`. :type sigma: float, optional - :returns: The function evaluations. + :returns: Function evaluations. :rtype: numpy.ndarray .. math:: @@ -160,7 +159,7 @@ def lorentzian(x, amplitude=1.0, center=0.0, sigma=1.0): :type center: float, optional :param sigma: Standard deviation, defaults to `1`. :type sigma: float, optional - :returns: The function evaluations. + :returns: Function evaluations. :rtype: numpy.ndarray .. math:: @@ -196,7 +195,7 @@ def pvoigt(x, amplitude=1.0, center=0.0, sigma=1.0, fraction=0.5): :param fraction: Relative weight of the Gaussian and Lorentzian components, defaults to `0.5`. :type fraction: float, optional - :returns: The function evaluations. + :returns: Function evaluations. :rtype: numpy.ndarray .. math:: @@ -248,7 +247,7 @@ def rectangle( - ``'logistic'``: Sigmoidal (logistic function) transitions. :type form: str, optional - :returns: The evaluated rectangle function values. + :returns: Evaluated rectangle function values. :rtype: float or numpy.ndarray .. note:: @@ -289,53 +288,6 @@ def rectangle( return amplitude*rect -def validate_parameters(parameters, info): - """Validate the parameters. - - :param parameters: Fit model parameters. - :type parameters: list[FitParameter] - :param info: Model parameter validation information. - :type info: pydantic.ValidationInfo - :return: List of fit model parameters. - :rtype: list[FitParameter] - """ - # System imports - from inspect import signature - - if 'model' in info.data: - model = info.data['model'] - else: - model = None - if model is None or model == 'expression': - return parameters - sig = dict(signature(models[model]['name']).parameters.items()) - sig.pop('x') - - # Check input model parameter validity - for par in parameters: - if par.name not in sig: - raise ValueError('Invalid parameter {par.name} in {model} model') - - # Set model parameters - output_parameters = [] - for sig_name, sig_par in sig.items(): - if model == 'rectangle' and sig_name == 'form': - continue - for par in parameters: - if sig_name == par.name: - break - else: - par = FitParameter(name=sig_name) - if sig_par.default != sig_par.empty: - par._default = sig_par.default - if model == 'pvoigt' and sig_name == 'fraction': - par.min = 0.0 - par.max = 1.0 - output_parameters.append(par) - - return output_parameters - - class FitParameter(CHAPBaseModel): """Class representing a specific fit parameter for the fit processor. @@ -489,232 +441,271 @@ def set(self, value=None, min=None, max=None, vary=None, expr=None): self.value = self.min self.expr = None -class Constant(CHAPBaseModel): + +class ConstantModel(CHAPBaseModel): """Class representing a Constant model component. - :ivar model: The model component base name (a prefix will be added - if multiple identical model components are added). - :vartype model: Literal['constant'] + :ivar model_type: Model component base name (a prefix will be + added if multiple identical model components are added). + :vartype model_type: Literal['constant'] :ivar parameters: Function parameters, defaults to those auto generated from the function signature (excluding the independent variable). :vartype parameters: list[FitParameter], optional - :ivar prefix: The model prefix, defaults to `''`. + :ivar prefix: Model prefix, defaults to `''`. :vartype prefix: str, optional """ - model: Literal['constant'] + LINEAR_PARAMETERS: ClassVar[list[str]] = ['c'] + MODEL_PARAMETERS: ClassVar[list[str]] = [] + MODEL_IDENTIFIERS: ClassVar[list[str]] = [] + model_type: Literal['constant'] parameters: Annotated[ conlist(item_type=FitParameter), Field(validate_default=True)] = [] prefix: Optional[str] = '' - _validate_parameters_parameters = field_validator( - 'parameters')(validate_parameters) + _func: PrivateAttr + _func_args: PrivateAttr + + @model_validator(mode='after') + def validate_model_after(self): + """Validate the model configuration and initialize the + appropriate parameters from the model function signature. + + :return: Validated and initialized configuration. + :rtype: Model + """ + # System imports + from inspect import signature + + if self.model_type == 'expression': + return self + self._func = globals()[self.model_type] + sig = dict(signature(self._func).parameters) + sig.pop('x') + self._func_args = list(sig.keys()) + + # Check input model parameter validity + par_names = [] + for par in self.parameters: + if par.name not in sig: + raise ValueError( + 'Invalid parameter {par.name} in {self.model_type} model ' + f'valid function arguments: {list(sig.keys())}') + par_names.append(par.name) + + # Set model parameters + for sig_name, sig_par in sig.items(): + if sig_name in self.MODEL_IDENTIFIERS or sig_name in par_names: +# if ((self.model_type == 'rectangle' and sig_name == 'form') +# or sig_name in par_names): + continue + par = FitParameter(name=sig_name) + if sig_par.default != sig_par.empty: + par._default = sig_par.default + self.parameters.append(par) + par_names.append(par.name) + + # Perform any additional validation of model parameters + if hasattr(self, '_validate_parameters'): + self._validate_parameters() + return self + @property + def func(self): + """Return the model function -class Linear(CHAPBaseModel): + :type: function + """ + if hasattr(self, '_func'): + return self._func + return None + + +class LinearModel(ConstantModel): """Class representing a Linear model component. - :ivar model: The model component base name (a prefix will be added - if multiple identical model components are added). - :vartype model: Literal['linear'] - :ivar parameters: Function parameters, defaults to those auto - generated from the function signature (excluding the - independent variable). - :vartype parameters: list[FitParameter], optional - :ivar prefix: The model prefix, defaults to `''`. - :vartype prefix: str, optional + :ivar model_type: Model component base name (a prefix will be + added if multiple identical model components are added). + :vartype model_type: Literal['linear'] """ - model: Literal['linear'] - parameters: Annotated[ - conlist(item_type=FitParameter), - Field(validate_default=True)] = [] - prefix: Optional[str] = '' + LINEAR_PARAMETERS: ClassVar[list[str]] = ['slope', 'intercept'] + model_type: Literal['linear'] - _validate_parameters_parameters = field_validator( - 'parameters')(validate_parameters) - -class Quadratic(CHAPBaseModel): +class QuadraticModel(ConstantModel): """Class representing a Quadratic model component. - :ivar model: The model component base name (a prefix will be added - if multiple identical model components are added). - :vartype model: Literal['quadratic'] - :ivar parameters: Function parameters, defaults to those auto - generated from the function signature (excluding the - independent variable). - :vartype parameters: list[FitParameter], optional - :ivar prefix: The model prefix, defaults to `''`. - :vartype prefix: str, optional + :ivar model_type: Model component base name (a prefix will be + added if multiple identical model components are added). + :vartype model_type: Literal['parabolic'] """ - model: Literal['quadratic'] - parameters: Annotated[ - conlist(item_type=FitParameter), - Field(validate_default=True)] = [] - prefix: Optional[str] = '' - - _validate_parameters_parameters = field_validator( - 'parameters')(validate_parameters) + LINEAR_PARAMETERS: ClassVar[list[str]] = ['a', 'b', 'c'] + model_type: Literal['parabolic'] -class Exponential(CHAPBaseModel): +class ExponentialModel(ConstantModel): """Class representing an Exponential model component. - :ivar model: The model component base name (a prefix will be added - if multiple identical model components are added). - :vartype model: Literal['exponential'] - :ivar parameters: Function parameters, defaults to those auto - generated from the function signature (excluding the - independent variable). - :vartype parameters: list[FitParameter], optional - :ivar prefix: The model prefix, defaults to `''`. - :vartype prefix: str, optional + :ivar model_type: Model component base name (a prefix will be + added if multiple identical model components are added). + :vartype model_type: Literal['exponential'] """ - model: Literal['exponential'] - parameters: Annotated[ - conlist(item_type=FitParameter), - Field(validate_default=True)] = [] - prefix: Optional[str] = '' - - _validate_parameters_parameters = field_validator( - 'parameters')(validate_parameters) + LINEAR_PARAMETERS: ClassVar[list[str]] = ['amplitude'] + model_type: Literal['exponential'] -class Gaussian(CHAPBaseModel): +class GaussianModel(ConstantModel): """Class representing a Gaussian model component. - :ivar model: The model component base name (a prefix will be added - if multiple identical model components are added). - :vartype model: Literal['gaussian'] - :ivar parameters: Function parameters, defaults to those auto - generated from the function signature (excluding the - independent variable). - :vartype parameters: list[FitParameter], optional - :ivar prefix: The model prefix, defaults to `''`. - :vartype prefix: str, optional + :ivar model_type: Model component base name (a prefix will be + added if multiple identical model components are added). + :vartype model_type: Literal['gaussian'] """ - model: Literal['gaussian'] - parameters: Annotated[ - conlist(item_type=FitParameter), - Field(validate_default=True)] = [] - prefix: Optional[str] = '' + LINEAR_PARAMETERS: ClassVar[list[str]] = ['amplitude'] + model_type: Literal['gaussian'] - _validate_parameters_parameters = field_validator( - 'parameters')(validate_parameters) + def _validate_parameters(self): + """Validate the model parameters.""" + for par in self.parameters: + if par.name == 'sigma': + par.min = 0.0 -class Lorentzian(CHAPBaseModel): +class LorentzianModel(ConstantModel): """Class representing a Lorentzian model component. - :ivar model: The model component base name (a prefix will be added - if multiple identical model components are added). - :vartype model: Literal['lorentzian'] - :ivar parameters: Function parameters, defaults to those auto - generated from the function signature (excluding the - independent variable). - :vartype parameters: list[FitParameter], optional - :ivar prefix: The model prefix, defaults to `''`. - :vartype prefix: str, optional + :ivar model_type: Model component base name (a prefix will be + added if multiple identical model components are added). + :vartype model_type: Literal['lorentzian'] """ - model: Literal['lorentzian'] - parameters: Annotated[ - conlist(item_type=FitParameter), - Field(validate_default=True)] = [] - prefix: Optional[str] = '' + LINEAR_PARAMETERS: ClassVar[list[str]] = ['amplitude'] + model_type: Literal['lorentzian'] - _validate_parameters_parameters = field_validator( - 'parameters')(validate_parameters) + def _validate_parameters(self): + """Validate the model parameters.""" + for par in self.parameters: + if par.name == 'sigma': + par.min = 0.0 -class PseudoVoigt(CHAPBaseModel): +class PseudoVoigtModel(ConstantModel): """Class representing a PseudoVoigt model component. - :ivar model: The model component base name (a prefix will be added - if multiple identical model components are added). - :vartype model: Literal['pvoigt'] - :ivar parameters: Function parameters, defaults to those auto - generated from the function signature (excluding the - independent variable). - :vartype parameters: list[FitParameter], optional - :ivar prefix: The model prefix, defaults to `''`. - :vartype prefix: str, optional + :ivar model_type: Model component base name (a prefix will be + added if multiple identical model components are added). + :vartype model_type: Literal['pvoigt'] """ - model: Literal['pvoigt'] - parameters: Annotated[ - conlist(item_type=FitParameter), - Field(validate_default=True)] = [] - prefix: Optional[str] = '' + LINEAR_PARAMETERS: ClassVar[list[str]] = ['amplitude'] + MODEL_PARAMETERS: ClassVar[list[str]] = ['fraction'] + model_type: Literal['pvoigt'] - _validate_parameters_parameters = field_validator( - 'parameters')(validate_parameters) + def _validate_parameters(self): + """Validate the model parameters.""" + for par in self.parameters: + if par.name == 'fraction': + par.min = 0.0 + par.max = 1.0 + elif par.name == 'sigma': + par.min = 0.0 -class Rectangle(CHAPBaseModel): +class RectangleModel(ConstantModel): """Class representing a Rectangle model component. - :ivar model: The model component base name (a prefix will be added - if multiple identical model components are added). - :vartype model: Literal['rectangle'] - :ivar parameters: Function parameters, defaults to those auto - generated from the function signature (excluding the - independent variable). - :vartype parameters: list[FitParameter], optional - :ivar prefix: The model prefix, defaults to `''`. - :vartype prefix: str, optional + :ivar model_type: Model component base name (a prefix will be + added if multiple identical model components are added). + :vartype model_type: Literal['rectangle'] """ - model: Literal['rectangle'] - parameters: Annotated[ - conlist(item_type=FitParameter), - Field(validate_default=True)] = [] - prefix: Optional[str] = '' - - _validate_parameters_parameters = field_validator( - 'parameters')(validate_parameters) + LINEAR_PARAMETERS: ClassVar[list[str]] = ['amplitude'] + MODEL_IDENTIFIERS: ClassVar[list[str]] = ['form'] + model_type: Literal['rectangle'] + def _validate_parameters(self): + """Validate the model parameters.""" + for par in self.parameters: + if par.name == 'form': + assert form in ('linear', 'atan', 'arctan', 'erf', 'logistic') -class Expression(CHAPBaseModel): +class ExpressionModel(ConstantModel): """Class representing an Expression model component. - :ivar model: The model component base name (a prefix will be added - if multiple identical model components are added). - :vartype model: Literal['expression'] + :ivar model_type: The model component base name (a prefix will be + added if multiple identical model components are added). + :vartype model_type: Literal['expression'] :ivar expr: Mathematical expression to represent the model component. :vartype expr: str - :ivar parameters: Function parameters, defaults to those auto - generated from the model expression (excluding the - independent variable). - :vartype parameters: list[FitParameter], optional - :ivar prefix: The model prefix, defaults to `''`. - :vartype prefix: str, optional """ - model: Literal['expression'] + model_type: Literal['expression'] expr: constr(strip_whitespace=True, min_length=1) - parameters: Annotated[ - conlist(item_type=FitParameter), - Field(validate_default=True)] = [] - prefix: Optional[str] = '' - _validate_parameters_parameters = field_validator( - 'parameters')(validate_parameters) + +# Available models for components of the fitting function +#MODEL_CLASSES = [ +# ConstantModel, +# LinearModel, +# QuadraticModel, +# ExponentialModel, +# GaussianModel, +# LorentzianModel, +# PseudoVoigtModel, +# RectangleModel, +# ExpressionModel, +#] + +# Reusable Discriminator Union for supported fit model components. +Model = Annotated[ +# FIX for Python 3.11+ Union[*MODEL_CLASSES], + Union[ + ConstantModel, + LinearModel, + QuadraticModel, + ExponentialModel, + GaussianModel, + LorentzianModel, + PseudoVoigtModel, + RectangleModel, + ExpressionModel, + ], + Field(discriminator='model_type') +] + +# Peak-like models: with amplitude, center and sigma as their +# function arguments +PEAK_LIKE_MODELS = { + 'gaussian': GaussianModel, + 'lorentzian': LorentzianModel, + 'pvoigt': PseudoVoigtModel, +} + +#MODEL_TYPE_TO_CLASS = {#v.model_type:v for v in MODEL_CLASSES} +# 'constant': constant, +# 'linear': linear, +# 'parabolic': parabolic, +# 'exponential': exponential, +# 'gaussian': gaussian, +# 'lorentzian': lorentzian, +# 'pvoigt': pvoigt, +# 'rectangle': rectangle, +#} -class Multipeak(CHAPBaseModel): +class MultipeakModel(CHAPBaseModel): """Class representing a multipeak model. - :ivar model: The model component base name (a prefix will be added - if multiple identical model components are added). - :vartype model: Literal['expression'] + :ivar model_type: The model component base name (a prefix will be + added if multiple identical model components are added). + :vartype model_type: Literal['expression'] :ivar centers: Peak centers. :vartype center: list[float] :ivar centers_range: Range of peak centers around their centers. @@ -730,7 +721,7 @@ class Multipeak(CHAPBaseModel): optional. """ - model: Literal['multipeak'] + model_type: Literal['multipeak'] centers: conlist(item_type=confloat(allow_inf_nan=False), min_length=1) centers_range: Optional[confloat(allow_inf_nan=False)] = None fit_type: Optional[Literal['uniform', 'unconstrained']] = 'unconstrained' @@ -739,29 +730,6 @@ class Multipeak(CHAPBaseModel): peak_models: Literal['gaussian', 'lorentzian', 'pvoigt'] = 'gaussian' -models = { - 'constant': {'name': constant, 'class': Constant}, - 'linear': {'name': linear, 'class': Linear}, - 'quadratic': {'name': quadratic, 'class': Quadratic}, - 'exponential': {'name': exponential, 'class': Exponential}, - 'gaussian': {'name': gaussian, 'class': Gaussian}, - 'lorentzian': {'name': lorentzian, 'class': Lorentzian}, - 'pvoigt': {'name': pvoigt, 'class': PseudoVoigt}, - 'rectangle': {'name': rectangle, 'class': Rectangle}, -} - -model_classes = ( - Constant, - Linear, - Quadratic, - Exponential, - Gaussian, - Lorentzian, - PseudoVoigt, - Rectangle, -) - - class FitConfig(CHAPBaseModel): """Class representing the configuration for the fit processor. @@ -779,8 +747,7 @@ class FitConfig(CHAPBaseModel): :vartype method: Literal[ 'leastsq', 'trf', 'dogbox', 'lm', 'least_squares'] :ivar models: The component(s) of the (composite) fit model. - :vartype models: - list[:attr:`~CHAP.utils.models.FitConfig.models`] + :vartype models: list[Model, MultipeakModel] :ivar num_proc: The number of processors used in fitting a map of data, defaults to `1`. :vartype num_proc: int, optional @@ -807,9 +774,7 @@ class FitConfig(CHAPBaseModel): memfolder: str = 'joblib_memmap' method: Literal[ 'leastsq', 'trf', 'dogbox', 'lm', 'least_squares'] = 'leastsq' - models: conlist(item_type=Union[ - Constant, Linear, Quadratic, Exponential, Gaussian, Lorentzian, - PseudoVoigt, Rectangle, Expression, Multipeak], min_length=1) + models: conlist(item_type=Union[Model, MultipeakModel], min_length=1) num_proc: conint(gt=0) = 1 parameters: conlist(item_type=FitParameter) = [] plot: StrictBool = False