Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DSP Improvements #324

Merged
merged 12 commits into from Jul 26, 2022
65 changes: 33 additions & 32 deletions src/pygama/dsp/build_dsp.py
Expand Up @@ -23,10 +23,11 @@
log = logging.getLogger(__name__)


def build_dsp(f_raw: str, f_dsp: str, dsp_config: str | dict, lh5_tables:
list[str] = None, database: str = None, outputs: list[str] =
None, n_max: int = np.inf, write_mode: str = 'r', buffer_len: int
= 3200, block_width: int = 16, chan_config: dict = None) -> None:
def build_dsp(f_raw: str, f_dsp: str, dsp_config: str | dict = None,
lh5_tables: list[str] = None, database: str = None,
outputs: list[str] = None, n_max: int = np.inf,
write_mode: str = 'r', buffer_len: int = 3200,
block_width: int = 16, chan_config: dict = None) -> None:
"""
Convert raw-tier LH5 data into dsp-tier LH5 data by running a sequence of
processors via the :class:`~.processing_chain.ProcessingChain`.
Expand Down Expand Up @@ -65,6 +66,22 @@ def build_dsp(f_raw: str, f_dsp: str, dsp_config: str | dict, lh5_tables:
`lh5_tables`
"""

if chan_config is not None:
# clear existing output files
if write_mode == 'r':
if os.path.isfile(f_dsp):
os.remove(f_dsp)
write_mode = 'a'

for tb, dsp_config in chan_config.items():
log.debug(f'processing table: {tb} with DSP config file {dsp_config}')
try:
build_dsp(f_raw, f_dsp, dsp_config, [tb], database,
outputs, n_max, write_mode, buffer_len, block_width)
except RuntimeError:
log.debug(f'table {tb} not found')
return

if isinstance(dsp_config, str):
with open(dsp_config) as config_file:
dsp_config = json.load(config_file)
Expand All @@ -80,21 +97,15 @@ def build_dsp(f_raw: str, f_dsp: str, dsp_config: str | dict, lh5_tables:

# if no group is specified, assume we want to decode every table in the file
if lh5_tables is None:
lh5_tables = []
lh5_keys = lh5.ls(f_raw)

# sometimes 'raw' is nested, e.g g024/raw
for tb in lh5_keys:
if "raw" not in tb:
tbname = lh5.ls(lh5_file[tb])[0]
if "raw" in tbname:
tb = f'{tb}/{tbname}' # g024 + /raw
lh5_tables.append(tb)

# make sure every group points to waveforms, if not, remove the group
for tb in lh5_tables:
if 'raw' not in tb:
lh5_tables.remove(tb)
lh5_tables = lh5.ls(f_raw)

# check if group points to raw data; sometimes 'raw' is nested, e.g g024/raw
for i, tb in enumerate(lh5_tables):
if "raw" not in tb and lh5.ls(lh5_file, f"{tb}/raw"):
lh5_tables[i] = f'{tb}/raw'
elif not lh5.ls(lh5_file, tb):
del lh5_tables[i]

if len(lh5_tables) == 0:
raise RuntimeError(f"could not find any valid LH5 table in {f_raw}")

Expand Down Expand Up @@ -134,16 +145,6 @@ def build_dsp(f_raw: str, f_dsp: str, dsp_config: str | dict, lh5_tables:
if n_max and n_max < tot_n_rows:
tot_n_rows = n_max

# if we have separate DSP files for each table, read them in here
if chan_config is not None:
f_config = chan_config[tb]
with open(f_config) as config_file:
dsp_config = json.load(config_file)
log.debug(f'processing table: {tb} with DSP config file {f_config}')

if not isinstance(dsp_config, dict):
raise RuntimeError(f'dsp_config for {tb} must be a dict')

chan_name = tb.split('/')[0]
db_dict = database.get(chan_name) if database else None
tb_name = tb.replace('/raw', '/dsp')
Expand Down Expand Up @@ -180,12 +181,12 @@ def build_dsp(f_raw: str, f_dsp: str, dsp_config: str | dict, lh5_tables:
write_start=write_offset+start_row)

if log.level <= logging.INFO:
progress_bar.update(buffer_len)
progress_bar.update(n_rows)

if start_row+n_rows > tot_n_rows:
if start_row+n_rows >= tot_n_rows:
break

if log.level <= logging.INFO:
progress_bar.close()

raw_store.write_object(dsp_info, 'dsp_info', f_dsp)
raw_store.write_object(dsp_info, 'dsp_info', f_dsp, wo_mode='o')
15 changes: 8 additions & 7 deletions src/pygama/dsp/processing_chain.py
Expand Up @@ -1082,8 +1082,8 @@ def write(self, start: int, end: int) -> None:
self.raw_var[0:end-start, ...], 'unsafe')

def __str__(self) -> str:
return (f"{self.var} linked to numpy.array({self.io_buf.shape}, "
f"{self.io_buf.dtype})@{self.io_buf.data})")
return (f"{self.var} linked to numpy.array(shape={self.io_buf.shape}, "
f"dtype={self.io_buf.dtype})")


class LGDOArrayIOManager(IOManager):
Expand Down Expand Up @@ -1130,7 +1130,7 @@ def write(self, start: int, end: int) -> None:
self.raw_var[0:end-start, ...], 'unsafe')

def __str__(self) -> str:
return f'{self.var} linked to {self.io_array}'
return f'{self.var} linked to lgdo.Array(shape={self.io_array.nda.shape}, dtype={self.io_array.nda.dtype}, attrs={self.io_array.attrs})'

class LGDOArrayOfEqualSizedArraysIOManager(IOManager):
"""IO Manager for buffers that are numpy ArrayOfEqualSizedArrays"""
Expand Down Expand Up @@ -1176,7 +1176,7 @@ def write(self, start: int, end: int) -> None:
self.raw_var[0:end-start, ...], 'unsafe')

def __str__(self) -> str:
return f'{self.var} linked to {self.io_array}'
return f'{self.var} linked to lgdo.ArrayOfEqualSizedArrays(shape={self.io_array.nda.shape}, dtype={self.io_array.nda.dtype}, attrs={self.io_array.attrs})'


class LGDOWaveformIOManager(IOManager):
Expand Down Expand Up @@ -1243,9 +1243,10 @@ def write(self, start: int, end: int) -> None:
self.t0_buf[start:end, ...] = self.t0_var[0:end-start, ...]

def __str__(self) -> str:
return (f"{self.var} linked to <pygama.lgdo.WaveformTable: values: "
f"{self.wf_table.values}, dt: {self.wf_table.dt}, t0: "
f"{self.wf_table.t0}>")
return (f"{self.var} linked to pygama.lgdo.WaveformTable("
f"values(shape={self.wf_table.values.nda.shape}, dtype={self.wf_table.values.nda.dtype}, attrs={self.wf_table.values.attrs}), "
f"dt(shape={self.wf_table.dt.nda.shape}, dtype={self.wf_table.dt.nda.dtype}, attrs={self.wf_table.dt.attrs}), "
f"t0(shape={self.wf_table.t0.nda.shape}, dtype={self.wf_table.t0.nda.dtype}, attrs={self.wf_table.t0.attrs}))")


def build_processing_chain(lh5_in: lgdo.Table, dsp_config: dict | str, db_dict: dict = None,
Expand Down
6 changes: 5 additions & 1 deletion src/pygama/lgdo/lh5_store.py
Expand Up @@ -623,7 +623,11 @@ def write_object(self,
# scalars
elif isinstance(obj, Scalar):
if name in group:
log.debug(f"overwriting '{name}' in '{group}'")
if wo_mode in ['o', 'a']:
log.debug(f'overwriting {name} in {group}')
del group[name]
else:
raise RuntimeError(f"tried to overwrite {name} in {group} for wo_mode {wo_mode}")
ds = group.create_dataset(name, shape=(), data=obj.value)
ds.attrs.update(obj.attrs)
return
Expand Down