diff --git a/src/pygama/dsp/build_dsp.py b/src/pygama/dsp/build_dsp.py index 583542409..85a8a708b 100644 --- a/src/pygama/dsp/build_dsp.py +++ b/src/pygama/dsp/build_dsp.py @@ -23,10 +23,11 @@ log = logging.getLogger(__name__) -def build_dsp(f_raw: str, f_dsp: str, dsp_config: str | dict, lh5_tables: - list[str] = None, database: str = None, outputs: list[str] = - None, n_max: int = np.inf, write_mode: str = 'r', buffer_len: int - = 3200, block_width: int = 16, chan_config: dict = None) -> None: +def build_dsp(f_raw: str, f_dsp: str, dsp_config: str | dict = None, + lh5_tables: list[str] = None, database: str = None, + outputs: list[str] = None, n_max: int = np.inf, + write_mode: str = 'r', buffer_len: int = 3200, + block_width: int = 16, chan_config: dict = None) -> None: """ Convert raw-tier LH5 data into dsp-tier LH5 data by running a sequence of processors via the :class:`~.processing_chain.ProcessingChain`. @@ -65,6 +66,22 @@ def build_dsp(f_raw: str, f_dsp: str, dsp_config: str | dict, lh5_tables: `lh5_tables` """ + if chan_config is not None: + # clear existing output files + if write_mode == 'r': + if os.path.isfile(f_dsp): + os.remove(f_dsp) + write_mode = 'a' + + for tb, dsp_config in chan_config.items(): + log.debug(f'processing table: {tb} with DSP config file {dsp_config}') + try: + build_dsp(f_raw, f_dsp, dsp_config, [tb], database, + outputs, n_max, write_mode, buffer_len, block_width) + except RuntimeError: + log.debug(f'table {tb} not found') + return + if isinstance(dsp_config, str): with open(dsp_config) as config_file: dsp_config = json.load(config_file) @@ -80,21 +97,15 @@ def build_dsp(f_raw: str, f_dsp: str, dsp_config: str | dict, lh5_tables: # if no group is specified, assume we want to decode every table in the file if lh5_tables is None: - lh5_tables = [] - lh5_keys = lh5.ls(f_raw) - - # sometimes 'raw' is nested, e.g g024/raw - for tb in lh5_keys: - if "raw" not in tb: - tbname = lh5.ls(lh5_file[tb])[0] - if "raw" in tbname: - tb = f'{tb}/{tbname}' # g024 + /raw - lh5_tables.append(tb) - - # make sure every group points to waveforms, if not, remove the group - for tb in lh5_tables: - if 'raw' not in tb: - lh5_tables.remove(tb) + lh5_tables = lh5.ls(f_raw) + + # check if group points to raw data; sometimes 'raw' is nested, e.g g024/raw + for i, tb in enumerate(lh5_tables): + if "raw" not in tb and lh5.ls(lh5_file, f"{tb}/raw"): + lh5_tables[i] = f'{tb}/raw' + elif not lh5.ls(lh5_file, tb): + del lh5_tables[i] + if len(lh5_tables) == 0: raise RuntimeError(f"could not find any valid LH5 table in {f_raw}") @@ -134,16 +145,6 @@ def build_dsp(f_raw: str, f_dsp: str, dsp_config: str | dict, lh5_tables: if n_max and n_max < tot_n_rows: tot_n_rows = n_max - # if we have separate DSP files for each table, read them in here - if chan_config is not None: - f_config = chan_config[tb] - with open(f_config) as config_file: - dsp_config = json.load(config_file) - log.debug(f'processing table: {tb} with DSP config file {f_config}') - - if not isinstance(dsp_config, dict): - raise RuntimeError(f'dsp_config for {tb} must be a dict') - chan_name = tb.split('/')[0] db_dict = database.get(chan_name) if database else None tb_name = tb.replace('/raw', '/dsp') @@ -180,12 +181,12 @@ def build_dsp(f_raw: str, f_dsp: str, dsp_config: str | dict, lh5_tables: write_start=write_offset+start_row) if log.level <= logging.INFO: - progress_bar.update(buffer_len) + progress_bar.update(n_rows) - if start_row+n_rows > tot_n_rows: + if start_row+n_rows >= tot_n_rows: break if log.level <= logging.INFO: progress_bar.close() - raw_store.write_object(dsp_info, 'dsp_info', f_dsp) + raw_store.write_object(dsp_info, 'dsp_info', f_dsp, wo_mode='o') diff --git a/src/pygama/dsp/processing_chain.py b/src/pygama/dsp/processing_chain.py index e19b115f8..ed69117a2 100644 --- a/src/pygama/dsp/processing_chain.py +++ b/src/pygama/dsp/processing_chain.py @@ -1082,8 +1082,8 @@ def write(self, start: int, end: int) -> None: self.raw_var[0:end-start, ...], 'unsafe') def __str__(self) -> str: - return (f"{self.var} linked to numpy.array({self.io_buf.shape}, " - f"{self.io_buf.dtype})@{self.io_buf.data})") + return (f"{self.var} linked to numpy.array(shape={self.io_buf.shape}, " + f"dtype={self.io_buf.dtype})") class LGDOArrayIOManager(IOManager): @@ -1130,7 +1130,7 @@ def write(self, start: int, end: int) -> None: self.raw_var[0:end-start, ...], 'unsafe') def __str__(self) -> str: - return f'{self.var} linked to {self.io_array}' + return f'{self.var} linked to lgdo.Array(shape={self.io_array.nda.shape}, dtype={self.io_array.nda.dtype}, attrs={self.io_array.attrs})' class LGDOArrayOfEqualSizedArraysIOManager(IOManager): """IO Manager for buffers that are numpy ArrayOfEqualSizedArrays""" @@ -1176,7 +1176,7 @@ def write(self, start: int, end: int) -> None: self.raw_var[0:end-start, ...], 'unsafe') def __str__(self) -> str: - return f'{self.var} linked to {self.io_array}' + return f'{self.var} linked to lgdo.ArrayOfEqualSizedArrays(shape={self.io_array.nda.shape}, dtype={self.io_array.nda.dtype}, attrs={self.io_array.attrs})' class LGDOWaveformIOManager(IOManager): @@ -1243,9 +1243,10 @@ def write(self, start: int, end: int) -> None: self.t0_buf[start:end, ...] = self.t0_var[0:end-start, ...] def __str__(self) -> str: - return (f"{self.var} linked to ") + return (f"{self.var} linked to pygama.lgdo.WaveformTable(" + f"values(shape={self.wf_table.values.nda.shape}, dtype={self.wf_table.values.nda.dtype}, attrs={self.wf_table.values.attrs}), " + f"dt(shape={self.wf_table.dt.nda.shape}, dtype={self.wf_table.dt.nda.dtype}, attrs={self.wf_table.dt.attrs}), " + f"t0(shape={self.wf_table.t0.nda.shape}, dtype={self.wf_table.t0.nda.dtype}, attrs={self.wf_table.t0.attrs}))") def build_processing_chain(lh5_in: lgdo.Table, dsp_config: dict | str, db_dict: dict = None, diff --git a/src/pygama/lgdo/lh5_store.py b/src/pygama/lgdo/lh5_store.py index c8f8096af..a2ff6d8cb 100644 --- a/src/pygama/lgdo/lh5_store.py +++ b/src/pygama/lgdo/lh5_store.py @@ -623,7 +623,11 @@ def write_object(self, # scalars elif isinstance(obj, Scalar): if name in group: - log.debug(f"overwriting '{name}' in '{group}'") + if wo_mode in ['o', 'a']: + log.debug(f'overwriting {name} in {group}') + del group[name] + else: + raise RuntimeError(f"tried to overwrite {name} in {group} for wo_mode {wo_mode}") ds = group.create_dataset(name, shape=(), data=obj.value) ds.attrs.update(obj.attrs) return