diff --git a/CHAP/common/map_utils.py b/CHAP/common/map_utils.py index 1e2c8706..384aabd4 100755 --- a/CHAP/common/map_utils.py +++ b/CHAP/common/map_utils.py @@ -6,12 +6,16 @@ from pydantic import ( conint, conlist, + model_validator, Field, FilePath, ) +from typing import Optional # Local modules +from CHAP.common.models.common import IndexSliceConfig from CHAP.common.models.map import ( + DetectorConfig, Detector, MapConfig, ) @@ -54,44 +58,50 @@ class MapSliceProcessor(Processor): :ivar map_config: Map configuration. :vartype map_config: MapConfig - :ivar detectors: Detector configurations. - :vartype detectors: - list[:class:`~CHAP.common.models.map.Detector`] + :ivar detector_config: Detector configurations. + :vartype detector_config: :class:`~CHAP.common.models.map.DetectorConfig` :ivar spec_file: SPEC file containing scan from which to read a slice of raw data. - :vartype spec_file: str - :ivar scan_number: Number of scan from which to read a slice of + :vartype spec_file: pydantic.FilePath + :ivar scan_numbers: Numbers of scans from which to read slices of raw data. - :vartype scan_number: int + :vartype scan_numbers: list[int] + :ivar idx_slice: Parameters for the slice of each scan to process + Defaults to `IndexSliceConfig()`. + :vartype idx_slice: CHAP.common.models.common.IndexSliceConfig, optional """ pipeline_fields: dict = Field( default={ 'map_config': 'common.models.map.MapConfig', + 'detector_config': 'common.models.map.DetectorConfig' }, init_var=True) map_config: MapConfig - detectors: conlist(item_type=Detector, min_length=1) + detector_config: DetectorConfig + detectors: Optional[conlist(item_type=Detector)] = None spec_file: FilePath - scan_number: conint(gt=0) + scan_number: Optional[conint(gt=0)] = None + scan_numbers: Optional[conlist(item_type=conint(gt=0))] = None + idx_slice: Optional[IndexSliceConfig] = IndexSliceConfig() - def process(self, data, #spec_file, scan_number, - idx_slice={'start': 0, 'step': 1}): - """Aggregate partial spec and detector data from one scan in a - map, returning results in a format suitable for writing to the - full map container with + def process(self, data): + + """Aggregate partial spec and detector data from one or more + scans in a map, returning results in a format suitable for + writing to the full map container with :class:`~CHAP.common.writer.NexusValuesWriter` or :class:`~CHAP.common.writer.ZarrValuesWriter`. + When all scans are adjacent in the map and `idx_slice` covers + each scan in full, data_points entries for the same path are + consolidated into a single array + slice, avoiding redundant + write calls. + :param data: Result of `Reader.read` where at least one item has the value `'common.models.map.MapConfig'` for the `'schema'` key. :type data: list[PipelineData] - :type idx_slice: Parameters for the slice of the scan to - process (slice parameters are the usual for the python - `slice` object: `'start'`, `'stop'`, and - `'step'`). Defaults to `{'start': 0, 'step': '1'}`. - :type idx_slice: dict[str, int], optional :return: Slice of map data, ready to be written to a map container. :rtype: list[dict[str, Any]] @@ -103,68 +113,223 @@ def process(self, data, #spec_file, scan_number, import numpy as np # Local modules - from chess_scanparsers import choose_scanparser from CHAP.common.models.map import SpecScans - ScanParser = choose_scanparser( - self.map_config.station, self.map_config.experiment_type) - scans = SpecScans( - spec_file=self.spec_file, scan_numbers=[self.scan_number]) - scan = scans.get_scanparser(self.scan_number) - - # Get index offset for this data slice within the map - npts_scan = int(scan.spec_scan_npts) - nscans_prev = 0 - for scans in self.map_config.spec_scans: - for scan_n in scans.scan_numbers: - if (os.path.abspath(self.spec_file) == \ - os.path.abspath(self.spec_file) - and scan_n == self.scan_number): + scans_obj = SpecScans( + spec_file=self.spec_file, scan_numbers=self.scan_numbers) + self_spec_file_abs = os.path.abspath(str(self.spec_file)) + + if self.map_config.experiment_type == 'EDD': + def get_detector_data(scan, detector, index): + return scan.get_detector_data(detector.get_id(), index)[0][0] + else: + def get_detector_data(scan, detector, index): + return scan.get_detector_data(detector.get_id(), index) + + # Build flat ordered list of (abs_spec_file, scan_number) for + # all scans in the map to determine each scan's map position + map_scan_order = [] + for spec_scans_item in self.map_config.spec_scans: + sf_abs = os.path.abspath(str(spec_scans_item.spec_file)) + for sn in spec_scans_item.scan_numbers: + map_scan_order.append((sf_abs, sn)) + scan_positions = {} + for sn in self.scan_numbers: + for pos, (sf, n) in enumerate(map_scan_order): + if sf == self_spec_file_abs and n == sn: + scan_positions[sn] = pos break - nscans_prev += 1 - index_offset = nscans_prev * npts_scan - - # Get spec scan indices to process - scan_indices = range(npts_scan)[slice( - idx_slice.get('start', 0), - idx_slice.get('stop', npts_scan + 1), - idx_slice.get('step', 1) - )] - # Get map indices to write to - map_indices = slice( - idx_slice.get('start', 0) + index_offset, - idx_slice.get('stop', npts_scan + 1) + index_offset, - idx_slice.get('step', 1) + + # Process scans in map order + sorted_scan_numbers = sorted( + self.scan_numbers, key=lambda sn: scan_positions[sn]) + + slice_start = self.idx_slice._slice.start + slice_step = self.idx_slice._slice.step + + # Collect per-scan metadata; assumes uniform npts across scans + # for index_offset calculation (index_offset = map_pos * npts) + per_scan = [] + for sn in sorted_scan_numbers: + scan = scans_obj.get_scanparser(sn) + npts_scan = int(scan.spec_scan_npts) + index_offset = scan_positions[sn] * npts_scan + # Cap stop at npts_scan so map_indices and data stay in sync + slice_stop = min(self.idx_slice._slice.stop, npts_scan) \ + if self.idx_slice._slice.stop > 0 else npts_scan + scan_indices = range(npts_scan)[ + slice(slice_start, slice_stop, slice_step)] + map_indices = slice( + slice_start + index_offset, + slice_stop + index_offset, + slice_step, + ) + per_scan.append({ + 'sn': sn, 'scan': scan, 'npts_scan': npts_scan, + 'index_offset': index_offset, + 'scan_indices': scan_indices, + 'map_indices': map_indices, + 'full': (slice_start == 0 + and slice_step == 1 + and slice_stop == npts_scan), + }) + + # Consolidate into single data_points entries when all scans + # are adjacent in the map and idx_slice covers each scan fully + sorted_positions = [scan_positions[sn] for sn in sorted_scan_numbers] + scans_are_adjacent = all( + sorted_positions[i + 1] == sorted_positions[i] + 1 + for i in range(len(sorted_positions) - 1) + ) + can_consolidate = ( + len(per_scan) > 1 + and scans_are_adjacent + and all(ps['full'] for ps in per_scan) ) - data_points = [ - { - 'path': f'{self.map_config.title}/scalar_data/{s_d.label}', - 'data': np.asarray([ - s_d.get_value( - scans, self.scan_number, i, - scalar_data=self.map_config.scalar_data) - for i in scan_indices - ]), - 'idx': map_indices - } - for s_d in self.map_config.all_scalar_data - ] - data_points.extend( - [ - { - 'path': f'{self.map_config.title}/data/{det.get_id()}', + if can_consolidate: + first, last = per_scan[0], per_scan[-1] + merged_idx = slice( + first['index_offset'], + last['index_offset'] + last['npts_scan'], + 1, + ) + data_points = [{ + 'path': (f'{self.map_config.title}' + f'/independent_dimensions/index'), + 'data': np.arange( + first['index_offset'], + last['index_offset'] + last['npts_scan'], + ), + 'idx': merged_idx, + }] + for s_d in self.map_config.all_scalar_data: + data_points.append({ + 'path': (f'{self.map_config.title}' + f'/scalar_data/{s_d.label}'), + 'data': np.concatenate([ + np.asarray([ + s_d.get_value( + scans_obj, ps['sn'], i, + scalar_data=self.map_config.scalar_data) + for i in ps['scan_indices'] + ]) + for ps in per_scan + ]), + 'idx': merged_idx, + }) + for dim in self.map_config.independent_dimensions: + data_points.append({ + 'path': (f'{self.map_config.title}' + f'/independent_dimensions/{dim.label}'), + 'data': np.concatenate([ + np.asarray([ + dim.get_value( + scans_obj, ps['sn'], i, + scalar_data=self.map_config.scalar_data) + for i in ps['scan_indices'] + ]) + for ps in per_scan + ]), + 'idx': merged_idx, + }) + for det in self.detector_config.detectors: + data_points.append({ + 'path': (f'{self.map_config.title}' + f'/data/{det.get_id()}'), + 'data': np.concatenate([ + np.asarray([ + get_detector_data(ps['scan'], det, i) + for i in ps['scan_indices'] + ]) + for ps in per_scan + ]), + 'idx': merged_idx, + }) + else: + data_points = [] + for ps in per_scan: + data_points.append({ + 'path': (f'{self.map_config.title}' + f'/independent_dimensions/index'), + 'data': np.asarray( + [ps['index_offset'] + i for i in ps['scan_indices']] + ), + 'idx': ps['map_indices'], + }) + data_points.extend([{ + 'path': (f'{self.map_config.title}' + f'/scalar_data/{s_d.label}'), 'data': np.asarray([ - scan.get_detector_data(det.get_id(), i) - for i in scan_indices + s_d.get_value( + scans_obj, ps['sn'], i, + scalar_data=self.map_config.scalar_data) + for i in ps['scan_indices'] ]), - 'idx': map_indices - } - for det in self.detectors - ] - ) + 'idx': ps['map_indices'], + } for s_d in self.map_config.all_scalar_data]) + data_points.extend([{ + 'path': (f'{self.map_config.title}' + f'/independent_dimensions/{dim.label}'), + 'data': np.asarray([ + dim.get_value( + scans_obj, ps['sn'], i, + scalar_data=self.map_config.scalar_data) + for i in ps['scan_indices'] + ]), + 'idx': ps['map_indices'], + } for dim in self.map_config.independent_dimensions]) + data_points.extend([{ + 'path': (f'{self.map_config.title}' + f'/data/{det.get_id()}'), + 'data': np.asarray([ + get_detector_data(ps['scan'], det, i) + for i in ps['scan_indices'] + ]), + 'idx': ps['map_indices'], + } for det in self.detector_config.detectors]) return data_points + @model_validator(mode='before') + @classmethod + def fill_scan_numbers(cls, data): + if not isinstance(data, dict): + return data + if 'scan_numbers' not in data or data['scan_numbers'] is None: + if data.get('scan_number') is not None: + data['scan_numbers'] = [data['scan_number']] + elif isinstance(data['scan_numbers'], int): + data['scan_numbers'] = [data['scan_numbers']] + elif isinstance(data['scan_numbers'], str): + from CHAP.utils.general import string_to_list + data['scan_numbers'] = string_to_list(data['scan_numbers']) + return data + + @model_validator(mode='after') + def validate_scan_numbers(self): + if self.scan_numbers is None: + raise ValueError( + 'scan_numbers is required; alternatively, provide scan_number') + if self.scan_number is not None \ + and self.scan_number not in self.scan_numbers: + self.scan_numbers.append(self.scan_number) + return self + + @model_validator(mode='before') + def fill_detector_config(cls, data): + if not isinstance(data, dict): + return data + if 'detector_config' not in data or data['detector_config'] is None: + if data.get('detectors') is not None: + data['detector_config'] = DetectorConfig( + detectors=data['detectors'] + ) + else: + raise ValueError( + 'detector_config is required; alternatively, provide detectors' + ) + return data + class SpecScanToMapConfigProcessor(Processor): """Processor to get the diff --git a/CHAP/common/models/common.py b/CHAP/common/models/common.py index a739c4a8..d5652448 100755 --- a/CHAP/common/models/common.py +++ b/CHAP/common/models/common.py @@ -146,6 +146,25 @@ def validate_vrange(cls, vrange, info): for i in info.data['index_range']] +class IndexSliceConfig(CHAPBaseModel): + """Configuration for a python `sslice` object. + + :ivar start: A `start` parameter for `slice()`, defaults to 0. + :vartype start: int, optional + :ivar stop: A `stop` parameter for `slice()`, defaults to -1. + :vartype stop: int, optional + :ivar step: A `step` parameter for `slice()`, defaults to 1. + :vartype step: int, optional + """ + start: Optional[int] = 0 + stop: Optional[int] = -1 + step: Optional[int] = 1 + + @property + def _slice(self): + return slice(self.start, self.stop, self.step) + + class UnstructuredToStructuredConfig(CHAPBaseModel): """Configuration class to reshape data in an `NXdata `__ diff --git a/CHAP/common/models/integration.py b/CHAP/common/models/integration.py index fd88bb48..1bce0d24 100755 --- a/CHAP/common/models/integration.py +++ b/CHAP/common/models/integration.py @@ -634,17 +634,37 @@ def integrate(self, azimuthal_integrators, data): return results - def zarr_tree(self, dataset_shape, dataset_chunks='auto'): + def zarr_tree(self, dataset_shape, dataset_chunks='auto', nxlinks=None): """Return a dictionary representing a `zarr.group` that can be used to contain results from this integration. - :return: A `zarr.group` that can be used to contain the - integration results. + :param dataset_shape: Shape of the measurement (scan) dimensions + of the output dataset, excluding the integration dimensions. + :type dataset_shape: tuple[int, ...] + :param dataset_chunks: Chunk shape along the scan dimensions, or + ``'auto'``, defaults to ``'auto'``. + :type dataset_chunks: list[int] or str, optional + :param nxlinks: NeXus path(s) to link into the ``data`` group. + When the zarr tree is written to a ``.zarr`` file and + converted to ``.nxs`` with + :class:`~CHAP.common.processor.ZarrToNexusProcessor`, each + path produces an ``NXlink`` whose name is + ``os.path.basename(path)``. Accepts a single path string or + a list of path strings. + :type nxlinks: str or list[str], optional + :returns: Nested dict representing the zarr group tree for this + integration. :rtype: dict """ # Third party modules #import json + if isinstance(nxlinks, str): + nxlinks = [nxlinks] + data_attrs = {**self.get_axes_indices(len(dataset_shape))} + if nxlinks: + data_attrs['__nxlinks__'] = { + os.path.basename(p): p for p in nxlinks} tree = { # NXprocess 'attributes': { @@ -654,10 +674,7 @@ def zarr_tree(self, dataset_shape, dataset_chunks='auto'): 'children': { 'data': { # NXdata - 'attributes': { - # 'axes': self.result_axes(), - **self.get_axes_indices(len(dataset_shape)) - }, + 'attributes': data_attrs, 'children': { 'I': { # NXfield @@ -732,28 +749,50 @@ def validate_config(cls, data): data['azimuthal_integrators'] = ais return data - def zarr_tree(self, dataset_shape, dataset_chunks='auto'): + def zarr_tree(self, dataset_shape, dataset_chunks='auto', nxlinks=None): """Return a dictionary representing a `zarr.group` that can be used to contain results from :class:`~CHAP.saxswaxs.PyfaiIntegrationProcessor`. - :return: A `zarr.group` that can be used to contain the - integration results. + Each integration defined in :attr:`integrations` gets its own + sub-group keyed by the integration's ``name``. See + :meth:`PyfaiIntegratorConfig.zarr_tree` for the structure of + each sub-group. + + :param dataset_shape: Shape of the measurement (scan) dimensions + of the output dataset, excluding the integration dimensions. + :type dataset_shape: tuple[int, ...] + :param dataset_chunks: Chunk shape along the scan dimensions, or + ``'auto'``, defaults to ``'auto'``. + :type dataset_chunks: list[int] or str, optional + :param nxlinks: NeXus links to inject into each integration's + ``data`` group. May be a single path string or list of path + strings (forwarded to every integration), or a dict keyed by + integration name mapping each integration to its own path(s). + See :meth:`PyfaiIntegratorConfig.zarr_tree` for details on + how individual paths are handled. + :type nxlinks: str or list[str] or dict[str, str or list[str]], + optional + :returns: Nested dict representing the zarr group tree for all + integrations. :rtype: dict """ ais = {ai.get_id(): ai for ai in self.azimuthal_integrators} for integration in self.integrations: integration.init_placeholder_results(ais) + if not isinstance(nxlinks, dict): + nxlinks = {intg.name: nxlinks for intg in self.integrations} tree = { 'root': { 'attributes': { 'description': 'Container for processed SAXS/WAXS data' }, - 'children': { - integration.name: integration.zarr_tree( - dataset_shape, dataset_chunks) - for integration in self.integrations - } + }, + 'children': { + integration.name: integration.zarr_tree( + dataset_shape, dataset_chunks, + nxlinks=nxlinks.get(integration.name)) + for integration in self.integrations } } return tree diff --git a/CHAP/common/models/map.py b/CHAP/common/models/map.py index fa463f6b..70c929f3 100755 --- a/CHAP/common/models/map.py +++ b/CHAP/common/models/map.py @@ -77,7 +77,7 @@ class Detector(CHAPBaseModel): """ id_: constr(min_length=1) = Field(alias='id') - shape: Optional[tuple[int, int]] = None + shape: Optional[Union[tuple[int, int], tuple[int]]] = None attrs: Optional[dict] = {} @field_validator('id_', mode='before') @@ -190,7 +190,7 @@ class SpecScans(CHAPBaseModel): spec_file: FilePath scan_numbers: Union[ - constr(min_length=1), conlist(item_type=conint(gt=0), min_length=1)] + constr(min_length=1), conlist(item_type=conint(gt=0))] par_file: Optional[FilePath] = None @field_validator('spec_file') @@ -417,9 +417,11 @@ class PointByPointScanData(CHAPBaseModel): label: constr(min_length=1) units: constr(strip_whitespace=True, min_length=1) data_type: Literal[ - 'expression', 'detector_log_timestamps', 'scan_column', - 'scan_start_time', 'scan_step_index', 'smb_par', 'spec_motor', - 'spec_motor_absolute', 'spec_motor_static'] + 'detector_log_timestamps', 'expression', 'scan_column', + 'scan_number', 'scan_start_time', 'scan_step_index', + 'smb_par', 'spec_motor', 'spec_motor_absolute', + 'spec_motor_static' + ] name: constr(strip_whitespace=True, min_length=1) ndigits: Optional[conint(ge=0)] = None @@ -625,6 +627,8 @@ def get_value( return scan_step_index scanparser = get_scanparser(spec_scans.spec_file, scan_number) return [i for i in range(scanparser.spec_scan_npts)] + if self.data_type == 'scan_number': + return scan_number return None @cache @@ -838,7 +842,9 @@ def _validate_data_source_for_map_config(data_source, info): if data_source is not None: values = info.data if data_source.data_type == 'expression': - data_source.validate_for_scalar_data(values['scalar_data']) + if 'scalar_data' in values: + data_source.validate_for_scalar_data( + values['scalar_data']) else: import_scanparser( values['station'], values['experiment_type']) @@ -1201,6 +1207,7 @@ def validate_experiment_type(cls, experiment_type, info): f'For station {station}, allowed experiment types are ' f'{", ".join(allowed_experiment_types)}. ' f'Supplied experiment type {experiment_type} is not allowed.') + import_scanparser(station, experiment_type) return experiment_type @@ -1220,6 +1227,18 @@ def validate_before(cls, data): data['attrs'] = {} return data + @model_validator(mode='after') + def validate_scalar_data_expressions(self): + """Validate any expression-type items in scalar_data against + the complete scalar_data list. This deferred check is necessary + because field validation of scalar_data runs before the list is + fully constructed. + """ + for data_source in self.scalar_data: + if data_source.data_type == 'expression': + data_source.validate_for_scalar_data(self.scalar_data) + return self + #RV maybe better to use model_validator, see v2 docs? @field_validator('attrs') @classmethod @@ -1238,13 +1257,17 @@ def validate_attrs(cls, attrs, info): """ # Get the map's scan_type for EDD experiments values = info.data + if not values['validate_data_present']: + return attrs station = values['station'] experiment_type = values['experiment_type'] if station in ['id1a3', 'id3a'] and experiment_type == 'EDD': scan_type = cls.get_smb_par_attr(values, 'scan_type') if scan_type is not None: attrs['scan_type'] = scan_type - attrs['config_id'] = cls.get_smb_par_attr(values, 'config_id') + config_id = cls.get_smb_par_attr(values, 'config_id') + if config_id is not None: + attrs['config_id'] = config_id dataset_id = cls.get_smb_par_attr( values, 'dataset_id', unique=False) if dataset_id is not None: @@ -1296,6 +1319,8 @@ def get_smb_par_attr( # f'{scans.spec_file}.') values.append(None) values = list(set(values)) + if not values: + return None if len(values) == 1: return values[0] if unique: diff --git a/CHAP/common/processor.py b/CHAP/common/processor.py index 0d198c11..1a1ff447 100755 --- a/CHAP/common/processor.py +++ b/CHAP/common/processor.py @@ -965,6 +965,11 @@ class MapProcessor(Processor): :ivar num_proc: Number of processors used to read map, defaults to `1`. :vartype num_proc: int, optional + :ivar remove_constant_dims: Flag to indicate that any + `independent_dimension`s in the map whose values are constant + across the map should be exluded from the output + `NXentry`. Defaults to `True`. + :vartype remove_constant_dims: bool, optional """ pipeline_fields: dict = Field( @@ -975,6 +980,7 @@ class MapProcessor(Processor): config: Optional[MapConfig] = None detector_config: DetectorConfig = DetectorConfig(detectors=[]) num_proc: Optional[conint(gt=0)] = 1 + remove_constant_dims: Optional[bool] = True @field_validator('num_proc') @classmethod @@ -1054,7 +1060,7 @@ def process( spec_scans = self.config.spec_scans[0] scan_numbers = spec_scans.scan_numbers num_scan = len(scan_numbers) - if num_scan < self.num_proc: + if 0 < num_scan < self.num_proc: self.logger.warning( f'Requested number of processors ({self.num_proc}) exceeds ' f'the number of scans ({num_scan}): reset it to {num_scan}') @@ -1138,7 +1144,16 @@ def process( offset = common_comm.scatter(offsets, root=0) # Read the raw data - if self.config.experiment_type == 'EDD': + if num_scan == 0: + num_id = len(self.config.independent_dimensions) + num_sd = len(self.config.all_scalar_data) + num_det = len(self.detector_config.detectors) + if placeholder_data is not False: + num_sd += 1 + data = np.empty((num_det, 0)) + independent_dimensions = np.empty((num_id, 0)) + all_scalar_data = np.empty((num_sd, 0)) + elif self.config.experiment_type == 'EDD': data, independent_dimensions, all_scalar_data = \ self._read_raw_data_edd( common_comm, num_scan, offset, placeholder_data) @@ -1146,7 +1161,8 @@ def process( data, independent_dimensions, all_scalar_data = \ self._read_raw_data(common_comm, num_scan, offset) if not rank: - self.logger.debug(f'Data shape: {data.shape}') + self.logger.debug( + f'Data shape: {data.shape if data is not None else None}') if independent_dimensions is not None: self.logger.debug('Independent dimensions shape: ' f'{independent_dimensions.shape}') @@ -1195,6 +1211,11 @@ def process( all_scalar_data = np.empty( (len(self.config.all_scalar_data), map_len)) if len(self.detector_config.detectors) > 0: + if det_shapes is False: + det_shapes = {} + for det in self.detector_config.detectors: + if det.shape is not None: + det_shapes[det.get_id()] = det.shape data = np.empty( (len(self.detector_config.detectors), map_len, @@ -1343,35 +1364,63 @@ def linkdims(nxgroup, nxdata_source): nxentry.attrs[k] = v nxentry.spec_scans = NXcollection() for scans in self.config.spec_scans: - nxentry.spec_scans[scans.scanparsers[0].scan_name] = \ + if len(scans.scanparsers) > 0: + key = scans.scanparsers[0].scan_name + else: + if str(scans.spec_file).endswith('spec.log'): + key = str(scans.spec_file).split('/')[-2] + else: + key = str(scans.spec_file).split('/')[-1] + nxentry.spec_scans[key] = \ NXfield(value=scans.scan_numbers, dtype='int8', attrs={'spec_file': str(scans.spec_file)}) + nxentry.data = NXdata() # Add sample metadata nxentry[self.config.sample.name] = NXsample( **self.config.sample.model_dump()) - # Set up independent dimensions NXdata group - # (squeeze out constant dimensions) - constant_dim = [] - for i, dim in enumerate(self.config.independent_dimensions): - unique = np.unique(independent_dimensions[i]) - if unique.size == 1: - constant_dim.append(i) + # Set up independent dimensions NeXus NXdata group nxentry.independent_dimensions = NXdata() - if len(constant_dim) < len(self.config.independent_dimensions): + if self.remove_constant_dims: + # (squeeze out constant dimensions) + constant_dim = [] for i, dim in enumerate(self.config.independent_dimensions): - if i not in constant_dim: - nxentry.independent_dimensions[dim.label] = NXfield( - independent_dimensions[i], dim.label, - attrs={'units': dim.units, - 'long_name': f'{dim.label} ({dim.units})', - 'data_type': dim.data_type, - 'local_name': dim.name}) + unique = np.unique(independent_dimensions[i]) + if unique.size == 1: + constant_dim.append(i) + if len(constant_dim) < len(self.config.independent_dimensions): + for i, dim in enumerate(self.config.independent_dimensions): + if i not in constant_dim or not self.remove_constant_dims: + nxentry.independent_dimensions[dim.label] = NXfield( + independent_dimensions[i], dim.label, + attrs={'units': dim.units, + 'long_name': f'{dim.label} ({dim.units})', + 'data_type': dim.data_type, + 'local_name': dim.name}, + maxshape=(None, + *independent_dimensions[i].shape[1:]), + chunks=(1, *independent_dimensions[i].shape[1:]) + ) + else: + nxentry.independent_dimensions.index = NXfield( + np.arange(independent_dimensions[0].size), 'index', + maxshape=(None,), + chunks=(1,) + ) else: - nxentry.independent_dimensions.index = NXfield( - np.arange(independent_dimensions[0].size), 'index') + for i, dim in enumerate(self.config.independent_dimensions): + nxentry.independent_dimensions[dim.label] = NXfield( + independent_dimensions[i], dim.label, + attrs={'units': dim.units, + 'long_name': f'{dim.label} ({dim.units})', + 'data_type': dim.data_type, + 'local_name': dim.name}, + maxshape=(None, + *independent_dimensions[i].shape[1:]), + chunks=(1, *independent_dimensions[i].shape[1:]) + ) # Set up scalar data NXdata group # (add the constant independent dimensions) @@ -1387,7 +1436,11 @@ def linkdims(nxgroup, nxdata_source): units=dim.units, attrs={'long_name': f'{dim.label} ({dim.units})', 'data_type': dim.data_type, - 'local_name': dim.name})) + 'local_name': dim.name}, + maxshape=(None, *all_scalar_data[i].shape[1:]), + chunks=(1, *all_scalar_data[i].shape[1:]) + )) + if (self.config.experiment_type == 'EDD' and not placeholder_data is False): scalar_signals.append('placeholder_data_used') @@ -1395,23 +1448,31 @@ def linkdims(nxgroup, nxdata_source): value=all_scalar_data[-1], attrs={'description': 'Indicates whether placeholder data may be present for' - 'the corresponding frames of detector data.'})) - for i, dim in enumerate(deepcopy(self.config.independent_dimensions)): - if i in constant_dim: - scalar_signals.append(dim.label) - scalar_data.append(NXfield( - independent_dimensions[i], dim.label, - attrs={'units': dim.units, - 'long_name': f'{dim.label} ({dim.units})', - 'data_type': dim.data_type, - 'local_name': dim.name})) - self.config.all_scalar_data.append( - PointByPointScanData(**dim.model_dump())) - self.config.independent_dimensions.remove(dim) + 'the corresponding frames of detector data.'}, + maxshape=(None, *all_scalar_data[-1].shape[1:]), + chunks=(1, *all_scalar_data[-1].shape[1:]) + )) + if self.remove_constant_dims: + for i, dim in enumerate(deepcopy(self.config.independent_dimensions)): + if i in constant_dim: + scalar_signals.append(dim.label) + scalar_data.append(NXfield( + independent_dimensions[i], dim.label, + attrs={'units': dim.units, + 'long_name': f'{dim.label} ({dim.units})', + 'data_type': dim.data_type, + 'local_name': dim.name}, + maxshape=(None, *independent_dimensions[i].shape[1:]), + chunks=(1, *independent_dimensions[i].shape[1:]) + )) + self.config.all_scalar_data.append( + PointByPointScanData(**dim.model_dump())) + self.config.independent_dimensions.remove(dim) if scalar_signals: nxentry.scalar_data = NXdata() for k, v in zip(scalar_signals, scalar_data): nxentry.scalar_data[k] = v + nxentry.data.makelink(nxentry.scalar_data[k]) if 'SCAN_N' in scalar_signals: nxentry.scalar_data.attrs['signal'] = 'SCAN_N' else: @@ -1420,19 +1481,25 @@ def linkdims(nxgroup, nxdata_source): nxentry.scalar_data.attrs['auxiliary_signals'] = scalar_signals # Add detector data - nxdata = NXdata() - nxentry.data = nxdata + nxdata = nxentry.data nxentry.data.set_default() detector_ids = [] for k, v in self.config.attrs.items(): nxdata.attrs[k] = v if data is not None: - min_ = np.min(data, axis=tuple(range(1, data.ndim))) - max_ = np.max(data, axis=tuple(range(1, data.ndim))) + if data.size > 0: + min_ = np.min(data, axis=tuple(range(1, data.ndim))) + max_ = np.max(data, axis=tuple(range(1, data.ndim))) + else: + min_ = np.full(len(self.detector_config.detectors), np.nan) + max_ = np.full(len(self.detector_config.detectors), np.nan) for i, detector in enumerate(self.detector_config.detectors): nxdata[detector.get_id()] = NXfield( value=data[i], - attrs={**detector.attrs, 'min': min_[i], 'max': max_[i]}) + attrs={**detector.attrs, 'min': min_[i], 'max': max_[i]}, + maxshape=(None, *data[i].shape[1:]), + chunks=(1, *data[i].shape[1:]) + ) detector_ids.append(detector.get_id()) linkdims(nxdata, nxentry.independent_dimensions) if len(self.detector_config.detectors) == 1: @@ -2142,6 +2209,7 @@ def process(self, data, chunks='auto'): from nexusformat.nexus import ( NXfield, NXgroup, + NXlink, ) # pylint: disable=import-error import zarr @@ -2174,9 +2242,19 @@ def copy_group(nexus_group, zarr_group): else: zarr_group.attrs[attr_key] = attr_value.nxvalue - # Copy datasets and sub-groups + # Copy datasets, sub-groups, and links + nxlinks = {} for key, item in nexus_group.items(): - if isinstance(item, NXfield): + if isinstance(item, NXlink): + self.logger.info(f'Recording link {item.nxpath}') + if item.nxfilename is not None: + nxlinks[key] = { + 'target': item.nxtarget, + 'file': item.nxfilename, + } + else: + nxlinks[key] = item.nxtarget + elif isinstance(item, NXfield): if isinstance(item.nxdata, np.ndarray): try: # Determine chunks @@ -2211,6 +2289,8 @@ def copy_group(nexus_group, zarr_group): # Recursively copy subgroup zarr_subgroup = zarr_group.create_group(key) copy_group(item, zarr_subgroup) + if nxlinks: + zarr_group.attrs['__nxlinks__'] = nxlinks copy_group(nexus_group, zarr_group) return zarr_group @@ -2422,7 +2502,7 @@ def process( :type npt: int :param mask_file: File to use for masking the input data. :type mask_file: str, optional - :param integrate1d_kwargs: Optional dictionary of keywords + :param integrate1d_kwargs: Optional dictionary of keywords. :type integrate1d_kwargs: Optional[dict] :returns: Azimuthal integration results as a dictionary of numpy arrays. @@ -3447,9 +3527,10 @@ def copy_group(zarr_group, nexus_group): :type: nexusformat.nexus.NXgroup """ self.logger.info(f'Copying {zarr_group.path}') - # Copy attributes + # Copy attributes (skip the internal link-metadata key) for attr_key, attr_value in zarr_group.attrs.items(): - nexus_group.attrs[attr_key] = attr_value + if attr_key != '__nxlinks__': + nexus_group.attrs[attr_key] = attr_value # Copy datasets and sub-groups for key, item in zarr_group.members(): @@ -3471,6 +3552,17 @@ def copy_group(zarr_group, nexus_group): nexus_subgroup = nexus_group.create_group(key) copy_group(item, nexus_subgroup) + # Recreate NXlinks + for link_name, link_target in \ + zarr_group.attrs.get('__nxlinks__', {}).items(): + self.logger.info( + f'Creating link {zarr_group.path}/{link_name}') + if isinstance(link_target, dict): + nexus_group[link_name] = h5py.ExternalLink( + link_target['file'], link_target['target']) + else: + nexus_group[link_name] = h5py.SoftLink(link_target) + # Start copying from the root group copy_group(zarr_file, nexus_file) diff --git a/CHAP/common/reader.py b/CHAP/common/reader.py index 398fe189..5fc45067 100755 --- a/CHAP/common/reader.py +++ b/CHAP/common/reader.py @@ -550,7 +550,7 @@ class PandasReader(Reader): `pandas `__ """ - def read(self, filename, method='read_csv', comment='#', kwargs=None): + def read(self, filename, method='read_csv', comment='#', **kwargs): """Return a `pandas.DataFrame` read from the given file. :param filename: Name of file to read from. @@ -561,9 +561,8 @@ def read(self, filename, method='read_csv', comment='#', kwargs=None): :param comment: Character to identify comment lines in the input file, defaults to `'#'`. :type comment: str, optional - :param kwargs: Additional keyword arguments to supply to the + :param \*\*kwargs: Additional keyword arguments to supply to the `pandas` reader. - :param kwargs: dict, optional. :rtype: `pandas.DataFrame` """ # Third party modules @@ -575,8 +574,6 @@ def read(self, filename, method='read_csv', comment='#', kwargs=None): raise ValueError( f'{method} is not a callable pandas reader method') - if kwargs is None: - kwargs = {} if not isinstance(kwargs, dict): raise TypeError( f'Invalid kwargs type ({type(kwargs)}, should be dict)') @@ -587,6 +584,9 @@ def read(self, filename, method='read_csv', comment='#', kwargs=None): class NexusReader(Reader): """Reader for `NeXus `__ files. + :ivar create_copy: Return a copy of the selected data, resolving + any and all linked NeXus objects, defaults to `False`.. + :vartype create_copy: bool, optional :ivar nxpath: Path to a specific location in the NeXus file tree to read from, defaults to `'/'`. :vartype nxpath: str, optional @@ -598,10 +598,11 @@ class NexusReader(Reader): :vartype nxmemory: int, optional """ - nxpath: Optional[constr(strip_whitespace=True, min_length=1)] = '/' + create_copy: Optional[bool] = False idx: Optional[conint(ge=0)] = None mode: Literal['r', 'rw', 'r+', 'w', 'a'] = 'r' nxmemory: Optional[conint(gt=0)] = None + nxpath: Optional[constr(strip_whitespace=True, min_length=1)] = '/' def read(self): """Return the NeXus Style @@ -620,11 +621,22 @@ def read(self): nxsetconfig, ) + # Local modules + from CHAP.utils.general import nxcopy + if self.nxmemory is not None: nxsetconfig(memory=self.nxmemory) if self.idx is not None: - return nxload(self.filename, mode=self.mode)[self.nxpath][self.idx] - return nxload(self.filename, mode=self.mode)[self.nxpath] + nxobject = nxload( + self.filename, mode=self.mode)[self.nxpath][self.idx] + if self.create_copy: + return nxcopy(nxobject) + return nxobject + #return nxload(self.filename, mode=self.mode)[self.nxpath][self.idx] + nxobject = nxload(self.filename, mode=self.mode)[self.nxpath] + if self.create_copy: + nxobject = nxcopy(nxobject) + return nxobject class NXdataReader(Reader): @@ -921,7 +933,6 @@ def read(self): detector_roi=[ self.detector_config.roi[0].toslice(), self.detector_config.roi[1].toslice()] - print(f'\n\ndetector_roi: {detector_roi}\n\n') nxdata[detector.get_id()] = NXfield( value=scanparser.get_detector_data( detector.get_id(), diff --git a/CHAP/common/writer.py b/CHAP/common/writer.py index 7c1e7baa..094e40a5 100755 --- a/CHAP/common/writer.py +++ b/CHAP/common/writer.py @@ -6,7 +6,11 @@ # System modules import os -from typing import Optional +from typing import ( + Literal, + Optional, + Union, +) # Third party modules import numpy as np @@ -22,6 +26,7 @@ Writer, validate_writer_model, ) +from CHAP.common.models.common import IndexSliceConfig def validate_model(model): @@ -427,6 +432,62 @@ def write(self, data): raise ValueError(f'Invalid image input type {type(image_data)}') +class JSONWriter(Writer): + """Writer for JSON data. + + :ivar index: Index of ``PipelineData`` item in input ``data`` list + that should be written to the JSON file. Defaults to ``-1``. + :vartpe index: int, Optional. + :ivar update: Update an existing file with new values, overwriting + the existing file's values. Defaults to ``False``. + :vartype update: bool, optional + :ivar extend: Extend the values in an existing file, adding to the + existing file's value lists. Defaults to ``False``. + :vartype extend: bool, optional + """ + index: int = -1 + update: Optional[bool] = False + extend: Optional[bool] = False + + def write(self, data): + """Write the last """ + import json + + _data = self.get_pipelinedata_item( + data, + index=self.index, + remove=self.remove, + ) + + write_data = _data + if self.update: + try: + with open(self.filename, 'r') as inf: + write_data = json.load(inf) + except: + self.logger.warning(f'Could not load JSON from {self.filename}') + write_data = {} + if self.extend: + for k, v in _data.items(): + if k in write_data: + if not isinstance(write_data[k], list): + write_data[k] = [write_data[k]] + write_data[k].extend(v) + else: + write_data[k] = v + else: + write_data.update(_data) + + with open(self.filename, 'w') as outf: + json.dump(write_data, outf) + + @model_validator(mode='after') + def validate_modes(self): + if self.extend and not self.update: + self.update = True + return self + + class MatplotlibAnimationWriter(Writer): """Writer for saving `Matplotlib `__ animations. @@ -563,20 +624,36 @@ def write(self, data): class NexusValuesWriter(Writer): - """Writer for updating values in an existing - `NeXus `__ file.""" - - def write(self, data, filename, path_prefix=''): - """Write new values specified in `data` to the exising - `NeXus `__ file `filename`. - - :param data: List of dictionaries with the following entries -- - `'path'` identifying the location of the NeXus style - `NXfield `__ - object to which values will be written, `'data'` - identifying the data to be written, and `'idx'` - identifying the index / indicies of the NXfield to which - the data will be written. + """Writer for updating values in an existing NeXus file. + + :ivar path_prefix: Prefix to use for all paths in input `data`, + defaults to `''`. + :vartype path_prefix: str, optional + :ivar resize_axis: `False` OR the axis along which the dataset + should be resized if the target slice shape does not match the + data shape. If `False`, any mismatching shapes will just raise + an error. Defaults to `False`. + :vartype resize_axis: Union[int, Literal[False]], optional + :ivar idx_slice: Configuration for a slice object that will be + used to select the slice of the target array to write to. Used + only if the `"idx"` key is not present for an item in the + newest `PipelineData` item in `data`. Defaults to + `IndexSliceConfig()`. + :vartype idx_slice: CHAP.common.models.common.IndexSliceConfig, optional + """ + path_prefix: str = '' + resize_axis: Union[int, Literal[False]] = False + idx_slice: Optional[IndexSliceConfig] = IndexSliceConfig() + + def write(self, data, filename): + """Write new values specified in `data` to the exising NeXus + file `filename`. + + :param data: List of dictionaries with the following entries + -- `'path'` identifying the location of the `NXfield` object + to which values will be written, `'data'` identifying the data + to be written, and `'idx'` identifying the index / indicies of + the `NXfield` to which the data will be written. :type data: list[PipelineData] :param filename: Name of an existing NeXus file to update. :type filename: str @@ -588,14 +665,17 @@ def write(self, data, filename, path_prefix=''): from nexusformat.nexus import NXFile data = self.get_pipelinedata_item(data, remove=self.remove) - for d in data: - with NXFile(filename, 'a') as nxroot: - self.nxs_writer( - nxroot=nxroot, - path=os.path.join(path_prefix, d['path']), - idx=d['idx'], - data=d['data'] - ) + with NXFile(filename, 'a') as nxroot: + for d in data: + try: + self.nxs_writer( + nxroot=nxroot, + path=os.path.join(self.path_prefix, d['path']), + idx=d.get('idx', self.idx_slice._slice), + data=d['data'] + ) + except Exception as exc: + self.logger.error(exc) def nxs_writer(self, nxroot, path, idx, data): """Write data to a specific @@ -613,8 +693,8 @@ def nxs_writer(self, nxroot, path, idx, data): :type nxroot: nexusformat.nexus.NXroot :param path: Path to the dataset inside the NeXus file. :type path: str - :param idx: Index or slice where the data should be written. - :type idx: tuple or int + :param idx: Slice where the data should be written. + :type idx: slice :param data: Data to be written to the specified slice in the dataset. :type data: numpy.ndarray or compatible array-like object @@ -630,13 +710,39 @@ def nxs_writer(self, nxroot, path, idx, data): # Access the specified dataset dataset = nxroot[path] + self.logger.debug( + f'chunks, maxshape = {dataset.chunks}, {dataset.maxshape}' + ) - # Check that the slice shape matches the data shape data = np.asarray(data) + + # Check the datatype + if data.dtype != dataset.dtype: + self.logger.warning( + f'Converting new data (type: {data.dtype}) to {dataset.dtype}' + ) + data = data.astype(dataset.dtype) + + # Check the shape + self.logger.debug( + f'data shape, target shape = {data.shape}, {dataset[idx].shape}' + ) if dataset[idx].shape != data.shape: - raise ValueError( - f'Data shape {data.shape} does not match the target slice ' - f'shape {dataset[idx].shape}.') + if self.resize_axis is not False: + # Resize along the specified axis + start = idx.start or 0 + stop = start + data.shape[0] + newshape = max(dataset.shape[self.resize_axis], stop) + self.logger.debug( + f'Resizing {path} {dataset.shape} -> {newshape} ' + + f'along axis {self.resize_axis}' + ) + + dataset.resize(newshape, axis=self.resize_axis) + else: + raise ValueError( + f'Data shape {data.shape} does not match the target slice ' + f'shape {dataset[idx].shape}.') # Write the data to the specified slice dataset[idx] = data @@ -757,7 +863,7 @@ def write(self, data): from CHAP.models import CHAPBaseModel def get_dict(data): - if isinstance(data, dict): + if isinstance(data, dict) or isinstance(data, list): return data if isinstance(data, (BaseModel, CHAPBaseModel)): try: @@ -791,7 +897,7 @@ def get_dict(data): self.status = 'written' # Right now does nothing yet, but could # add a sort of modification flag later - # Return provenance with the output file name added + # Return provenance with the output file name added return self._update_provenance(data) @@ -799,12 +905,24 @@ class ZarrValuesWriter(Writer): """Writer for updating values in arrays of an existing `Zarr `__ file. - :ivar path_prefix: Prefix to prepend to all "path" fields in - `data` before writing. Defaults to `""`. + :ivar path_prefix: Prefix to use for all paths in input `data`, + defaults to `''`. :vartype path_prefix: str, optional + :ivar resize_axis: `False` OR the axis along which the dataset + should be resized if the target slice shape does not match the + data shape. If `False`, any mismatching shapes will just raise + an error. Defaults to `False`. + :vartype resize_axis: Union[int, Literal[False]], optional + :ivar idx_slice: Configuration for a slice object that will be + used to select the slice of the target array to write to. Used + only if the `"idx"` key is not present for an item in the + newest `PipelineData` item in `data`. Defaults to + `IndexSliceConfig()`. + :vartype idx_slice: CHAP.common.models.common.IndexSliceConfig, optional """ - path_prefix: Optional[str] = '' + resize_axis: Union[int, Literal[False]] = False + idx_slice: Optional[IndexSliceConfig] = IndexSliceConfig() def write(self, data): """Write values to specific paths and slices in an existing @@ -824,11 +942,16 @@ def write(self, data): # Get list of PyfaiIntegrationProcessor results to write for d in self.get_pipelinedata_item(data, remove=self.remove): - self.zarr_writer( - zarrfile=zarrfile, - path=os.path.join(self.path_prefix, d['path']), - idx=d['idx'], - data=d['data']) + try: + self.logger.info(f'd = {d}') + self.zarr_writer( + zarrfile=zarrfile, + path=os.path.join(self.path_prefix, d['path']), + idx=d.get('idx', self.idx_slice._slice), + data=d['data'] + ) + except Exception as exc: + self.logger.error(exc) def zarr_writer(self, zarrfile, path, idx, data): """Write data to a specific dataset. @@ -856,18 +979,41 @@ def zarr_writer(self, zarrfile, path, idx, data): # Check if the dataset exists if path not in zarrfile: raise ValueError( - f'Dataset "{path}" does not exist in the Zarr file.') + f'Dataset "{path}" does not exist in the Zarr file.' + ) # Access the specified dataset dataset = zarrfile[path] + self.logger.debug( + f'chunks = {dataset.chunks}' + ) + + data = np.asarray(data) - # Check that the slice shape matches the data shape - if dataset[idx].shape != data.shape and data.shape[0] == 1: - data = np.squeeze(data, axis=0) + # Check the shape + self.logger.info( + f'data shape, target shape = {data.shape}, {dataset[idx].shape}' + ) if dataset[idx].shape != data.shape: - raise ValueError( - f'Data shape {data.shape} does not match the target slice ' - f'shape {dataset[idx].shape}.') + if self.resize_axis is not False: + # Resize along the specified axis + start = idx.start or 0 + stop = start + data.shape[0] + newshape_axis = max(dataset.shape[self.resize_axis], stop) + oldshape = dataset.shape + newshape = ( + *oldshape[0:self.resize_axis], + newshape_axis, + *oldshape[self.resize_axis + 1:] + ) + self.logger.info( + f'Resizing {path} {oldshape} -> {newshape}' + ) + dataset.resize(newshape) + else: + raise ValueError( + f'Data shape {data.shape} does not match the target slice ' + f'shape {dataset[idx].shape}.') # Write the data to the specified slice dataset[idx] = data diff --git a/CHAP/edd/models.py b/CHAP/edd/models.py index d224a69e..0ad7dabe 100755 --- a/CHAP/edd/models.py +++ b/CHAP/edd/models.py @@ -883,15 +883,17 @@ class StrainAnalysisConfig(MCACalibrationConfig): the maximum mean intensity for that detector. Defaults to `0` in which case this step is ignored. :vartype find_peak_cutoff: float, optional + :ivar max_nfev: Maximum number of function evaluations in the + the strain analysis peak fitting routine. + :vartype max_nfev: int, optional :ivar num_proc: Number of processors to be used by the strain analysis peak fitting routine. - :vartype num_proc: int + :vartype num_proc: int, optional :ivar rel_height_cutoff: Used to excluded peaks based on the `find_peak` parameter as well as for peak fitting exclusion of the individual detector spectra (see the strain detector configuration :class:`~CHAP.edd.models.MCADetectorStrainAnalysis`). - Defaults to `None`. :vartype rel_height_cutoff: float, optional :ivar skip_animation: Skip the animation and plotting of the strain analysis fits, defaults to `False`. @@ -904,6 +906,7 @@ class StrainAnalysisConfig(MCACalibrationConfig): #:vartype oversampling: FIX find_peak_cutoff: Optional[confloat(ge=0.0, allow_inf_nan=False)] = 0.0 + max_nfev: Optional[conint(gt=0)] = None num_proc: Optional[conint(gt=0)] = max(1, os.cpu_count()//4) #oversampling: dict = {'num': 10} rel_height_cutoff: Optional[ diff --git a/CHAP/edd/processor.py b/CHAP/edd/processor.py index 2309ff78..fae88013 100755 --- a/CHAP/edd/processor.py +++ b/CHAP/edd/processor.py @@ -4,7 +4,6 @@ Add discription of EDD """ - # System modules from copy import deepcopy import os @@ -736,11 +735,10 @@ def _measure_dvl(self, scanned_vals): models.append({'model': model, 'prefix': f'{model}_'}) models.append({'model': 'gaussian'}) self.logger.debug('Fitting mean spectrum') - fit = FitProcessor(**self.run_config) - result = fit.process( - NXdata( - NXfield(masked_sum, 'y'), NXfield(x, 'x')), - {'models': models, 'method': 'trf'}) + result = FitProcessor.run( + data={'x': x, 'y': masked_sum}, + config={'models': models, 'method': 'trf'}, + **self.run_config) # Calculate / manually select diffraction volume length detector.dvl = float( @@ -1185,6 +1183,8 @@ class MCAEnergyCalibrationProcessor(_BaseEddProcessor): config: MCAEnergyCalibrationConfig detector_config: MCADetectorConfig + _peak_fit_results: dict = PrivateAttr(default={}) + @model_validator(mode='before') @classmethod def validate_mcaenergycalibrationprocessor_before(cls, data): @@ -1309,6 +1309,10 @@ def process(self, data): **self.config.model_dump(), 'detectors': [d.model_dump(exclude_defaults=True) for d in self.detector_config.detectors]} + # Add the fit results to the detector fields + for i, d in enumerate(self.detector_config.detectors): + configs['detectors'][i].update( + {'peak_fit_results': self._peak_fit_results[d.get_id()]}) return configs, PipelineData( name=self.__name__, data=self._figures, schema='common.write.ImageWriter') @@ -1369,7 +1373,9 @@ def _calibrate(self): self._energies, self._masks, self._mean_data, self._mask_index_ranges, self.detector_config.detectors): - self.logger.info(f'Calibrating detector {detector.get_id()}') + id_ = detector.get_id() + + self.logger.info(f'Calibrating detector {id_}') bins = low + np.arange(energies.size, dtype=np.int16) @@ -1378,13 +1384,11 @@ def _calibrate(self): for energy in peak_energies] buf, initial_peak_indices = self._get_initial_peak_positions( mean_data*np.asarray(mask).astype(np.int32), low, - detector.mask_ranges, input_indices, max_peak_index, - detector.get_id(), return_buf=self.save_figures) + detector.mask_ranges, input_indices, max_peak_index, id_, + return_buf=self.save_figures) if self.save_figures: self._figures.append( - (buf, - f'{detector.get_id()}' - '_energy_calibration_initial_peak_positions')) + (buf, f'{id_}_energy_calibration_initial_peak_positions')) # Construct the fit model and perform the fit models = [] @@ -1401,33 +1405,30 @@ def _calibrate(self): 'fwhm_min': detector.fwhm_min, 'fwhm_max': detector.fwhm_max}) self.logger.debug('Fitting spectrum') - fit = FitProcessor(**self.run_config) - mean_data_fit = fit.process( - NXdata( - NXfield(mean_data[mask], 'y'), NXfield(bins[mask], 'x')), - {'models': models, 'method': 'trf'}) + mean_data_fit = FitProcessor.run( + data={'x': bins[mask], 'y': mean_data[mask]}, + config={'models': models, 'method': 'trf'}, + **self.run_config) # Extract the fit results for the peaks - fit_peak_amplitudes = sorted([ + fit_peak_amplitudes = np.asarray([ mean_data_fit.best_values[f'peak{i+1}_amplitude'] for i in range(len(initial_peak_indices))]) self.logger.debug(f'Fit peak amplitudes: {fit_peak_amplitudes}') - fit_peak_indices = sorted([ + fit_peak_indices = np.asarray([ mean_data_fit.best_values[f'peak{i+1}_center'] for i in range(len(initial_peak_indices))]) self.logger.debug(f'Fit peak center indices: {fit_peak_indices}') - fit_peak_fwhms = sorted([ - 2.35482*mean_data_fit.best_values[f'peak{i+1}_sigma'] + fit_peak_sigmas = np.asarray([ + mean_data_fit.best_values[f'peak{i+1}_sigma'] for i in range(len(initial_peak_indices))]) - self.logger.debug(f'Fit peak fwhms: {fit_peak_fwhms}') + self.logger.debug(f'Fit peak sigmas: {fit_peak_sigmas}') # FIX for now stick with a linear energy correction - fit = FitProcessor(**self.run_config) - energy_fit = fit.process( - NXdata( - NXfield(peak_energies, 'y'), - NXfield(fit_peak_indices, 'x')), - {'models': [{'model': 'linear'}]}) + energy_fit = FitProcessor.run( + data={'x': fit_peak_indices, 'y': peak_energies}, + config={'models': [{'model': 'linear'}]}, + **self.run_config) a = 0.0 b = float(energy_fit.best_values['slope']) c = float(energy_fit.best_values['intercept']) @@ -1446,8 +1447,7 @@ def _calibrate(self): bins_masked = bins[mask] fig, axs = plt.subplots(1, 2, figsize=(11, 4.25)) - fig.suptitle( - f'Detector {detector.get_id()} energy calibration') + fig.suptitle(f'Detector {id_} energy calibration') # Left plot: raw MCA data and best fit of peaks axs[0].set_title('MCA spectrum peak fit') axs[0].set_xlabel('Detector Channel (-)') @@ -1488,12 +1488,22 @@ def _calibrate(self): if self.save_figures: self._figures.append( - (fig_to_iobuf(fig), - f'{detector.get_id()}_energy_calibration_fit')) + (fig_to_iobuf(fig), f'{id_}_energy_calibration_fit')) if self.interactive: plt.show() plt.close() + # Collect the fit results for the peaks + # FIX assume Gaussian peaks for now and a=0 + self._peak_fit_results[id_] = { + 'heights': (fit_peak_amplitudes / + (np.sqrt(2*np.pi)*fit_peak_sigmas)).tolist(), + 'centers': (b*fit_peak_indices + c).tolist(), + 'fwhm': (b*fit_peak_sigmas*np.sqrt(8*np.log(2))).tolist(), + 'independent_dimension': { + 'name': 'Energy', 'unit': '-'}, + } + def _get_initial_peak_positions( self, y, low, index_ranges, input_indices, input_max_peak_index, detector_id, reset_flag=0, return_buf=False): @@ -1740,6 +1750,8 @@ class MCATthCalibrationProcessor(_BaseEddProcessor): config: MCATthCalibrationConfig detector_config: MCADetectorConfig + _peak_fit_results: dict = PrivateAttr(default={}) + @model_validator(mode='before') @classmethod def validate_mcatthcalibrationprocessor_before(cls, data): @@ -1875,6 +1887,10 @@ def process(self, data): **self.config.model_dump(), 'detectors': [d.model_dump(exclude_defaults=True) for d in self.detector_config.detectors]} + # Add the fit results to the detector fields + for i, d in enumerate(self.detector_config.detectors): + configs['detectors'][i].update( + {'peak_fit_results': self._peak_fit_results[d.get_id()]}) return configs, PipelineData( name=self.__name__, data=self._figures, schema='common.write.ImageWriter') @@ -2120,17 +2136,26 @@ def _direct_bragg_peak_fit( 'fwhm_max': detector.fwhm_max}) # Perform an unconstrained fit in terms of MCA bin index - fit = FitProcessor(**self.run_config) - result = fit.process( - NXdata(NXfield(mean_data[mask], 'y'), NXfield(bins[mask], 'x')), - {'models': models, 'method': 'trf'}) + result = FitProcessor.run( + data={'x': bins[mask], 'y': mean_data[mask]}, + config={'models': models, 'method': 'trf'}, + **self.run_config) best_fit = result.best_fit residual = result.residual # Extract the Bragg peak indices from the fit - i_bragg_fit = np.asarray( - [result.best_values[f'peak{i+1}_center'] - for i in range(len(e_bragg))]) + fit_peak_amplitudes = np.asarray([ + result.best_values[f'peak{i+1}_amplitude'] + for i in range(len(e_bragg))]) + self.logger.debug(f'Fit peak amplitudes: {fit_peak_amplitudes}') + fit_peak_indices = np.asarray([ + result.best_values[f'peak{i+1}_center'] + for i in range(len(e_bragg))]) + self.logger.debug(f'Fit peak center indices: {fit_peak_indices}') + fit_peak_sigmas = np.asarray([ + result.best_values[f'peak{i+1}_sigma'] + for i in range(len(e_bragg))]) + self.logger.debug(f'Fit peak sigmas: {fit_peak_sigmas}') # Fit a line through zero strain peak energies vs detector # energy bins @@ -2138,10 +2163,10 @@ def _direct_bragg_peak_fit( model = 'quadratic' else: model = 'linear' - fit = FitProcessor(**self.run_config) - result = fit.process( - NXdata(NXfield(e_bragg, 'y'), NXfield(i_bragg_fit, 'x')), - {'models': [{'model': model}]}) + result = FitProcessor.run( + data={'x': fit_peak_indices, 'y': e_bragg}, + config={'models': [{'model': model}]}, + **self.run_config) if quadratic_energy_calibration: a_fit = result.best_values['a'] b_fit = result.best_values['b'] @@ -2151,7 +2176,19 @@ def _direct_bragg_peak_fit( b_fit = result.best_values['slope'] c_fit = result.best_values['intercept'] e_bragg_unconstrained = ( - (a_fit*i_bragg_fit + b_fit) * i_bragg_fit + c_fit) + (a_fit*fit_peak_indices + b_fit) * fit_peak_indices + c_fit) + + # Collect the fit results for the peaks + # FWHM to first order (i.e. ignore non-zero a) + # FIX assume Gaussian peaks for now + self._peak_fit_results[detector.get_id()] = { + 'heights': (fit_peak_amplitudes / + (np.sqrt(2*np.pi)*fit_peak_sigmas)).tolist(), + 'centers': e_bragg_unconstrained.tolist(), + 'fwhm': (b_fit*fit_peak_sigmas*np.sqrt(8*np.log(2))).tolist(), + 'independent_dimension': { + 'name': 'Energy', 'unit': 'keV'}, + } return { 'best_fit_unconstrained': best_fit, @@ -2256,10 +2293,11 @@ def _direct_fit_tth_ecc( {'name': 'sigma', 'min': sig_min, 'max': sig_max}]}) # Perform the fit - fit = FitProcessor(**self.run_config) - result = fit.process( - NXdata(NXfield(mean_data[mask], 'y'), NXfield(bins[mask], 'x')), - {'parameters': parameters, 'models': models, 'method': 'trf'}) + result = FitProcessor.run( + data={'x': bins[mask], 'y': mean_data[mask]}, + config={ + 'parameters': parameters, 'models': models, 'method': 'trf'}, + **self.run_config) # Extract values of interest from the best values tth_fit = np.degrees(result.best_values['tth']) @@ -2291,14 +2329,10 @@ def _direct_fit_tth_ecc( 'centers_range': b_fit * detector.centers_range, 'fwhm_min': b_fit * detector.fwhm_min, 'fwhm_max': b_fit * detector.fwhm_max}] - result = fit.process( - NXdata(NXfield(mean_data[mask], 'y'), NXfield(bins[mask], 'x')), - {'parameters': parameters, 'models': models, 'method': 'trf'}) - fit = FitProcessor(**self.run_config) - result = fit.process( - NXdata(NXfield(mean_data[mask], 'y'), - NXfield(energies[mask], 'x')), - {'models': models, 'method': 'trf'}) + result = FitProcessor.run( + data={'x': energies[mask], 'y': mean_data[mask]}, + config={'models': models, 'method': 'trf'}, + **self.run_config) e_xrf_unconstrained = np.sort( [result.best_values[f'peak{i+1}_center'] for i in range(num_xrf)]) @@ -2538,6 +2572,21 @@ class StrainAnalysisProcessor(_BaseStrainProcessor): results as a list of updated points or update the result from the `setup` stage, defaults to `True`. :vartype update: bool, optional + :ivar standalone: Return results in standalone format suitable + for autonomous/streaming experiments, defaults to `False`. + When `True`, setup returns a standalone + `nexusformat.nexus.NXprocess` (not wrapped in an + `nexusformat.nexus.NXroot`), and update returns a list of + ``{'data': array, 'path': str}`` dicts suitable for writing + with ``common.NexusValuesWriter``. When both `setup` and + `update` are `True` the standalone `NXprocess` and the + values list are returned together as a tuple. + :vartype standalone: bool, optional + :ivar json_results: If updating, return an additional minimal + dictionary of results that can be written to .json file for + easier access by autonomous feedback experiment + drivers. Defaults to `False`. + :vartype json_results: bool, optional """ pipeline_fields: dict = Field( @@ -2552,6 +2601,8 @@ class StrainAnalysisProcessor(_BaseStrainProcessor): detector_config: MCADetectorConfig setup: Optional[bool] = True update: Optional[bool] = True + standalone: Optional[bool] = False + json_results: Optional[bool] = False @model_validator(mode='before') @classmethod @@ -2688,6 +2739,9 @@ def process(self, data): nxdata = nxentry[nxentry.default] # Load the validated calibration configuration + # FIX make this a class field and add to pipeline_fields too? + # NB: be sure that adding this field does not mess up + # self.detector_config during pydantic validation calibration_config = self.get_config( data, schema='edd.models.MCATthCalibrationConfig', remove=False) @@ -2722,6 +2776,17 @@ def process(self, data): self.logger.warning( f'Skipping detector {detector_id} (Illegal data shape ' f'{raw_detector_data.shape})') + elif raw_detector_data.size == 0 and self.setup: + # 0-scan map: no spectra yet, include for setup + for k, v in nxdata[detector_id].attrs.items(): + detector.attrs[k] = v.nxdata + if self.config.rel_height_cutoff is not None: + detector.rel_height_cutoff = \ + self.config.rel_height_cutoff + detector.add_calibration( + calibration_detectors[ + int(calibration_detector_ids.index(detector_id))]) + detectors.append(detector) elif raw_detector_data.sum(): for k, v in nxdata[detector_id].attrs.items(): detector.attrs[k] = v.nxdata @@ -2773,25 +2838,53 @@ def process(self, data): # Apply the combined energy ranges mask self._apply_combined_mask() + # Populate _peak_fit_info when there are no spectra (0-scan setup) + no_raw_data = ( + bool(self._nxdata_detectors) + and self._nxdata_detectors[0].nxsignal.shape[0] == 0) + if no_raw_data: + self._populate_peak_fit_info() + # Setup and/or run the strain analysis - points = [] - if self.update: - points = self._strain_analysis() + results = {} + if self.update and not no_raw_data: + results = self._strain_analysis() if self.setup: - nxroot = self._get_nxroot(nxentry, calibration_config) - if points: - self.logger.info(f'Adding {len(points)} points') - self.add_points(nxroot, points, logger=self.logger) - self.logger.info('... done') + if self.standalone: + nxsetup = self._get_nxprocess(nxentry, calibration_config) else: - self.logger.warning('Skip adding points') - if not (self._figures or self._animation): - return nxroot - ret = [nxroot] + nxsetup = self._get_nxroot(nxentry, calibration_config) + if not self.standalone: + points = self._get_points(results) + if points: + self.logger.info(f'Adding {len(points)} points') + self.add_points(nxsetup, points, logger=self.logger) + self.logger.info(f'... done') + else: + self.logger.warning('Skip adding points') + if self.standalone and self.update: + # Return nxprocess structure and values separately for writer + values = self._get_values(results) + ret = [nxsetup, values] + else: + if not (self._figures or self._animation): + return nxsetup + ret = [nxsetup] else: + result = self._get_values(results) if self.standalone \ + else self._get_points(results) + ret = [result] + if self.json_results: + json_results = { + k: v.tolist() + for k, v in results.items() + if not k.endswith('intensity') # exclude raw detector data + and not k.endswith('best_fit') # exlude best_fit spectra + and not k.endswith('residual') + } + ret.append(json_results) if not (self._figures or self._animation): - return points - ret = [points] + return tuple(ret) if self._figures: ret.append( PipelineData( @@ -2830,14 +2923,32 @@ def _add_fit_nxcollection(self, nxdetector, fit_type, hkls, peak_fit_info): nxcollection.results = NXdata() nxdata = nxcollection.results self._linkdims(nxdata, det_nxdata) - nxdata.best_fit = NXfield(shape=shape, dtype=np.float64) + nxdata.best_fit = NXfield( + shape=shape, dtype=np.float64, + maxshape=(None, *shape[1:]), chunks=(1, *shape[1:]) + ) nxdata.included_peaks = NXfield( - shape=[shape[0], len(hkls)], dtype=bool) - nxdata.included_peaks.attrs['hkls'] = dumps(peak_fit_info['hkls']) - nxdata.included_peaks.attrs['use_peaks'] = peak_fit_info['use_peaks'] - nxdata.residual = NXfield(shape=shape, dtype=np.float64) - nxdata.redchi = NXfield(shape=[shape[0]], dtype=np.float64) - nxdata.success = NXfield(shape=[shape[0]], dtype='bool') + shape=[shape[0], len(hkls)], dtype=bool, + maxshape=(None, *shape[1:]), chunks=(1, *shape[1:]) + ) + nxdata.included_peaks.attrs['hkls'] = dumps( + peak_fit_info.get('hkls', '') + ) + nxdata.included_peaks.attrs['use_peaks'] = peak_fit_info.get( + 'use_peaks' + ) + nxdata.residual = NXfield( + shape=shape, dtype=np.float64, + maxshape=(None, *shape[1:]), chunks=(1, *shape[1:]) + ) + nxdata.redchi = NXfield( + shape=[shape[0]], dtype=np.float64, + maxshape=(None,), chunks=(1,) + ) + nxdata.success = NXfield( + shape=[shape[0]], dtype='bool', + maxshape=(None,), chunks=(1,) + ) # Peak-by-peak results for hkl in hkls: @@ -2857,9 +2968,13 @@ def _add_fit_nxcollection(self, nxdetector, fit_type, hkls, peak_fit_info): nxcollection[hkl_name].amplitudes, det_nxdata, skip_field_dims=['energy']) nxcollection[hkl_name].amplitudes.values = NXfield( - shape=[shape[0]], dtype=np.float64, attrs={'units': 'counts'}) + shape=[shape[0]], dtype=np.float64, attrs={'units': 'counts'}, + maxshape=(None,), chunks=(1,) + ) nxcollection[hkl_name].amplitudes.errors = NXfield( - shape=[shape[0]], dtype=np.float64) + shape=[shape[0]], dtype=np.float64, + maxshape=(None,), chunks=(1,) + ) nxcollection[hkl_name].amplitudes.attrs['signal'] = 'values' # Report HKL peak centers nxcollection[hkl_name].centers = NXdata() @@ -2867,9 +2982,13 @@ def _add_fit_nxcollection(self, nxdetector, fit_type, hkls, peak_fit_info): nxcollection[hkl_name].centers, det_nxdata, skip_field_dims=['energy']) nxcollection[hkl_name].centers.values = NXfield( - shape=[shape[0]], dtype=np.float64, attrs={'units': 'keV'}) + shape=[shape[0]], dtype=np.float64, attrs={'units': 'keV'}, + maxshape=(None,), chunks=(1,) + ) nxcollection[hkl_name].centers.errors = NXfield( - shape=[shape[0]], dtype=np.float64) + shape=[shape[0]], dtype=np.float64, + maxshape=(None,), chunks=(1,) + ) nxcollection[hkl_name].centers.attrs['signal'] = 'values' # Report HKL peak FWHMs nxcollection[hkl_name].sigmas = NXdata() @@ -2877,20 +2996,28 @@ def _add_fit_nxcollection(self, nxdetector, fit_type, hkls, peak_fit_info): nxcollection[hkl_name].sigmas, det_nxdata, skip_field_dims=['energy']) nxcollection[hkl_name].sigmas.values = NXfield( - shape=[shape[0]], dtype=np.float64, attrs={'units': 'keV'}) + shape=[shape[0]], dtype=np.float64, attrs={'units': 'keV'}, + maxshape=(None,), chunks=(1,) + ) nxcollection[hkl_name].sigmas.errors = NXfield( - shape=[shape[0]], dtype=np.float64) + shape=[shape[0]], dtype=np.float64, + maxshape=(None,), chunks=(1,) + ) nxcollection[hkl_name].sigmas.attrs['signal'] = 'values' - if peak_fit_info['peak_models'] == 'pvoigt': + if peak_fit_info.get('peak_models') == 'pvoigt': # Report HKL peak fractions nxcollection[hkl_name].fractions = NXdata() self._linkdims( nxcollection[hkl_name].fractions, det_nxdata, skip_field_dims=['energy']) nxcollection[hkl_name].fractions.values = NXfield( - shape=[shape[0]], dtype=np.float64) + shape=[shape[0]], dtype=np.float64, + maxshape=(None,), chunks=(1,) + ) nxcollection[hkl_name].fractions.errors = NXfield( - shape=[shape[0]], dtype=np.float64) + shape=[shape[0]], dtype=np.float64, + maxshape=(None,), chunks=(1,) + ) nxcollection[hkl_name].fractions.attrs['signal'] = 'values' # Report HKL peak strains (unconstrained only) if fit_type == 'unconstrained': @@ -2900,13 +3027,43 @@ def _add_fit_nxcollection(self, nxdetector, fit_type, hkls, peak_fit_info): skip_field_dims=['energy']) values = np.full(shape=[shape[0]], fill_value=np.nan) nxcollection[hkl_name].strains.values = NXfield( - value=values, shape=[shape[0]], dtype=np.float64) + value=values, shape=[shape[0]], dtype=np.float64, + maxshape=(None,), chunks=(1,) + ) nxcollection[hkl_name].strains.errors = NXfield( - value=values, shape=[shape[0]], dtype=np.float64) + value=values, shape=[shape[0]], dtype=np.float64, + maxshape=(None,), chunks=(1,) + ) nxcollection[hkl_name].strains.residuals = NXfield( - value=values, shape=[shape[0]], dtype=np.float64) + value=values, shape=[shape[0]], dtype=np.float64, + maxshape=(None,), chunks=(1,) + ) nxcollection[hkl_name].strains.attrs['signal'] = 'values' + def _populate_peak_fit_info(self): + """Populate _peak_fit_info with all configured HKLs for each + detector, with all peaks marked as used. + + Called when no raw detector spectra are present (e.g. a 0-scan + setup run), so that the resulting NXprocess contains entries for + every HKL that could be encountered in a subsequent update run. + """ + # Local modules + from CHAP.edd.utils import get_peak_locations, get_unique_hkls_ds + + for detector in self.detector_config.detectors: + hkls, ds = get_unique_hkls_ds( + self.config.materials, tth_max=detector.tth_max, + tth_tol=detector.tth_tol) + hkls_fit = np.asarray([hkls[i] for i in detector.hkl_indices]) + ds_fit = np.asarray([ds[i] for i in detector.hkl_indices]) + peak_locations = get_peak_locations(ds_fit, detector.tth_calibrated) + self._peak_fit_info.append({ + 'hkls': [''.join(map(str, hkl)) for hkl in hkls_fit], + 'nominal_peak_centers': peak_locations.tolist(), + 'peak_models': detector.peak_models, + 'use_peaks': np.ones(len(hkls_fit), dtype=bool).tolist(), + }) def _create_animation( self, nxdata, energies, intensities, intensity_norms, best_fits, @@ -2996,18 +3153,48 @@ def animate(i): f'{detector_id}_strainanalysis_unconstrained_fits')) plt.close() - def _get_nxroot(self, nxentry, calibration_config): - """Return a NeXus style - `NXroot `__ - object initialized for the stress analysis. + + def _get_points(self, results): + """Convert strain analysis results to list-of-dicts format + expected by `add_points`. + + :param results: Strain analysis results, mapping NeXus path + strings to arrays with leading axis = num_points. + :type results: dict[str, numpy.ndarray] + :return: One dict per map coordinate, each containing all + result values at that coordinate. + :rtype: list[dict[str, object]] + """ + if not results: + return [] + num_points = next(iter(results.values())).shape[0] + return [{k: v[i] for k, v in results.items()} + for i in range(num_points)] + + def _get_values(self, results): + """Convert strain analysis results to a list of value dicts + suitable for writing with ``common.NexusValuesWriter``. + + :param results: Strain analysis results, mapping NeXus path + strings to arrays with leading axis = num_points. + :type results: dict[str, numpy.ndarray] + :return: List of ``{'data': array, 'path': str}`` dicts. + :rtype: list[dict[str, object]] + """ + return [{'data': v, 'path': k} for k, v in results.items()] + + def _get_nxprocess(self, nxentry, calibration_config): + """Return a standalone NXprocess for the strain analysis + results & metadata. :param nxentry: Strain analysis map, including the raw MCA data. :type nxentry: nexusformat.nexus.NXentry :param calibration_config: 2&theta calibration configuration. - :type calibration_config: MCATthCalibrationConfig - :return: Strain analysis results and associated metadata. - :rtype: nexusformat.nexus.NXroot + :type calibration_config: + CHAP.edd.models.MCATthCalibrationConfig + :return: Strain analysis results & associated metadata. + :rtype: nexusformat.nexus.NXprocess """ # Third party modules # pylint: disable=no-name-in-module @@ -3015,7 +3202,6 @@ def _get_nxroot(self, nxentry, calibration_config): NXdetector, NXfield, NXprocess, - NXroot, ) # pylint: enable=no-name-in-module @@ -3032,16 +3218,18 @@ def _get_nxroot(self, nxentry, calibration_config): 'StrainAnalysis Configuration, or re-run the pipeline with ' 'the --interactive flag.') - # Create the NXroot object - nxroot = NXroot() - nxroot[nxentry.nxname] = nxentry - nxroot[f'{nxentry.nxname}_strainanalysis'] = NXprocess() - nxprocess = nxroot[f'{nxentry.nxname}_strainanalysis'] + nxprocess = NXprocess() nxprocess.calibration_config = \ calibration_config.model_dump_json() nxprocess.strain_analysis_config = \ self.config.model_dump_json() + if len(self._peak_fit_info) == 0: + # FIX this is a temporary fix to be able to run update + # after setup. + self._peak_fit_info = [{}] * len(self.detector_config.detectors) + self.logger.warning('Missing peak_fit_info') + # Loop over the detectors to fill in the nxprocess for energies, mask, nxdata, detector, peak_fit_info in zip( self._energies, self._masks, self._nxdata_detectors, @@ -3073,35 +3261,53 @@ def _get_nxroot(self, nxentry, calibration_config): value=energies[mask], attrs={'units': 'keV'}) det_nxdata.norm = NXfield( dtype=np.float64, - shape=(num_points,)) + shape=(num_points,), + maxshape=(None,), chunks=(1,) + ) det_nxdata.tth = NXfield( dtype=np.float64, shape=(num_points,), - attrs={'units':'degrees', 'long_name': '2\u03B8 (degrees)'}) + attrs={'units':'degrees', 'long_name': '2\u03B8 (degrees)'}, + maxshape=(None,), chunks=(1,), + ) det_nxdata.uniform_strain = NXfield( dtype=np.float64, shape=(num_points,), - attrs={'long_name': 'Strain from uniform fit (\u03B5)'}) - #attrs={'long_name': 'Strain from uniform fit (\u03BC\u03B5)'}) + attrs={'long_name': 'Strain from uniform fit (\u03B5)'}, + # attrs={'long_name': 'Strain from uniform fit (\u03BC\u03B5)'} + maxshape=(None,), chunks=(1,) + ) + det_nxdata.unconstrained_strain = NXfield( dtype=np.float64, shape=(num_points,), attrs={'long_name': - 'Strain from unconstrained fit (\u03B5)'}) - #'Strain from unconstrained fit (\u03BC\u03B5)'}) + 'Strain from unconstrained fit (\u03B5)'}, + #'Strain from unconstrained fit (\u03BC\u03B5)'}, + maxshape=(None,), chunks=(1,) + ) det_nxdata.unconstrained_strain_stdev = NXfield( dtype=np.float64, shape=(num_points,), attrs={'long_name': 'Standard deviation in strain from unconstrained ' - 'fit (\u03B5)'}) - #'fit (\u03BC\u03B5)'}) + 'fit (\u03B5)'}, + #'fit (\u03BC\u03B5)'} + maxshape=(None,), chunks=(1,) + ) # Add the detector data + num_energy_bins = mask.sum() + _intensity = np.empty( + (num_points, num_energy_bins), dtype=np.float64) + for i in range(num_points): + _intensity[i] = data[i].astype(np.float64)[mask] det_nxdata.intensity = NXfield( - value=np.asarray([data[i].astype(np.float64)[mask] - for i in range(num_points)]), - attrs={'units': 'counts'}) + value=_intensity, + attrs={'units': 'counts'}, + maxshape=(None, num_energy_bins), + chunks=(1, num_energy_bins) + ) det_nxdata.attrs['signal'] = 'intensity' # Get the unique HKLs and lattice spacings for the strain @@ -3126,6 +3332,30 @@ def _get_nxroot(self, nxentry, calibration_config): tth_map = detector.get_tth_map((num_points,)) det_nxdata.tth.nxdata = tth_map + return nxprocess + + def _get_nxroot(self, nxentry, calibration_config): + """Return a `nexusformat.nexus.NXroot` object initialized for + the strain analysis. + + :param nxentry: Strain analysis map, including the raw + MCA data. + :type nxentry: nexusformat.nexus.NXentry + :param calibration_config: 2&theta calibration configuration. + :type calibration_config: + CHAP.edd.models.MCATthCalibrationConfig + :return: Strain analysis results & associated metadata. + :rtype: nexusformat.nexus.NXroot + """ + # Third party modules + # pylint: disable=no-name-in-module + from nexusformat.nexus import NXroot + # pylint: enable=no-name-in-module + + nxroot = NXroot() + nxroot[nxentry.nxname] = nxentry + nxprocess = self._get_nxprocess(nxentry, calibration_config) + nxroot[f'{nxentry.nxname}_strainanalysis'] = nxprocess return nxroot def _linkdims( @@ -3135,9 +3365,16 @@ def _linkdims( object. """ # Third party modules - from nexusformat.nexus import NXfield + from nexusformat.nexus import NXfield, NXroot from nexusformat.nexus.tree import NXlinkfield + if not isinstance(nxgroup.nxroot, NXroot): + self.logger.debug( + 'Skipping linkdims -- type(nxgroup.nxroot) = ' + + f'{type(nxgroup.nxroot)}' + ) + return + if skip_field_dims is None: skip_field_dims = [] if oversampling_axis is None: @@ -3189,7 +3426,13 @@ def _linkdims( nxgroup.attrs['unstructured_axes'] = unstructured_axes def _strain_analysis(self): - """Perform the strain analysis on the full or partial map.""" + """Perform the strain analysis on the full or partial map. + + :return: Strain analysis results mapping NeXus path strings to + arrays with leading axis = num_points. Returns an empty + dict when no valid peaks are found. + :rtype: dict[str, numpy.ndarray] + """ # Third party modules from nexusformat.nexus import NXfield @@ -3212,25 +3455,23 @@ def _strain_analysis(self): f'{self.detector_config.detectors[0].get_id()}_' 'strainanalysis_material_config')) - # Setup the points list with the map axes values + # Setup the results dict with the map axes values nxdata_ref = self._nxdata_detectors[0] axes = get_axes(nxdata_ref) if axes: - points = [ - {a: nxdata_ref[a].nxdata[i] for a in axes} - for i in range(nxdata_ref[axes[0]].size)] + num_points = nxdata_ref[axes[0]].size + results = {a: nxdata_ref[a].nxdata for a in axes} else: axes = ['index'] - points = [ - {'index': i} - for i in range(np.prod(nxdata_ref.nxsignal.shape[:-1]))] + num_points = int(np.prod(nxdata_ref.nxsignal.shape[:-1])) + results = {'index': np.arange(num_points)} for nxdata in self._nxdata_detectors: nxdata.attrs['axes'] = axes nxdata.index = NXfield( - np.arange(np.prod(nxdata_ref.nxsignal.shape[:-1])), + np.arange(num_points), 'index') - # Loop over the detectors to fill in the nxprocess + # Loop over the detectors to fill in the results dict for energies, mask, mean_data, nxdata, detector in zip( self._energies, self._masks, self._mean_data, self._nxdata_detectors, self.detector_config.detectors): @@ -3300,7 +3541,7 @@ def _strain_analysis(self): self.logger.warning( 'No matching peaks with heights above the threshold, ' f'skipping the fit for detector {detector.get_id()}') - return [] + return {} self._peak_fit_info.append({ 'hkls': ["".join(map(str, hkl)) for hkl in hkls_fit], 'nominal_peak_centers': peak_locations.tolist(), @@ -3312,8 +3553,9 @@ def _strain_analysis(self): uniform_results, unconstrained_results = get_spectra_fits( np.squeeze(intensities), energies[mask], peak_locations[use_peaks], detector, - num_proc=self.config.num_proc, **self.run_config) - if intensities.shape[0] == 1: + num_proc=self.config.num_proc, + max_nfev=self.config.max_nfev, **self.run_config) + if num_points == 1: uniform_results = {k: [v] for k, v in uniform_results.items()} unconstrained_results = { k: [v] for k, v in unconstrained_results.items()} @@ -3332,8 +3574,7 @@ def _strain_analysis(self): self.logger.info('... done') - # Add the fit results to the list of points - num_points = len(points) + # Compute the strain analysis results for all map points tth_map = detector.get_tth_map((nxdata.shape[0],)) nominal_centers = np.asarray( [get_peak_locations(d0, tth_map) @@ -3369,109 +3610,126 @@ def _strain_analysis(self): unconstrained_amplitudes_vary, insert_peak_indices, [False], axis=-1) - # Add points - for i, point in enumerate(points): - point.update({ - f'{detector.get_id()}/data/intensity': intensities[i], - f'{detector.get_id()}/data/norm': intensity_norms[i], - f'{detector.get_id()}/data/uniform_strain': - uniform_strain[i], - f'{detector.get_id()}/data/unconstrained_strain': - unconstrained_strain[i], - f'{detector.get_id()}/data/unconstrained_strain_stdev': - unconstrained_strain_stdev[i], - f'{detector.get_id()}/uniform_fit/results/best_fit': - uniform_results['best_fits'][i], - f'{detector.get_id()}/uniform_fit/results/included_peaks': - uniform_amplitudes_vary[i], - f'{detector.get_id()}/uniform_fit/results/residual': - uniform_results['residuals'][i], - f'{detector.get_id()}/uniform_fit/results/redchi': - uniform_results['redchis'][i], - f'{detector.get_id()}/uniform_fit/results/success': - uniform_results['success'][i], - f'{detector.get_id()}/unconstrained_fit/results/best_fit': - unconstrained_results['best_fits'][i], - f'{detector.get_id()}/unconstrained_fit/results/' - 'included_peaks': unconstrained_amplitudes_vary[i], - f'{detector.get_id()}/unconstrained_fit/results/residual': - unconstrained_results['residuals'][i], - f'{detector.get_id()}/unconstrained_fit/results/redchi': - unconstrained_results['redchis'][i], - f'{detector.get_id()}/unconstrained_fit/results/success': - unconstrained_results['success'][i], + # Store results as full arrays (leading axis = num_points) + det_id = detector.get_id() + results.update({ + f'{det_id}/data/intensity': intensities, + f'{det_id}/data/norm': intensity_norms, + f'{det_id}/data/uniform_strain': uniform_strain, + f'{det_id}/data/unconstrained_strain': unconstrained_strain, + f'{det_id}/data/unconstrained_strain_stdev': + unconstrained_strain_stdev, + f'{det_id}/uniform_fit/results/best_fit': + np.asarray(uniform_results['best_fits']), + f'{det_id}/uniform_fit/results/included_peaks': + uniform_amplitudes_vary, + f'{det_id}/uniform_fit/results/residual': + np.asarray(uniform_results['residuals']), + f'{det_id}/uniform_fit/results/redchi': + np.asarray(uniform_results['redchis']), + f'{det_id}/uniform_fit/results/success': + np.asarray(uniform_results['success']), + f'{det_id}/unconstrained_fit/results/best_fit': + np.asarray(unconstrained_results['best_fits']), + f'{det_id}/unconstrained_fit/results/included_peaks': + unconstrained_amplitudes_vary, + f'{det_id}/unconstrained_fit/results/residual': + np.asarray(unconstrained_results['residuals']), + f'{det_id}/unconstrained_fit/results/redchi': + np.asarray(unconstrained_results['redchis']), + f'{det_id}/unconstrained_fit/results/success': + np.asarray(unconstrained_results['success']), + }) + for j, hkl in enumerate(hkls_fit[use_peaks]): + hkl_name = '_'.join(str(hkl)[1:-1].split(' ')) + uniform_fit_path = f'{det_id}/uniform_fit/{hkl_name}' + unconstrained_fit_path = \ + f'{det_id}/unconstrained_fit/{hkl_name}' + results.update({ + f'{uniform_fit_path}/amplitudes/values': + np.asarray(uniform_results['amplitudes'][j]), + f'{uniform_fit_path}/amplitudes/errors': + np.asarray(uniform_results['amplitudes_errors'][j]), + f'{uniform_fit_path}/centers/values': + uniform_centers[j], + f'{uniform_fit_path}/centers/errors': + np.asarray(uniform_results['centers_errors'][j]), + f'{uniform_fit_path}/sigmas/values': + np.asarray(uniform_results['sigmas'][j]), + f'{uniform_fit_path}/sigmas/errors': + np.asarray(uniform_results['sigmas_errors'][j]), + f'{unconstrained_fit_path}/amplitudes/values': + np.asarray(unconstrained_results['amplitudes'][j]), + f'{unconstrained_fit_path}/amplitudes/errors': + np.asarray( + unconstrained_results['amplitudes_errors'][j]), + f'{unconstrained_fit_path}/centers/values': + unconstrained_centers[j], + f'{unconstrained_fit_path}/centers/errors': + np.asarray( + unconstrained_results['centers_errors'][j]), + f'{unconstrained_fit_path}/sigmas/values': + np.asarray(unconstrained_results['sigmas'][j]), + f'{unconstrained_fit_path}/sigmas/errors': + np.asarray(unconstrained_results['sigmas_errors'][j]), }) - for j, hkl in enumerate(hkls_fit[use_peaks]): - hkl_name = '_'.join(str(hkl)[1:-1].split(' ')) - uniform_fit_path = \ - f'{detector.get_id()}/uniform_fit/{hkl_name}' - unconstrained_fit_path = \ - f'{detector.get_id()}/unconstrained_fit/{hkl_name}' - point.update({ - f'{uniform_fit_path}/amplitudes/values': - uniform_results['amplitudes'][j][i], - f'{uniform_fit_path}/amplitudes/errors': - uniform_results['amplitudes_errors'][j][i], - f'{uniform_fit_path}/centers/values': - uniform_centers[j][i], - f'{uniform_fit_path}/centers/errors': - uniform_results['centers_errors'][j][i], - f'{uniform_fit_path}/sigmas/values': - uniform_results['sigmas'][j][i], - f'{uniform_fit_path}/sigmas/errors': - uniform_results['sigmas_errors'][j][i], - f'{unconstrained_fit_path}/amplitudes/values': - unconstrained_results['amplitudes'][j][i], - f'{unconstrained_fit_path}/amplitudes/errors': - unconstrained_results['amplitudes_errors'][j][i], - f'{unconstrained_fit_path}/centers/values': - unconstrained_centers[j][i], - f'{unconstrained_fit_path}/centers/errors': - unconstrained_results['centers_errors'][j][i], - f'{unconstrained_fit_path}/sigmas/values': - unconstrained_results['sigmas'][j][i], - f'{unconstrained_fit_path}/sigmas/errors': - unconstrained_results['sigmas_errors'][j][i], + if detector.peak_models == 'pvoigt': + results.update({ + f'{uniform_fit_path}/fractions/values': + np.asarray(uniform_results['fractions'][j]), + f'{uniform_fit_path}/fractions/errors': + np.asarray( + uniform_results['fractions_errors'][j]), + f'{unconstrained_fit_path}/fractions/values': + np.asarray( + unconstrained_results['fractions'][j]), + f'{unconstrained_fit_path}/fractions/errors': + np.asarray( + unconstrained_results['fractions_errors'][j]), }) - if detector.peak_models == 'pvoigt': - point.update({ - f'{uniform_fit_path}/fractions/values': - uniform_results['fractions'][j][i], - f'{uniform_fit_path}/fractions/errors': - uniform_results['fractions_errors'][j][i], - f'{unconstrained_fit_path}/fractions/values': - unconstrained_results['fractions'][j][i], - f'{unconstrained_fit_path}/fractions/errors': - unconstrained_results['fractions_errors'][j][i], - }) - if unconstrained_centers[j][i]: - point.update({ - f'{unconstrained_fit_path}/strains/values': - unconstrained_strains[j][i], - f'{unconstrained_fit_path}/strains/residuals': - unconstrained_strain[i] - - unconstrained_strains[j][i], - }) - if (unconstrained_results['centers_errors'][j][i] - is None): - point.update({ - f'{unconstrained_fit_path}/strains/errors': - None, - }) - else: - point.update({ - f'{unconstrained_fit_path}/strains/errors': - unconstrained_results[ - 'centers_errors'][j][i] / - unconstrained_centers[j][i], - }) - else: - point.update({ - f'{unconstrained_fit_path}/strains/values': None, - f'{unconstrained_fit_path}/strains/errors': None, - f'{unconstrained_fit_path}/strains/residuals': - None, - }) + # Strain values: NaN where unconstrained center is zero + has_center = unconstrained_centers[j].astype(bool) + centers_errors_j = unconstrained_results['centers_errors'][j] + strain_errors_raw = np.asarray([ + (e / c if e is not None else np.nan) + for e, c in zip(centers_errors_j, + unconstrained_centers[j]) + ]) + results.update({ + f'{unconstrained_fit_path}/strains/values': + np.where(has_center, unconstrained_strains[j], + np.nan), + f'{unconstrained_fit_path}/strains/errors': + np.where(has_center, strain_errors_raw, np.nan), + f'{unconstrained_fit_path}/strains/residuals': + np.where( + has_center, + unconstrained_strain - unconstrained_strains[j], + np.nan), + }) + + if self.json_results: + # Include placeholder values for unused peaks in results + placeholder = np.full(num_points, np.nan) + for j, hkl in enumerate(hkls_fit[~use_peaks]): + hkl_name = '_'.join(str(hkl)[1:-1].split(' ')) + for fitmode in ('unconstrained', 'uniform'): + fitparams = ('amplitudes', 'centers', 'sigmas') + if detector.peak_models == 'pvoigt': + fitparams = (*fitparams, 'fractions') + if fitmode == 'unconstrained': + fitparams = (*fitparams, 'strains') + for fitparam in fitparams: + fitquants = ('values', 'errors') + if fitmode == 'unconstrained' \ + and fitparam == 'strains': + fitquants = (*fitquants, 'residuals') + for fitquant in fitquants: + key = str( + f'{det_id}/{fitmode}_fit/{hkl_name}/' + f'{fitparam}/{fitquant}' + ) + results.update({key: placeholder}) # Create an animation of the fit points if (not self.config.skip_animation @@ -3480,7 +3738,7 @@ def _strain_analysis(self): nxdata, energies[mask], intensities, intensity_norms, unconstrained_results['best_fits'], detector.get_id()) - return points + return results if __name__ == '__main__': diff --git a/CHAP/edd/reader.py b/CHAP/edd/reader.py index fe597556..4a89b18d 100755 --- a/CHAP/edd/reader.py +++ b/CHAP/edd/reader.py @@ -16,6 +16,7 @@ conlist, constr, field_validator, + model_validator, ) # Local modules @@ -762,11 +763,12 @@ class SliceNXdataReader(Reader): object and slices all fields according to the provided slicing parameters. - :ivar scan_number: SPEC scan number. - :vartype scan_number: int + :ivar scan_numbers: Numbers of scans from which to read slices of + raw data. + :vartype scan_numbers: list[int] """ - - scan_number: conint(ge=0) + scan_number: Optional[conint(gt=0)] = None + scan_numbers: Optional[conlist(item_type=conint(gt=0))] = None def read(self): """Reads a NeXus style @@ -783,32 +785,91 @@ def read(self): from nexusformat.nexus import NXentry, NXfield # Local modules - from CHAP.common import NexusReader + from CHAP.common.reader import NexusReader from CHAP.utils.general import nxcopy + # Read NXroot reader = NexusReader(**self.model_dump()) - nxroot = nxcopy(reader.read()) - nxdata = None - for nxname, nxobject in nxroot.items(): - if isinstance(nxobject, NXentry): - nxdata = nxobject.data + nxroot = reader.read() + + # Locate NXentry + nxentry = next( + (obj for obj in nxroot.values() if isinstance(obj, NXentry)), + None, + ) + if nxentry is None: + raise ValueError('Could not find NXentry group') + nxentry_nxpath = nxentry.nxpath + self.logger.info(f'Using NXentry at: {nxentry_nxpath}') + + # Make a copy of the NXroot, excluding everything but the + # NXentry of interest. Do this so we can just slice the + # NXfields in place in the copy (because copy is not tied to + # the original input file). + exclude_nxpaths = [] + for v in nxroot.values(): + if v.nxpath != nxentry_nxpath: + exclude_nxpaths.append(v.nxpath) + nxroot = nxcopy(nxroot, exclude_nxpaths=exclude_nxpaths) + nxentry = nxroot[nxentry_nxpath] + + # Locate NXdata containining the "SCAN_N" NXfield + nxdata = getattr(nxentry, 'data', None) if nxdata is None: - msg = 'Could not find NXdata group' - self.logger.error(msg) - raise ValueError(msg) + self.logger.warning('NXdata group missing — searching fallback') - indices = np.argwhere( - nxdata.SCAN_N.nxdata == self.scan_number).flatten() - for nxname, nxobject in nxdata.items(): - if isinstance(nxobject, NXfield): - nxdata[nxname] = NXfield( - value=nxobject.nxdata[indices], - dtype=nxdata[nxname].dtype, - attrs=nxdata[nxname].attrs, - ) + for v in nxentry.values(): + if hasattr(v, 'SCAN_N'): + nxdata = v + break + + if nxdata is None: + raise ValueError('Cannot find SCAN_N dataset') + self.logger.info(f'Using NXdata at: {nxdata.nxpath}') + + # Get indicies of SCAN_N that match self.scan_number + scan_field = nxdata['SCAN_N'].nxdata + indices = np.flatnonzero(np.isin(scan_field, self.scan_numbers)) + + if indices.size == 0: + self.logger.warning( + f'scan_number {self.scan_number} not found in SCAN_N' + ) + self.logger.info(f'Slicing NXfields with: {indices}') + + # Slice only NXfields + for name, obj in list(nxdata.items()): + if isinstance(obj, NXfield): + self.logger.info(f'Slicing NXfield at: {obj.nxpath}') + nxdata[name] = obj[indices] return nxroot + @model_validator(mode='before') + @classmethod + def fill_scan_numbers(cls, data): + if not isinstance(data, dict): + return data + if 'scan_numbers' not in data or data['scan_numbers'] is None: + if data.get('scan_number') is not None: + data['scan_numbers'] = [data['scan_number']] + elif isinstance(data['scan_numbers'], int): + data['scan_numbers'] = [data['scan_numbers']] + elif isinstance(data['scan_numbers'], str): + from CHAP.utils.general import string_to_list + data['scan_numbers'] = string_to_list(data['scan_numbers']) + return data + + @model_validator(mode='after') + def validate_scan_numbers(self): + if self.scan_numbers is None: + raise ValueError( + 'scan_numbers is required; alternatively, provide scan_number') + if self.scan_number is not None \ + and self.scan_number not in self.scan_numbers: + self.scan_numbers.append(self.scan_number) + return self + class UpdateNXdataReader(Reader): """Companion to :class:`~CHAP.edd.reader.SetupNXdataReader` and :class:`~CHAP.common.processor.UpdateNXDataProcessor`. Constructs diff --git a/CHAP/edd/utils.py b/CHAP/edd/utils.py index 8079d4b5..dceb4a78 100755 --- a/CHAP/edd/utils.py +++ b/CHAP/edd/utils.py @@ -1311,10 +1311,10 @@ def get_spectra_fits( # Local modules from CHAP.utils.fit import FitProcessor - num_proc = kwargs.get('num_proc', 1) + num_proc = kwargs.pop('num_proc', 1) + max_nfev = kwargs.pop('max_nfev', 64000) rel_height_cutoff = detector.rel_height_cutoff num_peak = len(peak_locations) - nxdata = NXdata(NXfield(spectra, 'y'), NXfield(energies, 'x')) # Construct the fit model models = [] @@ -1341,6 +1341,7 @@ def get_spectra_fits( 'models': models, # 'plot': True, 'num_proc': num_proc, + 'max_nfev': max_nfev, 'rel_height_cutoff': rel_height_cutoff, # 'method': 'trf', 'method': 'leastsq', @@ -1350,8 +1351,8 @@ def get_spectra_fits( # Perform uniform fit # FIX make more generic for fit parameters - fit = FitProcessor(**kwargs) - uniform_fit = fit.process(nxdata, config) + uniform_fit = FitProcessor.run( + data={'x': energies, 'y': spectra}, config=config, **kwargs) uniform_success = uniform_fit.success if spectra.ndim == 1: if uniform_success: @@ -1518,7 +1519,8 @@ def get_spectra_fits( # Perform unconstrained fit config['models'][-1]['fit_type'] = 'unconstrained' - unconstrained_fit = fit.process(uniform_fit, config) + unconstrained_fit = FitProcessor.run( + data=uniform_fit, config=config, **kwargs) unconstrained_success = unconstrained_fit.success if spectra.ndim == 1: if unconstrained_success: diff --git a/CHAP/models.py b/CHAP/models.py index 7a7df819..435c2b57 100755 --- a/CHAP/models.py +++ b/CHAP/models.py @@ -33,8 +33,7 @@ class CHAPBaseModel(BaseModel): def dict(self, *args, **kwargs): """Dump the class implemention to a dictionary. - :param \*\*kwargs: Arbitrary keyword arguments. - :type: dict + :param \*\*kwargs: Optional keyword arguments, including: :keyword exclude: Class variable(s) to omit from the output dictionary. :type exclude: dict or set, optional @@ -49,8 +48,7 @@ class variables that have an alias., defaults to `True`. def model_dump(self, *args, **kwargs): """Dump the class implemention to a dictionary. - :param \*\*kwargs: Arbitrary keyword arguments. - :type: dict + :param \*\*kwargs: Optional keyword arguments, including: :keyword exclude: Class variable(s) to omit from the output dictionary. :type exclude: dict or set, optional @@ -70,8 +68,7 @@ class variables that have an alias., defaults to `True`. def model_dump_json(self, *args, **kwargs): """Dump the class implemention to a JSON string. - :param \*\*kwargs: Arbitrary keyword arguments. - :type: dict + :param \*\*kwargs: Optional keyword arguments, including: :keyword exclude: Class variable(s) to omit from the output dictionary. :type exclude: dict or set, optional diff --git a/CHAP/pipeline.py b/CHAP/pipeline.py index ce854e1f..ba046a68 100755 --- a/CHAP/pipeline.py +++ b/CHAP/pipeline.py @@ -86,9 +86,9 @@ def validate_pipelineitem_after(self): self.logger.propagate = False log_handler = logging.StreamHandler() log_handler.setFormatter(logging.Formatter( - '{asctime}: {name:20}: {levelname}: {message}', + '{asctime}: {name:20} (L{lineno}): {levelname}: {message}', datefmt='%Y-%m-%d %H:%M:%S', style='{')) - self.logger.addHandler(log_handler) + self.logger.handlers = [log_handler] self.logger.setLevel(self.log_level) # Optinal, but it's already available in the 'name' field #if self.get_schema() is None: @@ -269,7 +269,7 @@ def get_data(data, name=None, schema=None, remove=True): `NXobject `__ object or matches a given name or schema. Pick the last item for which the `'name'` key matches `name` if set or the `'schema'` key matches - `schema` if set, pick the last match for a `NXobjecta` object + `schema` if set, pick the last match for a `NXobject` object otherwise. Return the data object. :param data: Input data. @@ -416,7 +416,7 @@ def get_pipelinedata_item(data, index=-1, remove=False): :return: Matching data item. :rtype: Any """ - if isinstance(data, list): + if isinstance(data, list) and isinstance(data[index], PipelineData): if remove: return data.pop(index)['data'] return data[index]['data'] @@ -444,18 +444,17 @@ def unwrap_pipelinedata(data): unwrapped_data = [data] return unwrapped_data - def execute(self, data):#, metadata, provenance): + def execute(self, data=None):#, metadata, provenance): """Execute the appropriate method of the object and return the result. :param data: Input data. - :type data: list[PipelineData] + :type data: list[PipelineData], optional :return: Wrapped result of executing read, process, or write. :rtype: PipelineData | tuple[PipelineData] """ # self._metadata = metadata # self._provenance = provenance - if 'data' in self._allowed_args: self._args['data'] = data t0 = time() @@ -467,6 +466,84 @@ def execute(self, data):#, metadata, provenance): f'Finished "{self._method}" in {time()-t0:.0f} seconds\n') return data + @classmethod + def run(cls, **kwargs): + """Execute the appropriate method of the object and return the + result. + + This class method gets and executes the appropriate method + (process, read or write) from the pipeline item it's called + from. It is intended to be called from a script or notebook + only and should not be called from other CHAP Processors, + Readers or Writers. + + The method expects the same parameters as those used to + instantiate its class object and run the process, read or + write method, in addition to any run time parameter in the + pipeline file config dictionary (see: + :class:`~CHAP.models.RunConfig)`. + + :param \*\*kwargs: Optional keyword arguments, including: + :keyword config: Initialization parameters for an instance + of the pipeline item this method is called from (often + used by Readers and Processors). + :type config: dict, optional + :keyword data: Input data (required for any Processor, but + is allowed to be `None` or `[]`). + :type data: list[PipelineData], optional + :keyword filename: Name of file to read (required for most + Readers and Writers). + :type filename: str, optional + :keyword force_overwrite: Flag to allow data in `filename` + to be overwritten if it already exists, defaults to + `False` (optional for Writers). + :type force_overwrite: bool, optional + :keyword remove: Flag to remove the dictionary from `data`, + defaults to `False` (optional for Writers). + :type remove: bool, optional + :return: Returned result from executing the underlying read, + process, or write method. + :rtype: Any + """ + # System modules + from importlib import import_module + from inspect import isclass + from pkgutil import walk_packages + + # Local modules + from CHAP.utils.general import input_menu + + def _find_class_in_package(package_name, class_name): + package = import_module(package_name) + # Recursively walk through all submodules + found_classes = [] + for _, module_name, _ in walk_packages( + package.__path__, package.__name__ + "."): + try: + module = import_module(module_name) + # Check if the class exists in this module + if hasattr(module, class_name): + cls = getattr(module, class_name) + if (isclass(cls) and cls.__module__ == module_name + and cls not in found_classes): + found_classes.append(cls) + except ImportError: + continue + if not found_classes: + raise ImportError(f'Unable to find {class_name} in CHAP') + if len(found_classes) == 1: + return found_classes[0] + index = input_menu( + [v.__module__ for v in found_classes], + header=f'\nFound multiple classes named {class_name} in ' + f'CHAP\nUse {class_name} from:') + return found_classes[index] + + cls_name = cls.__name__ + mmc = _find_class_in_package('CHAP', cls_name) + item = mmc(modelmetaclass=mmc, **kwargs) + return item.execute(kwargs.get('data')) + class Pipeline(CHAPBaseModel): """Class representing a full `Pipeline` object. diff --git a/CHAP/runner.py b/CHAP/runner.py index 9e191d16..b7045658 100755 --- a/CHAP/runner.py +++ b/CHAP/runner.py @@ -227,9 +227,9 @@ def set_logger(log_level='INFO'): logger.setLevel(log_level) log_handler = logging.StreamHandler() log_handler.setFormatter(logging.Formatter( - '{asctime}: {name:20}: {levelname}: {message}', + '{asctime}: {name:20} (L{lineno}): {levelname}: {message}', datefmt='%Y-%m-%d %H:%M:%S', style='{')) - logger.addHandler(log_handler) + logger.handlers = [log_handler] return logger, log_handler def run( diff --git a/CHAP/saxswaxs/models.py b/CHAP/saxswaxs/models.py new file mode 100644 index 00000000..25d30d22 --- /dev/null +++ b/CHAP/saxswaxs/models.py @@ -0,0 +1,406 @@ +"""Models to help construct containers for results of +``saxswaxs.*CorrectionProcessor`` tools.""" + +# System modules +from functools import cached_property +import os +import re +from typing import Literal, Optional, Union + +# Third party modules +from pydantic import ( + confloat, + conlist, + model_validator, + Field, + AliasChoices, +) + +# Local modules +from CHAP import version as chap_version +from CHAP.models import CHAPBaseModel +from CHAP.common.models.common import IndexSliceConfig +from CHAP.common.models.map import SpecScans + + +class Background(SpecScans): + """Configuration for background scan data associated with a + correction. + + Extends :class:`~CHAP.common.models.map.SpecScans` with an + optional index slice so that only a subset of scan steps is used + when reading background images. + + Accepts either ``idx_slice`` (a :class:`~CHAP.common.models.common.IndexSliceConfig` + dict) or the convenience field ``scan_step_indices`` (a list of + integers or a compact string such as ``"0-4, 6"``), but not both. + When ``scan_step_indices`` is given it is converted to the + equivalent ``idx_slice``; the indices must form a uniformly-spaced + sequence expressible as a Python :class:`slice`. + + :ivar idx_slice: Index slice selecting which scan steps of the + background scan(s) to read and average. Defaults to all steps. + :vartype idx_slice: IndexSliceConfig + :ivar scan_step_indices: Convenience alternative to ``idx_slice``. + A list of integer step indices (or a compact string such as + ``"0-4, 6"``) that are converted to an ``idx_slice`` during + validation. The indices must be uniformly spaced. Mutually + exclusive with ``idx_slice``. + :vartype scan_step_indices: list[int] or str, optional + """ + + idx_slice: IndexSliceConfig = IndexSliceConfig() + scan_step_indices: Optional[Union[list[int], str]] = None + + @model_validator(mode='before') + @classmethod + def fill_idx_slice(cls, data): + scan_step_indices = data.get('scan_step_indices') + idx_slice = data.get('idx_slice') + if scan_step_indices is not None and idx_slice is not None: + raise ValueError( + 'Specify idx_slice or scan_step_indices, not both.') + if scan_step_indices is not None: + if isinstance(scan_step_indices, str): + from CHAP.utils.general import string_to_list + scan_step_indices = string_to_list(scan_step_indices) + # scan_step_indices is now list[int]; derive a uniform slice + indices = sorted(scan_step_indices) + if len(indices) == 1: + start, step = indices[0], 1 + else: + step = indices[1] - indices[0] + if step <= 0: + raise ValueError( + 'scan_step_indices must contain distinct, ' + 'positive-step values.') + diffs = [indices[i+1] - indices[i] + for i in range(len(indices) - 1)] + if len(set(diffs)) != 1: + raise ValueError( + 'scan_step_indices must be uniformly spaced so ' + 'they can be expressed as a slice.') + start = indices[0] + stop = indices[-1] + step + data['idx_slice'] = {'start': start, 'stop': stop, 'step': step} + return data + + def zarr_arrays(self, integration_shape): + """Return a dictionary describing the zarr array that will hold + the averaged integrated background intensity. + + :param integration_shape: Shape of one frame of integration + results, as returned by + :attr:`~CHAP.common.models.integration.PyfaiIntegratorConfig.result_shape`. + :type integration_shape: tuple[int, ...] + :returns: Dict mapping the array name ``'I_background'`` to a + zarr array specification (``dtype``, ``shape``, and + ``attributes`` keys) compatible with + :func:`~CHAP.saxswaxs.utils.dict_to_zarr`. + :rtype: dict + """ + return { + 'I_background': { + 'attributes': { + 'long_name': 'Intensity (a.u)', + 'units': 'a.u,' + }, + 'dtype': 'float64', + 'shape': integration_shape, + } + } + + +class CorrectionConfig(CHAPBaseModel): + """Base configuration for a single SAXS/WAXS correction step. + + Describes one correction (flux, flux+absorption, or + flux+absorption+background) to be applied to azimuthally integrated + detector data. Subclasses pin ``correction_type`` and may require + additional fields such as ``background``. + + :ivar correction_type: Identifies the correction algorithm. + One of ``'flux'``, ``'flux_absorption'``, or + ``'flux_absorption_background'``. + :vartype correction_type: str + :ivar name: Human-readable name used as the key for this + correction's group in the output zarr / NeXus tree. + :vartype name: str + :ivar uncorrected_data_name: Name of the + :class:`~CHAP.common.models.integration.PyfaiIntegrationConfig` + integration whose output serves as the uncorrected input for + this correction. Must match the ``name`` field of one of the + integrations in the associated + :class:`~CHAP.common.models.integration.PyfaiIntegrationConfig`. + :vartype uncorrected_data_name: str + :ivar presample_intensity_reference_rate: Fixed reference counting + rate for the pre-sample beam intensity monitor. When ``None`` + the rate is computed from the scan data as + ``nanmean(presample_intensity / dwell_time_actual)``. + :vartype presample_intensity_reference_rate: float, optional + :ivar background: Background scan configuration. Required for + ``'flux_absorption'`` and ``'flux_absorption_background'`` + correction types. + :vartype background: Background, optional + """ + + correction_type: Literal['flux', 'flux_absorption', + 'flux_absorption_background'] + name: str = Field(validation_alias=AliasChoices('name', 'title')) + uncorrected_data_name: str = Field(validation_alias=AliasChoices( + 'uncorrected_data_name', 'uncorrected_data_title')) + presample_intensity_reference_rate: Optional[float] = None + background: Optional[Background] = None + + def zarr_tree(self, dataset_shape, dataset_chunks, integration_shape, + nxlinks=None): + """Return a dictionary representing the zarr tree for this + correction's output container. + + The returned tree is compatible with + :func:`~CHAP.saxswaxs.utils.dict_to_zarr` and, after conversion + with :class:`~CHAP.common.processor.ZarrToNexusProcessor`, + produces an ``NXprocess`` group containing a ``data`` sub-group + with an ``I_corrected`` dataset and, when a ``background`` is + configured, an ``I_background`` dataset. + + :param dataset_shape: Shape of the measurement (scan) dimensions + of the output dataset, excluding the integration dimensions. + :type dataset_shape: tuple[int, ...] + :param dataset_chunks: Chunk shape along the scan dimensions, or + ``'auto'``. + :type dataset_chunks: list[int] or str + :param integration_shape: Shape of one frame of integration + results for the integration named by + ``uncorrected_data_name``. + :type integration_shape: tuple[int, ...] + :param nxlinks: NeXus path(s) to link into the ``data`` group. + When the zarr tree is written to a ``.zarr`` file and + converted to ``.nxs`` with + :class:`~CHAP.common.processor.ZarrToNexusProcessor`, each + path produces an ``NXlink`` whose name is + ``os.path.basename(path)``. Accepts a single path string or + a list of path strings. All links must be explicit; none are + auto-generated. + :type nxlinks: str or list[str], optional + :returns: Nested dict representing the zarr group tree for this + correction. + :rtype: dict + """ + if isinstance(nxlinks, str): + nxlinks = [nxlinks] + data_attrs = {} + if nxlinks: + data_attrs['__nxlinks__'] = { + os.path.basename(p): p for p in nxlinks + } + if self.background is None: + background_arrays = {} + else: + background_arrays = self.background.zarr_arrays(integration_shape) + data_attrs['background'] = str(self.background.model_dump()) + return { + # NXprocess + 'attributes': { + 'correction_type': self.correction_type, + 'default': 'data', + }, + 'children': { + 'program': 'CHAP.saxswaxs', + 'version': chap_version, + 'data': { + # NXdata + 'attributes': data_attrs, + 'children': { + 'I_corrected': { + 'attributes': { + 'long_name': 'Intensity (a.u)', + 'units': 'a.u,' + }, + 'dtype': 'float64', + 'shape': (*dataset_shape, *integration_shape), + }, + **background_arrays, + } + } + } + } + + @cached_property + def processor_name(self): + """Name of the processor class that implements this correction. + + Derived from ``correction_type`` by capitalising each + ``'_'``-separated word and appending ``'CorrectionProcessor'`` + (e.g. ``'flux_absorption'`` → + ``'FluxAbsorptionCorrectionProcessor'``). + + :type: str + """ + return ''.join( + [x.capitalize() for x in self.correction_type.split('_')] + + ['CorrectionProcessor'] + ) + + @cached_property + def processor_module(self): + """Module object containing the processor class for this + correction. + + :type: module + """ + return __import__('CHAP.saxswaxs.processor', + fromlist=[self.processor_name]) + + @cached_property + def processor_class(self): + """Processor class that implements this correction. + + :type: type + """ + return getattr(self.processor_module, self.processor_name) + + @property + def processor(self): + """A new instance of the processor class for this correction, + initialised with the current configuration. + + :type: Processor + """ + return self.processor_class(config=self.model_dump()) + + +class CorrectionsConfig(CHAPBaseModel): + """Configuration container for an ordered list of SAXS/WAXS + correction steps to apply to integrated detector data. + + :ivar corrections: Ordered list of correction configurations. + :vartype corrections: list[CorrectionConfig] + """ + + corrections: conlist(item_type=CorrectionConfig) + + def zarr_tree(self, dataset_shape, dataset_chunks, + integration_shapes, nxlinks=None): + """Return a dictionary representing the zarr tree for all + corrections in this configuration. + + Each correction gets its own sub-group keyed by + :attr:`CorrectionConfig.name`. See + :meth:`CorrectionConfig.zarr_tree` for the structure of each + sub-group. + + :param dataset_shape: Shape of the measurement (scan) dimensions, + excluding integration dimensions. + :type dataset_shape: tuple[int, ...] + :param dataset_chunks: Chunk shape along the scan dimensions, or + ``'auto'``. + :type dataset_chunks: list[int] or str + :param integration_shapes: Mapping from integration name to the + shape of one integration result frame. Used to look up the + ``integration_shape`` for each correction via + :attr:`CorrectionConfig.uncorrected_data_name`. + :type integration_shapes: dict[str, tuple[int, ...]] + :param nxlinks: NeXus links to inject into each correction's + ``data`` group. May be a single path string or list of path + strings (forwarded to every correction), or a dict keyed by + correction name mapping each correction to its own path(s). + See :meth:`CorrectionConfig.zarr_tree` for details on how + individual paths are handled. + :type nxlinks: str or list[str] or dict[str, str or list[str]], + optional + :returns: Nested dict representing the zarr group tree for all + corrections. + :rtype: dict + """ + if not isinstance(nxlinks, dict): + nxlinks = {corr.name: nxlinks for corr in self.corrections} + return { + 'root': { + 'attributes': {}, + }, + 'children': { + corr.name: corr.zarr_tree( + dataset_shape, dataset_chunks, + integration_shapes.get( + corr.uncorrected_data_name, None + ), + nxlinks=nxlinks.get(corr.name), + ) + for corr in self.corrections + } + } + + +class FluxCorrectionConfig(CorrectionConfig): + """Correction configuration for flux-only correction. + + Applies a flux correction that normalises the measured intensity by + the pre-sample beam monitor counts, referenced to a fixed counting + rate. No background scan is required. + """ + + correction_type: Literal['flux'] = 'flux' + + +class FluxAbsorptionCorrectionConfig(FluxCorrectionConfig): + """Correction configuration for combined flux and absorption + correction. + + Extends :class:`FluxCorrectionConfig` with a required background + scan used to determine the sample transmission. + + :ivar background: Background scan configuration used to compute + the sample transmission term. + :vartype background: Background + """ + + correction_type: Literal['flux_absorption'] = 'flux_absorption' + background: Background + + +class FluxAbsorptionBackgroundCorrectionConfig( + FluxAbsorptionCorrectionConfig): + """Correction configuration for combined flux, absorption, and + background-subtraction correction with optional thickness + normalisation. + + Extends :class:`FluxAbsorptionCorrectionConfig` with an integrated + background subtraction step and an optional sample thickness or + linear attenuation coefficient for thickness normalisation. At + most one of ``sample_thickness_cm`` and ``sample_mu_inv_cm`` may be + provided. + + :ivar background: Background scan configuration. + :vartype background: Background + :ivar sample_thickness_cm: Sample thickness in centimetres. When + provided, corrected intensities are divided by this value. + Mutually exclusive with ``sample_mu_inv_cm``. + :vartype sample_thickness_cm: float, optional + :ivar sample_mu_inv_cm: Sample linear attenuation coefficient in + inverse centimetres. When provided, the effective thickness is + derived from the measured transmission. Mutually exclusive with + ``sample_thickness_cm``. + :vartype sample_mu_inv_cm: float, optional + """ + + correction_type: Literal[ + 'flux_absorption_background'] = 'flux_absorption_background' + background: Background + sample_thickness_cm: Optional[confloat(gt=0)] = None + sample_mu_inv_cm: Optional[confloat(gt=0)] = None + + @model_validator(mode='after') + def validate_thickness(self): + """Ensure ``sample_thickness_cm`` and ``sample_mu_inv_cm`` are + not both specified. + + :raises ValueError: If both fields are set. + :returns: The validated model instance. + :rtype: FluxAbsorptionBackgroundCorrectionConfig + """ + if self.sample_thickness_cm and self.sample_mu_inv_cm: + raise ValueError( + 'Use sample_thickness_cm OR sample_mu_inv_cm, not both.' + ) + return self diff --git a/CHAP/saxswaxs/processor.py b/CHAP/saxswaxs/processor.py index f2933d86..aa4be885 100755 --- a/CHAP/saxswaxs/processor.py +++ b/CHAP/saxswaxs/processor.py @@ -22,13 +22,23 @@ import numpy as np # Local modules +from CHAP.common.models.common import IndexSliceConfig from CHAP.common.models.map import ( + DetectorConfig, Detector, MapConfig, ) from CHAP.common.models.integration import PyfaiIntegrationConfig from CHAP.common.processor import ExpressionProcessor from CHAP.processor import Processor +from CHAP.pipeline import PipelineData +from CHAP.saxswaxs.models import ( + CorrectionsConfig, + FluxCorrectionConfig, + FluxAbsorptionCorrectionConfig, + FluxAbsorptionBackgroundCorrectionConfig, +) +from CHAP.saxswaxs.utils import dict_to_zarr class CfProcessor(Processor): @@ -214,36 +224,41 @@ def process( class FluxCorrectionProcessor(ExpressionProcessor): - """Processor for flux correction.""" + """Processor for applying a flux correction to azimuthally + integrated SAXS/WAXS intensity data. - def process( - self, data, presample_intensity_reference_rate=None, - nxprocess=False): - """Given input data for `'intensity'`, `'presample_intensity'`, - and `'dwell_time_actual'`, compute the flux corrected intensity - signal. - - :param data: Input data list containing items with names - `'intensity'`, `'presample_intensity'`, and - `'dwell_time_actual'` (if - `presample_intensity_reference_rate` is not specified). + Normalises the measured intensity by the pre-sample beam monitor + counts, referenced to a fixed counting rate configured via + :attr:`config`. + + :ivar config: Correction configuration. + :vartype config: FluxCorrectionConfig + """ + config: FluxCorrectionConfig + def process(self, data, nxprocess=False): + """Compute the flux-corrected intensity. + + Requires input data items named ``'intensity'`` (the raw + integrated signal) and ``'presample_intensity'`` (beam monitor + counts). When + :attr:`~saxswaxs.models.FluxCorrectionConfig.presample_intensity_reference_rate` + is not pre-configured in :attr:`config`, an additional item + named ``'dwell_time_actual'`` is also required and the + reference rate is computed on the fly as + ``np.nanmean(presample_intensity / dwell_time_actual)``. + + :param data: Input data list. :type data: list[PipelineData] - :param presample_intensity_reference_rate: Reference counting - rate for the `'presample_intensity'` signal. If not - specified, it will set to - `'numpy.nanmean(presample_intensity / - dwell_time_actual)'`. - :type presample_intensity_reference_rate: float, optional - :param nxprocess: Flag to indicate the flux corrected data - should be returned as a NeXus style + :param nxprocess: Return the result as a NeXus `NXobject `__ - object. Defaults to `False`. + object. Defaults to ``False``. :type nxprocess: bool, optional - :returns: Flux corrected version of input `'intensity'` data. - :rtype: Any + :returns: Flux-corrected intensity array (or NXobject when + ``nxprocess`` is ``True``). + :rtype: numpy.ndarray or nexusformat.nexus.NXobject """ - if presample_intensity_reference_rate is None: - presample_intensity_reference_rate = super().process( + if self.config.presample_intensity_reference_rate is None: + self.config.presample_intensity_reference_rate = super().process( data, 'np.nanmean(presample_intensity / dwell_time_actual)' ) @@ -251,7 +266,7 @@ def process( data, name='presample_intensity', ) intensity = self.get_data( - data, name='intensity', + data, name=self.config.uncorrected_data_name, ) # nxfieldtable = { # 'intensity': intensity, @@ -270,7 +285,7 @@ def process( ) symtable = { 'presample_intensity_reference_rate': - presample_intensity_reference_rate, + self.config.presample_intensity_reference_rate, 'intensity': intensity, 'presample_intensity': presample_intensity } @@ -285,41 +300,47 @@ def process( class FluxAbsorptionCorrectionProcessor(ExpressionProcessor): - """Processor for flux and absorption correction.""" + """Processor for applying combined flux and absorption correction + to azimuthally integrated SAXS/WAXS intensity data. - def process( - self, data, presample_intensity_reference_rate=None, - nxprocess=False): - """Given input data for `'intensity'`, `'presample_intensity'`, - `'postsample_intensity'`, `'background_presample_intensity'`, - `'background_postsample_intensity'`, and - `'dwell_time_actual'`, compute the flux and absorption - corrected intensity signal. - - :param data: Input data list containing all necessary data - labelled with their proper names. + In addition to flux normalisation (see + :class:`FluxCorrectionProcessor`), divides by the sample + transmission, which is estimated from the ratio of post-sample to + pre-sample intensities for both the sample and a background scan. + + :ivar config: Correction configuration. + :vartype config: FluxAbsorptionCorrectionConfig + """ + config: FluxAbsorptionCorrectionConfig + def process(self, data, nxprocess=False): + """Compute the flux- and absorption-corrected intensity. + + Requires input data items named ``'intensity'``, + ``'presample_intensity'``, ``'postsample_intensity'``, + ``'background_presample_intensity'``, and + ``'background_postsample_intensity'``. When + :attr:`~saxswaxs.models.FluxAbsorptionCorrectionConfig.presample_intensity_reference_rate` + is not pre-configured in :attr:`config`, an additional item + named ``'dwell_time_actual'`` is also required and the + reference rate is computed on the fly as + ``np.nanmean(presample_intensity / dwell_time_actual)``. + + :param data: Input data list. :type data: list[PipelineData] - :param presample_intensity_reference_rate: Reference counting - rate for the `'presample_intensity'` signal. If not - specified, it will be calculated with - `'numpy.nanmean(presample_intensity / - dwell_time_actual)'`. - :type presample_intensity_reference_rate: float, optional - :param nxprocess: Flag to indicate the flux and absorption - corrected data should be returned as a NeXus style + :param nxprocess: Return the result as a NeXus `NXobject `__ - object. Defaults to `False`. + object. Defaults to ``False``. :type nxprocess: bool, optional - :returns: Flux and absorption corrected version of input - `'intensity'` data. - :rtype: Any + :returns: Flux- and absorption-corrected intensity array (or + NXobject when ``nxprocess`` is ``True``). + :rtype: numpy.ndarray or nexusformat.nexus.NXobject """ intensity = self.get_data( - data, name='intensity', + data, name=self.config.uncorrected_data_name, #'intensity', ) - if presample_intensity_reference_rate is None: - presample_intensity_reference_rate = super().process( + if self.config.presample_intensity_reference_rate is None: + self.config.presample_intensity_reference_rate = super().process( data, 'np.nanmean(presample_intensity / dwell_time_actual)' ) @@ -346,7 +367,7 @@ def process( symtable = { 'presample_intensity_reference_rate': - presample_intensity_reference_rate, + self.config.presample_intensity_reference_rate, 'intensity': intensity, 'presample_intensity': presample_intensity, 'tt': tt @@ -363,61 +384,60 @@ def process( class FluxAbsorptionBackgroundCorrectionProcessor(ExpressionProcessor): - """Processor for flux, absorption, and background correction as - well as optional thickness correction.""" - - def process( - self, data, presample_intensity_reference_rate=None, - sample_thickness_cm=None, sample_mu_inv_cm=None, nxprocess=False): - """Given input data for `'intensity'`, `'presample_intensity'`, - `'postsample_intensity'`, `'background_presample_intensity'`, - `'background_postsample_intensity'`, `'background_intensity'`, - and `'dwell_time_actual'`, return flux, absorption and - background corrected intensity signal. - - :param data: Input data list containing all necessary data - labelled with their proper names. + """Processor for applying combined flux, absorption, and + background-subtraction correction to azimuthally integrated + SAXS/WAXS intensity data, with optional thickness normalisation. + + Extends :class:`FluxAbsorptionCorrectionProcessor` by subtracting + an integrated background signal after flux and absorption + correction. Thickness normalisation is applied when + :attr:`~saxswaxs.models.FluxAbsorptionBackgroundCorrectionConfig.sample_thickness_cm` + or + :attr:`~saxswaxs.models.FluxAbsorptionBackgroundCorrectionConfig.sample_mu_inv_cm` + is set in :attr:`config`. + + :ivar config: Correction configuration. + :vartype config: FluxAbsorptionBackgroundCorrectionConfig + """ + config: FluxAbsorptionBackgroundCorrectionConfig + def process(self, data, nxprocess=False): + """Compute the flux-, absorption-, and background-corrected + intensity, with optional thickness normalisation. + + Requires input data items named ``'intensity'``, + ``'presample_intensity'``, ``'postsample_intensity'``, + ``'background_presample_intensity'``, + ``'background_postsample_intensity'``, and + ``'background_intensity'``. When + :attr:`~saxswaxs.models.FluxAbsorptionBackgroundCorrectionConfig.presample_intensity_reference_rate` + is not pre-configured in :attr:`config`, an additional item + named ``'dwell_time_actual'`` is also required and the + reference rate is computed on the fly as + ``np.nanmean(presample_intensity / dwell_time_actual)``. + + Thickness normalisation is controlled by :attr:`config`: if + ``config.sample_thickness_cm`` is set the corrected signal is + divided by that value; if ``config.sample_mu_inv_cm`` is set + the effective thickness is derived from the measured + transmission; otherwise no thickness normalisation is applied. + + :param data: Input data list. :type data: list[PipelineData] - :param presample_intensity_reference_rate: Reference counting - rate for the `'presample_intensity'` signal. If not - specified, it will be calculated with - `'numpy.nanmean(presample_intensity / - dwell_time_actual)'`. - :type presample_intensity_reference_rate: float, optional - :param sample_thickness_cm: Sample thickness in - centimeters. If specified, this processor will - additionally perform thickness correction. Use of this - parameter is mutualy exclusive with - use of `sample_mu_inv_cm`. - :type sample_thickness_cm: float, optional - :param sample_mu_inv_cm: Sample linear attenuation coefficient - in inverse centimeters. If specified, this processor will - additionally perform thickness correction. Use of this - parameter is mutualy exclusive with use of - `sample_thickness_cm`. - :type sample_mu_inv_cm: float, optional - :param nxprocess: Flag to indicate the flux, absorption, and - background corrected data should be returned as a Nexus - style + :param nxprocess: Return the result as a NeXus `NXobject `__ - object. Defaults to `False`. + object. Defaults to ``False``. :type nxprocess: bool, optional - :returns: Flux, absorption and background corrected version of - input `'intensity'` data. - :rtype: Any + :returns: Flux-, absorption-, and background-corrected + intensity array (or NXobject when ``nxprocess`` is + ``True``). + :rtype: numpy.ndarray or nexusformat.nexus.NXobject """ - if sample_thickness_cm is not None and sample_mu_inv_cm is not None: - raise ValueError(( - 'Cannot use sample_thickness_cm and sample_mu_inv_cm' - ' at the same time' - )) - intensity = self.get_data( - data, name='intensity', + data, name=self.config.uncorrected_data_name, ) - if presample_intensity_reference_rate is None: - presample_intensity_reference_rate = super().process( + if self.config.presample_intensity_reference_rate is None: + self.config.presample_intensity_reference_rate = super().process( data, 'np.nanmean(presample_intensity / dwell_time_actual)' ) @@ -450,17 +470,17 @@ def process( background_intensity = np.broadcast_to( background_intensity, intensity.shape) - if sample_thickness_cm is not None: - t = sample_thickness_cm - elif sample_mu_inv_cm is not None: - t = -np.log(tt / sample_mu_inv_cm) + if self.config.sample_thickness_cm is not None: + t = self.config.sample_thickness_cm + elif self.config.sample_mu_inv_cm is not None: + t = -np.log(tt / self.config.sample_mu_inv_cm) else: t = 1 symtable = { 't': t, 'presample_intensity_reference_rate': - presample_intensity_reference_rate, + self.config.presample_intensity_reference_rate, 'intensity': intensity, 'presample_intensity': presample_intensity, 'tt': tt, @@ -498,7 +518,7 @@ class PyfaiIntegrationProcessor(Processor): init_var=True) config: PyfaiIntegrationConfig - def process(self, data, idx_slices=None): + def process(self, data): """Perform a set of integrations on 2D detector data. :param data: Input data. @@ -515,19 +535,11 @@ def process(self, data, idx_slices=None): # System modules import time - if idx_slices is None: - idx_slices = [{'start':0, 'step': 1}] - # Organize input for integrations input_data = {d['name']: d['data'] for d in [d for d in data if isinstance(d['data'], np.ndarray)]} ais = {ai.get_id(): ai for ai in self.config.azimuthal_integrators} - # Finalize idx slice for results - idx = tuple(slice(idx_slice.get('start'), - idx_slice.get('stop'), - idx_slice.get('step')) for idx_slice in idx_slices) - # Perform integration(s), package results for ZarrResultsWriter results = [] nframes = len(input_data[list(input_data.keys())[0]]) @@ -543,7 +555,6 @@ def process(self, data, idx_slices=None): [ { 'path': f'{integration.name}/data/I', - 'idx': idx, 'data': np.asarray(result['intensities']), }, ] @@ -551,119 +562,6 @@ def process(self, data, idx_slices=None): return results -class SetupResultsProcessor(Processor): - """Processor for creating an intital - `Zarr group `__ - object with empty datasets for filling in by - :class:`~CHAP.saxswaxs.PyfaiIntegrationProcessor` and - :class:`~CHAP.common.ZarrValuesWriter`. - - :ivar config: Initialization parameters for an instance of - :class:`~CHAP.common.models.integration.PyfaiIntegrationConfig`. - :vartype config: dict - :ivar dataset_shape: Shape of the completed dataset that will be - processed later on (shape of the measurement itself, _not_ - including the dimensions of any signals collected at each point - in that measurement). - :vartype dataset_shape: int or list[int] - :ivar dataset_chunks: Extent of chunks along each dimension of the - completed dataset / measurement. Choose this according to how - you will process your data -- for example, if your - `dataset_shape` is `[m, n]`, and you are planning to process - each of the `m` rows as chunks, `dataset_chunks` should be - `[1, n]`. But if you plan to process each of the `n` columns as - chunks, `dataset_chunks` should be `[m, 1]`, - defaults to `"auto"`. - :vartype dataset_chunks: list[int] or Literal["auto"], optional - """ - - pipeline_fields: dict = Field( - default={ - 'config': 'common.models.integration.PyfaiIntegrationConfig' - }, - init_var=True) - config: PyfaiIntegrationConfig - dataset_shape: conlist(item_type=conint(gt=0), min_length=1) - dataset_chunks: Optional[ - Union[ - Literal['auto'], - conlist(item_type=conint(gt=0), min_length=1) - ]] = 'auto' - - def process(self, data): - """Create and return a - `Zarr group `__ - object to hold processed SAXS/WAXS data processed - by :class:`~CHAP.saxswaxs.PyfaiIntegrationProcessor`. - - :param data: Input data. - :type data: list[PipelineData] - :return: Empty structure for filling in SAXS/WAXS data. - :rtype: zarr.Group - """ - - # Get Zarr tree as dict from the PyfaiIntegrationConfig - tree = self.config.zarr_tree(self.dataset_shape, self.dataset_chunks) - - # Construct & return the root Zarr group - return self.zarr_setup(tree) - - def zarr_setup(self, tree): - """Create a - `Zarr group `__ - object based on a dictionary representing a Zarr tree of groups - and arrays. - - :param tree: Nested dictionary representing a Zarr tree of - groups and arrays. - :type tree: dict[str, Any] - :return: Zarr group corresponding to the contents of `tree`. - :rtype: zarr.Group - """ - # Third party modules - # pylint: disable=import-error - import zarr - from zarr.storage import MemoryStore - - def create_group_or_dataset(node, zarr_parent, indent=0): - """Create and return a - `Zarr group `__ - `Zarr dataset `__. - - :param node: Child Zarr tree group. - :type node: zarr.Group or zarr.Array - :param zarr_parent: Parent Zarr tree group. - :type zarr_parent: zarr.Group - :param indent: Indentation level, defaults to 0. - :type indent: int, optional - """ - # Set attributes if present - if 'attributes' in node: - for key, value in node['attributes'].items(): - zarr_parent.attrs[key] = value - # Create children (groups or datasets) - if 'children' in node: - for name, child in node['children'].items(): - if 'shape' in child or 'data' in child: - # It's a dataset - self.logger.debug(f'Adding dset: {name}') - zarr_parent.create_dataset( - name, - **child, - ) - # Set dataset attributes - if 'attributes' in child: - for key, value in child['attributes'].items(): - zarr_parent[name].attrs[key] = value - else: - # It's a group - group = zarr_parent.create_group(name) - create_group_or_dataset(child, group, indent=indent+2) - results = zarr.create_group(store=MemoryStore({})) - create_group_or_dataset(tree['root'], results) - return results - - class SetupProcessor(Processor): """Convenience Processor for setting up a container for SAXS/WAXS experiments. @@ -674,13 +572,13 @@ class SetupProcessor(Processor): :ivar pyfai_config: Initialization parameters for an instance of :class:`~CHAP.common.models.integration.PyfaiIntegrationConfig`. :vartype pyfai_config: dict, optional - :ivar detectors: Detector configurations. - :vartype detectors: DetectorConfig + :ivar detector_config: Detector configurations. + :vartype detector_config: DetectorConfig :ivar dataset_shape: Shape of the completed dataset that will be processed later on (shape of the measurement itself, _not_ including the dimensions of any signals collected at each point - in that measurement). - :vartype dataset_shape: int or list[int] + in that measurement). Defaults to `[0]`. + :vartype dataset_shape: int or list[int], optional :ivar dataset_chunks: Extent of chunks along each dimension of the completed dataset / measurement. Choose this according to how you will process your data -- for example, if your @@ -703,7 +601,9 @@ class SetupProcessor(Processor): pipeline_fields: dict = Field( default={ 'map_config': 'common.models.map.MapConfig', - 'pyfai_config': 'common.models.integration.PyfaiIntegrationConfig' + 'detector_config': 'common.models.map.DetectorConfig', + 'pyfai_config': 'common.models.integration.PyfaiIntegrationConfig', + 'correction_config': 'saxswaxs.models.CorrectionsConfig', }, init_var=True) # map_config needs a default value because the map configuration @@ -711,10 +611,11 @@ class SetupProcessor(Processor): # the case, then map_config needs SOME value in order for the # Pipeline to pass validation. map_config: MapConfig = None - pyfai_config: PyfaiIntegrationConfig - detectors: conlist(item_type=Detector, min_length=1) + pyfai_config: PyfaiIntegrationConfig = None + detector_config: DetectorConfig = DetectorConfig(detectors=[]) + correction_config: CorrectionsConfig = CorrectionsConfig(corrections=[]) dataset_shape: Optional[ - conlist(item_type=conint(gt=0), min_length=1)] = None + conlist(item_type=conint(ge=0), min_length=1)] = [0] dataset_chunks: Optional[ Union[ Literal['auto'], @@ -745,53 +646,23 @@ def process(self, data): # pylint: enable=import-error # Local modules - from CHAP.common import ( + from CHAP.common.processor import ( MapProcessor, NexusToZarrProcessor, ) from CHAP.pipeline import PipelineData - #from CHAP.saxswaxs.processor import SetupResultsProcessor - - def set_logger(pipeline_item): - """Set the logger and logging handler for given pipeline - item. - - :param pipeline_item: Pipeline item. - :type pipeline_item: PipelineItem - :return: Input Pipeline item, with updated logger and - logging handler. - :rtype: PipelineItem - """ - pipeline_item.logger = self.logger - pipeline_item.logger.name = pipeline_item.__class__.__name__ - handler = logging.StreamHandler() - handler.setFormatter(logging.Formatter( - '{asctime}: {name:20} (from '+ self.__class__.__name__ - + '): {levelname}: {message}', - datefmt='%Y-%m-%d %H:%M:%S', style='{')) - pipeline_item.logger.removeHandler( - pipeline_item.logger.handlers[0]) - pipeline_item.logger.addHandler(handler) - return pipeline_item # Get NXroot container for raw data map map_processor_kwargs = { - 'config': self.map_config + 'config': self.map_config, + 'detector_config': self.detector_config, + 'remove_constant_dims': False, } - if self.raw_data: - map_processor_kwargs['detector_config'] = { - 'detectors': self.detectors - } - else: - map_processor_kwargs['detector_config'] = { - 'detectors': [] - } - setup_map_processor = set_logger( - MapProcessor( - **map_processor_kwargs - # config=self.map_config, - # detector_config={'detectors': self.detectors}, - ) + if not self.raw_data: + self.map_config.spec_scans[0].scan_numbers = [] + + setup_map_processor = self.setup_pipelineitem( + MapProcessor(**map_processor_kwargs) ) ddata = [ PipelineData( @@ -824,18 +695,126 @@ def set_logger(pipeline_item): self.dataset_chunks = [self.dataset_chunks] # Convert raw data map container to Zarr format - ddata_converter = set_logger(NexusToZarrProcessor()) + ddata_converter = self.setup_pipelineitem(NexusToZarrProcessor()) zarr_map = ddata_converter.process(ddata, chunks=self.dataset_chunks) - # Get Zarr container for integration results - setup_results_processor = set_logger( - SetupResultsProcessor( - config=self.pyfai_config, - dataset_shape=self.dataset_shape, - dataset_chunks=self.dataset_chunks, + # Get paths to independent_dimension arrays + dim_paths = [ + f'{self.map_config.title}/independent_dimensions/{dim.label}' + for dim in self.map_config.independent_dimensions + ] + + # Get Zarr container for pyfai integrations + zarr_pyfai_tree = self.pyfai_config.zarr_tree( + self.dataset_shape, self.dataset_chunks, + nxlinks=dim_paths + ) + zarr_pyfai = dict_to_zarr(zarr_pyfai_tree, logger=self.logger) + + + # Get zarr container for corrected datasets + integration_shapes = { + integration.name: integration.result_shape + for integration in self.pyfai_config.integrations + } + intg_by_name = { + intg.name: intg for intg in self.pyfai_config.integrations + } + corr_nxlinks = { + corr.name: ( + dim_paths + + [f'{corr.uncorrected_data_name}/data/I'] + + [ + f'{corr.uncorrected_data_name}/data/{coord}' + for coord in intg_by_name[ + corr.uncorrected_data_name].result_coords + ] ) + for corr in self.correction_config.corrections + } + zarr_corr = dict_to_zarr( + self.correction_config.zarr_tree( + self.dataset_shape, self.dataset_chunks, + integration_shapes, + nxlinks=corr_nxlinks, + ), + logger=self.logger, ) - zarr_results = setup_results_processor.process(data) + + # For corrections that include a background, read and integrate + # that background data and store it in zarr_corr + for corr_cfg in self.correction_config.corrections: + if corr_cfg.background is None: + continue + if not self.detector_config.detectors: + self.logger.warning( + f'No detectors configured; skipping background ' + f'processing for correction "{corr_cfg.name}"') + continue + # Read in background raw detector data + self.logger.info( + f'Reading background data for correction "{corr_cfg.name}"') + idx_slice = corr_cfg.background.idx_slice + bg_det_images = { + d.get_id(): [] for d in self.detector_config.detectors} + for scan_number in corr_cfg.background.scan_numbers: + scanparser = corr_cfg.background.get_scanparser(scan_number) + npts = int(scanparser.spec_scan_npts) + slice_stop = ( + min(idx_slice._slice.stop, npts) + if idx_slice._slice.stop > 0 else npts + ) + scan_indices = range(npts)[ + slice(idx_slice._slice.start, slice_stop, + idx_slice._slice.step)] + for i in scan_indices: + for det in self.detector_config.detectors: + bg_det_images[det.get_id()].append( + scanparser.get_detector_data(det.get_id(), i)) + # Average all background data to a single frame before + # processing with appropriate integration + for det_id in bg_det_images: + bg_det_images[det_id] = np.mean( + bg_det_images[det_id], axis=0, keepdims=True) + self.logger.info( + f'Integrating background data for correction ' + f'"{corr_cfg.name}"') + bg_pyfai_input = [ + PipelineData(name=det_id, data=imgs) + for det_id, imgs in bg_det_images.items() + ] + bg_pyfai_config = self.pyfai_config.model_copy( + update={'integrations': [ + intg for intg in self.pyfai_config.integrations + if intg.name == corr_cfg.uncorrected_data_name + ]} + ) + bg_integrated = self.setup_pipelineitem( + PyfaiIntegrationProcessor(config=bg_pyfai_config) + ).process(bg_pyfai_input)[0] + # Read background scalar data + bg_presample = self.map_config.presample_intensity.get_value( + corr_cfg.background, scan_number, -1, + self.map_config.scalar_data + )[idx_slice._slice] + bg_postsample = self.map_config.postsample_intensity.get_value( + corr_cfg.background, scan_number, -1, + self.map_config.scalar_data + )[idx_slice._slice] + # Fill in placeholder zarr arrays with real background + # data + data_group = zarr_corr[corr_cfg.name]['data'] + data_group['I_background'][:] = np.squeeze( + bg_integrated['data'], axis=0 + ) + bg_presample_arr = data_group.create_array( + 'background_presample_intensity', + shape=bg_presample.shape, dtype=bg_presample.dtype) + bg_presample_arr[:] = bg_presample + bg_postsample_arr = data_group.create_array( + 'background_postsample_intensity', + shape=bg_postsample.shape, dtype=bg_postsample.dtype) + bg_postsample_arr[:] = bg_postsample # Assemble containers for raw & processed data zarr_root = zarr.create_group(store=MemoryStore({})) @@ -853,10 +832,34 @@ async def copy_zarr_store(source_store, dest_store): buf = await source_store.get( k, prototype=default_buffer_prototype()) await dest_store.set(k, buf) - asyncio.run(copy_zarr_store(zarr_map.store, zarr_root.store)) - asyncio.run(copy_zarr_store(zarr_results.store, zarr_root.store)) + for zarr_group in (zarr_map, zarr_pyfai, zarr_corr): + asyncio.run(copy_zarr_store(zarr_group.store, zarr_root.store)) return zarr_root + def setup_pipelineitem(self, pipeline_item): + """Convenience method to use a nice logger when this + ``Processor`` calls on another ``PipelineItem``. Set the + logger and logging handler for given pipeline item. + + :param pipeline_item: Pipeline item. + :type pipeline_item: PipelineItem + :return: Input Pipeline item, with updated logger and + logging handler. + :rtype: PipelineItem + """ + import logging + + pipeline_item.logger = logging.getLogger( + pipeline_item.__class__.__name__, + ) + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter( + '{asctime}: {name:20} (from '+ self.__class__.__name__ + + ') (L{lineno}): {levelname}: {message}', + datefmt='%Y-%m-%d %H:%M:%S', style='{')) + pipeline_item.logger.handlers = [handler] + return pipeline_item + class UnstructuredToStructuredProcessor(Processor): """Processor to aggregate "unstructured" data into a single NeXus @@ -1187,18 +1190,19 @@ class UpdateValuesProcessor(Processor): :ivar scan_number: Scan number from which to read and process a slice of raw data. :vartype scan_number: int - :ivar detectors: Detector configurations. - :vartype detectors: list[CHAP.common.models.map.Detector] + :ivar detector_config: Detector configurations. + :vartype detector_config: :class:`~CHAP.common.models.map.DetectorConfig` :ivar raw_data: Flag to indicate wether or not space for raw detector data should be included in the values returned, defaults to `True`. :vartype raw_data: bool, optional """ - pipeline_fields: dict = Field( - default={ + { 'map_config': 'common.models.map.MapConfig', - 'pyfai_config': 'common.models.integration.PyfaiIntegrationConfig' + 'detector_config': 'common.models.map.DetectorConfig', + 'pyfai_config': 'common.models.integration.PyfaiIntegrationConfig', + 'correction_config': 'saxswaxs.models.CorrectionsConfig', }, init_var=True) # map_config needs a default value because the map configuration @@ -1207,12 +1211,15 @@ class UpdateValuesProcessor(Processor): # Pipeline to pass validation. map_config: MapConfig = None pyfai_config: PyfaiIntegrationConfig + detector_config: DetectorConfig = DetectorConfig(detectors=[]) + correction_config: CorrectionsConfig spec_file: FilePath scan_number: conint(gt=0) - detectors: conlist(item_type=Detector, min_length=1) + filename: Optional[str] = None raw_data: Optional[bool] = True + idx_slice: Optional[IndexSliceConfig] = IndexSliceConfig() - def process(self, data, idx_slice=None): + def process(self, data): """Processes a slice of data for updating values in an existing container for a SAXS/WAXS experiment. @@ -1234,78 +1241,166 @@ def process(self, data, idx_slice=None): # Pass detector data to PyfaiIntegration processor # Concatenate & return results # System modules + from copy import deepcopy import logging import os # Local modules - from CHAP.common import MapSliceProcessor + from CHAP.common.map_utils import MapSliceProcessor from CHAP.pipeline import PipelineData - #from CHAP.saxswaxs.processor import PyfaiIntegrationProcessor - def set_logger(pipeline_item): - """Set the logger and logging handler for given pipeline - item. - - :param pipeline_item: Pipeline item. - :type pipeline_item: PipelineItem - :return: Input Pipeline item, with updated logger and - logging handler. - :rtype: PipelineItem - """ - pipeline_item.logger = self.logger - pipeline_item.logger.name = pipeline_item.__class__.__name__ - handler = logging.StreamHandler() - handler.setFormatter(logging.Formatter( - '{asctime}: {name:20} (from '+ self.__class__.__name__ - + '): {levelname}: {message}', - datefmt='%Y-%m-%d %H:%M:%S', style='{')) - pipeline_item.logger.removeHandler( - pipeline_item.logger.handlers[0]) - pipeline_item.logger.addHandler(handler) - return pipeline_item - - if idx_slice is None: - idx_slice = {'start':0, 'step': 1} + # Use a copy of input data so we can append to it inside this + # Processor without modifying the actual Pipeline's data + # unecessarily. + _data = deepcopy(data) # Read in slice of raw data - raw_values = set_logger( + raw_values = self.setup_pipelineitem( MapSliceProcessor( map_config=self.map_config, - detectors=self.detectors, + detector_config=self.detector_config, spec_file=str(self.spec_file), - scan_number=self.scan_number, + scan_numbers=[self.scan_number], + idx_slice=self.idx_slice ) - ).process(None, idx_slice=idx_slice) + ).process(None) def _get_detector_data(values, name): - "Get the detector data.""" + """Get the detector data.""" for v in values: if os.path.basename(v['path']) == name: return v['data'] return None # Use raw detector data as input to integration - for d in self.detectors: - data.append( + for d in self.detector_config.detectors: + _data.append( PipelineData( name=d.get_id(), data=_get_detector_data(raw_values, d.get_id()), ) ) # Get integrated data - processed_values = set_logger( + processed_values = self.setup_pipelineitem( PyfaiIntegrationProcessor(config=self.pyfai_config) - ).process(data, idx_slices=[idx_slice]) + ).process(_data) + + # Get corrected data + corrected_values = [ + { + 'data': self.setup_pipelineitem( + corr_cfg.processor + ).process( + self.get_corrections_input_data( + raw_values, processed_values, corr_cfg + ), + nxprocess=False, + ), + 'path': f'{corr_cfg.name}/data/I_corrected' + } + for corr_cfg in self.correction_config.corrections + ] if self.raw_data: - return raw_values + processed_values + return raw_values + processed_values + corrected_values - detector_ids = [d.get_id() for d in self.detectors] + detector_ids = [d.get_id() for d in self.detector_config.detectors] scalar_values = [ d for d in raw_values if not os.path.basename(d['path']) in detector_ids ] - return scalar_values + processed_values + return scalar_values + processed_values + corrected_values + + def get_corrections_input_data(self, raw_values, processed_values, + corr_cfg): + corr_data = [] + for x in ('dwell_time_actual', 'presample_intensity', + 'postsample_intensity'): + for d in raw_values: + if d['path'].endswith(x): + corr_data.append( + PipelineData(data=d['data'], name=x) + ) + break + for d in processed_values: + if d['path'].startswith(f'{corr_cfg.uncorrected_data_name}/'): + corr_data.append( + PipelineData( + data=d['data'], + name=corr_cfg.uncorrected_data_name, + ) + ) + if corr_cfg.background is not None: + if self.filename is None: + self.logger.warning( + f'No filename configured; cannot read background ' + f'intensities for correction "{corr_cfg.name}"') + else: + # System modules + import os + pre_path = (f'{corr_cfg.name}/data/' + 'background_presample_intensity') + post_path = (f'{corr_cfg.name}/data/' + 'background_postsample_intensity') + intens_path = f'{corr_cfg.name}/data/I_background' + if os.path.splitext(self.filename)[1] in ('.nxs', '.h5', + '.hdf5'): + import h5py + with h5py.File(self.filename, 'r') as f: + for path, name in ( + (pre_path, + 'background_presample_intensity'), + (post_path, + 'background_postsample_intensity'), + (intens_path, 'background_intensity')): + if path in f: + corr_data.append(PipelineData( + data=np.asarray(f[path]), + name=name, + )) + else: + self.logger.warning( + f'{path} not found in {self.filename}') + else: + import zarr + zarrfile = zarr.open(self.filename, mode='r') + for path, name in ( + (pre_path, 'background_presample_intensity'), + (post_path, 'background_postsample_intensity'), + (intens_path, 'background_intensity')): + try: + corr_data.append(PipelineData( + data=np.asarray(zarrfile[path]), + name=name, + )) + except KeyError: + self.logger.warning( + f'{path} not found in {self.filename}') + return corr_data + + def setup_pipelineitem(self, pipeline_item): + """Convenience method to use a nice logger when this + ``Processor`` calls on another ``PipelineItem``. Set the + logger and logging handler for given pipeline item. + + :param pipeline_item: Pipeline item. + :type pipeline_item: PipelineItem + :return: Input Pipeline item, with updated logger and + logging handler. + :rtype: PipelineItem + """ + import logging + + pipeline_item.logger = logging.getLogger( + pipeline_item.__class__.__name__, + ) + handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter( + '{asctime}: {name:20} (from '+ self.__class__.__name__ + + ') (L{lineno}): {levelname}: {message}', + datefmt='%Y-%m-%d %H:%M:%S', style='{')) + pipeline_item.logger.handlers = [handler] + return pipeline_item if __name__ == '__main__': diff --git a/CHAP/saxswaxs/utils.py b/CHAP/saxswaxs/utils.py new file mode 100644 index 00000000..4a9cdfb7 --- /dev/null +++ b/CHAP/saxswaxs/utils.py @@ -0,0 +1,60 @@ +import zarr + +from CHAP.runner import set_logger + +logger, _ = set_logger(log_level='DEBUG') + +def dict_to_zarr(tree, logger=logger): + """Create a + `Zarr group `__ + object based on a dictionary representing a Zarr tree of groups + and arrays. + + :param tree: Nested dictionary representing a Zarr tree of + groups and arrays. + :type tree: dict[str, Any] + :return: Zarr group corresponding to the contents of `tree`. + :rtype: zarr.Group + """ + # Third party modules + # pylint: disable=import-error + import zarr + from zarr.storage import MemoryStore + + def create_group_or_dataset(node, zarr_parent, indent=0): + """Create and return a + `Zarr group `__ + `Zarr dataset `__. + + :param node: Child Zarr tree group. + :type node: zarr.Group or zarr.Array + :param zarr_parent: Parent Zarr tree group. + :type zarr_parent: zarr.Group + :param indent: Indentation level, defaults to 0. + :type indent: int, optional + """ + # Set attributes if present + if 'attributes' in node: + for key, value in node['attributes'].items(): + zarr_parent.attrs[key] = value + # Create children (groups or datasets) + if 'children' in node: + for name, child in node['children'].items(): + if 'shape' in child or 'data' in child: + # It's a dataset + logger.debug(f'Adding dset: {name}') + zarr_parent.create_dataset( + name, + **child, + ) + # Set dataset attributes + if 'attributes' in child: + for key, value in child['attributes'].items(): + zarr_parent[name].attrs[key] = value + else: + # It's a group + group = zarr_parent.create_group(name) + create_group_or_dataset(child, group, indent=indent+2) + results = zarr.create_group(store=MemoryStore({})) + create_group_or_dataset(tree, results) + return results diff --git a/CHAP/tomo/processor.py b/CHAP/tomo/processor.py index 18993aa5..9d794348 100755 --- a/CHAP/tomo/processor.py +++ b/CHAP/tomo/processor.py @@ -1627,7 +1627,6 @@ def _set_detector_bounds( # Try to get a fit from the bright field row_sum = np.sum(tbf, 1) num = len(row_sum) - fit = FitProcessor(**self.run_config) model = {'model': 'rectangle', 'parameters': [ {'name': 'amplitude', @@ -1641,11 +1640,10 @@ def _set_detector_bounds( 'min': 0.0, 'max': num}, {'name': 'sigma2', 'value': num/7.0, 'min': sys.float_info.min}]} - bounds_fit = fit.process( - data=NXdata( - NXfield(row_sum, 'y'), - NXfield(np.array(range(num)), 'x')), - config={'models': [model], 'method': 'trf'}) + bounds_fit = FitProcessor.run( + data={'x': np.array(range(num)), 'y':row_sum}, + config={'models': [model], 'method': 'trf'}, + **self.run_config) parameters = bounds_fit.best_values row_low_fit = parameters.get('center1', None) row_upp_fit = parameters.get('center2', None) diff --git a/CHAP/utils/fit.py b/CHAP/utils/fit.py index c626b74b..019608af 100755 --- a/CHAP/utils/fit.py +++ b/CHAP/utils/fit.py @@ -14,6 +14,7 @@ from shutil import rmtree from sys import float_info #from time import time +from typing import Optional # Third party modules try: @@ -24,8 +25,8 @@ HAVE_JOBLIB = True except ImportError: HAVE_JOBLIB = False -from nexusformat.nexus import NXdata import numpy as np +from pydantic import Field # Local modules from CHAP.processor import Processor @@ -35,6 +36,7 @@ index_nearest, quick_plot, ) +from CHAP.utils.models import FitConfig FLOAT_MIN = float_info.min FLOAT_MAX = float_info.max @@ -60,168 +62,193 @@ class FitProcessor(Processor): - """A processor to perform a fit on a data set or data map. """ + """A processor to perform a fit on a data set or data map. - def process(self, data, config=None): + :ivar config: Initialization parameters for an instance of + :class:`~CHAP.utils.models.FitConfig`. + :vartype config: dict, optional + """ + + pipeline_fields: dict = Field( + default = { + 'config': 'CHAP.utils.models.FitConfig'}, init_var=True) + config: Optional[FitConfig] = None + + def _get_pipelinedata_item(self, data, remove=True): + """Retrieve the input data to :meth:`process` from the list of + PipelineData items. + + :param data: Input data. + :type data: list[PipelineData] + :param remove: If there is a matching entry in `data`, remove + it from the list, defaults to `True`. + :type remove: bool, optional + :return: Matching data item(s). + :rtype: Any + """ + # Retrieve the data to (re)fit from the pipeline + for i, d in reversed(list(enumerate(data))): + ddata = d.get('data') + if isinstance(ddata, (Fit, FitMap)): + if remove: + data.pop(i) + return ddata + if d.get('name') == 'signal': + if remove: + data.pop(i) + break + else: + raise ValueError( + f'Unable to extract suitable fit input data from {data}') + + # Retrieve the optional coordinates from the pipeline + try: + y = np.asarray(ddata) + for i, d in reversed(list(enumerate(data))): + if d.get('name') == 'coordinates': + x = np.asarray(d.get('data')) + if remove: + data.pop(i) + assert x.size == y.shape[-1] + break + else: + x = None + except (ValueError, TypeError) as exc: + raise ValueError( + f'Unable to extract suitable fit input data from {data}') + return x, y + + def process(self, data): """Fit the data and return a :class:`~CHAP.utils.fit.Fit` or :class:`~CHAP.utils.fit.FitMap` object depending on the - dimensionality of the input data. The input data should be or - contain a NeXus style - `NXdata `__ - object, with properly defined signal and axis, or a - :class:`~CHAP.utils.fit.Fit` or :class:`~CHAP.utils.fit.FitMap` - object from a previous fit. - - :param data: Input data containing the - nexusformat.nexus.NXdata object to fit. - :type data: list[PipelineData] or Fit or FitMap or - nexusformat.nexus.NXdata - :param config: Fit configuration. - :type config: dict, optional - :raises ValueError: Invalid input or configuration parameter. + dimensionality of the input data. The input data should be a + list of PipelineData items containing either one with a `data` + field of type :class:`~CHAP.utils.fit.Fit` or + :class:`~CHAP.utils.fit.FitMap` (to refit or continue a + previous fit), or one with the `name` of `signal` and an + array-like `data` field. In the latter case, optional + x-coordinates can be supplied by a second PipelineData item + with the `name` of `coordinates` and again an array-like + `data` field. + + :param data: Input data. + :type data: list[PipelineData] :return: The fitted data object. :rtype: Fit or FitMap """ # Local modules - from CHAP.utils.models import ( - FitConfig, - Multipeak, - ) - - # Unwrap the PipelineData if called as a Pipeline Processor - if (not isinstance(data, (Fit, FitMap)) - and not isinstance(data, NXdata)): - data = self.get_pipelinedata_item(data) + from CHAP.utils.models import MultipeakModel - # Get the validated fit configuration - fit_config = None - if config is not None: - try: - fit_config = FitConfig(**config) - except Exception as exc: - raise RuntimeError from exc + # Unwrap the PipelineData + data = self._get_pipelinedata_item(data) if isinstance(data, (Fit, FitMap)): # Refit/continue the fit with possibly updated parameters fit = data - if isinstance(data, FitMap): - fit.fit(config=fit_config) - else: - fit.fit(config=fit_config) - if fit_config is not None: - if fit_config.print_report: - fit.print_fit_report() - if fit_config.plot: - fit.plot(skip_init=True) - - elif isinstance(data, NXdata): + fit.fit(config=self.config, max_nfev=self.config.max_nfev) + if self.config is not None and not isinstance(data, FitMap): + if self.config.print_report: + fit.print_fit_report() + if self.config.plot: + fit.plot(skip_init=True) - # Get the default NXdata object - try: - nxdata = data.get_default() - assert nxdata is not None - except (AssertionError, ValueError) as exc: - if nxdata is None or nxdata.nxclass != 'NXdata': - raise ValueError( - 'Invalid default pathway to an NXdata ' - f'object in ({data})') from exc + else: # Expand multipeak model if present found_multipeak = False - for i, model in enumerate(deepcopy(fit_config.models)): - if isinstance(model, Multipeak): + for i, model in enumerate(deepcopy(self.config.models)): + if isinstance(model, MultipeakModel): if found_multipeak: raise ValueError( - f'Invalid parameter models ({fit_config.models}) ' + f'Invalid parameter models ({self.config.models}) ' '(multiple instances of multipeak not allowed)') parameters, models = self.create_multipeak_model(model) if parameters: - fit_config.parameters += parameters - fit_config.models += models - fit_config.models.pop(i) + self.config.parameters += parameters + self.config.models += models + self.config.models.pop(i) found_multipeak = True # Instantiate the Fit or FitMap object and fit the data - if np.squeeze(nxdata.nxsignal).ndim == 1: - fit = Fit(nxdata, fit_config, self.logger) - fit.fit() - if fit_config.print_report: + if np.squeeze(data[1]).ndim == 1: + fit = Fit(data[1], self.config, self.logger, x=data[0]) + fit.fit(max_nfev=self.config.max_nfev) + if self.config.print_report: fit.print_fit_report() - if fit_config.plot: + if self.config.plot: fit.plot(skip_init=True) else: - fit = FitMap(nxdata, fit_config, self.logger) + fit = FitMap(data[1], self.config, self.logger, x=data[0]) fit.fit( - rel_height_cutoff=fit_config.rel_height_cutoff, - num_proc=fit_config.num_proc, plot=fit_config.plot, - print_report=fit_config.print_report) - else: - raise ValueError(f'Invalid input data ({type(data)}: {data})') + rel_height_cutoff=self.config.rel_height_cutoff, + max_nfev=self.config.max_nfev, + num_proc=self.config.num_proc, + plot=self.config.plot, + print_report=self.config.print_report) return fit @staticmethod - def create_multipeak_model(model_config): + def create_multipeak_model(model): """Create a multipeak model. - :param model_config: A Multipeak fit model class. - :type model_config: :class:`~CHAP.utils.models.Multipeak` + :param model: A Multipeak fit model class. + :type model: :class:`~CHAP.utils.models.MultipeakModel` :return: The fit parameters and fit model classes. :rtype: list[:attr:`~CHAP.utils.models.FitParameter`], list[:attr:`~CHAP.utils.models.FitConfig.models`] """ # Local modules from CHAP.utils.models import ( + PEAK_LIKE_MODELS, FitParameter, - models, ) - peak_model_name = model_config.peak_models - peak_model_class = models[peak_model_name]['class'] + peak_model_class = PEAK_LIKE_MODELS[model.peak_models] parameters = [] peak_models = [] - num_peak = len(model_config.centers) - if num_peak == 1 and model_config.fit_type == 'uniform': - model_config.fit_type = 'unconstrained' + num_peak = len(model.centers) + if num_peak == 1 and model.fit_type == 'uniform': + model.fit_type = 'unconstrained' sig_min = FLOAT_MIN sig_max = np.inf - if (model_config.fwhm_min is not None - or model_config.fwhm_max is not None): + if (model.fwhm_min is not None + or model.fwhm_max is not None): # Third party modules from asteval import Interpreter ast = Interpreter() - if model_config.fwhm_min is not None: - ast(f'fwhm = {model_config.fwhm_min}') - sig_min = ast(fwhm_factor[model_config.peak_models]) - if model_config.fwhm_max is not None: - ast(f'fwhm = {model_config.fwhm_max}') - sig_max = ast(fwhm_factor[model_config.peak_models]) + if model.fwhm_min is not None: + ast(f'fwhm = {model.fwhm_min}') + sig_min = ast(fwhm_factor[model.peak_models]) + if model.fwhm_max is not None: + ast(f'fwhm = {model.fwhm_max}') + sig_max = ast(fwhm_factor[model.peak_models]) prefix = '' - if model_config.fit_type == 'uniform': + if model.fit_type == 'uniform': parameters.append(FitParameter( name='scale_factor', value=1.0, min=FLOAT_MIN)) - for i, cen in enumerate(model_config.centers): + for i, cen in enumerate(model.centers): if num_peak > 1: prefix = f'peak{i+1}_' peak_models.append(peak_model_class( - model=peak_model_name, + model_type=model.peak_models, prefix=prefix, parameters=[ {'name': 'amplitude', 'min': FLOAT_MIN}, {'name': 'center', 'expr': f'scale_factor*{cen}'}, {'name': 'sigma', 'min': sig_min, 'max': sig_max}])) else: - for i, cen in enumerate(model_config.centers): + for i, cen in enumerate(model.centers): if num_peak > 1: prefix = f'peak{i+1}_' - if model_config.centers_range == 0: + if model.centers_range == 0: peak_models.append(peak_model_class( - model=peak_model_name, + model_type=model.peak_models, prefix=prefix, parameters=[ {'name': 'amplitude', 'min': FLOAT_MIN}, @@ -229,14 +256,14 @@ def create_multipeak_model(model_config): {'name': 'sigma', 'min': sig_min, 'max': sig_max} ])) else: - if model_config.centers_range is None: + if model.centers_range is None: cen_min = None cen_max = None else: - cen_min = cen - model_config.centers_range - cen_max = cen + model_config.centers_range + cen_min = cen - model.centers_range + cen_max = cen + model.centers_range peak_models.append(peak_model_class( - model=peak_model_name, + model_type=model.peak_models, prefix=prefix, parameters=[ {'name': 'amplitude', 'min': FLOAT_MIN}, @@ -260,12 +287,15 @@ def __init__(self, model, prefix=''): :type prefix: str, optional """ # Local modules - from CHAP.utils.models import models + #from CHAP.utils.models import MODEL_TYPE_TO_CLASS - self.func = models[model.model]['name'] + self.func = model.func #MODEL_TYPE_TO_CLASS[model.model_type] + self.func_args = model._func_args #MODEL_TYPE_TO_CLASS[model.model_type] self.param_names = [f'{prefix}{par.name}' for par in model.parameters] self.prefix = prefix - self._name = model.model + self._name = model.model_type + names = [f'{par.name}' for par in model.parameters] + self.func_args_indices = [names.index(arg) for arg in self.func_args] class Components(dict): @@ -291,24 +321,6 @@ def components(self): """ return self.values() - def add(self, model, prefix=''): - """Add a model to the model fit components dictionary. - - :param model: A fit model class. - :type model: :attr:`~CHAP.utils.models.FitConfig.models` - :param prefix: Model prefix. - :type prefix: str, optional - """ - # Local modules - from CHAP.utils.models import model_classes - - if not isinstance(model, model_classes): - raise ValueError(f'Invalid parameter model ({model})') - if not isinstance(prefix, str): - raise ValueError(f'Invalid parameter prefix ({prefix})') - name = f'{prefix}{model.model}' - self.__setitem__(name, Component(model, prefix)) - class Parameters(dict): """A dictionary of FitParameter objects, mimicking the @@ -364,7 +376,7 @@ class ModelResult(): """ def __init__( - self, model, parameters, x=None, y=None, method=None, ast=None, + self, model, parameters, *, x=None, y=None, method=None, ast=None, res_par_exprs=None, res_par_indices=None, res_par_names=None, result=None): """Initialize SetNumexprThreads. @@ -505,7 +517,9 @@ def eval_components(self, x=None, parameters=None): name = component._name else: name = component.prefix - result[name] = component.func(x, *par_values) + ppar_values = tuple( + par_values[i] for i in component.func_args_indices) + result[name] = component.func(x, *ppar_values) return result def fit_report(self, show_correl=False): @@ -577,22 +591,24 @@ def fit_report(self, show_correl=False): class Fit: """Wrapper class for scipy/lmfit.""" - def __init__(self, nxdata, config, logger): + def __init__(self, y, config, logger, x=None): """Initialize Fit. - :param nxdata: The input data. - :type nxdata: nexusformat.nexus.NXdata + :param y: Input signal data. + :type y: array-like :param config: Fit configuration. - :type config: CHAP.utils.models.FitConfig, optional + :type config: CHAP.utils.models.FitConfig :param logger: A python Logger object. :type logger: logging.Logger + :param x: Input coordinate data. + :type x: array-like, optional """ self._code = config.code for model in config.models: - if model.model == 'expression' and self._code != 'lmfit': + if model.model_type == 'expression' and self._code != 'lmfit': self._code = 'lmfit' - logger.warning('Using lmfit instead of scipy with ' - 'an expression model') + logger.warning('Using lmfit instead of scipy with an ' + 'expression model') if self._code == 'scipy': # Local modules from CHAP.utils.fit import Parameters @@ -633,20 +649,16 @@ def __init__(self, nxdata, config, logger): # raise ValueError( # 'Invalid value of keyword argument try_linear_fit ' # f'({self._try_linear_fit})') - if nxdata is not None: - if isinstance(nxdata.attrs['axes'], str): - dim_x = nxdata.attrs['axes'] - else: - dim_x = nxdata.attrs['axes'][-1] - self._x = np.asarray(nxdata[dim_x]) - self._y = np.squeeze(nxdata.nxsignal) - if self._x.ndim != 1: - raise ValueError( - f'Invalid x dimension ({self._x.ndim})') - if self._x.size != self._y.size: + if y is not None: + self._y = np.squeeze(y) + if self._y.ndim != 1: raise ValueError( - f'Inconsistent x and y dimensions ({self._x.size} vs ' - f'{self._y.size})') + f'Invalid input signal dimension ({self._y.ndim})') + if x is None: + self._x = np.arange(self._y.size) + else: + self._x = x + assert self._x.size == self._y.size # if 'mask' in kwargs: # self._mask = kwargs.pop('mask') if True: #self._mask is None: @@ -928,14 +940,6 @@ def var_names(self): return None return getattr(self._result, 'var_names', None) - @property - def x(self): - """Return the input x-coordinates. - - :type: numpy.ndarray - """ - return self._x - @property def y(self): """Return the input y-coordinates. @@ -1010,12 +1014,10 @@ def add_model(self, model, prefix): RectangleModel, ) - if model.model == 'expression': + if model.model_type == 'expression': expr = model.expr else: expr = None - parameters = model.parameters - model_name = model.model if prefix is None: pprefix = '' @@ -1023,149 +1025,147 @@ def add_model(self, model, prefix): pprefix = prefix if self._code == 'scipy': new_parameters = [] - for par in deepcopy(parameters): + for par in deepcopy(model.parameters): + name = par.name self._parameters.add(par, pprefix) if self._parameters[par.name].expr is None: self._parameters[par.name].set(value=par.default) new_parameters.append(par.name) - self._res_num_pars += [len(parameters)] + if name in model.LINEAR_PARAMETERS: + self._linear_parameters.append(par.name) + elif name not in model.MODEL_PARAMETERS: + self._nonlinear_parameters.append(par.name) + self._res_num_pars += [len(model.parameters)] - if model_name == 'constant': - # Par: c - if self._code == 'lmfit': + if self._code == 'lmfit': + if model.model_type == 'constant': + # Par: c newmodel = ConstantModel(prefix=prefix) - self._linear_parameters.append(f'{pprefix}c') - elif model_name == 'linear': - # Par: slope, intercept - if self._code == 'lmfit': + self._linear_parameters.append(f'{pprefix}c') + elif model.model_type == 'linear': + # Par: slope, intercept newmodel = LinearModel(prefix=prefix) - self._linear_parameters.append(f'{pprefix}slope') - self._linear_parameters.append(f'{pprefix}intercept') - elif model_name == 'quadratic': - # Par: a, b, c - if self._code == 'lmfit': + self._linear_parameters.append(f'{pprefix}slope') + self._linear_parameters.append(f'{pprefix}intercept') + elif model.model_type == 'parabolic': + # Par: a, b, c newmodel = QuadraticModel(prefix=prefix) - self._linear_parameters.append(f'{pprefix}a') - self._linear_parameters.append(f'{pprefix}b') - self._linear_parameters.append(f'{pprefix}c') -# elif model_name == 'polynomial': -# # Par: c0, c1,..., c7 -# degree = kwargs.get('degree') -# if degree is not None: -# kwargs.pop('degree') -# if degree is None or not is_int(degree, ge=0, le=7): -# raise ValueError( -# 'Invalid parameter degree for build-in step model ' -# f'({degree})') -# if self._code == 'lmfit': + self._linear_parameters.append(f'{pprefix}a') + self._linear_parameters.append(f'{pprefix}b') + self._linear_parameters.append(f'{pprefix}c') +# elif model.model_type == 'polynomial': +# # Par: c0, c1,..., c7 +# degree = kwargs.get('degree') +# if degree is not None: +# kwargs.pop('degree') +# if degree is None or not is_int(degree, ge=0, le=7): +# raise ValueError( +# 'Invalid parameter degree for build-in step model ' +# f'({degree})') # newmodel = PolynomialModel(degree=degree, prefix=prefix) -# for i in range(degree+1): -# self._linear_parameters.append(f'{pprefix}c{i}') - elif model_name == 'exponential': - # Par: amplitude, decay - if self._code == 'lmfit': +# for i in range(degree+1): +# self._linear_parameters.append(f'{pprefix}c{i}') + elif model.model_type == 'exponential': + # Par: amplitude, decay newmodel = ExponentialModel(prefix=prefix) - self._linear_parameters.append(f'{pprefix}amplitude') - self._nonlinear_parameters.append(f'{pprefix}decay') - elif model_name == 'gaussian': - # Par: amplitude, center, sigma (fwhm, height) - if self._code == 'lmfit': + self._linear_parameters.append(f'{pprefix}amplitude') + self._nonlinear_parameters.append(f'{pprefix}decay') + elif model.model_type == 'gaussian': + # Par: amplitude, center, sigma (fwhm, height) newmodel = GaussianModel(prefix=prefix) # parameter norms for height and fwhm are needed to # get correct errors - self._linear_parameters.append(f'{pprefix}amplitude') - self._nonlinear_parameters.append(f'{pprefix}center') - self._nonlinear_parameters.append(f'{pprefix}sigma') - elif model_name == 'lorentzian': - # Par: amplitude, center, sigma (fwhm, height) - if self._code == 'lmfit': + self._linear_parameters.append(f'{pprefix}amplitude') + self._nonlinear_parameters.append(f'{pprefix}center') + self._nonlinear_parameters.append(f'{pprefix}sigma') + elif model.model_type == 'lorentzian': + # Par: amplitude, center, sigma (fwhm, height) newmodel = LorentzianModel(prefix=prefix) # parameter norms for height and fwhm are needed to # get correct errors - self._linear_parameters.append(f'{pprefix}amplitude') - self._nonlinear_parameters.append(f'{pprefix}center') - self._nonlinear_parameters.append(f'{pprefix}sigma') - elif model_name == 'pvoigt': - # Par: amplitude, center, sigma (fwhm, height), fraction - if self._code == 'lmfit': + self._linear_parameters.append(f'{pprefix}amplitude') + self._nonlinear_parameters.append(f'{pprefix}center') + self._nonlinear_parameters.append(f'{pprefix}sigma') + elif model.model_type == 'pvoigt': + # Par: amplitude, center, sigma (fwhm, height), fraction newmodel = PseudoVoigtModel(prefix=prefix) # parameter norms for height and fwhm are needed to # get correct errors - self._linear_parameters.append(f'{pprefix}amplitude') - self._linear_parameters.append(f'{pprefix}fraction') - self._nonlinear_parameters.append(f'{pprefix}center') - self._nonlinear_parameters.append(f'{pprefix}sigma') -# elif model_name == 'step': -# # Par: amplitude, center, sigma -# form = kwargs.get('form') -# if form is not None: -# kwargs.pop('form') -# if (form is None or form not in -# ('linear', 'atan', 'arctan', 'erf', 'logistic')): -# raise ValueError( -# 'Invalid parameter form for build-in step model ' -# f'({form})') -# if self._code == 'lmfit': + self._linear_parameters.append(f'{pprefix}amplitude') + self._linear_parameters.append(f'{pprefix}fraction') + self._nonlinear_parameters.append(f'{pprefix}center') + self._nonlinear_parameters.append(f'{pprefix}sigma') +# elif model.model_type == 'step': +# # Par: amplitude, center, sigma +# form = kwargs.get('form') +# if form is not None: +# kwargs.pop('form') +# if (form is None or form not in +# ('linear', 'atan', 'arctan', 'erf', 'logistic')): +# raise ValueError( +# 'Invalid parameter form for build-in step model ' +# f'({form})') # newmodel = StepModel(prefix=prefix, form=form) -# self._linear_parameters.append(f'{pprefix}amplitude') -# self._nonlinear_parameters.append(f'{pprefix}center') -# self._nonlinear_parameters.append(f'{pprefix}sigma') - elif model_name == 'rectangle': - # Par: amplitude, center1, center2, sigma1, sigma2 - form = 'atan' #kwargs.get('form') - #if form is not None: - # kwargs.pop('form') - # RV: Implement and test other forms when needed - if (form is None or form not in - ('linear', 'atan', 'arctan', 'erf', 'logistic')): - raise ValueError( - 'Invalid parameter form for build-in rectangle model ' - f'({form})') - if self._code == 'lmfit': +# self._linear_parameters.append(f'{pprefix}amplitude') +# self._nonlinear_parameters.append(f'{pprefix}center') +# self._nonlinear_parameters.append(f'{pprefix}sigma') + elif model.model_type == 'rectangle': + # Par: amplitude, center1, center2, sigma1, sigma2 + form = 'atan' #kwargs.get('form') + #if form is not None: + # kwargs.pop('form') + # RV: Implement and test other forms when needed + if (form is None or form not in + ('linear', 'atan', 'arctan', 'erf', 'logistic')): + raise ValueError( + 'Invalid parameter form for build-in rectangle model ' + f'({form})') newmodel = RectangleModel(prefix=prefix, form=form) - self._linear_parameters.append(f'{pprefix}amplitude') - self._nonlinear_parameters.append(f'{pprefix}center1') - self._nonlinear_parameters.append(f'{pprefix}center2') - self._nonlinear_parameters.append(f'{pprefix}sigma1') - self._nonlinear_parameters.append(f'{pprefix}sigma2') - elif model_name == 'expression' and self._code == 'lmfit': - # Third party modules - from asteval import ( - Interpreter, - get_ast_names, - ) + self._linear_parameters.append(f'{pprefix}amplitude') + self._nonlinear_parameters.append(f'{pprefix}center1') + self._nonlinear_parameters.append(f'{pprefix}center2') + self._nonlinear_parameters.append(f'{pprefix}sigma1') + self._nonlinear_parameters.append(f'{pprefix}sigma2') + elif model.model_type == 'expression': + # FIX move to a validator + # Third party modules + from asteval import ( + Interpreter, + get_ast_names, + ) - for par in parameters: - if par.expr is not None: - raise KeyError( - f'Invalid "expr" key ({par.expr}) in ' - f'parameter ({par}) for an expression model') - ast = Interpreter() - expr_parameters = [ - name for name in get_ast_names(ast.parse(expr)) - if (name != 'x' and name not in self._parameters - and name not in ast.symtable)] - if prefix is None: - newmodel = ExpressionModel(expr=expr) - else: - for name in expr_parameters: - expr = sub(rf'\b{name}\b', f'{prefix}{name}', expr) + for par in model.parameters: + if par.expr is not None: + raise KeyError( + f'Invalid "expr" key ({par.expr}) in ' + f'parameter ({par}) for an expression model') + ast = Interpreter() expr_parameters = [ - f'{prefix}{name}' for name in expr_parameters] - newmodel = ExpressionModel(expr=expr, name=model_name) - # Remove already existing names - for name in newmodel.param_names.copy(): - if name not in expr_parameters: - newmodel._func_allargs.remove(name) - newmodel._param_names.remove(name) - else: - raise ValueError(f'Unknown fit model ({model_name})') + name for name in get_ast_names(ast.parse(expr)) + if (name != 'x' and name not in self._parameters + and name not in ast.symtable)] + if prefix is None: + newmodel = ExpressionModel(expr=expr) + else: + for name in expr_parameters: + expr = sub(rf'\b{name}\b', f'{prefix}{name}', expr) + expr_parameters = [ + f'{prefix}{name}' for name in expr_parameters] + newmodel = ExpressionModel(expr=expr, name=model.model_type) + # Remove already existing names + for name in newmodel.param_names.copy(): + if name not in expr_parameters: + newmodel._func_allargs.remove(name) + newmodel._param_names.remove(name) + else: + raise ValueError(f'Unknown fit model ({model.model_type})') # Add the new model to the current one if self._code == 'scipy': if self._model is None: self._model = Components() - self._model.add(model, prefix) + self._model |= { + f'{prefix}{model.model_type}': Component(model, prefix)} else: if self._model is None: self._model = newmodel @@ -1201,23 +1201,28 @@ def add_model(self, model, prefix): value = par.value _min = par.min _max = par.max - if not (model_name == 'pvoigt' and 'fraction' in name): - if value is not None: - value *= self._norm[1] - if not np.isinf(_min) and abs(_min) != FLOAT_MIN: - _min *= self._norm[1] - if not np.isinf(_max) and abs(_max) != FLOAT_MIN: - _max *= self._norm[1] +# if not (model.model_type == 'pvoigt' and 'fraction' in name): +# if value is not None: +# value *= self._norm[1] +# if not np.isinf(_min) and abs(_min) != FLOAT_MIN: +# _min *= self._norm[1] +# if not np.isinf(_max) and abs(_max) != FLOAT_MIN: +# _max *= self._norm[1] + if value is not None: + value *= self._norm[1] + if not np.isinf(_min) and abs(_min) != FLOAT_MIN: + _min *= self._norm[1] + if not np.isinf(_max) and abs(_max) != FLOAT_MIN: + _max *= self._norm[1] par.set(value=value, min=_min, max=_max) # Initialize the model parameters - for parameter in deepcopy(parameters): + for parameter in deepcopy(model.parameters): name = parameter.name if name not in new_parameters: name = pprefix+name if name not in new_parameters: - raise ValueError( - f'Unable to match parameter {name}') + raise ValueError(f'Unable to match parameter {name}') if parameter.expr is None: self._parameters[name].set( value=parameter.value, min=parameter.min, @@ -1258,9 +1263,8 @@ def fit(self, config=None, **kwargs): :param config: Fit configuration. :type config: CHAP.utils.models.FitConfig, optional - :param kwargs: Additional key, value pairs to pass on directly - to the core fit routine. - :type kwargs: dict + :param \*\*kwargs: Additional key, value pairs to pass on + directly to the core fit routine. """ # Check input parameters if self._model is None: @@ -1332,11 +1336,13 @@ def fit(self, config=None, **kwargs): return None def plot( - self, y=None, *, y_title=None, title=None, result=None, + self, x=None, y=None, *, y_title=None, title=None, result=None, skip_init=False, plot_comp=True, plot_comp_legends=False, plot_residual=False, plot_masked_data=True, **kwargs): """Plot the best fit. + :param x: x-coordinates. + :type x: array-like, optional :param y: y-coordinates. :type y: array-like, optional :param y_title: y-axis label. @@ -1359,9 +1365,8 @@ class attribute. :type plot_residual: bool, optional :param plot_masked_data: :type plot_masked_data: bool, optional - :param kwargs: Additional key, value pairs to pass on directly - to the Matplotlib plot function. - :type kwargs: dict + :param \*\*kwargs: Additional key, value pairs to pass on + directly to the Matplotlib plot function. """ if result is None: result = self._result @@ -1374,34 +1379,45 @@ class attribute. plot_masked_data = False else: mask = self._mask + if x is not None: + if not isinstance(x, (tuple, list, np.ndarray)): + self._logger.warning( + 'Ignoring invalid parameter x ({type(x)})') + if len(x) != len(self._x): + self._logger.warning( + 'Ignoring parameter x in plot (wrong dimension)') + x = None + if x is None: + x = self._x if y is not None: if not isinstance(y, (tuple, list, np.ndarray)): - self._logger.warning('Ignorint invalid parameter y ({y}') - if len(y) != len(self._x): + self._logger.warning( + 'Ignoring invalid parameter y ({type(y)})') + if len(y) != len(x): self._logger.warning( 'Ignoring parameter y in plot (wrong dimension)') y = None if y is not None: if y_title is None or not isinstance(y_title, str): y_title = 'data' - plots += [(self._x, y, '.')] + plots += [(x, y, '.')] legend += [y_title] if self._y is not None: - plots += [(self._x, np.asarray(self._y), 'b.')] + plots += [(x, np.asarray(self._y), 'b.')] legend += ['data'] if plot_masked_data: - plots += [(self._x[mask], np.asarray(self._y)[mask], 'bx')] + plots += [(x[mask], np.asarray(self._y)[mask], 'bx')] legend += ['masked data'] if isinstance(plot_residual, bool) and plot_residual: - plots += [(self._x[~mask], result.residual, 'r-')] + plots += [(x[~mask], result.residual, 'r-')] legend += ['residual'] - plots += [(self._x[~mask], result.best_fit, 'k-')] + plots += [(x[~mask], result.best_fit, 'k-')] legend += ['best fit'] if not skip_init and hasattr(result, 'init_fit'): - plots += [(self._x[~mask], result.init_fit, 'g-')] + plots += [(x[~mask], result.init_fit, 'g-')] legend += ['init'] if plot_comp: - components = result.eval_components(x=self._x[~mask]) + components = result.eval_components(x=x[~mask]) num_components = len(components) if 'tmp_normalization_offset_' in components: num_components -= 1 @@ -1415,8 +1431,8 @@ class attribute. if len(modelname) > 20: modelname = f'{modelname[0:16]} ...' if isinstance(y_comp, (int, float)): - y_comp *= np.ones(self._x[~mask].size) - plots += [(self._x[~mask], y_comp, '--')] + y_comp *= np.ones(x[~mask].size) + plots += [(x[~mask], y_comp, '--')] if plot_comp_legends: if modelname[-1] == '_': legend.append(modelname[:-1]) @@ -1543,7 +1559,7 @@ def _create_prefixes(self, models): names = [] prefixes = [] for model in models: - names.append(f'{model.prefix}{model.model}') + names.append(f'{model.prefix}{model.model_type}') prefixes.append(model.prefix) counts = Counter(names) for model, count in counts.items(): @@ -1656,10 +1672,10 @@ def _setup_fit(self, config, guess=False): # Add constant offset for a normalized model if self._result is None and self._norm is not None and self._norm[0]: # Local modules - from CHAP.utils.models import Constant + from CHAP.utils.models import ConstantModel - model = Constant( - model='constant', + model = ConstantModel( + model_type='constant', parameters=[{ 'name': 'c', 'value': -self._norm[0], @@ -1672,14 +1688,14 @@ def _setup_fit(self, config, guess=False): # Local modules from CHAP.utils.models import ( FitConfig, - Multipeak, + MultipeakModel, ) # Expand multipeak model if present scale_factor = None for i, model in enumerate(deepcopy(config.models)): found_multipeak = False - if isinstance(model, Multipeak): + if isinstance(model, MultipeakModel): if found_multipeak: raise ValueError( f'Invalid parameter models ({config.models}) ' @@ -1763,10 +1779,9 @@ def _setup_fit(self, config, guess=False): if par.expr: self._res_par_exprs.append( {'expr': par.expr, 'index': i}) - else: - if par.vary: - self._res_par_indices.append(i) - self._res_par_names.append(name) + elif par.vary: + self._res_par_indices.append(i) + self._res_par_names.append(name) # Check for uninitialized parameters for name, par in self._parameters.items(): @@ -2082,28 +2097,29 @@ def _fit_nonlinear_model(self, x, y, **kwargs): else: bounds = (-np.inf, np.inf) init_params = deepcopy(self._parameters) -# t0 = time() lskws = { 'ftol': 1.49012e-08, 'xtol': 1.49012e-08, 'gtol': 10*FLOAT_EPS, } + max_nfev = kwargs.get('max_nfev') if self._method == 'leastsq': - lskws['maxfev'] = 64000 + if max_nfev is not None: + lskws['maxfev'] = max_nfev result = leastsq( self._residual, pars_init, args=(x, y), full_output=True, **lskws) else: - lskws['max_nfev'] = 64000 + if max_nfev is not None: + lskws['max_nfev'] = max_nfev result = least_squares( self._residual, pars_init, bounds=bounds, method=self._method, args=(x, y), **lskws) -# t1 = time() -# print(f'\n\nFitting took {1000*(t1-t0):.3f} ms\n\n') model_result = ModelResult( - self._model, self._parameters, x, y, self._method, self._ast, - self._res_par_exprs, self._res_par_indices, - self._res_par_names, result) + self._model, self._parameters, x=x, y=y, method=self._method, + ast=self._ast, res_par_exprs=self._res_par_exprs, + res_par_indices=self._res_par_indices, + res_par_names=self._res_par_names, result=result) model_result.init_params = init_params model_result.init_values = {} for name, par in init_params.items(): @@ -2113,12 +2129,9 @@ def _fit_nonlinear_model(self, x, y, **kwargs): fit_kws = {} # if 'Dfun' in kwargs: # fit_kws['Dfun'] = kwargs.pop('Dfun') -# t0 = time() model_result = self._model.fit( y, self._parameters, x=x, method=self._method, fit_kws=fit_kws, **kwargs) -# t1 = time() -# print(f'\n\nFitting took {1000*(t1-t0):.3f} ms\n\n') return model_result @@ -2304,8 +2317,11 @@ def _residual(self, pars, x, y): self._ast.eval(expr['expr']) for component, num_par in zip( self._model.components, self._res_num_pars): + values = self._res_par_values[n_par:n_par+num_par] + vvalues = [values[i] for i in component.func_args_indices] res += component.func( - x, *tuple(self._res_par_values[n_par:n_par+num_par])) +# x, *tuple(self._res_par_values[n_par:n_par+num_par])) + x, *tuple(vvalues)) n_par += num_par return res - y @@ -2313,15 +2329,17 @@ def _residual(self, pars, x, y): class FitMap(Fit): """Wrapper to the Fit class to fit data on a N-dimensional map.""" - def __init__(self, nxdata, config, logger): + def __init__(self, y, config, logger, x=None): """Initialize FitMap. - :param nxdata: The input data. - :type nxdata: nexusformat.nexus.NXdata + :param y: Input signal data. + :type y: array-like :param config: Fit configuration. - :type config: CHAP.utils.models.FitConfig, optional + :type config: CHAP.utils.models.FitConfig :param logger: A python Logger object. :type logger: logging.Logger + :param x: Input coordinate data. + :type x: array-like, optional """ super().__init__(None, config, logger) self._best_errors = None @@ -2347,16 +2365,12 @@ def __init__(self, nxdata, config, logger): # At this point the fastest index should always be the signal # dimension so that the slowest ndim-1 dimensions are the # map dimensions - self._x = np.asarray(nxdata[nxdata.attrs['axes'][-1]]) - self._ymap = np.asarray(nxdata.nxsignal) - - # Check input parameters - if self._x.ndim != 1: - raise ValueError(f'Invalid x dimension ({self._x.ndim})') - if self._x.size != self._ymap.shape[-1]: - raise ValueError( - f'Inconsistent x and y dimensions ({self._x.size} vs ' - f'{self._ymap.shape[-1]})') + self._ymap = y + if x is None: + self._x = np.arange(self._ymap.shape[-1]) + else: + self._x = x + assert self._x.size == self._ymap.shape[-1] # Flatten the map # Store the flattened map in self._ymap_norm @@ -2387,7 +2401,7 @@ def __init__(self, nxdata, config, logger): self._y_range = ymap_max-ymap_min if self._y_range > 0.0: self._norm = (ymap_min, self._y_range) - self._ymap_norm = (self._ymap_norm-self._norm[0]) / self._norm[1] + self._ymap_norm = (self._ymap_norm-self._norm[0])/self._norm[1] else: self._redchi_cutoff *= self._y_range**2 @@ -2591,6 +2605,15 @@ def best_parameters(self, dims=None): """ if dims is None: return self._best_parameters +# FIX use something else, self._best_parameters is "reserved" to get the +# parameters in the EDD strain analysis and must return the order of the +# parameters in self.best_values and self.best_errors +# parameters_dict = {} +# for i, name in enumerate(self._best_parameters): +# parameters_dict[name] = { +# 'values': self._best_values[i], +# 'errors': self._best_errors[i]} +# return parameters_dict if (not isinstance(dims, (list, tuple)) or len(dims) != len(self._map_shape)): raise ValueError('Invalid parameter dims ({dims})') @@ -2629,10 +2652,12 @@ def freemem(self): self._logger.warning('Could not clean-up automatically.') def plot( - self, dims=None, *, y_title=None, plot_comp_legends=False, + self, x=None, dims=None, *, y_title=None, plot_comp_legends=False, plot_residual=False, plot_masked_data=True, **kwargs): """Plot the best fits. + :param x: x-coordinates. + :type x: array-like, optional :param dims: Map indices of the data point to plot, defaults to `None` which will plot the first data point. :type dims: list or tuple, optional @@ -2645,13 +2670,22 @@ def plot( :type plot_residual: bool, optional :param plot_masked_data: :type plot_masked_data: bool, optional - :param kwargs: Additional key, value pairs to pass on directly - to the Matplotlib plot function. - :type kwargs: dict + :param \*\*kwargs: Additional key, value pairs to pass on + directly to the Matplotlib plot function. """ # Third party modules from lmfit.models import ExpressionModel + if x is not None: + if not isinstance(x, (tuple, list, np.ndarray)): + self._logger.warning( + 'Ignoring invalid parameter x ({type(x)})') + if len(x) != len(self._x): + self._logger.warning( + 'Ignoring parameter x in plot (wrong dimension)') + x = None + if x is None: + x = self._x if dims is None: dims = [0]*len(self._map_shape) if (not isinstance(dims, (list, tuple)) @@ -2666,20 +2700,20 @@ def plot( if y_title is None or not isinstance(y_title, str): y_title = 'data' if self._mask is None: - mask = np.zeros(self._x.size).astype(bool) + mask = np.zeros(x.size).astype(bool) plot_masked_data = False else: mask = self._mask - plots = [(self._x, np.asarray(self._ymap[dims]), 'b.')] + plots = [(x, np.asarray(self._ymap[dims]), 'b.')] legend = [y_title] if plot_masked_data: plots += \ - [(self._x[mask], np.asarray(self._ymap)[(*dims,mask)], 'bx')] + [(x[mask], np.asarray(self._ymap)[(*dims,mask)], 'bx')] legend += ['masked data'] - plots += [(self._x[~mask], self.best_fit[dims], 'k-')] + plots += [(x[~mask], self.best_fit[dims], 'k-')] legend += ['best fit'] if plot_residual: - plots += [(self._x[~mask], self.residual[dims], 'r--')] + plots += [(x[~mask], self.residual[dims], 'r--')] legend += ['residual'] # Create current parameters parameters = deepcopy(self._parameters) @@ -2706,10 +2740,10 @@ def plot( modelname = f'{component._name}' if len(modelname) > 20: modelname = f'{modelname[0:16]} ...' - y = component.eval(params=parameters, x=self._x[~mask]) + y = component.eval(params=parameters, x=x[~mask]) if isinstance(y, (int, float)): - y *= np.ones(self._x[~mask].size) - plots += [(self._x[~mask], y, '--')] + y *= np.ones(x[~mask].size) + plots += [(x[~mask], y, '--')] if plot_comp_legends: legend.append(modelname) quick_plot( @@ -2720,9 +2754,8 @@ def fit(self, config=None, **kwargs): :param config: Fit configuration. :type config: CHAP.utils.models.FitConfig, optional - :param kwargs: Additional key, value pairs to pass on directly - to the core fit routine. - :type kwargs: dict + :param \*\*kwargs: Additional key, value pairs to pass on + directly to the core fit routine. """ # Check input parameters if self._model is None: @@ -3093,9 +3126,13 @@ def _fit(self, n, current_best_values, return_result=False, **kwargs): # Regular full fit result = self._fit_with_bounds_check(n, current_best_values, **kwargs) + if result.nfev == kwargs.get('max_nfev'): + self._logger.info( + f'Hit max_nfev limit for n={n}\n\tnfev: {result.nfev}') if self._rel_height_cutoff is not None: # Check for low heights peaks and refit without them + # FIX make sure to add "height" and "fwhm" to peak-like models heights = [] names = [] for component in result.components: @@ -3127,6 +3164,9 @@ def _fit(self, n, current_best_values, return_result=False, **kwargs): # Reset fixed amplitudes back to default self._parameters = deepcopy(parameters_save) self._parameter_bounds = deepcopy(parameters_bounds_save) + if result.nfev == kwargs.get('max_nfev'): + self._logger.info( + f'\nHit max_nfev limit again after refit for n={n}') if result.redchi >= self._redchi_cutoff: result.success = False diff --git a/CHAP/utils/general.py b/CHAP/utils/general.py index 6a717336..5b9cfb5a 100755 --- a/CHAP/utils/general.py +++ b/CHAP/utils/general.py @@ -2735,9 +2735,8 @@ def quick_imshow( :type grid_linewidth: int, optional :param colorbar: Include a colorbar, defaults to `False`. :type colorbar: bool, optional - :param kwargs: Any additional keyword parameters to pass on to + :param \*\*kwargs: Any additional keyword parameters to pass on to `matplotlib.pyplot.imshow `__. - :type kwargs: dict, optional :raise: ValueError for invalid input data or parameters. :return: In-memory object as a byte stream represention if `return_fig` is set. @@ -2853,9 +2852,8 @@ def quick_plot( :type save_only: bool, optional :param block: Wait for the image to be closed before returning. :type block: bool, optional - :param kwargs: Any additional keyword parameters to pass on to + :param \*\*kwargs: Any additional keyword parameters to pass on to `matplotlib.pyplot.plot `__ - :type kwargs: dict, optional :raise: ValueError for invalid input data or parameters. """ #FIX: Update with return_buf diff --git a/CHAP/utils/models.py b/CHAP/utils/models.py index 99d8a714..657265a3 100755 --- a/CHAP/utils/models.py +++ b/CHAP/utils/models.py @@ -4,6 +4,7 @@ # System modules from typing import ( + ClassVar, Literal, Optional, Union, @@ -19,6 +20,7 @@ confloat, constr, field_validator, + model_validator, ) from typing_extensions import Annotated import numpy as np @@ -27,9 +29,6 @@ from CHAP.models import CHAPBaseModel from CHAP.utils.general import not_zero, tiny -# pylint: disable=no-member -tiny = np.finfo(np.float64).resolution -# pylint: enable=no-member s2pi = np.sqrt(2*np.pi) s2ln2 = np.sqrt(2*np.log(2)) @@ -39,7 +38,7 @@ def constant(x, c=0.0): :param c: Constant, defaults to `0`. :type c: float, optional - :returns: The function evaluations. + :returns: Function evaluations. :rtype: numpy.ndarray .. math:: @@ -58,7 +57,7 @@ def linear(x, slope=1.0, intercept=0.0): :type slope: float, optional :param intercept: Intercept, defaults to `0`. :type intercept: float, optional - :returns: The function evaluations. + :returns: Function evaluations. :rtype: numpy.ndarray .. math:: @@ -71,8 +70,8 @@ def linear(x, slope=1.0, intercept=0.0): return slope * x + intercept -#def quadratic(x, a=0.5, b=0.4, c=0.1): -def quadratic(x, a=0.0, b=0.0, c=0.0): +#def parabolic(x, a=0.5, b=0.4, c=0.1): +def parabolic(x, a=0.0, b=0.0, c=0.0): r"""Return a parabolic function. :param a: Quadratic polynomial coefficient, defaults to an @@ -84,7 +83,7 @@ def quadratic(x, a=0.0, b=0.0, c=0.0): :param c: Constant polynomial coefficient, defaults to an initial value of `0`. :type c: float, optional - :returns: The function evaluations. + :returns: Function evaluations. :rtype: numpy.ndarray .. math:: @@ -104,7 +103,7 @@ def exponential(x, amplitude=1.0, decay=1.0): :type amplitude: float, optional :param decay: Exponential decay, defaults to `1`. :type decay: float, optional - :returns: The function evaluations. + :returns: Function evaluations. :rtype: numpy.ndarray .. math:: @@ -128,7 +127,7 @@ def gaussian(x, amplitude=1.0, center=0.0, sigma=1.0): :type center: float, optional :param sigma: Standard deviation, defaults to `1`. :type sigma: float, optional - :returns: The function evaluations. + :returns: Function evaluations. :rtype: numpy.ndarray .. math:: @@ -160,7 +159,7 @@ def lorentzian(x, amplitude=1.0, center=0.0, sigma=1.0): :type center: float, optional :param sigma: Standard deviation, defaults to `1`. :type sigma: float, optional - :returns: The function evaluations. + :returns: Function evaluations. :rtype: numpy.ndarray .. math:: @@ -196,7 +195,7 @@ def pvoigt(x, amplitude=1.0, center=0.0, sigma=1.0, fraction=0.5): :param fraction: Relative weight of the Gaussian and Lorentzian components, defaults to `0.5`. :type fraction: float, optional - :returns: The function evaluations. + :returns: Function evaluations. :rtype: numpy.ndarray .. math:: @@ -248,7 +247,7 @@ def rectangle( - ``'logistic'``: Sigmoidal (logistic function) transitions. :type form: str, optional - :returns: The evaluated rectangle function values. + :returns: Evaluated rectangle function values. :rtype: float or numpy.ndarray .. note:: @@ -289,53 +288,6 @@ def rectangle( return amplitude*rect -def validate_parameters(parameters, info): - """Validate the parameters. - - :param parameters: Fit model parameters. - :type parameters: list[FitParameter] - :param info: Model parameter validation information. - :type info: pydantic.ValidationInfo - :return: List of fit model parameters. - :rtype: list[FitParameter] - """ - # System imports - import inspect - - if 'model' in info.data: - model = info.data['model'] - else: - model = None - if model is None or model == 'expression': - return parameters - sig = dict(inspect.signature(models[model]['name']).parameters.items()) - sig.pop('x') - - # Check input model parameter validity - for par in parameters: - if par.name not in sig: - raise ValueError('Invalid parameter {par.name} in {model} model') - - # Set model parameters - output_parameters = [] - for sig_name, sig_par in sig.items(): - if model == 'rectangle' and sig_name == 'form': - continue - for par in parameters: - if sig_name == par.name: - break - else: - par = FitParameter(name=sig_name) - if sig_par.default != sig_par.empty: - par._default = sig_par.default - if model == 'pvoigt' and sig_name == 'fraction': - par.min = 0.0 - par.max = 1.0 - output_parameters.append(par) - - return output_parameters - - class FitParameter(CHAPBaseModel): """Class representing a specific fit parameter for the fit processor. @@ -489,232 +441,271 @@ def set(self, value=None, min=None, max=None, vary=None, expr=None): self.value = self.min self.expr = None -class Constant(CHAPBaseModel): + +class ConstantModel(CHAPBaseModel): """Class representing a Constant model component. - :ivar model: The model component base name (a prefix will be added - if multiple identical model components are added). - :vartype model: Literal['constant'] + :ivar model_type: Model component base name (a prefix will be + added if multiple identical model components are added). + :vartype model_type: Literal['constant'] :ivar parameters: Function parameters, defaults to those auto generated from the function signature (excluding the independent variable). :vartype parameters: list[FitParameter], optional - :ivar prefix: The model prefix, defaults to `''`. + :ivar prefix: Model prefix, defaults to `''`. :vartype prefix: str, optional """ - model: Literal['constant'] + LINEAR_PARAMETERS: ClassVar[list[str]] = ['c'] + MODEL_PARAMETERS: ClassVar[list[str]] = [] + MODEL_IDENTIFIERS: ClassVar[list[str]] = [] + model_type: Literal['constant'] parameters: Annotated[ conlist(item_type=FitParameter), Field(validate_default=True)] = [] prefix: Optional[str] = '' - _validate_parameters_parameters = field_validator( - 'parameters')(validate_parameters) + _func: PrivateAttr + _func_args: PrivateAttr + + @model_validator(mode='after') + def validate_model_after(self): + """Validate the model configuration and initialize the + appropriate parameters from the model function signature. + + :return: Validated and initialized configuration. + :rtype: Model + """ + # System imports + from inspect import signature + + if self.model_type == 'expression': + return self + self._func = globals()[self.model_type] + sig = dict(signature(self._func).parameters) + sig.pop('x') + self._func_args = list(sig.keys()) + + # Check input model parameter validity + par_names = [] + for par in self.parameters: + if par.name not in sig: + raise ValueError( + 'Invalid parameter {par.name} in {self.model_type} model ' + f'valid function arguments: {list(sig.keys())}') + par_names.append(par.name) + + # Set model parameters + for sig_name, sig_par in sig.items(): + if sig_name in self.MODEL_IDENTIFIERS or sig_name in par_names: +# if ((self.model_type == 'rectangle' and sig_name == 'form') +# or sig_name in par_names): + continue + par = FitParameter(name=sig_name) + if sig_par.default != sig_par.empty: + par._default = sig_par.default + self.parameters.append(par) + par_names.append(par.name) + + # Perform any additional validation of model parameters + if hasattr(self, '_validate_parameters'): + self._validate_parameters() + return self + + @property + def func(self): + """Return the model function + + :type: function + """ + if hasattr(self, '_func'): + return self._func + return None -class Linear(CHAPBaseModel): +class LinearModel(ConstantModel): """Class representing a Linear model component. - :ivar model: The model component base name (a prefix will be added - if multiple identical model components are added). - :vartype model: Literal['linear'] - :ivar parameters: Function parameters, defaults to those auto - generated from the function signature (excluding the - independent variable). - :vartype parameters: list[FitParameter], optional - :ivar prefix: The model prefix, defaults to `''`. - :vartype prefix: str, optional + :ivar model_type: Model component base name (a prefix will be + added if multiple identical model components are added). + :vartype model_type: Literal['linear'] """ - model: Literal['linear'] - parameters: Annotated[ - conlist(item_type=FitParameter), - Field(validate_default=True)] = [] - prefix: Optional[str] = '' + LINEAR_PARAMETERS: ClassVar[list[str]] = ['slope', 'intercept'] + model_type: Literal['linear'] - _validate_parameters_parameters = field_validator( - 'parameters')(validate_parameters) - -class Quadratic(CHAPBaseModel): +class QuadraticModel(ConstantModel): """Class representing a Quadratic model component. - :ivar model: The model component base name (a prefix will be added - if multiple identical model components are added). - :vartype model: Literal['quadratic'] - :ivar parameters: Function parameters, defaults to those auto - generated from the function signature (excluding the - independent variable). - :vartype parameters: list[FitParameter], optional - :ivar prefix: The model prefix, defaults to `''`. - :vartype prefix: str, optional + :ivar model_type: Model component base name (a prefix will be + added if multiple identical model components are added). + :vartype model_type: Literal['parabolic'] """ - model: Literal['quadratic'] - parameters: Annotated[ - conlist(item_type=FitParameter), - Field(validate_default=True)] = [] - prefix: Optional[str] = '' - - _validate_parameters_parameters = field_validator( - 'parameters')(validate_parameters) + LINEAR_PARAMETERS: ClassVar[list[str]] = ['a', 'b', 'c'] + model_type: Literal['parabolic'] -class Exponential(CHAPBaseModel): +class ExponentialModel(ConstantModel): """Class representing an Exponential model component. - :ivar model: The model component base name (a prefix will be added - if multiple identical model components are added). - :vartype model: Literal['exponential'] - :ivar parameters: Function parameters, defaults to those auto - generated from the function signature (excluding the - independent variable). - :vartype parameters: list[FitParameter], optional - :ivar prefix: The model prefix, defaults to `''`. - :vartype prefix: str, optional + :ivar model_type: Model component base name (a prefix will be + added if multiple identical model components are added). + :vartype model_type: Literal['exponential'] """ - model: Literal['exponential'] - parameters: Annotated[ - conlist(item_type=FitParameter), - Field(validate_default=True)] = [] - prefix: Optional[str] = '' - - _validate_parameters_parameters = field_validator( - 'parameters')(validate_parameters) + LINEAR_PARAMETERS: ClassVar[list[str]] = ['amplitude'] + model_type: Literal['exponential'] -class Gaussian(CHAPBaseModel): +class GaussianModel(ConstantModel): """Class representing a Gaussian model component. - :ivar model: The model component base name (a prefix will be added - if multiple identical model components are added). - :vartype model: Literal['gaussian'] - :ivar parameters: Function parameters, defaults to those auto - generated from the function signature (excluding the - independent variable). - :vartype parameters: list[FitParameter], optional - :ivar prefix: The model prefix, defaults to `''`. - :vartype prefix: str, optional + :ivar model_type: Model component base name (a prefix will be + added if multiple identical model components are added). + :vartype model_type: Literal['gaussian'] """ - model: Literal['gaussian'] - parameters: Annotated[ - conlist(item_type=FitParameter), - Field(validate_default=True)] = [] - prefix: Optional[str] = '' + LINEAR_PARAMETERS: ClassVar[list[str]] = ['amplitude'] + model_type: Literal['gaussian'] - _validate_parameters_parameters = field_validator( - 'parameters')(validate_parameters) + def _validate_parameters(self): + """Validate the model parameters.""" + for par in self.parameters: + if par.name == 'sigma': + par.min = 0.0 -class Lorentzian(CHAPBaseModel): +class LorentzianModel(ConstantModel): """Class representing a Lorentzian model component. - :ivar model: The model component base name (a prefix will be added - if multiple identical model components are added). - :vartype model: Literal['lorentzian'] - :ivar parameters: Function parameters, defaults to those auto - generated from the function signature (excluding the - independent variable). - :vartype parameters: list[FitParameter], optional - :ivar prefix: The model prefix, defaults to `''`. - :vartype prefix: str, optional + :ivar model_type: Model component base name (a prefix will be + added if multiple identical model components are added). + :vartype model_type: Literal['lorentzian'] """ - model: Literal['lorentzian'] - parameters: Annotated[ - conlist(item_type=FitParameter), - Field(validate_default=True)] = [] - prefix: Optional[str] = '' + LINEAR_PARAMETERS: ClassVar[list[str]] = ['amplitude'] + model_type: Literal['lorentzian'] - _validate_parameters_parameters = field_validator( - 'parameters')(validate_parameters) + def _validate_parameters(self): + """Validate the model parameters.""" + for par in self.parameters: + if par.name == 'sigma': + par.min = 0.0 -class PseudoVoigt(CHAPBaseModel): +class PseudoVoigtModel(ConstantModel): """Class representing a PseudoVoigt model component. - :ivar model: The model component base name (a prefix will be added - if multiple identical model components are added). - :vartype model: Literal['pvoigt'] - :ivar parameters: Function parameters, defaults to those auto - generated from the function signature (excluding the - independent variable). - :vartype parameters: list[FitParameter], optional - :ivar prefix: The model prefix, defaults to `''`. - :vartype prefix: str, optional + :ivar model_type: Model component base name (a prefix will be + added if multiple identical model components are added). + :vartype model_type: Literal['pvoigt'] """ - model: Literal['pvoigt'] - parameters: Annotated[ - conlist(item_type=FitParameter), - Field(validate_default=True)] = [] - prefix: Optional[str] = '' + LINEAR_PARAMETERS: ClassVar[list[str]] = ['amplitude'] + MODEL_PARAMETERS: ClassVar[list[str]] = ['fraction'] + model_type: Literal['pvoigt'] - _validate_parameters_parameters = field_validator( - 'parameters')(validate_parameters) + def _validate_parameters(self): + """Validate the model parameters.""" + for par in self.parameters: + if par.name == 'fraction': + par.min = 0.0 + par.max = 1.0 + elif par.name == 'sigma': + par.min = 0.0 -class Rectangle(CHAPBaseModel): +class RectangleModel(ConstantModel): """Class representing a Rectangle model component. - :ivar model: The model component base name (a prefix will be added - if multiple identical model components are added). - :vartype model: Literal['rectangle'] - :ivar parameters: Function parameters, defaults to those auto - generated from the function signature (excluding the - independent variable). - :vartype parameters: list[FitParameter], optional - :ivar prefix: The model prefix, defaults to `''`. - :vartype prefix: str, optional + :ivar model_type: Model component base name (a prefix will be + added if multiple identical model components are added). + :vartype model_type: Literal['rectangle'] """ - model: Literal['rectangle'] - parameters: Annotated[ - conlist(item_type=FitParameter), - Field(validate_default=True)] = [] - prefix: Optional[str] = '' - - _validate_parameters_parameters = field_validator( - 'parameters')(validate_parameters) + LINEAR_PARAMETERS: ClassVar[list[str]] = ['amplitude'] + MODEL_IDENTIFIERS: ClassVar[list[str]] = ['form'] + model_type: Literal['rectangle'] + def _validate_parameters(self): + """Validate the model parameters.""" + for par in self.parameters: + if par.name == 'form': + assert form in ('linear', 'atan', 'arctan', 'erf', 'logistic') -class Expression(CHAPBaseModel): +class ExpressionModel(ConstantModel): """Class representing an Expression model component. - :ivar model: The model component base name (a prefix will be added - if multiple identical model components are added). - :vartype model: Literal['expression'] + :ivar model_type: The model component base name (a prefix will be + added if multiple identical model components are added). + :vartype model_type: Literal['expression'] :ivar expr: Mathematical expression to represent the model component. :vartype expr: str - :ivar parameters: Function parameters, defaults to those auto - generated from the model expression (excluding the - independent variable). - :vartype parameters: list[FitParameter], optional - :ivar prefix: The model prefix, defaults to `''`. - :vartype prefix: str, optional """ - model: Literal['expression'] + model_type: Literal['expression'] expr: constr(strip_whitespace=True, min_length=1) - parameters: Annotated[ - conlist(item_type=FitParameter), - Field(validate_default=True)] = [] - prefix: Optional[str] = '' - _validate_parameters_parameters = field_validator( - 'parameters')(validate_parameters) + +# Available models for components of the fitting function +#MODEL_CLASSES = [ +# ConstantModel, +# LinearModel, +# QuadraticModel, +# ExponentialModel, +# GaussianModel, +# LorentzianModel, +# PseudoVoigtModel, +# RectangleModel, +# ExpressionModel, +#] + +# Reusable Discriminator Union for supported fit model components. +Model = Annotated[ +# FIX for Python 3.11+ Union[*MODEL_CLASSES], + Union[ + ConstantModel, + LinearModel, + QuadraticModel, + ExponentialModel, + GaussianModel, + LorentzianModel, + PseudoVoigtModel, + RectangleModel, + ExpressionModel, + ], + Field(discriminator='model_type') +] + +# Peak-like models: with amplitude, center and sigma as their +# function arguments +PEAK_LIKE_MODELS = { + 'gaussian': GaussianModel, + 'lorentzian': LorentzianModel, + 'pvoigt': PseudoVoigtModel, +} + +#MODEL_TYPE_TO_CLASS = {#v.model_type:v for v in MODEL_CLASSES} +# 'constant': constant, +# 'linear': linear, +# 'parabolic': parabolic, +# 'exponential': exponential, +# 'gaussian': gaussian, +# 'lorentzian': lorentzian, +# 'pvoigt': pvoigt, +# 'rectangle': rectangle, +#} -class Multipeak(CHAPBaseModel): +class MultipeakModel(CHAPBaseModel): """Class representing a multipeak model. - :ivar model: The model component base name (a prefix will be added - if multiple identical model components are added). - :vartype model: Literal['expression'] + :ivar model_type: The model component base name (a prefix will be + added if multiple identical model components are added). + :vartype model_type: Literal['expression'] :ivar centers: Peak centers. :vartype center: list[float] :ivar centers_range: Range of peak centers around their centers. @@ -730,7 +721,7 @@ class Multipeak(CHAPBaseModel): optional. """ - model: Literal['multipeak'] + model_type: Literal['multipeak'] centers: conlist(item_type=confloat(allow_inf_nan=False), min_length=1) centers_range: Optional[confloat(allow_inf_nan=False)] = None fit_type: Optional[Literal['uniform', 'unconstrained']] = 'unconstrained' @@ -739,75 +730,57 @@ class Multipeak(CHAPBaseModel): peak_models: Literal['gaussian', 'lorentzian', 'pvoigt'] = 'gaussian' -models = { - 'constant': {'name': constant, 'class': Constant}, - 'linear': {'name': linear, 'class': Linear}, - 'quadratic': {'name': quadratic, 'class': Quadratic}, - 'exponential': {'name': exponential, 'class': Exponential}, - 'gaussian': {'name': gaussian, 'class': Gaussian}, - 'lorentzian': {'name': lorentzian, 'class': Lorentzian}, - 'pvoigt': {'name': pvoigt, 'class': PseudoVoigt}, - 'rectangle': {'name': rectangle, 'class': Rectangle}, -} - -model_classes = ( - Constant, - Linear, - Quadratic, - Exponential, - Gaussian, - Lorentzian, - PseudoVoigt, - Rectangle, -) - - class FitConfig(CHAPBaseModel): """Class representing the configuration for the fit processor. :ivar code: Specifies is lmfit is used to perform the fit or if the scipy fit method is called directly, default to `'lmfit'`. :vartype code: Literal['lmfit', 'scipy'], optional + :ivar max_nfev: Maximum number of function evaluations in the + the strain analysis peak fitting routine. + :vartype max_nfev: int, optional + :ivar memfolder: Folder name for the temporary memory map if + multiple processors are used, defaults to `'joblib_memmap'`. + :vartype memfolder: str, optional + :ivar method: SciPy non-linear fit method, defaults to + `"leastsq"`. + :vartype method: Literal[ + 'leastsq', 'trf', 'dogbox', 'lm', 'least_squares'] + :ivar models: The component(s) of the (composite) fit model. + :vartype models: list[Model, MultipeakModel] + :ivar num_proc: The number of processors used in fitting a map + of data, defaults to `1`. + :vartype num_proc: int, optional :ivar parameters: Fit model parameters in addition to those implicitly defined through the build-in model functions, defaults to `[]`' :vartype parameters: list[:class:`~CHAP.utils.models.FitParameter`], optional - :ivar models: The component(s) of the (composite) fit model. - :vartype models: - list[:attr:`~CHAP.utils.models.FitConfig.models`] - :ivar rel_height_cutoff: Relative peak height cutoff for - peak fitting (any peak with a height smaller than - `rel_height_cutoff` times the maximum height of all peaks - gets removed from the fit model). - :vartype rel_height_cutoff: float, optional - :ivar num_proc: The number of processors used in fitting a map - of data, defaults to `1`. - :vartype num_proc: int, optional :ivar plot: Whether a plot of the fit result is generated, defaults to `False`. :vartype plot: bool, optional. :ivar print_report: Whether to generate a fit result printout, defaults to `False`. :vartype print_report: bool, optional. - :ivar memfolder: Folder name for the temporary memory map if - multiple processors are used, defaults to `'joblib_memmap'`. - :vartype memfolder: str, optional + :ivar rel_height_cutoff: Relative peak height cutoff for + peak fitting (any peak with a height smaller than + `rel_height_cutoff` times the maximum height of all peaks + gets removed from the fit model). + :vartype rel_height_cutoff: float, optional """ code: Literal['lmfit', 'scipy'] = 'scipy' - parameters: conlist(item_type=FitParameter) = [] - models: conlist(item_type=Union[ - Constant, Linear, Quadratic, Exponential, Gaussian, Lorentzian, - PseudoVoigt, Rectangle, Expression, Multipeak], min_length=1) + max_nfev: Optional[conint(gt=0)] = None + memfolder: str = 'joblib_memmap' method: Literal[ 'leastsq', 'trf', 'dogbox', 'lm', 'least_squares'] = 'leastsq' - rel_height_cutoff: Optional[ - confloat(gt=0, lt=1.0, allow_inf_nan=False)] = None + models: conlist(item_type=Union[Model, MultipeakModel], min_length=1) num_proc: conint(gt=0) = 1 + parameters: conlist(item_type=FitParameter) = [] plot: StrictBool = False print_report: StrictBool = False - memfolder: str = 'joblib_memmap' + rel_height_cutoff: Optional[ + confloat(gt=0, lt=1.0, allow_inf_nan=False)] = None @field_validator('method') @classmethod diff --git a/examples/edd/edd_calibration_script.py b/examples/edd/edd_calibration_script.py new file mode 100644 index 00000000..b8ed5b9d --- /dev/null +++ b/examples/edd/edd_calibration_script.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +"""A python script version of the CHAP pipeline version of the +calibration example. + +Modify any user choices in the "Start user input" and "Optional user +input" blocks below. + +In particular: + +* Choose the path to your CHAP repo by setting sys.path. +""" + +# System modules +import os +from pprint import pprint +import sys +from tempfile import NamedTemporaryFile + +# Third party modules +import yaml + +# Local modules +from CHAP.pipeline import PipelineData +from CHAP.common.reader import SpecReader +from CHAP.edd.processor import ( + MCAEnergyCalibrationProcessor, + MCATthCalibrationProcessor, +) + +#------------------------------------------------------------------------------# +# Start user input +#------------------------------------------------------------------------------# + +# Choose the path to your CHAP repo +sys.path.append( + '/home/rv43/Documents/Programs/repos/CHESSComputing/ChessAnalysisPipeline_main') +#sys.path.append( +# '/nfs/chess/sw/CHESS-software-releases/repos/prod/ChessAnalysisPipeline') + +root = None #'examples/edd' +interactive = True +outputdir = 'output' +energy_calib = True +tth_calib = True + +if root is None: + root = os.getcwd() +if outputdir is None: + outputdir = '.' +if not os.path.isabs(outputdir): + outputdir = os.path.normpath( + os.path.realpath(os.path.join(root, outputdir))) + +energy_name = f'{outputdir}/energy_calibration_result.yaml' +tth_name = f'{outputdir}/tth_calibration_result.yaml' +spec_file = f'{root}/data/ceo2-5deg-80um-calib/spec.log' +scan_numbers = 1 +detector_ids = [0, 21] + +#------------------------------------------------------------------------------# +# Optional user input +#------------------------------------------------------------------------------# + +spec_config = { + 'station': 'id3a', + 'experiment_type': 'EDD', + 'spec_scans': [ + {'spec_file': spec_file, + 'scan_numbers': scan_numbers} + ], +} +energy_config = { + 'peak_energies': [34.276, 34.717, 39.255, 40.231], + 'max_peak_index': 1, + 'materials': [ + {'material_name': 'CeO2', + 'lattice_parameters': 5.41153, + 'sgnum': 225} + ], +} +energy_detector_config = { + 'baseline': True, + 'mask_ranges': [[650, 850]], +} +if detector_ids is not None: + energy_detector_config['detectors'] = [{'id':id_} for id_ in detector_ids] +tth_config = {'tth_initial_guess': 5.2} +tth_detector_config = { + 'energy_mask_ranges': [[65, 155]], +} + +#------------------------------------------------------------------------------# +# End user input +#------------------------------------------------------------------------------# + +if not os.path.isdir(outputdir): + os.makedirs(outputdir) +try: + NamedTemporaryFile(dir=outputdir) +except Exception as exc: + raise OSError( + 'Output directory not accessible for writing ({outputdir})') from exc + +#------------------------------------------------------------------------------# + +print(f'\nspec_config:') +pprint(spec_config) +if energy_calib: + print(f'\nenergy_config:') + pprint(energy_config) + print(f'\nenergy_detector_config:') + pprint(energy_detector_config) +if tth_calib: + print(f'\ntth_config:') + pprint(tth_config) + print(f'\ntth_detector_config:') + pprint(tth_detector_config) + +# Perform the energy calibration +if energy_calib: + # Read the calibration data + nxroot = SpecReader.run(config=spec_config) + + # Perform the energy calibration + data = [PipelineData(name='SpecReader', data=nxroot)] + energy_calib_config, images = MCAEnergyCalibrationProcessor.run( + data=data, config=energy_config, detector_config=energy_detector_config, + interactive=interactive) + + # Write the energy calibration results + with open(energy_name, 'w', encoding='utf-8') as f: + yaml.dump(energy_calib_config, f, sort_keys=False) + +# Perform the tth calibration +if tth_calib: + # Read the energy calibration results + with open(energy_name, encoding='utf-8') as f: + energy_calib_config = yaml.safe_load(f) + + # Read the calibration data + nxroot = SpecReader.run(config=spec_config) + + # Perform the tth calibration + data = [ + PipelineData( + data=energy_calib_config, + schema='edd.models.MCAEnergyCalibrationConfig'), + PipelineData(name='SpecReader', data=nxroot), + ] + tth_calib_config, images = MCATthCalibrationProcessor.run( + data=data, config=tth_config, detector_config=tth_detector_config, + interactive=interactive) + with open(tth_name, 'w', encoding='utf-8') as f: + yaml.dump(tth_calib_config, f, sort_keys=False) + diff --git a/examples/tomo/tomo_one_plane.py b/examples/tomo/tomo_script_id3b.py similarity index 52% rename from examples/tomo/tomo_one_plane.py rename to examples/tomo/tomo_script_id3b.py index 353d760b..c6eed623 100644 --- a/examples/tomo/tomo_one_plane.py +++ b/examples/tomo/tomo_script_id3b.py @@ -1,16 +1,36 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +"""A python script version of the CHAP pipeline version of the id3b +example. +Modify any user choices in the "Start user input" and "Optional user +input" blocks below. + +In particular: + +* Choose the path to your CHAP repo by setting sys.path. + +* Select the run_type below from: + =========== ============================================== + run_type Description + =========== ============================================== + 'normal' Full regular reconstruction + (like running the default hollow_cube in CHAP) + 'one_plane' Reconstruction of two planes only + 'two_planes' Reconstruction of one plane only + 'roi' Reconstruction for a subset or rows + (find center at first and last row) + =========== ============================================== +""" + +# System modules +import os +from pprint import pprint import sys -#sys.path.append( -# '/home/rv43/Documents/Programs/repos/CHESSComputing/ChessAnalysisPipeline') -sys.path.append( - '/nfs/chess/sw/CHESS-software-releases/repos/prod/ChessAnalysisPipeline') +from tempfile import NamedTemporaryFile # Third party modules -import matplotlib.pyplot as plt from nexusformat.nexus import nxload -from PIL import Image import yaml # Local modules @@ -29,74 +49,79 @@ # Start user input #------------------------------------------------------------------------------# +# Choose the path to your CHAP repo +#sys.path.append( +# '/home/rv43/Documents/Programs/repos/CHESSComputing/ChessAnalysisPipeline_main') +sys.path.append( + '/nfs/chess/sw/CHESS-software-releases/repos/prod/ChessAnalysisPipeline') + +# Select run_type +#run_type = 'normal' +#run_type = 'one_plane' +#run_type = 'two_planes' +run_type = 'roi' + interactive = True +outputdir = 'output' construct_chess_map = True -map_name = 'output/map.nxs' -chess_map_name = 'output/chess_map.nxs' -load_converted_map = True -img_row_bounds = [600, 606] # None, [80, 2080] or [1050, 1251] -recon_layer_index = [600, 1400] # One or two values - # ignored when using img_row_bounds reduce_data = True -reduced_data_name = 'output/reduced_data.nxs' +find_center = True +reconstruct_data = True + +map_name = f'{outputdir}/map.nxs' +chess_map_name = f'{outputdir}/chess_map.nxs' +reduced_data_name = f'{outputdir}/reduced_data.nxs' reduce_config = { - 'remove_stripe': {'remove_all_stripe': {}}, - 'img_row_bounds': img_row_bounds, +# 'remove_stripe': {'remove_all_stripe': {}}, } -find_center = True -find_center_name = 'output/find_center.yaml' +find_center_name = f'{outputdir}/find_center.yaml' find_center_config = { - 'gaussian_sigma': 0.75, + 'gaussian_sigma': 0.05, # 'remove_stripe_sigma': None, - 'ring_width': 5, + 'ring_width': 1, } -reconstruct_data = True -reconstructed_data_name = 'output/reconstructed_data.nxs' +reconstructed_data_name = f'{outputdir}/reconstructed_data.nxs' reconstruct_config = { - 'x_bounds': [90, 2470], - 'y_bounds': [600, 2000], - 'secondary_iters:': 25, - 'gaussian_sigma': 0.75, - 'remove_stripe_sigma': None, - 'ring_width': 5, + 'x_bounds': [15, 390], + 'y_bounds': [25, 380], + 'secondary_iters': 10, +# 'gaussian_sigma': 0.75, +# 'remove_stripe_sigma': None, + 'ring_width': 1, } tdf_scan_numbers = 1 tbf_scan_numbers = 2 map_config = { - 'title': 'badran-4510-b', + 'title': 'hollow_cube', 'station': 'id3b', 'experiment_type': 'TOMO', - 'sample': {'name': 'sample_190-1-Edge'}, + 'sample': {'name': 'hollow_cube'}, 'spec_scans': [ - {'spec_file': 'data/sample_190-1-Edge/refine_1', + {'spec_file': 'raw/hollow_cube/hollow_cube', 'scan_numbers': 3} ], 'independent_dimensions': [ {'label': 'rotation_angles', 'units': 'degrees', 'data_type': 'scan_column', - 'name': 'GI_samphi'}, + 'name': 'theta'}, {'label': 'x_translation', 'units': 'mm', 'data_type': 'spec_motor', - 'name': 'saxx'}, + 'name': 'GI_samx'}, {'label': 'z_translation', 'units': 'mm', 'data_type': 'spec_motor', - 'name': 'saxz'} + 'name': 'GI_samz'} ], - 'presample_intensity': { - 'data_type': 'scan_column', - 'name': 'ic1' - } } -detector_config = {'detectors': [{'id': 'andor2'}],} +detector_config = {'detectors': [{'id': 'sim'}],} detector_setup = { - 'prefix': 'andor2', - 'rows': 2160, - 'columns': 2560, - 'pixel_size': [0.0065, 0.0065], - 'lens_magnification': 5.0, + 'prefix': 'sim', + 'rows': 40, + 'columns': 400, + 'pixel_size': [0.05, 0.005], + 'lens_magnification': 1.0, } remove_stripe_corrections = [ # 'remove_all_stripe', @@ -111,73 +136,121 @@ # 'remove_stripe_ti', ] +#------------------------------------------------------------------------------# +# Optional user input +#------------------------------------------------------------------------------# + +if run_type == 'normal': + img_row_bounds = [3, 35] + recon_layer_indices = [11, 28] +elif run_type == 'one_plane': + img_row_bounds = None + recon_layer_indices = 15 +elif run_type == 'two_planes': + img_row_bounds = None + recon_layer_indices = [11, 20] +elif run_type == 'roi': + img_row_bounds = None + recon_layer_indices = [11, 20] +else: + print('Pick valid values for img_row_bounds and recon_layer_indices') + img_row_bounds = None # list[int, int], optional + recon_layer_indices = None # int | list[int, int], optional + # ignored when using img_row_bounds + #------------------------------------------------------------------------------# # End user input #------------------------------------------------------------------------------# +if not os.path.isdir(outputdir): + os.makedirs(outputdir) +try: + NamedTemporaryFile(dir=outputdir) +except Exception as exc: + raise OSError( + 'Output directory not accessible for writing ({outputdir})') from exc + +if isinstance(recon_layer_indices, int): + recon_layer_indices = [recon_layer_indices] if img_row_bounds is None: - if len(recon_layer_index)== 1: + center_rows = None + if len(recon_layer_indices)== 1: detector_config['roi'] = [ - {'start': recon_layer_index, - 'end': recon_layer_index+1, + {'start': recon_layer_indices[0], + 'end': recon_layer_indices[0]+1, 'step': None}, None] - elif len(recon_layer_index) == 2: + elif len(recon_layer_indices) == 2: detector_config['roi'] = [ - {'start': recon_layer_index[0], - 'end': recon_layer_index[1], - 'step': recon_layer_index[1]-recon_layer_index[0]-1}, + {'start': recon_layer_indices[0], + 'end': recon_layer_indices[1]}, None] + if run_type == 'two_planes': + detector_config['roi'][0]['step'] = \ + recon_layer_indices[1]-recon_layer_indices[0]-1 + else: + detector_config['roi'][0]['step'] = None else: - print('Invalid value for recon_layer_index ({recon_layer_index})') - exit() + raise RuntimeError('Invalid value for recon_layer_indices ({recon_layer_indices})') else: assert len(img_row_bounds) == 2 - recon_layer_index = [] - detector_config['roi'] = [ - {'start': img_row_bounds[0], - 'end': img_row_bounds[1], - 'step': 1}, - None] - -# Construct the CHESS style tomo map + center_rows = recon_layer_indices + recon_layer_indices = [] + +reduce_config['img_row_bounds'] = img_row_bounds +find_center_config['center_rows'] = center_rows + +#------------------------------------------------------------------------------# + +if construct_chess_map: + print(f'\nmap_config:') + pprint(map_config) +if reduce_data: + print(f'\nreduce_config:') + pprint(reduce_config) +if find_center: + print(f'\nfind_center_config:') + pprint(find_center_config) +if reconstruct_data: + print(f'\nreconstruct_config:') + pprint(reconstruct_config) +print(f'\ndetector_config:') +pprint(detector_config) + +# Construct the CHAP style tomo map if not construct_chess_map: nxroot = nxload(chess_map_name) else: # Create the map for the tomo stack - tomo_map = MapProcessor(config=map_config, detector_config=detector_config) - - # Read the map for the tomo stack - tomofields = tomo_map.process(data=None) + tomofields = MapProcessor.run( + config=map_config, detector_config=detector_config) tomofields.save(map_name, mode='w') # Read the dark field - tdf_spec_reader = SpecReader( + darkfield = SpecReader.run( config={ - 'station': tomo_map.config.station, - 'experiment_type': tomo_map.config.experiment_type, - 'sample': tomo_map.config.sample, + 'station': map_config['station'], + 'experiment_type': map_config['experiment_type'], + 'sample': map_config['sample'], 'spec_scans': [ - {'spec_file': tomo_map.config.spec_scans[0].spec_file, + {'spec_file': map_config['spec_scans'][0].spec_file, 'scan_numbers': tdf_scan_numbers}], }, detector_config=detector_config) - darkfield = tdf_spec_reader.read() # Read the bright field - tbf_spec_reader = SpecReader( + brightfield = SpecReader.run( config={ - 'station': tomo_map.config.station, - 'experiment_type': tomo_map.config.experiment_type, - 'sample': tomo_map.config.sample, + 'station': map_config['station'], + 'experiment_type': map_config['experiment_type'], + 'sample': map_config['sample'], 'spec_scans': [ - {'spec_file': tomo_map.config.spec_scans[0].spec_file, + {'spec_file': map_config['spec_scans'][0].spec_file, 'scan_numbers': tbf_scan_numbers}], }, detector_config=detector_config) - brightfield = tbf_spec_reader.read() - # Convert to CHESS style tomography map + # Convert to CHAP style tomography map data = [ PipelineData( name='MapProcessor', data=tomofields, schema='tomofields'), @@ -188,8 +261,7 @@ name='YAMLReader', data=detector_setup, schema='tomo.models.Detector') ] - chess_map = TomoCHESSMapConverter() - _, _, nxroot = chess_map.process(data=data) + _, _, nxroot = TomoCHESSMapConverter.run(data=data) nxroot = nxroot['data'] nxroot.save(chess_map_name, mode='w') @@ -197,32 +269,21 @@ if not reduce_data: nxroot = nxload(reduced_data_name) else: - # Load as needed - if load_converted_map: - nxroot = nxload(chess_map_name) + nxroot = nxload(chess_map_name) # Reduce the data with remove_all_stripe data = [ PipelineData(name='TomoCHESSMapConverter', data=nxroot, schema=None)] - tomo = TomoReduceProcessor( - config=reduce_config, interactive=interactive) - (metadata, provenance, images, reduced_data) = tomo.process(data) + _, _, _, reduced_data = TomoReduceProcessor.run( + data=data, config=reduce_config, interactive=interactive) reduced_data = reduced_data['data'] reduced_data.save(reduced_data_name, mode='w') -# for (buf, ext), name in images.get('data', []): -# buf.seek(0) -# plt.imshow(Image.open(buf)) -# plt.axis('off') -# plt.tight_layout() -# plt.show(block=True) -# buf.close() -# plt.close() nxentry = reduced_data[reduced_data.default] nxdata = nxentry[nxentry.default] image_slice = nxdata.nxsignal[0,:,0,:] vmin = image_slice.min() vmax = image_slice.max() - if len(recon_layer_index): + if recon_layer_indices: quick_imshow( image_slice, title=f'Slice {detector_config["roi"][0]["start"]}', @@ -231,7 +292,7 @@ vmin=vmin, vmax=vmax, save_fig=True, block=True) - if len(recon_layer_index) == 2: + if len(recon_layer_indices) == 2: quick_imshow( image_slice, title=f'Slice {detector_config["roi"][0]["end"]}', @@ -246,10 +307,10 @@ data = [PipelineData( name='TomoCHESSMapConverter', data=nxroot, schema=None)] - tomo = TomoReduceProcessor( + _, _, _, reduced_data = TomoReduceProcessor.run( + data=data, config=reduce_config.update({'remove_stripe': {method: {}}}), interactive=interactive) - (metadata, provenance, images, reduced_data) = tomo.process(data) reduced_data = reduced_data['data'] reduced_data.save(f'reduced_data_{method}.nxs', mode='w') @@ -274,11 +335,10 @@ data = [PipelineData( name='TomoCHESSMapConverter', data=reduced_data, schema=None)] - tomo = TomoFindCenterProcessor( - interactive=interactive, config=find_center_config) - (metadata, provenance, images, center_config) = tomo.process(data) + _, _, _, center_config = TomoFindCenterProcessor.run( + data=data, config=find_center_config, interactive=interactive) center_config = center_config['data'] - with open(find_center_name, 'w') as f: + with open(find_center_name, 'w', encoding='utf-8') as f: yaml.dump(center_config, f, sort_keys=False) # Reconstruct @@ -287,7 +347,7 @@ pass else: reduced_data = nxload(reduced_data_name) - with open(find_center_name) as f: + with open(find_center_name, encoding='utf-8') as f: center_config = yaml.safe_load(f) data = [ @@ -300,10 +360,8 @@ # name='TomoFindCenterProcessor', data=center_config, # schema='tomo.models.TomoFindCenterConfig') ] - tomo = TomoReconstructProcessor( - config=reconstruct_config, - center_config=center_config, + _, _, _, reconstructed_data = TomoReconstructProcessor.run( + data=data, config=reconstruct_config, center_config=center_config, interactive=interactive) - (metadata, provenance, images, reconstructed_data) = tomo.process(data) reconstructed_data = reconstructed_data['data'] reconstructed_data.save(reconstructed_data_name, mode='w')